对多输出 Keras 模型中的每个输出使用不同的样本权重

2024-02-23

我的输入数组是image_array,包含 10000 张大小为 512x512、4 个通道的图像的数据。 IE。image_array.shape = (10000, 512, 512, 4)。每张图像都有一个相关的指标,我想训练 CNN 来为我进行预测。因此metric_array.shape = (10000)。由于我不希望网络偏向更频繁出现的指标值,因此我有一个加权数组,其中包含指标的每个值的权重。因此weightArray.shape = (10000).

我正在使用 Keras。这是我的顺序模型:

model = Sequential()
model.add(Conv2D(32, use_bias=True, kernel_size=(3,3), strides=(1, 1), activation='relu', input_shape=(512,512,4))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, use_bias=True, kernel_size=(3,3), strides=(1, 1), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(128, use_bias=True, kernel_size=(3,3), strides=(1, 1), activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(32))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Dense(1, activation=relu_max))

我想使用均方误差损失函数和随机梯度下降优化器。我编译我的模型:

model.compile(loss='mean_squared_error', optimizer=optimizers.SGD(lr=0.01))

我将数据集分为训练和验证:

X_train, X_validate, Y_train, Y_validate, W_train, W_validate \
= train_test_split(image_array, metric_array, weightArray, test_size=0.3)

最后训练模型:

model.fit(X_train, Y_train, epochs=100, batch_size = 32, \
           validation_data=(X_validate,Y_validate), sample_weight=W_train)

以上所有工作。现在,我想做的是使用 2 个指标而不是 1 个。我对每个图像都有一个 metric1 值和一个 metric2 值。 metric1 和 metric2 的每个值都有一个关联的权重。因此

metric_array1.shape = metric_array2.shape = weightArray1.shape = weightArray2.shape = (10000)

然后,我的网络将有两个输出节点,每个节点对应一个指标。

我尝试将上面的最后一层更改为:

model.add(Dense(2, activation=relu_max))

然后,我将度量和权重数据组合成一个 metric_array 和一个元组的weightArray,形状为 (10000, 2)。 这让我发现顺序模型是为单个输出而设计的,因此我应该使用函数模型。

我读过一些文档,看起来相当复杂。我尝试使用上面的模型(但最后一层有 2 个节点),然后执行

from keras.models import Model
new_model = Model(model)

但当我尝试编译它时它不喜欢它,因为模型没有选项.add.

有没有一种简单的方法来修改我已经拥有的东西以获得我的新目的?我真的很感激任何指导。


首先我们先澄清一个误区:

如果你的模型有one输出输入layer那么你可以使用 Sequential API 来构建你的模型,无论神经元数量在输出层和输入层。另一方面,如果你的模型有multiple输出输入layers,那么您必须使用功能 API 来定义您的模型(无论输入/输出层可能有多少个神经元)。

现在,您已经声明您的模型有两个输出值,并且对于每个输出值,您想要使用不同的样本权重。为了能够做到这一点,您的模型必须具有两个输出层,然后您可以设置sample_weight参数作为字典,包含对应于两个输出层的两个权重数组。

为了更清楚地说明这一点,请考虑这个虚拟示例:

from keras import layers
from keras import models 
import numpy as np

inp = layers.Input(shape=(5,))
# assign names to output layers for more clarity
out1 = layers.Dense(1, name='out1')(inp)
out2 = layers.Dense(1, name='out2')(inp)

model = models.Model(inp, [out1, out2])
model.compile(loss='mse',
              optimizer='adam')

# create some dummy training data as well as sample weight
n_samples = 100
X = np.random.rand(n_samples, 5)
y1 = np.random.rand(n_samples,1)
y2 = np.random.rand(n_samples,1)

w1 = np.random.rand(n_samples,)
w2 = np.random.rand(n_samples,)

model.fit(X, [y1, y2], epochs=5, batch_size=16, sample_weight={'out1': w1, 'out2': w2})
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

对多输出 Keras 模型中的每个输出使用不同的样本权重 的相关文章

