Skip to main content

WILDS 分布偏移基准

项目描述


派皮 执照

概述

WILDS 是野外分布变化的基准,跨越多种数据模式和应用,从肿瘤识别到野生动物监测再到贫困地图。

WILDS 包包含:

  1. 自动处理数据下载、处理和拆分的数据加载器,以及
  2. 为每个数据集标准化模型评估的数据集评估器。

此外,示例脚本包含默认模型、优化器、调度器和训练/评估代码。可以轻松添加新算法并在所有 WILDS 数据集上运行。

有关更多信息,请访问我们的网站或阅读 WILDS 的主要论文 ( 1 ) 及其后续整合未标记数据 ( 2 )。如有问题和反馈,请在讨论区发帖。

安装

我们建议使用 pip 安装 WILDS:

pip install wilds

如果您已经安装了它,请检查您是否拥有最新版本:

python -c "import wilds; print(wilds.__version__)"
# This should print "2.0.0". If it doesn't, update by running:
pip install -U wilds

如果您打算编辑或为 WILDS 做出贡献,您应该从源代码安装:

git clone git@github.com:p-lambda/wilds.git
cd wilds
pip install -e .

examples/中,我们提供了一组脚本,可用于在 WILDS 数据集上训练模型。这些脚本也用于我们的论文 [ 1 , 2 ] 中的基准基准测试。这些脚本不是已安装的 WILDS 包的一部分。要使用它们,您应该从源代码安装,如上所述。

要求

WILDS 包取决于以下要求:

  • numpy>=1.19.1
  • ogb>=1.2.6
  • 过时>=0.2.0
  • 熊猫>=1.1.0
  • 枕头>=7.2.0
  • pytz>=2020.4
  • 火炬>=1.7.0
  • 火炬散射>=2.0.5
  • 火炬几何>=2.0.1
  • 火炬视觉>=0.8.2
  • tqdm>=4.53.0
  • scikit-learn>=0.20.0
  • scipy>=1.5.4

运行pip install wildsorpip install -e .将自动检查并安装所有这些要求 ,除了torch-scattertorch-geometric,它们需要 快速手动安装

示例脚本要求

要运行示例脚本,您还需要安装这些附加依赖项:

论文中的所有基线实验均在 Python 3.8.5 和 CUDA 10.1 上运行。

数据集

WILDS 目前包括 10 个数据集,我们在下面简要列出了这些数据集。有关完整的数据集描述,请参阅我们的论文 ( 1 , 2 )。

数据集 模态 标记的拆分 未标记的拆分
iwildcam 图片 训练、验证、测试、id_val、id_test extra_unlabeled
骆驼17 图片 训练、验证、测试、id_val train_unlabeled、val_unlabeled、test_unlabeled
rxrx1 图片 训练、验证、测试、id_test -
ogb-molpcba 图形 训练、验证、测试 train_unlabeled、val_unlabeled、test_unlabeled
全球小麦 图片 训练、验证、测试、id_val、id_test train_unlabeled、val_unlabeled、test_unlabeled、extra_unlabeled
民事评论 文本 训练、验证、测试 extra_unlabeled
关注 图片 训练、验证、测试、id_val、id_test train_unlabeled、val_unlabeled、test_unlabeled
贫困 图片 训练、验证、测试、id_val、id_test train_unlabeled、val_unlabeled、test_unlabeled
亚马逊 文本 训练、验证、测试、id_val、id_test val_unlabeled、test_unlabeled、extra_unlabeled
py150 文本 训练、验证、测试、id_val、id_test -

使用 WILDS 包

数据

WILDS 包为基准测试中的所有数据集提供了一个简单、标准化的接口。这个简短的 Python 片段涵盖了开始使用 WILDS 数据集的所有步骤,包括数据集下载和初始化、访问各种拆分以及准备用户可自定义的数据加载器。我们将在#Data loading中更详细地讨论数据加载。

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms

# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="iwildcam", download=True)

# Get the training set
train_data = dataset.get_subset(
    "train",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)

# Prepare the standard data loader
train_loader = get_train_loader("standard", train_data, batch_size=16)

# (Optional) Load unlabeled data
dataset = get_dataset(dataset="iwildcam", download=True, unlabeled=True)
unlabeled_data = dataset.get_subset(
    "test_unlabeled",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
unlabeled_loader = get_train_loader("standard", unlabeled_data, batch_size=16)

# Train loop
for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):
    x, y, metadata = labeled_batch
    unlabeled_x, unlabeled_metadata = unlabeled_batch
    ...

其中metadata包含诸如域身份之类的信息,例如,照片是从哪个相机拍摄的,或者患者的数据来自哪个医院等,以及其他元数据。

域名信息

