目标数组形状与使用 Tensorflow 的预期输出不同

2024-03-18

我正在尝试制作 CNN(仍然是初学者)。当尝试拟合模型时,我收到此错误:

ValueError:形状为 (10000, 10) 的目标数组被传递用于形状 (None, 6, 6, 10) 的输出,同时用作损失categorical_crossentropy。这种损失期望目标具有与输出相同的形状。

标签的形状 = (10000, 10) 图像数据的形状 = (10000, 32, 32, 3)

Code:

import pickle
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Dense, Dropout, Activation, Flatten, 
                                     Conv2D, MaxPooling2D)
from tensorflow.keras.callbacks import TensorBoard
from keras.utils import to_categorical
import numpy as np
import time

MODEL_NAME = f"_________{int(time.time())}"
BATCH_SIZE = 64

class ConvolutionalNetwork():
    '''
    A convolutional neural network to be used to classify images
    from the CIFAR-10 dataset.
    '''

    def __init__(self):
        '''
        self.training_images -- a 10000x3072 numpy array of uint8s. Each 
                                a row of the array stores a 32x32 colour image. 
                                The first 1024 entries contain the red channel 
                                values, the next 1024 the green, and the final 
                                1024 the blue. The image is stored in row-major 
                                order, so that the first 32 entries of the array are the red channel values of the first row of the image.
        self.training_labels -- a list of 10000 numbers in the range 0-9. 
                                The number at index I indicates the label 
                                of the ith image in the array data.
        '''
        # List of image categories
        self.label_names = (self.unpickle("cifar-10-batches-py/batches.meta",
                            encoding='utf-8')['label_names'])

        self.training_data = self.unpickle("cifar-10-batches-py/data_batch_1")
        self.training_images = self.training_data[b'data']
        self.training_labels = self.training_data[b'labels']

        # Reshaping the images + scaling 
        self.shape_images()  

        # Converts labels to one-hot
        self.training_labels = np.array(to_categorical(self.training_labels))

        self.create_model()

        self.tensorboard = TensorBoard(log_dir=f'logs/{MODEL_NAME}')

    def unpickle(self, file, encoding='bytes'):
        '''
        Unpickles the dataset files.
        '''
        with open(file, 'rb') as fo:
            training_dict = pickle.load(fo, encoding=encoding)
        return training_dict

    def shape_images(self):
        '''
        Reshapes the images and scales by 255.
        '''
        images = list()
        for d in self.training_images:
            image = np.zeros((32,32,3), dtype=np.uint8)
            image[...,0] = np.reshape(d[:1024], (32,32)) # Red channel
            image[...,1] = np.reshape(d[1024:2048], (32,32)) # Green channel
            image[...,2] = np.reshape(d[2048:], (32,32)) # Blue channel
            images.append(image)

        for i in range(len(images)):
            images[i] = images[i]/255

        images = np.array(images)
        self.training_images = images
        print(self.training_images.shape)

    def create_model(self):
        '''
        Creating the ConvNet model.
        '''
        self.model = Sequential()
        self.model.add(Conv2D(64, (3, 3), input_shape=self.training_images.shape[1:]))
        self.model.add(Activation("relu"))
        self.model.add(MaxPooling2D(pool_size=(2,2)))

        self.model.add(Conv2D(64, (3,3)))
        self.model.add(Activation("relu"))
        self.model.add(MaxPooling2D(pool_size=(2,2)))

        # self.model.add(Flatten())
        # self.model.add(Dense(64))
        # self.model.add(Activation('relu'))

        self.model.add(Dense(10))
        self.model.add(Activation(activation='softmax'))

        self.model.compile(loss="categorical_crossentropy", optimizer="adam", 
                           metrics=['accuracy'])

    def train(self):
        '''
        Fits the model.
        '''
        print(self.training_images.shape)
        print(self.training_labels.shape)
        self.model.fit(self.training_images, self.training_labels, batch_size=BATCH_SIZE, 
                       validation_split=0.1, epochs=5, callbacks=[self.tensorboard])


network = ConvolutionalNetwork()
network.train()

感谢您的帮助,已经尝试修复一个小时了。


