Skip to main content

基于 PyTorch 的神经网络训练管道。旨在标准化培训过程并提高编码性能

项目描述

神经管道

基于 PyTorch 的神经网络训练管道。旨在标准化培训过程并提高编码性能。

构建状态 覆盖状态 可维护性

  • 核心是大约 2K 行,被测试覆盖,你不需要再次编写
  • 灵活且可定制的培训流程
  • 检查点管理和训练过程恢复(源和目标设备独立)
  • 通过内置( tensorboardMatplotlib)或自定义监视器进行指标处理和可视化
  • 训练最佳实践(例如学习率衰减和硬负挖掘)
  • 指标记录和比较(DVC 兼容)

训练 MNIST 示例:

此代码使用 Tensorboard 监控运行 MNIST 图像分类。基于 PyTorch示例的代码。

请参阅那里的完整示例。

from neural_pipeline.builtin.monitors.tensorboard import TensorboardMonitor
from neural_pipeline import DataProducer, AbstractDataset, TrainConfig, TrainStage,\
    ValidationStage, Trainer, FileStructManager

import torch
from torch import nn
from torchvision import datasets, transforms

class Net(nn.Module):
    # Network implementation

class MNISTDataset(AbstractDataset):
    transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def __init__(self, data_dir: str, is_train: bool):
        self.dataset = datasets.MNIST(data_dir, train=is_train, download=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        data, target = self.dataset[item]
        return {'data': self.transforms(data), 'target': target}

fsm = FileStructManager(base_dir='data', is_continue=False)
model = Net()

train_dataset = DataProducer([MNISTDataset('data/dataset', True)], batch_size=4, num_workers=2)
validation_dataset = DataProducer([MNISTDataset('data/dataset', False)], batch_size=4, num_workers=2)

train_config = TrainConfig([TrainStage(train_dataset), ValidationStage(validation_dataset)], torch.nn.NLLLoss(),
                           torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.5))

trainer = Trainer(model, train_config, fsm, torch.device('cuda:0')).set_epoch_num(50)
trainer.monitor_hub.add_monitor(TensorboardMonitor(fsm, is_continue=False))
trainer.train()

安装:

PyPI 版本 PyPI 下载/月 PyPI 下载

pip install neural-pipeline

对于builtin使用安装的模块:

pip install tensorboardX matplotlib

在 PyPi 上发布之前安装最新版本

pip install -U git+https://github.com/toodef/neural-pipeline

入门:

文档

文件状态 在那里查看完整的文档

数据流方案: 数据流

查看示例

下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

内置分布

neural_pipeline-0.1.0-py3-none-any.whl (30.3 kB 查看哈希

已上传 py3