基于 PyTorch 的神经网络训练管道。旨在标准化培训过程并提高编码性能
项目描述
神经管道
基于 PyTorch 的神经网络训练管道。旨在标准化培训过程并提高编码性能。
- 核心是大约 2K 行,被测试覆盖,你不需要再次编写
- 灵活且可定制的培训流程
- 检查点管理和训练过程恢复(源和目标设备独立)
- 通过内置( tensorboard,Matplotlib)或自定义监视器进行指标处理和可视化
- 训练最佳实践(例如学习率衰减和硬负挖掘)
- 指标记录和比较(DVC 兼容)
训练 MNIST 示例:
此代码使用 Tensorboard 监控运行 MNIST 图像分类。基于 PyTorch示例的代码。
请参阅那里的完整示例。
from neural_pipeline.builtin.monitors.tensorboard import TensorboardMonitor
from neural_pipeline import DataProducer, AbstractDataset, TrainConfig, TrainStage,\
ValidationStage, Trainer, FileStructManager
import torch
from torch import nn
from torchvision import datasets, transforms
class Net(nn.Module):
# Network implementation
class MNISTDataset(AbstractDataset):
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def __init__(self, data_dir: str, is_train: bool):
self.dataset = datasets.MNIST(data_dir, train=is_train, download=True)
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
data, target = self.dataset[item]
return {'data': self.transforms(data), 'target': target}
fsm = FileStructManager(base_dir='data', is_continue=False)
model = Net()
train_dataset = DataProducer([MNISTDataset('data/dataset', True)], batch_size=4, num_workers=2)
validation_dataset = DataProducer([MNISTDataset('data/dataset', False)], batch_size=4, num_workers=2)
train_config = TrainConfig([TrainStage(train_dataset), ValidationStage(validation_dataset)], torch.nn.NLLLoss(),
torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.5))
trainer = Trainer(model, train_config, fsm, torch.device('cuda:0')).set_epoch_num(50)
trainer.monitor_hub.add_monitor(TensorboardMonitor(fsm, is_continue=False))
trainer.train()
安装:
pip install neural-pipeline
对于builtin使用安装的模块:
pip install tensorboardX matplotlib
在 PyPi 上发布之前安装最新版本
pip install -U git+https://github.com/toodef/neural-pipeline
入门:
文档
数据流方案:
查看示例
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
内置分布
neural_pipeline-0.1.0-py3-none-any.whl
(30.3 kB
查看哈希)
关
Neural_pipeline -0.1.0-py3-none-any.whl 的哈希值
| 算法 | 哈希摘要 | |
|---|---|---|
| SHA256 | 74a32a7fe0d33efb1ae36fc7a223a93139d9b1bd68ccb998ae3908357a02d8ac |
|
| MD5 | 3da235fbbbdbcf079f4a439d0114bbbb |
|
| 布莱克2-256 | f6ae440a1d20745d5de34c3a79605a63ad00669b93ffa0125f01535c9c72134c |