在tensorflow keras中采样softmax

2024-02-17

我想在 tf keras 中进行采样的 softmax 损失。我通过子类化 keras 模型来定义自己的模型。在 init 中,我指定了所需的层,包括最后一个密集投影层。但是这个密集层不应该在训练中调用,因为我想做采样的softmax并且只使用它的权重和偏差。然后我这样定义损失函数:

class SampledSoftmax:
    def init( self,
              num_sampled,
              num_classes,
              projection,
              bias,
              hidden_size):
        self.weights = tf.transpose(projection)
        self.bias = bias
        self.num_classes = num_classes
        self.num_sampled = num_sampled
        self.hidden_size = hidden_size

    def call(self, y_true, input):
        """ reshaping of y_true and input to make them fit each other """
        input = tf.reshape(input, (-1,self.hidden_size))
        y_true = tf.reshape(y_true, (-1,1))

        return tf.nn.sampled_softmax_loss(
                   weights=self.weights,
                   biases=self.bias,
                   labels=y_true,
                   inputs=input,
                   num_sampled=self.num_sampled,
                   num_classes=self.num_classes,
                   partition_strategy='div')

它接受必要的参数进行初始化,并且类调用将是所需的采样 softmax 损失函数。问题是,为了在模型编译中添加损失,我需要最后一个 Dense 的权重等。但是 1)在训练中 Dense 不包含在模型中,2)即使包含在模型中,Dense 层也只会与输入连接,从而在调用我的自定义模型时获取其输入尺寸等。简而言之,权重等在编译模型之前是不可用的。谁能提供一些帮助来指出我正确的方向?

现在是导致它失败的代码。我首先对模型进行子类化,如下所示:

class LanguageModel(tf.keras.Model):
    def __init__(self, 
                 vocal_size=15003, 
                 embedding_size=512
                 input_len=64)
       self.embedding = Embedding(vocal_size, embedding_size, 
                                  input_length=input_len)
       self.lstm = LSTM(hidden_size, return_sequences=True)
       self.dense = Dense(vocal_size, activation='softmax')

   def call(self, inputs, training=False):
       emb_out = self.embedding(inputs)
       lstm_out = self.lstm(embrace_out)
       res = self.dense(lstm_out)
       if (training)
           ''' shouldn't use the last dense as we want to do sampling'''
           return lstm_out
       return res

然后训练模型的部分如下

sampled_loss = SampledSoftmax(num_sampled, vocal_size, 
                   model.dense.kernel, model.dense.bias,
                   hidden_size)

model.compile(optimizer=tf.train.RMSPropOptimizer(lr),
              loss=sampled_loss)

然而我使用它会失败,因为 model.dense.kernel 无法访问,因为在编译模型时,密集层尚未在调用方法中初始化。错误信息如下:

Traceback (most recent call last):
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/wuxinyu/workspace/nlu/lm/main.py", line 72, in <module>
    train_main()
  File "/home/wuxinyu/workspace/nlu/lm/main.py", line 64, in train_main
    train_model.build_lm_model()
  File "/home/wuxinyu/workspace/nlu/lm/main.py", line 26, in build_lm_model
self.model.dense.kernel,
AttributeError: 'Dense' object has no attribute 'kernel'

顺便说一句,上面定义的损失将适用于如下所示的小型测试用例。

x = Input(shape=(10,), name='input_x')
emb_out = Embedding(10000,200,input_length=10)(x)
lstm_out = LSTM(200, return_sequences=True)(emb_out)

dense = Dense(10000, activation='sigmoid')
output = dense(lstm_out)

sl = SampledSoftmax(10, 10000, dense.kernel, dense.bias)

model = Model(inputs=x, outputs=lstm_out)
model.compile(optimizer='adam', loss=sl)
model.summary()
model.fit(dataset, epochs=20, steps_per_epoch=5)

None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在tensorflow keras中采样softmax 的相关文章

