使用可用的训练挂钩在 tf.estimator.DNNRegressor 中实现提前停止

2024-01-08

我是张量流新手,想要实现提前停止tf.estimator.DNNRegressor带有可用的训练挂钩训练挂钩 https://www.tensorflow.org/api_guides/python/train#Training_Hooks对于 MNIST 数据集。如果在指定的步数内损失没有改善,早期停止钩子将停止训练。 Tensorflow 文档仅提供示例记录钩子 https://www.tensorflow.org/tutorials/layers#set_up_a_logging_hook。有人可以编写一个代码片段来实现它吗?


这里有一个EarlyStoppingHook示例实现:

import numpy as np
import tensorflow as tf
import logging
from tensorflow.python.training import session_run_hook


class EarlyStoppingHook(session_run_hook.SessionRunHook):
    """Hook that requests stop at a specified step."""

    def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                 mode='auto'):
        """
        """
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        if mode not in ['auto', 'min', 'max']:
            logging.warning('EarlyStopping mode %s is unknown, '
                            'fallback to auto mode.', mode, RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

        self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def begin(self):
        # Convert names to tensors if given
        graph = tf.get_default_graph()
        self.monitor = graph.as_graph_element(self.monitor)
        if isinstance(self.monitor, tf.Operation):
            self.monitor = self.monitor.outputs[0]

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return session_run_hook.SessionRunArgs(self.monitor)

    def after_run(self, run_context, run_values):
        current = run_values.results

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                run_context.request_stop()

这个实现是基于Keras 实现 https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/python/keras/_impl/keras/callbacks.py.

与 CNN MNIST 一起使用example https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/layers/cnn_mnist.py创建钩子并将其传递给train.

early_stopping_hook = EarlyStoppingHook(monitor='sparse_softmax_cross_entropy_loss/value', patience=10)

mnist_classifier.train(
  input_fn=train_input_fn,
  steps=20000,
  hooks=[logging_hook, early_stopping_hook])

Here sparse_softmax_cross_entropy_loss/value是该示例中损失操作的名称。

EDIT 1:

使用估计器时似乎没有“官方”方法来查找损失节点(或者我找不到它)。

For the DNNRegressor该节点有名称dnn/head/weighted_loss/Sum.

以下是如何在图中找到它:

  1. 在模型目录中启动tensorboard。就我而言,我没有设置任何目录,因此估算器使用临时目录并打印此行:
    WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpInj8SC
    启动张量板:

    tensorboard --logdir /tmp/tmpInj8SC
    
  2. Open it in browser and navigate to GRAPHS tab. enter image description here

  3. Find loss in the graph. Expand blocks in the sequence: dnnheadweighted_loss and click on the Sum node (note that there is summary node named loss connected to it). enter image description here

  4. 右侧信息“窗口”中显示的名称是所选节点的名称,需要传递给monitor参数pfEarlyStoppingHook.

的损失节点DNNClassifier默认情况下具有相同的名称。两个都DNNClassifier and DNNRegressor有可选参数loss_reduction影响丢失节点名称和行为(默认为losses.Reduction.SUM).

EDIT 2:

有一种不看图表就能找到损失的方法。
您可以使用GraphKeys.LOSSES收集以获得损失。但这种方式只有在训练开始后才有效。所以你只能在钩子中使用它。

例如,您可以删除monitor论证来自EarlyStoppingHook类并改变它的begin函数始终使用集合中的第一个损失:

self.monitor = tf.get_default_graph().get_collection(tf.GraphKeys.LOSSES)[0]

您可能还需要检查集合中是否存在丢失。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用可用的训练挂钩在 tf.estimator.DNNRegressor 中实现提前停止 的相关文章

随机推荐