元组索引超出范围,Tensorflow

2024-01-03

这是模型。它是基本的张量流模型,可以拍摄数字的图片并告诉您它是什么数字。* 我知道python中的索引从0开始。我遇到的问题是这行代码“model.fit(np.array(test), np.array(num))”。阅读下面的代码以获取更多信息。 *

import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
for train in range(len(x_train)):
    for row in range(28):
        for x in range(28):
            if x_train[train][row][x] != 0:
                x_train[train][row][x] = 1

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.save('epic_num_reader.model')
print("Model saved")

在下面的代码中,函数“user_train”给出了错误。特别是“model.fit(np.array(test), np.array(num))”行。代码会弹出一个框,让您绘制一个数字,一旦您单击空格键,模型就会尝试找出您绘制的内容。我想让你可以画一些东西,然后用你画的东西训练模型。

import sys, os, random
stdout = sys.__stdout__
stderr = sys.__stderr__
sys.stdout = open(os.devnull,'w')
sys.stderr = open(os.devnull,'w')
import pygame
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tkinter import *
from tkinter import messagebox
sys.stdout = stdout
sys.stderr = stderr

class pixel(object):
def __init__(self, x, y, width, height):
    self.x = x
    self.y = y
    self.width = width
    self.height = height
    self.color = (255,255,255)
    self.neighbors = []

def draw(self, surface):
    pygame.draw.rect(surface, self.color, (self.x, self.y, self.x + self.width, self.y + self.height))

def getNeighbors(self, g):
    # Get the neighbours of each pixel in the grid, this is used for drawing thicker lines
    j = self.x // 20 # the var i is responsible for denoting the current col value in the grid
    i = self.y // 20 # the var j is responsible for denoting thr current row value in the grid
    rows = 28
    cols = 28

    # Horizontal and vertical neighbors
    if i < cols - 1:  # Right
        self.neighbors.append(g.pixels[i + 1][j])
    if i > 0:  # Left
        self.neighbors.append(g.pixels[i - 1][j])
    if j < rows - 1:  # Up
        self.neighbors.append(g.pixels[i][j + 1])
    if j > 0:  # Down
        self.neighbors.append(g.pixels[i][j - 1])

    # Diagonal neighbors
    if j > 0 and i > 0:  # Top Left
        self.neighbors.append(g.pixels[i - 1][j - 1])

    if j + 1 < rows and i > -1 and i - 1 > 0:  # Bottom Left
        self.neighbors.append(g.pixels[i - 1][j + 1])

    if j - 1 < rows and i < cols - 1 and j - 1 > 0:  # Top Right
        self.neighbors.append(g.pixels[i + 1][j - 1])

    if j < rows - 1 and i < cols - 1:  # Bottom Right
        self.neighbors.append(g.pixels[i + 1][j + 1])


class grid(object):
pixels = []

def __init__(self, row, col, width, height):
    self.rows = row
    self.cols = col
    self.len = row * col
    self.width = width
    self.height = height
    self.generatePixels()
    pass

def draw(self, surface):
    for row in self.pixels:
        for col in row:
            col.draw(surface)

def generatePixels(self):
    x_gap = self.width // self.cols
    y_gap = self.height // self.rows
    self.pixels = []
    for r in range(self.rows):
        self.pixels.append([])
        for c in range(self.cols):
            self.pixels[r].append(pixel(x_gap * c, y_gap * r, x_gap, y_gap))

    for r in range(self.rows):
        for c in range(self.cols):
            self.pixels[r][c].getNeighbors(self)

def clicked(self, pos): #Return the position in the grid that user clicked on
    try:
        t = pos[0]
        w = pos[1]
        g1 = int(t) // self.pixels[0][0].width
        g1 = int(t) // self.pixels[0][0].width
        g2 = int(w) // self.pixels[0][0].height

        return self.pixels[g2][g1]
    except:
        pass