您需要取消注释Flatten创建模型时的图层。本质上,该层的作用是接受 4D 输入(batch_size, height, width, num_filters)并将其展开为 2D 的(batch_size, height * width * num_filters)。这是获得您想要的输出形状所必需的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

目标数组形状与使用 Tensorflow 的预期输出不同 的相关文章

  • 将 transaction.commit_manually() 升级到 Django > 1.6

    我继承了为 Django 1 4 编写的应用程序的一些代码 我们需要更新代码库以使用 Django 1 7 并最终更新到 1 8 作为下一个长期支持版本 在一些地方它使用旧风格 transaction commit manually and
  • 在 Pandas 中,如何从基于另一个数据框的数据框中删除行?

    我有 2 个数据框 一个名为 USERS 另一个名为 EXCLUDE 他们都有一个名为 电子邮件 的字段 基本上 我想删除 USERS 中包含 EXCLUDE 中包含电子邮件的每一行 我该怎么做 您可以使用boolean indexing
  • 使用 Python 创建 MIDI

    本质上 我正在尝试从头开始创建 MIDI 并将它们放到网上 我对不同的语言持开放态度 但更喜欢使用Python 两种语言之一 如果这有什么区别的话 并且想知道我应该使用哪个库 提前致谢 看起来这就是您正在寻找的 适用于 Python 的简单
  • TensorFlow:带有轴选项的 bincount

    在 TensorFlow 中 我可以使用 tf bincount 获取数组中每个元素的计数 x tf placeholder tf int32 None freq tf bincount x tf Session run freq feed
  • 在 python 3 中使用子进程

    我使用 subprocess 模块在 python 3 中运行 shell 命令 这是我的代码 import subprocess filename somename py in practical i m using a real fil
  • cv2.drawContours() - 取消填充字符内的圆圈(Python,OpenCV)

    根据 Silencer的建议 我使用了他发布的代码here https stackoverflow com questions 48244328 copy shape to blank canvas opencv python 482465
  • 小部件之间的自定义信号

    尝试将信号从一个 gtk EventBox 子级发送到另一个 在 init HeadMode 第 75 行 上出现错误 类型错误 未知信号名称 消息发送 why usr bin env python coding utf8 import p
  • 编辑 Jupyter Notebook 时 VS Code 中缺少“在选择中查找”

    使用 Jupyter Notebook 时 VSCode 中缺少 在选择中查找 按钮 它会减慢开发速度 所以我想请问有人知道如何激活它吗 第一张图显示了在 python 文件中的搜索 替换 第二张图显示了笔记本电脑中缺少的按钮 Python
  • 在相同任务上,Keras 比 TensorFlow 慢

    我正在使用 Python 运行斩首 DCNN 本例中为 Inception V3 来获取图像特征 我使用的是 Anaconda Py3 6 和 Windows7 使用 TensorFlow 时 我将会话保存在变量中 感谢 jdehesa 并
  • Alembic:如何迁移模型中的自定义类型?

    My User模型是 class User UserMixin db Model tablename users noinspection PyShadowingBuiltins uuid Column uuid GUID default
  • 设置 verify_certs=False 但 elasticsearch.Elasticsearch 因证书验证失败而引发 SSL 错误

    self host KibanaProxy 自我端口 443 self user 测试 self password 测试 我需要禁止证书验证 使用选项时它与curl一起使用 k在命令行上 但是 在使用 Elasticsearch pytho
  • 对使用 importlib.util 导入的对象进行酸洗

    我在使用Python的pickle时遇到了一个问题 我需要通过将文件路径提供给 importlib util 来加载一些 Python 模块 如下所示 import importlib util spec importlib util sp
  • Pandas 堆积条形图中元素的排序

    我正在尝试绘制有关某个地区 5 个地区的家庭在特定行业赚取的收入比例的信息 我使用 groupby 按地区对数据框中的信息进行排序 df df orig groupby District Portion of income value co
  • Python:我不明白 sum() 的完整用法

    当然 我明白你使用 sum 与几个数字 然后它总结所有 但我正在查看它的文档 我发现了这一点 sum iterable start 第二个参数 start 的作用是什么 这太尴尬了 但我似乎无法通过谷歌找到任何示例 并且对于尝试学习该语言的
  • Flask 应用程序的测试覆盖率不起作用

    您好 想在终端的 Flask 应用程序中测试 删除路由 我可以看到测试已经过去 它说 test user delete test app LayoutTestCase ok 但是当我打开封面时 它仍然是红色的 这意味着没有覆盖它 请有人向我
  • OSX 上的 locale.getlocale() 问题

    我需要获取系统区域设置来执行许多操作 最终我想使用 gettext 翻译我的应用程序 我打算在 Linux 和 OSX 上分发它 但我在 OSX Snow Leopard 上遇到了问题 python Python 2 5 2 r252 60
  • 使用 Python 将对象列表转为 JSON

    我在转换时遇到问题Object实例到 JSON ob Object list name scaping myObj base url u number page for ob in list name json string json du
  • 使用Multiprocessing和Pool时如何访问全局变量?

    我试图避免将变量冗余地传递到dataList e g 1 globalDict 2 globalDict 3 globalDict 并在全球范围内使用它们 global globalDict然而 在下面的代码中并不是这样做的解决方案 是否有
  • tkinter:打开一个带有按钮提示的新窗口[关闭]

    Closed 这个问题需要调试细节 help minimal reproducible example 目前不接受答案 用户如何按下 tkinter GUI 中的按钮来打开新窗口 我只需要非常简单的解决方案 如果代码也能被解释那就太好了 这
  • python 中的 after() 与 update()

    我是 python 新手 开始使用 tkinter 作为画布 到目前为止 我使用 update 来更新我的画布 但还有一个 after 方法 谁能给我解释一下这个函数 请举个例子 两者之间有什么区别 root after integer c

