ImageDataGenerator 预测类 - 为什么预测未正确从概率转换为预测类?

2024-05-01

我有一个这样设置的目录:

images

-- val
    --class1
    --class2
-- test
   --all_classes
-- train
    --class1
    --class2

每个目录中都有一组图像。我想预测测试中的每个图像是否属于 1 类或 2 类。

我写这个是为了读取训练和验证数据:

train_path = "/content/drive/train/"
valid_path = "/content/drive/val/"

train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator=train_datagen.flow_from_directory(
  directory=train_path,
  batch_size=32,
  class_mode='binary',
  target_size=(150,150)
)

validation_generator=test_datagen.flow_from_directory(
  directory=valid_path,
  batch_size=32,
  class_mode='binary',
  target_size=(150,150)
)

创建了一个网络:

def create_network(): 
  model = Sequential()
  model.add(Input(shape=(150,150,3)))

  model.add(Conv2D(32, kernel_size=3,strides=(1, 1),activation='relu', padding='valid', dilation_rate=1))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=3, strides=(1, 1), activation='relu',padding='valid', dilation_rate=1))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Flatten())
  model.add(Dense(512, activation='relu'))

  model.add(Dense(1, activation='sigmoid'))
  plot_model(model, to_file='/content/drive/question1_model.png', show_shapes=True, show_layer_names=True)

  model.compile(optimizer = 'adam',
                   loss = 'binary_crossentropy', 
                   metrics = ['accuracy'])
  return model

拟合模型:

def fit_model(train_generator=train_generator, validation_generator=validation_generator,network=create_network()):
  checkpoint_path = "/content/drive/question1_checkpoint.h5"
  checkpoint_dir = os.path.dirname(checkpoint_path)

  callbacks_list = [
      callbacks.EarlyStopping(
          monitor = 'accuracy',
          patience = 5,
      ),

      callbacks.ModelCheckpoint(
          filepath=checkpoint_path,
          monitor = 'val_loss',
          #save_weights_only=True,
          save_best_only=True,
      ),

  ]

  model = network
  history = model.fit(train_generator,
                      epochs=200,
                      validation_data=validation_generator,
                      batch_size=32, 
                      callbacks = callbacks_list,
                      verbose=1
                      )
  return history,model,time_taken

history,model = fit_model(train_generator,validation_generator)

模型的准确率和验证准确率>80%,我将其重新加载进行预测:

model = load_model('/content/drive/question1_checkpoint.h5')

然后我想预测测试目录中的一组图像:

test_datagen = ImageDataGenerator(rescale=1./255)
test_path = "/content/drive/test/"

test_generator = test_datagen.flow_from_directory(
  directory=test_path,
  batch_size=16,
  class_mode='binary',
  target_size=(150,150),
  shuffle = False
)
test_generator.reset()
filenames = test_generator.filenames
nb_samples = len(filenames)
batch_size=16
predict = model.predict(test_generator,steps=test_generator.n/batch_size)

当我打印预测的开始时,我可以看到:

[[6.09035552e-01]
 [2.47541070e-02]
 [7.37663209e-02]
 [5.22839129e-02]
 [2.94408262e-01]
 [1.39171720e-01]
 [6.15863085e-01]

我认为这给了我 1 类正确的概率。但是当我打印每个预测的类别时:

predicted_class_indices=np.argmax(predict,axis=-1)
print(predicted_class_indices)

输出是:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]

这意味着我的预测概率没有被正确地转换到课堂上,对吗?因为例如 2.47541070e-02 是 0.02,而 6.09035552e-01 是 0.60,所以这些不应该被预测为不同的类别吗?有人可以告诉我哪里出错了吗?


这段代码给出了以下输出:(在使用相同的二进制类数据集训练模型后,我拍摄了 10 张图像进行测试 - 5 张狗的图像、5 张猫的图像)。

nb_samples = len(filenames)
batch_size=5
predict = model.predict(test_generator,steps=test_generator.n/batch_size)
predict

Output:

array([[0.06690815],
       [0.7787118 ],
       [0.109512  ],
       [0.39706784],
       [0.07243159],
       [0.61042166],
       [0.5808931 ],
       [0.86361384],
       [0.9961897 ],
       [0.61571515]], dtype=float32)

你用过哪个是正确的sigmoid https://www.tensorflow.org/api_docs/python/tf/keras/activations/sigmoid最后一层的激活函数,则输出范围将从 0 到 1。

请不要使用argmax for sigmoid价值观。您可以使用argmax使用时的方法softmax https://www.tensorflow.org/api_docs/python/tf/keras/activations/softmax激活函数来查找其他类别概率中类别的最高概率值。

在这里你可以使用下面的代码:

import tensorflow as tf

predictions = tf.where(predict <= 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())

Output:

Predictions:

 [[0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ImageDataGenerator 预测类 - 为什么预测未正确从概率转换为预测类? 的相关文章

随机推荐

  • 在 SugarCRM 中,将帐户所有权转让给其他用户不会更新联系人所有权

    我正在使用 SugarCRM v6 x 并发现当将帐户所有权转移给新的销售代表 分配的用户 ID 字段 时 联系人和其他相关子记录也不会转移 这是 SugarCRM 作者的实际设计选择吗 如果是 其背后的原因是什么 是否有推荐的帐户转移方法
  • Android 数字格式不知为何是错误的,我得到的不是 3.5,而是 3.499999999,为什么?

    我将一些数据存储在数据库中 然后使用游标读取这些数据 所有数据均为 56 45 3 04 0 03 类型 即小数点后两位 现在我想对它们求和 但这似乎并不容易 我得到这些数字c getDouble 3 然后我将它添加到 sum 变量中 如下
  • iOS 信号处理程序可以轻松收集哪些原因信息?

    我正在尝试向应用程序添加一些崩溃日志记录 并且我有一个signal设置处理程序以捕获标准 致命 信号 我可以在信号处理程序中实际 简单地收集哪些 原因 信息 如果有 以进行记录 我花了大约 2 小时谷歌搜索内容 但我找到的大部分内容都是针对
  • Java多线程和安全发布[关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 看完之后 Java并发实践 http jcip net and OSGI 实践 http neilbartlett name blog osgi
  • PayPal API 监听器网站支付标准 URI

    PayPal IPN 指南文档说得很清楚 将请求发布到 www paypal com 或 www sandbox paypal com 具体取决于您是要在沙盒中上线还是测试您的侦听器 等待 PayPal 的响应 该响应要么已验证 要么无效
  • 在 MVC 中重用 WPF ViewModel 是否可行?

    我们有一个用 WPF WCF 编写的富客户端应用程序 并打算在 ASP net 中创建一个配套网站 如果可能 使用 MVC 我被要求弄清楚我们当前的代码库中有多少是可以重用的 由一个单独的团队 而且我对 ASP net 几乎没有经验 我们将
  • CSS 中的圆帽下划线

    你能用 CSS 制作圆形下划线 如上图所示 吗 如何 有没有办法做到这一点border bottom border radius相反 会产生这种时尚的效果 编辑 我误解了皮克想要什么 但这应该有效 test font size 50px b
  • 根据条件过滤数据集

    我正在使用 asp net 2 0 和 c 我有一个数据集 正在获取员工信息 现在我想根据用户在搜索文本框中输入的名称来过滤网格视图 我正在这样做 DataSet ds new DataSet EmployeeInformation loa
  • 使用VBA从Zip中删除一些特定文件[重复]

    这个问题在这里已经有答案了 在完整的宏观过程中 我正在创建一个Zip的文件Folder 该文件夹有多个子文件夹和文件 使用此代码 Dim oApp As Object NewZip s path acc name zip Set oApp
  • ORA-02289: 序列不存在,hibernbate 中出错

    ORA 02289 序列不存在 hibernbate 中出错 在 Oracle 中 您无法自动生成值 您应该创建一个序列 我们称之为 VEHICLE SEQ 然后你应该把这个注释放在你的 id 上 GeneratedValue strate
  • 在 fork() 之后寻求有关“文件描述符”的简单描述

    Unix 环境中的高级编程 第二版 作者 W Richard Stevens 第 8 3 节 fork 函数 描述如下 父级和子级共享相同的文件偏移量非常重要 考虑一个分叉子进程 然后等待子进程完成的进程 假设两个进程都写入标准输出作为其正
  • 使用选择器获取最接近的父元素(不包括当前元素)

    我正在尝试获取元素的最接近的父元素 看着 closest https developer mozilla org en US docs Web API Element closest 如果选择器与元素匹配 它似乎会返回元素本身 Closes
  • WEBHID API:条形码扫描仪未触发输入报告

    我几乎使用 Nintendo Switch Joy Con 控制器演示 我对其进行了一些修改以使其与我的条形码扫描仪一起使用 它就是行不通 如果行得通 则每 100 次站点刷新就会工作一次 console log text gt log t
  • 如何将 Mercurial 存储库克隆到已存在的目录中?

    我有一个客户的 Django 项目 正在本地开发 使用 Mercurial 进行版本控制 我将本地存储库推送到我的个人远程服务器 我保存所有项目的地方 然后当我部署它时 在任何 Web 服务器上 我从我的个人服务器克隆该存储库 这在大多数服
  • 作为颜色表示的值

    将值转换为颜色是众所周知的 我确实理解以下两种方法 在改变 RGB 颜色值来表示一个值 https stackoverflow com questions 1423925 changing rgb color values to repre
  • 如何从控制器 Symfony2 内部访问不同的控制器

    我需要从另一个控制器内的不同控制器访问方法 我该怎么做 我可以用吗this gt get method 我可以将控制器包含在当前控制器中并创建它的对象并通过该对象访问该方法吗 这样做 可以 吗 我想调用另一个控制器的表单方法 newActi
  • 找不到 build.xml (Android)

    我一直在寻找这个问题的答案有一段时间了 但我似乎找不到它 我通过 perfoce 移动了 NeBeans Android 项目 现在出现以下错误 ZYAndroidAPP build xml 81 Cannot find F Program
  • 我应该如何处理 Android 应用程序中 http post 的服务器超时和错误代码响应?

    我的 Android 应用程序会向 URL 发送 http 帖子 例如http example com 电子邮件受保护 http example com abc php email abc xyz com因此 Android 应用程序基本上
  • 在哪里可以找到所有 HQL 关键字的列表?

    在哪里可以找到所有 HQL 关键字的列表 在完整的 Hibernate 源代码下载中 有一个grammar hql g文件 这是ANTLR http www antlr org 语言定义 您可以从官方GitHub源码仓库查看该文件的最新版本
  • ImageDataGenerator 预测类 - 为什么预测未正确从概率转换为预测类?

    我有一个这样设置的目录 images val class1 class2 test all classes train class1 class2 每个目录中都有一组图像 我想预测测试中的每个图像是否属于 1 类或 2 类 我写这个是为了读