生存時間データへのDeep Learningの適用 - DeepSurv

生存時間データの分析に関してちょこちょこ取り上げていますが、今回はそんな生存時間データにDeep Learningを適用してみた論文、DeepSurv論文を読んでまとめてみました。

bmcmedresmethodol.biomedcentral.com

1分で理解するDeepSurv

  • Cox比例ハザードモデルのような標準的な生存モデルでは、個人レベルでの治療の相互作用をモデル化するために、広範な事前の医学的知識が必要となる。
  • 通常のCox比例ハザードモデルでは、共変数の線形性を仮定していることから、個人レベルでの治療の相互作用をモデル化することが困難であった。Neural NetworkやSurvival Forestといった非線形手法は、原理上高次元の相互作用項をモデル化できるが、効果的な治療推奨システムとしての有用性はまだ示されていない。
  • DeepSurvは古典的なCox比例ハザードモデルの拡張として、生存時間データに適用可能なDeep Learning手法。DeepSurvを用いると患者の共変量と治療効果の相互作用をモデル化することが可能。さらに、患者の特徴と様々な治療の有効性との関係をモデル化することで、治療推奨に有用であることを示した。
  • 実際の臨床研究データを使用して、個別化された治療の推奨によって各データセットにおける患者の生存期間がどれだけ延びるかを検証した。治療を受けるように推奨した群は、治療を受けないように推奨した群に対して有意に生存時間が延長することが示された。

マテメソ

Cox比例ハザードモデルの復習

Cox比例ハザードモデルにおいて、ハザード関数はベースラインハザード関数 $\lambda_{0}(t)$とリスク関数 $r(x)=e^h(x)$ の積で表されるのでした。 $$ \lambda(t | x) = \lambda_{0}(t) \cdot e^{h(x)} $$ Cox回帰を実行するには、重み $\beta$ を調整してCox比例ハザードモデルの部分尤度関数を最適化(最大化)します。 部分尤度関数 $L_{c}(\beta)$ は、$T_i$: イベント発生時間、 $E_i$: イベントの発生有無、$x_i$: ベースライン時の共変量としたとき下記の式で表されます。 $$ L_{c}(\beta) = \underset{i : E_{i} = 1}{\prod} \frac{\hat{r}_{\beta}(x_{i})} { \underset{j \in \Re(T_{i})}{\sum} \hat{r}_{\beta}(x_{j})} = \underset{i : E_{i} = 1}{\prod} \frac{\exp (\hat{h}_{\beta}(x_{i}))} { \underset{j \in \Re(T_{i})}{\sum} \exp (\hat{h}_{\beta}(x_{j}))}, $$

モデルの構成

モデル自体は至ってシンプルな順伝播型ニューラルネットワークとなっています。
隠れ層は全結合層(Fully-Connected Layer)とそれに続くドロップアウト層(Dropout Layer)で構成されています。最終層は線形結合層となっており、出力は対数リスク関数$\hat{h}_{\theta(x)}$となっています。

なお、ドロップアウト層は、全結合層のノード群と出力層の間の接続の一部をランダムに切断することで過学習を防ぐ働きを担っています。
ネットワークに関するハイパーパラメータ(隠れ層の数、各層におけるノード数、ドロップアウト率等)はRandom Searchによって決定しています。

ここで、目的関数は負の対数部分尤度を平均化した値にL2ノルムを追加した値を使用しています(下記)。

$$ {}l(\theta) \!:=\! - \frac{1}{N_{E=1}} \sum_{i: E_{i} = 1} \left(\hat{h}_{\theta}(x_{i}) - \log \sum_{j \in \Re(T_{i})} e^{\hat{h}_{\theta}(x_{j})} \right) + \lambda \cdot ||\theta||^{2}_{2}, $$

治療推奨システムへの応用

