hello paddle

2023-11-05

import paddle         #导入飞桨paddle和numpy
import numpy
print("paddle " + paddle.__version__)
print("numpy " + numpy.__version__)
paddle 2.2.1
numpy 1.21.2
 #用paddle.to_tensor把示例数据转换为paddle的Tensor数据。
    
x_data = paddle.to_tensor([[1.], [3.0], [5.0], [9.0], [10.0], [20.0]])   
y_data = paddle.to_tensor([[12.], [16.0], [20.0], [28.0], [30.0], [50.0]])

一、用飞桨定义模型的计算

在这里的示例中,根据经验,已经事先知道了distance_travelledtotal_fee之间是线性的关系,而在更实际的问题当中,x和y的关系通常是非线性的,因此也就需要使用更多类型,也更复杂的神经网络。(比如,BMI指数跟你的身高就不是线性关系,一张图片里的某个像素值跟这个图片是猫还是狗也不是线性关系。)

飞桨的线性变换层:paddle.nn.Linear来实现这个计算过程,这个公式里的变量x, y, w, b, y_predict,对应着飞桨里面的Tensor概念

linear = paddle.nn.Linear(in_features=1, out_features=1)

二、准备好运行飞桨

机器(计算机)在一开始的时候会随便猜w和b,这时候的w是一个随机值,b是0.0,这是飞桨的初始化策略,也是这个领域常用的初始化策略。(如果你愿意,也可以采用其他的初始化的方式,今后你也会看到,选择不同的初始化策略也是对于做好深度学习任务来说很重要的一点。

w_before_opt = linear.weight.numpy().item()
b_before_opt = linear.bias.numpy().item()

print("w before optimize: {}".format(w_before_opt))
print("b before optimize: {}".format(b_before_opt))
w before optimize: -1.2472419738769531
b before optimize: 0.0

三、告诉飞桨怎么样学习

前面定义好了神经网络(尽管是一个最简单的神经网络),还需要告诉飞桨,怎么样去学习,从而能得到参数w和b

机器学习/深度学习当中,机器(计算机)在最开始的时候,得到参数w和b的方式是随便猜一下,用这种随便猜测得到的参数值,去进行计算(预测)的时候,得到的y_predict,跟实际的y值一定是有差距的。接下来,机器会根据这个差距来调整w和b,随着这样的逐步的调整,w和b会越来越正确,y_predict跟y之间的差距也会越来越小,从而最终能得到好用的w和b。这个过程就是机器学习的过程。

用更加技术的语言来说,衡量差距的函数(一个公式)就是损失函数,用来调整参数的方法就是优化算法

用最简单的均方误差(mean square error)作为损失函数(paddle.nn.MSELoss);和最常见的优化算法SGD(stocastic gradient descent)作为优化算法(传给paddle.optimizer.SGD的参数learning_rate,你可以理解为控制每次调整的步子大小的参数)。

mse_loss = paddle.nn.MSELoss()      #均方误差mse作为损失函数,
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters = linear.parameters())   #调参算法SGD作为优化算法

四、运行优化算法

接下来,让飞桨运行一下这个优化算法,这会是一个前面介绍过的逐步调整参数的过程,你应该可以看到loss值(衡量y和y_predict的差距的loss)在不断的降低。

total_epoch = 5000
for i in range(total_epoch):
    y_predict = linear(x_data)
    loss = mse_loss(y_predict, y_data)
    loss.backward()
    sgd_optimizer.step()
    sgd_optimizer.clear_grad()
    
    if i%1000 == 0:
        print("epoch {} loss {}".format(i, loss.numpy()))
        
print("finished training, loss {}".format(loss.numpy()))

epoch 0 loss [1702.1357]
epoch 1000 loss [7.9019628]
epoch 2000 loss [1.7668242]
epoch 3000 loss [0.39504835]
epoch 4000 loss [0.08833063]
finished training, loss [0.0197801]

五、机器学习出来的参数

