使用 Tensorflow.js 计算损失梯度

2024-02-08

我正在尝试使用 Tensorflow.js 计算与网络可训练权重相关的损失梯度,以便将这些梯度应用于我的网络权重。在 python 中,这可以使用 tf.gradients() 函数轻松完成,该函数需要两个表示 dx 和 dy 的最小输入。 但是,我无法重现 Tensorflow.js 中的行为。我不确定我对权重损失梯度的理解是否错误,或者我的代码是否包含错误。

我花了一些时间分析 tfjs-node 包的核心代码,以了解当我们调用函数 tf.model.fit() 时它是如何完成的,但到目前为止收效甚微。

let model = build_model(); //Two stacked dense layers followed by two parallel dense layers for the output
let loss = compute_loss(...); //This function returns a tf.Tensor of shape [1] containing the mean loss for the batch.
const f = () => loss;
const grad = tf.variableGrads(f);
grad(model.getWeights());

model.getWeights() 函数返回一个 tf.variable() 数组,因此我假设该函数会计算每一层的 dL/dW,稍后我可以将其应用于网络的权重,但是,情况并非如此,因为我得到这个错误:

Error: Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.

我不太明白这个错误是什么意思。 那么我应该如何使用 Tensorflow.js 计算损失的梯度(类似于 Python 中的 tf.gradients())?

编辑 : 这是计算损失的函数:

function compute_loss(done, new_state, memory, agent, gamma=0.99) {
    let reward_sum = 0.;
    if(done) {
        reward_sum = 0.;
    } else {
        reward_sum = agent.call(tf.oneHot(new_state, 12).reshape([1, 9, 12]))
                    .values.flatten().get(0);
    }

    let discounted_rewards = [];
    let memory_reward_rev = memory.rewards;
    for(let reward of memory_reward_rev.reverse()) {
        reward_sum = reward + gamma * reward_sum;
        discounted_rewards.push(reward_sum);
    }
    discounted_rewards.reverse();

    let onehot_states = [];
    for(let state of memory.states) {
        onehot_states.push(tf.oneHot(state, 12));
    }
    let init_onehot = onehot_states[0];

    for(let i=1; i<onehot_states.length;i++) {
        init_onehot = init_onehot.concat(onehot_states[i]);
    }

    let log_val = agent.call(
        init_onehot.reshape([memory.states.length, 9, 12])
    );

    let disc_reward_tensor = tf.tensor(discounted_rewards);
    let advantage = disc_reward_tensor.reshapeAs(log_val.values).sub(log_val.values);
    let value_loss = advantage.square();
    log_val.values.print();

    let policy = tf.softmax(log_val.logits);
    let logits_cpy = log_val.logits.clone();

    let entropy = policy.mul(logits_cpy.mul(tf.scalar(-1))); 
    entropy = entropy.sum();

    let memory_actions = [];
    for(let i=0; i< memory.actions.length; i++) {
        memory_actions.push(new Array(2000).fill(0));
        memory_actions[i][memory.actions[i]] = 1;
    }
    memory_actions = tf.tensor(memory_actions);
    let policy_loss = tf.losses.softmaxCrossEntropy(memory_actions.reshape([memory.actions.length, 2000]), log_val.logits);

    let value_loss_copy = value_loss.clone();
    let entropy_mul = (entropy.mul(tf.scalar(0.01))).mul(tf.scalar(-1));
    let total_loss_1 = value_loss_copy.mul(tf.scalar(0.5, dtype='float32'));

    let total_loss_2 = total_loss_1.add(policy_loss);
    let total_loss = total_loss_2.add(entropy_mul);
    total_loss.print();
    return total_loss.mean();

}

EDIT 2:

我设法使用compute_loss作为model.compile()上指定的损失函数。但是,它只需要两个输入(预测、标签),所以它不适合我,因为我想输入多个参数。

我真的对这件事迷失了。


错误说明了一切。 您的问题与 tf.variableGrads 有关。loss应该是使用所有可用的计算得出的标量tf张量运算符。loss不应返回问题中所示的张量。

以下是损失应该是什么的示例:

const a = tf.variable(tf.tensor1d([3, 4]));
const b = tf.variable(tf.tensor1d([5, 6]));
const x = tf.tensor1d([1, 2]);

const f = () => a.mul(x.square()).add(b.mul(x)).sum(); // f is a function
// df/da = x ^ 2, df/db = x 
const {value, grads} = tf.variableGrads(f); // gradient of f as respect of each variable

Object.keys(grads).forEach(varName => grads[varName].print());

/!\ 请注意,梯度是根据使用创建的变量来计算的tf.variable

Update:

您没有按应有的方式计算梯度。这是修复方法。