为了允许算法利用域注释以及可用元数据上的其他分组,WILDS 包提供了Grouper对象。这些Grouper对象是从元数据中提取组注释的辅助对象,允许用户以灵活的方式指定分组方案。它们用于初始化组感知数据加载器(如#Data loading中所述)并实现依赖于域注释的算法(例如,Group DRO)。在下面的代码片段中,我们初始化并使用Grouper在 iWildCam 数据集上提取域注释的 a,其中域是位置。

from wilds.common.grouper import CombinatorialGrouper

# Initialize grouper, which extracts domain information
# In this example, we form domains based on location
grouper = CombinatorialGrouper(dataset, ['location'])

# Train loop
for x, y_true, metadata in train_loader:
    z = grouper.metadata_to_group(metadata)
    ...

数据加载

对于训练,WILDS 包提供了两种类型的数据加载器。标准数据加载器对训练集中的示例进行洗牌,并用于经验风险最小化 (ERM) 的标准方法,我们将平均损失最小化。

from wilds.common.data_loaders import get_train_loader

# Prepare the standard data loader
train_loader = get_train_loader('standard', train_data, batch_size=16)

为了支持其他依赖特定数据加载方案的算法,我们还提供了组数据加载器。在每个小批量中,组加载器首先对指定数量的组进行采样,然后从每个组中采样固定数量的示例。(默认情况下,组是随机均匀采样的,因此会增加少数组的权重。这可以通过uniform_over_groups参数进行切换。)我们如下初始化组加载器,使用Grouper它指定分组方案。

# Prepare a group data loader that samples from user-specified groups
train_loader = get_train_loader(
    "group", train_data, grouper=grouper, n_groups_per_batch=2, batch_size=16
)

最后,我们还提供了一个用于评估的数据加载器,它可以在不打乱的情况下加载示例(与训练加载器不同)。

from wilds.common.data_loaders import get_eval_loader

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((224, 224)), transforms.ToTensor()]
    ),
)

# Prepare the evaluation data loader
test_loader = get_eval_loader("standard", test_data, batch_size=16)

评价者

WILDS 包对每个数据集进行标准化和自动化评估。调用eval每个数据集的方法会产生论文和排行榜上报告的所有指标。

from wilds.common.data_loaders import get_eval_loader

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((224, 224)), transforms.ToTensor()]
    ),
)

# Prepare the data loader
test_loader = get_eval_loader("standard", test_data, batch_size=16)

# Get predictions for the full test set
for x, y_true, metadata in test_loader:
    y_pred = model(x)
    # Accumulate y_true, y_pred, metadata

# Evaluate
dataset.eval(all_y_pred, all_y_true, all_metadata)
# {'recall_macro_all': 0.66, ...}

大多数eval方法默认采用预测标签all_y_pred,但默认输入因数据集而异,并记录在eval相应数据集类的文档字符串中。

使用示例脚本

examples/中,我们提供了一组脚本,可用于在 WILDS 数据集上训练模型。

python examples/run_expt.py --dataset iwildcam --algorithm ERM --root_dir data
python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data

这些脚本被配置为使用默认模型和合理的超参数。有关我们论文中使用的精确超参数设置,请参阅我们的 CodaLab 可执行论文

下载和训练 WILDS 数据集

第一次运行这些脚本时,您可能需要下载数据集。您可以使用--download参数执行此操作,例如:

# downloads (labeled) dataset
python examples/run_expt.py --dataset globalwheat --algorithm groupDRO --root_dir data --download

# additionally downloads all unlabeled data
python examples/run_expt.py --dataset globalwheat --algorithm groupDRO --root_dir data --download  --unlabeled_split [...]

请注意,下载大量未标记数据是可选的;只有设置了一些,才会下载未标记的数据--unlabeled_split。(设置哪个无关紧要--unlabeled_split;所有未标记的数据将一起下载。)

或者,您可以使用独立wilds/download_datasets.py脚本下载数据集,例如:

# downloads (labeled) data
python wilds/download_datasets.py --root_dir data

# downloads (unlabeled) data
python wilds/download_datasets.py --root_dir data --unlabeled

这会将所有数据集下载到指定data文件夹。您还可以使用该--datasets参数下载特定数据集。

这些是我们每个数据集的大小,以及使用 NVIDIA V100 GPU 训练和评估单个 ERM 运行的默认模型所需的大致时间。

数据集命令 模态 下载大小 (GB) 磁盘大小 (GB) 训练+评估时间(小时)
iwildcam 图片 11 25 7
骆驼17 图片 10 15 2
rxrx1 图片 7 7 11
ogb-molpcba 图形 0.04 2 15
全球小麦 图片 10 10 2
民事评论 文本 0.1 0.3 4.5
关注 图片 50 55 6
贫困 图片 12 14 5
亚马逊 文本 7 7 5
py150 文本 0.1 0.8 9.5

以下是未标记数据包的大小:

数据集命令 模态 下载大小 (GB) 磁盘大小 (GB)
iwildcam 图片 41 41
骆驼17 图片 69.4 96
ogb-molpcba 图形 1.2 21
全球小麦 图片 103 108
民事评论 文本 0.3 0.6
关注* 图片 50 55
贫困 图片 172 184
亚马逊* 文本 7 7

