在 keras 中集成采样的 softmax 失败

2024-05-14

基于如何在 Keras 模型中使用 TensorFlow 的采样 softmax 损失函数? https://stackoverflow.com/questions/47892380/how-can-i-use-tensorflows-sampled-softmax-loss-function-in-a-keras-model,我创建了这段代码:

class SampledSoftmax(tensorflow.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(SampledSoftmax, self).__init__(**kwargs)


    def call(self, inputs):

        def f1(inputs):
            return tf.nn.sampled_softmax_loss(
                                inputs[0]._keras_history[0].weights[0],
                                inputs[0]._keras_history[0].bias,
                                tf.reshape(tf.argmax(inputs[1], 1), [-1, 1]),
                                inputs[0],                                
                                8192,
                                817496)
        def f2(inputs):
            logits = tf.matmul(inputs[0], tf.transpose(inputs[0]._keras_history[0].weights[0]))
            logits = tf.nn.bias_add(logits, inputs[0]._keras_history[0].bias)            
            return tf.nn.softmax_cross_entropy_with_logits_v2(
                                labels=inputs[1],
                                logits=logits)    

        return tf.cond(K.learning_phase(), true_fn=f1(inputs), false_fn=f2(inputs))

与以下型号一起使用时:

#model
input_layer = Input(shape=(None,), dtype='int32')
target_input = Input(shape=(None,vocab_size), dtype='int8')

embedding_layer = Embedding(vocab_size,
                            EMBEDDING_DIM,
                            trainable=True,
                            mask_zero=True) (input_layer)
common = LSTM(LSTM_UNITS, return_sequences=True,dropout=0.2, recurrent_dropout=0.2)(embedding_layer)
common = (Dense(PROJ_UNITS, activation='linear'))(common)
out = (Dense(vocab_size, name='output_layer'))(common)
out = (SampledSoftmax())([out, target_input])


model = Model(inputs=[input_layer,target_input], outputs=out)

它因以下错误而失败: ValueError:形状必须为等级 2,但对于“sampled_softmax/sampled_softmax_loss/MatMul”(操作:“MatMul”),形状必须为等级 3,输入形状为:[?,?,817496]、[?,817496]。

我根据谷歌搜索取得了一些进展:

class MyLayer(tensorflow.keras.layers.Dense):
    def __init__(self, num_sampled, num_classes, mode,  **kwargs):
        self.num_sampled = num_sampled
        self.num_classes = num_classes
        self.mode = mode
        super(MyLayer, self).__init__(num_classes, **kwargs)
        self.input_spec = [InputSpec(ndim=2)]

    def build(self, input_shape):
        #self.input_spec = [InputSpec(shape=input_shape)]
        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, inputs_and_labels):
        inputs, labels = inputs_and_labels
        if self.mode == "train":
            loss = tf.nn.sampled_softmax_loss(
                weights=self.kernel,
                biases=self.bias,
                labels=tf.reshape(tf.argmax(labels, 1), [-1, 1]),
                inputs=inputs,
                num_sampled=self.num_sampled,
                num_classes=self.num_classes,
                num_true=1)

        elif self.mode == "eval":
            logits = tf.matmul(inputs, tf.transpose(self.kernel))
            logits = tf.nn.bias_add(logits, self.bias)
            loss = tf.nn.softmax_cross_entropy_with_logits(
                labels=labels,
                logits=logits)

        return loss

    def compute_output_shape(self, input_shape):
        dense_shape, classes_shape = input_shape
        return (dense_shape[0], )    

现在的错误: 现在的错误:

ValueError: Layer my_layer expects 1 inputs, but it received 2 input tensors. Inputs received: [<tf.Tensor 'dense/BiasAdd:0' shape=(?, ?, 512) dtype=float32>, <tf.Tensor 'input_2:0' shape=(?, ?, 817496) dtype=int8>]

我尝试使用 self.input_spec 但直到现在才起作用。


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 keras 中集成采样的 softmax 失败 的相关文章