すべての患者をn個の治療グループ $\tau \in \{0,1、…、n-1\}$ の1つに割り当てます。 各治療 $\tau = i$ を行った患者は独立したリスク関数 $e^{h_{i}(x)}$ を持つと仮定すると、$\tau = i$ の患者のハザード関数は下記で表されます。 $$ \lambda(t; x | \tau = i) = \lambda_{0}(t) \cdot e^{h_{i}(x)}. $$ 対数ハザードの差を、推奨関数として定義します。 $$ \begin{aligned} {}\text{rec}_{ij}(x) &= \log \left(\frac{\lambda(t;x | \tau = i)} {\lambda(t; x | \tau = j)} \right) = \log \left(\frac{\lambda_{0}(t) \cdot e^{h_{i}(x)}}{\lambda_{0}(t) \cdot e^{h_{j}(x)}} \right) \\ &= h_{i}(x) - h_{j}(x). \end{aligned} $$ 推奨関数 $\text{rec}_{ij}(x)$ が正となる場合、治療 $i$ は治療 $j$ よりも高い死亡リスクにつながります。したがって、患者は治療 $j$ を処方されるべきであると判断することができます。 一方、推奨関数 $\text{rec}_{ij}(x)$ が負となる場合、治療 $i$ は治療 $j$ よりも効果的であるとみなすことができます。したがって、患者は治療 $i$ を処方されるべきであると判断することができます。 なお、Cox比例ハザードモデルにおいて推奨関数を考えると、対数リスク関数の線形制約からモデルが患者の特徴に関係なく定数を返すことが分かります。 $$ \begin{aligned} \text{rec}_{ij}(x) &= \log \left(\frac{\lambda(t;x | \tau = i)} {\lambda(t; x | \tau = j)} \right) \\ &= \log \left(\frac{\lambda_{0}(t) \cdot e^{\beta_{0} i + \beta_{1} x_{1} + \ldots + \beta_{n} x_{n}}}{\lambda_{0}(t) \cdot e^{\beta_{0} j + \beta_{1} x_{1} + \ldots + \beta_{n} x_{n}}} \right) \\ &= \log \left(e^{\beta_{0} i + \beta_{1} x_{1} + \ldots + \beta_{n} x_{n} - (\beta_{0} j + \beta_{1} x_{1} + \ldots + \beta_{n} x_{n})} \right) \\ &= \beta_{0} i - \beta_{0} j \\ &= \beta_{0} (i-j). \end{aligned} $$

これは解釈として、モデルが重み $β_0$ を正または負と推定するかどうかに基づいて、すべての患者に同じ治療オプションの選択を推奨することを意味します。

治療(患者)ごとに異なるリスク関数を推定する特性から、個別化された治療推奨システムへの応用はDeep Learningベースの手法に分があるといえます。

結果

本論文では、大きく下記の2点を主張しています。

  1. (非線形性を持った対数リスク関数に対する)予測精度の高さ
  2. 治療推奨システムとしての有用性

予測精度の高さ

精度指標にはC-Indexを使用し、CPH(Cox比例ハザード)、RSF(Random Survival Forest)と比較しています。
シミュレーションにより生成した線形データへの当てはまりは極わずかにCox比例ハザードモデルに負けているものの、それ以外のデータセットではCox比例ハザードモデルの精度を上回っており、多くのデータセットでRSFと同等、またはそれ以上の精度が出ていることが示されています。

予測精度をシミュレーションにより検証

共変量を2変数としたときの対数リスク関数をシミュレーションすることでその予測精度の高さを示しています。

  1. リスク関数に線形性を仮定
    共変量 $x_0$, $x_1$、実際のリスク関数が共変量の線形関数($h(x) = x_0 + 2x_1$)で表されると仮定したときの各手法で予測されるリスク値、誤差を表示しています。 CPH(Cox比例ハザード)、DeepSurv共に精度よく推定できていることが分かります。

  2. リスク関数に非線形性を仮定
    共変量 $x_0$, $x_1$、実際のリスク関数が共変量の非線形関数($h(x) = \log (\lambda_{\max}) : \exp \left(-{\frac{x_{0}^{2} + x_{1}^{2}}{2 r^{2}}} \right)$)で表されると仮定したときの各手法で予測されるリスク値、誤差を表示しています。 当たり前ですが、CPH(Cox比例ハザード)では全く推定できていないことが分かります。
    一方、DeepSurvでは精度よく推定できていることが分かります。

