google lab 深度学习_Google 深度学习笔记 - 深度神经网络实践

2023-10-27

优化

Regularization

在前面实现的RELU连接的两层神经网络中,加Regularization进行约束,采用加l2 norm的方法,进行负反馈:

代码实现上,只需要对tf_sgd_relu_nn中train_loss做修改即可:

  • 可以用tf.nn.l2_loss(t)对一个Tensor对象求l2 norm
  • 需要对我们使用的各个W都做这样的计算(参考tensorflow官方example)
l2_loss = tf.nn.l2_loss(weights1) + tf.nn.l2_loss(weights2)
  • 添加到train_loss上
  • 这里还有一个重要的点,Hyper Parameter: β
  • 我觉得这是一个拍脑袋参数,取什么值都行,但效果会不同,我这里解释一下我取β=0.001的理由
  • 如果直接将l2_loss加到train_loss上,每次的train_loss都特别大,几乎只取决于l2_loss
  • 为了让原本的train_loss与l2_loss都能较好地对参数调整方向起作用,它们应当至少在同一个量级
  • 观察不加l2_loss,step 0 时,train_loss在300左右
  • 加l2_loss后, step 0 时,train_loss在300000左右
  • 因此给l2_loss乘0.0001使之降到同一个量级
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf_train_labels)) + 0.001 * l2_loss
  • 所有其他参数不变,训练3000次,准确率提高到92.7%
  • 黑魔法之所以为黑魔法就在于,这个参数可以很容易地影响准确率,如果β = 0.002,准确率提高到93.5%

OverFit问题

在训练数据很少的时候,会出现训练结果准确率高,但测试结果准确率低的情况

  • 缩小训练数据范围:将把batch数据的起点offset的可选范围变小(只能选择0-1128之间的数据):
offset_range = 1000offset = (step * batch_size) % offset_range
  • 可以看到,在step500后,训练集就一直是100%,验证集一直是77.6%,准确度无法随训练次数上升,最后的测试准确度是85.4%

DropOut

采取Dropout方式强迫神经网络学习更多知识

参考aymericdamien/TensorFlow-Examples中dropout的使用

  • 我们需要丢掉RELU出来的部分结果
  • 调用tf.nn.dropout达到我们的目的:
keep_prob = tf.placeholder(tf.float32)if drop_out: hidden_drop = tf.nn.dropout(hidden, keep_prob) h_fc = hidden_drop
  • 这里的keep_prob是保留概率,即我们要保留的RELU的结果所占比例,tensorflow建议的语法是,让它作为一个placeholder,在run时传入
  • 当然我们也可以不用placeholder,直接传一个0.5:
if drop_out: hidden_drop = tf.nn.dropout(hidden, 0.5) h_fc = hidden_drop
  • 这种训练的结果就是,虽然在step 500对训练集预测没能达到100%(起步慢),但训练集预测率达到100%后,验证集的预测正确率仍然在上升
  • 这就是Dropout的好处,每次丢掉随机的数据,让神经网络每次都学习到更多,但也需要知道,这种方式只在我们有的训练数据比较少时很有效
  • 最后预测准确率为88.0%

Learning Rate Decay

随着训练次数增加,自动调整步长

  • 在之前单纯两层神经网络基础上,添加Learning Rate Decay算法
  • 使用tf.train.exponential_decay方法,指数下降调整步长,具体使用方法官方文档说的特别清楚
  • 注意这里面的cur_step传给优化器,优化器在训练中对其做自增计数
  • 与之前单纯两层神经网络对比,准确率直接提高到90.6%

Deep Network

增加神经网络层数,增加训练次数到20000

  • 为了避免修改网络层数需要重写代码,用循环实现中间层
# middle layerfor i in range(layer_cnt - 2): y1 = tf.matmul(hidden_drop, weights[i]) + biases[i] hidden_drop = tf.nn.relu(y1) if drop_out: keep_prob += 0.5 * i / (layer_cnt + 1) hidden_drop = tf.nn.dropout(hidden_drop, keep_prob)
  • 初始化weight在迭代中使用
for i in range(layer_cnt - 2): if hidden_cur_cnt > 2: hidden_next_cnt = int(hidden_cur_cnt / 2) else: hidden_next_cnt = 2 hidden_stddev = np.sqrt(2.0 / hidden_cur_cnt) weights.append(tf.Variable(tf.truncated_normal([hidden_cur_cnt, hidden_next_cnt], stddev=hidden_stddev))) biases.append(tf.Variable(tf.zeros([hidden_next_cnt]))) hidden_cur_cnt = hidden_next_cnt
  • 第一次测试时,用正太分布设置所有W的数值,将标准差设置为1,由于网络增加了一层,寻找step调整方向时具有更大的不确定性,很容易导致loss变得很大
  • 因此需要用stddev调整其标准差到一个较小的范围(怎么调整有许多研究,这里直接找了一个来用)
