e-009 matlab,matlab使用贝叶斯优化的深度学习

2023-10-27

此示例说明如何将贝叶斯优化应用于深度学习,以及如何为卷积神经网络找到最佳网络超参数和训练选项。

要训练深度神经网络,必须指定神经网络架构以及训练算法的选项。选择和调整这些超参数可能很困难并且需要时间。贝叶斯优化是一种非常适合用于优化分类和回归模型的超参数的算法。

准备数据

下载CIFAR-10数据集[1]。该数据集包含60,000张图像,每个图像的大小为32 x 32和三个颜色通道(RGB)。整个数据集的大小为175 MB。

加载CIFAR-10数据集作为训练图像和标签,并测试图像和标签。

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);idx = randperm(numel(YTest),5000);XValidation = XTest(:,:,:,idx);XTest(:,:,:,idx) = [];YValidation = YTest(idx);YTest(idx) = [];

您可以使用以下代码显示训练图像的样本。

figure; idx = randperm(numel(YTrain),20); for i = 1:numel(idx) subplot(4,5,i); imshow(XTrain(:,:,:,idx(i))); end

选择要优化的变量

选择要使用贝叶斯优化进行优化的变量,并指定要搜索的范围。此外,指定变量是否为整数以及是否在对数空间中搜索区间。优化以下变量:

网络部分的深度。此参数控制网络的深度。该网络具有三个部分,每个部分具有SectionDepth相同的卷积层。因此,卷积层的总数为3*SectionDepth。脚本后面的目标函数将每一层中的卷积过滤器数量与成正比1/sqrt(SectionDepth)。结果,对于不同的截面深度,每次迭代的参数数量和所需的计算量大致相同。

最佳学习率取决于您的数据以及您正在训练的网络。

随机梯度下降动量。

L2正则化强度。

optimVars = [optimizableVariable('SectionDepth',[1 3],'Type','integer')optimizableVariable('InitialLearnRate',[1e-2 1],'Transform','log')optimizableVariable('Momentum',[0.8 0.98])optimizableVariable('L2Regularization',[1e-10 1e-2],'Transform','log')];

执行贝叶斯优化

使用训练和验证数据作为输入,为贝叶斯优化器创建目标函数。目标函数训练卷积神经网络,并在验证集上返回分类误差。

ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation);

通过最小化验证集上的分类误差来执行贝叶斯优化。 为了充分利用贝叶斯优化的功能,您应该至少执行30个目标函数评估。

每个网络完成训练后,bayesopt将结果打印到命令窗口。bayesopt然后该函数返回中的文件名BayesObject.UserDataTrace。目标函数将训练有素的网络保存到磁盘,并将文件名返回给bayesopt。

|===================================================================================================================================|| Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | SectionDepth | InitialLearn-| Momentum | L2Regulariza-|| | result | | runtime | (observed) | (estim.) | | Rate | | tion ||===================================================================================================================================|| 1 | Best | 0.19 | 2201 | 0.19 | 0.19 | 3 | 0.012114 | 0.8354 | 0.0010624 |

| 2 | Accept | 0.3224 | 1734.1 | 0.19 | 0.19636 | 1 | 0.066481 | 0.88231 | 0.0026626 |

| 3 | Accept | 0.2076 | 1688.7 | 0.19 | 0.19374 | 2 | 0.022346 | 0.91149 | 8.242e-10 |

| 4 | Accept | 0.1908 | 2167.2 | 0.19 | 0.1904 | 3 | 0.97586 | 0.83613 | 4.5143e-08 |

| 5 | Accept | 0.1972 | 2157.4 | 0.19 | 0.19274 | 3 | 0.21193 | 0.97995 | 1.4691e-05 |

| 6 | Accept | 0.2594 | 2152.8 | 0.19 | 0.19 | 3 | 0.98723 | 0.97931 | 2.4847e-10 |

| 7 | Best | 0.1882 | 2257.5 | 0.1882 | 0.18819 | 3 | 0.1722 | 0.8019 | 4.2149e-06 |

| 8 | Accept | 0.8116 | 1989.7 | 0.1882 | 0.18818 | 3 | 0.42085 | 0.95355 | 0.0092026 |

| 9 | Accept | 0.1986 | 1836 | 0.1882 | 0.18821 | 2 | 0.030291 | 0.94711 | 2.5062e-05 |

| 10 | Accept | 0.2146 | 1909.4 | 0.1882 | 0.18816 | 2 | 0.013379 | 0.8785 | 7.6354e-09 |

| 11 | Accept | 0.2194 | 1562 | 0.1882 | 0.18815 | 1 | 0.14682 | 0.86272 | 8.6242e-09 |

