李沐老师 《动手学深度学习》笔记

2023-11-08

         

        08 线性回归 + 基础优化算法

文章目录

  • 前言
  • 一、true_w和w以及true_b和b之间的关系
  • 二、代码实现


前言

        这个是我在B站上看李沐老师《动手学深度学习》之后,针对自己不懂和想记录的部分的一个记录。由于本人是刚接触深度学习的小白,所以可能会有许许多多的错误,如果您碰巧看到了这篇文章,有发现了错误,请批评指正!

一、true_w和w以及true_b和b之间的关系

        true_w和true_b是自己设定的,是用来构造数据集的。

        而w和b是需要通过学习来“接近”true_w和true_b的。

        f = loss(net(X,w,b),y),X和y分别是由 true_w和true_b构造出来的数据和对应的标签,这段代码就是将数据X和参数w,b放入网络中进行计算,将计算的结果和标签比对并计算误差。刚开始由于w和b都是随便设定的,所以误差可能较大。

        f.sum().backward(),利用反向传播函数来计算误差f对于每个分量X,w,b,y的梯度;因为后面要用到w和b的梯度。(但没有用到X和y的梯度)

        sgd([w, b], batch_size, lr),调用sgd函数;利用sgd函数中的 param -= lr * param.grad / batch_size 这一段代码,使用参数梯度更新参数。

        train_l = loss(net(features, w, b), labels),在更新完了w,b之后,再将数据和参数w,b放入网络中和标签对比并计算误差

for epoch in range(0, num_epochs):
    for X, y in read_data(batch_size, features, labels):
        f = loss(net(X, w, b), y)
        f.sum().backward()
        sgd([w, b], batch_size, lr)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print("w {0} \nb {1} \nloss {2:f}".format(w, b, float(train_l.mean())))
def sgd(params, batch_size, lr):
    with torch.no_grad():  # with torch.no_grad() 则主要是用于停止autograd模块的工作,
        for param in params:
            param -= lr * param.grad / batch_size  ##  这里用param = param - lr * param.grad / batch_size会导致导数丢失, zero_()函数报错
            param.grad.zero_()  ## 导数如果丢失了,会报错‘NoneType’ object has no attribute ‘zero_’

二、代码实现

import random
import torch


## 人造数据集
def create_data(w, b, nums_example):
    X = torch.normal(0, 1, (nums_example, len(w)))
    y = torch.matmul(X, w) + b
    print("y_shape:", y.shape)
    y += torch.normal(0, 0.01, y.shape)  # 加入噪声
    return X, y.reshape(-1, 1)  # y从行向量转为列向量


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = create_data(true_w, true_b, 1000)


## 读数据集
def read_data(batch_size, features, lables):
    nums_example = len(features)
    indices = list(range(nums_example))  # 生成0-999的元组,然后将range()返回的可迭代对象转为一个列表
    random.shuffle(indices)  # 将序列的所有元素随机排序。
    for i in range(0, nums_example, batch_size):  # range(start, stop, step)
        index_tensor = torch.tensor(indices[i: min(i + batch_size, nums_example)])
        yield features[index_tensor], lables[index_tensor]  # yield就是 return 返回一个值,并且记住这个返回的位置,下次迭代就从这个位置后开始。


batch_size = 10
for X, y in read_data(batch_size, features, labels):
    print("X:", X, "\ny", y)
    break;

##初始化参数
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)


# 定义模型
def net(X, w, b):
    return torch.matmul(X, w) + b


# 定义损失函数
def loss(y_hat, y):
    # print("y_hat_shape:",y_hat.shape,"\ny_shape:",y.shape)
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2  # 这里为什么要加 y_hat_shape: torch.Size([10, 1])  y_shape: torch.Size([10])


# 定义优化算法
def sgd(params, batch_size, lr):
    with torch.no_grad():  # with torch.no_grad() 则主要是用于停止autograd模块的工作,
        for param in params:
            param -= lr * param.grad / batch_size  ##  这里用param = param - lr * param.grad / batch_size会导致导数丢失, zero_()函数报错
            param.grad.zero_()  ## 导数如果丢失了,会报错‘NoneType’ object has no attribute ‘zero_’


# 训练模型
lr = 0.03
num_epochs = 3

for epoch in range(0, num_epochs):
    for X, y in read_data(batch_size, features, labels):
        f = loss(net(X, w, b), y)
        # 因为`f`形状是(`batch_size`, 1),而不是一个标量。`f`中的所有元素被加到一起,
        # 并以此计算关于[`w`, `b`]的梯度
        f.sum().backward()
        sgd([w, b], batch_size, lr)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print("w {0} \nb {1} \nloss {2:f}".format(w, b, float(train_l.mean())))

print("w误差 ", true_w - w, "\nb误差 ", true_b - b)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

李沐老师 《动手学深度学习》笔记 的相关文章

