在 Java 中加载 sklearn 模型。在 python 中使用 DNNClassifier 创建的模型

2024-03-06

目标是在 Java 中打开在 python 中创建/训练的模型tensorflow.contrib.learn.learn.DNNClassifier.

目前主要问题是知道在会话运行器方法上在 java 中给出的“张量”的名称。

我在 python 中有这个测试代码:

    from __future__ import division, print_function, absolute_import
import tensorflow as tf
import pandas as pd
import tensorflow.contrib.learn as learn
import numpy as np
from sklearn import metrics
from sklearn.cross_validation import train_test_split
from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.util.compat import as_text

print(tf.VERSION)

df = pd.read_csv('../NNNormalizeData-out.csv')

inputs = []
target = []

y=0;    
for x in df.columns:
    if y != 35 :
        #print("added %d" %y)
        inputs.append(x)
    else :
        target.append(x)
    y+=1

total_inputs,total_output = df.as_matrix(inputs).astype(np.float32),df.as_matrix([target]).astype(np.int32)

train_inputs, test_inputs, train_output, test_output = train_test_split(total_inputs, total_output, test_size=0.2, random_state=42)

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=train_inputs.shape[1],dtype=tf.float32)]
#target_column = [tf.contrib.layers.real_valued_column("output", dimension=train_output.shape[1])]

classifier = learn.DNNClassifier(hidden_units=[10, 20, 5], n_classes=5
                                 ,feature_columns=feature_columns)

classifier.fit(train_inputs, train_output, steps=100)

#Save Model into saved_model.pbtxt file (possible to Load in Java)
tfrecord_serving_input_fn = tf.contrib.learn.build_parsing_serving_input_fn(layers.create_feature_spec_for_parsing(feature_columns))  
classifier.export_savedmodel(export_dir_base="test", serving_input_fn = tfrecord_serving_input_fn,as_text=True)


# Measure accuracy
pred = list(classifier.predict(test_inputs, as_iterable=True))
score = metrics.accuracy_score(test_output, pred)
print("Final score: {}".format(score))

# test individual samples 
sample_1 = np.array( [[0.37671986791414125,0.28395908337619136,-0.0966095873607713,-1.0,0.06891621389763203,-0.09716678086712205,0.726029084013637,4.984689881073479E-4,-0.30296253267499107,-0.16192917054985334,0.04820256230479658,0.4951319883569152,0.5269983894210499,-0.2560313828048315,-0.3710980821053321,-0.4845867212612598,-0.8647234314469595,-0.6491591208322198,-1.0,-0.5004549422844073,-0.9880910165770813,0.5540293108747256,0.5625990251930839,0.7420121698556554,0.5445551415657979,0.4644276850235627,0.7316976292340245,0.636690006814346,0.16486621649984112,-0.0466018967678159,0.5261100063227044,0.6256168612312738,-0.544295484930702,0.379125782517193,0.6959368575211544]], dtype=float)
sample_2 = np.array( [[1.0,0.7982741870963959,1.0,-0.46270838239235024,0.040320274521029376,0.443451913224413,-1.0,1.0,1.0,-1.0,0.36689718911339564,-0.13577379160035796,-0.5162916256414466,-0.03373651520104648,1.0,1.0,1.0,1.0,0.786999801054777,-0.43856035121103853,-0.8199093927945158,1.0,-1.0,-1.0,-0.1134921695894473,-1.0,0.6420892436196663,0.7871737734493178,1.0,0.6501788845358409,1.0,1.0,1.0,-0.17586627413625022,0.8817194210401085]], dtype=float)

pred = list(classifier.predict(sample_2, as_iterable=True))
print("Prediction for sample_1 is:{} ".format(pred))

pred = list(classifier.predict_proba(sample_2, as_iterable=True))
print("Prediction for sample_2 is:{} ".format(pred))

创建 model_saved.pbtxt 文件。

我尝试使用以下代码在 Java 中加载此模型:

    public class HelloTF {
    public static void main(String[] args) throws Exception {
        SavedModelBundle bundle=SavedModelBundle.load("/java/workspace/APIJavaSampleCode/tfModels/dnn/ModelSave","serve");
        Session s = bundle.session();

        double[] inputDouble = {1.0,0.7982741870963959,1.0,-0.46270838239235024,0.040320274521029376,0.443451913224413,-1.0,1.0,1.0,-1.0,0.36689718911339564,-0.13577379160035796,-0.5162916256414466,-0.03373651520104648,1.0,1.0,1.0,1.0,0.786999801054777,-0.43856035121103853,-0.8199093927945158,1.0,-1.0,-1.0,-0.1134921695894473,-1.0,0.6420892436196663,0.7871737734493178,1.0,0.6501788845358409,1.0,1.0,1.0,-0.17586627413625022,0.8817194210401085};
        float [] inputfloat=new float[inputDouble.length];
        for(int i=0;i<inputfloat.length;i++)
        {
            inputfloat[i]=(float)inputDouble[i];
        }
        Tensor inputTensor = Tensor.create(new long[] {35}, FloatBuffer.wrap(inputfloat) );

        Tensor result = s.runner()
                .feed("input_example_tensor", inputTensor)
                .fetch("dnn/multi_class_head/predictions/probabilities")
                .run().get(0);


         float[] m = new float[5];
         float[] vector = result.copyTo(m);
         float maxVal = 0;
         int inc = 0;
         int predict = -1;
         for(float val : vector) 
         {
             System.out.println(val+"  ");
             if(val > maxVal) {
                 predict = inc;
                 maxVal = val;
             }
             inc++;
         }
         System.out.println(predict);



    }
} 

