haiku定义简单的模型并初始化参数

2024-01-04

Haiku 是一个基于 JAX 的深度学习库,旨在提供简洁、灵活且易于使用的 API,以构建和训练神经网络模型。

import haiku as hk
import jax
import jax.numpy as jnp

### 1. 定义简单的二层神经网络
class SimpleNN(hk.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

    def __call__(self, x):
        x = hk.Linear(self.hidden_size)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(self.output_size)(x)
        out = jax.nn.sigmoid(x)
        return out
    
    
### 2. 创建模块实例
# hk.transform将普通的Python函数转换为可训练的Haiku模块。
# 转换后可以进行参数初始化、模块应用等操作。
model = hk.transform(lambda x: SimpleNN(64, 10)(x)) 
#print(type(model))
#print(model)

### 3. 模块参数初始化
# jax.random.PRNGKey用于伪随机数生成。
# 使用伪随机数生成器(PRNG)可以确保在相同的初始状态下获得相同的随机数序列,从而保持实验的可重复性。
rng = jax.random.PRNGKey(42)
# print(rng)
## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
input_data = jnp.ones((1, 128))
params = model.init(rng, input_data)

## 查看随机初始化的参数,rng保证每次初始化出相同的参数
#print("Initialized Parameters:", params)
#print(params)
print(params['simple_nn/linear']['w'].shape)
print(params['simple_nn/linear_1']['w'].shape)

### 4.模型预测
# apply方法接受模块参数和输入数据,并返回模块的输出数据
# 在模型训练时,apply方法是对整个模块进行前向传播的操作
output_data = model.apply(params, rng, input_data)
print("Output Data:", output_data)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

haiku定义简单的模型并初始化参数 的相关文章

随机推荐

  • 深入理解左倾红黑树 | 京东物流技术团队

    平衡二叉搜索树 平衡二叉搜索树 Balanced Binary Search Tree 的每个节点的左右子树高度差不超过 1 它可以在 O logn 时间复杂度内完成插入 查找和删除操作 最早被提出的自平衡二叉搜索树是 AVL 树 AVL
  • Win7系统提示找不到KBDURDU.DLL文件的解决办法

    其实很多用户玩单机游戏或者安装软件的时候就出现过这种问题 如果是新手第一时间会认为是软件或游戏出错了 其实并不是这样 其主要原因就是你电脑系统的该dll文件丢失了或没有安装一些系统软件平台所需要的动态链接库 这时你可以下载这个KBDURDU
  • 如何使用Requests库采集前程无忧招聘数据

    使用Requests库来采集前程无忧 智联招聘 的数据涉及以下步骤 了解目标网站结构 首先 需要了解前程无忧网站的结构 查看其页面布局 URL结构和需要采集的信息位置 发送HTTP请求 使用Requests库发送HTTP请求获取页面内容 通
  • 使用pytorch构建图卷积网络预测化学分子性质

    在本文中 我们将通过化学的视角探索图卷积网络 我们将尝试将网络的特征与自然科学中的传统模型进行比较 并思考为什么它的工作效果要比传统的方法好 图和图神经网络 化学或物理中的模型通常是一个连续函数 例如y f x x x x 其中x x x
  • CAD Exchanger SDK 3.24 for Android Crack

    CAD Exchanger SDK Software Libraries to Read Write and Visualize 3D CAD files Quickly and easily enrich your web server
  • 蜜罐溯源以及蜜罐HFish的使用

    一 蜜罐是什么 蜜罐技术本质上是一种对攻击方进行欺骗的技术 通过布置一些作为诱饵的主机 网络服务或者信息 诱使攻击方对它们实施攻击 从而可以对攻击行为进行捕获和分析 了解攻击方所使用的工具与方法 推测攻击意图和动机 能够让防御方清晰地了解他
  • Matlab图像处理系列——图像复原之噪声模型仿真

    微信公众号上线 搜索公众号 小灰灰的FPGA 关注可获取相关源码 定期更新有关FPGA的项目以及开源项目源码 包括但不限于各类检测芯片驱动 低速接口驱动 高速接口驱动 数据信号处理 图像处理以及AXI总线等 本节目录 一 图像复原的模型 二
  • VS Code 自动选择Python3 venv

    我们使用VS Code写Python代码时 往往希望这个项目的依赖和其他项目或者全局的python环境隔离开 VS Code不像PyCharm那样自动完成 但是我们也可以快速的进行设置 首先我们需要把python项目所在的目录添加为VS C
  • CAD Exchanger SDK 3.24 FOR WIN Crack

    Manufacturing Toolkit and Web Toolkit enhancements Unity performance optimization renaming and rotating SDK examples in
  • vscode插件离线安装地址

    因内网开发 编辑器不可联网 插件需要离线安装 vscode插件商店 Extensions for Visual Studio family of products Visual Studio Marketplace
  • SpringCloud+saToken实现登录及权限认证

    SpringCloud saToken实现登录及权限认证 文章目录 SpringCloud saToken实现登录及权限认证 1 为什么要用sa Token 2 sa Token功能 3 springcloud集成sa token 3 1
  • CAD Exchanger SDK 3.24 for Linux Crack

    CAD Exchanger SDK Software Libraries to Read Write and Visualize 3D CAD files Quickly and easily enrich your web server
  • Jlink V9刷入自动升级固件

    Jlink V9刷入自动升级固件 1 所需工具 一个可用的jlink 一个待刷jlink 2 接线如图 3 查看待刷Jlink的主控芯片型号 我的型号为stm32f205rc 4 刷入固件 固件下载地址 https download csd
  • 一批J-link V9变砖拯救

    一批J link V9变砖拯救 weixin 51547258 于 2023 05 05 16 05 09 发布 阅读量282 收藏 点赞数 文章标签 单片机 stm32 嵌入式硬件 版权 手里有一批J link V9版本 由于误操作升级固
  • 指尖互鉴APP 毕业设计源码48084

    赠送源码 毕业设计 SSM指尖互鉴app https www bilibili com video BV15t4y1Z7Gs vd source 72970c26ba7734ebd1a34aa537ef5301 题目 SSM 指尖互鉴APP
  • 14.9-时序和组合的混合逻辑——使用非阻塞赋值

    时序和组合的混合逻辑 使用非阻塞赋值 1 在一个always块中同时实现组合逻辑和时序逻辑 2 将组合和时序逻辑分别写入两个always块中 原则4 在同一个always块中描述时序和组合逻辑混合电路时 用非阻塞赋值 1 在一个always
  • 14.11-对同一变量进行多次赋值

    对同一变量进行多次赋值 1 同一变量多次赋值 即便是非阻塞赋值 也会存在竞争冒险 原则6 严禁在多个always块中对同一个变量赋值 包括阻塞和非阻塞赋值 1 同一变量多次赋值 即便是非阻塞赋值 也会存在竞争冒险 两个always块都对输出
  • zzz666

    6
  • 拓展:vue 父组件调用子组件方法ref(且父组件可通过ref调用的方法传值给子组件)

    1 ref被用来给元素或子组件注册引用信息 引用信息将会注册在父组件的 refs对象上 一 ref被用来给元素或子组件注册引用信息 引用信息将会注册在父组件的 refs对象上 div class formBtn fl 111 div div
  • haiku定义简单的模型并初始化参数

    Haiku 是一个基于 JAX 的深度学习库 旨在提供简洁 灵活且易于使用的 API 以构建和训练神经网络模型 import haiku as hk import jax import jax numpy as jnp 1 定义简单的二层神