Tensorflow,在另一个 tf.estimator model_fn 中使用 tf.estimator 训练的模型

2024-06-28

有没有办法在另一个模型 B 中使用 tf.estimator 训练的模型 A?

这是情况, 假设我有一个经过 model_a_fn() 训练的“模型 A”。 “模型 A”获取图像作为输入,并输出一些类似于 MNIST 分类器的向量浮点值。 还有另一个“模型 B”,在 model_b_fn() 中定义。 它还获取图像作为输入,并在训练“模型 B”时需要“模型 A”的矢量输出。

所以基本上我想训练“模型 B”,它需要输入作为“模型 A”的图像和预测输出。 (不再需要训练“模型A”,只需在训练“模型B”时获得预测输出)

我尝试过三种情况:

  1. 在 model_b_fn() 内使用估计器对象('Model A')
  2. 使用 tf.estimator.export_savedmodel() 导出“模型 A”,并创建预测函数。使用 params dict 将其传递给 model_b_fn() 。
  3. 与 2 相同,但在 model_b_fn() 中恢复“模型 A”

但所有情况都显示错误:

  1. ...必须来自与...相同的图表
  2. 类型错误:无法 pickle _thread.RLock 对象
  3. 类型错误:提要的值不能是 tf.Tensor 对象。

这是我使用的代码......仅附加重要部分

train_model_a.py

def model_a_fn(features, labels, mode, params):
    # ...
    # ...
    # ...
    return

def main():
    # model checkpoint location
    model_a_dir = './model_a'

    # create estimator for Model A
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # train Model A
    model_a.train(input_fn=lambda : input_fn_a)
    # ...
    # ...
    # ...

    # export model a
    model_a.export_savedmodel(model_a_dir, serving_input_receiver_fn=serving_input_receiver_fn)
    # exported to ./model_a/123456789
    return

if __name__ == '__main__':
    main()

train_model_b_case_1.py

# follows model_a's input format
def bypass_input_fn(x):
    features = {
        'x': x,
    }
    return features

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a = params['model_a']
    predictions = model_a.predict(
        input_fn=lambda: bypass_input_fn(inputs)
    )
    for results in predictions:
        # Error occurs!!!
        model_a_output = results['class_id']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a'
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a': model_a,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_2.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = params['model_a_predict_fn']
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_predict_fn': model_a_predict_fn,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_3.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=params['model_a_dir'])
    # Error occurs!!!
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_dir': model_a_dir,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

那么在另一个 tf.estimator 中使用经过训练的自定义 tf.estimator 有什么想法吗?


我已经找到了解决这个问题的方法。

如果遇到同样的问题,可以使用这种方法。

  1. 创建一个运行tensorflow.contrib.predictor.from_saved_model()的函数 -> 称之为“pretrained_predictor()”
  2. 在模型 B 的 model_fn() 内部,调用上面预定义的 'pretrained_predictor()' 用 tensorflow.py_func() 包装它

示例案例参见https://github.com/moono/tf-cnn-mnist/blob/master/4_3_estimator_within_estimator.py https://github.com/moono/tf-cnn-mnist/blob/master/4_3_estimator_within_estimator.py对于简单的用例。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow,在另一个 tf.estimator model_fn 中使用 tf.estimator 训练的模型 的相关文章

随机推荐