Word2Vec のメモ その1

Chainer 本 を読みながら Word2Vec を Chainer で実装してみたので、その過程でわかったことをメモしておく。

注意: 素人なので完全に間違っているかもしれない。

Word2Vec

  • Word2Vec の目的は、各単語の 分散表現 を求めること。
    • 単語の分散表現とは、単語の意味を低次元(100次元とか)の密な実ベクトルで表したもの。
  • Word2Vec では、分散表現を求めるために Continuous Bag-of-Words (CBOW) か、skip-gram が利用できる。Chainer 本では skip-gram のみが解説されていたため、この記事でも skip-gram のみを扱う。

skip-gram

skip-gram とは、雑に言うと、「ある単語が与えられたときに、その周囲に現れる単語を予測する」という問題を考え、これを精度よく答えられるように各単語の分散表現ベクトルを学習してくモデル。

入力されたコーパス上の位置 { t } からオフセット { b } 以内にある単語群 { w_{t-b}, \dots, w_{t-1}, w_{t+1}, \dots, w_{t+b} } を位置 { t } における 文脈 { c_t } と定義する。 文脈という言葉を使って上の問題を言い換えると「位置 { t } にある単語 { w_t } が与えられたとき、その文脈 { c_t } を予測する問題」となる。

文脈内の各単語の条件付き独立性を仮定して、{ p( c_t \mid w_t ) } を以下のようにモデル化する。

{ \displaystyle
    p( c_t \mid w_t ) = \prod_{ c \in c_t } p( c \mid w_t )
}

図で書くとこんな感じ( b=2 のとき)

f:id:nojima718:20170821012826p:plain

さらに、{ p( c \mid w_t ) } を、{ c, w_t } の分散表現 { \vec{v}_c, \vec{v}_{w_t} } を使って、以下のように分散表現の内積のソフトマックスでモデル化する。(各単語の分散表現はパラメータであり、学習によって求める)

{ \displaystyle
    p( c \mid w_t ) = \frac{\exp( \vec{v}_c \cdot \vec{v}_{w_t} )}{\sum_{w' \in V} \exp( \vec{v}_{w'} \cdot \vec{v}_{w_t} )}
}

ただし、 V は語彙集合。

……と本には書いてあるんだけど、自分が元論文を読んだ感じでは単語のベクトル表現は2種類 (入力用、出力用) があり、{ w_t } は入力用のベクトル表現(=分散表現)でベクトル化し、{ c } は出力用のベクトル表現でベクトル化しているように見えた。 つまり、2種類のベクトル表現 { \vec{u}, \vec{v} } を使って以下のようにモデル化する。

{\displaystyle
    p( c \mid w_t ) = \frac{\exp( \vec{u}_c \cdot \vec{v}_{w_t} )}{\sum_{w' \in V} \exp( \vec{v}_{w'} \cdot \vec{v}_{w_t} )}
}

どっちがいいのかよくわからなかったので、両方実装して実験してみたので結果については後述する。

いずれにしても、これで文脈の条件付き確率が定義できたので、コーパスの対数尤度が以下の式で計算できる。

{
\begin{align*}
    \mathcal{L} &= \sum_{t=1}^T \log p( {\bf c}_t \mid w_t ) \\
                &= \sum_{t=1}^T \sum_{c \in {\bf c}_t} \log p( c \mid w_t ) \\
                &= \sum_{t=1}^T \sum_{c \in {\bf c}_t} \left( (\vec{v}_c \cdot \vec{v}_{w_t}) - \log \sum_{w' \in V} \exp( \vec{v}_{w'} \cdot \vec{v}_{w_t} ) \right)
\end{align*}
}

ただし、 T は訓練データのサイズ。

しかし、この尤度は { \log \sum_{w' \in V} \exp( \vec{v}_{w'} \cdot \vec{v}_{w_t} ) } の項があるため、学習のための計算コストが大きい。 元論文では語彙のサイズが70万ぐらいで、分散表現の次元が300であるため、70万×300 = 2億1000万回の演算が一つの入力ごとに必要になる。 入力データセットは、元論文の実験では300億単語あり、context size が 5 なので、1500億個の入力があることになる。 よって、 \mathcal{L} を計算するには 2億1000万×1500億=3150京回の演算が必要になる。 スーパーコンピューターなら計算できるかもしれないけど、普通はこのサイズの計算は無理。