随机推荐

  • GROUP BY 子句在 sqlite 中获取逗号分隔值

    我的表结构是这样的 使用sqlite3 CREATE TABLE enghindi eng TEXT hindi TEXT 我有一张名为enghindi其中有两列名为hindi eng 我想合并 eng 列的记录 并通过逗号分隔合并印地文单
  • 如果没有人调用interrupt(),可以忽略InterruptedException吗?

    如果我创建自己的线程 即不是线程池 并且在某个地方调用sleep或任何其他可中断方法 是否可以忽略 InterruptedException如果我知道代码中没有其他人在线程上进行中断 换句话说 如果线程的寿命应该与 JVM 一样长 这意味着
  • 如何让 PHP SOAP 客户端与使用无效证书通过 SSL 运行的服务进行通信

    我尝试使用 PHP SOAP 客户端使用 SOAP 服务 但失败并显示以下消息 SoapFault SOAP ERROR Parsing WSDL Couldn t load from https domain com webservice
  • MathJax 渲染模糊

    MathJax http www mathjax org 在浏览器中的渲染 右 比在 LaTeX 中的等效 PDF 渲染 左 要模糊得多 这是 Javascript 限制 浏览器限制 MathJax 限制 错误 设计原因还是其他原因 有什么
  • 创建 AMI 映像作为 cloudformation 堆栈的一部分

    我想创建一个 EC2 cloudformation 堆栈 基本上可以按以下步骤描述 1 启动实例 2 配置实例 3 停止实例并从中创建 AMI 映像 4 使用创建的 AMI 映像作为源创建自动缩放组以启动新实例 基本上我可以在一个 clou
  • 当我需要转义 Html 字符串时?

    在我的遗留项目中 我可以在字符串发送到浏览器之前看到 escapeHtml 的用法 StringEscapeUtils escapeHtml stringBody 我从 api 文档知道 escapeHtml 的作用 这里是给出的示例 Fo
  • 构建 Mac 和 Windows GUI 应用程序

    我计划为 Mac 和 Windows 构建一个 GUI 应用程序 我一直在技术选择方面进行一些研究 例如语言 库和构建工具 以便我可以在两个平台之间共享尽可能多的代码 主要要求是 满足 Mac App Store 要求 Mac 和 Wind
  • C# 对象类型比较

    如何比较声明为类型的两个对象的类型 我想知道两个对象是否属于同一类型或来自同一基类 任何帮助表示赞赏 e g private bool AreSame Type a Type b Say a and b是两个对象 如果你想看看是否a and
  • QFormLayout 中的 QSpacerItem - 垂直展开

    我想在我的内心拓展一个空间QFormLayout 但无论如何QFormLayout仅使用QSpaceItem sizeHint 有谁知道解决这个问题的方法 或者处理这个问题的正确方法 MyWidget MyWidget QWidget pa
  • Kinesis 分区键始终位于同一个分片中

    我有一个包含 2 个分片的运动流 如下所示 StreamDescription StreamStatus ACTIVE StreamName my stream Shards ShardId shardId 000000000001 Has
  • fullcalendar jquery 插件标题字符串中的 HTML

    我认为 fullcalendar jquery plugin 是一个非常好的解决方案 但是 我注意到插件转义了 htmlEscape 标题 但我需要格式化标题中的一些字符串 例如粗体文本 颜色或小图像 使用另一个插件 例如 qTip 如示例
  • 陷入 Gradle Build 运行状态

    当尝试在 Android Studio 2 1 在 Ubuntu 16 04 上 上构建我的应用程序时 它陷入了以下注释 Executing tasks app generateDebugSources app mockableAndroi
  • Helm:从键可变的 Map 中获取值

    我有一个舵图如下 dns entries cluster1 xx xx xx xx cluster2 xx xx xx xx 安装 Helm Chart 时 集群值也会动态设置 在模板中 我需要从上面的地图中动态选择它 if hasKey
  • 为什么我可以使用 nullptr 而不包含 STL?

    The C nullptr属于类型std nullptr t 为什么一个程序喜欢 int main int ptr nullptr 仍然可以工作 尽管它不包含任何 STL 库 在C 11中他们想添加一个关键字来替换宏NULL 基本上定义为
  • 使用 Cromis IPC 进行双向通信

    我已经下载并玩了克罗米斯工控机 http www cromis net blog 2009 11 cromis ipc fast inter process communication named pipes 来自 Iztok Kacin
  • 比较字符串(文字和数字)的最快方法

    我有一个与字符串比较 Java 中 相关的性能问题 我正在开发一个需要对巨大列表进行排序的项目 Eclipse 中的 TableViewer 无论如何 我已经将瓶颈定位到对要比较的字符串的compareTo 的调用 有什么方法可以优化字符串
  • C#:使用通用字典 保存混合类型的设置并返回正确的值和类型转换

    我正在尝试实现一个类 以优雅且易于维护的方式保存用户设置 有一个广泛的可能设置列表 其中包含多种类型的设置 int double string 等 我试图使用字典 但由于我的类型是混合的 所以我使用通用对象类型作为键返回值 我还有另一个字典
  • 多个 JFrame 的使用:好还是坏实践? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我正在开发一个显示图像并播放数据库中的声音的应用程序 我正在尝试决定是否使用单独的 JFrame 从 GUI 将图像添加到数据库中 我只是想知道
  • getrandbits 不产生恒定长度的数字

    我使用Python 2 6 6 我使用 getrandbits 128 获取 128 位随机数 a random getrandbits 128 然而 位数并不总是 128 有时甚至少于 128 这是什么原因呢 有没有比较稳定的库 这 12
  • 对多输出 Keras 模型中的每个输出使用不同的样本权重

    我的输入数组是image array 包含 10000 张大小为 512x512 4 个通道的图像的数据 IE image array shape 10000 512 512 4 每张图像都有一个相关的指标 我想训练 CNN 来为我进行预测