def convert_binary(self):
    li = self.pixels

    newMatrix = [[] for x in range(len(li))]

    for i in range(len(li)):
        for j in range(len(li[i])):
            if li[i][j].color == (255,255,255):
                newMatrix[i].append(0)
            else:
                newMatrix[i].append(1)

    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_test = tf.keras.utils.normalize(x_test, axis=1)
    for row in range(28):
        for x in range(28):
            x_test[0][row][x] = newMatrix[row][x]

    return x_test[:1]


def guess(li):
    model = tf.keras.models.load_model('epic_num_reader.model')

    predictions = model.predict(li)
    print(predictions[0])
    t = (np.argmax(predictions[0]))
    print("I predict this number is a:", t)
    window = Tk()
    window.withdraw()
    messagebox.showinfo("Prediction", "I predict this number is a: " + str(t))
    window.destroy()
    #plt.imshow(li[0], cmap=plt.cm.binary)
    #plt.show()


############################

### Function with error ####
def user_train(test, num):
    model = tf.keras.models.load_model('epic_num_reader.model')
    test = np.array(test)
    test = np.reshape(test, (28,28))
    model.fit(np.array(test), np.array(num))
    model.save('epic_num_reader.model')
############################

############################


def main():
    run = True
    while run:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE:
                    li = g.convert_binary()
                    guess(li)
                    g.generatePixels()
                elif event.key == pygame.K_0:
                    test = g.convert_binary()
                    user_train(test, 0)
                    g.generatePixels()

            if pygame.mouse.get_pressed()[0]:

                pos = pygame.mouse.get_pos()
                clicked = g.clicked(pos)
                clicked.color = (0,0,0)
                for n in clicked.neighbors:
                    n.color = (0,0,0)

            if pygame.mouse.get_pressed()[2]:
                try:
                    pos = pygame.mouse.get_pos()
                    clicked = g.clicked(pos)
                    clicked.color = (255,255,255)
                except:
                    pass

        g.draw(win)
        pygame.display.update()

pygame.init()
width = height = 560
win = pygame.display.set_mode((width, height))
pygame.display.set_caption("Number Guesser")
g = grid(28, 28, width, height)
main()

pygame.quit()
quit()

这是完整的错误:

Traceback (most recent call last):
File "D:/Users/user/AppData/Local/Programs/Pycharm/numbersML/drawNumber.py", line 184, in <module>
main()
File "D:/Users/user/AppData/Local/Programs/Pycharm/numbersML/drawNumber.py", line 157, in main
user_train(test, 0)
File "D:/Users/user/AppData/Local/Programs/Pycharm/numbersML/drawNumber.py", line 140, in user_train
model.fit(np.array(test), np.array(num))
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 718, in fit
use_multiprocessing=use_multiprocessing)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_v2.py", line 235, in fit
use_multiprocessing=use_multiprocessing)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_v2.py", line 582, in _process_training_inputs
use_multiprocessing=use_multiprocessing)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_v2.py", line 635, in _process_inputs
x, y, sample_weight=sample_weights)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2186, in _standardize_user_data
batch_size=batch_size)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2281, in _standardize_tensors
training_utils.check_array_lengths(x, y, sample_weights)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 730, in check_array_lengths
set_y = set_of_lengths(targets)
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 725, in set_of_lengths
for y in x
File "D:\Users\user\AppData\Local\Programs\Pycharm\numbersML\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 726, in <listcomp>
if y is not None and not is_tensor_or_composite_tensor(y)
IndexError: tuple index out of range

问题是,“测试”变量只是一个二维数组,而张量流期望数组中有一个二维数组。
This:
test = np.array(test) model.fit(np.array(test), np.array(num))
变成这样:
test = [test] num = [num] model.fit(np.array(test), np.array(num))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

