训练后,TensorFlow 始终会收敛到所有项目的相同输出

2023-12-05

这是我正在使用的代码片段:

import tensorflow as tf
import numpy as np
from PIL import Image
from os import listdir

nodes_l1 = 500
nodes_l2 = 100
nodes_l3 = 500
num_batches = 20
num_epochs = 50

# Array of file dirs
human_file_array = listdir('human/')
human_file_array = [['human/'+human_file_array[i],[1,0]] for i in range(len(human_file_array))]
cucumber_file_array = listdir('cucumber/')
cucumber_file_array = [['cucumber/'+cucumber_file_array[i],[0,1]] for i in range(len(cucumber_file_array))]
file_array_shuffled = human_file_array + cucumber_file_array
np.random.shuffle(file_array_shuffled)

htest_file_array = listdir('human_test/')
htest_file_array = [['human_test/'+htest_file_array[i],[1,0]] for i in range(len(htest_file_array))]
ctest_file_array = listdir('cucumber_test/')
ctest_file_array = [['cucumber_test/'+ctest_file_array[i],[0,1]] for i in range(len(ctest_file_array))]
test_file_array = ctest_file_array + htest_file_array
np.random.shuffle(test_file_array)

input_data = tf.placeholder('float', [None, 250*250*3]
output_data = tf.placeholder('float')

hl1_vars = {
    'weight': tf.Variable(tf.random_normal([250*250*3, nodes_l1])),
    'bias': tf.Variable(tf.random_normal([nodes_l1]))
}

hl2_vars = {
    'weight': tf.Variable(tf.random_normal([nodes_l1, nodes_l2])),
    'bias': tf.Variable(tf.random_normal([nodes_l2]))
}

hl3_vars = {
    'weight': tf.Variable(tf.random_normal([nodes_l2, nodes_l3])),
    'bias': tf.Variable(tf.random_normal([nodes_l3]))
}

output_layer_vars = {
    'weight': tf.Variable(tf.random_normal([nodes_l3, 2])),
    'bias': tf.Variable(tf.random_normal([2]))
}

layer1 = tf.add(tf.matmul(input_data, hl1_vars['weight']),hl1_vars['bias'])
layer1 = tf.nn.softmax(layer1)

layer2 = tf.add(tf.matmul(layer1, hl2_vars['weight']), hl2_vars['bias'])
layer2 = tf.nn.softmax(layer2)

layer3 = tf.add(tf.matmul(layer2, hl3_vars['weight']), hl3_vars['bias'])
layer3 = tf.nn.softmax(layer3)

output = tf.add(tf.matmul(layer3, output_layer_vars['weight']), output_layer_vars['bias'])
output = tf.nn.softmax(output)

def convert_image(path):
    with Image.open(path) as img:
        img = img.resize((250,250))
        img = img.convert('RGB')
        return img

def train_network():
    #prediction = output
    cost = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(output, output_data)) # output is the prediction, output_data is key
    optimizer = tf.train.AdamOptimizer().minimize(cost)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver = tf.train.Saver()

        for epoch in range(num_epochs):
            epoch_error = 0
            batch_size = int((len(file_array_shuffled)/num_batches))
            for i in range(num_batches):
                path_var = []
                key_var = []
                img_var = []
                #Still Filename Batch!!
                batch_file_array = file_array_shuffled[batch_size*i:(batch_size*i)+batch_size] #batch1['file&val array']['val']
                for batch_val in batch_file_array:
                    path_var.append(batch_val[0])
                    key_var.append(batch_val[1])
                #FROM HERE ON path_var AND key_var HAVE MATCHING INDEXES DO NOT RANDOMIZE!!!

                #This section here is complicated!
                for path in path_var:
                    img = convert_image(path)
                    img_var.append(np.reshape(np.array(img), 250*250*3))
                #print np.shape(img_var),np.shape(key_var) #img_var is array of size (batch#, 64*64*3) key_var is the key [human, cucumber]

                #End of complicationimage conversion
                _,c = sess.run([optimizer, cost], feed_dict={input_data:img_var, output_data:key_var})
                epoch_error += c
                #print "Batch",i+1,"done out of",num_batches
            print "Epoch",epoch+1,"completed out of",num_epochs,"\tError",epoch_error
            save_path = saver.save(sess, "model.ckpt")

train_network()


def use_network():
    #prediction = output
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())

        saver = tf.train.Saver()
        saver.restore(sess, "model.ckpt")

        for test_file in test_file_array:
            #print test_file
            img = np.reshape(np.array(convert_image(test_file[0])), 250*250*3)
            result = output.eval(feed_dict={input_data:[img]})
            print result,tf.argmax(result,1).eval(),test_file[1]