随机推荐

  • Stream流

    Stream流 Stream 流 是一个来自数据源的元素队列并支持聚合操作 元素是特定类型的对象 形成一个队列 Java中的Stream并不会存储元素 而 是按需计算 数据源 流的来源 可以是集合 数组等 聚合操作 类似SQL语句一样的操作
  • Bes 充电盒协议总结

    1 开盖 上升沿信号开机 a 充电脚设成3 0 v 然后延迟160ms b 充电脚设成5v 然后延时100 ms c充电脚设成3 0 v 2 合盖 a 开5v 然后延时3s b 关5v 然后延时45ms c 发送复位pattern 0101
  • c++ 字符串相等比较

    介绍 在C 中比较字符串的技术 Techniques to Compare Strings in C Strings in C can be compared using either of the following techniques
  • mysql命令 show_mysql--SHOW命令大全

    SHOW AUTHORS 顾名思义 这个要展示的是各位MYSQL开发者的信息 包括姓名 住址及相关注解 e g 1 mysql gt show authors G 1 row Name Brian Krow Aker Location Se
  • LeetCode 62. Unique Paths

    题目链接 题目描述 A robot is located at the top left corner of a m x n grid marked Start in the diagram below The robot can only
  • Microsoft Store无法打开解决方案 错误代码:0x80131500

    这种情况大部分是设置了Vpn代理 提供两种解决方案 一 打开 运行 输入 inetcpl cpl 点还原高级设置 注意看看勾选了TLS 1 2没有 二 如果上述方法没有解决 那么就打开Internet选项 gt 安全选项卡 gt 点一下 将
  • pip安装opencv-python

    文章目录 前言 一 基本概念 二 操作步骤 1 删除旧版本 2 pip升级 3 opencv python安装 总结 前言 OpenCV的全称是Open Source Computer Vision Library 是一个跨平台的计算机视觉
  • 跳转至tabBar页面不触发页面的onLoad,点击底部tabar不触发onLoad

    小程序想跳转tabar页面带参数 使用了全局变量app js的全局 跳转到页面后发现不是每次都执行onLoad方法 传参失败 更换跳转的方法解决 由wx switchTab改为wx reLaunch 就可以了 点击底部导航不触发解决 js
  • Ubuntu挂载Win10下的NTFS硬盘出错的解决方案

    概述 在Ubuntu下打开Win10的NTFS硬盘总是提示出错了 而且是全部的NTFS盘都出错 其中sdb1错误显示如下 he disk contains an unclean file system 0 0 Metadata kept i
  • matplotlib函数总结

    导入matplotlib import matplotlib pyplot as plt import matplotlib Figures对象包含一个或多个Asex对象 方法 matplotlib rc figure figsize 14
  • 在Ubuntu18.04.3系统中安装谷歌拼音输入法(Google Pinyin)

    一 安装前的准备 在Ubuntu18 04下 谷歌拼音输入法是基于Fcitx输入法的 因此 我们需要首先安装Fcitx 一般来说 Ubuntu最新版中都默认安装了Fcitx 但是为了确保一下 我们可以在系统终端中运行如下命令 sudo ap
  • 如何用PHP解决高并发与大流量问题

    举个例子 高速路口 1秒钟来5部车 每秒通过5部车 高速路口运作正常 突然 这个路口1秒钟只能通过4部车 车流量仍然依旧 结果必定出现大塞车 5条车道忽然变成4条车道的感觉 同理 某一个秒内 20 500个可用连接进程都在满负荷工作中 却仍
  • StrangeIOC中Signal类使用详解

    在讲解Signal类之前 先复习一下dispatch的用法 1 View层调用自身的dispatch view 告知绑定的Mediator层也调用自身的dispatch mediator 2 Mediator层的dispatch media
  • 如何实现在的Windows上运行的Linux程序(附示例代码)

    而今天的这篇文章将会讲解如何自己实现一个简单的原生Linux程序运行器 这个运行器在用户层实现 原理和Bash On Windows不完全一样 比较接近Linux上的Wine 示例程序完整的代码在github上 地址是 https gith
  • SuperSocket教程六:配置文件启动后使用自己的请求处理

    上一教程虽然实现了配置文件启动 但是发送信息后返回的缺失AppServer的原始信息 而不是我在教程四锁自定义的信息回复 配置文件启动是实现了 接下来做什么修改可以实现自定义的请求处理呢 其实很简单 只是把原来的那些代码换了个位置罢了 这个
  • 机器学习案例3:基于逻辑回归的肿瘤预测

    案例3 基于逻辑回归的肿瘤预测 为什么写本博客 前人种树 后人乘凉 希望自己的学习笔记可以帮助到需要的人 需要的基础 懂不懂原理不重要 本系列的目标是使用python实现机器学习 必须会的东西 python基础 numpy pandas m
  • 阿里云上的gitlab不能使用ssh

    晚上突然发现ssh到gitlab的项目失败 提示 ssh exchange identification read Connection reset by peer fatal Could not read from remote repo
  • 【电路参考】缓启动电路

    一 外部供电直接上电可能导致的问题 1 在热拔插的过程中 两个连接器的机械接触 触点在瞬间会出现弹跳 电源不稳 发生震荡 这期间系统工作可能造成不稳定 2 由于电路中存在滤波或大电解电容 在上电瞬间 会产生较大的脉冲电流 有时候会看到DC接
  • react+antd修改主题色

    第一步 安装需要的插件 npm install react app rewired customize cra babel plugin import less less loader 第二步 修改package json文件 将原本 sc
  • 李沐老师 《动手学深度学习》笔记

    08 线性回归 基础优化算法 文章目录 前言 一 true w和w以及true b和b之间的关系 二 代码实现 前言 这个是我在B站上看李沐老师 动手学深度学习 之后 针对自己不懂和想记录的部分的一个记录 由于本人是刚接触深度学习的小白 所