* 这些未标注数据集与标注数据同时下载,无需单独下载。

虽然camelyon17数据集小且训练速度快,但我们建议不要将其用作原型方法的唯一数据集,因为在此数据集上训练的模型的测试性能往往表现出与随机种子相比有很大程度的可变性。

图像数据集(iwildcamcamelyon17rxrx1globalwheatfmowpoverty)往往具有较高的磁盘 I/O 使用率。如果您的训练时间比上面列出的大致时间慢得多,请考虑检查 I/O 是否是瓶颈(例如,如果您使用网络驱动器,则通过移动到本地磁盘,或者通过增加数据加载器工作人员的数量)。为了加快训练速度,您还可以通过切换--evaluate_all_splits和相关参数禁用每个 epoch 或所有拆分的评估。

算法

在该文件夹中,我们提供了在我们的论文( 1、2examples/algorithms作为基准的自适应算法的实现。所有算法都在来自 WILDS 数据集拆分的标记数据上进行训练。一些算法旨在利用未标记的数据。要加载未标记的数据,请在运行时指定。train--unlabeled_split

除了 、 、 和 等共享超参数之外lrweight_decay脚本batch_sizeunlabeled_batch_size接受算法特定超参数的命令行参数。

算法命令 超参数 笔记 见野生纸
风险管理 - 仅使用标记数据 ( 1 , 2 )
组DRO group_dro_step_size 仅使用标记数据 ()
深珊瑚 coral_penalty_weight 可以选择使用未标记的数据 ( 1 , 2 )
风险管理 irm_lambda,irm_penalty_anneal_iters 仅使用标记数据 ()
dann_penalty_weight, dann_classifier_lr, dann_featurizer_lr,dann_discriminator_lr 可以使用未标记的数据 ( 2 )
AFN afn_penalty_weight, safn_delta_r,hafn_r 旨在使用未标记的数据 ( 2 )
固定匹配 self_training_lambda,self_training_threshold 旨在使用未标记的数据 ( 2 )
伪标签 self_training_lambda, self_training_threshold,pseudolabel_T2 旨在使用未标记的数据 ( 2 )
吵闹的学生 soft_pseudolabels,noisystudent_dropout_rate 旨在使用未标记的数据 ( 2 )

该存储库的设置是为了促进通用算法的开发:可以将新算法添加到examples/algorithms所有 WILDS 数据集,然后使用默认模型在所有 WILDS 数据集上运行。

评估训练有素的模型

我们还提供了一个评估脚本,用于聚合不同复制的预测 CSV 文件并报告它们的组合评估。要使用它,请运行:

python examples/evaluate.py <predictions_dir> <output_dir> --root-dir <root_dir>

where<predictions_dir>是预测目录的路径,<output_dir>结果 JSON 将被写入的位置,并且<root_dir>是数据集根目录。iwildcam预测目录应该为每个包含要评估的预测 CSV 文件的数据集(例如)有一个子目录;请参阅我们的格式提交指南。评估脚本将跳过任何缺少预测文件的数据集。任何不在其中的数据集<root_dir>都将下载到<root_dir>.

再现性

我们在 CodaLab 上有我们论文的可执行版本,其中包含我们论文中报告的实验的确切命令、代码和数据,这些实验依赖于这些脚本。所有数据集的训练模型权重也可以在此处找到。所有配置和超参数也可以在examples/configs这个 repo 的文件夹中找到,数据集特定的参数在examples/configs/datasets.py.

排行榜

如果您正在 WILDS 上开发新的训练算法和/或模型,请考虑将它们提交到我们的公共排行榜

引用 WILDS ( Bibtex )

如果您在工作中使用 WILDS 数据集,请引用我们的论文:

  1. WILDS:野外分布变化的基准。Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw、Imran S. Haque、Sara Beery、Jure Leskovec、Anshul Kundaje、Emma Pierson、Sergey Levine、Chelsea Finn 和 Percy Liang。ICML 2021。

如果您使用来自 WILDS 数据集的未标记数据,请同时引用:

  1. 扩展无监督适应的 WILDS 基准。Shiori Sagawa*, Pang Wei Koh*, Tony Lee*, Irena Gao*, Sang Michael Xie, Kendrick Shen, Ananya Kumar, Weihua Hu, Michihiro Yasunaga, Henrik Marklund, Sara Beery, Etienne David, Ian Stavness, Wei Guo, Jure Leskovec 、Kate Saenko、Tatsunori Hashimoto、Sergey Levine、Chelsea Finn 和 Percy Liang。NeurIPS 2021 分布转变研讨会。

此外,请引用数据集页面上列出的介绍数据集的原始论文。

致谢

WILDS 基准测试的设计灵感来自Open Graph Benchmark,我们感谢 Open Graph Benchmark 团队在设置 WILDS 时提供的建议和帮助。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

wilds-2.0.0.tar.gz (100.1 kB 查看哈希

已上传 source

内置分布

wilds-2.0.0-py3-none-any.whl (126.2 kB 查看哈希

已上传 py3