如何在 Java 中为 TensorFlow DNNRegressor 提供输入?

2024-01-06

我设法使用 DNNRegressor 编写了 TensorFlow python 程序。我已经训练了模型,并且能够通过手动创建的输入(常量张量)从 Python 中的模型中获得预测。我还能够以二进制格式导出模型。

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import graph_util

#######################
# Setup
#######################

# Converting Data into Tensors
def input_fn(df, training = True):
    # Creates a dictionary mapping from each continuous feature column name (k) to
    # the values of that column stored in a constant Tensor.
    continuous_cols = {k: tf.constant(df[k].values)
                    for k in continuous_features}

    feature_cols = dict(list(continuous_cols.items()))

    if training:
        # Converts the label column into a constant Tensor.
        label = tf.constant(df[LABEL_COLUMN].values)

        # Returns the feature columns and the label.
        return feature_cols, label

    # Returns the feature columns    
    return feature_cols

def train_input_fn():
    return input_fn(train_df)

def eval_input_fn():
    return input_fn(evaluate_df)

#######################
# Data Preparation
#######################
df_train_ori = pd.read_csv('training.csv')
df_test_ori = pd.read_csv('test.csv')
train_df = df_train_ori.head(10000)
evaluate_df = df_train_ori.tail(5)
test_df = df_test_ori.head(1)
MODEL_DIR = "/tmp/model"
BIN_MODEL_DIR = "/tmp/modelBinary"
features = train_df.columns
continuous_features = [feature for feature in features if 'label' not in feature]
LABEL_COLUMN = 'label'

engineered_features = []

for continuous_feature in continuous_features:
    engineered_features.append(
        tf.contrib.layers.real_valued_column(
            column_name=continuous_feature,
            dimension=1,
            default_value=None,
            dtype=tf.int64,
            normalizer=None
            ))


#######################
# Define Our Model
#######################
regressor = tf.contrib.learn.DNNRegressor(
    feature_columns=engineered_features,
    label_dimension=1,
    hidden_units=[128, 256, 512], 
    model_dir=MODEL_DIR
    )

#######################
# Training Our Model
#######################
wrap = regressor.fit(input_fn=train_input_fn, steps=5)

#######################
# Evaluating Our Model
#######################
results = regressor.evaluate(input_fn=eval_input_fn, steps=1)
for key in sorted(results):
    print("%s: %s" % (key, results[key]))

#######################
# Save binary model (to be used in Java)
#######################
tfrecord_serving_input_fn = tf.contrib.learn.build_parsing_serving_input_fn(tf.contrib.layers.create_feature_spec_for_parsing(engineered_features)) 
regressor.export_savedmodel(
    export_dir_base=BIN_MODEL_DIR, 
    serving_input_fn = tfrecord_serving_input_fn,
    assets_extra=None,
    as_text=False,
    checkpoint_path=None,
    strip_default_attrs=False)

我的下一步是将模型加载到 java 中并做出一些预测。然而,我在用 Java 指定模型的输入时遇到了问题。

import org.tensorflow.*;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import java.util.List;
import java.util.Map;

public class ModelEvaluator {
    public static void main(String[] args) throws Exception {
        System.out.println("Using TF version: " + TensorFlow.version());

        SavedModelBundle model = SavedModelBundle.load("/tmp/modelBinary/1546510038", "serve");
        Session session = model.session();

        printSignature(model);
        printAllNodes(model);

        float[][] km1 = new float[1][1];
        km1[0][0] = 10;
        Tensor inKm1 = Tensor.create(km1);

        float[][] km2 = new float[1][1];
        km2[0][0] = 10000;
        Tensor inKm2 = Tensor.create(km2);

        List<Tensor<?>> outputs = session.runner()
                .feed("dnn/input_from_feature_columns/input_from_feature_columns/km1/ToFloat", inKm1)
                .feed("dnn/input_from_feature_columns/input_from_feature_columns/km2/ToFloat", inKm2)
                .fetch("dnn/regression_head/predictions/Identity:0")
                .run();

        System.out.println("\n\nOutputs from evaluation:");
        for (Tensor<?> output : outputs) {
            if (output.dataType() == DataType.STRING) {
                System.out.println(new String(output.bytesValue()));
            } else {
                float[] outArray = new float[1];
                output.copyTo(outArray);
                System.out.println(outArray[0]);
            }
        }
    }