| 12 | Accept | 0.2246 | 1591.2 | 0.1882 | 0.18813 | 1 | 0.70438 | 0.82809 | 1.0102e-06 |

| 13 | Accept | 0.2648 | 1621.8 | 0.1882 | 0.18824 | 1 | 0.010109 | 0.89989 | 1.0481e-10 |

| 14 | Accept | 0.2222 | 1562 | 0.1882 | 0.18812 | 1 | 0.11058 | 0.97432 | 2.4101e-07 |

| 15 | Accept | 0.2364 | 1625.7 | 0.1882 | 0.18813 | 1 | 0.079381 | 0.8292 | 2.6722e-05 |

| 16 | Accept | 0.26 | 1706.2 | 0.1882 | 0.18815 | 1 | 0.010041 | 0.96229 | 1.1066e-05 |

| 17 | Accept | 0.1986 | 2188.3 | 0.1882 | 0.18635 | 3 | 0.35949 | 0.97824 | 3.153e-07 |

| 18 | Accept | 0.1938 | 2169.6 | 0.1882 | 0.18817 | 3 | 0.024365 | 0.88464 | 0.00024507 |

| 19 | Accept | 0.3588 | 1713.7 | 0.1882 | 0.18216 | 1 | 0.010177 | 0.89427 | 0.0090342 |

| 20 | Accept | 0.2224 | 1721.4 | 0.1882 | 0.18193 | 1 | 0.09804 | 0.97947 | 1.0727e-10 |

|===================================================================================================================================|| Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | SectionDepth | InitialLearn-| Momentum | L2Regulariza-|| | result | | runtime | (observed) | (estim.) | | Rate | | tion ||===================================================================================================================================|| 21 | Accept | 0.1904 | 2184.7 | 0.1882 | 0.18498 | 3 | 0.017697 | 0.95057 | 0.00022247 |

| 22 | Accept | 0.1928 | 2184.4 | 0.1882 | 0.18527 | 3 | 0.06813 | 0.9027 | 1.3521e-09 |

| 23 | Accept | 0.1934 | 2183.6 | 0.1882 | 0.1882 | 3 | 0.018269 | 0.90432 | 0.0003573 |

| 24 | Accept | 0.303 | 1707.9 | 0.1882 | 0.18809 | 1 | 0.010157 | 0.88226 | 0.00088737 |

| 25 | Accept | 0.194 | 2189.1 | 0.1882 | 0.18808 | 3 | 0.019354 | 0.94156 | 9.6197e-07 |

| 26 | Accept | 0.2192 | 1752.2 | 0.1882 | 0.18809 | 1 | 0.99324 | 0.91165 | 1.1521e-08 |

| 27 | Accept | 0.1918 | 2185 | 0.1882 | 0.18813 | 3 | 0.05292 | 0.8689 | 1.2449e-05 |

__________________________________________________________

Optimization completed.MaxTime of 50400 seconds reached.Total function evaluations: 27Total elapsed time: 51962.3666 seconds.Total objective function evaluation time: 51942.8833Best observed feasible point:SectionDepth InitialLearnRate Momentum L2Regularization____________ ________________ ________ ________________3 0.1722 0.8019 4.2149e-06Observed objective function value = 0.1882Estimated objective function value = 0.18813Function evaluation time = 2257.4627Best estimated feasible point (according to models):SectionDepth InitialLearnRate Momentum L2Regularization____________ ________________ ________ ________________3 0.1722 0.8019 4.2149e-06Estimated objective function value = 0.18813Estimated function evaluation time = 2166.2402

评估最终网络

加载优化中发现的最佳网络及其验证准确性。

valError = 0.1882

预测测试集的标签并计算测试误差。将测试集中每个图像的分类视为具有一定成功概率的独立事件,这意味着错误分类的图像数量遵循二项式分布。使用它来计算标准误差(testErrorSE)和testError95CI广义误差率的大约95%置信区间()。这种方法通常称为_Wald方法_。

testError = 0.1864

testError95CI = 1×20.1756 0.1972

绘制混淆矩阵以获取测试数据。通过使用列和行摘要显示每个类的精度和召回率。

1460000022068382

您可以使用以下代码显示一些测试图像及其预测的类以及这些类的概率。

优化目标函数

定义用于优化的目标函数。

定义卷积神经网络架构。

在卷积层上添加填充,以便空间输出大小始终与输入大小相同。

每次使用最大池化层对空间维度进行2倍的下采样时,将过滤器的数量增加2倍。这样做可确保每个卷积层所需的计算量大致相同。

