多输入深度学习中的平均层

2024-03-14

我正在努力在 Keras 中创建一个用于图像分类的多输入卷积神经网络 (CNN) 模型,该模型采用两个图像并给出一个输出,即两个图像的类别。

我有两个数据集:type1 和 type2,每个数据集包含相同的类。该模型应从 Type1 数据集中获取一张图像,从 Type2 数据集中获取一张图像,然后将这些图像分类为一个类别(ClassA 或 ClassB 或 ------)。

我想创建一个模型来预测两个图像,然后计算预测的平均值,类似于下图:

我怎样才能创建这个模型? 如何在 fit_generator 中创建生成器?


选项 1 - 双方模型相同,只是使用不同的输入

假设你有一个达到“预测”的模型,称为predModel.
创建两个输入张量:

input1 = Input(shape)   
input2 = Input(shape)

获取每个输入的输出:

pred1 = predModel(input1)
pred2 = predModel(input2)   

平均输出:

output = Average()([pred1,pred2])

创建最终模型:

model = Model([input1,input2], output)

选项2 - 双方都是相似的模型,但使用不同的重量

基本上与上面相同,但为每一侧单独创建图层。

def createCommonPart(inputTensor):
    out = ZeroPadding2D(...)(inputTensor)
    out = Conv2D(...)(out)

    ...
    out = Flatten()(out)
    return Dense(...)(out)

进行两个输入:

input1 = Input(shape)   
input2 = Input(shape)

获取两个输出:

pred1 = createCommonPart(input1)
pred2 = createCommonPart(input2)

平均输出:

output = Average()([pred1,pred2])

创建最终模型:

model = Model([input1,input2], output)

发电机

任何能产生的东西[xTrain1,xTrain2], y.

您可以像这样创建一个:

def generator(files1,files2, batch_size):

    while True: #must be infinite

        for i in range(len(files1)//batch_size)):
            bStart = i*batch_size
            bEnd = bStart+batch_size

            x1 = loadImagesSomehow(files1[bStart:bEnd])
            x2 = loadImagesSomehow(files2[bStart:bEnd])
            y = loadPredictionsSomeHow(forSamples[bStart:bEnd])

            yield [x1,x2], y

您还可以实施keras.utils.Sequence以类似的方式。

class gen(Sequence):
    def __init__(self, files1, files2, batchSize):
        self.files1 = files1
        self.files2 = files2
        self.batchSize = batchSize

    def __len__(self):
        return self.len(files1) // self.batchSize

    def __getitem__(self,i):

        bStart = i*self.batchSize
        bEnd = bStart+self.batchSize 

        x1 = loadImagesSomehow(files1[bStart:bEnd])
        x2 = loadImagesSomehow(files2[bStart:bEnd])
        y = loadPredictionsSomeHow(forSamples[bStart:bEnd])

        return [x1,x2], y
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多输入深度学习中的平均层 的相关文章

随机推荐

  • 生成按字母顺序位于其他两个字符串之间的字母字符串的算法?

    我试图解决的一个问题 假设您有两个由小写字母 a 到 z 组成的不同字符串 请在两个字符串之间找到一个字符串 以便始终可以找到更多中间字符串 更多细节 鉴于按字母顺序 a 位于 b 之前 当按照字典排序时 a 和 b 之间存在无限数量的字符
  • 如何使用可达性类来检测有效的互联网连接?

    我是 iOS 开发新手 正在努力让reachability h 类正常工作 这是我的视图控制器代码 void viewWillAppear BOOL animated NSNotificationCenter defaultCenter a
  • 如何在给定的任意年份中获得去年的相同工作日?

    我希望任何一年都能得到去年的同一天 我怎样才能最好地在 R 中做到这一点 例如 给定星期日 2010 01 03 我想获取前一年同一周的星期日 Sunday weekdays as Date 2010 01 03 format Y m d
  • 在 Magento 交易电子邮件中添加密件抄送

    我创建了一个新的电子邮件模板 在 Magento 中运行良好 但我不知道如何将密件抄送地址添加到电子邮件中 您可以在发送电子邮件的代码中添加密件抄送 Mage getModel core email template gt addBcc e
  • 排序 if/else if 语句的最快/正确方法

    在 PHP 中 是否有最快 正确的方法来排序 if else if 语句 出于某种原因 在我看来 我喜欢认为第一个 if 语句应该是预期的 最受欢迎 满足条件 然后是第二个 依此类推 但是 这真的重要吗 如果第二个条件是最流行的选择 是否会
  • 如何按幂 bi 矩阵的降序对列日期进行排序

    我需要按日期降序对矩阵列进行排序 我还有什么选择吗 检查这个图像matrix https i stack imgur com sj9Et png我需要从 1 月 20 日到 1 月 19 日订购 此列已按日期列排序 提前致谢 一种解决方案是
  • Python - 将列表列表分组

    考虑以下简化情况 lol John Polak 5 3 7 9 John Polak 7 9 2 3 Mark Eden 0 3 3 1 Mark Eden 5 1 2 9 什么会是pythonic 和内存 速度高效根据前两个参数将此列表列
  • 如何覆盖自带 .d.ts 的包中的错误类型?

    我正在使用 chalk 处理 JavaScript 项目 并使用 TypeScript 检查该项目checkJs flag JavaScript 代码像这样导入它 const chalk require chalk 不幸的是 粉笔有自己的类
  • RxJS 6 获取 Observable 数组的过滤列表

    在我的 ThreadService 类中 我有一个函数getThreads 给我返回一个Observable
  • Android Honeycomb 上的 DexClassLoader

    我正在开发一个项目 尝试通过加载外部库 Dex类加载器 这在 2 3 中效果很好 public class FormularDisplayLoader public final static String PATH data data at
  • Angular 可重复使用模板

    是否可以编写可重用的ng template 我的很多组件都使用完全相同的ng template 例如
  • 为什么 Tensorflow 对象检测 API 使用 YUV420SP 到 ARGB8888 转换

    所以我得到了tensorflow object detection API在Android上运行 我注意到在浏览代码时 在处理从相机拍摄的帧之前 它们是一个像这样的转换CameraActivity java imageConverter n
  • 如何在 ASP.NET Core 应用程序中使用位图资源?

    我正在尝试使用其中的一些位图资源 netcore2 1应用程序 但是当我将图像资源添加到我的项目时 它显示以下错误 严重性代码 说明 项目文件行抑制状态 错误资源 sign here tag 无法实例化 找不到类型 System Drawi
  • 在 MySQL 数据库中存储纬度/经度时使用的理想数据类型是什么?

    请记住 我将对纬度 经度对执行计算 哪种数据类型最适合与 MySQL 数据库一起使用 基本上 这取决于您所在位置所需的精度 使用 DOUBLE 您将获得 3 5nm 的精度 DECIMAL 8 6 9 6 下降到 16 厘米 浮子是1 7m
  • 如何将 Asp.Net 身份验证与 Azure AD 身份验证连接

    我在我的 asp net 项目中使用 UseOpenIdConnectAuthentication 协议来连接到我的 Azure AD 并且工作正常 今天 我也需要在 Asp net Identity 或其他与 Azure AD 不同的身份
  • 如何在google-colaboratory上安装需要编译的库

    当尝试安装需要的库时cmake像这样 pip install dlib 笔记本返回以下错误 error Errno 2 No such file or directory cmake cmake 您可以使用aptgoogle colabor
  • 静态定义的 IDT [重复]

    这个问题在这里已经有答案了 我正在开发一个启动时间要求很紧的项目 目标架构是基于 IA 32 的处理器 在 32 位保护模式下运行 已确定可以改进的领域之一是当前系统动态初始化处理器的 IDT 中断描述符表 由于我们没有任何即插即用设备并且
  • 使用 account-ui 包时,meteor 中是否有 post createUser 挂钩?

    假设我有一个待办事项应用程序 我想确保每个注册的用户都至少有一个待办事项开始 例如 第一个待办事项要划掉 我将如何在流星中做到这一点 一般来说 在我看来 我可以在第一次创建用户时执行此操作 理想 或者检查他们每次登录时是否需要新的待办事项
  • 复杂的自定义标签助手

    基本上 我扩展了之前回答的问题 更新相关实体 https stackoverflow com questions 53380176 updating related entities 因此它是一个自定义标签助手 我想向自定义标签助手发送与用
  • 多输入深度学习中的平均层

    我正在努力在 Keras 中创建一个用于图像分类的多输入卷积神经网络 CNN 模型 该模型采用两个图像并给出一个输出 即两个图像的类别 我有两个数据集 type1 和 type2 每个数据集包含相同的类 该模型应从 Type1 数据集中获取