我基于GRU训练以下模型,请注意,我将参数stateful=True
传递给GRU构建器。
class LearningToSurpriseModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
stateful=True,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
[defining here my training step]
我实例化我的模型
model = LearningToSurpriseModel(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units
)
[编译并做事情]
和EPOCHS
纪元
for i in range(EPOCHS):
model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS], verbose=1)
model.reset_states()
此代码关于GRU状态的行为是什么:状态是针对每个新的数据批次进行更新,还是仅针对每个新时期进行更新?所需的行为是仅对每个新纪元进行重置。如果没有完成,如何实现?
编辑
TensorFlow为Models
实现reset_states
函数
def reset_states(self):
for layer in self.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
layer.reset_states()
是否意味着(与文档的其他含义相反)只有在stateful=False
的情况下才能重置状态?这是我从getattr(layer, 'stateful', False)
上的条件推断的。
您可以尝试重置自定义Callback
中的状态:
model = LearningToSurpriseModel(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units
)
gru_layer = model.layers[1]
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, gru_layer):
self.gru_layer = gru_layer
def on_epoch_end(self, epoch, logs=None):
self.gru_layer.reset_states()
model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
另请参阅post有关何时重置状态的信息。