stddev = np.sqrt(2.0 / n)
  • 启用regular时,也要适当调一下β,不要让它对原本的loss造成过大的影响
  • DropOut时,因为后面的layer得到的信息越重要,需要动态调整丢弃的比例,到后面的layer,丢弃的比例要减小
keep_prob += 0.5 * i / (layer_cnt + 1)
  • 训练时,调节参数,你可能遇到消失的梯度问题, 对于一个幅度为1的信号,在BP反向传播梯度时,每隔一层下降0.25,指数下降使得后面的层级根本接收不到有效的训练信号
  • 官方教程表示最好的训练结果是,准确率97.5%,
  • 我的nn_overfit.py开启六层神经网络, 启用Regularization、DropOut、Learning Rate Decay, 训练次数20000(应该还有再训练的希望,在这里虽然loss下降很慢了,但仍然在下降),训练结果是,准确率95.2%
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

google lab 深度学习_Google 深度学习笔记 - 深度神经网络实践 的相关文章

  • Triangle Tessellation with OpenGL 4.0

    FROM http prideout net blog p 48 This is the first of a two part article on tessellation shaders with OpenGL 4 0 This en
  • AVPlayer 视频播放

    1 AVPlayer AVPlayer 是一个用来播放基于时间的视听媒体的控制器对象 一个队播放和资源时间相隔信息进行管理的对象 而非一个视图或窗口控制器 AVPlayer支持播放从本地 分步下载或通过HTTP Live Streaming
  • 2023版golang面试题100道(map)

    面试题合集目录 map查找 假设当前 B 4 即桶数量为2 B 16个 要从map中获取k4对应的value 外链图片转存失败 源站可能有防盗链机制 建议将图片保存下来直接上传 k4的查找步骤 计算k4的hash值 通过低B位来确定在哪号桶
  • 静态代码和动态代码的区别_静态和动态代码分析之间有什么区别,您如何知道使用哪个?...

    让我们从一个运动类比开始 以帮助说明这两种方法之间的差异 静态代码分析类似于练习网和投球机练习棒球挥杆 最小的惊喜 经过几次挥杆后 您每次都知道球的确切位置 这有助于处理基础知识并确保您拥有良好的形式 虽然这有助于改善你的游戏 但它只能让你
  • 订单管理系统

    本专栏介绍了使用Qt开发的一些小型桌面软件 其中包括软件功能介绍 软件截图 主要代码等内容 此外 本专栏还提供完整的软件源码和安装包供有需要的同学下载 我的目标是开发一些简洁美观且实用的客户端小软件 如果能够为大家提供有用的软件或对学习有益
  • Hypertable 快速安装,仅需上载一个RPM包,零编译

    Hypertable 快速安装 仅需上载一个RPM包 零编译 Hypertable 快速安装 仅需下载一个RPM包 零编译 本文采用 单机安装 1 Hypertable 安装 Hypertable 的几种安装方式 单机 安装于单机 采用本地
  • Arduino core for the ESP32 安装失败问题处理方法

    文章目录 目的 离线开发板数据包 鱼 安装最新开发板数据包 渔 总结 目的 理论上Arduino IDE安装开发板数据包是非常方便的 不过在国内的网络环境下有时候就会很纠结 另外Arduino IDE对于下载数据这块也存在问题 经常下着下着
  • SQL语句连接筛选条件放在on和where后的区别(一篇足矣)

    sql查询这个东西 要说它简单 可以很简单 通常情况下只需使用增删查改配合编程语言的逻辑表达能力 就能实现所有功能 但是增删查改并不能代表sql语句的所有 完整的sql功能会另人望而生畏 就拿比普通增删查改稍微复杂一个层次的连接查询来说 盲
  • 【BP时序预测】基于BP神经网络的时间序列预测附matlab完整代码

    作者简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 matlab项目合作可私信 个人主页 Matlab科研工作室 个人信条 格物致知 更多Matlab仿真内容点击 智能优化算法 神经网络预测 雷达通信 无线传感器 电力系统 信号
  • PY32F003F18之RS485通讯

    PY32F003F18将USART2连接到RS485芯片 和其它RS485设备实现串口接收后再转发的功能 一 测试电路 二 测试程序 include USART2 h include stdio h getchar putchar scan
  • 单链表中什么时候使用二级指针

    在使用单链表时 一直有一个疑惑 初始化单链表时为什么要用二级指针 代码如下 typedef int ElemType ElemType类型根据实际情况而定 这里假设为int typedef struct Node ElemType data
  • CH9-HarmonyOS传感器和媒体管理

    文章目录 前言 目标 传感器概述 运动类传感器 运动类传感器工作原理 主流传感器表示 运作机制 核心模块 接口说明 开发步骤 使用传感器 方向传感器调用示例 相机调用 基本概念 主要接口 位置传感器 位置能力 基本概念 运作机制 获取设备的
  • PCM data flow - 7 - Frame & Period

    后面章节将分析 dma buffer 的管理 其中细节需要对音频数据相关概念有一定的了解 因此本章说明下音频数据中的几个重要概念 Sample 样本长度 音频数据最基本的单位 常见的有 8 位和 16 位 Channel 声道数 分为单声道
  • git-在现有项目上创建新项目

    简单说一下需求 假设你有一个项目A 现在需要在创建项目B 但是B是在A的基础上进行修改的 其实在A项目中创建分支可以 不过有些情况需要单独创建一个项目 1 将A项目拷贝一份 拷贝版就是我们的B 新 项目了 2 到B目录下 找到隐藏文件 gi
  • docker使用

    例子 docker run d name game p 8080 80 game2048 映射到系统的8080端口 http mirrors aliyun com docker ce linux centos 7 x86 64 stable
  • 密度图是一种用于可视化连续变量分布的图表类型,在R语言中可以使用各种库和函数来创建密度图。下面是一个示例代码,展示如何使用R语言创建密度图。

    密度图是一种用于可视化连续变量分布的图表类型 在R语言中可以使用各种库和函数来创建密度图 下面是一个示例代码 展示如何使用R语言创建密度图 首先 我们需要加载必要的库 在R中 可以使用ggplot2库来创建美观的图表 并使用density函
  • Flutter 开发 一个 字母+数字的随机数图片验证码

    Flutter 一个 字母 数字的随机数图片验证码 废话不多说 首先上效果图 使用方法 override void initState super initState getCode 调用随机数方法 getCode code String
  • redis设置缓存时间一般多少

    redis过期时间 redis过期时间介绍有时候我们并不希望redis的key一直存在 例如缓存 验证码等数据 我们希望它们能在一定时间内自动的被销毁 redis提供了一些命令 能够让我们对key设置过期时间 并且让key过期之后被自动删除
  • 医院PACS系统

    一 什么是PACS系统 医学影像系统 Picture Archiving and CommunicationSystems 简称PACS 是应用在医院影像科室的系统 主要的任务就是把日常产生的各种医学影像 包括核磁 CT 超声 各种X光机

