使用 Tensorflow 2.0 进行逻辑回归?

2023-11-22

我正在尝试使用 TensorFlow 2.0 构建多类逻辑回归,并且我编写了我认为正确的代码,但它没有给出好的结果。我的准确率实际上是 0.1%,甚至损失也没有减少。我希望有人能在这里帮助我。

这是我到目前为止编写的代码。请指出我在这里做错了什么,我需要改进,这样我的模型才能工作。感谢您!

from tensorflow.keras.datasets import fashion_mnist
from sklearn.model_selection import train_test_split
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train/255., x_test/255.

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.15)
x_train = tf.reshape(x_train, shape=(-1, 784))
x_test  = tf.reshape(x_test, shape=(-1, 784))

weights = tf.Variable(tf.random.normal(shape=(784, 10), dtype=tf.float64))
biases  = tf.Variable(tf.random.normal(shape=(10,), dtype=tf.float64))

def logistic_regression(x):
    lr = tf.add(tf.matmul(x, weights), biases)
    return tf.nn.sigmoid(lr)

def cross_entropy(y_true, y_pred):
    y_true = tf.one_hot(y_true, 10)
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return tf.reduce_mean(loss)

def accuracy(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int32)
    preds = tf.cast(tf.argmax(y_pred, axis=1), dtype=tf.int32)
    preds = tf.equal(y_true, preds)
    return tf.reduce_mean(tf.cast(preds, dtype=tf.float32))

def grad(x, y):
    with tf.GradientTape() as tape:
        y_pred = logistic_regression(x)
        loss_val = cross_entropy(y, y_pred)
    return tape.gradient(loss_val, [weights, biases])

epochs = 1000
learning_rate = 0.01
batch_size = 128

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.repeat().shuffle(x_train.shape[0]).batch(batch_size)

optimizer = tf.optimizers.SGD(learning_rate)

for epoch, (batch_xs, batch_ys) in enumerate(dataset.take(epochs), 1):
    gradients = grad(batch_xs, batch_ys)
    optimizer.apply_gradients(zip(gradients, [weights, biases]))

    y_pred = logistic_regression(batch_xs)
    loss = cross_entropy(batch_ys, y_pred)
    acc = accuracy(batch_ys, y_pred)
    print("step: %i, loss: %f, accuracy: %f" % (epoch, loss, acc))

    step: 1000, loss: 2.458979, accuracy: 0.101562

该模型没有收敛,问题似乎是您正在执行 sigmoid 激活,然后直接进行tf.nn.softmax_cross_entropy_with_logits。在文档中tf.nn.softmax_cross_entropy_with_logits它说:

警告:此操作需要未缩放的 logits,因为它执行softmax on logits内部为了效率。不要使用以下输出调用此操作softmax,因为它会产生不正确的结果。

因此,在传递到前一层的输出之前,不应对前一层的输出进行 softmax、sigmoid、relu、tanh 或任何其他激活tf.nn.softmax_cross_entropy_with_logits。 有关何时使用 sigmoid 或 softmax 输出激活的更深入描述,请参阅here.

因此通过替换return tf.nn.sigmoid(lr)只用return lr in the logistic_regression函数,模型是收敛的。

下面是经过上述修复的代码的工作示例。我还更改了变量名称epochs to n_batches因为你的训练循环实际上经历了 1000 个批次,而不是 1000 个时期(我也将其提高到 10000,因为有迹象表明需要更多迭代)。

from tensorflow.keras.datasets import fashion_mnist
from sklearn.model_selection import train_test_split
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train/255., x_test/255.

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.15)
x_train = tf.reshape(x_train, shape=(-1, 784))
x_test  = tf.reshape(x_test, shape=(-1, 784))

weights = tf.Variable(tf.random.normal(shape=(784, 10), dtype=tf.float64))
biases  = tf.Variable(tf.random.normal(shape=(10,), dtype=tf.float64))

def logistic_regression(x):
    lr = tf.add(tf.matmul(x, weights), biases)
    #return tf.nn.sigmoid(lr)
    return lr


def cross_entropy(y_true, y_pred):
    y_true = tf.one_hot(y_true, 10)
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return tf.reduce_mean(loss)

def accuracy(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int32)
    preds = tf.cast(tf.argmax(y_pred, axis=1), dtype=tf.int32)
    preds = tf.equal(y_true, preds)
    return tf.reduce_mean(tf.cast(preds, dtype=tf.float32))

def grad(x, y):
    with tf.GradientTape() as tape:
        y_pred = logistic_regression(x)
        loss_val = cross_entropy(y, y_pred)
    return tape.gradient(loss_val, [weights, biases])

n_batches = 10000
learning_rate = 0.01
batch_size = 128

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.repeat().shuffle(x_train.shape[0]).batch(batch_size)