随机推荐

  • 如何在 po gettext 文件中将空翻译 (msgstr) 标记为已翻译?

    我发现字符串 msgid 的翻译为空 所有 gettext 工具都会将该字符串视为未翻译 有解决方法吗 我确实想要一个空字符串作为该项目的翻译 由于这似乎是 gettext 规范中的一个很大的设计缺陷 我决定使用 Unicode Chara
  • Spark Streaming数据放入HBase的问题

    我是这个领域的初学者 所以我无法理解它 HBase 版本 0 98 24 hadoop2 火花版本 2 1 0 以下代码尝试将从 Spark Streming Kafka 生产者接收的数据放入 HBase 中 Kafka输入数据格式是这样的
  • 点“.”的 java keyevent 字段是什么?

    我知道如何使用 keyevent 调用 1 应该像 aaa keyPress KeyEvent VK 1 现在我需要输入 点 但我找不到 KeyEvent VK DOT 或一些类似的命令 请帮忙 Thanks 这个 点 被称为period
  • 如何使用带有条纹元素的引导浮动标签?

    我想知道如何使用浮动标签设置条纹元素的样式 bootstrap 5 我的所有其他字段都采用这种方式设计 因此最好对信用卡输入和 cvv 输入进行设计 以匹配我网站的主题 我尝试过使用以下答案 如何使用 Bootstrap 设置 Stripe
  • 从本地开发环境访问ElastiCache memcache实例

    有没有办法从本地开发环境访问缓存节点 尽管可以从 EC2 实例访问相同的缓存节点 我正在使用带有 C 的 Enyim memcache 客户端库 我发现很少有文章说这是不可能的 那么最好的方法应该是什么 我是否需要在本地设置内存缓存以进行开
  • 最流行的 C 通用集合数据结构库是什么?

    我正在寻找一个提供通用集合数据结构 例如列表 关联数组 集合等 的 C 库 该库应该稳定且经过良好测试 我基本上是在寻找比蹩脚的 C 标准库更好的东西 哪些 C 库符合此描述 编辑 我希望该库是跨平台的 但如果做不到这一点 任何可以在 Ma
  • 将数据存储在自定义字段中或将附件存储在 ics iCal 文件中

    我需要为我手动构建的 iCal 文件 ics 提供一些我实际上不希望日历应用程序用户看到的附加信息 因此 当我在 iOS 应用程序中创建事件并 稍后 从日历事件中读取它们时 我需要能够手动设置它们 我想知道是否可以将自定义字段 属性添加到
  • 使用 dplyr 进行 SQL in-db 操作时的 ifelse 和 grepl 命令

    在R数据帧上运行的dplyr中 很容易运行 df lt df gt mutate income topcoded ifelse income gt topcode income topcode 我现在正在使用一个大型 SQL 数据库 使用
  • SharePoint Designer 动态重新格式化 HTML,是否可以禁用?

    在我彻底放弃之前 我一直在尝试修改 SharePoint Designer 中的一些母版页 每当我更改 HTML 标记时 它都会根据需要重新设置它们的格式 例如 我试图使代码可读 因此我将项目移动到自己的行等 一旦我保存 它就会将所有内容移
  • 将数据从 s3 复制到带有前缀的本地

    我正在尝试使用 aws cli 将数据从 s3 复制到带有前缀的本地 但我在使用不同的正则表达式时遇到错误 aws s3 cp s3 my bucket name RAW TIMESTAMP 0506 profile prod error
  • DirectQuery 模式下的 AAS 表格模型性能优势

    假设您有 10 个相当大的事实表 每个 50 100 GB 应该使用 Power BI 进行查询 它们不适合 Azure Analysis Services RAM 价格合理 因此 为了使用表格模型和 AAS 您必须使用以下模式 1 Pow
  • 如何在 Playframework 中将 Oracle 存储过程与 Scala Anorm 结合使用

    我有许多存储过程 其结果是字符串列表 我如何使用scala访问play 2 0框架中的refcurser 有人可以举一个简单的例子 我如何填写一个列表吗 我试过这个 case class XXXX name String descripti
  • 为什么 UIView 中有一个框架矩形和一个边界矩形?

    好吧 虽然已经是深夜了 但我不明白为什么有两个不同的矩形 frame and bounds 据我了解 一个矩形就足以完成所有操作 相对于另一个坐标系定位视图本身 然后将其内容剪切到指定的大小 你还想用两个矩形做什么 他们如何相互作用 有人有
  • 通过循环在renderUI中创建Value Box

    我想根据我拥有的数据创建一个值框 假设我有 5 个数据变量consumerdata像这样 id data number1 number2 1 k4j A 67 53 2 rls B 30 62 3 yv9 C 45 28 4 l6h D 6
  • 如何在 Eclipse 中使用 SonarLint

    我被分配使用 SonarQube 来提高代码质量 但是当我将它的插件下载到 Eclipse 时 我知道它已被弃用 新的 插件是 SonarLint 但到目前为止 我找不到任何关于如何使用 SonarLint 的好的文档 如何使用它检查jav
  • Delphi 2010远程调试-无法使断点工作

    我最近发布了这个问题 https stackoverflow com questions 4579654 no breakpoints when remote debugging with delphi 2010 so stuck on d
  • 如何从C中的文件中读取最后n行

    这是一道微软面试题 使用 C 读取文件的最后 n 行 精确地 实现这一目标的方法有很多 但其中很少有 gt 最简单的是 在第一遍中 计算文件中的行数 在第二遍中显示最后 n 行 gt 或者可以为每一行维护一个双向链表 并通过向后遍历链表直到
  • 查询 Firestore 文档中的参考字段

    我正在尝试编写一个函数 在文档 Firestore 艺术家 集合中 中的数据发生更改后 Google Cloud Functions 将查找另一个集合 显示 中具有引用字段的所有文档 artist 指向刚刚更改的文档 在 artists 集
  • XMLHttpRequest 已弃用。用什么代替?

    尝试使用纯 JS 方法来检查我是否有有效的 JS 图像 url 我收到警告XMLHttpRequest已弃用 有什么更好的方法来做到这一点 urlExists url const http new XMLHttpRequest http o
  • 目标数组形状与使用 Tensorflow 的预期输出不同

    我正在尝试制作 CNN 仍然是初学者 当尝试拟合模型时 我收到此错误 ValueError 形状为 10000 10 的目标数组被传递用于形状 None 6 6 10 的输出 同时用作损失categorical crossentropy 这