治療推奨システムとしての有用性

シミュレーションデータ(非線形)、臨床データセットの2セットにおいて検証されています。
有用性の指標としては、推奨グループと非推奨グループにおける生存関数の中央値、およびログランク検定におけるp-valueを使用しています。RSF(Random Survival Forest)と比較しています。

治療推奨のグループ間差異を検証

  1. リスク関数に非線形性を仮定したデータセットを使用
    データセット内のシミュレートされた各患者に対して、治療グループ $\tau\in{0,1}$ を均一に割り当てた後、治療グループ $\tau = 1$ にはガウス分布に従う非線形の治療効果を与えています。
    その後、推奨グループと非推奨グループの両方について、各モデルによって推定されたカプラン・マイヤー生存曲線をプロットしています。ログランク検定の結果から、DeepSurvでは有意な差が得られた一方、RSFでは有意性が得られなかったことが分かります。

  2. 臨床データセットを使用
    続いて、治療の推奨、非推奨に関するデータを含む実際の臨床データセット(Rotterdam & German Breast Cancer Study Group (GBSG) )を使用します。推奨グループと非推奨グループの両方について、各モデルによって推定されたカプラン・マイヤー生存曲線をプロットしています。
    DeepSurv、RSFのいずれにおいても有意な差が得られています。一方で、RSFのp-valueの方が高く、治療推奨システムとしての有用性はDeepSurvの方が高いことが予想されます。

DeepSurvを試してみる

DeepSurvはPythonのパッケージとして、GitHub上で提供されています。

※インストールの際には、h5pyのインストール及び最新版のLasagneのインストールが必要でした。

$ pip install --upgrade https://github.com/Lasagne/Lasagne/archive/master.zip
$ pip install h5py

上記のインストールが完了したらDeepSurvのインストールに移ります。

$ git clone https://github.com/jaredleekatzman/DeepSurv.git
$ cd DeepSurv
$ pip install .

./example_data.csvにテスト用のデータセットが公開されています。
こちらを使用してテストランを行います(テストランのノートブックはGitHub上に公開されています)。

train_dataset_fp = './example_data.csv'
train_df = pd.read_csv(train_dataset_fp)
train_df.head()

##    Variable_1   Variable_2  Variable_3  Variable_4  Event  Time
## 0            0           3           2         4.6      1    43
## 1            0           2           0         1.6      0    52
## 2            0           3           0         3.5      1    73
## 3            0           3           1         5.1      0    51
## 4            0           2           0         1.7      0    51
{
    'x': numpy array of float32
    'e': numpy array of int32
    't': numpy array of float32
    'hr': (optional) numpy array of float32
}

pandasのDFをDeepSurvの入力フォーマットに変換するdataframe_to_deepsurv_ds関数を作成・実行します。

# event_col にはイベントの発生有無を示す列を指定 
# time_col にはイベント発生時間を示す列を指定
def dataframe_to_deepsurv_ds(df, event_col = 'Event', time_col = 'Time'):
    # 各データをnumpy arrayに変換
    e = df[event_col].values.astype(np.int32)
    t = df[time_col].values.astype(np.float32)
    x_df = df.drop([event_col, time_col], axis = 1)
    x = x_df.values.astype(np.float32)
    
    # DeepSurvの入力フォーマットとして返す
    return {
        'x' : x,
        'e' : e,
        't' : t
    }

train_data = dataframe_to_deepsurv_ds(train_df, event_col = 'Event', time_col= 'Time')