我在 .run().get(0); 上收到错误线 :

Exception in thread "main" org.tensorflow.TensorFlowException: Output 0 of type float does not match declared output type string for node _recv_input_example_tensor_0 = _Recv[_output_shapes=[[-1]], client_terminated=true, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=3663984897684684554, tensor_name="input_example_tensor:0", tensor_type=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"]()
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:285)
    at org.tensorflow.Session$Runner.run(Session.java:235)
    at tensorflow.HelloTF.main(HelloTF.java:35)

好吧,我终于解决了:主要问题是在java中使用的输入名称是“”dnn/input_from_feature_columns/input_from_feature_columns/concat”而不是“input_example_tensor”。

我使用图形导航发现了这一点:tensorboard --logdir=D:\python\Workspace\Autoencoder\src\dnn\ModelSave

这是java代码:

public class HelloTF {
public static void main(String[] args) throws Exception {
    SavedModelBundle bundle=SavedModelBundle.load("/java/workspace/APIJavaSampleCode/tfModels/dnn/ModelSave","serve");
    Session s = bundle.session();

    double[] inputDouble = {1.0,0.7982741870963959,1.0,-0.46270838239235024,0.040320274521029376,0.443451913224413,-1.0,1.0,1.0,-1.0,0.36689718911339564,-0.13577379160035796,-0.5162916256414466,-0.03373651520104648,1.0,1.0,1.0,1.0,0.786999801054777,-0.43856035121103853,-0.8199093927945158,1.0,-1.0,-1.0,-0.1134921695894473,-1.0,0.6420892436196663,0.7871737734493178,1.0,0.6501788845358409,1.0,1.0,1.0,-0.17586627413625022,0.8817194210401085};
    float [] inputfloat=new float[inputDouble.length];
    for(int i=0;i<inputfloat.length;i++)
    {
        inputfloat[i]=(float)inputDouble[i];
    }
FloatBuffer.wrap(inputfloat) );
    float[][] data= new float[1][35];
    data[0]=inputfloat;
    Tensor inputTensor=Tensor.create(data);


    Tensor result = s.runner()
            .feed("dnn/input_from_feature_columns/input_from_feature_columns/concat", inputTensor)
            //.feed("input_example_tensor", inputTensor)
            //.fetch("tensorflow/serving/classify")
            .fetch("dnn/multi_class_head/predictions/probabilities")
            //.fetch("dnn/zero_fraction_3/Cast")
            .run().get(0);


     float[][] m = new float[1][5];
     float[][] vector = result.copyTo(m);
     float maxVal = 0;
     int inc = 0;
     int predict = -1;
     for(float val : vector[0]) 
     {
         System.out.println(val+"  ");
         if(val > maxVal) {
             predict = inc;
             maxVal = val;
         }
         inc++;
     }
     System.out.println(predict);



}

}

我测试了输出:

蟒蛇方面:

Prediction for sample_2 is:[3] 
Prediction for sample_2 is:[array([ 0.17157166,  0.24475774,  0.16158019,  0.24648622,  0.17560424], dtype=float32)] 

Java端:

