如何将 Tensorflow BatchNormalization 与 GradientTape 结合使用?

2023-12-08

假设我们有一个使用 BatchNormalization 的简单 Keras 模型:

model = tf.keras.Sequential([
                     tf.keras.layers.InputLayer(input_shape=(1,)),
                     tf.keras.layers.BatchNormalization()
])

如何实际使用 GradientTape?以下似乎不起作用,因为它没有更新移动平均线?

# model training... we want the output values to be close to 150
for i in range(1000):
  x = np.random.randint(100, 110, 10).astype(np.float32)
  with tf.GradientTape() as tape:
    y = model(np.expand_dims(x, axis=1))
    loss = tf.reduce_mean(tf.square(y - 150))
  grads = tape.gradient(loss, model.variables)
  opt.apply_gradients(zip(grads, model.variables))

特别是,如果您检查移动平均值,它们将保持不变(检查 model.variables,平均值始终为 0 和 1)。我知道可以使用 .fit() 和 .predict(),但我想使用 GradientTape 并且我不知道如何执行此操作。某些版本的文档建议更新 update_ops,但这似乎在急切模式下不起作用。

特别是,经过上述训练后,以下代码将不会输出任何接近 150 的结果。

x = np.random.randint(200, 210, 100).astype(np.float32)
print(model(np.expand_dims(x, axis=1)))

使用梯度磁带模式 BatchNormalization 层应使用参数 Training=True 进行调用

example:

inp = KL.Input( (64,64,3) )
x = inp
x = KL.Conv2D(3, kernel_size=3, padding='same')(x)
x = KL.BatchNormalization()(x, training=True)
model = KM.Model(inp, x)

然后移动变量被正确更新

>>> model.layers[2].weights[2]
<tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32, numpy
=array([-0.00062087,  0.00015137, -0.00013239], dtype=float32)>
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将 Tensorflow BatchNormalization 与 GradientTape 结合使用? 的相关文章

  • Django 的内联管理:一个“预填充”字段

    我正在开发我的第一个 Django 项目 我希望用户能够在管理中创建自定义表单 并向其中添加字段当他或她需要它们时 为此 我在我的项目中添加了一个可重用的应用程序 可在 github 上找到 https github com stephen
  • 与区域指示符字符类匹配的 python 正则表达式

    我在 Mac 上使用 python 2 7 10 表情符号中的标志由一对表示区域指示符号 https en wikipedia org wiki Regional Indicator Symbol 我想编写一个 python 正则表达式来在
  • Python 中的哈希映射

    我想用Python实现HashMap 我想请求用户输入 根据他的输入 我从 HashMap 中检索一些信息 如果用户输入HashMap的某个键 我想检索相应的值 如何在 Python 中实现此功能 HashMap
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • Pandas/Google BigQuery:架构不匹配导致上传失败

    我的谷歌表中的架构如下所示 price datetime DATETIME symbol STRING bid open FLOAT bid high FLOAT bid low FLOAT bid close FLOAT ask open
  • 跟踪 pypi 依赖项 - 谁在使用我的包

    无论如何 是否可以通过 pip 或 PyPi 来识别哪些项目 在 Pypi 上发布 可能正在使用我的包 也在 PyPi 上发布 我想确定每个包的用户群以及可能尝试积极与他们互动 预先感谢您的任何答案 即使我想做的事情是不可能的 这实际上是不
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 使用字典映射数据帧索引

    为什么不df index map dict 工作就像df column name map dict 这是尝试使用index map的一个小例子 import pandas as pd df pd DataFrame one A 10 B 2
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • 如何使用python在一个文件中写入多行

    如果我知道要写多少行 我就知道如何将多行写入一个文件 但是 当我想写多行时 问题就出现了 但是 我不知道它们会是多少 我正在开发一个应用程序 它从网站上抓取并将结果的链接存储在文本文件中 但是 我们不知道它会回复多少行 我的代码现在如下 r
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • 模拟pytest中的异常终止

    我的多线程应用程序遇到了一个错误 主线程的任何异常终止 例如 未捕获的异常或某些信号 都会导致其他线程之一死锁 并阻止进程干净退出 我解决了这个问题 但我想添加一个测试来防止回归 但是 我不知道如何在 pytest 中模拟异常终止 如果我只
  • 在 JavaScript 函数的 Django 模板中转义字符串参数

    我有一个 JavaScript 函数 它返回一组对象 return Func id name 例如 我在传递包含引号的字符串时遇到问题 Dr Seuss ABC BOOk 是无效语法 I tried name safe 但无济于事 有什么解
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是