选择与成正比的滤波器数量,以1/sqrt(SectionDepth)使不同深度的网络具有大致相同数量的参数,并且每次迭代所需的计算量大致相同。要增加网络参数的数量和整体网络灵活性,请增加numF。要训练更深的网络,请更改SectionDepth变量的范围。

使用convBlock(filterSize,numFilters,numConvLayers)创建的块numConvLayers卷积层,每个具有指定filterSize和numFilters过滤器,并且每个随后分批正常化层和RELU层。该convBlock函数在本示例的末尾定义。

指定验证数据,然后选择一个'ValidationFrequency'值,以便trainNetwork每个时期对网络进行一次验证。训练固定的时期数,并在最后一个时期将学习率降低10倍。这减少了参数更新的噪音,并使网络参数的沉降更接近损耗函数的最小值。

使用数据增强可沿垂直轴随机翻转训练图像,并将它们随机水平和垂直转换为四个像素。

训练网络并在训练过程中绘制训练进度。

1460000022068384

1460000022068383

在验证集上评估经过训练的网络,计算预测的图像标签,并在验证数据上计算错误率。

创建一个包含验证错误的文件名,然后将网络,验证错误和培训选项保存到磁盘。目标函数fileName作为输出参数bayesopt返回,并返回中的所有文件名BayesObject.UserDataTrace。

该convBlock函数创建一个numConvLayers卷积层块,每个卷积层都有一个指定的filterSize和numFilters过滤器,每个卷积层后面都有一个批处理归一化层和一个ReLU层。

参考文献

[1]克里热夫斯基,亚历克斯。“从微小的图像中学习多层功能。” (2009)。https://www.cs.toronto.edu/~k...

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