0.17157166  
0.24475774  
0.16158019   
0.24648622  
0.17560424  
3
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Java 中加载 sklearn 模型。在 python 中使用 DNNClassifier 创建的模型 的相关文章

  • 从文本文件中读取阿拉伯字符

    我完成了一个项目 在该项目中我读取了用记事本编写的文本文件 我的文本文件中的字符是阿拉伯语 文件编码类型是UTF 8 当在 Netbeans 7 0 1 中启动我的项目时 一切似乎都正常 但是当我将项目构建为 jar 文件时 字符以这种方式
  • 在 Java 中使用 Batik 检查和删除 SVG 中的属性

    这个问题基本上说明了一切 如何检查 SVG 是否具有 viewBox 属性 我正在使用蜡染库 我需要这个 因为我需要 至少 通知用户有一个 viewBox 属性 我可以删除它吗 使用 org w3c dom 类 您可以按照以下方式做一些事情
  • 在 jTextfield 中禁用“粘贴”

    我有一个用 Swing awt 编写的应用程序 我想阻止用户将值粘贴到文本字段中 有没有办法在不使用动作监听器的情况下做到这一点 您可以使用 null 参数调用 setTransferHandler 如下所示 textComponent s
  • 对象数组的数组(二维数组)JNI

    我正在努力创建自定义对象类型 ShareStruct 的二维数组 jobjectArray ret jobjectArray ins jobjectArray outs jclass myClass env gt FindClass env
  • 如何作为应用程序发布到页面?

    所以 我有一个应用程序 Facebook 应用程序实体 并且我有一个页面 我想使用应用程序通过java代码 通过restfb或任何其他建议 发布到页面 看起来我错过了页面授予应用程序发布权限的阶段 不知道该怎么做 谢谢你们 乌里 您只能 作
  • 重写 getPreferredSize() 会破坏 LSP

    我总是在这个压倒一切的网站上看到建议getPreferredSize 而不是使用setPreferredSize 例如 如前面的线程所示 对于固定大小的组件 使用重写 getPreferredSize 而不是使用 setPreferredS
  • Java 卡布局。多张卡中的一个组件

    一个组件 例如JLabel 在多张卡中使用CardLayout 目前看来该组件仅出现在它添加到的最后一张卡上 如果有办法做到这一点 我应该吗 这是不好的做法吗 或者有其他选择吗 你是对的 它只出现在 添加到的最后一张卡 中 但这与CardL
  • 运行 java -jar 时出现 java.lang.ClassNotFoundException

    我正在使用 ant 来构建我的build xml文件 它编译正常 但随后得到运行时java lang NoClassDefFoundError通过 运行生成的 jar 时java jar my jar jar 似乎这个问题出现了很多 但没有
  • 以有效的方式从 Map 中删除多个键?

    我有一个Map
  • 这个等待通知线程语义的真正目的是什么?

    我刚刚遇到一些代码 它使用等待通知构造通过其其他成员方法与类中定义的线程进行通信 有趣的是 获取锁后 同步范围内的所有线程都会在同一锁上进行定时等待 请参见下面的代码片段 随后 在非同步作用域中 线程执行其关键函数 即 做一些有用的事情1
  • PIL.Image.open和tf.image.decode_jpeg返回值的区别

    我使用 PIL Image open 和 tf image decode jpeg 将图像文件解析为数组 但发现PIL Image open 中的像素值与tf image decode jpeg不一样 为什么会出现这种情况 Thanks 代
  • 嵌套字段的 Comparator.comparing(...)

    假设我有一个这样的域模型 class Lecture Course course getters class Course Teacher teacher int studentSize getters class Teacher int
  • 为什么无法从 WEB-INF 文件夹内加载 POSModel 文件?

    我在我的 Web 项目中使用 Spring MVC 我将模型文件放在 WEB INF 目录中 String taggerModelPath WEB INF lib en pos maxent bin String chunkerModelP
  • C 与 C++ 中的 JNI 调用不同?

    所以我有以下使用 Java 本机接口的 C 代码 但是我想将其转换为 C 但不知道如何转换 include
  • 删除 JFX 中选项卡后面的灰色背景

    So is there any way to remove the gray area behind the tab s 我尝试过用 CSS 来做到这一点 但没有找到方法 要设置 tabpane 标题的背景颜色 请在 CSS 文件中写入 t
  • javax.media.jai 类的公共下载?

    这是一个非常简单的问题 我一直在寻找可以下载 javax media jai 库的地方 我找到了 jai imageio 库 但是我发现的所有其他 jai 内容要么已经过时 2008 年及之前 然后我遇到了登录屏幕 是否有 javax me
  • Android UnityPlayerActivity 操作栏

    我正在构建一个 Android 应用程序 其中包含 Unity 3d 交互体验 我已将 Unity 项目导入 Android Studio 但启动时该 Activity 是全屏的 并且不显示 Android 操作栏 我怎样才能做到这一点 整
  • 使用 PC/SC 读卡器验证 Ultralight EV1

    我在尝试使用 Java 中的 PC SC 读卡器 特别是 ACR1222L 验证 Ultralight EV1 卡时遇到问题 我能够使用 ISO 14443 3 标签的相应 APDU 在不受保护的标签上进行写入和读取 但是 我找不到运行 P
  • 如何使用自定义 JDK 构建 Jenkins 项目?

    我有一个常规的 Jenkins 实例 运行一些多分支管道 该实例在 JDK 11 上运行 因为 Jenkins 并不真正支持更高版本 没关系 但不好的是 我的所有管道似乎也都受到 Java 11 的限制 Jenkins 仅使用它自己也使用的
  • 决策树和规则引擎 (Drools)

    In the application that I m working on right now I need to periodically check eligibility of tens of thousands of object

随机推荐