scouty AI LAB

HR TECH・AI企業であるscoutyが、自社で利用しているAI技術や話題のAIテクノロジーを紹介していきます。

生存時間解析について - ノンパラメトリック

scoutyの高濱です。本記事では、生存時間解析におけるノンパラメトリックな推定量のうち代表的なものである Kaplan-Meier 推定量と Nelson-Aalen 推定量について紹介します。

生存時間解析の概要については以下の記事をご参照ください。

上の記事で紹介しましたが、生存時間解析では、モデルの形状をどの程度の強さで仮定するかに依存して、モデルの種類が仮定が強い順に「フルパラメトリック」「セミパラメトリック」「ノンパラメトリック」に分類されます。 しかし、個別のサンプル*1の特徴量*2をモデルに含めるか否かについての表現がさまざまな文献でしばしば雑に扱われ、読者の混乱を招いていると認識しています*3。 具体的には、セミパラメトリックなモデルの多くは上述した特徴量を明示的に含む表現が使われがちである一方、フルパラメトリックなモデルに関しては、特徴量を含む前提で説明されている場合と含まない前提で説明されている場合があります。

それに対して、今回紹介するノンパラメトリック推定では、基本的に特徴量を含まない表現しか使われません。 このため、ノンパラメトリック推定は、ある母集団についてすでに知られたハザードと観測された生存時間データの間の差、あるいは複数の生存時間データの集団ごとの差の有無に関して検定を行うために用いることはできますが、 scoutyでの利用目的である「ある候補者の転職時期を予測したい」といった予測モデリングのニーズに応えることを想定した手法ではありません*4

以上を考慮し、ノンパラメトリックな手法については検定に関する事項などまで深くは立ち入らず、推定量の存在とその定義について簡単に紹介するに留めたいと思います。

ノンパラメトリック推定

本章では、生存関数 $S(t)$ の推定量である Kaplan-Meier 推定量と累積ハザード関数 $H(t)$ の推定量である Nelson-Aalen 推定量を紹介します。

以下では、 [Klein and Moeschberger] で紹介されている、急性白血病臨床試験からの寛解持続時間に関するデータセット [Friereich et al.] のうち、抗癌剤を投与された患者に関する生存時間データ 6-MP を用いて関数の推定量の描画の例を示します。 6-MPはデータ数21件の小さなデータセットで、生存時間を示す time と、打ち切りか否かを示す censored からなります:

id time censored
1 10 False
2 7 False
3 32 True
4 23 False
5 22 False
6 6 False
7 16 False
8 34 True
9 32 True
10 25 True
11 11 True
12 20 True
13 19 True
14 6 False
15 17 True
16 35 True
17 6 False
18 13 False
19 9 True
20 6 True
21 10 True

定量を定義するに際して必要な変数を定義します。 打ち切られたものを含まない死亡のイベントが $D$ 個の異なる時間 $t_1 < t_2 < \dots < t_D$ で発生し、各時間 $t_i$ では $d_i$ 人が死亡しているとします。 また、時間 $t_i$ において死亡する可能性がある対象者数を $Y_i$ します。 $Y_i$ は、時間が経つに従って対象者が死亡したり観測が打ち切られたりすることで減少していきます。

Kaplan-Meier 推定量

