你需要设置shuffle=False
因为要预测新标签,您需要维护数据顺序。
下面是我运行预测的代码(我已经测试过)。输入文件类似于测试数据(csv),但没有标签列。
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
调用它:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
一个样本的预测结果如下:
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
每个字段的含义是
-
“概率”:数组([0.78595656,0.21404342],dtype = float32).
它预测输出标签为 0 类(在本例中
-
'logits':数组([-1.3007226],dtype = float32)
方程 1/(1+e^(-z)) 中 z 的值为 -1.3。
-
'类':数组(['0'],dtype =对象)
类标签为0