用于模型多样化的强化数据采样
项目描述
RDS
为模型多样化实施强化数据采样。
要求
- 麻木的
- 火炬
- scikit-学习
- 熊猫
- tqdm
机器学习任务
该存储库支持多变量、文本和视觉数据的多项机器学习任务:
- 二进制分类
- 多类分类
- 回归
安装
pip install torchRDS
用法
from torchRDS.RDS import RDS
trainer = RDS(data_file="datasets/madelon.csv", target=[0], task="classification", measure="auc",
model_classes=["models.MDL_RF", "models.MDL_MLP", "models.MDL_LR"],
learn="deterministic", ratio=0.7695, iters=100)
sample = trainer.train()
print("No of observations in training set: ", sum(sample))
实际用例
如果您想在此处列出真实世界的比赛或用例,请联系我们。
实验结果
已经在以下四个数据集上进行了实验。
| 数据集 | 任务 | 挑战 | 数据大小 | 评估 | 年 |
|---|---|---|---|---|---|
| 玛德隆 | 二进制分类 | NIPS 2013 特征选择挑战 | 2,600 x 500(多变量) | 曲线下面积 | 2003年 |
| 博士 | 回归 | 药物评论(Kaggle 黑客马拉松) | 215,063 x 6(多变量,文本) | R^2 | 2018 |
| MNIST | 多类分类 | 手写数字识别 | 70,000 x 28 x 28(图像) | 微F1 | 1998 |
| KLP | 二进制分类 | 卡拉帕信用评分挑战 | 50,000 x 64(多变量,文本) | 曲线下面积 | 2020 |
MADELON - 结果
| 采样 | #样本 | 班级比例 | LR | 射频 | MLP | 合奏 | 上市 | ||
|---|---|---|---|---|---|---|---|---|---|
| 火车 | 测试 | 火车 | 测试 | ||||||
| 预设 | 2000 | 600 | 1.0000 | 1.0000 | .6019 | .8106 | .5590 | .6783 | .9063 |
| 随机的 | 2000 | 600 | .9920 | 1.0270 | .5742 | .7729 | .5774 | .6453 | .9002 |
| 分层 | 2000 | 600 | 1.0000 | 1.0000 | .5673 | .7470 | .6153 | .6360 | .8828 |
| RDS^{DET} | 2001年 | 599 | 1.0375 | .9137 | .6192 | .8050 | .6228 | .6973 | .8915 |
| RDS^{STO} | 2021 | 579 | 1.0010 | .9966 | .6192 | .8050 | .6050 | .6947 | .9106 |
博士 - 结果
| 采样 | 火车 | 测试 | 岭 | MLP | 美国有线电视新闻网 | 合奏 | 上市 |
|---|---|---|---|---|---|---|---|
| 预设 | 161,297 | 53,766 | .4580 | .5787 | .7282 | .6660 | .7637 |
| 随机的 | 161,297 | 53,766 | .4597 | .4179 | .7353 | .6485 | .7503 |
| RDS^{DET} | 162,070 | 52,993 | .4646 | .5776 | .7355 | .6692 | .7649 |
| RDS^{STO} | 161,944 | 53,119 | .4647 | .5370 | .7509 | .6562 | .7600 |
MNIST - 结果
| 采样 | #样本 | 班级比例 | LR | 射频 | 美国有线电视新闻网 | 合奏 | 上市 | ||
|---|---|---|---|---|---|---|---|---|---|
| 火车 | 测试 | 火车 | 测试 | ||||||
| 预设 | 60000 | 10000 | .8571 | .1429 | .9647 | .9524 | .9824 | .9819 | .9917 |
| 随机的 | 59500 | 10500 | .8500 | .1500 | .9603 | .9465 | .9779 | .9768 | .9914 |
| 分层 | 59500 | 10500 | .8500 | .1500 | .9625 | .9510 | .9795 | .9792 | .9901 |
| RDS^{DET} | 59938 | 10062 | .8562 | .1438 | .9495 | .9382 | .9757 | .9769 | .9927 |
| RDS^{STO} | 59496 | 10504 | .8499 | .1501 | .9583 | .9486 | .9851 | .9830 | .9931 |
KLP - 结果
| 采样 | #样本 | 班级比例 | LR | 射频 | MLP | 合奏 | 上市 | ||
|---|---|---|---|---|---|---|---|---|---|
| 火车 | 测试 | 火车 | 测试 | ||||||
| 预设 | 30000 | 20000 | .0165 | .0186 | .5799 | .5517 | .5635 | .5723 | .5953 |
| 简单的 | 30000 | 20000 | .0169 | .0179 | .5886 | .5374 | .5914 | .5856 | .6042 |
| 分层 | 30000 | 20000 | .0173 | .0173 | .5952 | .5608 | .5780 | .5983 | .6014 |
| RDS^{DET} | 29999 | 20001 | .0180 | .0163 | .6045 | .5350 | .5802 | .6057 | .5362 |
| RDS^{STO} | 30031 | 19969 | .0172 | .0174 | .5997 | .5491 | .6354 | .6072 | .6096 |
引用这项工作
如果这项工作对您的研究有用,请考虑引用我们:
@misc{nguyen2020reinforced,
title={Reinforced Data Sampling for Model Diversification},
author={Hoang D. Nguyen and Xuan-Son Vu and Quoc-Tuan Truong and Duc-Trong Le},
year={2020},
eprint={2006.07100},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
参考
- Lee, S.、Prakash, SPS、Cogswell, M.、Ranjan, V.、Crandall, D. 和 Batra, D.,2016 年。用于训练不同深度集成的随机多项选择学习。在神经信息处理系统的进展中(第 2119-2127 页)。
- Peng, M.、Zhang, Q.、Xing, X.、Gui, T.、Huang, X.、Jiang, YG、Ding, K. 和 Chen, Z.,2019 年 7 月。用于类不平衡学习的可训练欠采样。在AAAI 人工智能会议论文集中(第 33 卷,第 4707-4714 页)。
- Gong, Z.、Zhong, P. 和 Hu, W.,2019 年。机器学习的多样性。IEEE 访问,7,第 64323-64350 页。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
torchRDS-0.3.tar.gz
(10.2 kB
查看哈希)
内置分布
torchRDS-0.3-py3-none-any.whl
(9.5 kB
查看哈希)