随机推荐

  • CentOS系统中常用查看日志命令

    cat tail f 日 志 文 件 说 明 var log message 系统启动后的信息和错误日志 是Red Hat Linux中最常用的日志之一 var log secure 与安全相关的日志信息 var log maillog 与
  • 一键生成App图标所有尺寸的三个酷站分享

    目前很多app设计新手不懂如何去生成各种APP图标尺寸 其实很简单的 目前网上有很多一键生成App图标所有尺寸酷站和工具 在这里25学堂不啰嗦了 大家可以前往iOS和安卓APP启动图标的尺寸和圆角大小详解 去查看图标的尺寸大小 常见的ios
  • c++ 函数返回引用

    一 c 函数的返回分为以下几种情况 1 主函数main的返回值 这里提及一点 返回0表示程序运行成功 2 返回非引用类型 函数的返回值用于初始化在跳用函数出创建的临时对象 用函数返回值初始化临时对象与用实参初始化形参的方法是一样 的 如果返
  • 用实例去看看url传参怎么用

    用实例去剖析url传参方式 常见的url传参 1 传确定的值 2 传的是变量 3 传定值 多个 4 传变量 多个 常见的url传参 1 传确定的值 url https www baidu com data 123 通过一个例子去看一下怎么用
  • 国产开源python IDE 介绍

    1 目的 纯粹为了宣传 2 测试版本 1 2 4 3 感受 如果是写脚本还是挺好用的 而且轻便 但是如果写django项目等可能就要麻烦一些 纯粹个人感受 我之前使用pycharm 中间使用sublime 目前正在学习使用vscode 因为
  • 图的邻接矩阵、邻接表存储和图的广度优先搜索(BFS)、深度优先搜索(DFS)

    图的邻接矩阵 邻接表存储和图的广度优先搜索 BFS 深度优先搜索 DFS 图及其存储方式 广度优先搜索 深度优先搜索 本文将先介绍图的存储方式 邻接矩阵和邻接表 接着介绍图的基本算法 广度优先搜索和深度优先搜索 图及其存储方式 图是一种非线
  • [架构之路-201]-《软考-系统分析师》- 关键技术 - 结构化分析方法与面向对象分析(分析与设计的区别、pre架构设计、架构前设计)

    目录 前言 一 分析与设计的区别 二 结构化分析方法 2 1 实体关系图 E R 图 名词 2 2 数据流图 数据的流动 1 顶层图 2 逐层分解 2 3 状态转换图 动作 2 4 数据字典 三 面向对象分析方法 3 1 用例模型 3 2
  • 相量的加减乘除计算

    相量的加减乘除计算 矢量是物理学中的术语 是指具有大小 magnitude 和方向的量 如速度 加速度 力等等就是这样的量 向量是数学中的术语 也称为欧几里得向量 几何向量 矢量 与向量对应的量叫做数量 在物理学中称为标量 数量只有大小 没
  • docker容器启动的问题 - docker容器和虚拟机的比较 - docker的底层隔离机制

    目录 一 docker容器启动的问题 二 什么是docker仓库 三 虚拟机和docker容器的区别 docker的优势 docker的缺点 对比 四 docker的底层隔离机制 参考文献 LXC linux容器简介 在操作系统层次上为进程
  • java代码比较数据_比对两个数据库的差异:Java篇

    人类之所以进步 在于会使用工具 我们知道 有代码比对工具 有版本控制控制工具比对同一个文件不同人修改的地方 还有eclipse工具提供的Compare History 工具 我同事比较 同情 我每次发布产品版本都要手动比对本地和在线数据库的
  • 解决Navicat远程服务器2013-Lost connection to MYSQL server at 'reading for initial communication packet'

    问题所在 使用Navicat远程服务器mysql数据库时报错误 2013 Lost connection to MYSQL server at reading for initial communication packet system
  • 工具类Util中的@Value注解注入为空

    1 原因分析 在后端开发当中我们可能会使用到工具类 而一般的工具类中的方法都是静态方法 而 Value注解只能给普通变量注入值 不能直接给静态变量赋值 2 延伸 静态变量 即类变量 是一个类的属性 而不是对象的属性 spring依赖注入是基
  • 【Java基础知识 4】秒懂数组拷贝,感知新境界

    目录 一 前言 二 为什么数组的起始索引是0而不是1 三 起别名 四 System arraycopy与Arrays copyOf 浅拷贝
  • socat工具

    socat socat 是一个功能强大的网络工具 它允许在两个连接的数据流之间建立双向通信 该工具可以用于创建虚拟串口 转发网络流量 调试和测试网络应用程序等 以下是 socat 的一些主要特点和用途 连接不同类型的套接字 socat 可以
  • Asymmetric Gained Deep Image Compression With Continuous Rate Adaptation文献复现

    前言 相关论文阅读自行解决 这里主要是记录代码的学习与实验的复现 github地址 此代码非官方部署代码 而是私人实现的 本博客仅做学习记录 1 代码学习 1 1 主要框架部分 这里的主编解码器与高斯建模的方式 采用的是同joint上下联合
  • 推荐系统与深度学习-学习笔记六

    仅供学习 第六章 基于深度学习的推荐模型 6 1 基于DNN的推荐算法 wide deep 6 2 基于DeepFM的推荐算法 6 3 基于矩阵分解和图像特征的推荐算法 6 4 基于循环网络的推荐算法 6 5 基于生成对抗网络的推荐算法 第
  • 发布镜像【DockerHub或阿里云】

    发布镜像到DockerHub 登录DockerHub root us4ci6jaxom1jjz2 docker login u windrose0318 Password WARNING Your password will be stor
  • vs2013中静态库lib文件的生成与使用

    一 静态库lib文件的生成 1 文件 新建项目 Visual C win32项目 输入项目名称 例如 CMath 2 项目右键 添加 新建项 CMath h class CMath public CMath CMath void setX
  • 【iOS】UserDefaults使用的一些“坑”

    UserDefaults使用的一些 坑 项目场景 问题1 初始化程序组对应UserDefaults失败 原因分析 问题2 没有记录数据的时候 读取值为0 or false 导致配置使用时错误 原因分析 问题3 extension进程中监听需
  • google lab 深度学习_Google 深度学习笔记 - 深度神经网络实践

    优化 Regularization 在前面实现的RELU连接的两层神经网络中 加Regularization进行约束 采用加l2 norm的方法 进行负反馈 代码实现上 只需要对tf sgd relu nn中train loss做修改即可