よって、モデルをちょっと変えて、もっと計算しやすいモデルを作る。

次の問題を考える。

「位置  t の単語  w_t と、以下のいずれかの単語  c が与えられる。

  • 正例: 位置  t の文脈  c_t に属する単語
  • 負例: ランダムな単語

このとき、 c が正例なのか負例なのかを判別せよ1

つまり、もともとの問題は与えられた単語の近傍の単語の確率分布を求めていたが、新しい問題は近傍の単語とノイズとを区別する2クラス分類問題になっている。

まず、確率変数  D を導入して、答えが正例であるとき  D=1 とし、負例であるとき  D=0 とする。

 c w_t が与えられたときの  D の確率を以下の式でモデル化する。

{
    p(D=1 \mid c, w_t) = \sigma(\vec{v}_c \cdot \vec{v}_{w_t})
}

ただし、 \sigma(x)シグモイド関数 { 1 / (1 + \exp(-x)) }

要するに、分散表現の内積が大きい単語同士は近傍に出現しやすいというモデルになっている。 (ベクトルを2つ使うモデルだとちょっと違う解釈をしないといけない)

次にパラメータ(=分散表現)の学習のために誤差関数を作る。 Word2Vec ではクロスエントロピー誤差  H(p, q) = -p \log q - (1-p) \log (1-q) を用いる。

訓練データ  (w_t, c, D) に対するこのモデルのクロスエントロピー誤差は以下の式で与えられる。

{
\begin{align*}
    E &= H( D, p(D=1 \mid c, w_t) ) \\
      &= -D \log p(D=1 \mid c, w_t) - (1 - D) \log (1 - p(D=1 \mid c, w_t)) \\
      &= -D \log \sigma(\vec{v}_c \cdot \vec{v}_{w_t}) - (1 - D) \log ( 1 - \sigma(\vec{v}_c \cdot \vec{v}_{w_t}) )
\end{align*}
}

もともとの問題の対数尤度関数  \mathcal{L} と違って語彙集合の上を走る和がないことに注目してほしい。

 E の形をよく見ると、ロジスティック回帰の対数尤度関数とほぼ同じ式になっている。 なので、この手法は2クラス分類問題をロジスティック回帰を用いて解いていると言ってもいいのかもしれない。 (ロジスティック回帰の定義をよく知らないので間違ってるかも)

あとは訓練データを作れば学習できる。 訓練データは後述する ノイズ分布 とハイパーパラメータ  k \in \mathbb{N} を用いて以下のように作る。

  • 学習データ内の各位置  t について
    • 位置  t の文脈内の各単語  c について
      • 正例  (w_t, c, D=1) を訓練データに加える。
      • ノイズ分布から  k 個単語  c' をサンプルし、それぞれの  c' に対して負例  (w_t, c', D=0) を訓練データに加える。

ノイズ分布 は負例をサンプリングするための単語の確率分布で、Word2Vec では以下の分布を使うのが実験的によい結果を残しているらしい。

{ \displaystyle
p(w) = \frac{\mathrm{freq}(w)^{0.75}}{\sum_{w' \in V} \mathrm{freq}(w')^{0.75} }
}

ただし、 \mathrm{freq}(w) は、コーパスにおける単語  w の頻度。

このように負例をサンプリングして学習する手法を negative sampling と呼ぶらしい。

実装と実験

これを Chainer で実装していろいろ実験してみたけど、今日は疲れたので明日の記事で


  1. Chainer 本の説明では「文脈上の単語  c と (1)  w_t または (2) ランダムな単語 が与えられて、(1) か (2) かを区別する」問題を解いており、ここで紹介した問題は解いていない。元の論文では、(自分の理解が正しいとすると)この記事で紹介したほうの問題を解いていると思う。