宇宙の渚を眺めるエンジニアのブログ

技術的な備忘録や、日々思ったこと、たまに宇宙関連について綴る

Multi-Head Self-Attentionを用いたSNLIタスク

勤め先のグループで、ここ一年間SNLI(Stanford Natural Language Inference)というタスクに取り組もうということになっていた。どういう手法でタスクに取り組もうかと調べていたときに、最近発表されたBERTモデルがその元となったTransformerというモデルのMulti-Head Self-Attentionを利用していると知り、Multi-Head Self-Attentionを自分で実装して、SNLIタスクに取り組んでみた。

SNLIタスクについて

SNLIは、英語で書かれた前提となる文と仮説となる文の2文を比較して、それが、entailment(含意)・contradiction(矛盾)・neutral(中立)のどれに当てはまるかという推論を行う3値分類のタスクである。より詳しくは、

などを参照。

SNLIのデータ件数
  • train: 549367/550152 (gold_label ありデータ/全データ)
  • validation: 9842/ 10000 (gold_label ありデータ/全データ)
  • test: 9824/ 10000 (gold_label ありデータ/全データ)

Multi-Head Self-Attention について

Scaled Dot-Product Attention

Transformerの論文では、まずScaled Dot-Product Attentionが導入されており、以下のような式で表される。


{\rm Attention}(Q, K, V) = {\rm softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Self-Attentionの場合は、Q, K, V はすべて同じものが入力される。 また、Q, K^{T}内積\sqrt{d_k}でスケーリングしているのは、次元数\sqrt{d_k}が大きいと内積が大きくなり、逆伝搬のsoftmaxの勾配が小さくなることを防いでいるようだ。

Multi-Head Attention

続いて論文では、以下の式でMulti-Head Attentionが導入されている。


\begin{eqnarray}
{\rm MultiHead}(Q, K, V) &=& {\rm Concat}({\rm head_1, ..., head_h})W^O \\\
{\rm where}  \;\;\; {\rm head_i} &=& {\rm Attention}(QW^Q_i, KW^K_i, VW^V_i) \\\
\\\
W^Q_i \in \mathbb{R}^{d_{\rm model}\times d_k} , W^K_i &&\in \mathbb{R}^{d_{\rm model}\times d_k}, W^V_i \in \mathbb{R}^{d_{\rm model}\times d_v}, W^O_i \in \mathbb{R}^{hd_{v}\times d_{\rm model}}
\end{eqnarray}

 これは、オリジナルの単語分散表現の次元( d_{\rm model} )をh分割して、h個のScaled Dot-Product Attentionを実行し、それらを連結させてW^{O}との内積をとることで元の d_{\rm model} 次元に写像するものである。このh分割のAttentionを使用することをMulti-Head Attentionと呼んでおり、Q, K, Vが全て同じ入力の場合はMulti-Head Self-Attentionとなる。
 単語分散表現の次元をh分割することによって、一つ一つのAttentionの性能としては落ちるものの、分散表現次元の特定の部分空間のAttentionを、各Headが役割を分担させて実施させることが、性能向上につながるのだろうというのが、自分なりの理解である。

TransformerやMulti-Head Attentionについて、より詳しくは、

などの記事を参照。

実験

モデル

Keras(バックエンド: Tensorflow) を用いて、以下のようなSNLI用のモデルを作成した。コードはこちら

f:id:hiro-uchi:20190317213739p:plain
  1. InputLayer: 系列の入力
  2. Embedding: 系列を単語分散表現に変換
  3. MHA: Multi-Head Attention層
  4. PFFN: Position-wise Feed-Forward Networks層
  5. Flatten: 135x300次元
  6. Dense: 全結合層

である。

4 の Position-wise Feed-Forward Networks (PFFN)は、系列の各位置の分散表現 xに対して、


{\rm FFN}(x) = {\rm max}\left(0, xW_1+b_1\right)W_2 + b_2

という変換を施すものである。

実験設定と環境

上記のようなモデルを用いて、以下のような設定・環境で実験を行った。

  • 前提と仮説の2文は以下のように、"|"を用いて結合して入力とした
    • 前提: a biker race
    • 仮説: the car is yellow
    • 結合文: a biker races [padding] | the car is yellow [padding]
    • 結合文の[padding]の部分は、それぞれ(仮説文/前提文)の最大文長まで調整することを意味
  • Embeddingには、事前学習モデルとしてglove.6B.300d を用いて、事後学習を実施
  • gold_labelありのデータについてのみ学習と評価に用いた
  • GPU: GeForce GTX 1080Ti

結果と考察

Head数ごとの精度

 8epoch 回した、 Head数に応じた評価セットの精度は以下のとおり

Head数 1 2 6 10 20
train 精度 [%] 79.1 81.0 83.8 84.3 85.6
test 精度 [%] 74.3 75.4 78.3 77.8 78.4

 確かに Head数を増やすことで、testの精度が上昇する結果となった。 8epoch より先まで実行すると、train精度のみが上昇する傾向であった。 Head数をさらに増やすとどうなるかというところが気になる点であったが、Head数を 20にした段階で、実行の速度の低下がみられたのでそれ以上の実験は行わなかった。 実行速度の低下は、Headの分割を増やしたことで、GPU計算での並列性が落ちたためと考えられる。

考察 ~HeadごとのAttention重みの可視化~

 最後に、HeadごとのAttentionがどのような役割を果たしているかを確認するため、 Attention重みの

 
  {\rm softmax}(QW_i(KW_i)^T/ \sqrt{d_k})

を可視化してみる。

entailment と正答した場合

前提: a land rover is being driven across a river
仮説: a vehicle is crossing a river
の2文に対して、2つのHead で推論した場合のAttenstion 重みを可視化したのが以下の2つの図である。

f:id:hiro-uchi:20190317204813p:plain
Head 1 のAttention重み

f:id:hiro-uchi:20190317204834p:plain
Head 2 のAttention重み

色が濃い部分が、より注意が向いていることを表している。Head 1 の方は、主に "driven"という単語に注目し、Head 2 の方は"crossing"という単語に最も注意が向いている。他にも"river"などにも注意が向いており、結果的に2つの文が合致するという判定するのに必要な単語に注意が向いていることが見て取れる。この場合、モデルはentailmentと正答している。

contradiction を entailmentと誤答した場合

前提: a woman with a green headscarf blue shirt and a very big grin
仮説: the woman has been shot
の2文に対して、2つのHead で推論した場合のAttenstion 重みを可視化したのが以下の2つの図である。

f:id:hiro-uchi:20190317210124p:plain
Head 1 のAttention重み

f:id:hiro-uchi:20190317210316p:plain
Head 2 のAttention重み

2つHeadのAttention重みを見てみると、"woman"や"grin"等に注意が向いており、最終的にentailmentという推論が得られている。 contradictionを導くのに重要と思われる"shot"等には注意が向いていないが、Headを増やすとどうなるか、というのがさらなる考察ポイントだろう。

まとめ

 今回は、Transformerモデルで導入されているMulti-Head Self-Attention層を、Kerasで自作してSNLIタスクに応用した。その結果、Multi-Head Attentionの有用性が確認された。ただし、今回はMulti-Head Self-Attentionを利用するというところも目的に置いていたので、SNLIタスクのように、2文の比較による推論の場合は、通常のSource-Target Attentionを用いたほうが良さそうな気がするのでやってみたいと思う。