各ハイパーパラメータは辞書型で定義します。
本来であればデータセットごとにチューニングする必要がありますが、今回はテストランのため適当に与えます。

# ハイパーパラメータを定義
hyperparams = {
    'L2_reg': 10.0,
    'batch_norm': True,
    'dropout': 0.4,
    'hidden_layers_sizes': [25, 25],
    'learning_rate': 1e-05,
    'lr_decay': 0.001,
    'momentum': 0.9,
    'n_in': train_data['x'].shape[1],
    'standardize': True
}

DeepSurvインスタンスを作成し、.trainメソッドにより学習を行います。
学習の過程では、損失関数 (loss) の値と C-Index (ci)の値が表示されます。

# ハイパーパラメータを与えてDeepSurvインスタンスを作成
model = deepsurv.DeepSurv(**hyperparams)

# DeepSurv は学習データと検証データの追跡にTensorBoardを使用できます。

# ログを出力したくない場合は下記の3行をコメントアウトした上でlogger = Noneを指定します。
# logger = None

experiment_name = 'test_experiment_sebastian'
logdir = './logs/tensorboard/'
logger = TensorboardLogger(experiment_name, logdir=logdir)

# 学習
update_fn=lasagne.updates.nesterov_momentum # 使用する最適化手法を選択 \
                                            # 参照: http://lasagne.readthedocs.io/en/latest/modules/updates.html \
n_epochs = 2000

# If you have validation data, you can add it as the second parameter to the function
metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)

## 2020-11-22 12:36:03,215 - Training step 0/2000    |                         | - loss: 26.2085 - ci: 0.3693
## 2020-11-22 12:36:32,527 - Training step 250/2000  |***                      | - loss: 13.5646 - ci: 0.3617
## 2020-11-22 12:37:02,550 - Training step 500/2000  |******                   | - loss: 8.7107 - ci: 0.3730
## 2020-11-22 12:37:32,982 - Training step 750/2000  |*********                | - loss: 7.0760 - ci: 0.3919
## 2020-11-22 12:38:03,217 - Training step 1000/2000 |************             | - loss: 6.4698 - ci: 0.4375
## 2020-11-22 12:38:33,590 - Training step 1250/2000 |***************          | - loss: 6.2745 - ci: 0.5634
## 2020-11-22 12:39:04,227 - Training step 1500/2000 |******************       | - loss: 6.1954 - ci: 0.6801
## 2020-11-22 12:39:36,684 - Training step 1750/2000 |*********************    | - loss: 6.1768 - ci: 0.7048
## 2020-11-22 12:40:07,556 - Finished Training with 2000 iterations in 245.75s

学習の過程(学習曲線、検証曲線)は .plotメソッドにより描画することが可能です。
学習中に表示される各種精度指標は.trainメソッドの返り値(metrics)に含まれています。

# 最終的なC-Index指標を出力
print('Train C-Index:', metrics['c-index'][-1])
# print('Valid C-Index: ',metrics['valid_c-index'][-1])

# 学習曲線、検証曲線を描画
viz.plot_log(metrics)

## Train C-Index: (1999, 0.695762840452166)

終わりに

METABRICのデータセットでは、患者の臨床的特徴(ホルモン治療指標、放射線療法指標、化学療法指標、ER陽性指標、診断時年齢)に加え、遺伝子発現量(MKI67, EGFR, PGR, ERBB2)のデータを使用していました。
こういう分析はワクワクしますね。

本論文によりDeep Learningの生存時間データへの適用可能性が示されたことにより、 時系列の医療画像を入力データとするようなことも可能となるのではないでしょうか(もうあるかも)。

一方で、Deep Learningとなるとその説明性・解釈性は皆無となっており、あくまで「結果としてこれだけの精度が出せたので臨床応用が期待できる」といった書きぶりになってしまっています。実際にこのようなリコメンドシステムが医療機関で利用されるようになるには、精度のみならずその精度の説明性が担保される必要があると感じました。