随机推荐

  • 将子文档添加到现有 Solr 6.4 集合文档会创建重复文档

    这个问题类似于Solr 不会覆盖 重复的 uniqueKey 条目 但我所处的情况是 我有大量现有文档已添加到集合中 没有子文档 并且我正在使用 独立而不是云 Solr 6 4 而不是 5 3 1 我们最近启用了子文档 以便我们可以存储更丰
  • 为什么 getSpeed() 在 android 上总是返回 0

    我需要从 GPS 获取速度和航向 然而我唯一拥有的号码是location getSpeed 为 0 或有时不可用 我的代码 String provider initLocManager if provider null return fal
  • Java WebStart 和认可的目录

    如何在 java webstart jnlp 文件中指定我的某些 jar 正在覆盖 JRE 内置实现 就像常规应用程序上认可的 lib 属性一样 似乎没有办法在网络启动中定义认可的目录 即使将 java endorsed dirs 属性定义
  • IntelliJ 问题 -> 无法创建名为“Main”的类

    标题说明了我的问题 我收到此错误消息 无法创建类无法解析模板 Class 错误信息 选定的类文件名 Main java 映射到非 java 文件类型 通过 TextMate 捆绑包支持的文件 有人对我如何解决这个问题有任何想法吗 请检查文件
  • 拆分字符串列值

    acctcode primekey groupby lt columns WDS 1 NULL lt values varchar FDS 2 NULL IRN 3 NULL SUM 4 1 2 3 STL 5 NULL WTR 6 NUL
  • 扩展 Asp.NET MVC3 控制器类

    我是一位经验丰富的 NET 程序员 也是一位使用 PHP 的 MVC 程序员 现在我是 MVC3 的新手 并尝试在其上构建我的第一个作品 因此我正在处理一些问题 对于初学者来说 如何扩展控制器类 有人可以指出我应该实施的指南 方法列表吗 T
  • 无法释放 C 中的 const 指针

    我怎样才能释放一个const char 我使用分配新内存malloc 当我尝试释放它时 我总是收到错误 不兼容的指针类型 导致此问题的代码类似于 char name Arnold const char str const char mall
  • Android 获取当前时间戳?

    我想像这样获取当前时间戳 1320917972 int time int System currentTimeMillis Timestamp tsTemp new Timestamp time String ts tsTemp toStr
  • Jenkins:根据相同 Jenkins 作业中的每个构建步骤结果发送电子邮件

    我只是想知道如何发送电子邮件电子邮件分机插件基于相同 Jenkins 作业的每个构建步骤结果 这是我的场景 我的 Jenkins 工作有 3 个构建步骤 构建步骤1 Pull latest code from github and Buil
  • 如何从 C++ 调用 fortran 例程?

    我希望从我的 C 代码中调用 fortran 例程 cbesj f 如何实现此目的 以下是我已完成的步骤 从 netlib amos 网页下载 cbesj f 以及依赖项 http www netlib org cgi bin netlib
  • 自动完成建议列表的 z-index 错误,我该如何更改?

    似乎我的自动完成列表的 z index 比我网站的某些元素低 所以它暴露不足 我应该编辑什么类 使用editCSS我播种这些类 并添加 我网站的z索引 但很少有不影响的是1 ui corner all ui menu item ingred
  • 如何打印第三列到最后一列?

    我正在尝试从 DbgView 日志文件中删除前两列 我对其中不感兴趣 我似乎找不到从第 3 列开始打印直到行尾的示例 请注意 每行都有可变数量的列 或更简单的解决方案 cut f 3 INPUTFILE只需添加正确的分隔符 d 即可获得相同
  • JTable 中的列的多个单元格渲染器?

    假设我有以下 JTable 按下按钮后就会显示 Name True Hello World False Foo Bar True Foo False Bar 我想渲染那些单元格最初对于 JCheckBox 来说是正确的 并且所有单元格都是最
  • MonoTouch.Dialog 崩溃

    我有一个小型测试应用程序 它仅在 3 个页面之间循环 这是应用程序委托 public override bool FinishedLaunching UIApplication app NSDictionary options sessio
  • 如何从嵌套函数内部访问 Stimulus JS 控制器方法?

    我有一个 Stimulus 控制器 其中有一个 setSegments 函数 然后在 connect 方法中使用以下代码 connect const options overview container document getElemen
  • 十六进制到二进制转换

    我已通过十六进制转换器将 jpeg 文件转换为十六进制代码 现在如何将该十六进制转换为二进制并另存为Jpeg磁盘上的文件 Like var 声明为十六进制代码 然后将该 var 十六进制代码转换为二进制并保存在磁盘上 Edit Var my
  • 如何使用X509使用JDBC连接MySQL?

    我已经设置了 MySQL 社区服务器 5 1 数据库服务器 我已经设置了 SSL 创建了证书等 我创建了一个具有 REQUIRES X509 属性的用户 我可以使用命令行客户端 mysql 使用此用户进行连接 并且 status 命令显示
  • 请解释一下此电子邮件验证正则表达式:[关闭]

    很难说出这里问的是什么 这个问题模棱两可 含糊不清 不完整 过于宽泛或言辞激烈 无法以目前的形式合理回答 如需帮助澄清此问题以便重新打开 访问帮助中心 我有这个脚本使用正则表达式来检查表单字段是否包含有效的电子邮件地址 请从声明中解释一下
  • Firebase 安全规则 - Auth 生成的 UID 是否应该保密? [复制]

    这个问题在这里已经有答案了 我一直在阅读 Firebase 实时数据库安全规则指南 https firebase google com docs database security 我有点困惑是否应该将 Firebase Auth 生成的
  • 如何将 Tensorflow BatchNormalization 与 GradientTape 结合使用?

    假设我们有一个使用 BatchNormalization 的简单 Keras 模型 model tf keras Sequential tf keras layers InputLayer input shape 1 tf keras la