Kaplan と Meier によって提案された生存関数 $S(t)$ の推定量 $\hat{S}(t)$ は Kaplan-Meier 推定量*5と呼ばれ、以下のように定義されます: \begin{align*} \hat{S}(t) = \left\{ \begin{array}{ll} 1 \displaystyle\frac{\mbox{}}{\mbox{}}& (\mathrm{if} \quad t < t_1) \\ \displaystyle\prod_{i = t_1}^t \left(1 - \frac{d_{i}}{Y_{i}} \right) & (\mathrm{if} \quad t_1 \le t)\end{array} \right. \end{align*}

6-MPに対する Kaplan-Meier 推定量をプロットすると以下のようになります: f:id:scouty:20180411132818p:plain

変数及び Kaplan-Meier 推定量の定義から明らかですが、Kaplan-Meier 推定量は 1 からスタートし、誰かの死亡が観測された時点で値が小さくなるような階段関数です。

前記事でも述べた通り、生存関数は $S(t)$ は「ある時間 $t$ 以上生存する確率」を表すものです。「誰がいつ死亡した」というデータから生存関数を推定する場合、死亡が観測された時間ごとに生存関数の値が小さくなっていくように推定された関数の形を決めるのは直観的に自然であると言えると思います。さらにそれに加えて、打ち切りが発生した場合には推定生存関数の値を小さくしないような力を加えたものが Kaplan-Meier 推定量であるというように理解することができます。

Nelson-Aalen 推定量

Nelson によって最初に提案され、その後 Aalen によって再発見された累積ハザード関数 $H(t)$ の推定量 $\tilde{H}(t)$ Nelson-Aalen 推定量と呼ばれ、以下のように定義されます:

\begin{align*} \tilde{H}(t) = \left\{ \begin{array}{ll} 0 \displaystyle\frac{\mbox{}}{\mbox{}}& (\mathrm{if} \quad t < t_1) \\ \displaystyle\sum_{i = t_1}^t \frac{d_{i}}{Y_{i}} & (\mathrm{if} \quad t_1 \le t)\end{array} \right. \end{align*}

6-MPに対する Nelson-Aalen 推定量をプロットすると以下のようになります: f:id:scouty:20180411132812p:plain

Kaplan-Meier 推定量と Nelson-Aalen 推定量で値が変化するタイミングは同じです。Nelson-Aalen 推定量は 0 からスタートして、誰かの死亡が観測された時点で値が大きくなるような階段関数です。

累積ハザード関数 $H(t)$ は「ある時点 $t$ 直前まで死亡していない上で $t$ に死亡する確率」である $h(t)$ を累積したものですから、生存関数とは逆に、死亡が観測された時間ごとに値が大きくなるような関数形になっているのもまた自然であると思います。

定量の変換

前記事で述べた基本量の関係とも関連しますが、生存関数 $S(t)$, ハザード関数 $h(t)$, 確率密度関数 $f(t)$, 累積ハザード関数 $H(t)$ のような生存時間解析における基本量は、どれか一つが求まるとその他の関数も一意に求めることができます。 特に生存関数 $S(t)$ と累積ハザード関数 $H(t)$ の関係は以下のように表されます:

\begin{align*} H(t) &= -\ln\left[ S(t) \right] \\ S(t) &= \exp\left[ -H(t) \right] \end{align*}

これを Kaplan-Meier 推定量 $\hat{S}(t)$ あるいは Nelson-Aalen 推定量 $\tilde{H}(t)$ に関しても同様に適用すると、以下のようになります: \begin{align*} \hat{H}(t) &= -\ln\left[ \hat{S}(t) \right] \\ \tilde{S}(t) &= \exp\left[ -\tilde{H}(t) \right] \end{align*}

つまり、

  • Kaplan-Meier 推定量から導かれた累積ハザード関数 $\hat{H}(t)$
  • Nelson-Aalen 推定量から導かれた生存関数 $\tilde{S}(t)$

が求まります*6

6-MPに対して、 Kaplan-Meier 推定量による推定生存関数 $\hat{S}(t)$ と Nelson-Aalen 推定量による推定生存関数 $\tilde{S}(t)$ を同時にプロットすると以下のようになります: f:id:scouty:20180411132814p:plain

また、Kaplan-Meier 推定量による推定累積ハザード関数 $\hat{H}(t)$ と Nelson-Aalen 推定量による推定累積ハザード関数 $\tilde{H}(t)$ を同時にプロットすると以下のようになります: f:id:scouty:20180411132801p:plain

これらより、死亡のイベントが発生した時間に関数値が変化するという性質が常に保たれていることが確認できます。 異なる推定式から導かれた関数なので、当然関数値は異なります。Nelson-Aalen 推定量は Kaplan-Meier 推定量と比較して「小さいサイズの標本に対してよりよい性質をもつ」 [Klein and Moeschberger] とのことですが、ここでは詳細には触れないこととします。

おわりに

本記事では、生存時間解析におけるノンパラメトリックな推定量のうち代表的なものである Kaplan-Meier 推定量と Nelson-Aalen 推定量について非常に簡単に紹介しました。 最初に述べた通り、ノンパラメトリック推定は予測モデリングのニーズに応えることを想定した手法ではないため、この手法を直接転職可能性予測に組み込んでいるわけではありません。 参考までに、セミパラメトリックなモデルの一つであるCox比例ハザードモデルはノンパラメトリックなモデルとフルパラメトリックなモデルのハイブリッドともいえるものですが、そのノンパラメトリックな部分に関しては Kaplan-Meier 推定量の考え方が応用されています。 そのため、本記事をご覧いただいたことによってセミパラメトリックなモデルに関する理解を早めることができると思います。 今後はフルパラメトリックなモデル、セミパラメトリックなモデルに関する記事を公開する予定です。

参考文献

  • [Friereich et al.] Freireich, Emil J., et al. "The effect of 6-mercaptopurine on the duration of steroid-induced remissions in acute leukemia: A model for evaluation of other potentially useful therapy." Blood 21.6 (1963): 699-716.
  • [Klein and Moeschberger] Klein, John P., and Melvin L. Moeschberger. Survival analysis: techniques for censored and truncated data. Springer Science & Business Media, 2005.

脚注

*1:例えば、薬効に関する分析であれば各患者、転職時期予測に関する分析であれば各候補者を指すようなものです。

*2:例えば年齢や性別といった、個人ごとに異なる情報のことです。

*3:もちろん、そうでない文献も存在します。

*4:特に、実世界の問題に対する予測モデリングのために生存時間解析の手法を応用する場合は、セミパラメトリックなモデル、あるいはフルパラメトリックなモデルを利用するのが適切でしょう。

*5:「積-極限推定量」とも。

*6:\hat{} で Kaplan-Meier 推定量を、 \tilde{} で Nelson-Aalen 推定量をそれぞれ表しています。

生存時間解析について - 概要編

scoutyの高濱です。本記事では、インターン田村くんと協力してscoutyでの転職可能性予測ロジックに組み込んだ生存時間解析に関する基礎的な事項の説明を行います。 転職可能性予測は、こちらの記事の通り、候補者の現在の転職の可能性を推定して提示し、スカウトメールを送るか否かの判断を助けます。 生存時間解析は、予測ロジックのコンポーネントのひとつとして、経歴などの情報から候補者が現職を退職する時期(月単位)を予測するために利用されています。

生存時間解析とは

生存時間解析とは、生物の死亡や機械系の故障など、一つまたは複数の事象(イベント)が起こるまでの予想される時間を分析するための一般化線形モデルの一分野です。分野によって様々な呼び方があり、例えば工学では信頼性分析 (reliability analysis)、経済学では期間分析 (duration analysis)、社会学ではイベント分析 (event history analysis) などと呼ばれます。

具体的には、各分野で以下のような時間に関する分析を与える際に生存時間解析が用いられます。

  • 労働経済: 就職してから離職に至るまでの時間。
  • 機械工学: 機械が使われ始めてから正常に動作しなくなるまでの時間。
  • 医療: ある疾患と診断されてから死亡するまでの時間。
  • Webサービス: 会員登録から離脱までの時間。

このうち、最初に挙げた「就職してから離職に至るまでの時間」は、scoutyで転職可能性予測を行う際にまさに分析の対象としたいものに該当します。本記事では、生存時間解析の基礎に関する解説を行います。

関数名とその定義

最初に、生存関数 $S(t)$, 死亡関数 $F(t)$, ハザード関数 $h(t)$ をそれぞれ紹介します。これらの、時間 $t$ に関する関数は生存時間解析のモデルを決定づけるもので、このうち一つでも定まればその他の関数を比較的容易に導出することができます。そのため、生存時間をモデリングする場合は通常、生存(死亡)関数もしくはハザード関数を仮定したのち、パラメータ推定に必要なその他の関数を計算するという手順を踏みます。

以下では、 $T$ は生存時間の確率変数、 $t > 0$ は生存時間の実現値、 $f$ は生存時間の確率密度関数を表します。

ハザード関数 (hazard function)

ハザード関数とは、時間 $t$ までは生存しているという条件のもとで、$t$ と $t + \delta t$ の間に死亡する確率を表す関数で、十分小さい値 $\Delta t$ を用いて以下のように定義されます。 \begin{align*} h(t) = \lim _{\Delta \rightarrow 0} \frac{\mathrm{Pr}( {t \le T < t + \Delta t} \ | \ { T \ge t }) }{\Delta t} \end{align*}

一例として、よく使われるパラメトリックモデルであるワイブル回帰モデルでのハザード関数( $\lambda = 1.5$ として、 $\theta$ を動かしたもの)は以下のようになります。

f:id:scouty:20180223181559p:plain
図1: ワイブル分布 (λ = 1.5) のハザード関数

$h(t)$ はパラメータ推定などの式変形には現れませんが、モデリングする際のドメイン知識の埋め込みに有効である場合があります。例えば、ある機械はその使用年月に応じて疲労が蓄積するため、故障の確率は徐々に高まっていくことが知られていたとします。この機械に対する生存時間分析を行う際には、単調増加なハザード関数を仮定することで故障に関する知見をモデリングに反映することができます。

生存関数 (survival function)

ある時間 $t$ 以上生存する確率を表す生存関数 $S(t)$ は以下のように定義されます。 \begin{align*} S(t) = \mathrm{Pr}( T \ge t) = \int_t^\infty f(u)\mathrm{d}u \end{align*}

$f(t) \ge 0$ が常に成り立つことから、 $S(t)$ は単調減少な関数です。 一例として、よく使われるパラメトリックモデルであるワイブル回帰モデルでの生存関数( $\lambda = 1.5$ として、 $\theta$ を動かしたもの)は以下のようになります。

f:id:scouty:20180223181607p:plain
図2: ワイブル分布 (λ = 1.5) の生存関数
図2は図1に対応しています。先述のハザード関数がある時点での死亡率に該当する値をもつことを考慮すると、 $\theta$ が大きくなればなるほど死亡しやすさが大きくなると解釈できるので、生存関数は早い段階で0に近づきます。

$F(t) = 1 - S(t)$ によって定義される $F(t)$ を死亡関数と呼びます。 $t: 0 \rightarrow \infty$ のとき $S(t): 1 \rightarrow 0$ ですから、 $F(t)$ は $t: 0 \rightarrow \infty$ のとき $F(t): 0 \rightarrow 1$ となる単調増加な関数で、 $y = 0.5$ を軸にして $S(t)$ を上下反転させた形になります。

関数間の関係

先述した通り、生存関数 $S(t)$, 死亡関数 $F(t)$, ハザード関数 $h(t)$, 確率密度関数 $f(t)$ は、このうち一つが決定されるとその他の関数を容易に導出することができます。これらの関数に関してよく用いられる関係式を記載します。 \begin{align*} S(t) = \int_t^\infty f(u)\mathrm{d}u, \ \ f(t) = \frac{\mathrm{d}}{\mathrm{d}t}S(t) \end{align*} \begin{align*} f(t) = h(t) \times S(t), \ \ h(t) = \frac{f(t)} {S(t)} \end{align*} \begin{align*} F(t) = 1 - S(t) \end{align*}

データの形式について

打切り (censoring)

生存時間解析では、しばしば打切りが発生したデータを扱います。例えば、がん患者の生存時間を調査するとき、実際にデータ収集を行う場合を考えると、現実的には調査対象の患者の状態を永久に追跡し続けることは不可能であるため、調査期間が限定されてしまう状況を想定するのが自然です。そのため、収集されたデータには、調査期間中に死亡する患者に関するデータと、調査期間が終了するまでに死亡しなかった患者に関するデータが共に存在することになります。

もし観察対象(患者、従業員、機械など)の観察期間中にイベント(死亡、退職、故障など)が発生しなかった場合、打切りデータとして記録されます。打切りデータには、以下のような種類があります。

  • 右側打切り (right censored): 観察期間終了までにイベントが発生しなかった場合。
  • 左側打切り (left censored): 観察期間開始前にイベントが発生していた場合。
  • 中間打切り (interval censored): 2つの観察期間で挟まれた非観察期間にイベントが発生した場合。

特によく扱われるのは右側打切りデータであるため、本記事でも特に断りがない場合は打切りデータとして右側打切りのみを扱うこととします。

一般的なデータフォーマット

生存時間解析では以下のようなデータフォーマットを扱うことが多いです。

id 生存時間 (t) 打切り (censored)
1243 4 False
2389 20 True
1289 6 False
4389 12 False

上表に示したデータを用いて、ある群の死亡に関する傾向を発見するのが最も基本的な生存時間解析です。また、これに加えて各 id の特徴量(例えば、ある人の年齢・身長・体重・病歴*1など)を考慮した解析を行うように拡張することもできます。

パラメータ推定方法の概要

生存時間解析では、基本的に生存関数 $S(t)$ や確率密度関数 $f(t)$ に含まれているパラメータを最尤推定によって求めることで学習を行います。 $S(t)$ や $f(t)$ に仮定する関数形によって推定の手順が変化しますが、本節ではどのモデルの最尤推定でも共通して必要な(対数)尤度関数の定義について述べます。

一般的な最尤推定と異なり、生存時間解析では用いるデータに打切りが存在することを容認するため、尤度関数が多少複雑になります。先に述べた右側打切りを考慮した尤度関数の定義は以下のようになります。

\begin{align*} L(\theta) = \prod_{ i \ \in \ \mathrm{unc.} } \mathrm{Pr}( T = t_{i} ; \theta ) \times \prod_{i \ \in \ \mathrm{csd.}}\mathrm{Pr}( T > t_{i} ; \theta ) \end{align*}

ここで、式中の各変数が表す内容は以下のようになります。

  • $L( \theta )$: 尤度関数。最尤推定によって求めたい全てのパラメータを便宜的に $\theta$ で表現しています。
  • $\mathrm{unc.}$: 非打切り (uncensored) データからなる集合。
  • $\mathrm{csd.}$: 打切り (censored) データからなる集合。

さて、非打切りデータに対して用いる $\mathrm{Pr}( T = t_{i} ; \theta )$ の意味を考えると、打切りが発生しなかったため、尤度として「 $i$ がちょうど時間 $t_i$ で死亡する確率」を用いるのが適切なので、以下のように密度関数 $f(t)$ を用いて表現されます*2

\begin{align*} \mathrm{Pr} ( T = t_{i} ; \theta ) = f ( t_{i} ; \theta ) \end{align*}

次に、打切りデータに対して用いる $\mathrm{Pr}( T > t_{i} ; \theta )$ の意味を考えると、 $t$ の時点で打切りが発生したため、尤度として「 $i$ が少なくとも時点 $t_i$ まで生存する確率」を用いるのが適切なので、以下のように生存関数 $S(t)$ を用いて表現されます。

\begin{align*} \mathrm{Pr} ( T > t_{i} ; \theta ) = S( t_{i} ; \theta ) \end{align*}

では、 $f(t)$, $S(t)$ を用いて尤度関数を書き直します。さらに、サンプル $i$ が打切りデータである場合は $c_i = 0$, 非打切りデータである場合は $c_i = 1$ となるような変数 $c_i$ を用いて、尤度関数 $L(\theta)$ を以下のように簡略に表すことができます*3

\begin{align*} L(\theta) = \prod_{i = 1}^{N} \left\{ f ( t_{i} ; \theta )^{c_i} \times S( t_{i} ; \theta )^{1 - c_i} \right\} \end{align*}

最後に、最尤推定ではしばしば尤度関数の最大化のために対数尤度関数を扱うので、上式に対応する対数尤度関数もここで計算しておきます。

\begin{align*} \ln L(\theta) = \sum_{i = 1}^{N} \left\{c_i \ln f ( t_{i} ; \theta ) + (1 - c_i) \ln S( t_{i} ; \theta ) \right\} \end{align*}

実際のパラメータ推定では、対数尤度関数 $\ln L(\theta)$ の $\theta$ に関する微分を計算することで $\ln L(\theta)$ を最大化する $\theta$ を求めます。

生存時間解析で利用されるモデルの分類

生存時間解析でよく利用されるモデルは大きく(完全)パラメトリックなモデルとセミパラメトリックなモデルに分かれます。

パラメトリックなモデル

パラメトリックなモデルはハザード関数の形を陽に仮定します。一例として、特によく用いられる指数回帰モデルワイブル回帰モデルについては、ハザード関数 $h(t)$, 生存関数 $S(t)$, 確率密度関数 $f(t)$ は以下のように定義されます。

モデル ハザード関数 $h(t)$ 生存関数 $S(t)$ 確率密度関数 $f(t)$
指数回帰モデル $\theta$ $e^{-\theta t}$ $\theta e^{-\theta t}$
ワイブル回帰モデル $\theta \lambda (\theta t)^{\lambda - 1}$ $e^{-(\theta t)^\lambda}$ $\theta \lambda (\theta t)^{\lambda - 1} e^{-(\theta t)^\lambda}$

ここで、関数中の $\theta > 0$, $\lambda > 0$ は分布の形を決めるパラメータで、学習データを用いて上述した最尤推定を行うことで決定されます。なお、ワイブル回帰モデルは指数回帰モデルの一般化であり、ワイブル回帰モデルで $\lambda = 1$ とすれば各式は実際に指数回帰モデルの式と一致します。

これらの分布の特徴、導関数の導出、パラメータの推定、拡張の方法などについては別記事で詳細に解説します。

セミパラメトリックなモデル

セミパラメトリックなモデルでは、モデルを決定づける関数群は以下のように定義されます。

  • ハザード関数 $h(t) = h_0(t) r(\boldsymbol{x}, \boldsymbol{\beta}) \ \ (> 0)$
    • $h_0(t)$: 基準ハザード関数
    • $r(\boldsymbol{x}, \boldsymbol{\beta})$: $\boldsymbol{x}$ および $\boldsymbol{\beta}$ に依存する何らかの関数
      • $\boldsymbol{x}$: サンプルの特徴ベクトル
      • $\boldsymbol{\beta}$: 各特徴に対する重みベクトル
  • 累積ハザード関数 $H(t, \boldsymbol{x}, \boldsymbol{\beta}) = r(\boldsymbol{x}, \boldsymbol{\beta}) H_0(t)$
    • $H_0(t) = \int_0^t h_0(u) \mathrm{d}u$: 累積基準ハザード関数
  • 生存関数 $S(t) = \left[ S_0(t) \right]^{r(\boldsymbol{x}, \boldsymbol{\beta})}$
    • $S_0(t) = e^{-H_0(t)} $: 基準生存関数

パラメトリックなモデルとの本質的な違いとして、セミパラメトリックなモデルでは $h_0(t)$ の関数形を仮定しません。学習データに含まれるサンプルが「いつ死亡したか」「いつ観察が打ち切られたか」という情報から求められる生存関数の推定量*4のうち最もよく用いられる Kaplan-Meier 推定量の考え方をもとにして $h_0(t)$ を決定します。

ちなみに、サンプルの特徴 $\boldsymbol{x}$ を明示的に考慮することを前提としたモデルなので、その点も異なっているように見えますが、パラメトリックなモデルを $\boldsymbol{x}$ を考慮するように拡張することもできるので、この部分はパラメトリックなモデルに対する本質的なアドバンテージであるというわけではありません。

セミパラメトリックなモデルのうち、最もよく用いられるCox比例ハザードモデルでは $r(\boldsymbol{x}, \boldsymbol{\beta}) = \exp(\boldsymbol{\beta}^\top\boldsymbol{x})$ を利用します。

先に述べた対数尤度関数にCox比例ハザードモデルの生存関数と確率密度関数を代入して得られるものを完全尤度関数 (full likelihood function) と呼びますが、完全尤度関数の直接の最大化は不可能である [Kalbfleisch and Prentice] とのことで、実際には以下の処理を順に行うことでパラメータの推定を行います。

  1. 部分尤度関数 (partial likelihood function) の最大化によって $\boldsymbol{\beta}$ を推定
  2. 求まった $\boldsymbol{\beta}$ を使って基準生存関数 $S_0(t)$ に関する尤度を計算して最大化し、 $S_0(t)$ を推定
  3. 求まった $S_0(t)$ から基準ハザード関数 $h_0(t)$ を推定

Cox比例ハザードモデルに関するパラメータ推定の方法についても別記事で扱うこととします。

パラメトリック/セミパラメトリックの使い分けについて

モデリングを行う対象の死亡率が時間によってどのように変化するかについて非常に詳細に分かっている場合はパラメトリックなモデルが非常に効果的です。すでに知られている観察対象のふるまいをよく表現するハザード関数をもつようなモデルを選択することで、ドメイン知識を適切にモデルに取り込むことができます。 一方、ハザード関数に対しての事前情報に乏しい場合、複雑なパラメトリックモデルを使うべきではなく、セミパラメトリックなモデルを利用するのが無難であるといえます。

おわりに

本記事では生存時間解析の全体を俯瞰するために、用語の説明とモデルの簡単な説明を行いました。 重要な各モデルの詳細な解説は、今後に公開する記事中で行いますので、そちらをご参照ください。

参考文献

  • John D. Kalbfleisch and Ross L. Prentice, "The Statistical Analysis of Failure Time Data", 2nd Edition (2002).

  • David W. Hosmer, Stanley Lemeshow and Susanne May, "Applied Survival Analysis: Regression Modeling of Time-to-Event Data", 2nd Edition (2008).

脚注

*1:例えば、高血圧と診断されたことがあるか否かを表現する二値の特徴などが考えられます。

*2:$f(t)$ は密度関数なので、厳密には確率ではありませんが、 $f(t) = \lim_{\Delta t \rightarrow 0} (F(t + \Delta t) - F(t) ) / \Delta t$ であるとしてこれが用いられるようです。

*3:この表現が業界では一般的なようですが、 $c$ が censored の略であることを考えると、 $\mathrm{censored} = 1 (= \mathrm{True})$ を打切りデータに対応させるべきでは?(論理値逆転してないか?)という気持ち悪さがあります...。

*4:「生存関数の推定量」という表現は微妙ですが、ある期間 $t_i$ (たとえば 20ヶ月目から22ヶ月目 というような範囲)で一定となるような値を全て推定するようなイメージです。

Poincaré Embeddings による職種の類似度計算とその利用

scouty アルゴリズムエンジニアの高濱です。外部への情報発信はこの記事が最初なのでこの場を借りて自己紹介させていただきますが、私は scouty 代表の島田、リードエンジニアの伊藤京都大学工学部情報学科での同期で、京都大学大学院情報学研究科鹿島研究室で修士課程を修了した後、株式会社リクルートホールディングスを経て scouty に入社しました。代表的な著作物は [Takahama et al., 2018]*1, [Takahama et al., 2016]*2, [Takahama et al., 2014]*3 などです。よろしくお願いします。

さて、本記事では、 Poincaré Embeddings*4 を用いた職種の関係の埋め込みに関してご紹介します。第一段階として、サービス適用の実現性を測る目的での実験を行ったので、その実験の詳細と実験結果について記したいと思います。

なお、本記事で紹介する内容は scouty インターン山田くんに実装・実験してもらった結果をまとめたものです。

背景

職種間の関連度

scouty では、企業と候補者のマッチングや退職率予測を行うにあたり、候補者の職種情報を利用しています。 2つの職種(を表現する文字列)を考えたとき、職種間の関係には、 エンジニアEngineer のように等価なものと、 エンジニアインフラエンジニア のように包含関係(あるいは階層構造)にあるものなどがあり、これらをうまく扱える何らかの表現を利用したいと考えていました。

そこで今回、包含関係のデータから、適切な埋め込みを行い、例えば エンジニアインフラエンジニア意味的な距離を求める目的で、 Poincaré Embeddings の応用を検討しました。 乱暴に言うと、 Poincaré Embeddings を用いることで、職種間に定義された包含関係のみから、任意の職種間の意味的な距離を計算できるようになることを期待して検証を行います。

Poincaré Embeddings とその実装

Poincaré Embeddings は端的に言うと word2vec の埋め込み先をユークリッド空間ではなく双曲空間にするという手法で、階層構造やべき分布をもつデータを埋め込むという問題設定において、低次元でもよい表現を与えられるという特徴があります。また、実装が容易で、内積の計算式にちょっとした補正を加えることで利用することができます。 Poincaré Embeddings は2017年中頃に arXiv に投稿され、以上のような素晴らしい性質から非常に話題になりました。各所で参照されているので今更感はありますが、以下の記事でわかりやすく解説されています。

ABEJA Tech Blog のNIPS2017参加報告にもある通り、NIPS2017にも Spotlight で採択されました*5。私見ですが、 Poincaré Embeddings に限らず、非ユークリッド幾何学を応用して機械学習手法を再考する試みは近年のトレンドのようで、NIPS2017でこれに関連する分野のチュートリアルも行われています*6

FAIR による公式の PyTorch での実装*7も公開されましたが、今回は Gensim を用いて実験を行いました。なお、Gensim の API, Exemple に関するドキュメントは以下になります。

Poincaré Embeddings のAPI

Gensim で提供されている Poincaré Embeddings のAPIは以下のようになっています。

class gensim.models.poincare.PoincareModel(train_data,
                                           size=50,
                                           alpha=0.1,
                                           negative=10,
                                           workers=1,
                                           epsilon=1e-05,
                                           regularization_coeff=1.0,
                                           burn_in=10,
                                           burn_in_alpha=0.01,
                                           init_range=(-0.001, 0.001),
                                           dtype=<type 'numpy.float64'>,
                                           seed=0)

これら引数のうち、特に本記事と関連が深いものの定義を以下に挙げます。

  • train_data (iterable of (str, str)) – Iterable of relations, e.g. a list of tuples, or a PoincareRelations instance streaming from a file. Note that the relations are treated as ordered pairs, i.e. a relation (a, b) does not imply the opposite relation (b, a). In case the relations are symmetric, the data should contain both relations (a, b) and (b, a).
  • size (int, optional) – Number of dimensions of the trained model.
  • negative (int, optional) – Number of negative samples to use.

データセットの作成

本記事の実験で用意した、 gensim.models.poincare.PoincareModel に与える学習データ train_data の構造について述べます。 train_data は、関係を表すタプル (a, b)iterable です。 オブジェクト a がオブジェクト b に含まれる場合( ba の意味的な下位に該当する場合)に (a, b) というタプルで ab の関係を表現します。

scouty のデータベースにある職種のデータのうち、特に使用頻度の高いものを選び、関係がある職種のペアのリストを作成します。 「 ab に含まれる」という関係をもつ職種のペアには、例えば以下のようなものがあります。

  • Software EngineerEngineer に含まれる
  • Senior Software EngineerSoftware Engineer に含まれる
  • Web ProgrammerProgrammer に含まれる
  • UI DesignerDesigner に含まれる

また、 エンジニアEngineer のように、「 ab は等価」という関係をもつ職種のペアもあります。こういったペアの場合は、

  • エンジニアEngineer に含まれる
  • Engineerエンジニア に含まれる

という2つの包含関係のタプルに分解します。

上記の関係は、 Python では以下のように表現されることになります。

train_data = [
    ('Software Engineer', 'Engineer'),
    ('Senior Software Engineer', 'Software Engineer'),
    ('Web Programmer', 'Programmer'),
    ('UI Designer', 'Designer'),
    ('エンジニア', 'Engineer'),
    ('Engineer', 'エンジニア'),
]

実験

以下のコードは、 Jupyter Notebook 上で実行を行うためのものです。

最初に、必要なライブラリの import を行います。 今回はCSVファイルを読み込むために pandas を、可視化のために plotly を使用します。

from gensim.models.poincare import PoincareModel
from gensim.viz.poincare import poincare_2d_visualization
from IPython import display
from plotly.offline import init_notebook_mode, iplot
import pandas as pd

init_notebook_mode(connected=True)

続いて、データセットのロードと整形を行います。ここでは、前節に記したような訓練データが path/to/occupation_relations.csv に配置されているとします。 pandas でCSVファイルを読み込み、1つ目の要素が下位概念、2つ目の要素が上位概念になるようなタプルのリスト occupation_relations_list を定義します。

occupation_relations_list = [(a, b) for a, b in pd.read_csv('path/to/occupation_relations.csv', header=None).as_matrix()]

次に、 occupation_relations_list をモデルに与えて学習を行います。 PoincareModel の引数では、 size で分散表現の次元数、 negative で negative sampling の際に使用するデータ数をそれぞれ指定します。 本記事では、可視化を可能にするために二次元空間に埋め込みを行います (size=2)。また negative=8 とします。 学習は、メモリ16GBの2.3GHzデュアルコアIntelCore i5 プロセッサのCPUを搭載したMacBook Pro を用いたとき、500件程度の relation を含むサイズのデータセットに対して1分足らずで終了しました。

model = PoincareModel(occupation_relations_list, size=2, negative=8)
model.train(epochs=500)

最後に、学習の結果として実際にどのような埋め込みが行われるかを、 Gensim の poincare_2d_visualization を利用することで確認します。可視化は plotly を用いて行います。 poincare_2d_visualizationAPIは以下のようになっています。

gensim.viz.poincare.poincare_2d_visualization(model, tree, figure_title, num_nodes=50, show_node_labels=())

各引数で特に説明が必要なものは以下の通りです:

  • model (PoincareModel): 学習された model をそのまま渡します。
  • tree (set): 学習データで張られている relation の情報を渡します。(なぜか) set 型のみしか受け付けないので、本記事で言うところの occupation_relation_listset 型に変換します。
  • num_nodes (int or None): tree で渡した relation のうち、いくつのエッジを可視化するかを指定します。例えば num_node=10 とすると、エッジが10本のみ表示されます。 num_node=None とすると、全てのリレーションを可視化できます。
  • show_node_labels (iterable): 学習に用いたラベルのうち、可視化するもののリストを渡します。本記事では、500件以上の職種をデータセットとして与えたので、その中の代表的な職種のみを可視化するために、可視化するラベルのリスト major_occupation_list を指定します。
relations_set = set(occupation_relations_list)
# 代表的な職種のみをラベルとして可視化する
major_occupation_list = ['Engineer', 'Designer', 'Web', '取締役', 'Freelance', 'Intern', 'Programmer', 'Founder', 'Director', 'Marketer', 'Software Engineer']
figure_title = ''
iplot(poincare_2d_visualization(model, relations_set, figure_title, num_nodes=None, show_node_labels=major_occupation_list))

このコードを実行することによって、2次元の embedding を行った結果が以下のように可視化されます。

Poincaré Embeddings では、原点に近づけば近づくほど抽象的な概念であるといえますが、 Engineer取締役 といった、他の様々な職種を包含するような職種は比較的中心に近いことが確認できます。 また CEO のような特定の職種は他の職種をあまり包含しないので、中心から遠いところにプロットされており、これも妥当な結果であると言えると思います。 一方、(空間が歪んでいるので正しいか怪しいところですが) Software EngineerEngineer よりも Software Engineer取締役 の方が近く見えたりするなど、それぞれの職種の位置関係に関しては直感的に不自然な部分も見当たるため、もう少しデータの作り方や学習のチューニングを改善する余地が残されていると言えます。

おわりに

今回の記事では、 Poincaré Embeddings を利用することで、職種の類似度を計算するための(主観的に、概ね)尤もらしい embedding が得られることを確認しました。

現在、職種の類似度計算の結果を scouty のサービスへ適用する準備を進めています。具体的には、

  • 職種ごとに同様の退職傾向が見られることが確認されているため、等価あるいは近い職種に就いている人をうまく丸めることで、退職率予測の精度を高める
  • scouty を利用して候補者の職種を考慮した検索を行うとき、「ある職種にある程度近い職種の候補者を探す」といったような検索を行えるようにする

といった応用を検討しています。

また、 RubyRails, セールス法人営業 といった単語に関するデータセットを作成することで、職種ではなくスキルに関する embedding を作成することもできます。 これも、検索の際に「 Rails にある程度近いスキルをもった候補者」といった条件を指定できるようにするなどの応用が可能であり、今後取り組んでいきたいと考えています。

参考文献・脚注

*1:Ryusuke Takahama, Yukino Baba, Nobuyuki Shimizu, Sumio Fujita, Hisashi Kashima, "AdaFlock: Adaptive Feature Discovery for Human-in-the-loop Predictive Modeling", The 23nd AAAI Conference on Artificial Intelligence (AAAI-18)

*2:Ryusuke Takahama, Toshihiro Kamishima, Hisashi Kashima, "Progressive Comparison for Ranking Estimation", https://www.ijcai.org/Proceedings/16/Papers/546.pdf, The 25th International Joint Conference on Artificial Intelligence (IJCAI-16)

*3:Ryusuke Takahama, Naoki Otani, Sho Yokoi, Tomohiro Arai, Nozomi Nori, Norie Ugai, Koji Nakazawa, Hisashi Kashima, "私たちはお土産にどの八ッ橋を買えばよいのか", http://www.ml.ist.i.kyoto-u.ac.jp/wp/wp-content/uploads/2014/10/yatsuhashi.pdf

*4:Maximillian Nickel and Douwe Kiela, “Poincaré Embeddings for Learning Hierarchical Representations”, https://arxiv.org/abs/1705.08039

*5:Maximillian Nickel and Douwe Kiela, “Poincaré Embeddings for Learning Hierarchical Representations”, https://nips.cc/Conferences/2017/Schedule?showEvent=10034

*6:Michael Bronstein, Joan Bruna, Arthur Szlam, Xavier Bresson and Yann LeCun, “Geometric Deep Learning on Graphs and Manifolds”, https://nips.cc/Conferences/2017/Schedule?showEvent=8735

*7:https://github.com/facebookresearch/poincare-embeddings

AIビジネスのレシピ:AIビジネスを成功に導くアルゴリズムマネージャ

scouty代表の島田です。 今回の記事は、2017年9月21日に scouty ✕ talentio ✕ eureka で行われた下記の 「Ventures Engineer MeetUp #02 - AI & Server Side」というイベントで発表した「アルゴリズムマネジメント&デザイン 〜機械学習の技術選定とマネジメントについて〜」という発表を記事化したものです。(当日は、自分が突如入院してしまったのでリードエンジニアの showwinに代理登壇してもらいました) eure.connpass.com

今回は、scoutyで実際にAIを用いたアルゴリズム開発を例に、「AIビジネスを成功に導くにはどうすればよいか」という問の下、アルゴリズムのプロダクトマーケットフィットの確認のほか、事前リサーチやアルゴリズムの仕様決定や技術選定を行なう「アルゴリズムマネージャ」という役割とその必要性について説明します。 画像は、主にイベントのスライドから抜粋したものになります。

本記事でのAIの定義

本記事でいうAIおよび人工知能という言葉は、人工知能分野に含まれる研究や技術のことを指します。次の図(ML, DM, and AI Conference Map より引用)における Artificial Intelligence をはじめ、 Artificial Intelligence の領域と接している Computer VisionNLP、Data Mining、Machine Learning などの分野を含む広い領域がこれに該当します。

f:id:scouty:20171204163755p:plain:w580

AIビジネスというのは、事業のコア課題の解決の手法として上記でいうAIを利用しているビジネスのことを指します。

何故アルゴリズムマネージャが必要か

まず、本題に入る前に伝えたいことは、AI分野の研究とAIビジネスは全く別物だということです。 AI分野の中でも特に機械学習分野の研究では、一般に新しい手法を使って既存手法よりも高い数値を叩き出すことが大きな目標となっています。 そのためエンジニアとしては、少し精度を改善しただけで悦に浸りたくなります(実際、こういった瞬間はとても嬉しいものです):

f:id:scouty:20171202231943p:plain:w280

しかし、ビジネスの世界では精度の向上が必ずしも事業上のKGIの向上に繋がるとは限りません。

また、実際、上のようなタイプのエンジニアよりも、現実として見かけるのは次のようなエンジニアです:

f:id:scouty:20171202231957p:plain:w280

自分もエンジニアなのでよくわかりますが、エンジニアはHOWにこだわる生き物です。問題をどうスマートに、どうエレガントに(そして時には派手な手法を使って)解決することができるかに関心を払います。つまり、手法にこだわります。しかし、研究(あるいは趣味)と ビジネスの最大の違いは、顧客がいることです。 ルールベース(人工無能)だろうが人工知能だろうが、顧客の課題を解決できなければビジネスでは意味が無いのです。

つまり、エンジニアに丸投げしても、良いAIビジネスは生まれません。おそらく、この業界で蔓延しているひとつの大きな誤解は、「優秀な機械学習エンジニアがいればAIビジネスはうまくいく」という考え方でしょう。

それでは、ビジネスを一番良くわかっているCEOに聞いてみましょう:

f:id:scouty:20171202232007p:plain:w280

やはりダメでした。以前より「何でもかんでもディープラーニングを使えばいいわけじゃない」という考え方は浸透してきたものの、依然としてこれに近いことを言っている経営者やマネージャは多いものです。つまり、良いAIビジネスを作るためには、技術とビジネスの間に立つ人が必要なのです。

では、正しい問いはどのようなものになるのでしょうか。おそらく、それは次のようなものです:

f:id:scouty:20171202233030p:plain:w310

つまり、アルゴリズムがプロダクトマーケットフィットしているか(つまり、顧客が求めているか)を検証することがもっとも重要なのです。これを行なう人を、scoutyではアルゴリズムマネージャと呼んでいます。アルゴリズムマネージャはいわばアルゴリズムのプロダクトマネージャです。一般的なCTOの仕事とは守備範囲が異なりますが、多くの組織ではCTOが似たような役割を兼任していることが多いでしょう。 「顧客が必要としているの?」と問うこと自体はたとえ非技術者でもセンスがあればCEOでもできることですが、アルゴリズムを実際に顧客が必要とするものにすることは、技術的知識が無いとできません。次のセクションで、実際にアルゴリズムマネージャどんな仕事をするのか見てみましょう。

アルゴリズムマネージャの仕事

scoutyは、2017年11月8日に転職可能性予測アルゴリズムのβ版を発表しました。(下記参照)

scouty.co.jp

この転職可能性予測アルゴリズムの開発を例にとり、アルゴリズムマネージャの実際の仕事を次の各項目に整理しました。

  • システムダイナミクスマップの構築
  • メトリクス定義と評価
  • ロードマップ構築
  • 事前リサーチ
  • 技術選定
  • オペレーションの設計

各項目で用いられている図は、説明用に作成されたものなので、実際のアルゴリズム開発で用いられたものとは異なります。

システムダイナミクスマップの構築

システムダイナミクスマップは、「アルゴリズムの精度が1%上がると、ビジネスKGIがどのくらい伸びるのか?」ということを規定します。図中では、ビジネスKGIとKPI、アルゴリズムのメトリクスを抜き出し、正の因果関係がある要素ごとに矢印を張り、それぞれの強度を矢印の太さで表しています。scoutyの転職可能性予測(退職率予測)アルゴリズムでの例は次のようになります。

f:id:scouty:20171203014332p:plain:w680

このマップを描く際のコツは、マップを右(ゴール)から描くということです。売上を最大化しようとした結果実は機械学習がいらなかった、ということもあるので、ビジネス成果にこだわるのであればソリューションドリブン(解決策・アルゴリズムから考える)なやり方よりも、課題やゴールから出発したほうが有効でしょう。 このマップを正確に持つために、アルゴリズムマネージャは日々顧客のところに行ったり、現場の人と喋ったり、開発以外のことをしなければいけません。

メトリクス定義と評価

アルゴリズムは、改善のPDCAを回すにあたって評価することが必須になります。そこで、評価の指標となるメトリクスを定義します。 メトリクスは一般にアルゴリズムの精度に近いものをとりますが、一口に精度といっても、Precision、Recall、F値、平均二乗誤差などいろいろな指標があるので、そのどれを使って評価を行なうかを定義します。次の画像は、scoutyの転職可能性予測(退職率予測)アルゴリズムにおける評価軸の一例です。

f:id:scouty:20171203014121p:plain:w680

その数値が上がると、必ずKGIが上がり、
かつ測定しやすいもの をメトリクスとして定めると良いでしょう。しっかりとマップを見ながらKGIに結びつくものを決定します。

ロードマップ構築

最初から完璧なversion1はありません。最初から完璧なものを作るよりも、どのようなマイルストーンをどういうスケジュールで達成していくかというロードマップを構築することが重要となります。 例えば、次のようなスケジューリングが考えられます:

f:id:scouty:20171205150722p:plain:w680

これは短期的なスケジュールの例ですが、次にどのような条件が揃えば次のバージョンがリリースできるのか等、長期的なロードマップを作ることも重要です。 scoutyでは、「ルールベースレベル」「ライブラリ適用レベル」「最先端技術適用レベル」のような形で各アルゴリズムをレベル分けし、それぞれがどのフェーズにいるかを確認できるシートを作っています。フェーズごとに求められるエンジニアや技術の種類が全く異なるので、アルゴリズムがどのフェーズにいるかを常に把握しておく必要があります。

ロードマップやスケジュールに精度を含めるかに関しては、議論の余地があります。精度を含めても、単純に機械学習系のタスクは「〜〜をすると◯%上がる」といった単純な世界ではないので、結局設定した精度に達しない(そのための施策もわからない)、ということになり無意味なスケジュールとなってしまいます。「ビジネス上◯%までいかないと使えないから、◯%までいかなければ開発を辞める / リリースをしない」といった判断は可能かもしれませんが、「△ヶ月後に◯%まで伸ばす」というスケジュールはナンセンスなものになりがちです。したがって、「いつまでに◯◯を行なう」という施策ベースでスケジュールを立てるほうが現実的かもしれません。

事前リサーチ

事前リサーチをすることで、アルゴリズム開発の後工程が数倍楽になります。 事前リサーチをする主な目的は、「仮説を事前検証することで実装工数を減らすこと」と言って良いでしょう。実際のアルゴリズムを実装しながら、「あれも効果ない、これも効果ない」などとやっていては、工数がかさみます。

scoutyは、Jupyter Notebookで雑でもいいので特徴量ごとに転職年数分布の有意差が出そうかを実験してまとめていました。実際にリサーチを行ってまとめる人はアルゴリズムマネージャでなくても構いません(scoutyでもリサーチ業務はインターンが担当)が、リサーチ項目を決めたり仮設を立てたりする部分はアルゴリズムマネージャが率先してやるべきでしょう。 転職可能性予測アルゴリズムでは、「使用プログラミング言語ごとに勤務年数分布が変わるんじゃないか?」という仮説が出たので、実際のアルゴリズムに組み込む前に雑に集計を採りプロットをしました。結果、グラフは以下のようになり(横軸が勤務年数、縦軸が割合)、大きな差が無いことが確認できたので、アルゴリズムに取り込むことは止めました。

f:id:scouty:20171203180935p:plain:w490
使用プログラミング言語ごとの勤務年数分布

逆に、分布に大きな差があるために積極利用していこうと確認できた属性は、「所属会社規模」で、下記のようにグラフごとに分布の差があることが確認できます。

f:id:scouty:20171204155339p:plain:w490
所属会社規模ごとの勤務年数分布

技術選定

技術選定自体は、実装するエンジニアとのディスカッションを通じて行われるべきですが、可能な選択肢のリサーチやそれぞれのメリット・デメリットの整理などはアルゴリズムマネージャが中心となってやるべきタスクとなります。 技術選定の基準は様々ですが、自分は以下の4つの基準を用いることが多いという印象です。(広告業界だと実行速度といった指標も入れるべきなので、ここはケースバイケースです)

f:id:scouty:20171203190227p:plain:w540

もちろん、手法インプットとアウトプットが解きたい課題にフィットしているということが前提となります。学習データ量によっても手法ごとにパフォーマンス違う(少なくてもそこそこの精度が出る手法も、一定精度出すまでにかなりのデータがいる手法もある)ので、そういった点も考慮して手法を選定できればベストでしょう。 ビジネスでは必要なデータが貯まるまでに時間もかかったりもするので、「いついつまでに◯◯という手法を使い、その後は△△を使った実装に着手する」といったロードマップと合わせて検討するのが良いでしょう。

オペレーションの設計

新しいアルゴリズムを作ると、それに伴い実際のビジネスのオペレーションが変化します。例えば何かを自動化するアルゴリズムだとしたら、人手が不要になるので、その後の人の采配や、自動化した部分をどう運用していくかを考えなければいけません。転職可能性予測アルゴリズムの例では、「ラベル付けされた転職可能性で絞り込む」作業が顧客に発生するので、カスタマーサポートチームにその工程のサポートを加える必要がありました。(せっかく作ったアルゴリズムが顧客に利用されなければ無意味なので) アルゴリズムの開発に常にオペレーションの変化が発生するとは限りませんが、発生する場合は、アルゴリズムマネージャが率先してビジネス側の担当者とオペレーションを設計し、運用していかなければなりません。 また、アルゴリズムの保守の部分もアルゴリズムマネージャが担わなければいけません。データが変遷するとアルゴリズムの中身も変えなければいけないので、運用の型作りも作らなければいけません。(このあたりは手法が確立されている領域ではないので、作っていきたいですね)

アルゴリズムマネージャに必要な能力

以上をふまえて、アルゴリズムマネージャにとって必要なスキルセットは以下のようなものとなるでしょう。

  • 各手法に関しての広い知識 (できれば実装経験があり、工数や必要なデータ量の肌感が掴めるのが望ましい)
  • AI / ML系のライブラリ・クラウドサービスなどの広い知識(実装しなくて済む部分と実装が必要な部分の理解)
  • CRISP-DM などのデータ分析のプロセスモデルへの理解
  • クライアント・ユーザとの対話能力
  • 経営や事業上のKPIへの理解
  • 実装するエンジニアのマネジメント能力

知識が無いから、実装が苦手だからマネージャになるわけではありません。アルゴリズムマネージャがミスをすると、ビジネス上結果につながらないことを工数をかけて実行するといったことが発生するなど経営的にも大きなロスになるので、むしろ一番経験がある人がやるべきでしょう。

scoutyでは、アルゴリズムマネージャを募集しています。

データサイエンティストという職業はここ数年で浸透しましたが、アルゴリズムマネージャないしそれに類する役割は未だ必要性や仕事内容が認知されていませんし、仕事内容も曖昧なままです。scoutyでは、アルゴリズムマネージャの仕事をこなすと同時に、他社にも活かせる一般的な型やフローを構築し、それを世の中に発信していける人材を探しています。 現在は代表の島田がこれにあたる役割を兼任していますが、今後アルゴリズムも増えるにあたって専任の方を募集していく予定ですので、興味のある方は、下記のリンクから応募をお願いします!

www.wantedly.com

また、今回のテーマ+α の内容は再度外部のイベントで発表予定です。イベントを共催してくれるの企業の方や、発表内容のリクエストなど、連絡やコメントお待ちしております!

RNNで言語モデルを作る - 実装編

代表の島田です。前回の記事 RNNで言語モデルを作るための理論 では、言語モデルを作るという目的で一般的なRNNの構造についての解説を行いました。それを踏まえて、今回の記事では Python で実際に言語モデルを実装し、その言語モデルを用いて自動で生成された文章の内容を確認してみます。 scoutyでもRNNは今後文生成や、スカウトメールの文面と返信率の相関性検証などに使っていこうと考えている技術です。

今はフルスクラッチで書かなくても、Chainerのような便利なライブラリがあるので、実践で使うならそちらの方が便利でしょう。 ただ、実際にフルスクラッチで実装することで、中身の原理を理解し、チューニングすることも自分でできるようになるので、今回はそういった趣旨でRNNの実装を行ってみます。

モジュール構成

一般的なNNと同様、以下のようなモジュール構成になります:

  • predict: Forward Propagation を行う関数
    • 入力 x: 文(単語 $x^{(t)}$  を並べた1次元配列)
      • e.g., "I have a" を表現する [0, 4, 2]
    • 出力 y: 予測値(単語 $y^{(t)}$ を並べた1次元配列)
      • e.g., "have a pen" を表現する [4, 2, 3]
    • 出力 s: 隠れ層
  • compute_mean_loss: 損失関数
    • 入力 X: 訓練データ(文 $\boldsymbol{x}$ を並べた2次元配列)
      • e.g., "I have a" と "You want to" を表現する [[0, 4, 2], [1, 3, 0]]
    • 入力 D: 教師データ(文 $\boldsymbol{d}$ を並べた2次元配列)
      • e.g., "have a pen" と "want to make" を表現する [[4, 2, 3], [3, 0, 3]]
    • 出力 mean_loss: XD から計算された損失関数の値
  • acc_deltas_bptt: 重み更新値(デルタ)を計算する関数
    • 入力 x: 文(単語 $x^{(t)}$ を並べた1次元配列)
    • 入力 d: 教師データ(単語 $d^{(t)}$ を並べた1次元配列)
    • 入力 y: 予測値(単語 $y^{(t)}$ を並べた1次元配列)
    • 入力 s: 隠れ層
    • 入力 steps: BPTTで遡るタイムステップ数
  • train: トレーニングを実行する関数
    • 入力は数が多いので省略します(コード内の docstring をご参照ください)

なお、もちろんこれは実装の一例であり、入出力や関数の切り分けパターンは他にも存在します。

各モジュール

今回の実装では、以下のように RNN クラスのインスタンス変数として重みや重み更新値を設定します:

class RNN(object):
    def __init__(self, vocab_size, hidden_dims):
        self.vocab_size = vocab_size
        self.hidden_dims = hidden_dims 
        
        # matrices
        # V (input -> hidden)
        self.V = random.randn(self.hidden_dims, self.vocab_size) * sqrt(0.1)
        # W (hidden -> output)
        self.W = random.randn(self.vocab_size, self.hidden_dims) * sqrt(0.1)
        # U (hidden ->; hidden)
        self.U = random.randn(self.hidden_dims, self.hidden_dims) * sqrt(0.1)
        
        # aggregated weight changes for V, W, U
        self.deltaV = zeros((self.hidden_dims, self.vocab_size))
        self.deltaW = zeros((self.vocab_size, self.hidden_dims))
        self.deltaU = zeros((self.hidden_dims, self.hidden_dims))

predict の実装

Forward Propagation, つまり入力(例えば "The bank went bankrupt")から次単語列(例えば "bank went bankrupt again")の予測を行う関数 predict は、次のように実装されます:

def predict(self, x):
    '''
    Args:
        x: list of words, as indices (e.g.: [0, 4, 2])
    Returns:
        y: matrix of probability vectors for each input word
        s: matrix of hidden layers for each input word
    '''
    # NOTE: in this implement, s[0] = [0. 0. ... 0.]
    #       and s[t+1] is the hidden layer corresponding to y[t]
    s = zeros((len(x) + 1, self.hidden_dims))
    y = zeros((len(x), self.vocab_size))

    for t in range(len(x)):
        x_vector = self.get_one_hot_vector(x[t])

        net_in = dot(self.V, x_vector) + dot(self.U, s[t])
        s[t + 1] =  sigmoid(net_in)

        net_out = dot(self.W, s[t + 1])
        y[t] = softmax(net_out)

    return y, s

前回記事の図に示したように、単語の長さに応じてレイヤーの枚数が変わっていくようなイメージです。

$t$ 番目の単語を one-hot ベクトル $\boldsymbol{x}^{(t)}$ で表した場合は文の表現は行列になりますが、本記事の実装では $\boldsymbol{x}^{(t)}$ のうち1となっているインデックスだけを整数で表すことで、文を1次元のベクトルで表現することとします。つまり、 predict の引数 x は1次元配列 (list) になります。

損失関数 compute_loss, compute_mean_loss の実装

今回の実装では、損失関数として Cross Entropy Loss を用います。このとき、文一つの損失 $J$ は各単語の損失 $J^{(t)}$ の和*1として、次のように表されます:

\begin{align} J^{(t)} &= -\sum^{|V|}_{j=1} d^{(t)}_j \log y_j^{(t)}\\ J &= \sum^{n}_{t=1} J^{(t)} \end{align}

今回はデータセットの評価にはmean_loss関数、つまり、データ全体でみた損失関数の単語ごとの平均を用いました。

def compute_loss(self, x, d):
    y, s = self.predict(x)
    loss_t = zeros(len(y))

    for t in range(len(y)):
        d_t_vector = self.get_one_hot_vector(d[t])
        y_t_vector = y[t]

        # dot product works as product &amp;amp; summation
        loss_t[t] = -dot(d_t_vector, log(y_t_vector)) 

    # combined loss is summation of loss over t
    loss = sum(loss_t)

    return loss

def compute_mean_loss(self, X, D):
    '''
    Args:
        X: a list of input vectors   (e.g., [[0, 4, 2], [1, 3, 0]]
        D: a list of desired outputs (e.g., [[4, 2, 3], [3, 0, 3]]
    Returns:
        mean_loss
    '''
    sum_of_loss = 0
    sum_of_length = 0

    for i in range(len(X)):
       sum_of_loss += self.compute_loss(X[i], D[i])
       sum_of_length += len(X[i])

    # loss per word = total loss in the dataset devided by total number of words in the dataset.
    mean_loss = sum_of_loss / float(sum_of_length)

    return mean_loss

重み更新値デルタ関数 acc_deltas_bptt の実装

前回記事の更新式をそのままコードにすることで、重み更新値デルタを計算する関数 acc_deltas_bptt を以下のように実装できます:

def acc_deltas_bptt(self, x, d, y, s, steps):
    '''
    Args:
        steps: number of time steps to go back in BPTT
    Return:
        None
    '''

    net_out_grad = ones(len(x))
    net_in_grad = array([s_t * (ones(len(s_t)) - s_t) for s_t in s])
    sum_deltaW = zeros((self.vocab_size, self.hidden_dims))
    sum_deltaV = zeros((self.hidden_dims, self.vocab_size))
    sum_deltaU = zeros((self.hidden_dims, self.hidden_dims))

    # NOTE: in this implement, s[0] = [0. 0. ... 0.] 
    #       and s[t+1] is the hidden layer corresponding to y[t] (same in net_in_grad)
    for t in reversed(range(len(x))):
        d_t_vector = self.get_one_hot_vector(d[t])
        delta_out_t = (d_t_vector - y[t]) * net_out_grad[t]
        sum_deltaW += outer(delta_out_t, s[t + 1])

        delta_in = zeros((len(x), self.hidden_dims))

        for tau in range(0, 1 + steps):
            if t - tau < 0:
                break
            if tau == 0:
                delta_in[t - tau] = dot(self.W.T, delta_out_t) * net_in_grad[t + 1]
            else:
                delta_in[t - tau] = dot(self.U.T, delta_in[t - tau + 1]) * net_in_grad[t - tau + 1]
            sum_deltaV += outer(delta_in[t - tau], x[t - tau])
            sum_deltaU += outer(delta_in[t - tau], s[t - tau])

    # multiply learning rate when actually applying delta
    self.deltaW = sum_deltaW
    self.deltaV = sum_deltaV
    self.deltaU = sum_deltaU

遡るタイムステップを指定する引数 steps はハイパーパラメータで、 steps を大きくすれば大きくするほど計算結果は正確になりますが、計算量も大きくなります。

train の実装

トレーニングの際には、「エポック」という単位で処理を繰り返します。各エポックでは、まずトレーニングセンテンスごとに acc_delta_bptt を実行し、誤差逆伝搬でデルタを計算します。デルタの計算が終わると重みを更新し、次のエポックに移り、これを繰り返す流れです。 ここで、重みの更新の大きさをコントロールする 学習率 というパラメータがありますが、学習率は各エポックごとにだんだん小さくしていくのが一般的です。学習率をどのように変化させて行くかもハイパーパラメータのひとつとして実装者が決定します。今回の実装では、  m エポック目での学習率  \eta_m を以下のように定義します:

 \eta_m = \eta_0 \frac{r}{m + r}.

$r$ は実数の定数で、今回は $r=5$ としています。 train は以下のように実装します:

def train(self, X, D, X_dev, D_dev, epochs, learning_rate, anneal, back_steps, batch_size, min_change):
    '''
  
    Args:
        X: a list of input vectors       e.g., [[0, 4, 2], [1, 3, 0]]
        D: a list of desired outputs     e.g., [[4, 2, 3], [3, 0, 3]]
        X_dev: a list of input vectors   e.g., [[0, 4, 2], [1, 3, 0]]
        D_dev: a list of desired outputs e.g., [[4, 2, 3], [3, 0, 3]]
        epochs: maximum number of epochs (iterations) over the training set
        learning_rate: initial learning rate for training
        anneal: positive integer. if > 0, lowers the learning rate in a harmonically after each epoch.
                higher annealing rate means less change per epoch.
                anneal=0 will not change the learning rate over time
        back_steps: positive integer. number of timesteps for BPTT. if back_steps < 2, standard BP will be used
        batch_size: number of training instances(sentence?) to use before updating the RNN's weight matrices.
                    if set to 1, weights will be updated after each instance. if set to len(X), weights are only updated after each epoch
        min_change: minimum change in loss between 2 epochs. if the change in loss is smaller than min_change, training stops regardless of
                    number of epochs left

    Returns:
        None
    '''

    print("Training model for {0} epochs\ntraining set: {1} sentences (batch size {2})".format(epochs, len(X), batch_size))
    print("Optimizing loss on {0} sentences".format(len(X_dev)))
    print("Vocab size: {0}\nHidden units: {1}".format(self.vocab_size, self.hidden_dims))
    print("Steps for back propagation: {0}".format(back_steps))
    print("Initial learning rate set to {0}, annealing set to {1}".format(learning_rate, anneal))
    print("\ncalculating initial mean loss on dev set")

    initial_loss = self.compute_mean_loss(X_dev, D_dev)
    print("initial mean loss: {0}".format(initial_loss))

    prev_loss = initial_loss
    loss_watch_count = -1
    min_change_count = -1

    a0 = learning_rate

    best_loss = initial_loss
    bestU, bestV, bestW = self.U, self.V, self.W
    best_epoch = 0

    for epoch in range(epochs):
        t0 = time.time()
        
        if anneal > 0:
            learning_rate = a0 / ((epoch + 0.0 + anneal) / anneal)
        else:
            learning_rate = a0
        print("\nepoch %d, learning rate %.04f" % (epoch + 1, learning_rate))

        count = 0

        # use random sequence of instances in the training set (tries to avoid local maxima when training on batches)
        for i in random.permutation(range(len(X))):
            count += 1
            stdout.write("\r\tpair {0}".format(count))

            x_i = X[i]
            d_i = D[i]
            y_i, s_i = self.predict(x_i)
            if back_steps < 2:
                self.acc_deltas(x_i, d_i, y_i, s_i)
            else:
                self.acc_deltas_bptt(x_i, d_i, y_i, s_i, back_steps)

            if count % batch_size == 0:
                self.deltaU /= batch_size
                self.deltaV /= batch_size
                self.deltaW /= batch_size
                self.apply_deltas(learning_rate)

        if count % batch_size > 0:
            self.deltaU /= (count % batch_size)
            self.deltaV /= (count % batch_size)
            self.deltaW /= (count % batch_size)
            self.apply_deltas(learning_rate)

        print("\n\tcalculating new loss on dev set")
        loss = self.compute_mean_loss(X_dev, D_dev)
        print("\tmean loss: {0}".format(loss))
        print("\tepoch done in %.02f seconds" % (time.time() - t0))
        
        if loss < best_loss:
            best_loss = loss
            bestU, bestV, bestW = self.U.copy(), self.V.copy(), self.W.copy()
            best_epoch = epoch

        # make sure we change the RNN enough
        if abs(prev_loss - loss) < min_change:
            min_change_count += 1
        else:
            min_change_count = 0
        if min_change_count > 2:
            print("\ntraining finished after {0} epochs due to minimal change in loss".format(epoch + 1))
            break

        prev_loss = loss
    print("\ntraining finished after reaching maximum of {0} epochs".format(epochs))
    print("best observed loss was {0}, at epoch {1}".format(best_loss, (best_epoch + 1)))
    print("setting U, V, W to matrices from best epoch")

    self.U, self.V, self.W = bestU, bestV, bestW

トレーニング結果と文章生成結果

今回の実装では、隠れ層のユニット数 $D_h$, BPTTにおいて遡るタイムステップ数  \tau, 学習率の初期値  \eta_0 をそれぞれ以下のような範囲で変更しました:

  • $D_h \in \{10,\ 50,\ 100\}$
  •  \tau \in \{0,\ 3,\ 10\}
  •  \eta_0 \in \{0.5,\ 0.1,\ 0.05\}

パラメータの組み合わせは計27通りです。実験のためのデータセットとして Penn Treebank の WSJ Corpus を利用しています。

まず、ハイパーパラメータのチューニングを行います。このとき、全てのデータを利用するとトレーニングに時間がかかるので、トレーニングに用いるセンテンスを1000個に絞ります。なお、絞られたデータセットでのボキャブラリーサイズ $|V|$ は $|V| = 2000$ であり、エポック数は25として実験を行いました。 各パラメータに対する最終的な平均損失 (mean loss) の値は以下の表1のようになりました:

表1: `train` の際の mean loss の一覧表。  \tau を定めたときの mean loss を  \mathcal{L}(\tau) とすると、各セルの中の値はそれぞれ ( \mathcal{L}(0),  \mathcal{L}(3),  \mathcal{L}(10)) を表す。
表1

今回の実装で試したパラメータでは、  \eta_0 = 0.5, $D_h = 50$,  \tau = 5.94 が最も良い組み合わせだとわかりました。

次に、このハイパーパラメータの組み合わせを固定してフルトレーニングセット(6万件くらい)でネットワークのトレーニングをやり直し*2、できた言語モデルを使って自動で文章を生成させてみました。文章を作成する関数 generate_sequence のコードは以下のようになります:

def generate_sequence(self, start, end, maxLength):
    sequence = [start]
    loss = 0.
    x = [start]
    while True:
        # predict next word from current sequene x
        y,s = self.predict(x)

        # generate next word by sampling the word  according to the last element of y
        word_index = multinomial_sample(y[-1])

        x.append(word_index)
        sequence.append(word_index)
        pointwise_loss = -log(y[-1][word_index])
        loss += pointwise_loss

        if word_index == end or len(sequence) > maxLength:
            break

    return sequence, loss

文頭記号 <s> から出発し、 predict 関数を使って $\boldsymbol{y}^{(t)}$ ベクトルを計算し、その確率分布に応じて次の単語を選択、今までのセンテンスと結合させて、再度 predict 関数で次の単語を推定する... という流れで文章を作り、文末記号 </s> に到達したらセンテンスを返す、という処理です。

出力結果のうち、特に mean loss の小さかった例が以下になります。

mean loss: 3.36850018043
[’<s>’, ’for’, ’offered’, ’market’, ’.’, ’,’, ’that’, ’says’, ’,’, ’</s>’] 

mean loss: 3.63943955362
[’<s>;’, ’net’, ’was’, ’to’, ’share’, ’the’, ’</s>’]

mean loss: 3.78469000827
[’<s>;’, ’these’, ’has’, ’the’, ’resources’, ’the’, ’he’, ’,’, ’.’, ’it’,
’which’, ’fund’, ’</s>’]

意味の無い文が多くなっていますが、今回はデータセットも絞った上でハイパーパラメータのチューニングもまだ十分にできていなかったので、改善の余地はありそうです。 ", that says," や "was to share" といった部分的に意味の通っている表現は見受けられるものの、全体を通しては意味不明な文章が多いですね。この精度であればマルコフ連鎖モデル(品詞 → 品詞の遷移 (transition) 確率と、品詞 → 単語の emission 確率を組み合わせるモデル)などを用いる方がいいかもしれません。 意味不明な文章が生成されてしまった原因のひとつとして考えられるのは、ボキャブラリーサイズ2000のデータセットに対して one-hot ベクトルを使うことで、単語を表現するベクトル空間がスパースになってしまっていることです。 $\boldsymbol{y}^{(t)}$ ベクトルの一番大きな値でも0.01か0.02(つまり一番出やすい単語でも1%か2%の確率、他が0.5%くらい)というケースが多く、この分布に従ったらほとんどランダムに単語をピックアップしているようなものです。ピックアップのアルゴリズム(例えば上位100単語だけ見て確率を正規化し直すとか)を変えるだけで性能は向上するかもしれません。 今回は実験と原理理解のために実装した形なので、精度向上に関しての考察はこのあたりに留めておきます。

まとめ

言語モデルにおいて最も伝統的なNgramは、直前N個の単語まで考慮してコーパス内の単語のつながりをカウントし、単にカウントとカウントの割り算で確率を計算する、というシンプルなモデルでした。 このモデルの問題点は、コーパスにある単語のつながりしかカウントできないので、実際には出現する可能性があってもコーパスに出ていないというだけで単語の生起確率が0になってしまう、というものです。 これを解消する手法が単語カウントに他の単語カウントから得られた適当な数を足して0をなくすというSmoothingですが、これにもいろいろな問題があります(ここでは省略します)。端的に言うと、Ngramは汎化能力が低いという弱みがあります。

この点、RNNは単語をカウントしているのではなく、重みを学ぶことで単語と単語のつながり方を学習しているようなものなので、汎化能力が相対的に高いのです。つまり、コーパスに無い初見の単語でも確率0を与えないような出力を行うことができます。 しかし、今回のようなモデルでは one-hot ベクトルを利用するという性質上ベクトルがスパースになるという(おなじみの)問題は避けられません。今回のRNNアーキテクチャは単語の分散表現 (embedding) ベクトルでも応用可能なのか、少し考えてみる必要はあるかもしれません。

また、RNNは言語以外のシーケンシャルなデータに対しても応用できます。例えば、各タイムステップでの画像を入力にして動画を解析したり、株価の推移やFX取引のアルゴリズムとして使うことが可能です。

*1:あるいは平均を用いることもできます

*2:筆者の環境では1エポックに30分近くかかりました。

RNNで言語モデルを作るための理論

代表の島田です。scoutyでは、10/14 と 10/28 に下記の機械学習講習会を行いました。機械学習講習会#1は、このブログで紹介した言語モデルの基礎とNgramによる実践を半日で行うイベントです。
講習会の際に、RNNによる言語モデルについて教えて欲しいといった意見が多く頂きましたので、今回の記事ではRNNによる言語モデルを扱います。

scouty.connpass.com

今回は、ニューラルネットワークの拡張型RNNを用いて、ある文が与えられたときその次に来る単語の確率を与える言語モデルを作る過程を紹介します。この記事は、言語モデルに関わらずRNNの一般的な構造を取り扱うので、言語モデル以外にも応用することができます(間違いなどがあればコメントでお知らせください)。
RNNによる言語モデルは、scoutyにおいてもスカウトメールの解析などで使っていこうとしている技術のひとつです。

言語モデルとは

言語モデルとは、クロスエントロピーで名前から国籍判定する - scouty AI LAB の記事でご紹介したように、ある文が(学習元となったデータにおいて)生起する確率を与えるモデルで、例えば、以下のような関係を知ることができます。
$$
P(\mathrm{the\ cat\ slept\ peacefully}) > P(\mathrm{slept\ the\ peacefully\ cat})
$$

言語モデルがあることで、どの文章がより起こりやすいか=どの文章がより自然か を知ることができるので、機械翻訳音声認識など応用範囲は様々です。
これを応用すると、ある文 $x_t, \cdots ,x_1$ ($x_i$ は文内 $i$ 番目の単語)が与えられたとき、その次に続く単語として何が続くかを予測することができます。これを数式で表すと以下のようになります:
$$
P(x_{t+1} =v_j |x_t, \cdots ,x_1).
$$
つまり、言語モデルは ある単語の列 {x_t, \cdots , x_1} が与えられた時の次の単語(あるいは単語の列){x_{t+1}}{v_j} である確率を与えます。条件を付さなければ上のように、純粋に文章が生起する確率が得られます。

RNNとは

RNN (Recurrent Neural Network) とは、通常のフィードフォワード型のニューラルネットワークの拡張で、時系列データ(各時間でのスナップショットのシーケンス)や文(単語のシーケンス)のようなシーケンスを扱うことができるようにしたものです。各時刻 $t$ での隠れ層は、時刻 $t$ でのインプットに加え、時刻 $t-1$ の隠れ層を受け取り、両者の和をとります。基本的な構造は以下のようになります:


スクリーンショット 2016-03-25 0.35.02
画像は[1]より引用

言語モデルを作る場合、{\boldsymbol{x}(t)} は 文中 $t$ 番目の単語の one-hot ベクトル*1になります。そして、予測すべきは $\boldsymbol{x}(t)$ の次に来る単語で、これを $\boldsymbol{y}(t)$ として吐き出させることで、RNNで言語モデルを構築することができます。
実際に、$\boldsymbol{s}(t-1)$の先にも同じように $\boldsymbol{y}(t-1)$ の出力がついていて、これは $\boldsymbol{x}(t-1)$ の次に来る単語、つまり $\boldsymbol{x}(t)$ を予測するという構造になっています。

RNNでも重み行列を学習しますが、通常のNNでの学習と異なるのは3種類の異なる重みの推定を行うという点です。本記事では、3種類の重みを以下のように定義することとします:

  • $\boldsymbol{U}$: 隠れ層 → 隠れ層 の重み
  • $\boldsymbol{V}$: インプット → 隠れ層 の重み
  • $\boldsymbol{W}$: 隠れ層 → 出力 の重み

図中では同じ $\boldsymbol{U}$ がたくさん現れていますが、これらはすべて同じ行列を表しています。ある単語の次にどの単語が来るかは、単語の絶対的な位置に依存しないので、重み行列は同じであって当然という前提に基づきます。

基本構造:もう少し詳しく

入力 $\boldsymbol{x}^{(t)}$ に対して、出力 $\boldsymbol{y}^{(t)}$ を計算する Forward Propagation は次のようになります*2:


\begin{align}
\boldsymbol{\mathrm{net}}^{(t)}_{\mathrm{in}} &= \boldsymbol{Vx}^{(t)} + \boldsymbol{Us}^{(t-1)},\\\\
\boldsymbol{s}^{(t)} &= \mathrm{sigmoid}(\boldsymbol{\mathrm{net}}^{(t)}_{\mathrm{in}}),\\\\
\boldsymbol{\mathrm{net}}_{\mathrm{out}}^{(t)} &= \boldsymbol{W}\boldsymbol{s}^{(t)}, \\\\
\boldsymbol{y}^{(t)} &= \mathrm{softmax}(\boldsymbol{\mathrm{net}}^{(t)}_{\mathrm{out}}).
\end{align}


これを $t=0$ から $t=n-1$ まで順に計算してやればよいわけです。

$\boldsymbol{y}^{(t)}$ ベクトルは $t$ 番目の単語の次にくる予測単語の確率を表しています。ある単語のベクトル表現を $\boldsymbol{w}_j$ とすると、 $\boldsymbol{w}_j$ は $j$ 番目の要素が1、他が0のベクトルとなります。 $\boldsymbol{w}_j$ が $t + 1$ 番目の単語として出現する確率 $y^t_j$ は次のように表されます:
\begin{eqnarray*}
P(\boldsymbol{x}^{(t+1)}=\boldsymbol{w}_j|\boldsymbol{x}^{(t)}, \cdots, \boldsymbol{x}^{(1)}) = y^t_j.
\end{eqnarray*}

$\boldsymbol{x}^{(t)}$ は $t$ 番目の単語の one-hot ベクトルであるように、 $\boldsymbol{y}^{(t)}$ は次元数が全単語数のベクトルで、softmax 関数を使っているので全要素の和が1になります。

ボキャブラリーサイズ(全単語数)を $|V|$, 隠れ層 $\boldsymbol{s}$ の次元数を $D_h$*3 としたとき、各重み行列は以下のような形をとります:

\begin{align}
\boldsymbol{U} \in \mathbb{R}^{D_h \times D_h}, \
\boldsymbol{V} \in \mathbb{R}^{D_h \times |V|}, \
\boldsymbol{W} \in \mathbb{R}^{|V| \times D_h}.
\end{align}


学習アルゴリズム:BPTT

通常のNNでは Back Propagation (誤差逆伝播法)が学習アルゴリズムに用いられますが、RNNでは Back Propagation Through Time (BPTT) というアルゴリズムが用いられます。基本アイディアは同じで、出力層の誤差を重みを通じて伝播させていきますが、BPTTでは出力層 $\boldsymbol{y}^{(t)}$ での誤差のみならず、一つの前の時刻 $t-1$ の誤差を加算するという点が異なります。もちろん、一つだけとは言わず任意のタイムステップ  \tau だけ遡った誤差を加味してもよいでしょう。ただし、  \tau を大きくすれば当然計算量は増えます。例えば、10ステップ前の誤差まで加味するRNNはレイヤー数の10の FeedFoward NN の計算量に匹敵するので、適切な  \tau を設定する必要があります(これもまたハイパーパラメータ)。遡るタイムステップ数  \tau を限定したBPTTを Truncated BPTT といいます。また、多層NNと同じ理由で Vanishing Gradient Problem も発生します*4

以下が、各重みの更新式となります。添字 $p$ は、データ内から取ってきたセンテンス $p$ を表します。$\boldsymbol{x}_p^{(t)}$ は $p$ 内の $t$ 番目の単語です。$\boldsymbol{d}_p^{(t)}$ は教師データ (desired vector)、つまり、$\boldsymbol{x}_p^{(t)}$ の次に実際に来る単語です。通常、$\boldsymbol{d}_p^{(t)} = \boldsymbol{x}_p^{(t+1)}$ となります。

\begin{align}
\boldsymbol{\delta}_{\mathrm{out},\ p}^{(t)} &= (\boldsymbol{d}_p^{(t)} - \boldsymbol{y}_p^{(t)})g'(\boldsymbol{\mathrm{net}}_{\mathrm{out},\ p}^{(t)})\\
\Delta\boldsymbol{W}_p^{(t)} &= \boldsymbol{\delta}_{\mathrm{out},\ p}^{(t)}\ \boldsymbol{s}_p^{(t)} \\
\end{align}

$\boldsymbol{W}$ は純粋にアウトプットのエラー($\boldsymbol{d}_p^{(t)}$ と $\boldsymbol{y}_p^{(t)}$ の差)だけに影響を受けるので、タイムステップを遡る必要はありません。

\begin{align}
\boldsymbol{\delta}_{\mathrm{in},\ p}^{(t-k)} &= \begin{cases}
\boldsymbol{W}^T\boldsymbol{\delta}_{\mathrm{out},\ p}^{(t)}\ f'(\boldsymbol{\mathrm{net}}_{\mathrm{in},\ p}^{(t)}) & (k=0) \\
\boldsymbol{U}^T\boldsymbol{\delta}_{\mathrm{in},\ p}^{(t-k+1)} f'(\boldsymbol{\mathrm{net}}_{\mathrm{in},\ p}^{(t-k)}) & (k>0)
\end{cases}\\
\Delta\boldsymbol{V}_p^{(t-k)} &= \boldsymbol{\delta}_{\mathrm{in},\ p}^{(t-k)} \otimes \boldsymbol{x}_p^{(t-k)}\\
\Delta\boldsymbol{U}_p^{(t-k)} &= \boldsymbol{\delta}_{\mathrm{in},\ p}^{(t-k)} \otimes \boldsymbol{s}_p^{({t-k}-1)}
\end{align}

なお、$\otimes$は直積 (outer product) を表します。
今回は損失関数に Cross Entropy Loss を使用し、活性化関数に sigmoid と softmax を利用しているので、それらの微分は次のようになります:
\begin{align}
g'(\boldsymbol{\mathrm{net}}_{\mathrm{out},\ p}^{(t)}) &= \boldsymbol{I},\\
f'(\boldsymbol{\mathrm{net}}_{\mathrm{in},\ p}^{(t)}) &= \boldsymbol{s}_p^{(t)} ( \boldsymbol{I} - \boldsymbol{s}_p^{(t)}).
\end{align}

$\boldsymbol{I}$ は成分がすべて1である適当な長さのベクトルです。なお、断りがない限りベクトルどうしの積は対応要素ごとの積を並べたベクトルとします。

これを遡るタイムステップ  \tau ぶんと、全データ $N$ について足し合わせればよいというわけです。つまり、重みの更新式は以下のようになります:


 \boldsymbol{W} \leftarrow \boldsymbol{W} + \eta \sum_{p}^{N} \sum_{t=1}^{n} \Delta\boldsymbol{W}p^{(t)}.

ネットワークの形は異なりますが、学習や順伝播など、基本的なアイディアはほとんど通常のニューラルネットワークと同じですね。
希望が多ければ、次回の scouty 機械学習講習会イベントのテーマにすることも検討しています。

参考文献

[1] Guo, Jiang. "Backpropagation through time." Unpubl. ms., Harbin Institute of Technology, 2013.
https://pdfs.semanticscholar.org/c77f/7264096cc9555cd0533c0dc28e909f9977f2.pdf

[2] Frank Keller, Natural Language Understanding - Lecture 11 Recurrent Neural Network, University of Edinburgh, 2016.

*1:次元数=全単語数で、そのベクトルが表現する単語の成分だけ1,他が0になっているベクトル。

*2:$t$ に関する表記法が [1] とやや異なりますが、本記事の以降の部分はこちらのノーテーションに従うこととします。

*3:$D_h$は自由に決められるハイパーパラメータの一つです。

*4:その解決策がいわゆるLSTMです。

Recursive Autoencoder で文の分散表現

scouty 代表の島田です。 トピックモデルで単語の分散表現 - 理論編 - scouty AI LAB では、局所表現・分散表現の違いに関して説明しましたが、「単語の分散表現と同じように、文*1の分散表現を作るにはどうすればよいか?」というのが今回のテーマです。 CNNで文の識別タスクを解く - scouty AI LAB でもCNNによって文の分散表現を作る方法を扱いましたが、本記事では Recursive Autoencoder によって文の分散表現を作る方法をご紹介します。

Autoencoder とは何か

Recursive Autoencoder は、 Autoencoder (オートエンコーダー)を組み合わせることによって文の意味表現をひとつのベクトルとして表そうとするモデルです。 Autoencoder というのは、入力ベクトルを受け取ったら、入力ベクトルと全く同一のベクトルを出力することを目的として学習を行う特殊なニューラルネットワークモデルのひとつで、次のような図で表されます。

スクリーンショット 2016-04-16 17.05.28

入力次元よりも隠れ層の次元が大きければ入力 $x_i$ をただ単に出力 $\hat{x}_i$ に受け流すという自明な解が存在してしまうので、一般的には隠れ層の次元は入力層の次元より小さく設定します。学習が終了したのちに何らかのベクトルを入力すると、隠れ層の値は入力ベクトルの圧縮表現になっています。つまり、隠れ層は「学習した重みと掛け合わせることで入力が再現できるベクトル」であり、隠れ層の次元は入力ベクトルの次元より小さくなっているので、 Autoencoder はベクトルの次元縮約器であると解釈することができます*2

Autoencoder では入力ベクトルを再現するように学習を行うため、教師データが必要ありません。そのため、 Autoencoder は教師なし学習に分類されます。$d$ 次元の入力ベクトル $\boldsymbol{x} = (x_1, \dots, x_d)^T$ に対する Autoencoder の出力を $\boldsymbol{\hat{x}} = (\hat{x}_1, \dots, \hat{x}_d)^T$ とすると、$\boldsymbol{\hat{x}}$ は以下のように表されます。

$$ \boldsymbol{h} = \boldsymbol{f}(\boldsymbol{W_1}\boldsymbol{x} + \boldsymbol{b_1}),\\ \boldsymbol{\hat{x}} = \boldsymbol{f}(\boldsymbol{W_2}\boldsymbol{h} + \boldsymbol{b_2}). $$

ここで、 $\boldsymbol{h}$ は隠れ層を表すベクトル、$\boldsymbol{b_1}$, $\boldsymbol{b_2}$ はバイアスを表すベクトル、 $\boldsymbol{W_1}$, $\boldsymbol{W_2}$ は重みを表す行列です。 また、$\boldsymbol{f}$ は sigmoid などの活性化関数であり、上式では、ベクトルを入力として同次元のベクトルを出力するものとして定式化しています。

Autoencoder の二乗誤差は

$$L(\boldsymbol{x}, \boldsymbol{\hat{x}}) = ||\boldsymbol{x} - \boldsymbol{\hat{x}}||^2 $$

Cross Entropy 誤差は

$$L(\boldsymbol{x}, \boldsymbol{\hat{x}}) = -\sum_{k=1}^d x_k \log \hat{x}_k + (1-x_k) \log(1-\hat{x}_k)$$

と表され、トレーニングなどは通常のNNと同じように行われます。 ちなみに、Autoencoder の最も代表的な使いみちは Recursive Autoencoder ではなく、深層学習のプレトレーニングで使われる Stacked Autoencoder でしょう。こちらに関しては既に扱っている文献が多いので、そちらを参照することをおすすめします。

Recursive Autoencoder

Recursive Autoencoder は、Autoencoder を次のように積み上げます。

スクリーンショット 2016-04-16 17.27.04

一番下の各 $\boldsymbol{x}$ は、文内の各単語の分散表現になります。まず文章内の単語を二分木で表すのが第一ステップです。$[\boldsymbol{x}_3; \boldsymbol{x}_4]$ は、ベクトル $\boldsymbol{x}_3$ と $\boldsymbol{x}_4$ を単に縦に連結させたベクトルです。この2つのベクトルを Autoencoder で圧縮し、その隠れ層をベクトル $\boldsymbol{y}_1$ とします*3。次に、 $\boldsymbol{x}_2$ と $\boldsymbol{y}_1$ の連結を圧縮し、$\boldsymbol{y}_2$ を作ります。このように単語を再帰的に圧縮していってできた最後のベクトル $\boldsymbol{y}_3$ が、この文の分散表現となります

文から上図のような二分木を作る手法は様々ですが、次のような greedy な手法が用いられることが多いようです。

  1. 文中の単語数を $n$ 個として、隣り合う $n-1$ 個のペアを考える。
  2. それぞれのペアの Reconstruction Error を計算する。
  3. その中で一番小さいエラーのペアを選択し、実際にペアを作る。
  4. ペアをひとつの単語(ノード)に見立てて、1〜3を繰り返す。 隣り合う単語しか見ないので完全なアルゴリズムではありませんが、妥当な時間でそこそこ良い結果を与えることが知られています。

Recursive Autoencoder の性能の尺度である Reconstruction Error (ノード $y_k$ のエラー) $E_{\mathrm{rec}}(k)$ は次のように表されます。

$$E_{\mathrm{rec}}(k) = \frac{n_i}{n_i + n_j} ||x_i - \hat{x}_i||^2 + \frac{n_j}{n_i + n_j} ||x_j - \hat{x}_j||^2$$

$n_i$ はノード $i$ に含まれている単語数。長い単語の連続ほど再現が難しくなるので、重みをつけています。また、全ツリーのコストは単にこれらのエラーの和とすればよいです。

スクリーンショット 2016-04-16 20.40.47

これを使って文のクラス識別器を作る場合、上図のように最終的に得られた $n-1$ 個の隠れ層 $y_k$ を入力として、各クラスの確率分布を出力とするニューラルネットを新たに作れば良いわけです。各 $y_k$ のラベルはすべて文のクラス(ラベル)と同じとします。教師データは該当クラスが1、他が0になっているベクトルになります。 これはあくまでクラス識別や類似性評価に使うものであるので、このベクトル自体から文の意味を抽出して推論などができるかと言われれば、それはできません。しかし、Recursive Autoencoder を使えば文のクラス識別を、人間の手作りの特徴(e.g.センチメント解析ならポジティブ・ネガティブを表す特徴語)を使うことなく行うことができるという点では非常に便利なモデルであると言えるでしょう。

ただし、実運用においては、2017年10月時点では便利なライブラリが無いため、実際にビジネスの場で利用するのは難しそうです。 CNNによる文の分散表現のほうが一般的に文識別などでは精度が高いことが知られていますが、実際に実装する場合を考えたら、Recursive Autoencoder の方が実装コストは少なくて済むでしょう。 Autoencoder は、今回の例のように文だけでなく一般的な次元圧縮器になるので、自然言語処理のみならず幅広く応用することができると考えられます。

詳細のアーキテクチャや説明は、こちらの論文 に詳しく書かれていますので、詳細が気になった方はそちらをご覧ください。

*1:"I saw him swimming." など、連続した単語のピリオドまでのまとまりのこと。

*2:アウトプットは特に重要ではないので捨てちゃうのが普通。

*3:出力はいらないので無視。