使用 Transformer 模型进行多实例分类

2023-12-04

我使用这里的变压器Keras 文档示例用于多实例分类。每个实例的类取决于一个包中的其他实例。我使用变压器模型是因为:

它不对数据之间的时间/空间关系做出任何假设。这非常适合处理一组对象

例如,每个包最多可以有 5 个实例,每个实例有 3 个特征。

# Generate data
max_length = 5
x_lst = []
y_lst = []
for _ in range(10):
    num_instances = np.random.randint(2, max_length + 1)
    x_bag = np.random.randint(0, 9, size=(num_instances, 3))
    y_bag = np.random.randint(0, 2, size=num_instances)
    
    x_lst.append(x_bag)
    y_lst.append(y_bag)

前 2 个袋子的特征和标签(分别有 5 个和 2 个实例):

x_lst[:2]

[array([[8, 0, 3],
        [8, 1, 0],
        [4, 6, 8],
        [1, 6, 4],
        [7, 4, 6]]),
 array([[5, 8, 4],
        [2, 1, 1]])]

y_lst[:2]

[array([0, 1, 1, 1, 0]), array([0, 0])]

接下来,我用零填充特征,用 -1 填充目标:

x_padded = []
y_padded = []

for x, y in zip(x_lst, y_lst):
    x_p = np.zeros((max_length, 3))
    x_p[:x.shape[0], :x.shape[1]] = x
    x_padded.append(x_p)

    y_p = np.negative(np.ones(max_length))
    y_p[:y.shape[0]] = y
    y_padded.append(y_p)

X = np.stack(x_padded)
y = np.stack(y_padded)

where X.shape等于(10, 5, 3) and y.shape等于(10, 5).

我对原始模型做了两处更改:添加了 Masking 层 在输入层之后,并将最后一个密集层中的单元数量设置为包的最大尺寸(加上“sigmoid”激活):

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(inputs, inputs)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    return x + res

def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    inputs = keras.layers.Masking(mask_value=0)(inputs) # ADDED MASKING LAYER
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(5, activation='sigmoid')(x) # CHANGED ACCORDING TO MY OUTPUT
    return keras.Model(inputs, outputs)

input_shape = (5, 3)

model = build_model(
    input_shape,
    head_size=256,
    num_heads=4,
    ff_dim=4,
    num_transformer_blocks=4,
    mlp_units=[128],
    mlp_dropout=0.4,
    dropout=0.25,
)

model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    metrics=["binary_accuracy"],
)
model.summary()

看起来我的模型并没有学到太多东西。如果我使用每个包的真实值数量(y.sum(axis=1) and Dense(1))作为目标而不是对每个实例进行分类,模型学习效果很好。我的错误在哪里?在这种情况下我应该如何构建输出层?我需要自定义丢失功能吗?

更新: 我做了一个自定义损失函数:

def my_loss_fn(y_true, y_pred):
    mask = tf.cast(tf.math.not_equal(y_true, tf.constant(-1.)), tf.float32)
    y_true, y_pred = tf.expand_dims(y_true, axis=-1), tf.expand_dims(y_pred, axis=-1)
    bce = tf.keras.losses.BinaryCrossentropy(reduction='none')
    return tf.reduce_sum(tf.cast(bce(y_true, y_pred), tf.float32) * mask)

mask = (y_test != -1).astype(int)
pd.DataFrame({'n_labels': mask.sum(axis=1), 'preds': ((preds * mask) >= .5).sum(axis=1)}).plot(figsize=(20, 5))

And it looks like the model learns: enter image description here

But it predicts all nonmasked labels as 1. enter image description here

@thushv89 这是我的问题。我取 2 个时间点:t1 和 t2,并查找在时间 t1 进行维护的所有车辆以及计划在时间 t2 进行维护的所有车辆。这是我的袋子里的东西。然后我计算一些特征,例如 t1 车辆已经花费了多少时间进行维护、从 t1 到 t2 车辆的计划开始需要多长时间等。如果我尝试预测 t2 时刻进行维护的车辆数量,我的模型会学得很好,但我想预测他们中的哪一个会离开,哪一个会进来(3 vs [True, False, True, True] 包里有 4 辆车)。