随机推荐

  • Tomcat - 将旧上下文根重定向到新上下文根

    我们想要更改 Tomcat Web 应用程序的上下文根 并让旧的 URL 引导用户访问新命名的应用程序 http hostname oldappname http hostname newappname 实现此目的的一种方法是部署具有 ne
  • 方法返回 IOrderedEnumerable 而不是 IEnumerable 是否有利?

    Can it be advantageous for a method to return IOrderedEnumerable instead of IEnumerable 仅当您希望人们每次都订购该枚举并且发现很难弄清楚如何执行此操作时
  • 编程错误:(psycopg2.errors.UndefinedColumn)关系“task_fail”的列“execution_date”不存在

    我正在尝试在气流中运行 DAG 以将数据集摄取到谷歌云存储 这是 DAG 脚本 import os from airflow import DAG from airflow utils dates import days ago from
  • 编辑 InitializeComponent() 方法 C#

    我已经浏览了多个资源 试图找到何时手动向 InitializeComponent 添加代码的用例 但没有找到任何具体的内容 这表明我们不应该这样做 InitializeComponent 方法中的代码由设计者生成 不应手动修改 https
  • 原子聚合的使用

    我想在下一个查询中找到年龄最小的人 d q find name min age in name age John 20 Bill 25 Jack 20 Steve 28 Andrew 30 但结果是 Andrew 30 Bill 25 Ja
  • 在viewpager2中禁用动画

    我有 viewpager2 和扩展 FragmentStateAdapter 的适配器 我希望用户仅通过单击选项卡布局即可转到另一个页面 我已禁用此 viewpager2 的用户输入 但是当我单击选项卡时 有页面之间快速滑动的动画 但我只想
  • sympy 任意函数范围

    我想定义任意函数f 我知道 f 总是返回一个正数 我希望 sympy 在运行简化时能够使用这些知识 特别是简化文档中提到的三个幂规则 有没有办法做到这一点 我正在寻找类似下面的东西 f Function f positive True g
  • 如何从 shell 编译 macOS Sierra 上使用 dylib 路径的源代码

    我正在编译一些源代码 需要我已经构建的其他项目中的一些 dylib 我越来越 ld 未找到架构 x86 64 的符号 每当我执行 g some code cpp I usr local include o executable binary
  • 我什么时候应该关闭游标和数据库?

    我在自定义视图中以不同的方法多次使用相同的光标 我应该在每次使用后关闭光标还是可以保持它打开直到视图被破坏 对于数据库也是如此 是否可以在创建保存此视图的活动时打开它并在活动销毁时关闭它 当我按照上述操作时 我不断收到错误 close 从未
  • 是否有用于封闭类型名称的简短版本的 Eclipse 模板变量

    我想在 Eclipse 中为 Java 类创建一个构造函数模板 我有一个适用于大多数课程的版本 尽管它不适用于嵌套在其他类中的类 见类Inner如下 如何获得类名的简短版本 模板不起作用 public newType enclosing t
  • Jackson:无法反序列化 START_OBJECT 令牌中的 Number 实例

    我的 GWT 服务返回LinkedList
  • 为什么 FLT_MIN 等于 0?

    limits h指定非浮点数学类型的限制 例如INT MIN and INT MAX 这些值是可以使用 int 表示的最大负值和最大正值 In float h 有定义FLT MIN and FLT MAX 如果您执行以下操作 NSLog f
  • Laravel MySQL 按计数排序

    我正在使用 Laravel 和 MySQL 并且我有一个表post代表用户可以评论的帖子 现在我想按照每个帖子的评论数量按升序 降序对帖子进行排序 我该如何在 Laravel 中执行此操作 我不想添加字段post表来跟踪每个帖子的评论数量
  • 如何创建基于多个 IEnumerable 的集合

    我想要操作的类提供了类型的吸气剂IEnumerable
  • mongorestore 从独立到复制集

    我已转储在默认端口上运行的独立 mongo 数据库 14Gb 大 如下所示 mongodump username
  • 在 Inno Setup 中单击“下一步”按钮时验证自定义页面上的数据

    我已经设法获得一个基本脚本来显示向导 使用CreateInputFilePage 供用户识别我用来更新 XML 文件中某些设置的文件位置 但是 我想对所选文件的输入进行一些基本检查 而不是简单地接受用户提供的任何内容 例如 如果用户在内容无
  • 识别相似图像的库

    我想确定 2 张图像的相似程度 图像可能已被缩放 裁剪等 因此简单的像素比较将不起作用 我环顾四周 有很多关于这个主题的学术论文 但他们没有发布他们的代码 那么 您知道有一个可以比较图像的已发布库 适用于 Linux 和 Windows 吗
  • Windows 版 GitKraken 中的文件名太长

    正如建议的Q22575737 https stackoverflow com a 22575737 6623589 我已经更新了我的注册表并设置了git config system core longpaths true在处理长路径时 问题
  • 是否可以使用.NET 跟踪文件操作?

    当以某种方式调用文件操作 例如打开或关闭 时 我是否可以在操作系统继续请求之前处理它 如果可能的话可以通过以下方式取消它 NET http en wikipedia org wiki NET Framework 如果 NET没有这样的能力
  • 在 keras 中集成采样的 softmax 失败

    基于如何在 Keras 模型中使用 TensorFlow 的采样 softmax 损失函数 https stackoverflow com questions 47892380 how can i use tensorflows sampl