宇宙の渚を眺めるエンジニアのブログ

技術的な備忘録や、日々思ったこと、たまに宇宙関連について綴る

Kerasで自作レイヤーを保存&ロードするときにはget_configが必要

Kerasで自作レイヤーを作るときには、最低限build, call, compute_output_shapeの3つのメソッドを定義していればよいとある。 例えば、コンストラクタに、以下のような定義をしてあるPFNNという自作レイヤーを作ったときに、

def __init__(self, ff_dim, **kwargs):
    self.ff_dim = ff_dim
    super(PFNN, self).__init__(**kwargs)

その自作レイヤーを使ったモデルを保存&ロードしようとすると、

TypeError: __init__() missing 1 required positional argument: 'ff_dim'

のようなエラーが出てしまう。これは、ロード時にget_configというメソッドが必要なためで、自作レイヤーに

def get_config(self):
    config = {'ff_dim': self.ff_dim}
    base_config = super(PFNN, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

のように定義しておけばよい。