optimizer = tf.optimizers.SGD(learning_rate)

for batch_numb, (batch_xs, batch_ys) in enumerate(dataset.take(n_batches), 1):
    gradients = grad(batch_xs, batch_ys)
    optimizer.apply_gradients(zip(gradients, [weights, biases]))

    y_pred = logistic_regression(batch_xs)
    loss = cross_entropy(batch_ys, y_pred)
    acc = accuracy(batch_ys, y_pred)
    print("Batch number: %i, loss: %f, accuracy: %f" % (batch_numb, loss, acc))

(removed printouts)
>> Batch number: 1000, loss: 2.868473, accuracy: 0.546875
(removed printouts)
>> Batch number: 10000, loss: 1.482554, accuracy: 0.718750
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Tensorflow 2.0 进行逻辑回归? 的相关文章

  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 如何结合pytube和tkinter标签来显示进度?

    我正在编写从 youtube 下载歌曲的小程序 使用 pytube 我想添加 python tkinter GUI 以在下载文件时显示百分比值 现在 当我执行代码时 程序首先下载文件 大约需要 60 秒 然后才显示 100 的标签 如果我希
  • 当变量取特定值时如何使 PyCharm 中断?

    我有一本大字典 其中一些元素偶尔会出现非法值 我想弄清楚非法值从何而来 PyCharm 应该不断监视我的字典的值 一旦它们中的任何一个取了非法值 它就应该中断并让我检查程序的状态 我知道我可以通过为我的字典创建一个 getter sette
  • 处理 Python 行为测试框架中的异常

    我一直在考虑从鼻子转向行为测试 摩卡 柴等已经宠坏了我 到目前为止一切都很好 但除了以下之外 我似乎无法找出任何测试异常的方法 then It throws a KeyError exception def step impl contex
  • Pandas Merge (pd.merge) 如何设置索引和连接

    我有两个 pandas 数据框 dfLeft 和 dfRight 以日期作为索引 dfLeft cusip factorL date 2012 01 03 XXXX 4 5 2012 01 03 YYYY 6 2 2012 01 04 XX
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • 如何使用python在一个文件中写入多行

    如果我知道要写多少行 我就知道如何将多行写入一个文件 但是 当我想写多行时 问题就出现了 但是 我不知道它们会是多少 我正在开发一个应用程序 它从网站上抓取并将结果的链接存储在文本文件中 但是 我们不知道它会回复多少行 我的代码现在如下 r
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Jupyter Notebook 找不到 Python 模块

    不知道发生了什么 但每当我使用 ipython 氢 原子 或 jupyter 笔记本时都找不到任何已安装的模块 我知道我安装了 pandas 但笔记本说找不到 我应该补充一点 当我正常运行脚本时 python script py 它确实导入
  • import matplotlib.pyplot 给出 AttributeError: 'NoneType' 对象没有属性 'is_interactive'

    我尝试在 Pycharm 控制台中导入 matplotlib pyplt import matplotlib pyplot as plt 然后作为回报我得到 Traceback most recent call last File D Pr
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 如何解决 PDFBox 没有 unicode 映射错误?

    我有一个现有的 PDF 文件 我想使用 python 脚本将其转换为 Excel 文件 目前正在使用PDFBox 但是存在多个类似以下错误 org apache pdfbox pdmodel font PDType0Font toUnico
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是

