带参数的自定义激活

2024-03-25

我正在尝试在 Keras 中创建一个可以接受参数的激活函数beta像这样:

from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation

class Swish(Activation):

    def __init__(self, activation, beta, **kwargs):
        super(Swish, self).__init__(activation, **kwargs)
        self.__name__ = 'swish'
        self.beta = beta


def swish(x):
    return (K.sigmoid(beta*x) * x)

get_custom_objects().update({'swish': Swish(swish, beta=1.)})

它运行良好,无需beta参数,但是如何在激活定义中包含该参数?我也希望在执行此操作时保存该值model.to_json()就像 ELU 激活一样。


Update:我根据@today的回答编写了以下代码:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)
        self.__name__ = 'swish'

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
    arch_file.write(arch)

然而,它目前不保存beta.json 文件中的值。怎样才能让它保值呢?


由于您想在序列化模型时保存激活函数的参数,因此我认为最好将激活函数定义为像这样的层Keras 中定义的高级激活 https://github.com/keras-team/keras/blob/master/keras/layers/advanced_activations.py。你可以这样做:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

然后您可以像使用 Keras 层一样使用它:

# ...
model.add(Swish(beta=0.3))

Since get_config()方法已在其定义中实现,参数beta使用类似方法时会被保存to_json() or save().

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

带参数的自定义激活 的相关文章

  • ipdb 和 pdb++ 之间的区别?

    Python 有一个名为 pdb 的默认调试器 但社区创建了一些替代品 其中两个是ipdb https github com gotcha ipdb and pdb https github com pdbpp pdbpp 它们似乎迎合了相
  • Celery计划任务中的打印语句不会出现在终端中

    当我跑步时celery A tasks2 celery worker B我想看到每秒打印 芹菜任务 目前没有打印任何内容 为什么这不起作用 from app import app from celery import Celery from
  • 帮助需要在可选条件下编写正则表达式[关闭]

    我有一个日志文件包含如下内容 log Using data from yyyy mm dd 2011 8 3 0 files queued for scanning Warning E test H ndler pdf File not F
  • 无法在 mysql 表中的值中使用破折号(-)[重复]

    这个问题在这里已经有答案了 我一直在尝试从 python 将数据插入 MYSQL 表 我的sql表中的字段是id token start time end time和no of trans 我想存储使用生成的令牌uuid4在令牌栏中 但由于
  • Python 小数.InvalidOperation 错误

    当我运行这样的东西时 我总是收到此错误 from decimal import getcontext prec 30 b 2 3 Decimal b Error Traceback most recent call last File Te
  • 创建一个打开文件并创建字典的函数

    我有一个正在处理的文件 我想创建一个读取文件并将内容放入字典中的函数 然后该字典需要通过 main 函数传递 这是主程序 它无法改变 我所做的一切都必须与主程序配合 def main sunspot dict file str raw in
  • 有条件填写 pandas 数据框

    我有一个数据框df列中包含浮点值A 我想添加另一列B这样 B 0 A 0 for i gt 0 B i if np isnan A i then A i else Step3 B i if abs B i 1 A i B i 1 lt 0
  • Python 使用 M2Crypto 通过 S/MIME 对消息进行签名

    我现在花了几个小时 但找不到我的错误 我想要一个简单的例程来创建 S MIME 签名消息 稍后可以与 smtplib 一起使用 这是我到目前为止所拥有的 usr bin python2 7 coding utf 8 from future
  • Pyinstaller --onefile 警告文件已存在但不应存在

    跑步时Pyinstaller onefile 并开始得到结果 exe 会出现多个弹出窗口 并显示以下警告 WARNING file already exists but should not C Users myuser AppData L
  • Selenium:等到 WebElement 中的文本发生变化

    我在用着selenium使用Python 2 7 从网页上的搜索框检索内容 搜索框动态检索结果并在框本身中显示结果 from selenium import webdriver from selenium webdriver common
  • InvalidArgumentException:消息:无效参数:“using”必须是字符串

    我对 python 很陌生 试图创建可重用的代码 当我尝试通过传递 Login 类下使用的所有参数来调用 test main py 中的 Login 类和函数 login user 时 我收到错误 InvalidArgumentExcept
  • 在 MATLAB 中创建共享库

    一位研究人员在 MATLAB 中创建了一个小型仿真 我们希望其他人也能使用它 我的计划是进行模拟 清理一些东西并将其变成一组函数 然后我打算将其编译成C库并使用SWIG https en wikipedia org wiki SWIG创建一
  • Python 垃圾收集有时在 Jupyter Notebook 中不起作用

    我的一些 Jupyter 笔记本经常出现 RAM 不足的情况 而且我似乎无法释放不再需要的内存 这是一个例子 import gc thing Thing result thing do something thing None gc col
  • 如何在 Python 中执行相当于预处理器指令的操作?

    有没有办法在 Python 中执行以下预处理器指令 if DEBUG lt do some code gt else lt do some other code gt endif There s debug 这是编译器预处理的特殊值 if
  • 向量化 numpy bincount

    我有一个 2d numpy 数组 A我要申请np bincount 到矩阵的每一列A生成另一个二维数组B由原始矩阵每列的 bincounts 组成A 我的问题是 np bincount 是一个采用一维数组的函数 它不是像这样的数组方法B A
  • 如何将 pytest 装置与 django TestCase 一起使用

    我如何在TestCase方法 类似问题的几个答案似乎暗示我的例子应该有效 import pytest from django test import TestCase from myapp models import Category py
  • psutil:测量特定进程的CPU使用率

    我正在尝试测量进程树的 cpu 使用率 目前获取进程 没有子进程 的 cpu usage 就可以了 但我得到了奇怪的结果 import psutil p psutil Process PID p cpu percent 还给我float g
  • 在 Tensorflow 2.0 中的简单 LSTM 层之上添加 Attention

    我有一个由一个 LSTM 和两个 Dense 层组成的简单网络 如下所示 model tf keras Sequential model add layers LSTM 20 input shape train X shape 1 trai
  • 如何让你的精灵在pygame中跳跃

    目前我已经制作了一个平台游戏 可以左右移动我的角色 他从地上开始 关于如何让他跳的任何想法 因为我不明白 目前 如果我按住向上键 我的玩家精灵将连续向上移动 或者如果我按下它 我的玩家精灵将向上移动并保持向上 我想找个办法远离他 让我重新跌
  • 在读/写二进制数据结构时访问位域

    我正在为二进制格式编写一个解析器 这种二进制格式涉及不同的表 这些表同样采用二进制格式 通常包含不同的字段大小 其中 50 100 个之间 大多数这些结构都有位域 并且在 C 语言中表示时看起来像这样 struct myHeader uns

随机推荐