如何在 keras 中创建可训练参数?

2023-11-25

感谢您查看我的问题。

例如。

最终的输出是两个矩阵A和B的和,如下所示:

output = keras.layers.add([A, B])

现在,我想构建一个新参数 x 来更改输出。

我想让新输出 = Ax+B(1-x)

x 是一个可训练参数在我的网络中。

我应该怎么办? 请帮助我~非常感谢!

编辑(部分代码):

conv1 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input)
drop1 = Dropout(0.5)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(drop1)

conv2 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
drop2 = Dropout(0.5)(conv2)

up1 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop2))

#the line I want to change:
merge = add([drop2,up1])
#this layer is simply add drop2 and up1 layer.now I want to add a trainable parameter x to adjust the weight of thoese two layers.

我尝试使用代码,但仍然出现一些问题:

1.如何使用自己的图层?

merge = Mylayer()(drop2,up1)

或者其他方式?

2.out_dim的含义是什么? 这些参数都是3维矩阵。out_dim的含义是什么?

谢谢你……T.T

编辑2(已解决)

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

from keras.layers import add

class MyLayer(Layer):

def __init__(self, **kwargs):
    super(MyLayer, self).__init__(**kwargs)

def build(self, input_shape):

    self._x = K.variable(0.5)
    self.trainable_weights = [self._x]

    super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

def call(self, x):
    A, B = x
    result = add([self._x*A ,(1-self._x)*B])
    return result

def compute_output_shape(self, input_shape):
    return input_shape[0]

您必须创建一个继承自的自定义类Layer并使用创建可训练参数self.add_weight(...)。你可以找到一个这样的例子here and there.

对于您的示例,该图层在某种程度上看起来像这样:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self._A = self.add_weight(name='A', 
                                    shape=(input_shape[1], self.output_dim),
                                    initializer='uniform',
                                    trainable=True)
        self._B = self.add_weight(name='B', 
                                    shape=(input_shape[1], self.output_dim),
                                    initializer='uniform',
                                    trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self._A) + K.dot(1-x, self._B)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

Edit:仅仅基于我(错误地)假设的名字x是层输入并且您想要优化A and B。但是,正如您所说,您想要优化x。为此,您可以执行以下操作:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