function compute_loss(done, new_state, memory, agent, gamma=0.99) {
    const f = () => { let reward_sum = 0.;
    if(done) {
        reward_sum = 0.;
    } else {
        reward_sum = agent.call(tf.oneHot(new_state, 12).reshape([1, 9, 12]))
                    .values.flatten().get(0);
    }

    let discounted_rewards = [];
    let memory_reward_rev = memory.rewards;
    for(let reward of memory_reward_rev.reverse()) {
        reward_sum = reward + gamma * reward_sum;
        discounted_rewards.push(reward_sum);
    }
    discounted_rewards.reverse();

    let onehot_states = [];
    for(let state of memory.states) {
        onehot_states.push(tf.oneHot(state, 12));
    }
    let init_onehot = onehot_states[0];

    for(let i=1; i<onehot_states.length;i++) {
        init_onehot = init_onehot.concat(onehot_states[i]);
    }

    let log_val = agent.call(
        init_onehot.reshape([memory.states.length, 9, 12])
    );

    let disc_reward_tensor = tf.tensor(discounted_rewards);
    let advantage = disc_reward_tensor.reshapeAs(log_val.values).sub(log_val.values);
    let value_loss = advantage.square();
    log_val.values.print();

    let policy = tf.softmax(log_val.logits);
    let logits_cpy = log_val.logits.clone();

    let entropy = policy.mul(logits_cpy.mul(tf.scalar(-1))); 
    entropy = entropy.sum();

    let memory_actions = [];
    for(let i=0; i< memory.actions.length; i++) {
        memory_actions.push(new Array(2000).fill(0));
        memory_actions[i][memory.actions[i]] = 1;
    }
    memory_actions = tf.tensor(memory_actions);
    let policy_loss = tf.losses.softmaxCrossEntropy(memory_actions.reshape([memory.actions.length, 2000]), log_val.logits);

    let value_loss_copy = value_loss.clone();
    let entropy_mul = (entropy.mul(tf.scalar(0.01))).mul(tf.scalar(-1));
    let total_loss_1 = value_loss_copy.mul(tf.scalar(0.5, dtype='float32'));

    let total_loss_2 = total_loss_1.add(policy_loss);
    let total_loss = total_loss_2.add(entropy_mul);
    total_loss.print();
    return total_loss.mean().asScalar();
}

return tf.variableGrads(f);
}