元组索引超出范围,Tensorflow 的相关文章

  • 使用 glGetFloatv 检索 pyglet 中的模型视图矩阵

    我正在使用 pyglet 在 python 中进行 3D 可视化 并且需要检索模型视图和投影矩阵来进行一些选择 我使用以下方式定义我的窗口 from pyglet gl import from pyglet window import wi
  • 为什么我会得到“ufunc 'multiply' did not contains a loop with Signature Matching types dtype('S32') dtype('S32') dtype('S32')”,其值来自 raw_

    我正在尝试创建一个非常简单的程序 它将绘制一个抛物线 其中v是速度 a是加速度和x是时间 用户将输入值v and a then v and a and x将决定y 我试图用这个来做到这一点 x np linspace 0 9 10 a ra
  • 如果每个区域内至少有 5 个连续行,如何在每个标题区域的末尾使用 Title[Name]2 发布新行?

    我想在每个 Title 区域的末尾使用 Title Name 2 发布新行的最简单方法是通过一个计算连续行数的变量 其中至少有 5 个连续行包含 1 1 1 1在每个 标题区域内 我不确定我对计数变量做错了什么 也许 确实必须在每个 Tit
  • 使用 Matplotlib 的范围绘制图像的 3D 轮廓

    正如我所介绍的here https stackoverflow com questions 18792624 fits image input to a range in plot python 在二维中 我想知道如何 缩放 要绘制到绘图中
  • 使用pip安装pylibmc时出错

    您好 当我尝试使用 pip 在 OSX Lion 上安装 pylibmc 时 出现以下错误 pylibmcmodule h 42 10 fatal error libmemcached memcached h file not found
  • Python 中嵌套列表的排序和分组

    我有以下数据结构 列表的列表 4 21 1 14 2008 10 24 15 42 58 3 22 4 2somename 2008 10 24 15 22 03 5 21 3 19 2008 10 24 15 45 45 6 21 1 1
  • 完全定制的Python帮助用法

    我正在尝试使用 Python 创建完全自定义的 帮助 用法 我计划将其导入到许多我想要具有风格一致性的程序中 但遇到了一些麻烦 我不知道为什么我的描述忽略换行符 尝试过 和 我无法让 出现在 ARGS 行的 换行符之后 显然它们坐在自己的行
  • 如何使直方图列的宽度都相同

    我在操作直方图时遇到了一些麻烦 我有一个包含两列的 df 我将它们绘制为堆叠直方图 我将它们放入特定的垃圾箱中 请参阅下面的代码 但我想在最后制作一个大垃圾箱 4000 10000 但是 默认情况下 大垃圾箱的列宽很大 有没有办法让这个大垃
  • 模拟类:Mock() 还是 patch()?

    我在用mock http www voidspace org uk python mock index html使用Python 想知道这两种方法中哪一种更好 阅读 更Pythonic 方法一 只需创建一个模拟对象并使用它 代码如下 def
  • 收到“/:未找到事件。”使用 PyCharm 远程调试器时

    当我使用 PyCharm 通过 ssh 进行远程调试时tcsh shell 服务器 很多时候它停止工作 并显示 未找到事件 更具体地说 我在 pycharm 调试控制台中遇到以下内容 ssh username hostserver 22 p
  • 返回吃异常

    我至少发现了以下行为weird def errors try ErrorErrorError finally return 10 print errors prints 10 It should raise NameError name E
  • 如何使用 PyAudio 选择特定的输入设备

    通过 PyAudio 录制音频时 如何指定要使用的确切输入设备 我的电脑有两个麦克风 一个内置 一个通过 USB 我想使用 USB 麦克风进行录音 这流类 https people csail mit edu hubert pyaudio
  • 如何使用资源模块来衡量函数的运行时间?

    我想使用Python代码测量函数的CPU运行时间和挂钟运行时间 此处建议资源模块 如何以 Python 代码 不是从终端 的形式分别测量函数的 CPU 运行时间和挂钟运行时间 https stackoverflow com q 192046
  • Python 中的颜色处理

    对于我的聚类 GUI 我目前对聚类使用随机颜色 因为我事先不知道最终会得到多少个聚类 在 Python 中 这看起来像 import random def randomColor return random random random ra
  • 如何从 IDLE 命令行运行 Python 脚本?

    在 bash shell 中 我可以使用 bash 或 source 手动调用脚本 我可以在 Python IDLE 的交互式 shell 中做类似的事情吗 我知道我可以转到文件 gt gt 打开模块 然后在单独的窗口中运行它 但这很麻烦
  • 使用神经网络包进行多项分类

    这个问题应该很简单 但文档没有帮助 我正在使用 R 我必须使用neuralnet多项式分类问题的包 所有示例均针对二项式或线性输出 我可以使用二项式输出进行一些一对一的实现 但我相信我应该能够通过使用 3 个单元作为输出层来做到这一点 其中
  • PyMC3 和 Theano - 导入 pymc3 后,有效的 Theano 代码停止工作

    一些简单的 theano 代码可以完美运行 当我导入 pymc3 时停止工作 这里有一些片段可以重现错误 Initial Theano Code this works import theano tensor as tsr x tsr ds
  • 混合两个列表的Pythonic方法[重复]

    这个问题在这里已经有答案了 我有两个长度为 n 和 n 1 的列表 a 1 a 2 a n b 1 b 2 b n 1 我想要一个函数作为结果给出一个列表 其中包含两个中的替代元素 即 b 1 a 1 b n a n b n 1 以下方法有
  • 重新安装后使用 pandas dataframes 时出现问题

    我已经重新安装了 Python 和 Anaconda 现在面临以下问题 在我将 pkl 文件加载到数据帧并尝试 查看 该文件后 如下所示 df pd read pickle example pkl df 我收到错误 AttributeErr
  • 将 .parquet 编码为 io.Bytes

    目标 将 Parquet 文件上传到 MinIO 这需要将文件转换为字节 我已经能够做到这一点了 csv json and txt bytes data to csv encode utf 8 bytes json dumps self d