    public static void printAllNodes(SavedModelBundle model) {
        model.graph().operations().forEachRemaining(x -> {
            System.out.println(x.name() + "   " + x.numOutputs());
        });
    }


    /**
     * This info can also be obtained from a command prompt via the command:
     * saved_model_cli show  --dir <dir-to-the-model> --tag_set serve --signature_def serving_default
     * <p>
     * See this where they also try to input data to a DNN regressor:
     * https://github.com/tensorflow/tensorflow/issues/12367
     * <p>
     * https://github.com/tensorflow/tensorflow/issues/14683
     * <p>
     * https://github.com/migueldeicaza/TensorFlowSharp/issues/293
     */
    public static void printSignature(SavedModelBundle model) throws Exception {
        MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
        SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
        int numInputs = sig.getInputsCount();
        int i = 1;
        System.out.println("-----------------------------------------------");
        System.out.println("MODEL SIGNATURE");
        System.out.println("Inputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
        }
        int numOutputs = sig.getOutputsCount();
        i = 1;
        System.out.println("Outputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
        }
        System.out.println("-----------------------------------------------");
    }
}

从java代码中可以看出,我为两个节点提供了输入(用“km1”和“km2”命名)。但我想这不是正确的方法。我猜我需要为节点“input_example_tensor:0”提供输入?

所以问题是:如何实际为加载到 java 中的模型创建输入?在Python中,我必须创建一个带有键“km1”和“km2”的字典,并值两个常量张量。


在 Python 上,尝试

feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

请查看build_parsing_serving_input_receiver_fn,以及一个名为输入示例张量需要序列化的 tf.Example。

在 Java 上,尝试创建一个Example http://static.javadoc.io/org.tensorflow/proto/1.12.0/org/tensorflow/example/package-summary.html输入(封装在org.tensorflow:原型工件 https://mvnrepository.com/artifact/org.tensorflow/proto/1.12.0),以及一些这样的代码:

public static void main(String[] args) {
    Example example = buildExample(yourFeatureNameAndValueMap);
    byte[][] exampleBytes = {example.toByteArray()};
    try (Tensor<String> inputBatch = Tensors.create(exampleBytes);
         Tensor<Float> output =
                 yourSession
                         .runner()
                         .feed(yourInputsName, inputBatch)
                         .fetch(yourOutputsName)
                         .run()
                         .get(0)
                         .expect(Float.class)) {
        long[] shape = output.shape();
        int batchSize = (int) shape[0];
        int labelNum = (int) shape[1];
        float[][] resultValues = output.copyTo(new float[batchSize][labelNum]);
        System.out.println(resultValues);
    }
}

public static Example buildExample(Map<String, ?> yourFeatureNameAndValueMap) {
    Features.Builder builder = Features.newBuilder();
    for (String attr : yourFeatureNameAndValueMap.keySet()) {
        Object value = yourFeatureNameAndValueMap.get(attr);
        if (value instanceof Float) {
            builder.putFeature(attr, feature((Float) value));
        } else if (value instanceof float[]) {
            builder.putFeature(attr, feature((float[]) value));
        } else if (value instanceof String) {
            builder.putFeature(attr, feature((String) value));
        } else if (value instanceof String[]) {
            builder.putFeature(attr, feature((String[]) value));
        } else if (value instanceof Long) {
            builder.putFeature(attr, feature((Long) value));
        } else if (value instanceof long[]) {
            builder.putFeature(attr, feature((long[]) value));
        } else {
            throw new UnsupportedOperationException("Not supported attribute value data type!");
        }
    }
    Features features = builder.build();
    Example example = Example.newBuilder()
            .setFeatures(features)
            .build();
    return example;
}

