Skip to main content

PyTorch 中的 LARS 实现

项目描述

火炬手

派皮 构建状态

PyTorch 中的LARS实现。

from torchlars import LARS
optimizer = LARS(optim.SGD(model.parameters(), lr=0.1))

什么是 LARS?

LARS(Layer-wise Adaptive Rate Scaling)是由 You、Gitman 和 Ginsburg 发布的为大批量训练设计的优化算法,它在每个优化步骤计算每层的局部学习率。根据论文,当使用 LARS 在 ImageNet ILSVRC (2016) 分类任务上训练 ResNet-50 时,即使批量大小扩大到 32K。

卷积网络的大批量训练

最初,LARS 是根据 SGD 优化器制定的,论文中没有提到扩展到其他优化器。相比之下,torchlars将 LARS 实现为一个包装器,它可以将包括 SGD 在内的任何优化器作为基础。

此外,与现有实现相比,火炬的 LARS 旨在更多地考虑在 CUDA 环境中的操作。多亏了这一点,在没有发生 CPU 到 GPU 同步的环境中,与仅使用 SGD 相比,您只能看到很少的速度损失。

用法

目前,火炬手需要以下环境:

  • Linux
  • Python 3.6+
  • PyTorch 1.1+
  • CUDA 10+

要使用 torchlars,请通过 PyPI 安装它:

$ pip install torchlars

要使用 LARS,只需用torchlars.LARS. LARS 继承torch.optim.Optimizer,因此您可以简单地将 LARS 用作代码的优化器。之后,当您调用 LARS 的 step 方法时,LARS 会在运行基础优化器(如 SGD 或 Adam)之前自动计算局部学习率

下面的示例代码显示了如何使用 LARS,使用 SGD 作为基础优化器。

from torchlars import LARS

base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)

output = model(input)
loss = loss_fn(output, target)
loss.backward()

optimizer.step()

基准测试

ImageNet 分类上的 ResNet-50

批量大小 LR政策 LR 暖身 时代 最佳 Top-1 准确率,%
256 聚(2) 0.2 不适用 90 73.79
8k LARS+聚(2) 12.8 5 90 73.78
16K LARS+聚(2) 25.0 5 90 73.36
32K LARS+聚(2) 29.0 5 90 72.26

上图和表格显示了 ResNet-50 上的重现性能基准,如本文的表 4 和图 5 所示。

青色线代表baseline结果,即batch size 256的训练结果,其他分别代表8K、16K、32K的训练结果。如您所见,每个结果都显示出相似的学习曲线和最佳的 top-1 准确度。

大多数实验条件与论文中使用的相似,但我们稍微改变了一些条件,如学习率,以观察 LARS 论文提出的可比较结果。

注意:我们参考 论文提供的日志文件 来获取上述超参数。

作者和许可

torchlars 项目由Kakao BrainChunmyong ParkHeungsub LeeMyungryong JeongWoonhyuk BaekChiheon Kim的帮助下开发。它在Apache License 2.0下分发。

引文

如果您将此库应用于任何项目和研究,请引用我们的代码:

@misc{torchlars,
  author       = {Park, Chunmyong and Lee, Heungsub and Jeong, Myungryong and
                  Baek, Woonhyuk and Kim, Chiheon},
  title        = {torchlars, {A} {LARS} implementation in {PyTorch}},
  howpublished = {\url{https://github.com/kakaobrain/torchlars}},
  year         = {2019}
}

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

torchlars-0.1.2.tar.gz (6.5 kB 查看哈希

已上传 source