请注意,您很快就会遇到内存消耗问题。建议将功能区分为tf.tidy处理张量。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Tensorflow.js 计算损失梯度 的相关文章

  • 动态速度计 javascript 或 jquery 插件

    我希望有动态ajax插件在页面上显示速度计 一个想法是我设置一个背景并旋转针 有人知道相关插件吗 这里有一些供您参考 http bernii github com gauge js http bernii github com gauge
  • jquery 验证错误位置

    这看起来很简单 但我无法弄清楚 我正在使用 jquery 验证插件 我验证所有文件 但我想要的是在输入文本行中显示验证消息警报 例如在电子邮件输入中 请填写电子邮件地址 但现在它出现在所有字段下 在我的html中
  • 具有 Firebase (FCM) 推送通知的 Node js

    我正在使用 Node js 开发 REST api 并且有一个休息端点来发送 firebase 推送通知 我的代码如下 const bodyParser require body parser var cors require cors v
  • jQuery 选择 # id 以单词为前缀,计数器为后缀

    有没有办法用 jQuery 选择所有带有前缀 my 和后缀 0 9 的 id 像这样的 my 1 4 还是可以用循环来实现 div div div div div div div div div div 第一个想法 似乎效果很好 div i
  • 如何按照编写的顺序迭代 javascript 对象属性

    我发现了代码中的一个错误 我希望通过最少的重构工作来解决该错误 此错误发生在 Chrome 和 Opera 浏览器中 问题 var obj 23 AA 12 BB iterating through obj s properties for
  • Node.js - console.log 不显示数组中的项目,而是显示 [Object]

    我在注销对象内数组的内容时遇到问题 实际的物体看起来像这样 var stuff accepted item1 item2 rejected response Foo envelope from The sender to new item1
  • 仅一页 JavaScript 应用程序

    您是否尝试过单页 Web 应用程序 即浏览器仅从服务器 获取 一页 其余部分由客户端 JavaScript 代码处理 此类 应用程序页面 的一个很好的例子是 Gmail 对于更简单的应用程序 例如博客和 CMS 使用这种方法有哪些优点和缺点
  • JavaScript 中数组的 HTML 数据列表值

    我有一个简单的程序 它必须从服务器上的文本文件中获取值 然后将数据列表填充为输入文本字段中的选择 为此 我想要采取的第一步是我想知道如何动态地将 JavaScript 数组用作数据列表选项 我的代码是
  • 未捕获的错误:找不到模块“jquery”

    我在用Electron https github com atom electron制作桌面应用程序 在我的应用程序中 我正在加载一个外部站点 Atom 应用程序之外 可以说http mydummysite index html http
  • Typeahead.js substringMatcher 函数说明

    我只是在做一些研究Typeahead js这是一个非常酷的图书馆 感谢文档 我已经成功地获得了一个基本的示例 该文档也非常好 但是我试图弄清楚以下代码块实际上在做什么 var substringMatcher function strs r
  • 如何使JavaScript函数在Eclipse“大纲视图”中可见?

    我有这样的代码 但如果它在匿名函数中定义 则无法打开函数大纲 类没有问题 我该如何概述something2 请分享一些提示 我可以将所有函数标记为构造函数 但这是无效的方法 start of track event required deb
  • 流星内存不足

    我正在使用流星来制作报废引擎 我必须执行一个 HTTP GET 请求 这会向我发送一个 xml 但这个 xml 大于 400 ko 我得到一个异常 内存不足 result Meteor http get http SomeUrl com 致
  • Chartjs刻度标签位置

    尝试让 Y 轴刻度标签看起来像image https i stack imgur com XgoxX png 位于秤顶部且不旋转 缩放选项当前如下所示 scales yAxes id temp scaleLabel display true
  • 类型“typeof import("/home/kartik/Desktop/Ecommerce/ecommerce/node_modules/firebase/index")”上不存在属性“auth”。 TS(2339)

    我是 FireBase 的初学者 我正在尝试使用 Angular 通过 FireBase 实现 Google 登录 我在 auth 时收到上述错误 我特此附上login component ts和package json package l
  • 为什么“tbody”不设置表格的背景颜色?

    我在用 tbody 作为 CSS 选择器来设置background color在一个表中 我这样做是因为我有多个 tbody 表内的部分 它们具有不同的背景颜色 我的问题是 当使用border radius在细胞上 细胞不尊重backgro
  • 如何在jquery中获取保存时间和当前时间的差异?

    我想在 javascript 或 jquery 中获取保存时间和当前时间之间的时差 我节省的时间看起来像Sun Oct 24 15 55 56 GMT 05 30 2010 java中的日期格式代码如下 String newDate 201
  • 数据表日期范围过滤器

    如何添加日期范围过滤器 like From To 我开始进行常规搜索和分页等工作 但我不知道如何制作日期范围过滤器 我正在使用数据表 1 10 11 版本 My code var oTable function callFilesTable
  • 如何在 Google 地图 V3 中创建编号地图标记?

    我正在制作一张上面有多个标记的地图 这些标记使用自定义图标 但我还想在顶部添加数字 我已经了解了如何使用旧版本的 API 来实现这一点 我怎样才能在V3中做到这一点 注意 当您将鼠标悬停在标记上时 标题 属性会创建一个工具提示 但我希望即使
  • Jquery - 通过在字符串中构建 id 的 id 获取元素

    我在使用 jquery 元素时遇到问题 我正在 var 中构造名称 例如 var myId myGotId myId attr title changed myId 返回空 我想通过 id 获取我的元素 但动态构建我的 Id 连接字符串 编
  • 如何在 gulp.src 中使用基本正则表达式?

    我正在尝试选择两个文件gulp src highcharts js and highcharts src js 当然 我知道我可以使用数组表达式显式添加这两个表达式 但出于学习目的 我尝试为它们编写一个表达式 我读过可以使用简单的正则表达式