private static Feature feature(String... strings) {
    BytesList.Builder b = BytesList.newBuilder();
    for (String s : strings) {
        b.addValue(ByteString.copyFromUtf8(s));
    }
    return Feature.newBuilder().setBytesList(b).build();
}

private static Feature feature(float... values) {
    FloatList.Builder b = FloatList.newBuilder();
    for (float v : values) {
        b.addValue(v);
    }
    return Feature.newBuilder().setFloatList(b).build();
}

private static Feature feature(long... values) {
    Int64List.Builder b = Int64List.newBuilder();
    for (long v : values) {
        b.addValue(v);
    }
    return Feature.newBuilder().setInt64List(b).build();
}

如果你想要自动获取您的输入名称 and 你的输出名称, 你可以试试

SignatureDef signatureDef;
try {
    signatureDef = MetaGraphDef.parseFrom(model.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF_KEY);
} catch (InvalidProtocolBufferException e) {
    throw new RuntimeException(e.getMessage(), e);
}
String yourInputsName = signatureDef.getInputsOrThrow(SIGNATURE_DEF_INPUT_KEY).getName();
String yourOutputsName = signatureDef.getOutputsOrThrow(SIGNATURE_DEF_OUTPUT_KEY).getName();

关于java,请参考检测对象.java https://github.com/tensorflow/models/blob/master/samples/languages/java/object_detection/src/main/java/DetectObjects.java。关于Python,请参考宽_深 https://github.com/tensorflow/models/tree/master/official/wide_deep

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 Java 中为 TensorFlow DNNRegressor 提供输入? 的相关文章