e-009 matlab,matlab使用贝叶斯优化的深度学习 的相关文章

  • 【动态规划】合唱队形

    题目描述 n位同学站成一排 音乐老师要请其中的 n K 位同学出列 使得剩下的K位同学排成合唱队形 合唱队形是指这样的一种队形 设K位同学从左到右依次编号为1 2 K 他们的身高分别为T1 T2 TK 则他们的身高满足T1 lt Ti l
  • 关于代码家(干货集中营)共享android端知识点综合整理

    关于代码家 干货集中营 共享android端知识点综合整理 标签 开源项目自定义控件教程特效工具 2016 03 08 13 23 8520人阅读 评论 2 收藏 举报 分类 移动开发 28 版权声明 本文为博主原创文章 未经博主允许不得转
  • 探索MySQL错误: 1241 - Operand should contain 1 column(s)问题解决方案

    AI绘画关于SD MJ GPT SDXL百科全书 面试题分享点我直达 2023Python面试题 2023最新面试合集链接 2023大厂面试题PDF 面试题PDF版本 java python面试题 项目实战 AI文本 OCR识别最佳实践 A
  • Qt中moc问题(qt moc 处理 cpp)

    我用的是QT Designer 一般只有用到信号signals和槽slots时才会用到MOC 因为采用信号signals和槽slots是QT的特性 而C 没有 所以采用了MOC 元对象编译器 把信号signals和槽slots部分编译成C
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一
  • libevent源码学习(5):TAILQ_QUEUE解析

    目录 前言 结点定义 链表初始化 链表查询及遍历 链表查询 链表遍历 插入结点 头插法 尾插法 前插法 后插法 删除结点 替换结点 总结 前言 在libevent中使用到了TAILQ数据结构 看了一下其他资料 发现TAILQ这一数据结构不仅
  • TVM 0.9 在 ubuntu(任意版本)上的安装(简单且保姆级!)

    近一年来尝试过TVM在ubuntu16 04 ubuntu18 04 ubuntu20 04 以及windows上的安装 也看了官方教程和网上各种博客 踩坑无数 现在总结在Ubuntu上踩坑几率最小的安装流程如下 建议学习TVM一开始就在u
  • fisco-bcos使用caliper进行压力测试

    使用caliper对fisco bcos进行压力测试 通过Caliper进行压力测试程序 注意 官网给出的测试案例会出现错误 我会给出相应的解决方案 本文以centos系统为例进行测试 1 环境要求 第一步 配置基本环境 部署Caliper
  • 勇敢的人

    昨天刚写完做个勇敢的人 发现一篇博客写的非常好 于是果断给它转载一下 这位老师的发言 在某个瞬间打动了现在的我 请观看 https blog csdn net zkl99999 article details 46683535 关键字 毛尖
  • iOS检测网络连接状态

    请从Apple网站下载示例 点此下载 然后将Reachability h 和 Reachability m 加到自己的项目中 并引用 SystemConfiguration framework 就可以使用了 Reachability 中定义
  • Java 23种设计模式的分类和使用场景

    听说过GoF吧 GoF是设计模式的经典名著Design Patterns Elements of Reusable Object Oriented Software 中译本名为 设计模式 可复用面向对象软件的基础 的四位作者 他们分为是 E
  • Qt创建的子线程不断循环,主线程界面一直处于无响应状态

    说明 今天用子线程处理数据 但只创建了子线程 还没有来得及让子线程处理大量的数据 在子线程只作了简单处理 发现主线程界面一直不能响应 在主线程让子线程参数isStop true 也跳不出循环 while isStop emit mySign
  • KCF追踪器在opencv和RM中的应用

    1 理论部分 参考文档 KCF目标跟踪方法分析与总结 概念 1 判别式模型和生成式模型 判别式模型 根据训练数据得到分类函数和分界面 比如说根据SVM模型得到一个分界面 然后直接计算条件概率 P y x 我们将最大的 P y x 生成式模型
  • 模拟登陆 Selenium

    模拟登陆 使用爬虫实现登录操作 为何需要做模拟登陆 有些平台只有登录之后才可以访问其内部其他的子页面 如何实现模拟登陆 模拟点击登录按钮发起的请求即可 阻力 验证码的识别 验证码识别 使用线上的打码平台进行各种各样验证码的识别 不包含滑动验
  • eclipse创建动态web项目

    1 打开eclipse 2 依次选择File new Dynamic Web Project 点击new如果没有Dynamic Web Project 选择Other 3 在wizards下输入web 在下面的选框中选择 Dynamic W
  • 前端(十七)——gitee上开源一个移动端礼盒商城项目(前端+后台)

    博主 小猫娃来啦 文章核心 gitee上开源一个移动端礼盒商城项目 文章目录 前言 开源地址 项目运行命令 项目基本展示 前端效果细节展示视频 前端代码细节展示视频 后台效果展示 后台代码展示 经典优势 思维导图 实现思路 前言 项目样式老
  • 养娃探索记录

    0 3岁 搞好身体 3 6岁 培养好生活习惯 6 9岁 培养好学习习惯 9 12岁 培养自学能力 12 15岁 了解三百六十行 不同行业不同职业都是做什么的 15 18岁 确定未来发展方向和人生目标 自我决定理论对我们的教育有着重要的指导作
  • Anaconda安装Pytorch,以及在Pycharm中的配置

    Anaconda安装Pytorch 以及在Pycharm中的配置 有关Anaconda安装请移步安装Anaconda Python3 9 Tensorflow pytorch cpu 打开Anaconda Prompt 创建环境 这个步骤同
  • 八大排序(二)-----堆排序

    基本思想 1 将带排序的序列构造成一个大顶堆 根据大顶堆的性质 当前堆的根节点 堆顶 就是序列中最大的元素 2 将堆顶元素和最后一个元素交换 然后将剩下的节点重新构造成一个大顶堆 3 重复步骤2 如此反复 从第一次构建大顶堆开始 每一次构建