class MyLayer(Layer):

    def __init__(self, **kwargs):
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self._x = self.add_weight(name='x', 
                                    shape=(1,),
                                    initializer='uniform',
                                    trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        A, B = x
        return K.dot(self._x, A) + K.dot(1-self._x, B)

    def compute_output_shape(self, input_shape):
        return input_shape[0]

Edit2:您可以使用调用该层

merge = Mylayer()([drop2,up1])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 keras 中创建可训练参数? 的相关文章

  • Python XLWT调整列宽

    XLWT 的易用性给我留下了深刻的印象 但有一件事我还没有弄清楚该怎么做 我正在尝试将某些行调整为显示所有字符所需的最小宽度 换句话说 如果双击单元格之间的分隔线 excel 会做什么 我知道如何将列宽调整为预定量 但我不确定如何确定显示所
  • 如何显示 pymongo.errors.OperationFailure 详细信息?

    写入 MongoDB 时 我在 python 中遇到 pymongo OperationsFailure 除了回溯之外 还有没有办法打印出详细信息或代码属性 另请参阅 http api mongodb org python current
  • 来自 yahoo 的 python lxml etree 小程序信息

    雅虎财经更新了他们的网站 我有一个 lxml etree 脚本 用于提取分析师建议 然而现在 分析师的建议已经存在 但只是以图表的形式出现 你可以看到一个例子这一页 https finance yahoo com quote CSX ana
  • 在 Windows 中安装 IPOPT 求解器以与 pyomo 一起使用

    如何安装 IPOPT 求解器以在 Windows 中与 pyomo 一起使用 我尝试了 pip install ipopt 但收到此错误 错误 为 ipopt 构建轮子失败 我正在使用 Windows 10 和 Python 3 7 4 在
  • SQLAlchemy如何为同一个表定义两个模型

    我有一个表 其中一列是具有两个值的 varchar groupA groupB 当我创建模型时 我想实现两件事 A 组模型 包含 X 数量的相关函数 B 组模型 包含 Y 数量的相关函数 两个模型的功能并不相同 尽管它们代表了same ta
  • 尝试将行附加到按对象分组中的每个组时出现奇怪的行为

    这个问题是关于一个函数在应用于两个不同的数据帧时以意想不到的方式表现的 更准确地说 是 groupby 对象 要么是我遗漏了一些明显错误的东西 要么是 pandas 中存在错误 我编写了以下函数 将一行附加到 groupby 对象中的每个组
  • 如何在Tensorflow中读取json文件?

    我正在尝试编写一个函数 用于读取张量流中的 json 文件 json 文件具有以下结构 bounding box y 98 5 x 94 0 height 197 width 188 rotation yaw 27 970195770263
  • 如何实例化具有已知系数的 Scikit-Learn 线性模型而不进行拟合

    背景 作为实验的一部分 我正在测试各种保存的模型 但其中一个模型来自我编写的算法 而不是来自 sklearn 模型拟合 但是 我的自定义模型仍然是线性模型 所以我想实例化一个LinearModel实例并设置coef and intercep
  • 对训练和测试数据帧使用相同的标签编码器

    我有 2 个不同的 csv 其中包含训练数据和测试数据 我从这些 train features df 和 test features df 创建了两个不同的数据帧 请注意 测试和训练数据有多个分类列 因此我需要对它们应用 labelEnco
  • 视频的 EXIF 之类的东西

    有没有从视频文件中获取信息的标准方法 对于图像 我们有 EXIF 数据 可用于获取有关图像文件的日期 时间 大小等信息 我想知道视频是否也有这样的东西 用例是 我有很多用数码相机拍摄的视频 我想将它们重命名为更有意义的名称 例如 YYYY
  • OSMNX - 边缘的哪个“部分”被认为是最近的

    我正在使用 OSMNX 中的 returned edges 函数 我不清楚在进行此计算时使用边缘的哪个 部分 它是边缘的任何部分吗 是中间点吗 对于网络中的长边来说 这会产生很大的差异 这取决于您如何参数化该函数 来自nearest edg
  • matplotlib 示例代码不适用于 python 虚拟环境

    我正在尝试在 matplotlib 中显示图像的 x y z 坐标 示例代码 http matplotlib org examples api image zcoord html在全局 python 安装上工作得很好 当我移动光标时 x y
  • 当从 python 使用 TSQL(SQL Server 上的 mssql)时,如何自动生成 SQLAlchemy 的 ORM 代码?

    SQLAlchemy 依赖于我构建这样的 ORM 类 from sqlalchemy import Column DateTime String Integer ForeignKey func from sqlalchemy orm imp
  • 在 SQLAlchemy 中删除父级后删除子级

    我的问题如下 我有两个型号Entry and Tag通过 SQLAlchemy 中的多对多关系链接 现在我想删除所有Tag没有任何对应的Entry后Entry被删除 示例来说明我想要的内容 Entry 1带标签python java Ent
  • Pygame - 两个圆圈的碰撞检测

    我正在制作一个碰撞检测程序 其中我的光标是一个半径为 20 的圆 当它碰到另一个圆时应该将值更改为 TRUE 出于测试目的 我在屏幕中心有一个半径为 50 的固定圆 我可以测试光标圆是否击中固定圆 但它不能正常工作 因为它实际上是在测试它是
  • 相比之下,超出了最大递归深度

    我写了这段代码来计算组合的数量 def fact n return 1 if n 1 else n fact n 1 def combinations n k return fact n fact n k fact k while True
  • 如何创建使用几个客户端权重的 FL 算法?

    基于此link https github com tensorflow federated tree 3c0852c5fef375198f5931ce31fd97f2df9c4d05 tensorflow federated python
  • pandas 数据帧和聚合中的行明智排序

    我在 pandas dataframe df 中有一个表 col1 col2 count 12 15 3 13 17 5 1 36 4 15 12 7 36 1 4 等等 我想要的是将 12 和 15 和 15 和 12 等计算值视为相同
  • Android Systrace 没有这样的文件或目录

    这是错误消息 D Programming Tools ADT bundle sdk platform tools systrace gt python systrace py Traceback most recent call last
  • Python:正则表达式 findall

    我使用 python 正则表达式从给定字符串中提取某些值 这是我的字符串 mystring txt sometext somemore text here some other text course course1 Id Name mar

随机推荐

  • app.js 中的全局变量可在路由中访问吗?

    我如何设置一个变量app js并使其在所有路线上都可用 至少在index js文件位于路径中 使用express框架和node js 实际上 使用 Express 对象上可用的 set 和 get 方法可以很容易地做到这一点 示例如下 假设
  • 如何在 MATLAB 中删除轴

    axis off不工作 function displayResults filename hObject eventdata handles Open filename file for reading fid fopen filename
  • 为什么Java的Arrays.sort方法对不同的类型使用两种不同的排序算法?

    Java 6 的Arrays sort方法对基元数组使用快速排序 对对象数组使用合并排序 我相信大多数时候快速排序比合并排序更快并且消耗更少的内存 我的实验支持这一点 尽管两种算法都是 O n log n 那么为什么不同的类型使用不同的算法
  • AFNetworking-2 waitUntilFinished 不起作用

    我知道有另一个类似的问题 但它适用于旧版本的 AFNetworking 而且并没有真正回答它 我有以下代码 AFHTTPRequestOperationManager manager AFHTTPRequestOperationManage
  • Google Maps API a.lat 不是函数错误

    我正在创建一个代码 可以通过使用分割作为分隔符来协调 CSV 文件中的数据 并计算两个输入坐标之间的距离 但结果总是显示错误a lat is not a function 我已经在网上浏览了有关此特定错误类型的信息 但似乎找不到正确的解决方
  • C# 中可调整大小的表格布局面板

    我发现 c net 2 0 中的表格布局面板非常原始 我希望允许我的用户调整表格布局面板中的列大小 但没有现成的选项可以执行此操作 有没有办法至少找出光标是否直接位于单元格的任何边框上 如果是 则哪个单元格位于其下方 可能有了这些信息 我们
  • 如何使用应用内自定义键盘的按钮输入文本

    我制作了一个应用程序内自定义键盘 它取代了系统键盘 并在我点击内部时弹出UITextField 这是我的代码 class ViewController UIViewController var myCustomKeyboard UIView
  • 数据表过滤:linq 与过滤器?

    过滤内存对象 数据表 这样做之间有很大的不同吗 var t dt Select id 2 vs var g dt AsEnumerable Where f gt f id ToString 2 我假设DataTable Select需要更多
  • 在 3D 世界中渲染 2D 精灵?

    假设我有精灵的 png 如何在 OpenGL 中渲染 2D 精灵 将图像作为我想要实现的效果的示例 另外 我想在屏幕上覆盖武器 例如底部图像中的步枪 有谁知道我如何实现这两种效果 任何帮助是极大的赞赏 在 3D 术语中 这称为 广告牌 广告
  • 计算太阳位于地平线以下/之上 X 度的时间

    我想知道太阳在地平线以下 之上 X 度的时间是什么时候 例如 我想找到太阳位于地平线以下 19 75 度的时间 我认为这与函数中的最高点有关date sunrise date sunset但我不确定 提前致谢 收集您需要的日期的太阳星历数据
  • 当我运行 Angular 4 应用程序时,哪个文件首先运行该应用程序?

    我正在使用 Angular 4 我有一个问题 当我运行项目并使用 ngserve 时 项目中的哪个文件首先呈现 有很多文件 例如main ts angular cli json app module我不明白当我运行 ngserve 时发生了
  • 从 ElasticSearch 中的数组中删除元素/对象,然后进行匹配查询

    我在尝试从 elasticsearch 中的数组中删除元素 对象时遇到问题 这是索引的映射 example1 mappings doc properties locations type geo point postDate type da
  • 是什么 !! JavaScript 中的(不是 not)运算符?

    我看到一些代码似乎使用了我不认识的运算符 以两个感叹号的形式 如下所示 有人可以告诉我这个操作员是做什么的吗 我看到这个的背景是 this vertical vertical undefined vertical this vertical
  • 分析生产代码

    我正在考虑实现一些在生产服务器上分析代码的东西 并需要一些最佳实践建议 显然 分析所有请求是一个坏主意 因为会增加开销 因此我正在研究一些可以根据请求随机调用分析器的技术 类似于每 10 000 个请求 1 个配置文件 我知道有一种方法可以
  • 列元素上的 CSS 3 动画“变换:缩放”在 Chrome 上不起作用

    我在 Chrome v44 中遇到问题 我尝试使用 transform scale 1 1 放大列项中的图像 但动画不起作用 如果我尝试在 Firefox 上使用 效果很好 我认为问题是由于 chrome 造成的 但我想知道是否有人找到了解
  • 我如何中断正在执行 (*TCPListener) Accept 的 goroutine?

    我最近正在玩 go 并尝试创建一些服务器来响应 tcp 连接上的客户端 我的问题是如何干净地关闭服务器并中断当前在以下调用中 被阻止 的 go 例程 func TCPListener 接受吗 根据接受的文档 Accept实现Listener
  • 使用 Interface Builder 嵌套自定义类/XIB

    我会尽力使其简短 我编写了一个自定义类 它使用几个 IBOutlet 属性扩展 UIView 并且它有一个与之关联的 XIB 这些 IBOutlet 链接到该 XIB 然后我想学习该类 将其嵌入到其他 XIB 例如 表格单元格 中 然后让它
  • Apple 和私有 API

    既然众所周知 App Store 提交正在接受私有 API 的使用测试 我需要问一个问题 私有 API 到底是什么 以便我可以避免使用它们 私有 API 是未记录在 SDK 中的 API 例如 框架类可能声明一个不适合外部开发人员使用的方法
  • IonAuth - 似乎随机将我注销

    我正在使用 ionAuth 它似乎几乎随机地将我注销 我正在使用 Codeigniter v2 1 4 它登录得很好 但是 ionAuth 似乎会随机注销 有没有办法强制会话保持活动状态 直到我调用 ionAuth gt logout 函数
  • 如何在 keras 中创建可训练参数?

    感谢您查看我的问题 例如 最终的输出是两个矩阵A和B的和 如下所示 output keras layers add A B 现在 我想构建一个新参数 x 来更改输出 我想让新输出 Ax B 1 x x 是一个可训练参数在我的网络中 我应该怎