有以下三项重要改进:

  1. 将 GlobalAveragePooling1D 层替换为 Flatten 层。
  2. 添加自定义损失函数以从计算中排除目标填充(已添加到我的问题中),如果您想查看真实指标,请添加自定义指标函数。
  3. 将attention_mask添加到MultiHeadAttention(而不是Masking层)以掩盖填充。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Transformer 模型进行多实例分类 的相关文章

  • 行未从树视图复制

    该行未在树视图中复制 我在按行并复制并粘贴到未粘贴的任何地方后制作了弹出复制 The code popup tk Menu tree opportunity tearoff 0 def row copy item tree opportun
  • 一次将Python dict的内容分配给多个变量?

    我想做这样的事情 def f return a 1 b 2 c 3 a b f or a b f IE 这样 a 被分配为 1 b 被分配为 2 并且 c 是未定义的 这与此类似 def f return 1 2 a b f 依赖于变量名称
  • on_delete=models.PROTECT 和 on_delete=models.CASCADE 在 Django 模型上有什么作用?

    我对 Django 很熟悉 但最近注意到有一个on delete models CASCADE and on delete models PROTECT模型的选项 on delete models CASCADE and on delete
  • 如何确定非阻塞套接字是否真正连接?

    这个问题不仅限于Python 这是一个一般的套接字问题 我有一个非阻塞套接字 想要连接到一台可访问的机器 在另一端 该端口不存在 为什么 select 仍然成功 我预计会超时 sock send 因管道损坏而失败 select 之后如何确定
  • python 中分割字符串以获得一个值?

    需要帮助 假设我在名为 input 的变量中有一个字符串 Sam Person name kind input split 通过执行上述操作 我得到两个具有不同字符串 Sam 和 Person 的变量 有没有办法只获取第一个值 name S
  • PIL Image.size 返回相反的宽度/高度

    使用PIL确定图像的宽度和高度 在特定图像上 幸运的是只有这一个 但这很麻烦 从 image size 返回的宽度 高度是相反的 图片 http storage googleapis com cookila 533ebf752b9d1f7c
  • 带图像的简单 GUI [关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我试图在简单的 GUI 上显示一些卡
  • 如果字段值在外部列表中,Django 会注释布尔值

    想象一下我有这个 Django 模型 class Letter models Model name models CharField max length 1 unique True 还有这个列表 vowels a e i o u 我想查询
  • 无法打开 Python。错误 0xc000007b

    我最近一直在学习 Python 3 我在我的上网本 32 位 Windows 7 上创建简单的小程序没有任何问题 当我将它安装在我的上网本上时 我没有遇到任何问题 但现在我已经开始使用它了 我想将它安装在我的台式机上 并且我有一个 我的桌面
  • matplotlib matshow 标签

    我一个月前开始使用 matplotlib 所以我仍在学习 我正在尝试用 matshow 制作热图 我的代码如下 data numpy array a reshape 4 4 cax ax matshow data interpolation
  • 如何在python中检索aws批处理参数值?

    流程 Dynamo DB gt Lambda gt 批处理 如果将角色 arn 插入动态数据库 它是从 lambda 事件中检索的 然后使用submit job角色 arn 的 API 被传递为 parameters role arn ar
  • 如何使用Python的super()来更新父值?

    我对继承很陌生 之前所有关于继承和 Python 的 super 函数的讨论都有点超出我的理解 我当前使用以下代码来更新父对象的值 usr bin env python test py class Master object mydata
  • 为什么我用 beautifulSoup 刮的时候有桌子,但没有 pandas

    尝试抓取条目页面转换为制表符分隔格式 主要拉出序列和 UniProt 登录号 当我跑步时 url www signalpeptide de index php sess m listspdb bacteria s details id 10
  • 如何列出 python PDB 中的当前行?

    在 perl 调试器中 如果重复列出离开当前行的代码段 可以通过输入命令返回到当前行 点 我无法使用 python PDB 模块找到任何类似的东西 如果我list如果我自己离开当前行并想再次查看它 似乎我必须记住当前正在执行的行号 对我来说
  • conda-env list / conda info --envs 如何查找环境?

    我一直在尝试 anaconda miniconda 因为我的用户使用随 miniconda 安装的结构生物学程序 并且作者都没有 A 考虑到可能存在其他 miniconda 应用程序 B 他们的程序将在多用户环境中使用 因此 使用 Arch
  • 如何有效地从 loadmat 函数生成的嵌套 numpy 数组中提取值?

    python中是否有更有效的方法从嵌套的python列表中提取数据 例如A array array 12000000 dtype object 我一直在使用A 0 0 0 0 当你有很多像 A 这样的数据时 这似乎不是一个有效的方法 我也用
  • 在 MacO 和 Linux 上安装 win32com [重复]

    这个问题在这里已经有答案了 我的问题很简单 我可以安装吗win32com蟒蛇API pywin32特别是 在非 Windows 操作系统上 我一直在Mac上尝试多个版本pip install pywin32 都失败了 下面是一个例子 如果你
  • Scipy 稀疏 Cumsum

    假设我有一个scipy sparse csr matrix代表下面的值 0 0 1 2 0 3 0 4 1 0 0 2 0 3 4 0 我想就地计算非零值的累积和 这会将数组更改为 0 0 1 3 0 6 0 10 1 0 0 3 0 6
  • 为什么我们应该在 def __init__(self, n) -> None: 中使用 -> ?

    我们为什么要使用 gt in def init self n gt None 我读了以下摘录来自 PEP 484 https www python org dev peps pep 0484 the meaning of annotatio
  • 异步和协程与任务队列

    我一直在阅读有关 python 3 中的 asyncio 模块的内容 以及更广泛地了解 python 中的协程的内容 但我不明白是什么让 asyncio 成为如此出色的工具 我的感觉是 你可以用协程做的所有事情 通过使用基于多处理模块 例如

随机推荐

  • 如何让awk忽略双引号内的字段分隔符? [复制]

    这个问题在这里已经有答案了 我需要删除逗号分隔值文件中的 2 列 考虑 csv 文件中的以下行 email protected www example com field2 field3 field4 email protected fie
  • 如何在 React 中循环一个对象?

    React 新手 尝试循环对象属性 但 React 抱怨对象不是有效的 React 子对象 有人可以给我一些关于如何解决此问题的建议吗 我已经添加了 createFragment 但不完全确定需要去哪里或者我应该采取什么方法 JS var
  • 获取 mongoid 生成的原始 mongo db 查询表达式

    我想获取 mongoid 生成的 mongo 查询表达式该怎么做 例如这是 mongoid 语法 History where report type params report type order by ts 1 only ts last
  • 在 MVC4 中使用 DotNetOpenAuth 获取 Twitter 访问密钥

    我正在使用 MVC4 创建一个应用程序 该应用程序将授权用户使用 Twitter 并允许他们也从该应用程序发送推文 我可以使用 MVC4 中的BuiltInOAuthClient Twitter 毫无问题地对用户进行身份验证 http ww
  • 如何在 Web Api 中手动执行 Breeze 过滤器?

    我想使用一些外部服务器端逻辑来修改查询结果的属性 为此 我需要应用 Breeze 查询选项 修改结果集并返回它 我基本上知道如何申请OdataQueryOptions我的查询 但我不想错过 BreezeJS 所做而 Web Api 的 OD
  • PHP 两次获取数据

    我的功能看起来就是这样 private function generateTree courseID q SELECT l id l name AS lesson name c name AS course name FROM lesson
  • 什么意思 ”!”在 require.js 中

    什么意思 当我包含模块时在 require js 中 语法是什么 在我的项目中包含动态样式表 我发现https github com martinsb require css插入 效果很好 require css css sample cs
  • 目标版本 1.8 无效

    我尝试在 OPENSHIFT 上部署我的应用程序 但 Maven 无法编译它并出现错误 目标版本 1 8 无效 我的 构建 action hook export JAVA HOME OPENSHIFT DATA DIR jdk1 8 0 0
  • Web 服务上的 X509Certificate2 验证

    我正在开发 WCF Web 服务 用于检查 XML 签名中的证书是否有效 XML 使用合格且有效的 X509 证书进行签名 当我在 Visual Studio 开发环境中运行服务时 X509Certificate2 Verify 和 X50
  • 无状态 Spring MVC

    我目前正在阅读 Spring in Action 第三版 并且一直在尝试 Spring MVC 一切正常 直到我尝试将示例 Web 应用程序 移植 到无状态 Web 应用程序 为了确定是否创建了会话对象 我在 URL映射 只打印出req g
  • 使用 jQuery 加载图像并将其附加到 DOM

    我正在尝试从给定的链接加载图像 var imgPath imgLink attr href 并将其附加到页面上 这样我就可以将其插入到给定元素中对于图像查看器 尽管我搜索过堆栈溢出和jQuery文档没有尽头 我无法弄清楚 加载图像后 我想设
  • 保护C++程序免遭反编译[重复]

    这个问题在这里已经有答案了 可能的重复 是否可以反编译C Builder exe C Builder exe 安全吗 我使用 Microsoft Visual C 2010 Express 来编写程序 当我想分发我的程序时 我使用 发布 配
  • 如何处理JPA命名查询中数字类型的空值

    我想将两个参数传递给namedquery 一种是数字类型 另一种是字符串类型 它们都可以为空 例如 id null username joe 和 id 1 username joe 是两个不同的结果 在namedQuery中 如果id为nu
  • 找不到静态文件 - 在 Heroku 上部署 Django

    我正在尝试在 Heroku 上部署 Django 站点 但在让应用程序查找我的静态文件时遇到问题 我用过python manage py collectstatic将我的静态文件收集到 staticfiles 文件夹中 但我的应用程序似乎仍
  • HttpWebRequest 不发送 UserAgent

    我对 net 的整个 Web 端很陌生 并且遇到了一个小问题 我正在尝试执行以下 HttpWebRequest 操作 String uri https skyid sky com signup HttpWebRequest request
  • 在 Pydantic v2 中使用 bson.ObjectId

    I found 一些例子关于如何在其中使用 ObjectIdBaseModel类 基本上 这可以通过创建 Pydantic 友好的类来实现 如下所示 class PyObjectId ObjectId classmethod def get
  • 什么是交错音频? [关闭]

    Closed 这个问题不符合堆栈溢出指南 目前不接受答案 我在核心音频文档中多次看到此交错音频 有人可以向我解释此属性的真正功能是什么吗 一般来说 如果您有 2 个通道 我们将它们称为 L 左 和 R 右 并且您想要传输或存储 20 个样本
  • PHP 中的换行帮助

    上面是我正在使用的以下代码 我想要的输出是 title reportno 但我得到的输出是 title reportno 谁能告诉我我在换行中做错了什么 您需要添加一个 br 标记到您的输出 abc output
  • httpclient.execute(httpget) 之后的 Android 代码没有在 try 中运行(使用 AsyncTask)

    我正在尝试从网站获取数据并将其解析到我的 Android 应用程序中 不幸的是我什至没有到达解析数据的部分 该代码在以下行之后不会运行 HttpResponse response httpclient execute httpget 结果是
  • 使用 Transformer 模型进行多实例分类

    我使用这里的变压器Keras 文档示例用于多实例分类 每个实例的类取决于一个包中的其他实例 我使用变压器模型是因为 它不对数据之间的时间 空间关系做出任何假设 这非常适合处理一组对象 例如 每个包最多可以有 5 个实例 每个实例有 3 个特