use_network()

http://pastebin.com/Gp6SVYJR

由于我对使用张量流还很陌生,因此我认为尝试创建一个可以识别人类和黄瓜之间差异的程序是个好主意。我从Image-Net中提取图像,并将人类图片放入human/和黄瓜照片到黄瓜/

我创建了一个我认为该程序正在执行的步骤列表:

  1. 创建文件路径和键的数组,然后进行打乱。

  2. 批量创建文件路径。

  3. 批次中的文件路径将转换为图像,调整大小为 250x250,并添加到图像批次数组中。(此时键和图像仍对齐)。

  4. 图像批次和关键批次输入到阵列中。

  5. 在所有 epoch 结束时,它会针对每个图像 10 个来测试网络。

当我运行 use_network() 时,我在控制台中得到以下输出:

[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.61422414  0.38577583]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [1, 0]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [0, 1]
[[ 0.53653401  0.46346596]] [0] [1, 0]

第一个数组是输出节点,第二个数组是输出的 tf.argmax(),第三个数组是预期的数组。

实际的学习似乎也相当小,这是学习的输出:

Epoch 1 completed out of 50     Error 3762.83390808
Epoch 2 completed out of 50     Error 3758.51748657
Epoch 3 completed out of 50     Error 3753.70425415
Epoch 4 completed out of 50     Error 3748.32539368
Epoch 5 completed out of 50     Error 3742.45524597
Epoch 6 completed out of 50     Error 3736.21272278
Epoch 7 completed out of 50     Error 3729.56756592
...
Epoch 45 completed out of 50    Error 3677.34605408
Epoch 46 completed out of 50    Error 3677.34388733
Epoch 47 completed out of 50    Error 3677.34150696
Epoch 48 completed out of 50    Error 3677.3391571
Epoch 49 completed out of 50    Error 3677.33673096
Epoch 50 completed out of 50    Error 3677.33418274

我尝试做以下事情来尝试改变事情:

  1. 使图像更小,例如 32x32,和/或黑白。看看较小的图像是否会导致预测发生变化。

  2. 更改reduce_sum和reduce_mean之间的成本方程,以及sigmoid_cross_entropy到softmax_cross_entropy之间的内部方程。

关于它为什么不起作用的原因,我有一些想法,如下:

  1. 只是糟糕的代码

  2. 输入数据太大,没有足够的节点/层来处理。

  3. 图像和相关密钥在某处被打乱。


我认为这可能存在一些问题。首先,您正在使用密集连接的层来处理大型图像网络图像。您应该对图像使用卷积网络。我认为这是你最大的问题。只有在应用卷积/池化层金字塔将空间维度减少为“特征”之后,才应该添加密集层。

https://www.tensorflow.org/versions/r0.11/tutorials/deep_cnn/index.html

其次,即使您打算使用密集层,也不应该将 softmax 函数用作隐藏层之间的激活(有一些例外,例如在注意力模型中,但这是一个更高级的概念。)Softmax 强制每个激活的总和在您可能不想要的图层中。我会将隐藏层之间的激活更改为 relu 或至少 tanh。

最后,我发现当网络接近恒定值时,它可以帮助降低学习率。但我认为这不是你的问题。我的前两条评论是你应该关注的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

训练后,TensorFlow 始终会收敛到所有项目的相同输出 的相关文章