随机推荐

  • OSMdroid 添加自定义图标到 ItemizedOverlay

    我正在使用 ItemizedIconOverlay 类 当前正在地图上显示事件以及具有相同默认图标的用户位置 如何更改每个叠加层的图标集 是否有类似于 google maps 示例的内容 drawable getResources getD
  • Keras:类型错误:无法使用 KerasClassifier pickle _thread.lock 对象

    import pandas as pd import numpy as np import matplotlib pyplot as plt dataset pd read csv Churn Modelling csv X dataset
  • 在选择框中重新填充日期

    我在 Rails 中创建了一个 date select 它有 3 个选择框 一个代表年份 一个代表月份 一个代表日期 2 月 31 日在他们身上是相当令人困惑的 我希望能够只让选择框包含有效日期 我的意思是 当您选择二月时 31 日 30
  • 重建/获取 PHP 函数的源代码

    我可以通过编程方式通过函数名称获取函数的源代码吗 Like function blah a b return a b echo getFunctionCode blah 是否可以 是否有任何 php 自描述函数可以重构函数 类代码 我的意思
  • 如何禁用 Android Studio 3.0 的即时运行

    进行一些更改后 我收到错误 会话 app 安装 APK 时出错 据一些人说 这是因为 Instant Run 在最新的Stable Android Studio 3 0上 在构建 执行 部署我没有任何 即时运行 选项 即使在设置搜索中进行了
  • 如何在C++中将自定义项目添加到系统菜单?

    我需要枚举所有正在运行的应用程序 特别是所有顶部窗户 对于每个窗口 我需要将自定义项目添加到该窗口的系统菜单中 我怎样才能在 C 中实现这一点 Update 我非常乐意为 Windows MacOS 和 Ubuntu 提供解决方案 不过 我
  • 使用钉枪加快 Clojure 启动时间

    我时不时地想用一下会很好clojure for 外壳脚本 但是大约 900ms 的启动时间太慢了 然后我会google首页对于 nailgun clojure 但显示的唯一结果是针对像 vimclojure 这样的特殊情况 那时我假装没有时
  • 从python中的字节中提取LSB位

    我在变量 DATA 中有一个字节 我想从中提取 LSB 位并打印它 我对 python 很陌生 我发现很多文章都有复杂的按位加法逻辑 而且所有这些都很难理解 我正在寻找一个简单的逻辑 就像我们对字符串所做的那样 例如 DATA 7 1 请帮
  • 在 Android 中上传带有进度条且没有 OutOfMemory 错误的大文件

    我在上传大型视频文件 最大 150MB 时遇到问题 1 当我使用此代码时Link1 我可以上传带有进度条的小文件 但是由于我的文件很大 所以android给了我内存不足错误 2 如果我使用这个代码Link2 我能够上传大文件 这确实是一个很
  • 在 C# 中使用 Lisp

    正如很多人指出的那样这个问题 Lisp主要是作为一种学习体验 尽管如此 如果我能以某种方式使用我的 Lisp 算法并将它们与我的 C 程序结合起来 那就太好了 在大学里 我的教授从来没能告诉我如何在程序中使用 Lisp 例程 不 不是用 L
  • 我如何才能收到 Cocoa 应用程序中系统时间更改的通知?

    我有一个可可应用程序 用于记录事件的日期戳 我需要知道系统时间何时重置以及重置多少 但我似乎无法在任何地方发出通知来告诉我发生了这样的事情 由于 NTP 重置时钟或用户重置 例如从系统偏好设置 可能会发生此更改 如果有一个就太好了NSNot
  • 自定义改造 ErrorHandler 给出 UndeclaredThrowableException

    基于这篇文章我应该如何在 Android 上使用 Retrofit 处理 无互联网连接 我做了一个定制ErrorHandler private static class CustomErrorHandler implements Error
  • 如何将字符串解析为java.sql.date

    我有一个字符串 String s 01 NOVEMBER 2012 然后我想将其解析为 sqlDate 并将其插入数据库 是否可以将字符串解析为sql Date 是的 sql日期格式是 yyyy mm dd Use SimpleDateFo
  • 获取模型后渲染 Marionette 区域

    我想使用 Derick Bailey 在 通用问题解决方案 在这个thread获取模型后渲染视图 我将在这里报告他的解决方案 MyView Backbone View extend initialize function this mode
  • 在 ocaml 中输入级别整数

    有人可以给我关于在 OCaml 3 12 中制作类型级整数支持加法和减法运算的建议 建议吗 例如 如果我有这样表示的数字 type zero type a succ type pos1 zero succ type pos2 zero su
  • 如何在Python中从负纪元创建日期时间

    第一次使用 StackExchange 我正在使用 ArcGIS Server 和 Python 在尝试使用地图服务的 REST 端点执行查询时 我在 JSON 响应中获取负纪元中 esriFieldTypeDate 字段的值 JSON 响
  • XamlParseException 无法分配给属性。绑定不适用于附加属性

    我想为 Windows 应用商店应用程序创建带有附加属性的自定义文本框 我正在关注这个解决方案 现在它使用硬编码值作为属性值 但我想使用绑定来设置值 但它不起作用 我尝试搜索很多但没有帮助我任何解决方案 异常详细信息是这样的 Windows
  • CSV 的替代品?

    我打算构建一个 RESTful 服务 它将返回自定义文本格式 鉴于我的数据量非常大 XML JSON 太冗长了 我正在寻找一种基于行的文本格式 CSV 是一个明显的候选者 不过我想知道是否还有更好的东西 我通过一些研究唯一发现的是CTX a
  • 正确安装 mingw-get - mingw/msys 路径缺失以及更多!

    我运行的是Windows XP 我一直在关注本教程所以下载 mingw get insthere 我已经这样做过几次了 最后一次我检查了 boes 以安装所有内容 包括但不限于 gcc g MSYS 和 MinGW 编译套件 我告诉它也创建
  • 使用 Tensorflow 2.0 进行逻辑回归?

    我正在尝试使用 TensorFlow 2 0 构建多类逻辑回归 并且我编写了我认为正确的代码 但它没有给出好的结果 我的准确率实际上是 0 1 甚至损失也没有减少 我希望有人能在这里帮助我 这是我到目前为止编写的代码 请指出我在这里做错了什