如何使用tensorflow keras在网络中一起使用嵌入层和其他特征列

2024-04-01

让我们考虑一个包含 6 列和 10 行的示例数据集。

这 3 列是数字,其余 3 列是分类变量。

分类列被转换为大小为 10x3 的多热编码数组。

我有目标列,我想要预测它也是分类变量,它可以再次采用 3 个可能的值。这一列是一个热编码的列。

现在我想使用这个多重热编码数组作为嵌入层的输入。嵌入层应输出 2 个单位。

然后我想使用数据集中的 3 个数字列和嵌入层的 2 个输出单元,总共 5 个单元作为隐藏层的输入。

这就是我被卡住的地方。我不知道如何使用tensorflow keras桥接嵌入层和其他特征列,也不知道如何传递嵌入层和其他2个单元的输入。

我已经用谷歌搜索过了。我尝试了以下代码,但仍然出现错误。我猜 tf.keras 包中没有 Merge 层.

对此的任何帮助将不胜感激。

        import tensorflow as tf
        from tensorflow import keras
        import numpy as np

        num_data = np.random.random(size=(10,3))
        multi_hot_encode_data = np.random.randint(0,2, 30).reshape(10,3)
        target =  np.eye(3)[np.random.randint(0,3, 10)]

        model = keras.Sequential()
        model.add(keras.layers.Embedding(input_dim=multi_hot_encode_data.shape[1], output_dim=2))
        model.add(keras.layers.Dense(3, activation=tf.nn.relu, input_shape=(num_data.shape[1],)))
        model.add(keras.layers.Dense(3, activation=tf.nn.softmax)

        model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
                      loss=keras.losses.categorical_crossentropy,
                      metrics=[keras.metrics.categorical_accuracy])

        #model.fit([multi_hot_encode_data, num_data], target)   # I get error here 

我的网络结构将是

    multi-hot-encode-input  num_data_input 
            |                   |
            |                   |
            |                   |
        embedding_layer         |
            |                   |
            |                   | 
             \                 /        
               \              / 
              dense_hidden_layer
                     | 
                     | 
                  output_layer 

这种“合并”模式与顺序模型不兼容。我认为使用函数式 keras API 更容易keras.Model https://keras.io/models/model/代替keras.Sequential (主要差异的简短解释 https://jovianlin.io/keras-models-sequential-vs-functional/):

import tensorflow as tf
from tensorflow import keras
import numpy as np

num_data = np.random.random(size=(10,3))
multi_hot_encode_data = np.random.randint(0,2, 30).reshape(10,3)
target =  np.eye(3)[np.random.randint(0,3, 10)]

# Use Input layers, specify input shape (dimensions except first)
inp_multi_hot = keras.layers.Input(shape=(multi_hot_encode_data.shape[1],))
inp_num_data = keras.layers.Input(shape=(num_data.shape[1],))
# Bind nulti_hot to embedding layer
emb = keras.layers.Embedding(input_dim=multi_hot_encode_data.shape[1], output_dim=2)(inp_multi_hot)  
# Also you need flatten embedded output of shape (?,3,2) to (?, 6) -
# otherwise it's not possible to concatenate it with inp_num_data
flatten = keras.layers.Flatten()(emb)
# Concatenate two layers
conc = keras.layers.Concatenate()([flatten, inp_num_data])
dense1 = keras.layers.Dense(3, activation=tf.nn.relu, )(conc)
# Creating output layer
out = keras.layers.Dense(3, activation=tf.nn.softmax)(dense1)
model = keras.Model(inputs=[inp_multi_hot, inp_num_data], outputs=out)

model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
              loss=keras.losses.categorical_crossentropy,
              metrics=[keras.metrics.categorical_accuracy])
  • 您应该在连接嵌入层之前压平嵌入层的输出,或者 numeric_data 应该具有兼容的形状和至少三个维度
  • 在各层之后定义功能模型。输入和输出可以是单层或可迭代的层

输出model.summary:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            (None, 3)            0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 3, 2)         6           input_5[0][0]                    
__________________________________________________________________________________________________
flatten (Flatten)               (None, 6)            0           embedding_2[0][0]                
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, 3)            0                                            
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 9)            0           flatten[0][0]                    
                                                                 input_6[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 3)            30          concatenate_2[0][0]              
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 3)            12          dense[0][0]                      
==================================================================================================
Total params: 48
Trainable params: 48
Non-trainable params: 0
__________________________________________________________________________________________________

此外,它也成功适配:

model.fit([multi_hot_encode_data, num_data], target)
Epoch 1/1
10/10 [==============================] - 0s 34ms/step - loss: 1.0623 - categorical_accuracy: 0.3000
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用tensorflow keras在网络中一起使用嵌入层和其他特征列 的相关文章

