使用估计器训练 Tensorflow 模型 (from_generator)

2024-01-11

我正在尝试使用生成器训练估计器,但我想为每次迭代提供该估计器的样本包。我展示代码:

def _generator():
for i in range(100):
    feats  = np.random.rand(4,2)
    labels = np.random.rand(4,1)

    yield feats, labels


def input_func_gen():
    shapes = ((4,2),(4,1))
    dataset = tf.data.Dataset.from_generator(generator=_generator,
                                         output_types=(tf.float32, tf.float32),
                                         output_shapes=shapes)
dataset = dataset.batch(4)
# dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels


x_col = tf.feature_column.numeric_column(key='x', shape=(4,2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col],model_dir=tf_data)
es = es.train(input_fn=input_func_gen,steps = None)

当我运行此代码时,它会引发此错误:

    raise ValueError(err.message)
ValueError: Dimensions must be equal, but are 2 and 3 for 'linear/head/labels/assert_equal/Equal' (op: 'Equal') with input shapes: [2], [3].

我该如何调用这个结构?

thx!!!


批量大小由 Tensorflow 自动计算并添加到张量形状中,因此不必手动完成。您的生成器还应该定义为输出单个样本。

假设4形状的位置 0 是批量大小,然后:

import tensorflow as tf
import numpy

def _generator():
    for i in range(100):
        feats  = numpy.random.rand(2)
        labels = numpy.random.rand(1)

        yield feats, labels


def input_func_gen():
    shapes = ((2),(1))
    dataset = tf.data.Dataset.from_generator(generator=_generator,
                                         output_types=(tf.float32, tf.float32),
                                         output_shapes=shapes)
    dataset = dataset.batch(4)
    # dataset = dataset.repeat(20)
    iterator = dataset.make_one_shot_iterator()
    features_tensors, labels = iterator.get_next()
    features = {'x': features_tensors}
    return features, labels


x_col = tf.feature_column.numeric_column(key='x', shape=(2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col])
es = es.train(input_fn=input_func_gen,steps = None)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用估计器训练 Tensorflow 模型 (from_generator) 的相关文章

随机推荐