随机推荐

  • 关于 iPhone 应用程序包对于 App Store 的合理大小的问题。存储内存!

    我计划将我的新应用程序提交到 App Store App包含大量图像资源 动画 超过40M App Store 对应用程序包大小有正式限制吗 我从来没有见过这样的情况 认为应该没问题 这样对吗 我的主要问题是 如果应用程序很大 可用性是否会
  • Vector 是一个过时的集合

    检查报告 java util Vector 或 java util hashtable 的任何使用 虽然仍然受支持 但这些类已被 JDK 1 2 Collection 类废弃 并且可能不应该在新的开发中使用 我有一个 Java 项目 它使用
  • 如何将缺失的行插入到该数据集中?

    我想做的是每当缺少一行时将记录插入到数据集中 如果您查看上面的数据集 它包含 3 列属性 然后是 2 个数值 第三列 TTF 是增量的 不应跳过任何值 在此示例中 缺少显示在底部的 2 行 因此 我希望我的代码执行的操作是将这两行插入到结果
  • 如果长度 > 5 如何修剪数组

    如果长度 gt 5 如何修剪数组 我的 JSON 是 name aaa files name A link string com name Q link string com name M link string com
  • 如何删除 Github 网络视图中显示的 git 中的未命名分支

    在我的 git 存储库的 Github 网络视图中 有一个没有名称的 幻影 分支 请看下图 为了简单起见 我想删除黑色分支 只留下蓝色分支 如何才能做到这一点 一些带有哈希值的标签 假设黑色分支从提交 A 开始 到提交 Z 结束 A 和 Z
  • 请帮助我完成康威生命游戏的基本 java 实现

    我花了很长时间试图编写一个程序来实现康威的生命游戏 链接更多信息 http en wikipedia org wiki Conway 27s Game of Life 我正在遵循一些在线指南 并获得了大部分功能 我编写了如下所示的 next
  • Xamarin 表单向左滑动/向右滑动手势

    我想先说一下 我对移动开发 Xamarin C Net 完全陌生 我正在使用 Xamarin Forms 创建移动应用程序 但遇到了无法使用滑动手势的问题 至少根据我看到的文档是这样 我找到了这个网站 http arteksoftware
  • 如何使用CSS替换PNG图像的颜色? [复制]

    这个问题在这里已经有答案了 我在网页中有一个图标 div class icon container img src img gavel3 png class gavel icon style width 80px div 我正在尝试用颜色替
  • 如何使用向量化从数组中选择最接近数组中值的值?

    我有一个值数组 我想根据线性最接近的选择从一系列选择中替换这些值 问题是选择的大小是在运行时定义的 import numpy as np a np array 0 0 0 4 4 4 9 9 9 choices np array 1 5 1
  • 设置 Fullcalendar 单元格背景颜色

    我看到了几个关于如何在全日历中设置单元格背景颜色的主题 但它们都不适合我 我猜日历曾经使用日期来列出日期 例如 fc day5 或 fc day17 但在版本 1 6 2 中不再这样做了 我有一个正在渲染的多个事件的列表 我想将它们的单元格
  • 无法访问 Facebook 活动

    我无法使用 Facebook Graph API 获取我的营销活动列表 回应 me adaccounts data account id 123456789000001234 id act 123456789000001234 paging
  • Google Cloud 虚拟实例试用后消失?

    我创建了两个虚拟机实例 审判结束后他们就消失了 我已经在计算引擎菜单中搜索 但找不到任何内容 你知道我是否可以恢复它们或者我能做什么吗 试用结束后 您在试用期间创建的资源将停止 但如果您在 30 天内升级到付费帐户 则可以恢复 在这 30
  • 如何向 .NET 4.5 WCF 服务添加异步支持,使其不会中断现有客户端?

    我有一个带有 SOAP 端点的现有 WCF 服务 使用 NET 4 5 大多数现有的客户端代码都使用ChannelFactory
  • ProgramFiles64Folder 正在安装到 WIX 安装程序中的 \Program Files (x86)\

    我目前有两个 WIX 项目 一个用于创建 x86 安装程序 另一个用于创建 x64 安装程序 我想将这两个项目合并为一个使用变量来控制程序流程的项目 我有以下内容
  • 编码 UI 播放 - 在特定文本框中输入文本时抛出错误(使用 javaScript 过滤击键)

    我刚刚开始编写一些编码的 ui 测试 当我尝试在文本框中输入一个值时 我在播放过程中遇到了问题 该文本框通过 javaScript 函数仅限于数字 十进制 值 我已将该脚本确定为 罪魁祸首 因为测试在禁用该脚本时成功运行 我在测试中输入的值
  • CRTP:基于派生类内容启用基类中的方法

    有没有办法从 CRTP 基类查询派生类的内容 与 SFINAE 一起使用来启用或禁用基类方法 我想要完成的事情可能如下所示 template
  • Excel 互操作打印

    我需要使用以下打印设置打印 Excel 工作表的选定区域 我使用 Range Select 选择的区域 打印机 Microsoft XPS 文档编写器打印选择景观方向 A4正常边距一页适合尺寸表 如何使用 Worksheet PrintOu
  • ggplot2facets:每个图的不同注释文本

    我有以下生成的数据框 称为 Raw Data Time Velocity Type 1 10 1 a 2 20 2 a 3 30 3 a 4 40 4 a 5 50 5 a 6 10 2 b 7 20 4 b 8 30 6 b 9 40 8
  • 如何模块化这个react状态容器?

    因此 在工作中 我们构建了这个很棒的状态容器挂钩 以便在 React 应用程序和相关包中使用 首先介绍一下这个钩子的一些背景 以及在开始我想用它做什么之前我想要保留的内容 这是工作代码 您会注意到它带有注释 可以轻松复制和粘贴以创建新的 S
  • 如何在 Java 中为 TensorFlow DNNRegressor 提供输入?

    我设法使用 DNNRegressor 编写了 TensorFlow python 程序 我已经训练了模型 并且能够通过手动创建的输入 常量张量 从 Python 中的模型中获得预测 我还能够以二进制格式导出模型 import pandas