经过了这样的对参数w和b的调整(学习),再通过下面的程序,来看看现在的参数变成了多少。你应该会发现w变成了很接近2.0的一个值,b变成了接近10.0的一个值。虽然并不是正好的2和10,但却是从数据当中学习出来的还不错的模型的参数,可以在未来的时候,用从这批数据当中学习到的参数来预估了。(如果你愿意,也可以通过让机器多学习一段时间,从而得到更加接近2.0和10.0的参数值。)

w_after_opt = linear.weight.numpy().item()
b_after_opt = linear.bias.numpy().item()

print("w after optimize: {}".format(w_after_opt))
print("b after optimize: {}".format(b_after_opt))

w after optimize: 2.017909288406372
b after optimize: 9.771002769470215
print("hello paddle")
hello paddle
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

hello paddle 的相关文章

  • 微信小程序----相对路径图片不显示

    WXRUI体验二维码 如果文章对你有帮助的话 请打开微信扫一下二维码 点击一下广告 支持一下作者 谢谢 出现场景 在本地调试的时候本地图片显示 但是手机浏览的时候本地图片不显示 出现图片不显示的原因 小程序只支持网络路径和base64的图片
  • vue-别名路径联想提示的配置

    在根路径下 新建 jsconfig json 文件 即可 在输入 自动联想到src目录 代码如下 别名路径联想提示 输入 自动联想 compilerOptions baseUrl paths src 注 这里只是提示 其 实际的路径转换在v
  • .net 只需三步让Swagger显示注释

    net 只需三步让Swagger显示注释 先看效果 第一步 导包 我使用的是 net5 0的框架 所以导入5 x的包 如果你使用 net6 0的框架 注意改版本号 Install Package Swashbuckle AspNetCore
  • 最长连续不重复子序列python题解-双指针

    看了很久 发现这道题目应该是有些前提的 以至于我开始一直往错的方向去想 将题目大大复杂化了 在csdn上看了许多博客文章 终于在看到12235这道例题的答案是235才恍然大悟 反应过来 这题很简单一直被连续 类似12345 再不重复所误导