随机推荐

  • JavaScript 浅层克隆和深度克隆

    文章目录 JS 浅层克隆和深度克隆 1 相关知识点 2 浅层克隆 2 1 浅克隆函数 2 2 运用实例 3 深度克隆 3 1 深克隆步骤分析 3 2 深克隆函数 3 3 运用实例 3 4 hasOwnProperty JS 浅层克隆和深度克
  • 【自学笔记】后端01_Web服务器01_Tomcat

    一 web服务器 web服务器是一个应用程序 对HTTP协议的操作进行封装 使得程序员不必直接对协议进行操作 简化开发 主要功能是 提供网上信息浏览服务 说人话 服务器的作用就是封装了HTTP协议操作简化网站部署 让人可以在浏览器访问部署的
  • 【Leetcode笔记】重复的子字符串

    Leetcode原题链接 重复的子字符串 零 前情提要 自己写了个暴力法 LeetCode的一个巨长的字符串测试 然后嫌弃我时间太长了没通过 先放这里了 自己测试着没什么问题 代码 class Solution def repeatedSu
  • python基础:基本数据类型:整数,浮点数,字符串

    基本数据类型介绍 整数 浮点数 字符串 在介绍之前先来了解一个python在终端窗口运行的解释器 在你安装好python之后 运行终端命令行 输入python后回车 你会看到你的python版本以及提示信息和 gt gt gt 等待输入 你
  • 安卓刷机之pixel

    刷机记录 提示 本例子是pixel sailfish 刷rom 提示 刷rom及刷容比较简单一点 1 首先去谷歌的官网去下载手机对应的机型 2 下载刷机工具platform tools zip github地址好像404了 不过下载的地方很
  • windows 64位 apachect2.4+php5.5 无法加载php_curl模块

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 环境 windows2008 64位 单独装了apachect2 4 php5 5版本 今天服务迁移后 一直报call to undefined function curl
  • unity的触摸类touch使用

    这篇博文将简单的记录 如何用unity处理在移动设备上的触控操作 iOS和Android设备能够支持多点触控 在unity中你可以通过Input touches属性集合访问在最近一帧中触摸在屏幕上的每一根手指的状态数据 简单的触控响应实现起
  • Android studio单击按钮时,提示文字信息,并实现图片的显示

    MainActivity2 java package com example myapplication one import android media Image import android os Bundle import andr
  • 2021年计算机专业研究生分数线,2021年计算机科学与技术在职研究生分数线是多少?...

    计算机科学与技术为时下热门专业 国内对此专业人才的需求量比较大 薪资待遇也是比较不错的 国内有多所院校开设了此专业在职研课程班 但有学员对其分数线不了解 那么2021年计算机科学与技术在职研究生分数线是多少 据了解得知 就读计算机科学与技术
  • springboot非配置实现动态多数据库查询

    1需求 数据库配置信息不能在项目代码中配置或写死 系统能接入用户配置的数据库并保存和读取 每个用户可添加多个数据库 不同数据库类型 不同host 多个用户可添加相同的一个数据库 同一个数据库只创建一个连接池 数据库类型差异对业务逻辑透明 2
  • Python列表转换为字典

    如何将两个列表组合生成字典 1 源码 how to convert list into dict list1 1 2 3 list1 print list1 list2 one two three list2 print list2 obj
  • http2-浏览器支持的情况

    毕竟http2是新事物 尽管它的协议文本已经正式发布 但是相应的服务器和客户端代码依然在演进中 我本人也特别关注浏览器部分 因为研究了颇有一段时间的node http2 希望它可以和浏览器互操作 而不是自己的client 自己的server
  • Redis系列 - 单线程的Redis为什么那么快?

    Redis系列 单线程的Redis为什么那么快 Redis为什么使用单线程 在说这个问题之前我们先来了解下引入多线程常见的开销 1 上下文切换 即使是单核CPU也支持多线程执行代码 CPU通过给每个线程分配CPU时间片来实现这个机制 时间片
  • 三种开窗函数详细用法,图文详解

    开窗函数的详细用法 一 开窗函数的语法 二 从聚合开窗函数sum score over partition by name 讲起 三 开窗函数之first value last value lead lag 四 排名开窗函数ROW NUMB
  • 什么是百分比堆积条形图?

    条形图实际上范围很广 它是以横置图形展示数据的一种图表类型 百分比堆积条形图即以堆积条形图的形式来显示多个数据序列 但是每个堆积元素的累积比例始终总计为 100 它主要用于显示一段时间内的多项数据占比情况 百分比堆叠条形图将多个数据集的条形
  • Web网站的性能测试工具

    随着Web 2 0技术的迅速发展 许多公司都开发了一些基于Web的网站服务 通常在设计开发Web应用系统的时候很难模拟出大量用户同时访问系统的实际情况 因此 当Web网站遇到访问高峰时 容易发生服务器响应速度变慢甚至服务中断 为了避免这种情
  • 三子棋大致构建思路

    设计思路 1 菜单 输入选择 1 PLAY 开始游戏 0 EXIT 退出游戏 其他 重新进入菜单选择 2 PLAY 开始游戏 大致结构 1 创建并打印棋盘 2 玩家下棋 3 电脑下棋 4 判断局势 5 得出结果 6 返回1 菜单 3 创建并
  • Unity 入门打字机效果

    Unity 入门打字机效果 使用协程加延迟 public class UIDazhi MonoBehaviour public Text t private string currentstr public string str 欢迎来到U
  • Nginx HTTP 健康检查

    通过发送定期健康检查 包括 NGINX Plus 中可自定义的主动健康检查 来监控上游组中 HTTP 服务器的健康状况 介绍 NGINX 和 NGINX Plus 可以持续测试您的上游服务器 避免出现故障的服务器 并将恢复的服务器优雅地添加
  • e-009 matlab,matlab使用贝叶斯优化的深度学习

    此示例说明如何将贝叶斯优化应用于深度学习 以及如何为卷积神经网络找到最佳网络超参数和训练选项 要训练深度神经网络 必须指定神经网络架构以及训练算法的选项 选择和调整这些超参数可能很困难并且需要时间 贝叶斯优化是一种非常适合用于优化分类和回归