Skip to main content

生存分析的深度学习

项目描述

火炬生活

使用 pytorch 进行生存分析

该库采用深度学习方法进行生存分析。

安装

pip install torchlife

如何使用

我们需要一个数据框,它有一个名为“t”的列表示时间,“e”表示一个死亡事件。

import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
e 年龄 种族 wexp 马尔 帕罗 优先
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3
from torchlife.model import ModelHazard

model = ModelHazard('cox', lr=0.5)
model.fit(df)
λ, S = model.predict(df)
时代 train_loss 有效损失 时间
0 6.993955 10.741218 00:00
1 8.774823 14.736155 00:00
2 9.991431 16.564432 00:00
3 10.995527 17.174604 00:00
4 11.723181 16.920387 00:00
5 12.060142 15.983603 00:00
6 12.174074 14.553919 00:00
7 12.038597 12.683950 00:00
8 11.702325 10.452137 00:00
9 11.218502 7.981377 00:00
10 10.570101 5.209520 00:00
11 9.859859 4.039678 00:00
12 9.155064 3.643379 00:00
13 8.514476 2.742133 00:00
14 7.915660 3.074418 00:00
15 7.413548 2.585245 00:00
16 6.967895 2.710384 00:00
17 6.569957 2.544009 00:00
18 6.215098 2.433515 00:00
19 5.880322 2.342750 00:00

让我们绘制数据框中第 4 个元素的生存函数:

x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)

PNG

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

torchlife-0.0.2.tar.gz (10.9 kB 查看哈希

已上传 source

内置分布

torchlife-0.0.2-py3-none-any.whl (17.2 kB 查看哈希

已上传 py3