随机推荐

  • ini_set、set_time_limit、(max_execution_time) - 不起作用

    If I do set time limit 50 or ini set max execution time 50 然后当我回声时ini get max execution time 在我的本地主机上我得到50 但是当我在另一台服务器上执
  • 保持 LDAP 会话

    在 PHP 中 假设我在第 1 页有一个 LDAP 连接 ldapconn ldap connect ldapserver if ldapconn binding to ldap server ldapbind ldap bind ldap
  • 如何在不加载整个文件的情况下向 CSV 添加标题行?

    我有一个console application我想添加一个header row到 CSV 文件 而不将数据加载到应用程序中 我需要什么代码来执行此操作 并且仅检查第一行以查看标题是否已存在 如果不存在则不添加标题行 我尝试了几种方法来执行此
  • Excel VBA 类型不匹配将范围传递给数组时出错[重复]

    这个问题在这里已经有答案了 我正在尝试检索工作表中单元格数组的值 存储为数组而不是简单单元格 但由于某种原因不断收到运行时错误 13 类型不匹配 我读过有关类似问题的帖子 但其中许多似乎与数组类型错误 即不是变体类型 或静态大小有关 这是调
  • 如何指定 xsi:type zeep python

    我使用 python 的 zeep SOAP 客户端 尝试将一些数据获取到某些 wsdl address 我现在有以下内容 ambCase data1 value1 data2 value2 client zeep Client wsdl
  • 如何访问Hadoop HDFS中的文件?

    我的 Hadoop HDFS 中有一个 jar 文件 包含我想要修改的 Java 项目 我想在 Eclipse 中打开它 当我打字时hdfs dfs ls user 我可以看到 jar 文件在那里 但是 当我打开 Eclipse 并尝试导入
  • 如何创建一个不会重新创建具有相同输入参数的对象的类

    我正在尝试创建一个不会重新创建具有相同输入参数的对象的类 当我尝试使用与创建已存在对象相同的参数实例化一个类时 我只希望我的新类返回指向已创建 昂贵创建的 对象的指针 这是我到目前为止所尝试过的 class myobject0 object
  • Laravel 4 无法运行整个 RAW 查询

    我想使用 laravel 的 DB 类来执行 mysql 查询 但 Laravel 提供的功能都不起作用 这些都不起作用 DB statement DB select DB raw DB update DB select DB raw 这是
  • 转义字符串以在 XML 中使用

    我正在使用Python的xml dom minidom创建 XML 文档 逻辑结构 gt XML 字符串 而不是相反 如何让它转义我提供的字符串 这样它们就不会弄乱 XML 像这样的东西吗 gt gt gt from xml sax sax
  • 如何将 AWS WAF 与应用程序 ELB 结合使用

    我需要对 AWS 上托管的 Web 应用程序使用 AWS WAF 以为其提供额外的基于规则的安全性 我找不到任何方法直接将 WAF 与 ELB 结合使用 并且 WAF 需要 Cloudfront 添加 WEB ACL 以根据规则阻止操作 因
  • 在Interface Builder中设计UITableView的节标题

    我有一个xib文件带有UITableView我想使用委托方法添加自定义节标题视图tableView viewForHeaderInSection 有没有可能设计成Interface Builder然后以编程方式更改其一些子视图属性 My U
  • 遗留代码中的泛型

    我们有相当多的代码刚刚跳转到 Java 5 我们一直在那些打算在 Java 5 版本中发布的组件中使用泛型 但是剩下的代码当然充满了原始代码类型 我已将编译器设置为生成原始类型错误并开始手动清除它们 但按照目前的速度 这将需要very很长时
  • 什么样的日志记录对您的应用程序来说是好的日志记录?

    因此 我们已经讨论了在我的工作地点进行登录 我想知道这里的一些人是否可以给我一些关于你们的方法的想法 通常我们的场景是 根本没有日志记录 并且大多数是 NET 应用程序 winforms WPF 客户端通过 Web 服务进行通信或直接与数据
  • 如何知道推送通知发送状态

    我正在应用程序中使用推送通知 一切都很顺利 有时从服务器发送的消息但在应用程序端它没有收到 在这种情况下 我必须知道缺少哪条消息无法传递 应用程序未收到 有没有办法从服务器端知道应用程序收到了哪些消息 哪些没有收到 不 推送通知是一劳永逸的
  • 如何在 Laravel Passport 中获取刷新令牌?

    我正在使用 Laravel 6 7 并尝试使用Passport用于用户身份验证 我可以在用户注册时为他们创建访问令牌 这是代码 user User create input user gt createToken auth token 正如
  • NUnit 与 Windows Phone 7 [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我想使用 NUnit 对我的 Windows Phone 7 库进行单元测试 是否有与 Windows Phone 7 兼容的 NUnit
  • 通过 React App 的数据表按钮

    在 React 中工作时 我在尝试添加 Excel 导出按钮时遇到了问题 我认为它与导入有关 但在这方面我在网上找不到太多与 React 和 DataTables net 相关的帮助 我只希望用户能够下载到 Excel 这些是我与 jque
  • 我在使用 log4js-protractor-appender 时遇到麻烦

    我的 log4js js 文件代码 use strict var log4js require log4js var log4jsGen getLogger function getLogger log4js loadAppender fi
  • 模块错误“模块是使用不兼容的 Kotlin 版本编译的。其元数据的二进制版本是 1.5.1,预期版本是 1.1.16”

    我正在为我们的项目编写一个 kotlin 库 完成后 我构建了一个 aar 文件并将其发送给团队 但他们有一个错误 Module was compiled with an incompatible version of Kotlin The
  • 如何使用tensorflow keras在网络中一起使用嵌入层和其他特征列

    让我们考虑一个包含 6 列和 10 行的示例数据集 这 3 列是数字 其余 3 列是分类变量 分类列被转换为大小为 10x3 的多热编码数组 我有目标列 我想要预测它也是分类变量 它可以再次采用 3 个可能的值 这一列是一个热编码的列 现在