随机推荐

  • Go 正则表达式中的转义括号

    我想在 Go 中的字符串上运行以下正则表达式 0 9 0 9 0 9 但我不断收到错误unknown escape sequence 我运行它的字符串是 1 53 38 45 2 88 62 98 3 78 48 3 4 72 30 76
  • Xcode 模拟器表视图是黑色的

    当在模拟器中运行我的 Xcode 项目时 我的UITableView in my UIViewController not UITableViewController 是黑色的 这是我的模拟器的图像 我的代码cellForRowAtInde
  • 字符数组后出现奇怪的字符

    我是 C 的真正初学者 但我正在学习 我以前偶然发现过这个问题 并决定询问其原因是什么 请解释你的答案 以便我学习 我制作了一个程序 允许您输入 5 个字符 然后显示您编写的字符并恢复它们 例如 asdfg gfdsa 奇怪的是 输入的原始
  • 如何处理 x86 与 x64 软件包

    We use NuGet管理我们的第三方包 我们还必须建立x86 and x64 builds 我们现在依赖于NuGet包裹 zeromq 依赖于 C dll 因此有一个x86 and x64发布 在 Nuget 中搜索时 我只看到两个不同
  • Express 使用高级服务,无法创建全文索引

    我已经安装了 SQL Server 2012 Express Edition 高级服务 其中声明它包含全文索引 这是一个链接 说明了这一点 http msdn microsoft com en us library cc645993 asp
  • 如何从 SQL Server 获取 DateTime 数据而忽略时区问题?

    我的情况是我们将数据存储在SQL Server数据库中 支持2005年以上 当存储 DateTime 值时 它是客户端的本地时间 我需要能够在任何其他客户端上取回该日期 无论其他客户端可能位于哪个时区 例如 当纽约的用户输入 2012 12
  • 根据模板参数引用不同基类的函数

    include
  • CryptoJS AES-128-ECB 和 PHP openssl_encrypt 不匹配

    我有一些 PHP 代码 无法编辑 还有一个充满加密消息的数据库 key 297796CCB81D2553B07B379D78D87618 return encrypted openssl encrypt data AES 128 ECB k
  • 用于Linux进程管理的Python库[关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 通过我的网络界面 我想启动 停止某些进程并确定启动的进程是否仍在运行 我现有的网站基于 Python
  • 在每个时区的凌晨 12 点运行 cron 作业

    所以我认为每个时区之间大约有30分钟的时间 我想运行我的脚本cron php每个时区的中午 12 点 午夜 我怎样才能做到这一点 我正在看这段代码 TZ UTC 7 root date mail root TZ CEST 7 root da
  • Django 错误 ---index() 缺少 1 个必需的位置参数:'pk'

    尝试打开路径时出现此错误 它需要在我的 def 中进行 pk 并插入它 但问题仍然存在 如果有人能帮忙 我会欠你很多 这是我在浏览器中遇到的错误 TypeError at batches index missing 1 required p
  • Android 操作系统是否有 /etc/passwd、/etc/shadow 和 /etc/group 等文件?

    如果不是 android如何判断用户是否属于某个组 该线程讨论了如何完成此操作 http groups google com group android ndk browse thread thread adddb27c1a5438e9 h
  • Unity C# ArgumentOutOfRangeException:参数超出范围。参数名称:索引[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我正在创建一个像蛇一样的游戏 在我下面的代码中 蛇身体的每个部分都是 Character 类的一个实例 当我尝试添加新角色时 出现错误
  • python中数组的就地修改

    我发现这个问题要求对数组进行就地修改 以便将所有零移动到数组末尾 并保持非零元素的剩余顺序 根据问题陈述 就地意味着不复制原始数组 这取自 Leetcode 可以在 283 Move Zeroes 中找到 输入和输出的示例是 0 1 0 1
  • 如何在 Github Gist 中软换行

    我有一个很长的字符串 我想使用 Github Gist 将其嵌入到我的博客中 我想为其启用换行 以便读者不必向右滚动即可查看整个字符串 即使在编辑时单击 软换行 选项 最终的要点也不会换行 编辑时 启用软包装 保存后 无换行 How can
  • PHP 将任何尺寸的图像调整为 16:9 的宽高比

    午安 我目前正在尝试了解如何以 16 9 的宽高比裁剪服务器上已加载的图像 为了更好地理解 如果我有 4 3 图像 我必须剪切顶部和底部图像部分以使其适合 16 9 比例 Thanks 我举了这个代码示例 http myrusakov ru
  • Javascript 中类似 Python 的“类”

    我想知道如何在 Javascript 中创建类似于 Python 中的 类 采用此处列出的 Python 类和函数 class one def foo bar some code 函数 foo 将被调用one foo bar JS 的等价物
  • 在纯原生 Android 应用程序中渲染文本

    我有一个纯原生的 Android NDK 应用程序 需要在每一帧渲染一些文本 我读过一些帖子 说我需要使用字体的所有字符创建一个图像文件 然后将每个字符渲染为该图像的四边形 这听起来需要大量工作 而且我不知道从哪里获取简单字体 例如 Ari
  • 将lucene索引分成两半

    将现有 Lucene 索引拆分为两半的最佳方法是什么 即每个拆分应包含原始索引中文档总数的一半 拆分现有索引 无需重新索引所有文档 的最简单方法是 制作现有索引的另一个副本 即 cp r myindex mycopy 打开第一个索引 并删除
  • 元组索引超出范围,Tensorflow

    这是模型 它是基本的张量流模型 可以拍摄数字的图片并告诉您它是什么数字 我知道python中的索引从0开始 我遇到的问题是这行代码 model fit np array test np array num 阅读下面的代码以获取更多信息 im