Skip to main content

OML 是一个基于 PyTorch 的框架,用于训练和验证生成高质量嵌入的模型。

项目描述

示例工作流程 示例工作流程 示例工作流程 皮皮版 PyPI 状态 文件状态

OML 是一个基于 PyTorch 的框架,用于训练和验证生成高质量嵌入的模型。

常问问题

为什么需要 OML?

您可能会想“如果我需要图像嵌入,我可以简单地训练一个普通分类器并采用它的倒数第二层”。好吧,作为一个起点,这是有道理的。但是有几个可能的缺点:

  • If you want to use embeddings to perform searching you need to calculate some distance among them (for example, cosine or L2). Usually, you don't directly optimize these distances during the training in the classification setup. So, you can only hope that final embeddings will have the desired properties.

  • The second problem is the validation process. In the searching setup, you usually care how related your top-N outputs are to the query. The natural way to evaluate the model is to simulate searching requests to the reference set and apply one of the retrieval metrics. So, there is no guarantee that classification accuracy will correlate with these metrics.

  • 最后,您可能希望自己实现一个度量学习管道。 有很多工作:要使用triplet loss,您需要以特定方式形成batch,实现不同类型的triplets挖掘,跟踪距离等。对于验证,您还需要实现检索指标,包括有效的嵌入累积在这个时代,覆盖极端情况等。如果你有多个 GPU 并使用 DDP,那就更难了。您可能还想通过突出显示好的和坏的搜索结果来可视化您的搜索请求。您可以简单地将 OML 用于您的目的,而不是自己做。

什么是度量学习?

度量学习问题(也称为极端分类问题)是指我们有一些实体的数千个 id,但每个实体只有几个样本的情况。通常我们假设在测试阶段(或生产)我们将处理看不见的实体,这使得无法直接应用普通分类管道。在许多情况下,获得的嵌入用于对它们执行搜索或匹配过程。

以下是计算机视觉领域此类任务的一些示例:

  • 人/动物重新识别
  • 人脸识别
  • 地标识别
  • 在线商店和许多其他搜索引擎。

词汇表(命名约定)

  • embedding- 模型的输出(也称为features vectoror descriptor)。
  • query- 在检索过程中用作请求的样本。
  • gallery set - the set of entities to search items similar to query (also known as reference or index).
  • Sampler - an argument for DataLoader which is used to form batches
  • Miner - the object to form pairs or triplets after the batch was formed by Sampler. It's not necessary to form the combinations of samples only inside the current batch, thus, the memory bank may be a part of Miner.
  • Samples// Labels-Instances作为一个例子,让我们考虑一下 DeepFashion 数据集。它包括数千个时尚商品 ID(我们将它们命名为labels)和每个商品 ID 的几张照片(我们将单个照片命名为instancesample)。所有的时尚单品 id 都有它们的分组,比如“裙子”、“夹克”、“短裤”等等(我们命名它们categories)。请注意,我们避免使用该术语class以避免误解。
  • training epoch- 我们用于基于组合的损失的批量采样器的长度通常等于 [number of labels in training dataset] / [numbers of labels in one batch]. 这意味着我们不会在一个时期内观察到所有可用的训练样本(与普通分类相反),而是观察所有可用的标签。

OML 如何在幕后工作?

Training part implies using losses, well-established for metric learning, such as the angular losses (like ArcFace) or the combinations based losses (like TripletLoss or ContrastiveLoss). The latter benefits from effective mining schemas of triplets/pairs, so we pay great attention to it. Thus, during the training we:

  1. Use DataLoader + Sampler to form batches (for example BalanceSampler)
  2. [Only for losses based on combinations] Use Miner to form effective pairs or triplets, including those which utilize a memory bank.
  3. Compute loss.

Validation part consists of several steps:

  1. Accumulating all of the embeddings (EmbeddingMetrics).
  2. Calculating distances between them with respect to query/gallery split.
  3. 应用一些特定的检索技术,如查询重新排序或分数规范化。
  4. 计算检索指标,如CMC@kPrecision@kMeanAveragePrecision@k

自我监督学习呢?

最近对 SSL 的研究肯定取得了很好的成果。问题是这些方法需要大量的计算来训练模型。但在我们的框架中,我们考虑了最常见的情况,即普通用户的 GPU 不超过几个。

同时,忽视这个领域的成功是不明智的,所以我们仍然以两种方式利用它:

  • 作为检查点的来源,可以很好地开始训练。从出版物和我们的经验来看,它们在初始化方面比在 ImageNet 上训练的默认监督模型要好得多。因此,我们添加了仅通过在配置或构造函数中传递参数来使用这些预训练检查点初始化模型的可能性。
  • 作为灵感来源。例如,我们为 TripletLoss 改编了MoCo内存库的想法。

我是否需要了解其他框架才能使用 OML?

