Keras ImageDataGenerator:为什么我的 CNN 的输出是相反的?

2024-02-04

我正在尝试编写一个区分猫和狗的 CNN 代码。 我已经设置了标签,例如狗:0和猫:1,所以我希望我的CNN在它是狗时输出0,如果它是猫则输出1。然而,它却做了相反的事情(当它是猫时给出 0,对于狗给出 1)。请检查我的代码并看看我哪里出错了。谢谢

我目前使用的是python 3.6.8,使用jupyter笔记本(里面的所有代码都是我从jupyter笔记本复制粘贴代码的不同部分)

import os
import cv2
from random import shuffle
import numpy as np
from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense, Activation, Conv2D, MaxPooling2D, Flatten, Dropout, BatchNormalization
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
%matplotlib inline

train_dir = r'C:\Users\tohho\Desktop\Python pypipapp\Machine Learning\data\PetImages\train'
test_dir = r'C:\Users\tohho\Desktop\Python pypipapp\Machine Learning\data\PetImages\test1'
IMG_WIDTH = 100
IMG_HEIGHT = 100
batch_size = 32



######## THIS IS WHERE I LABELLED 0 FOR DOG AND 1 FOR CAT ##########
filenames = os.listdir(train_dir)
categories = [] 
for filename in filenames:
    category = filename.split('.')[0]
    if category == 'cat':
        categories.append(1)
    elif category == 'dog':
        categories.append(0)

df = pd.DataFrame({'filename':filenames, 'class':categories}) # making the dataframe

#### I SPLIT THE DATA INTO TRAIN AND VALIDATION DATASETS ####
df_train, df_validate = train_test_split(df, test_size=0.25) # splitting data for train/test
 # need to reset index for both dataframs so imagedatagenerator works properly
df_train = df_train.reset_index(drop=True)
df_validate = df_validate.reset_index(drop=True)

print(df_train['class'].value_counts())
print(df_validate['class'].value_counts())

len_training = df_train.shape[0]
len_validate = df_validate.shape[0]
print('{} training eg, {} test eg'.format(len_training, len_validate))



#### CREATE IMAGE DATA GENERATORS ####
train_datagen = ImageDataGenerator(rescale=1./255,
                               shear_range = 0.2,
                               zoom_range = 0.2,
                               horizontal_flip = True)
# our train_datagen generator will use the following transformations on the images
validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_dataframe(df_train, 
                                                    train_dir,
                                                    target_size=(IMG_WIDTH, IMG_HEIGHT),
                                                    batch_size=batch_size,
                                                    x_col='filename',
                                                    y_col='class',
                                                    class_mode = 'binary')

# generator = ImageDataGenerator(*args).flow_from_dataframe(dataframe, directory, target_size,
# batch_size, x_col, y_col, class_mode)
# your dataframe shoudl be in the format such that x_col = features, y_col = class/label
# binary class mode since output is either 0(dog) or 1(cat)

validation_generator = validation_datagen.flow_from_dataframe(df_validate, 
                                                   train_dir,
                                                    target_size=(IMG_WIDTH, IMG_HEIGHT),
                                                    x_col='filename',
                                                    y_col='class',
                                                    class_mode='binary', 
                                                  batch_size=batch_size)