随机推荐

  • Bash 脚本 - Do-While 循环中的变量作用域

    我有一个 do while 循环 我在其中向自身添加一个变量 while read line do let variable variable someOtherVariable done return variable 当我回显 vari
  • 使用 C# .accdb 文件的 Microsoft Access 压缩和修复

    我需要使用 C 压缩并修复 accdb 最后一个 MS Access 版本 我尝试使用这个 var jroEngine new JRO JetEngineClass var old Provider Microsoft ACE OLEDB
  • 用 rpy 制作的图发送到 X11 突然关闭?

    我正在使用 RPy2 来绘制一些图 绘图显示 但 X11 窗口立即消失 我输入的内容如下CCFS是一个数据矩阵 import rpy2 robjects as robjects r robjects r pca r princomp CCF
  • 将 mongo ObjectId 转换为字符串并将其用于 URL 可以吗?

    document show id 4cf8ce8a8aad6957ff00005b 一般来说 我认为您应该谨慎向客户端公开内部结构 例如数据库 ID URL 很容易被操纵 并且用户可能访问您不希望他访问的对象 特别是对于 MongoDB 对
  • SQL Server:如何获取排它锁以防止竞争条件?

    我有以下 T SQL 代码 SET TRANSACTION ISOLATION LEVEL SERIALIZABLE BEGIN TRANSACTION T1 Test This is a dummy table used for lock
  • PHP 多个复选框删除

    我很难解决删除多个复选框的问题 有人可以指导我找到解决方案吗 这里应该发生的是 用户可以勾选复选框并单击删除按钮来删除勾选的框 不幸的是 我的代码似乎不起作用 你能为我指出正确的方向吗 div class page img class pa
  • 为什么必须声明 Typescript 的环境接口实现?

    我有一些接口及其实现的定义 每个实现类都必须声明许多方法 我发现它乏味且多余 因为它只是一个定义 是否只是缺乏时间来实现此功能 或者为什么应该强制执行环境实现定义背后的某些想法 或者我错过了什么 UPDATE 我现在不喜欢我的问题 它是从一
  • 这是批处理文件注入吗?

    C gt batinjection OFF DEL c c batinjection bat 的内容为ECHO 我听说过 SQL 注入 虽然我从未真正做过 但这就是注入吗 有不同类型的注射吗 这是其中之一吗 或者还有另一个技术术语吗 或者更
  • 如何覆盖 AWS-SDK-CPP 中的端点以连接到 localhost:9000 处的 minio 服务器

    我尝试过类似的东西 Aws Client ClientConfiguration config config endpointOverride Aws String localhost 9000 这是行不通的 看来AWS SDK CPP默认
  • Pyspark 按另一个数据帧的列过滤数据帧

    不知道为什么我在这方面遇到困难 考虑到在 R 或 pandas 中相当容易做到 它看起来很简单 我想避免使用 pandas 因为我正在处理大量数据 而且我相信toPandas 将所有数据加载到 pyspark 中的驱动程序内存中 我有 2
  • 使用 Jquery 在导航菜单中突出显示父链接

    我使用以下 Jquery 在导航中突出显示当前页面的链接 Add Active Class To Current Link var url window location get current URL nav a href url add
  • 如何在Python中将浮点数格式化为固定宽度

    如何将浮点数格式化为固定宽度并满足以下要求 如果 n 添加尾随小数零以填充固定宽度 截断超过固定宽度的小数位 对齐所有小数点 例如 formatter something like 06 numbers 23 23 0 123334987
  • Apex 5:动态操作设置页面项值

    使用新的 apex 5 版本时 我遇到以下问题 无法通过plsql获取页面项的值 nv P2 TO P2 FROM lt lt lt DOESN T WORK I Yes P FROM exist and verified nv P2 TO
  • SPSS 按行分组并将字符串连接成一个变量

    我试图导出 SPSS 元数据使用 SPSS 语法转换为自定义格式 具有值标签的数据集包含一个或多个变量标签 但是 现在我想将每个变量的值标签连接成一个字符串 例如对于变量SEX将行组合或分组F Female and M Male转化为一个变
  • 在 R 中的 ggplot2 地图上叠加栅格图层?

    我正在尝试将栅格图层叠加到 ggplot 中的地图上 栅格图层包含卫星标签中每个时间点的似然面 我还想在栅格图层上设置累积概率 95 75 50 我已经弄清楚如何在 ggplot 地图上显示栅格图层 但坐标未彼此对齐 我尝试使每个都有相同的
  • 如何在 Netbeans 中使用 -g 选项进行编译?

    调试时 我收到一条有关异常的警告消息 variable info not available compiled without g 如何在 netbeans 中设置使用 g 进行编译 thanks 据我所知你的own代码是用调试信息编译的
  • MVC3 EF 工作单元 + 通用存储库 + Ninject

    我是 MVC3 的新手 一直在关注 asp net 网站上的精彩教程 然而 我不太清楚如何将工作单元和通用存储库模式与 Ninject 结合使用 我使用本教程作为起点 http www asp net mvc tutorials getti
  • 使用设备策略控制器在后台升级应用程序

    我有一个正在运行的 DPC 应用程序 它是设备所有者 我已在两台不同的 Android 6 0 1 设备上尝试过此操作 以排除任何设备 制造商问题 I used adb shell dpm set device owner com exam
  • VB.NET 中默认启用选项 Strict

    每当我创建一个新的 VB NET 程序时 我必须进入该项目的属性并将 Option strict 设置为打开 我可以这样做一次 这样每次创建新项目时它都是默认的吗 在 Visual Studio 中 转到菜单Tools gt Options
  • 训练后,TensorFlow 始终会收敛到所有项目的相同输出

    这是我正在使用的代码片段 import tensorflow as tf import numpy as np from PIL import Image from os import listdir nodes l1 500 nodes