宇宙の渚を眺めるエンジニアのブログ

技術的な備忘録や、日々思ったこと、たまに宇宙関連について綴る

Keras のBatchNormalization レイヤーのバグ?

Keras(version 2.2.4)でBatchNormarizationを下のようにかけようとして、

model = BatchNormalization(axis=1)(model)

次のようなエラーが出た。

Shape must be rank 1 but is rank 0 for 'batch_normalization_1/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,96,1,1], [].
Traceback (most recent call last):
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/site-packages/tqdm/std.py", line 1039, in __del__
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/site-packages/tqdm/std.py", line 1223, in close
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/site-packages/tqdm/std.py", line 555, in _decr_instances
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/site-packages/tqdm/_monitor.py", line 51, in exit
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/threading.py", line 521, in set
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/threading.py", line 364, in notify_all
  File "/home/hirotoshi/.pyenv/versions/anaconda3-2019.10/lib/python3.6/threading.py", line 347, in notify
TypeError: 'NoneType' object is not callable

入力 tensor の shape 等には問題なし。

https://github.com/keras-team/keras/issues/10648

を参考に、 Keras のライブラリ中の tensorflow_backend.py の "()"を"[]"に変更することでバグが解消された。該当行は、1908,1910,1914,1918。バージョンによっては、行が若干違うかも。