########## BUILDING MODEL ############
model = Sequential()
model.add(Conv2D(32, (3,3), input_shape=(IMG_WIDTH, IMG_HEIGHT, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(64, (3,3), input_shape=(IMG_WIDTH, IMG_HEIGHT, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(128, (3,3), input_shape=(IMG_WIDTH, IMG_HEIGHT, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Flatten()) # remember to flatten conv2d to dense layer
model.add(Dense(256))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.4))

model.add(Dense(1))
model.add(Activation('sigmoid')) 
# since we have only 1 output with range [0,1], we use sigmoid
# if there were n categories, use softmax

# binary_crossentropy since output is either 0,1
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

earlystop = EarlyStopping(monitor='val_loss', patience=3) # stops learning if val_loss doesnt improve
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                            patience=2, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.000001) 
# reduces learning rate if val_acc doesnt improve
callbacks = [earlystop, learning_rate_reduction]

##### FIT THE MODEL #####
epochs = 50
model.fit_generator(train_generator,
                   steps_per_epoch=len_training//batch_size,
                   verbose=1,
                   epochs=epochs,
                   validation_data=validation_generator,
                   validation_steps=len_validate//batch_size,
                   callbacks=callbacks) # fitting model


######### PREDICTING #############
output_generator = validation_datagen.flow_from_dataframe(df_output,
                                                   outputdir,
                                                   x_col='filename',
                                                   y_col=None,
                                                   class_mode=None,
                                                   target_size=(IMG_WIDTH, IMG_HEIGHT),
                                                   shuffle=False,
                                                   batch_size=batch_size)
predictions = model.predict_generator(output_generator, 
                                      steps=np.ceil(len_output/batch_size))
df_output['probability'] = predictions
df_output['label'] = np.where(df_output['probability'] > 0.5, 'cat','dog')
df_output.head()

CNN给出了与正确答案相反的结果,当反转输出时,我得到了预期的结果(正确的识别和准确性)。 我知道只要改变线路df_output['label'] = np.where(df_output['probability'] > 0.5, 'cat','dog') to df_output['label'] = np.where(df_output['probability'] < 0.5, 'cat','dog')解决了问题,但这并不能帮助我弄清楚为什么 CNN 的输出是相反的。


你的问题的原因很微妙。我将用一个玩具示例来说明发生了什么。假设我们使用以下代码实例化一个数据生成器:

# List of image paths, doesn't matter here
image_paths = ['./img_{}.png'.format(i) for i in range(5)] 
labels = ...  # List of labels

df = pd.DataFrame()
df['filename'] = image_paths
df['class'] = labels

generator = ImageDataGenerator().flow_from_dataframe(dataframe=df, 
                                                    directory='./',
                                                    x_col='filename',
                                                    y_col='class')

ImageDataGenerator 期望class数据框中的列包含与图像关联的字符串标签。在内部,它将这些标签映射到类整数。您可以通过调用来检查此映射class_indices属性。使用以下标签列表实例化我们的生成器后:

labels = ['cat', 'cat', 'cat', 'dog', 'dog']

the class_indices映射将如下所示:

generator.class_indices
> {'cat': 0, 'dog': 1}

让我们再次实例化生成器,但更改第一张图像的标签:

labels = ['dog', 'cat', 'cat', 'dog', 'dog']
# After re-instantiating the generator
generator.class_indices
> {'dog': 0, 'cat': 1}

我们类的整数编码被交换,这表明标签到类整数的内部映射取决于遇到不同类的顺序.

您正在绘制地图cat到 1 和dog为 0,但 ImageDataGenerator 将它们解释为标签字符串并在内部将它们映射为整数。

现在,如果目录中的第一张图片是一只猫,会发生什么?

labels = [1, 0, 1, 0, 0] # ['cat', 'dog', 'cat', 'dog', 'dog']
# After re-instantiating the generator
generator.class_indices
> {1: 0, 0: 1}  # !

这就是你困惑的根源。 :) 为了避免这种情况,可以:

  • 在数据框的标签列中使用“cat”和“dog”,并让 ImageDataGenerator 为您处理映射
  • 将类列表传递给classes调用中的参数flow_from_dataframe显式指定映射。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras ImageDataGenerator:为什么我的 CNN 的输出是相反的? 的相关文章

  • Kivy - 有所有颜色名称的列表吗?

    在 Kivy 中 小部件 color属性允许输入其值作为字符串颜色名称 也 例如在 kv file Label color red 是否有所有可能的颜色名称的列表 就在这里 来自Kivy 的文档 https kivy org doc sta
  • pyspark 数据框中的自定义排序

    是否有推荐的方法在 pyspark 中实现分类数据的自定义排序 我理想地寻找 pandas 分类数据类型提供的功能 因此 给定一个数据集Speed列 可能的选项是 Super Fast Fast Medium Slow 我想实现适合上下文的
  • Mac OS X 中文件系统的 Unicode 编码在 Python 中不正确?

    在 OS X 和 Python 中处理 Unicode 文件名有点困难 我试图在代码中稍后使用文件名作为正则表达式的输入 但文件名中使用的编码似乎与 sys getfilesystemencoding 告诉我的不同 采取以下代码 usr b
  • Emacs 24.x 上的 IPython 支持

    我对 IPython 与 Emacs 的集成感到困惑 从 Emacs 24 开始 Emacs 附带了自己的python el 该文件是否支持 IPython 还是仅支持 Python 另外 维基百科 http emacswiki org e
  • Django 的 request.FILES 出现 UnicodeDecodeError

    我在视图调用中有以下代码 def view request body u for filename f in request FILES items body body Filename filename n f read n 在某些情况下
  • pytest:同一接口的不同实现的可重用测试

    想象一下我已经实现了一个名为的实用程序 可能是一个类 Bar在一个模块中foo 并为其编写了以下测试 测试 foo py from foo import Bar as Implementation from pytest import ma
  • .pyx 文件出现未知文件类型错误

    我正在尝试构建一个包含 pyx 文件的 Python 包 pyregion 但在构建过程中出现错误 检查以下输出 python setup py build running build running build py creating b
  • 使用Python计算目录的大小?

    在我重新发明这个特殊的轮子之前 有没有人有一个很好的例程来使用 Python 计算目录的大小 如果例程能够很好地以 Mb Gb 等格式格式化大小 那就太好了 这会遍历所有子目录 总结文件大小 import os def get size s
  • 通过索引访问Python字典的元素

    考虑一个像这样的字典 mydict Apple American 16 Mexican 10 Chinese 5 Grapes Arabian 25 Indian 20 例如 我如何访问该字典的特定元素 例如 我想在对 Apple 的第一个
  • Matplotlib 中 x 轴标签的频率和旋转

    我在下面编写了一个简单的脚本来使用 matplotlib 生成图形 我想将 x tick 频率从每月增加到每周并轮换标签 我不知道从哪里开始 x 轴频率 我的旋转线产生错误 TypeError set xticks got an unexp
  • Mac OSX 10.6 上的 Python mysqldb 不工作

    我正在使用 Python 2 7 并尝试让 Django 项目在 MySQL 后端运行 我已经下载了 mysqldb 并按照此处的指南进行操作 http cd34 com blog programming python mysql pyth
  • Anaconda 无法导入 ssl 但 Python 可以

    Anaconda 3 Jupyter笔记本无法导入ssl 但使用Atom终端导入ssl没有问题 我尝试在 Jupyter 笔记本中导入 ssl 但出现以下错误 C ProgramData Anaconda3 lib ssl py in
  • 在系统托盘中隐藏 tkinter 窗口 [重复]

    这个问题在这里已经有答案了 我正在制作一个程序来提醒我朋友的生日 这样我就不会忘记祝福他们 为此 我制作了两个 tkinter 窗口 1 First one is for entering name and birth date 2 Sec
  • 从 NumPy 数组到 Mat 的 C++ 转换 (OpenCV)

    我正在围绕 ArUco 增强现实库 基于 OpenCV 编写一个薄包装器 我试图构建的界面非常简单 Python 将图像传递给 C 代码 C 代码检测标记并将其位置和其他信息作为字典元组返回给 Python 但是 我不知道如何在 Pytho
  • 动态过滤 pandas 数据框

    我正在尝试使用三列的阈值来过滤 pandas 数据框 import pandas as pd df pd DataFrame A 6 2 10 5 3 B 2 5 3 2 6 C 5 2 1 8 2 df df loc df A gt 0
  • 双击打开 ipython 笔记本

    相关文章 通过双击 osx 打开 ipython 笔记本 https stackoverflow com questions 16158893 open an ipython notebook via double click on osx
  • python 线程安全可变对象复制

    Is 蟒蛇的copy http docs python org 2 library copy html模块线程安全吗 如果不是 我应该如何在 python 中以线程安全的方式复制 deepcopy 可变对象 蟒蛇的GIL http en w
  • 从 pandas DataFrame 中删除少于 K 个连续 NaN

    我正在处理时间序列数据 我在从数据帧列中删除小于或等于阈值的连续 NaN 时遇到问题 我尝试查看一些链接 例如 标识连续 NaN 出现的位置以及计数 Pandas NaN 孔的游程长度 https stackoverflow com que
  • 如何为不同操作系统/Python 版本编译 Python C/C++ 扩展?

    我注意到一些成熟的Python库已经为大多数架构 Win32 Win amd64 MacOS 和Python版本提供了预编译版本 针对不同环境交叉编译扩展的标准方法是什么 葡萄酒 虚拟机 众包 我们使用虚拟机和Hudson http hud
  • Apache Beam Pipeline 写表后查询表

    我有一个 Apache Beam Dataflow 管道 它将结果写入 BigQuery 表 然后我想查询该表以获取管道的单独部分 但是 我似乎无法弄清楚如何正确设置此管道依赖性 我编写的新表 然后想要查询 与一个单独的表连接以进行某些过滤

随机推荐

  • 如何获取意图服务中的上下文

    场景如下 我有一个 WakefulBroadcastReceiver 它执行以下操作 备份到网络计算机或云端 它设置为在 半夜 当我知道平板电脑可以访问 局域网 备份会将数据存储到实例化 WakefulBroadcastReceiver 的
  • visio 的 vba 编程

    目前 我正在IVR 交互式语音应答 系统工作 要在 IVR 中添加服务 我必须在 visio 中制作流程 该流程具有 IVR 系统卖家提供的预定义形状 形状是用 VBA 编程的 我决定学习VBA来修改预定义的形状 在google中搜索时 它
  • 属性更改时重新构建/重新渲染 Angular2 组件

    如何实施 我的子组件 import Component Input ngOnInit from angular2 core Component selector my component template div In child comp
  • 存储库名称作为 GitHub Action 环境变量?

    如何获取存储库名称 而不是用户或组织 作为 GitHub Actions 中的环境变量 我发现github repository但这包含所有者作为第一部分 如下所示 owner repo Try github event repositor
  • 如何控制 Honeycomb 中的软菜单按钮?

    我有一个应用程序 我想在其中关闭菜单按钮 我正在选择其他人的项目 并且不确定是什么引起了菜单按钮的出现 它没有 做任何事情 有没有办法手动关闭该图标 或者我必须首先找出它被显示的原因 请不要批评寻找解决方法 显然在理想的世界中 我会对代码足
  • C++11 std::bind 和 boost::bind 之间的区别

    两者有什么区别吗 或者我可以安全地替换每次出现的boost bind by std bind在我的代码中 从而消除对Boost的依赖 boost bind 关系运算符重载 http www boost org libs bind bind
  • 为 fa 圆添加边框

    如何给图片添加边框circleFont Awesome 的图标 其实我的结果是 http jsfiddle net 0jhdvj0k http jsfiddle net 0jhdvj0k 边框类似于省略号 而不是圆形边框 table cla
  • Jquery - 更改标签中的文本

    这是标签 有文字 使用 20 公里 使用 jquery 我想将文本 20 Kms 替换为 10 kms 我用手像这样贴上标签 label for applyDistanceSlab 我怎样才能做到这一点 label for applyDis
  • 在嵌套对象内搜索文本(以 Backbone.js 集合为例)

    我有一个backbone js 集合 我需要在其中进行全文搜索 我手头的工具如下 Backbone js 下划线 js jQuery 对于那些不熟悉主干的人 骨干集合只是一个对象 在集合内有一个包含模型的数组 每个模型都有一个带有属性的数组
  • CMake Qt UIC 失败

    我目前正在将我的项目从 qmake 移植到 CMake 并且我遇到了 Qt UIC 的问题 它尝试处理不存在的 UI 文件 而不是我希望它处理的实际文件 我有以下文件层次结构 CMakeLists txt MyProject pro mai
  • Visual Studio 解决方案——有什么方法可以创建“特殊”文件夹吗?

    基本上 我希望我的一个文件夹作为一种 特殊文件夹 出现在其他文件夹上方 类似于 属性 如何拥有自己的特殊位置 即使它是一个文件夹 与 App Data 等相同 这可能吗 默认情况下 Visual Studio 不支持添加特殊项目文件夹 Pr
  • 对外界隐藏内部服务以确保使用正确的高级服务[关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 我正在一个电子商务网站上工作 我有广告实体 其中包括属性和照片 属性写入数据库 照片存储在文件系统中 我创建了一个WriterService
  • 在 Dash/Plotly 中显示属性会导致 KeyError

    我正在尝试可视化文档中的引用 为此 我有Elements csv 看起来像这样 Doc Description DocumentID SOP Laboratory This SOP should be used in the lab 104
  • __attribute__((force)) 有什么作用?

    这听起来像是我应该能够通过谷歌搜索的东西 但我找不到很好的参考 到底是做什么的 attribute force 做 如 return attribute force uint32 t p 这是针对 ARM 系统 与 clang 交叉编译的
  • qtmaind.lib 中未解析的外部符号

    我正在尝试将我的 Qt 项目设置从 Visual Studio 2013 升级到 2015 它几乎完成了 但我在 qtmaind lib 中遇到了一些错误 1 gt qtmaind lib qtmain winrt obj 错误LNK201
  • Django:URLconf 中的变量参数

    我一直在寻找这个问题 但找不到任何问题 如果重复的话 抱歉 我正在建立某种电子商务网站 类似于 eBay 当我尝试浏览 类别 和 过滤器 时出现了问题 例如 您可以浏览 监视器 类别 这将向您显示大量监视器和一些应用它们的过滤器 与 eBa
  • Spring Boot:java.lang.IllegalArgumentException:找到多个名为 [spring_web] 的片段

    我在 tomcat 9 上部署 spring boot war 时得到了这个 我尝试了很多解决方案 例如清理项目以及我在 stackoverflow 中找到的所有可能的解决方案 但没有任何效果 其中一个在 web xml 中提供绝对排序 但
  • 具有 2 个中心部分的 Windows Phone 8.1 应用程序

    我创建了一个包含两个中心部分的 WP8 1 中心应用程序 这会产生两个轮毂之间滑动的奇怪行为 它们不会像预期的那样 飞 到位 它更像是垂直可滚动视图 如果我添加第三个集线器部分 一切都会正常工作 我在这里上传了有关它的 YouTube 视频
  • 在节点外部标记 networkx 节点属性

    我正在研究属于两种类型的小示例节点集 human machine 我想在每个节点之外以字典形式标记节点属性networkx图中 如下图的节点c e j所示 我使用MS Word在图表上添加了字典类型的属性 基本图是使用以下代码生成的 imp
  • Keras ImageDataGenerator:为什么我的 CNN 的输出是相反的?

    我正在尝试编写一个区分猫和狗的 CNN 代码 我已经设置了标签 例如狗 0和猫 1 所以我希望我的CNN在它是狗时输出0 如果它是猫则输出1 然而 它却做了相反的事情 当它是猫时给出 0 对于狗给出 1 请检查我的代码并看看我哪里出错了 谢