Kerasで自作レイヤーを保存&ロードするときにはget_configが必要
Kerasで自作レイヤーを作るときには、最低限build
, call
, compute_output_shape
の3つのメソッドを定義していればよいとある。
例えば、コンストラクタに、以下のような定義をしてあるPFNN
という自作レイヤーを作ったときに、
def __init__(self, ff_dim, **kwargs): self.ff_dim = ff_dim super(PFNN, self).__init__(**kwargs)
その自作レイヤーを使ったモデルを保存&ロードしようとすると、
TypeError: __init__() missing 1 required positional argument: 'ff_dim'
のようなエラーが出てしまう。これは、ロード時にget_config
というメソッドが必要なためで、自作レイヤーに
def get_config(self): config = {'ff_dim': self.ff_dim} base_config = super(PFNN, self).get_config() return dict(list(base_config.items()) + list(config.items()))
のように定義しておけばよい。