不,你没有。OML 与框架无关。尽管我们使用 PyTorch Lightning 作为实验的循环运行器,但我们也保留了在纯 PyTorch 上运行所有内容的可能性。因此,只有 OML 的一小部分是 Lightning 专用的,我们将此逻辑与其他代码分开(请参阅 参考资料oml.lightning)。即使您使用 Lightning,您也不需要知道它,因为我们提供了随时可用的Config API

在实现必要的包装器之后,使用纯 PyTorch 和代码的模块化结构的可能性为使用 OML 和您喜欢的框架留下了空间。

我可以在没有任何数据科学知识的情况下使用 OML 吗?

是的。要使用Config API运行实验, 您只需将转换器写入我们的格式(这意味着准备 .csv具有 5 个预定义列的表)。而已!

可能我们已经在Models Zoo中为您的领域提供了合适的预训练模型。在这种情况下,你甚至不需要训练它。

文档

文档可通过链接获得。

安装

OML 在 PyPI 中可用:

pip install -U open-metric-learning

您还可以从 DockerHub 中提取准备好的图像...

docker pull omlteam/oml:gpu
docker pull omlteam/oml:cpu

...或者自己建造一个

make docker_build RUNTIME=cpu
make docker_build RUNTIME=gpu

开始使用 Config API

如果您的数据集和管道足够标准,或者您没有机器学习或 Python 方面的经验,则使用配置是最佳选择。您可以在 示例中找到更多详细信息。

开始使用 Python

最灵活但需要知识的方法。您不受我们项目结构的限制,您只能使用您需要的那部分功能。您可以从下面的完整工作代码片段开始,在很小的图形数据集上训练和验证模型 。

训练

import torch
from tqdm import tqdm

from oml.datasets.base import DatasetWithLabels
from oml.losses.triplet import TripletLossWithMiner
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models.vit.vit import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)

model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)

for batch in tqdm(train_loader):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

验证

import torch
from tqdm import tqdm

from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models.vit.vit import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root =  "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)

model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()

val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics()
calculator.setup(num_samples=len(val_dataset))

with torch.no_grad():
    for batch in tqdm(val_loader):
        batch["embeddings"] = model(batch["input_tensors"])
        calculator.update_data(batch)

metrics = calculator.compute_metrics()

培训 + 验证 [闪电]

import pytorch_lightning as pl
import torch

from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
from oml.lightning.modules.retrieval import RetrievalModule
from oml.lightning.callbacks.metric import  MetricValCallback
from oml.losses.triplet import TripletLossWithMiner
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models.vit.vit import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root =  "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)

# model
model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# train
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics())

# run
pl_model = RetrievalModule(model, criterion, optimizer)
trainer = pl.Trainer(max_epochs=1, callbacks=[metric_callback], num_sanity_val_steps=0)
trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

动物园

模型 cmc1 数据集 权重 配置 哈希(开头)
ViTExtractor(weights="vits16_inshop", arch="vits16", ...) 0.903 DeepFashion Inshop 关联 关联 e1017d
ViTExtractor(weights="vits16_sop", arch="vits16", ...) 0.830 斯坦福在线产品 关联 关联 85cfa5
ViTExtractor(weights="vits16_cars", arch="vits16", ...) 0.907 汽车 196 关联 关联 9f1e59
ViTExtractor(weights="vits16_cub", arch="vits16", ...) 0.837 幼崽 200 2011 关联 关联 e82633

请注意,上述模型期望的是感兴趣区域的裁剪,而不是整个图片。

您可以指定所需的权重和架构并自动下载预训练的检查点(通过与 的类比torchvision.models)。但是,您也可以通过weights列中的链接手动执行此操作。

import oml
from oml.models.vit.vit import ViTExtractor

# We are downloading vits16 pretrained on CARS dataset:
model = ViTExtractor(weights="vits16_cars", arch="vits16", normalise_features=False)

# You can also check other available pretrained models...
print(list(ViTExtractor.pretrained_models.keys()))

# ...or check other available types of architectures
print(oml.registry.models.MODELS_REGISTRY)

# It's also possible to use `weights` argument to directly pass the path to the checkpoint:
model_from_disk = ViTExtractor(weights=oml.const.CKPT_SAVE_ROOT / "vits16_cars.ckpt", arch="vits16", normalise_features=False)

有关培训过程的更多详细信息,请访问示例子模块,它是 自述文件

致谢

该项目于 2020 年作为Catalyst库的一个模块启动。我要感谢在该模块上与我一起工作的人: Julia ShenshinaNikita BalaganskySergey Kolesnikov 和其他人。

我要感谢当它成为一个单独的项目时继续致力于这条管道的人: Julia ShenshinaAleksei TarasovVerkhovtsev Leonid

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

open-metric-learning-0.1.25.tar.gz (107.7 kB 查看哈希

已上传 source

内置分布

open_metric_learning-0.1.25-py3-none-any.whl (143.4 kB 查看哈希

已上传 py3