希望您已使用类似于下面提到的代码保存了估计器模型:
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
export_path = estimator.export_saved_model(
"/tmp/from_estimator/", serving_input_fn)
您可以使用下面提到的代码加载模型:
imported = tf.saved_model.load(export_path)
To Predict
通过传递输入特征来使用您的模型,您可以使用以下代码:
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
欲了解更多详情,请参阅这个链接 https://www.tensorflow.org/guide/saved_model#savedmodels_from_estimators其中解释了使用 TF Estimator 保存的模型。