TensorFlow - 为什么这个 softmax 回归没有学到任何东西?

2024-05-06

我的目标是用 TensorFlow 做大事,但我正在尝试从小事做起。

我有一些小的灰度方块(有一点噪音),我想根据它们的颜色对它们进行分类(例如 3 个类别:黑色、灰色、白色)。我编写了一个小 Python 类来生成正方形和 1-hot 向量,并修改了它们的基本 MNIST 示例以将它们输入。

但它不会学到任何东西 - 例如对于 3 个类别,它的猜测总正确率约为 33%。

import tensorflow as tf
import generate_data.generate_greyscale

data_generator = generate_data.generate_greyscale.GenerateGreyScale(28, 28, 3, 0.05)
ds = data_generator.generate_data(10000)
ds_validation = data_generator.generate_data(500)
xs = ds[0]
ys = ds[1]
num_categories = data_generator.num_categories

x = tf.placeholder("float", [None, 28*28])
W = tf.Variable(tf.zeros([28*28, num_categories]))
b = tf.Variable(tf.zeros([num_categories]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,num_categories])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# let batch_size = 100 --> therefore there are 100 batches of training data
xs = xs.reshape(100, 100, 28*28) # reshape into 100 minibatches of size 100
ys = ys.reshape((100, 100, num_categories)) # reshape into 100 minibatches of size 100

for i in range(100):
  batch_xs = xs[i]
  batch_ys = ys[i]
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

xs_validation = ds_validation[0]
ys_validation = ds_validation[1]
print sess.run(accuracy, feed_dict={x: xs_validation, y_: ys_validation})

我的数据生成器如下所示:

import numpy as np
import random

class GenerateGreyScale():
    def __init__(self, num_rows, num_cols, num_categories, noise):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.num_categories = num_categories
        # set a level of noisiness for the data
        self.noise = noise

    def generate_label(self):
        lab = np.zeros(self.num_categories)
        lab[random.randint(0, self.num_categories-1)] = 1
        return lab

    def generate_datum(self, lab):
        i = np.where(lab==1)[0][0]
        frac = float(1)/(self.num_categories-1) * i
        arr = np.random.uniform(max(0, frac-self.noise), min(1, frac+self.noise), self.num_rows*self.num_cols)
        return arr

    def generate_data(self, num):
        data_arr = np.zeros((num, self.num_rows*self.num_cols))
        label_arr = np.zeros((num, self.num_categories))
        for i in range(0, num):
            label = self.generate_label()
            datum = self.generate_datum(label)
            data_arr[i] = datum
            label_arr[i] = label
        #data_arr = data_arr.astype(np.float32)
        #label_arr = label_arr.astype(np.float32)
        return data_arr, label_arr

对于初学者,尝试使用随机值(而不是零)初始化 W 矩阵 - 当所有输入的输出全为零时,您不会为优化器提供任何可处理的内容。

代替:

W = tf.Variable(tf.zeros([28*28, num_categories]))

Try:

W = tf.Variable(tf.truncated_normal([28*28, num_categories],
                                    stddev=0.1))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow - 为什么这个 softmax 回归没有学到任何东西? 的相关文章

  • 在 sympy 绘图中,如何获得具有固定纵横比的绘图?

    如果我用这个片段画一个圆 from sympy import x y symbols x y p1 plot implicit Eq x 2 y 2 1 aspect ratio 1 1 我会得到一个像这样的图形窗口 现在长宽比不是我所期望
  • goJS 下拉菜单删除项目

    我有简单的 python Flask goJS 图形应用程序 如下所示 节点和链接文本的源是从应用程序的后端加载的 我将它们设置为model modelData像这样的部分 var graphDataString JSON parse di
  • 如何在Python中找到低精度浮点值的原始文本表示?

    我遇到了显示问题floatPython 中的值 从外部数据源加载 它们是 32 位浮点数 但这也适用于较低精度的浮点数 以防万一 这些值是由人类在 C C 中输入的 因此与任意计算值不同 与round数字很 可能not预期的 但不能被忽略
  • Python 如果 kwargs 中的 key 并且 key 为 true

    if force in kwargs and kwargs force is True 感觉应该有更好的方法来编写这个条件 因为我重复了键和变量 假设您确实想检查返回的关键字参数是否is True 这是另一种稍微不同的方式 if kwarg
  • 如何从数据库模式自动生成示例 Django 应用程序?

    我正在评估概念验证应用程序的框架 该应用程序的生命周期约为 30 天 之后它将被遗忘或完全重写 我已确定要从现有数据库模式自动生成示例应用程序 然后调整视觉设计的某些方面 我看过一个演示红宝石 on Rails 它会为数据库中的每个表自动生
  • 如何在模型 Django 中创建必需:布尔字段

    我有一个模型 其中有一个名为的字段is student and is teacher Student and Teacher forms is teacher models BooleanField teacher status defau
  • tkinter 上的“NoneType”对象没有属性“get”错误[重复]

    这个问题在这里已经有答案了 我最近开始使用 python 3 6 进行编码tkinter并尝试创建我自己的项目repl it 该项目是一个简单的交互式待办事项列表 但是我陷入困境并且无法使该功能正常工作 该函数只是简单地获取条目并将其添加到
  • Python绕相机轴旋转图像

    假设我有一个图像 是在对某些原始图像应用单应性变换 H 后获得的 未显示原始图像 将单应性 H 应用于原始图像的结果是该图像 我想围绕合适的轴 可能是相机所在的位置 如果有的话 将此图像旋转 30 度以获得此图像 如果我不知道相机参数 如何
  • 如何逐行替换(更新)文件中的文本

    我试图通过读取每一行 测试它 然后写入是否需要更新来替换文本文件中的文本 我不想保存为新文件 因为我的脚本已经先备份文件并对备份进行操作 这是我到目前为止所拥有的 我从 os walk 获取路径 并且保证 pathmatch var 正确返
  • 将Python嵌入到C中——导入模块

    我在使用嵌入式 Python for C 时遇到问题文档 http docs python org extending embedding html 每当我尝试使用导入的模块时 我都会得到 PythonIncl exe 中 0x1e089e
  • PyInstaller 是否包含 CUDA

    我正在开发一个Python脚本 我使用Python 3 7 3 它使用tensorflow gpu 1 14 0 并使用PyInstaller 3 5将此脚本转换为可执行文件 我使用的是 CUDA 10 0 和 cuDNN 7 6 1 我的
  • UTF-8 解码如何知道字节边界?

    我一直在阅读大量有关 unicode 编码的文章 尤其是有关 Python 的文章 我想我现在对此已经有了相当深入的了解 但仍有一个小细节我有点不确定 解码如何知道字节边界 例如 假设我有一个带有两个 unicode 字符的 unicode
  • Tensorflow:Cuda 计算能力 3.0。所需的最低 Cuda 能力为 3.5

    我正在从源安装tensorflow 文档 https www tensorflow org versions r0 10 get started os setup html installing from sources Cuda驱动版本
  • Web 应用程序框架:C++ 与 Python

    作为一名程序员 我熟悉 Python 和 C 我正在考虑编写自己的简单 Web 应用程序 并且想知道哪种语言更适合服务器端 Web 开发 我正在寻找一些东西 它必须是直观的 我认识到 Wt 存在并且它遵循 Qt 的模型 我讨厌 Qt 的一件
  • 有什么理由不在Python中混合使用多处理和线程模块

    我正在考虑使用Python来实现一个需要大量多线程的程序 另一个要求是它将在桌面上运行 因此拥有许多进程将使应用程序显得混乱且难以杀死 在任务管理器中 因此 我正在考虑使用线程和多处理模块来减少进程数量 据我了解 GIL 仅适用于单个进程
  • Python 柯里化任意数量的变量

    我正在尝试使用柯里化在 Python 中进行简单的函数添加 我找到了这个咖喱装饰器here https gist github com JulienPalard 021f1c7332507d6a494b def curry func def
  • 如何使 cx-oracle 将查询结果绑定到字典而不是元组?

    这是我的代码 我想找到一种方法将查询结果作为字典列表而不是元组列表返回 看起来 cx oracle 通过部分文档讨论 绑定 来支持这一点 虽然我不知道它是如何工作的 def connect dsn cx Oracle makedsn hos
  • 写入文件的正确方法?

    我想知道这样做是否有什么区别 var1 open filename w write Hello world 并做 var1 open filename w var1 write Hello world var1 close 我发现没有必要
  • Django Python - LDAP 身份验证

    我目前正在研究 Django Python 我的目标是从 Ldap 目录对用户进行身份验证 我确实有 python 代码来访问 ldap 目录并检索信息 Code import ldap try l ldap open ldap forum
  • 我收到错误:rest_framework.request.WrappedAttributeError:'CSRFCheck'对象没有属性'process_request'

    urls py from django conf urls import url from django contrib import admin from django conf import settings from django c

随机推荐