随机推荐

  • 如何移除后退堆栈片段的焦点?

    我在我的应用程序中使用片段 我有一个片段包含EditText还有一些Dialogfragment 当我单击一个特定的小部件时 它将移动到下一个片段 我需要后台堆栈中的第一个片段 因此我还添加了 addToBackStack 方法 第二个片段
  • 以图形方式显示 IntelliJ 中 git log --follow 的等效项

    IntelliJ 14 有没有办法显示特定文件的完整日志 我的意思是 执行一种git log follow以图形方式查看旧版本 在这些文件可能被重命名之前 目前 当我这样做时Git gt 显示历史记录在文件上 它仅显示相当于git log
  • 不带任何操作的 asp.net 路由语法

    我正在尝试建立一条没有任何操作而只有一个参数的路线 domain com 不带任何参数 应转到一个控制器 however 域名 com somestring 域名 com anotherstring 域名 com anythingreall
  • 将 uint16_t 转换为 char[2] 以通过套接字发送(unix)

    我知道大致上有关于这方面的事情 但是我的大脑受伤了 我找不到任何东西可以让这项工作发挥作用 我正在尝试通过 unix 套接字发送一个 16 位无符号整数 为此 我需要将 uint16 t 转换为两个字符 然后我需要在连接的另一端读入它们并将
  • 系统设置意图后无法返回活动

    在我的应用程序中 我需要进入手机的设置活动来激活 GPS 并希望使用以下代码返回我的应用程序 Intent intent new Intent Settings ACTION LOCATION SOURCE SETTINGS startAc
  • 仅包含标准库的 Golang 中间件

    我的第一个 stackoverflow 问题 所以请不要介意我对 stackoverflow 的天真和所问的问题 golang 的初学者 我想知道这两个调用之间的区别以及简单的理解Handle Handler HandleFunc Hand
  • 将列名添加到 dplyr 函数内的 vars()

    我有一个函数 可用于根据一些用户定义的组来汇总变量 利用dplyr library tidyverse get var summary lt function data target var group vars vars target v
  • ggplot2 分类x轴的不同面宽度[重复]

    这个问题在这里已经有答案了 我正在绘制分类数据的不同方面 df lt as data frame as factor c A B C D E F names df lt Xvar df Yvar lt c 2 1 4 5 3 7 df fa
  • 我可以使用什么方法来代替 python 中的 __file__ ?

    我通过 cython 将 python 代码转换为 c 然后编译 c 文件并在我的项目中使用 so 我的问题 我用 file 在我的 python 代码和 gcc 编译时 它不会出现错误 但是当我运行程序并在其他 python 文件中导入
  • 将 R闪亮应用程序部署为独立应用程序[关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我开发了一个 RShiny 应用程序 我想与我的同事在内部共享 现阶段无法在服务器上托管该应用程序 我正
  • Webkit 伪元素文档

    我实际上完成了我想做的事情 当我想打印页面时隐藏一些 webkit 伪元素 代码如下所示 问题是我没有从我的研究中学到任何东西来做到这一点 而且我找不到任何关于它的文档 而且我看到的关于这个主题的每个答案都只显示了代码 没有任何进一步的解释
  • 优雅关闭失败

    我有一个带有 server shutdown graceful 的 spring boot 2 3 应用程序 当关闭时会抛出 2020 11 30 11 07 35 485 WARN 3038 SpringContextShutdownHo
  • SQL 存储过程 - 请帮我写这个! (第2部分)

    我有下表 其中值为 501 CREATE TABLE Numbers Number numeric 20 0 NOT NULL PRIMARY KEY INSERT INTO Numbers VALUES 501 我如何在此上编写一个存储过
  • 在 Java/Swing 的全屏程序中停止使用 Tab/Alt-F4

    我需要一种方法来阻止人们在我的 Java 程序运行时使用其他程序 即阻止人们切换选项卡并按 Alt F4 使程序全屏使用 window setExtendedState Frame MAXIMIZED BOTH maximise windo
  • C# ASCII 或 Unicode

    您好 我是编程和网络开发的初学者 我有一个关于 ASCII 和 Unicode 编码的问题 在 msdn 和其他 Web 示例中执行以下操作 byte byteData Encoding ASCII GetBytes data 这是因为这些
  • 如何将 Google Cloud AI Platform Jupyter Lab 升级到 Python 3.7+

    Google Cloud Platform的AI Platform可以方便地部署Jupyter Lab 但仅适用于Python 2和Python 3 5 3 如何升级我的实例才能运行 Python 3 7 或更高版本 笔记本 该解决方案是基
  • 在 Visual Studio 2012 的新 C++ 项目中自动创建的 stdafx.cpp 文件是什么

    据我了解 stdafx h 是一个预编译头文件 用于使 Visual Studio 中的编译时间更快 当我在 Visual Studio 2012 中创建 C 项目时 还有一个 stdafx cpp 有人可以解释 stdafx h 和 st
  • 将配置文件共享给多个 docker 容器

    假设我的 Docker 主机上有以下配置文件 并且我希望多个 Docker 容器能够访问该文件 opt shared config file yml 在典型的非 Docker 环境中 我可以使用符号链接 例如 opt app1 config
  • 检测舞台何时再次聚焦并加载场景

    我有一个父舞台 可以在其顶部显示弹出窗口 这是代码 private static Stage chooseBreedStage static chooseBreedStage new Stage chooseBreedStage setTi
  • 在tensorflow keras中采样softmax

    我想在 tf keras 中进行采样的 softmax 损失 我通过子类化 keras 模型来定义自己的模型 在 init 中 我指定了所需的层 包括最后一个密集投影层 但是这个密集层不应该在训练中调用 因为我想做采样的softmax并且只