Skip to main content

用于模型多样化的强化数据采样

项目描述

RDS

为模型多样化实施强化数据采样。

要求

  • 麻木的
  • 火炬
  • scikit-学习
  • 熊猫
  • tqdm

机器学习任务

该存储库支持多变量、文本和视觉数据的多项机器学习任务:

  • 二进制分类
  • 多类分类
  • 回归

安装

pip install torchRDS

用法

from torchRDS.RDS import RDS

trainer = RDS(data_file="datasets/madelon.csv", target=[0], task="classification", measure="auc", 
              model_classes=["models.MDL_RF", "models.MDL_MLP", "models.MDL_LR"], 
              learn="deterministic", ratio=0.7695, iters=100)
sample = trainer.train()

print("No of observations in training set: ", sum(sample))

实际用例

如果您想在此处列出真实世界的比赛或用例,请联系我们。

实验结果

已经在以下四个数据集上进行了实验。

数据集 任务 挑战 数据大小 评估
玛德隆 二进制分类 NIPS 2013 特征选择挑战 2,600 x 500(多变量) 曲线下面积 2003年
博士 回归 药物评论(Kaggle 黑客马拉松) 215,063 x 6(多变量,文本) R^2 2018
MNIST 多类分类 手写数字识别 70,000 x 28 x 28(图像) 微F1 1998
KLP 二进制分类 卡拉帕信用评分挑战 50,000 x 64(多变量,文本) 曲线下面积 2020

MADELON - 结果

采样 #样本 班级比例 LR 射频 MLP 合奏 上市
火车 测试 火车 测试
预设 2000 600 1.0000 1.0000 .6019 .8106 .5590 .6783 .9063
随机的 2000 600 .9920 1.0270 .5742 .7729 .5774 .6453 .9002
分层 2000 600 1.0000 1.0000 .5673 .7470 .6153 .6360 .8828
RDS^{DET} 2001年 599 1.0375 .9137 .6192 .8050 .6228 .6973 .8915
RDS^{STO} 2021 579 1.0010 .9966 .6192 .8050 .6050 .6947 .9106

博士 - 结果

采样 火车 测试 MLP 美国有线电视新闻网 合奏 上市
预设 161,297 53,766 .4580 .5787 .7282 .6660 .7637
随机的 161,297 53,766 .4597 .4179 .7353 .6485 .7503
RDS^{DET} 162,070 52,993 .4646 .5776 .7355 .6692 .7649
RDS^{STO} 161,944 53,119 .4647 .5370 .7509 .6562 .7600

MNIST - 结果

采样 #样本 班级比例 LR 射频 美国有线电视新闻网 合奏 上市
火车 测试 火车 测试
预设 60000 10000 .8571 .1429 .9647 .9524 .9824 .9819 .9917
随机的 59500 10500 .8500 .1500 .9603 .9465 .9779 .9768 .9914
分层 59500 10500 .8500 .1500 .9625 .9510 .9795 .9792 .9901
RDS^{DET} 59938 10062 .8562 .1438 .9495 .9382 .9757 .9769 .9927
RDS^{STO} 59496 10504 .8499 .1501 .9583 .9486 .9851 .9830 .9931

KLP - 结果

采样 #样本 班级比例 LR 射频 MLP 合奏 上市
火车 测试 火车 测试
预设 30000 20000 .0165 .0186 .5799 .5517 .5635 .5723 .5953
简单的 30000 20000 .0169 .0179 .5886 .5374 .5914 .5856 .6042
分层 30000 20000 .0173 .0173 .5952 .5608 .5780 .5983 .6014
RDS^{DET} 29999 20001 .0180 .0163 .6045 .5350 .5802 .6057 .5362
RDS^{STO} 30031 19969 .0172 .0174 .5997 .5491 .6354 .6072 .6096

引用这项工作

如果这项工作对您的研究有用,请考虑引用我们:

@misc{nguyen2020reinforced,
    title={Reinforced Data Sampling for Model Diversification},
    author={Hoang D. Nguyen and Xuan-Son Vu and Quoc-Tuan Truong and Duc-Trong Le},
    year={2020},
    eprint={2006.07100},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

参考

  • Lee, S.、Prakash, SPS、Cogswell, M.、Ranjan, V.、Crandall, D. 和 Batra, D.,2016 年。用于训练不同深度集成的随机多项选择学习。在神经信息处理系统的进展中(第 2119-2127 页)。
  • Peng, M.、Zhang, Q.、Xing, X.、Gui, T.、Huang, X.、Jiang, YG、Ding, K. 和 Chen, Z.,2019 年 7 月。用于类不平衡学习的可训练欠采样。在AAAI 人工智能会议论文集中(第 33 卷,第 4707-4714 页)。
  • Gong, Z.、Zhong, P. 和 Hu, W.,2019 年。机器学习的多样性。IEEE 访问7,第 64323-64350 页。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

torchRDS-0.3.tar.gz (10.2 kB 查看哈希

已上传 source

内置分布

torchRDS-0.3-py3-none-any.whl (9.5 kB 查看哈希

已上传 py3