如何使用经过训练的 Tensorflow 模型进行预测

2024-03-24

我已经创建并训练了一个神经网络,但我希望能够输入测试点并查看其结果(而不是使用评估函数)。

该模型运行良好,并且成本减少了每个时期,但我只想在末尾添加一行来传递一些输入坐标,并让它告诉我预测的转换坐标。

import tensorflow as tf
import numpy as np

def coordinate_transform(size, angle):
    input = np.random.rand(size, 2)
    output = np.zeros((size, 2))
    noise = 0.05*(np.add(np.random.rand(size) * 2, -1))
    theta = np.add(np.add(np.arctan(input[:,1] / input[:,0]) , angle) , noise)
    radii = np.sqrt(np.square(input[:,0]) + np.square(input[:,1]))
    output[:,0] = np.multiply(radii, np.cos(theta))
    output[:,1] = np.multiply(radii, np.sin(theta))
    return input, output

#Data
input, output = coordinate_transform(2000, np.pi/2)
train_in = input[:1000]
train_out = output[:1000]
test_in = input[1000:]
test_out = output[1000:]

# Parameters
learning_rate = 0.001
training_epochs = 15
batch_size = 1
display_step = 1

# Network Parameters
n_hidden_1 = 100 # 1st layer number of features
n_input = 2 # [x,y]
n_classes = 2 # output x,y coords

# tf Graph input
x = tf.placeholder("float", [1,n_input])
y = tf.placeholder("float", [1, n_input])

# Create model
def multilayer_perceptron(x, weights, biases):
    # Hidden layer with RELU activation
    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
    layer_1 = tf.nn.relu(layer_1)
    # Output layer with linear activation
    out_layer = tf.matmul(layer_1, weights['out']) + biases['out']
    return out_layer

# Store layers weight & bias
weights = {
    'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'out': tf.Variable(tf.random_normal([n_hidden_1, n_classes]))
}
biases = {
    'b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}

# Construct model
pred = multilayer_perceptron(x, weights, biases)

# Define loss and optimizer
#cost = tf.losses.mean_squared_error(0, (tf.slice(pred, 0, 1) - x)**2 + (tf.slice(pred, 1, 1) - y)**2)
cost = tf.losses.mean_squared_error(y, pred)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
optimizer = optimizer.minimize(cost)

# Initializing the variables
#init = tf.global_variables_initializer()
init = tf.initialize_all_variables()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = 1000#int(len(train_in)/batch_size)
        # Loop over all batches
        for i in range(total_batch):
            batch_x = train_in[i].reshape((1,2))
            batch_y = train_out[i].reshape((1,2))

            #print(batch_x.shape)
            #print(batch_y.shape)
            #print(batch_y, batch_x)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
                                                          y: batch_y})
            # Compute average loss
            avg_cost += c / total_batch
        # Display logs per epoch step
        if epoch % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", \
                "{:.9f}".format(avg_cost))
    print("Optimization Finished!")

    # Test model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # Calculate accuracy
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

    #Make predictions

好吧,“pred”操作是您的实际结果(因为它用于在计算损失时与 y 进行比较),因此类似以下内容应该可以解决问题:

print(sess.run([pred], feed_dict={x: _INPUT_GOES_HERE_ })

明显地_INPUT_GOES_HERE_需要替换为实际输入。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用经过训练的 Tensorflow 模型进行预测 的相关文章

随机推荐

  • Adobe AIR 执行程序

    我想按下 Adob e AIR 应用程序中的按钮并执行某些已安装的程序 例如 我有一个名为 Start Winamp 的按钮 当按下这个按钮时 它应该直接启动 Winamp exe 我不想执行一些命令行 我只想启动一个 exe 或者 是同一
  • CSS - 将文本添加到样式表中的样式

    我还没有找到任何文档 所以我认为这是不可行的 但值得一问 我可以在样式表内指定样式内的实际文本吗 我有几个地方在相同的 div 位置使用相同的文本 我没有使用 javascript 或在 div 中重新输入相同的文本 而是在考虑样式是否可以
  • Json.NET - 防止重新序列化已经序列化的属性[重复]

    这个问题已经存在了 在 ASP NET Web API 应用程序中 我正在使用的一些模型包含一块仅在客户端有用的临时 JSON 在服务器上 它只是作为字符串进出关系数据库 性能是关键 在服务器端处理 JSON 字符串似乎根本没有意义 所以在
  • 无法在不指定完整路径的情况下运行 python 脚本

    您好 我正在尝试直接从终端运行 python 脚本 为此 我已将包含代码的目录添加到我的环境 PATH 变量中 但是 当我指定完整路径时我可以运行 但当我只调用脚本时则不能运行 base DS home user abc my codes
  • 侦听器拒绝连接并出现以下错误:ORA-12505,TNS:侦听器当前不知道连接描述符中给出的 SID

    从昨天开始我的数据库已经工作一年多了 突然间 我无法再连接 我得到的错误是 Status Failure Test failed Listener refused the connection with the following erro
  • .NET 委托是否用于事件?

    我有点困惑 我知道委托就像函数指针 它们用于将函数作为参数传递到方法中 这如何融入事件模型 Calling myButton OnClick new 当事件发生时 内部是否只是将方法 函数作为参数传递 并且所有订阅者都收到有关该事件的通知
  • 有没有简单的方法可以在目标 c 中对一位数字的浮点数进行四舍五入?

    是的 你是对的 当然 这是一个重复的问题 在标记我的问题之前 请继续阅读下面的内容 我想四舍五入一个浮点值 即 56 6748939 to 56 7 56 45678 to 56 5 56 234589 to 56 2 实际上它可以是任意数
  • onBeforeRequest 侦听器中的异步调用替代方案

    对于我的 Chrome 扩展程序 我希望具有阻止请求功能 我有一个很大的域列表 10000 我正在考虑使用 IndexedDb 来存储域列表 但据我现在了解 不可能进行异步调用并在请求处理程序中返回结果 我最初的计划是 function r
  • iframe 被认为是“不好的做法”吗? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 在此过程中 我发现使用 iframe 是 不好的做法 这是真的 使用它们的优点 缺点是什么 与所有技术一样 它也有其优点和缺点 如果您使用 if
  • 向 ExpandoObject 添加与字符串同名的属性

    有没有办法向 ExpandoObject 添加与字符串值同名的属性 例如 如果我有 string propName ProductNumber dynamic obj new System Dynamic ExpandoObject 我可以
  • 如何选择每个部门的最高工资,包括赚取该工资的员工[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 给定一个表Employees EMP
  • 如何使用 Strawberry 在 Windows 上编译 Perl 模块?

    这更多的是一个公开的讨论和结论 而不是一个真正的问题 希望它能在某个时候帮助别人 我正在寻找如何在断开互联网的服务器上制作 Perl 模块 否则答案很简单 使用cpan 所以我唯一的选择就是直接在服务器上手动编译从互联网 CPAN或其他 下
  • 递归地循环遍历对象(树)

    有没有办法 在 jQuery 或 JavaScript 中 循环遍历每个对象及其子对象和孙对象等等 如果是这样 我也能读到他们的名字吗 Example foo bar child grand greatgrand and so on 所以循
  • 如何生成 Facebook Marketing API 访问令牌以在 Windows 应用程序中使用它

    我使用 Facebook 作为广告平台在 Apple 和 Google 商店上推广我的应用程序 我想制作一个 Windows 服务 该服务将下载有关我运行一个 Facebook 的营销活动的广告状态的每日报告 最好使用 60 天的令牌 或某
  • lambda 表达式语法与 LambdaExpression 类

    这行代码尝试将 lambda 表达式分配给LambaExpression http msdn microsoft com en us library system linq expressions lambdaexpression aspx
  • 如何在 ruby​​ 中将 ruby​​ 格式的 json 字符串转换为 json 哈希?

    我想像哈希对象一样访问 json 字符串 以便我可以使用像这样的键值访问 jsontemp anykey 如何将 ruby 格式的 json 字符串转换为 json 对象 我有以下 json 字符串 temp accept gt host
  • 添加性能计数器类别使计算机挂起

    我正在尝试从 ASP NET MVC 应用程序 在 Windows 8 x64 PC 上使用 VS 2012 添加性能计数器 但我遇到的问题是 如果我检查类别是否存在或添加新的性能计数器类别 计算机就会挂起 我的代码是 namespace
  • 模糊边缘检测

    我对图像处理和识别的背景知识很少 我正在尝试检测灰度图像 例如肖像 上的主要边缘 灰度过渡 问题是在某些部分 边缘模糊 因为焦点 我使用具有多个阈值的 Canny 边缘检测器 但我永远无法检测到这些边缘 下巴 衣服 耳朵 脸的侧面 Orig
  • 如何在Python中使用正则表达式替换字符串中的多个单词?

    我有一本字典 比如 dic xl xlarg l larg m medium 我想使用 re sub 或类似的方法查找 dic keys 中的任何字符串 包括单个字母 并将其替换为键的值 def multiple replace dict
  • 如何使用经过训练的 Tensorflow 模型进行预测

    我已经创建并训练了一个神经网络 但我希望能够输入测试点并查看其结果 而不是使用评估函数 该模型运行良好 并且成本减少了每个时期 但我只想在末尾添加一行来传递一些输入坐标 并让它告诉我预测的转换坐标 import tensorflow as