随机推荐

  • Ableton Live 10 Suite功能特色

    Ableton Live 10 Suite软件介绍 Ableton Live 10 Suite 是来自国外Ableton公司的一款旗舰级音乐创作软件 本站为大家分享的Ableton Live 10 Suite 是目前的最新版本 拥有四个全新
  • 非连续性概率分布的概率密度(有间断点时如何求数学期望)

    一 笔者做张宇试卷的时候 第三套试卷22题遇到一道这样的题 这里求出来的关于Y的分布函数在Y 1处并不连续 右连续 故而不能直接求导然后再积分 答案给出一种思路 就是利用关于X的概率密度是连续的 间接利用X的概率密度来求Y的数学期望 这是一
  • 提问 未来计算机的发展趋势是什么,计算机今后的发展趋势是什么?

    计算机今后的发展趋势是 1 巨型化 为了适应尖端科学技术的需要 发展高速度 大存储容量和功能强大的超级计算机 2 网络化 3 人工智能化 4 多媒体化 5 微型化 家用计算机的体积不断的缩小 逐步微型化 为人们提供便捷的服务 计算机 是现代
  • 蓝屏代码大全(留着自己看)

    1 常见蓝屏代码 蓝屏代码 蓝屏原因 处理方法 A5 主板 主板BIOS问题 主板放电 编程器尝试刷BIOS 不行就寄回换主板 0A 内存或硬盘 memtest测试一下内存是否报错 一般都是内存问题 更换内存 EA 显卡驱动或者显卡 完全卸
  • C++中模板函数以及类模板的示例(template)

    模板是泛型编程的基础 泛型编程即以一种独立于任何特定类型的方式编写代码 模板是创建泛型类或函数的蓝图或公式 库容器 比如迭代器和算法 都是泛型编程的例子 它们都使用了模板的概念 下面是具体的使用 include
  • 【记录】Django shell

    参考书目依旧与前几章一样 记录使用Django Shell 可以使用交互式终端会话以编程方式查看我们创建的数据 激活虚拟环境 并打开服务器 python manage py shell 一个方便快捷的可以查看bug和debug的地方
  • Unity-3DRPG游戏 学习笔记(1)--使用URP渲染管线

    教程地址 Unity2020 3DRPG游戏开发教程 Core核心功能01 Create Project 创建项目导入素材 Unity中文课堂 哔哩哔哩 bilibili 创建URP通用渲染管线 2021版本 1 打开 Windows Pa
  • 报错:java.lang.NullPointerException: Attempt to invoke virtual method ‘void android.widget.ImageView

    小编在调用View的时候出错 错误代码 java lang NullPointerException Attempt to invoke virtual method void android widget ImageView setIma
  • get和post请求方式总结

    前端发送请求最常 的是get请求还有post请求 get请求只能传params参数 params参数都是拼在请求地址上的 post可以传body和params两种形式 注意 params形式传递数据不管是get还是post请求 参数最后都是
  • java---多线程编程

    Java 多线程编程 Java 给多线程编程提供了内置的支持 一条线程指的是进程中一个单一顺序的控制流 一个进程中可以并发多个线程 每条线程并行执行不同的任务 与之对比的是 多线程是多任务的一种特别的形式 但多线程使用了更小的资源开销 这里
  • 异步处理及其他

    异步的方式 Spring事件发布 开启新的线程 其他资料 非异步 事务提交后 做其他的事情 事务回滚 同时记录异常信息 Spring事件发布 https blog csdn net root zhb article details 1256
  • 算法基础/递归回溯

    当要求解全排列或者全部的组合时 常采用递归 回溯的方式 标准的递归 回溯 DFS形式DFS nums index 表示当前位置是index 1 对于每个位置的数 要么被选中 temp push back nums index DFS num
  • Ubuntu 扩展内存或断电之后卡在 /dev/sda1 clean 和 /dev/sda1 recovering journal

    当ubuntu虚拟机硬盘空间不够用的时候 往往会出现新增扩展硬盘空间之后 出现开机卡死的现象 通过查阅相关资料 排坑如下 一 原VM硬盘空间已满 当原VM硬盘空间已满的情况下 千万不要重启或者关机操作 极容易引起卡死的状况发生 解决方案为
  • 计及电池储能寿命损耗的微电网经济调度(matlab代码)

    目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序参考文献 考虑寿命损耗的微网电池储能容量优化配置 模型 以购售电成本 燃料成本和储能寿命损耗成本三者之和为目标函数 创新考虑储能寿命损耗约束 放电深度约束和储
  • 分立式BUCK电路原理与制作持续更新

    目录 一 分立式BUCK电路总体原理图 二 BUCK电路与LDO的区别 三 BUCK电路为什么要加电感 四 BUCK电路要加续流二极管 五 BUCK电路导通与断开的回路 六 电源公式的中的几个表示方式 1 输入功率用Pin表示 2 输出功率
  • springboot+vue商城项目实战-springboot后端搭建

    搭建Spring Boot Vue商城后端项目 要搭建Spring Boot Vue商城后端项目 你需要掌握一系列的技术背景 下面我将为你介绍 开发这种项目所需的主要技术要求 Spring Boot框架 Spring Boot是一个开发Ja
  • redis常见操作命令-list

    1 将1个或者多个的value压入key的表头 LPUSH key value value 127 0 0 1 6379 gt LPUSH list abc integer 1 127 0 0 1 6379 gt LGET list err
  • Nginx设置成网站为https

    首先 获取SSL证书 我的证书是阿里云获取的 免费版dv证书 一年有效期 购买后 自动跳转到证书控制台 点击申请 然后选择如下设置 打码内容填入自己的个人信息 等待审核通过 我大概等了半小时 然后下载证书 解压 获得以下两个文件 在服务器的
  • 多益网络人工智能面试和入职问题

    以下几点是我在技术面试中技术hr问到的一些问题 1 简单自我介绍 2 网测的智商检测问题怎么看 3 分别介绍两个项目 4 基于第一个项目 有没有做过法律相关的知识图谱构建来优化模型结果 5 基于第二个项目 在做方案研究的时候就只是模型的融合
  • hello paddle

    文章目录 一 用飞桨定义模型的计算 二 准备好运行飞桨 三 告诉飞桨怎么样学习 四 运行优化算法 五 机器学习出来的参数 import paddle 导入飞桨paddle和numpy import numpy print paddle pa