随机推荐

  • 如何在知道线程 id 的情况下获取消息线程 URL?

    有如果我有消息 ID 如何构建链接以查看 facebook com 上的消息 http facebook stackoverflow com questions 7747622 how can i construct a link to v
  • jquery mobile 和 ui 不兼容

    尽管有很多人提到类似的兼容性问题 但 50 的问题在 StackOverflow 上得到了解决 我希望我的问题能够成为 51 49 考虑这段代码
  • macOS 公证:找不到 altool

    我想开始构建一个公证自动化脚本 但是 当我尝试在终端中使用 xcrun altool 时 出现以下错误 xcrun error unable to find utility altool not a developer tool or in
  • 如何正确引用本地XML Schema文件?

    我在 XML 文件中引用 XML 架构时遇到此问题 我的 XSD 位于此路径中 C environment workspace maven ws ProjectXmlSchema email xsd 但是 当我在 XML 文件中尝试像这样查
  • 服务器标记格式不正确

    这真是太愚蠢了 但却让我快疯了
  • 堆叠 UITableViews 不会在其视图下方传递触摸事件

    我将 3 个 UIView 堆叠在一起 UI表格视图平面视图根视图 TableView 位于顶部 rootView 位于底部 rootView 不可见 因为 TableView 在它上面 我在 rootView 中实现了以下代码 code
  • 错误 TS2707 通用类型“ɵɵDirectiveDeclaration”需要 6 到 8 个类型参数

    安装角度材料并将角度材料导入 app module ts 添加到项目后 我遇到错误 并且到目前为止所有解决方案都不起作用 我的角度为 14 节点为 16 第一个错误 实际上要长得多 Error node modules angular cd
  • 如何使用 Python 从巨大的 Excel 工作表中提取特定行的数据?

    我需要获取其中包含某些关键字 名称 的特定数据行并将它们写入另一个文件 起始文件是 1 5 GB Excel 文件 我不能只是打开它并将其另存为不同的格式 我应该如何使用 python 处理这个问题 我是 xlrd 的作者和维护者 请编辑您
  • 如何提高Python循环速度?

    我有一个包含 370k 记录的数据集 存储在 Pandas Dataframe 中 需要集成 我尝试了多处理 线程 Cpython 和循环展开 但我没有成功 显示的计算时间是 22 小时 任务如下 matplotlib inline fro
  • 开发游戏服务器用什么语言好?

    我只是想知道什么语言是开发支持大量 数千 用户的游戏服务器的不错选择 我涉足Python 但意识到这太麻烦了 因为它不会跨核心产生线程 意味着8个核心服务器 1个核心服务器 我也不太喜欢这种语言 自我 的东西让我感到恶心 我知道 C 就性能
  • 在xamarin forms pcl项目中打开远程pdf的最佳方法

    在适用于 Ios 和 Android 的 xamarin pcl 应用程序中 在服务器上加载 pdf 的最佳方式是什么 是否有一个好的 nuget 或者我们必须编写自定义渲染器 在应用程序中打开 PDF 您有几个选项 iOS 在其 WebV
  • 使用 Cython 将 Python 链接到共享库

    我正在尝试集成用以下语言编写的第三方库C和我的python应用程序使用Cython 我已经为测试编写了所有 python 代码 我无法找到设置此功能的示例 我有一个pyd pyx我手动创建的文件 第三方给了我一个header file h
  • 使用Delphi RTTI获取接口的字符串名称

    我已经证明我可以使用 Delphi 2010 从其 GUID 获取接口的名称 例如 IMyInterface 转换为字符串 IMyInterface 我想在 Delphi 7 中实现此目的 为了兼容性 这可能吗 或者是存在基本的编译器限制
  • 哪种数据结构最适合 VirtualStringTree?

    我想每个曾经使用过Delphi的VirtualStringTree的人都会同意它是一个很棒的控件 它是一个 虚拟 控件 您的数据必须保存在其他地方 所以我在想什么数据结构最适合这样的任务 IMO认为数据结构必须支持层次结构 它必须快速且易于
  • 扩展器的默认控制模板

    有人 可能使用 Blend 可以为我提供 WPF Expander 的工作默认 ControlTemplate 吗 我想做一些细微的修改 但似乎找不到有效模板的来源 提前致谢 我有混合 可以帮助你 这是 Blend 为我生成的内容
  • 根据日期分割数据框

    我正在尝试根据日期将数据框分成两个 此处的相关问题已解决 根据日期将数据帧分成两部分 https stackoverflow com questions 37532098 split dataframe into two on the ba
  • Chrome 语音合成具有较长的文本

    我在 Chrome 33 中尝试使用语音合成 API 时遇到问题 它可以完美地处理较短的文本 但如果我尝试较长的文本 它就会停在中间 一旦停止后 语音合成将无法在 Chrome 中的任何地方工作 直到浏览器重新启动 示例代码 http js
  • 责任链模式是否可以很好地替代一系列条件?

    当您需要按特定顺序执行一系列操作时 责任链模式是否可以很好地替代一系列条件 用这样的条件替换简单的方法是个好主意吗 public class MyListener implements MyHttpListener if false the
  • 线程安全类是否应该在其构造函数末尾有一个内存屏障?

    当实现一个线程安全的类时 我是否应该在其构造函数末尾包含一个内存屏障 以确保任何内部结构在可以访问之前已完成初始化 或者消费者有责任在使实例可供其他线程使用之前插入内存屏障 简化问题 下面的代码中是否存在竞争危险 由于初始化和线程安全类的访
  • 使用 Tensorflow.js 计算损失梯度

    我正在尝试使用 Tensorflow js 计算与网络可训练权重相关的损失梯度 以便将这些梯度应用于我的网络权重 在 python 中 这可以使用 tf gradients 函数轻松完成 该函数需要两个表示 dx 和 dy 的最小输入 但是