生存分析的深度学习
项目描述
火炬生活
使用 pytorch 进行生存分析
该库采用深度学习方法进行生存分析。
安装
pip install torchlife
如何使用
我们需要一个数据框,它有一个名为“t”的列表示时间,“e”表示一个死亡事件。
import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
| 吨 | e | 鳍 | 年龄 | 种族 | wexp | 马尔 | 帕罗 | 优先 | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 20 | 1 | 0 | 27 | 1 | 0 | 0 | 1 | 3 |
| 1 | 17 | 1 | 0 | 18 | 1 | 0 | 0 | 1 | 8 |
| 2 | 25 | 1 | 0 | 19 | 0 | 1 | 0 | 1 | 13 |
| 3 | 52 | 0 | 1 | 23 | 1 | 1 | 1 | 1 | 1 |
| 4 | 52 | 0 | 0 | 19 | 0 | 1 | 0 | 1 | 3 |
from torchlife.model import ModelHazard
model = ModelHazard('cox', lr=0.5)
model.fit(df)
λ, S = model.predict(df)
| 时代 | train_loss | 有效损失 | 时间 |
|---|---|---|---|
| 0 | 6.993955 | 10.741218 | 00:00 |
| 1 | 8.774823 | 14.736155 | 00:00 |
| 2 | 9.991431 | 16.564432 | 00:00 |
| 3 | 10.995527 | 17.174604 | 00:00 |
| 4 | 11.723181 | 16.920387 | 00:00 |
| 5 | 12.060142 | 15.983603 | 00:00 |
| 6 | 12.174074 | 14.553919 | 00:00 |
| 7 | 12.038597 | 12.683950 | 00:00 |
| 8 | 11.702325 | 10.452137 | 00:00 |
| 9 | 11.218502 | 7.981377 | 00:00 |
| 10 | 10.570101 | 5.209520 | 00:00 |
| 11 | 9.859859 | 4.039678 | 00:00 |
| 12 | 9.155064 | 3.643379 | 00:00 |
| 13 | 8.514476 | 2.742133 | 00:00 |
| 14 | 7.915660 | 3.074418 | 00:00 |
| 15 | 7.413548 | 2.585245 | 00:00 |
| 16 | 6.967895 | 2.710384 | 00:00 |
| 17 | 6.569957 | 2.544009 | 00:00 |
| 18 | 6.215098 | 2.433515 | 00:00 |
| 19 | 5.880322 | 2.342750 | 00:00 |
让我们绘制数据框中第 4 个元素的生存函数:
x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
torchlife-0.0.2.tar.gz
(10.9 kB
查看哈希)
内置分布
torchlife-0.0.2-py3-none-any.whl
(17.2 kB
查看哈希)