这里有一个EarlyStoppingHook
示例实现:
import numpy as np
import tensorflow as tf
import logging
from tensorflow.python.training import session_run_hook
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
"""
"""
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]
def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
current = run_values.results
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()
这个实现是基于Keras 实现 https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/python/keras/_impl/keras/callbacks.py.
与 CNN MNIST 一起使用example https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/layers/cnn_mnist.py创建钩子并将其传递给train
.
early_stopping_hook = EarlyStoppingHook(monitor='sparse_softmax_cross_entropy_loss/value', patience=10)
mnist_classifier.train(
input_fn=train_input_fn,
steps=20000,
hooks=[logging_hook, early_stopping_hook])
Here sparse_softmax_cross_entropy_loss/value
是该示例中损失操作的名称。
EDIT 1:
使用估计器时似乎没有“官方”方法来查找损失节点(或者我找不到它)。
For the DNNRegressor
该节点有名称dnn/head/weighted_loss/Sum
.
以下是如何在图中找到它:
-
在模型目录中启动tensorboard。就我而言,我没有设置任何目录,因此估算器使用临时目录并打印此行:
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpInj8SC
启动张量板:
tensorboard --logdir /tmp/tmpInj8SC
Open it in browser and navigate to GRAPHS tab.
Find loss in the graph. Expand blocks in the sequence: dnn
→ head
→ weighted_loss
and click on the Sum
node (note that there is summary node named loss
connected to it).
右侧信息“窗口”中显示的名称是所选节点的名称,需要传递给monitor
参数pfEarlyStoppingHook
.
的损失节点DNNClassifier
默认情况下具有相同的名称。两个都DNNClassifier
and DNNRegressor
有可选参数loss_reduction
影响丢失节点名称和行为(默认为losses.Reduction.SUM
).
EDIT 2:
有一种不看图表就能找到损失的方法。
您可以使用GraphKeys.LOSSES
收集以获得损失。但这种方式只有在训练开始后才有效。所以你只能在钩子中使用它。
例如,您可以删除monitor
论证来自EarlyStoppingHook
类并改变它的begin
函数始终使用集合中的第一个损失:
self.monitor = tf.get_default_graph().get_collection(tf.GraphKeys.LOSSES)[0]
您可能还需要检查集合中是否存在丢失。