PyTorch-02梯度下降Gradient Descent、回归案例、手写数字识别案例

2023-10-29

PyTorch-02梯度下降Gradient Descent、回归案例、手写数字识别案例

了解梯度下降

梯度下降是深度学习的精髓。整个deep learning是靠梯度下降所支撑的。可以求解一个非常难的函数,使用的方法就是梯度下降算法。

求一个函数的极小值,就可以先求其函数对应的导数,再检验这个函数的导数是否为极大值或者极小值点。梯度下降与上述的方法类似,但这里还需要一个迭代计算的过程。
例子:
函数y=x^2·sin(x)
该函数导数y’=2·x·sin(x) + x^2·cos(x)
在每获得一个倒数的时候,在x基础上减去y函数在x处的导数。即x’=x-∂x这样就可以获得一个新的x。
比如说,该函数在某一处的x为2.5,其函数y在x为2.5处时的导数为-0.9,新的x=2.5-0.9,但是这里如果之间减去-0.9,会让新的x变化过快,所以这里需要给一个缩放倍数(即学习的速率)使得新x调整的过程不会过大。
所以最终新的 x’ = x - ∂x·learningRate,这样新的x变化情况就可以变化的非常非常小。
假设x=5时,learningRate = 0.005,∂x=0,该函数y有最小值,即x’=5 - 0·0.005 = 5 ,这样x’依然为5,实际上∂x不会是0,会有一定的计算误差,x’会在5附加产生一定的抖动。
在这里插入图片描述

梯度下降思想在线性回归情况中的应用:

理想情况:精确求解Closed Form Solution

在这里插入图片描述
对于精确求解很多实际情况下是做不到的,但是如果可以达到近似求解并且在实际情况下被证明可行,这样就已经满足我们的目的了,不需要一个非常准确的精确求解Closed Form Solution。

真实情况:伴随着噪声

在这里插入图片描述
这里我们需要求解的是y = w·x + b 这个函数,与真实y之间的差值,即wx + b 与y真实值之间的差值和最小。这里表示均方误差最小。loss = (WX+b-y)^2的最小值。使得y真实与wx+b之间更加接近。

案例:
在这里插入图片描述
对于具体的方程:loss = ∑i (w·xi+b-yi)^2 要求解w和b,观测到的具体值是xi和yi(xi表示第i个观测到的样本)。借助梯度下降的方法,计算出一个极小值,希望一个wx+b逼近于y,即(y-(wx+b))这个的极小值。在这个取到极小值的时候的w’ 和b’ 的值就是我们需要求解的,即获得最能拟合出真实值y情况的w’x+b’。
在这里插入图片描述
loss = (w·xi+b-yi) ^2
∑i (w·xi+b-yi)^2 最小,这个过程是很好理解的:
在这里插入图片描述
loss是大于等于0的,所以b=0,w=1左右,loss是最小的。如果loss函数能够可视化的话,就可以搜索出loss最低点的w和b的值,就可以很直观的求解出loss最小值。但是现实情况往往不能够可视化,因为x的维度非常非常高,使得w的维度也非常高,因此很难在3维图像上将loss函数曲面图绘制出来

凸优化Convex optimization

上图是有一个全局最小点的,这样的函数叫做凸函数,会有专门的学科叫凸优化,对于深度学习,我们只需要了解一下即可,不需要深入了解,只需要会使用现成的凸函数优化器即可,即使是非凸函数,也能找到一个局部最小值。

查看一下求解过程:

随机的一个初始点,w=0,b=0为初始点,然后在每一处对w和b来求导,从而不断跟新w值,最终希望有一条直线穿过整个数据集,使得整体的误差偏小。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

回归问题实战:

在这里插入图片描述
data:

#.csv
32.50234527	31.70700585
53.42680403	68.77759598
61.53035803	62.5623823
47.47563963	71.54663223
59.81320787	87.23092513
55.14218841	78.21151827
52.21179669	79.64197305
39.29956669	59.17148932
48.10504169	75.3312423
52.55001444	71.30087989
45.41973014	55.16567715
54.35163488	82.47884676
44.1640495	62.00892325
58.16847072	75.39287043
56.72720806	81.43619216
48.95588857	60.72360244
44.68719623	82.89250373
60.29732685	97.37989686
45.61864377	48.84715332
38.81681754	56.87721319
66.18981661	83.87856466
65.41605175	118.5912173
47.48120861	57.25181946
41.57564262	51.39174408
51.84518691	75.38065167
59.37082201	74.76556403
57.31000344	95.45505292
63.61556125	95.22936602
46.73761941	79.05240617
50.55676015	83.43207142
52.22399609	63.35879032
35.56783005	41.4128853
42.43647694	76.61734128
58.16454011	96.76956643
57.50444762	74.08413012
45.44053073	66.58814441
61.89622268	77.76848242
33.09383174	50.71958891
36.43600951	62.12457082
37.67565486	60.81024665
44.55560838	52.68298337
43.31828263	58.56982472
50.07314563	82.90598149
43.87061265	61.4247098
62.99748075	115.2441528
32.66904376	45.57058882
40.16689901	54.0840548
53.57507753	87.99445276
33.86421497	52.72549438
64.70713867	93.57611869
38.11982403	80.16627545
44.50253806	65.10171157
40.59953838	65.56230126
41.72067636	65.28088692
51.08863468	73.43464155
55.0780959	71.13972786
41.37772653	79.10282968
62.49469743	86.52053844
49.20388754	84.74269781
41.10268519	59.35885025
41.18201611	61.68403752
50.18638949	69.84760416
52.37844622	86.09829121
50.13548549	59.10883927
33.64470601	69.89968164
39.55790122	44.86249071
56.13038882	85.49806778
57.36205213	95.53668685
60.26921439	70.25193442
35.67809389	52.72173496
31.588117	50.39267014
53.66093226	63.64239878
46.68222865	72.24725107
43.10782022	57.81251298
70.34607562	104.2571016
44.49285588	86.64202032
57.5045333	91.486778
36.93007661	55.23166089
55.80573336	79.55043668
38.95476907	44.84712424
56.9012147	80.20752314
56.86890066	83.14274979
34.3331247	55.72348926
59.04974121	77.63418251
57.78822399	99.05141484
54.28232871	79.12064627
51.0887199	69.58889785
50.28283635	69.51050331
44.21174175	73.68756432
38.00548801	61.36690454
32.94047994	67.17065577
53.69163957	85.66820315
68.76573427	114.8538712
46.2309665	90.12357207
68.31936082	97.91982104
50.03017434	81.53699078
49.23976534	72.11183247
50.03957594	85.23200734
48.14985889	66.22495789
25.12848465	53.45439421
import numpy as np

# y = wx + b
def compute_error_for_line_given_points(b, w, points):
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        totalError += (y - (w * x + b)) ** 2
    return totalError / float(len(points))

def step_gradient(b_current, w_current, points, learningRate):
    b_gradient = 0
    w_gradient = 0
    N = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        b_gradient += -(2/N) * (y - ((w_current * x) + b_current))
        w_gradient += -(2/N) * x * (y - ((w_current * x) + b_current))
    new_b = b_current - (learningRate * b_gradient)
    new_m = w_current - (learningRate * w_gradient)
    return [new_b, new_m]

def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):
    b = starting_b
    m = starting_m
    for i in range(num_iterations):
        b, m = step_gradient(b, m, np.array(points), learning_rate)
    return [b, m]

def run():
    points = np.genfromtxt("data.csv", delimiter=",")
    learning_rate = 0.0001
    initial_b = 0 # initial y-intercept guess
    initial_m = 0 # initial slope guess
    num_iterations = 1000
    print("Starting gradient descent at b = {0}, m = {1}, error = {2}"
          .format(initial_b, initial_m,
                  compute_error_for_line_given_points(initial_b, initial_m, points))
          )
    print("Running...")
    [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations)
    print("After {0} iterations b = {1}, m = {2}, error = {3}".
          format(num_iterations, b, m,
                 compute_error_for_line_given_points(b, m, points))
          )

if __name__ == '__main__':
    run()

在这里插入图片描述
可以发现error由最初的5565.107到112.614810。线性回归的拟合的均方误差明显减小了。

手写数字识别问题案例:

MNIST数据集
这个数据集每一个数字照片大小是28*28,有7000张。将前6000张做训练集,剩余1000张做测试集。
在这里插入图片描述
每一个手写数字图片可以由矩阵来表示,即28行28列,每一个index所对应的值是0到1区间内的值,该值表示灰度值。将这个28行28列的矩阵通过flat操作进行降维,拉伸成1维向量,数据量不变784,这样忽略了位置相关性。在降维后的向量前添加一个维度,使其变为[1,784],这样数据是不变的,只是前面多了一个1。

一个简单的线性模型y=wx+b,但是对于手写数字来说,仅仅使用简单的线性模型是不够的,我们使用三个线性函数的嵌套。
在这里插入图片描述
loss损失函数如何计算:
loss = ([预测结果]-[真实值])^2
这里的真实值,需要转换成独热编码one-hot。
在这里插入图片描述
在这里插入图片描述
对于线性模型很难处理非线性情况,这里引入非线性因子Non-linear Factor。这里非线性因子主要有sigmoid函数(激活函数),ReLU函数等。这里我们选用ReLU函数作为非线性因子。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

将上述过程通过python来实现:

在这里插入图片描述
最后一层不加relu非线性因子:
在这里插入图片描述
主要分为四个步骤对手写数字进行识别:
1、加载图片
2、新建模型
3、训练集
4、测试集
在这里插入图片描述

utils.py:

import  torch
from    matplotlib import pyplot as plt

#绘制下降的曲线:
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

#绘制图片,识别的结果
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

#实现独热编码
def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

mnist_train.py:

import torch
from torch import nn #表示神经网络的一些工作
from torch.nn import functional as F #表示常用的一些函数
from torch import optim #导入优化工具包

#mnist是一个视觉的数据集,所以需要导入torchvision
import torchvision
from matplotlib import pyplot as plt

#导入utils类,即utils.py文件
from utils import plot_image,plot_curve,one_hot


# step1. load dataset
#这里有一个batch_size概念:
#gpu性能非常强大,一次处理1张图片可能需要3毫秒,一次处理多张图片可能也就4到5毫秒。
#这样通过并行处理多张图片,可以大大节省计算时间。

batch_size = 512 #因为图片是28*28的大小,图片比较小,所以这里batch_size就稍微大一些。

#加载训练集
#torch专门加载数据集的方法:
train_loader = torch.utils.data.DataLoader(
    #加载mnist数据集
    #参数1:指定数据集
    #参数2train:指定下载的数据是那60000张训练集数据集,因为有60000张图片是做训练集的,有10000张是测试集的。
    #参数3download:表示如果当前文件没有mnist数据集,就从网络上下载下来。
    #参数4transform:转换格式,网络下载下来的数据集是numpy格式,需要转换成Tensor格式,该格式是torch的一个数据载体。
    #此外还有一个正则化过程normalize,因为神经网络所接受的数据最好均匀分布在0附近,但是我们图片的像素是0到1之间(0的右侧分布),通过减去0.1307再除以标准差这样一个过程使得我们的数据能够在0周围均匀的分布,这样更方便神经网络去优化,这个正则化可以不做,注释掉的话性能会差一些可能70%作用,如果使用正则化效果会好很多80%多。
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    # 参数5batch_size:一次加载多少张图片,并行处理多张图片。
    # 参数6shuffle:表示加载的时候随机打散
    batch_size=batch_size, shuffle=True)

#加载测试集
test_loader = torch.utils.data.DataLoader(
    #这里参数shuffle设置为false,因为测试集就没必要打散了
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=False)

#将加载的图片显示出来
#next() 返回迭代器的下一个项目。
#next() 函数要和生成迭代器的 iter() 函数一起使用。
#iter() 函数用来生成迭代器,iter(object[, sentinel]),object -- 支持迭代的集合对象。
#sentinel -- 如果传递了第二个参数,则参数 object 必须是一个可调用的对象(如,函数),此时,iter 创建了一个迭代器对象,每次调用这个迭代器对象的__next__()方法时,都会调用 object。
x,y = next(iter(train_loader))
#512张图片,1个通道,28行,28列,label有512个,最小值-0.4242,最大值2.8215,说明在0周围分布。
print(train_loader)
print(x.shape,y.shape,x.min(),x.max())
# print(x)
plot_image(x,y,'image sample')


#step2.create net
class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        #xw+b
        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        # x:[b,1,28,28]
        # h1 = xw+b
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2+b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3+b3
        x=self.fc3(x)
        return x

net = Net()
#[w1,b1,w2,b2,w3,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)


#step.3 trainning

train_loss = []
for epoch in range(3):
    for batch_idx,(x,y) in enumerate(train_loader):
        # x: [b,1,28,28],y:[512]
        # [b,1,28,28] => [b , feature]
        x= x.view(x.size(0),28*28)
        # => [b,10]
        out = net(x)
        #[b,10]
        y_onehot = one_hot(y)
        #loss = mse(out , y_onehot)
        loss = F.mse_loss(out,y_onehot)

        optimizer.zero_grad()
        loss.backward()
        #w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 ==0:
            print(epoch,batch_idx,loss.item())

plot_curve(train_loss)
#we get optimal [w1,b1,w2,b2,w3,b3]

#step.4 accuracy

total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)
    #out :[b,10] => pred:[b]
    pred = out.argmax(dim = 1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct/total_num
print('test acc:',acc)

#step.5
x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim =1)
plot_image(x,pred,'test')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch-02梯度下降Gradient Descent、回归案例、手写数字识别案例 的相关文章

随机推荐

  • LeetCode 226. 翻转二叉树

    题目链接 https leetcode cn com problems invert binary tree 先序遍历 Java 代码 class Solution public TreeNode invertTree TreeNode r
  • 我的世界ess服务器信息,我的世界ess指令怎么用 ess指令大全及用法详解

    我的世界ess指令都有哪些 作为风靡全球的沙盒游戏 我的世界带给玩家太多的乐趣 为了能更方便的游戏 ess指令能帮助我们更好的游戏 很多新手玩家刚接触就被搞晕了 这么多的指令看起来有些复杂 下面就由小编给大家带来 我的世界ess指令都有哪些
  • mybatis进行批量插入 返回批量插入主键ID 插入不成功等问题

    这篇博文讲的是批量插入的例子 dao层框架用的mybatis 最一开始我的批量插入其实是个伪批量 是类似吧很多条insert into语句 直接拼成一条 然后直接运行 发现这样的效率真的是十分低 我做测试时285条数据 插入一次需要10S多
  • uniapp添加.gitignore以及不生效解决办法

    一 第一次新建 gitignore 首先进入项目 命令行新建 gitignore文件 touch gitignore 然后编辑器打开 进入到项目中新建的 gitignore 文件 复制粘贴以下 node modules project un
  • C++实现鼠标点击其他程序

    1 主要是SendInput函数 代码如下 初始化 INPUT input 0 input type INPUT MOUSE dx dy代表的是进行点击的坐标 下面显示的是 950 150 input mi dx static cast
  • 【Proteus仿真】555组成的多谐振荡器电路

    Proteus仿真 555组成的多谐振荡器电路 Proteus仿真演示 多谐振荡器电路 多谐振荡器电路是一种矩形波产生电路 属于数字电路 三极管不工作在放大线性区 这种电路不需要外加触发信号便能连续地 周期性地自行产生矩形脉冲 该脉冲是由基
  • Stable Diffusion:ChatGPT与AI绘画,引领艺术的未来

    人工智能 AI 的快速发展正在为各个领域带来革命性的变化 其中包括艺术与创意领域 AI绘画是一种将人工智能技术与艺术创作相结合的新兴范式 通过深度学习和生成对抗网络 GAN 等技术 AI绘画可以生成各种富有创意和想象力的艺术作品 本文将探讨
  • python - __str__ 和 __repr__

    内建函数str 和repr representation 表达 表示 或反引号操作符 可以方便地以字符串的方式获取对象的内容 类型 数值属性等信息 str 函数得到的字符串可读性好 故被print调用 而repr 函数得到的字符串通常可以用
  • Docker+docker-compose+nginx部署已有项目

    项目背景 在异地服务器拷docker相关项目到新的服务器 具体操作 1 新服务器安装好docker 2 新服务器安装好docker compose 3 从老服务器拷贝镜像到新服务器 4 新服务器导入镜像 5 构建项目地址挂载目录 找到doc
  • 用U盘作启动盘装Windows10系统整套流程 纯净版(不用其他乱七八糟的软件)(macOS适用)

    简介 本人的电脑是MacBook Air 2014年版的 因为内存小而且文件杂乱 所以一下子都给格式化了 但是要用Mac自带的恢复系统的话需要连接校园网 连接校园网又需要打开网页输入账号和密码 我们学校的校园网是这样的 所以只能用U盘作为格
  • gqrx编译过程记录

    gqrx编译过程记录 目标 环境 编译 下载源代码 建立编译位置 修改CMakefile txt中的模块 编译安装 运行界面 没有更多 目标 在ubuntun 20 04桌面版编译gqrx 通过USRP 205mini实现收音机功能 环境
  • 【解决】docker容器怎么使用宿主机的IPv6地址

    在IPv4时代 我们对外访问都是端口映射 都没有公网IP 但是在IPv6时大家都有公网IP 可能需要容器地址和主机地址一致 可以在docker run时使用参数 network host 则此容器网络和宿主机一致 docker run ne
  • AQS详解

    AQS详解 文章目录 AQS详解 AQS简单介绍 AQS原理 AQS原理概览 AQS对资源的共享方式 AQS定义两种资源共享方式 Exclusive 独占 Share 共享 AQS底层使用了模板方法模式 Semaphore 信号量 Coun
  • 浅谈可重入锁

    一 可重入锁 递归锁 1 概念 同一个线程在外层方法获取锁的时候 再进入该线程的内层方法会自动获取锁 前提是 锁对象是同一个对象 不是因为之前已经获取过还没有释放而阻塞 2 java中的ReentrantLock和synchronied都是
  • 关于Gdi+和GdiplusStartup

    GDI 实际上是一组类的定义 封装了gdi 的几乎所有API 当然使用方法就要从这些 例子 里边寻找了 本文正是尝试用GDI 写一个纯SDK的程序 语言自然是我最喜欢的语言WIN32ASM 这个程序很简单 就是用GDI 画了一条直线 算是抛
  • HCIA-FusionCompute华为企业级虚拟化

    一 云计算 按需付费 集中资源对外提供服务 1 云本身没有资源 云是资源整合者 整合底层的所有计算机资源 cpu 内存 磁盘等 云计算是一种模型 它可以实现随时随地 随需应变地从可配置计算资源共享池中获取所需的资源 例如 网络 服务器 存储
  • BigDecimal 问题小结

    BigDecimal 加法 add 函数 乘法multiply 函数 除法divide 函数 绝对值abs 函数 减法subtract 函数 ROUND CEILING 向正无穷方向舍入 ROUND DOWN 向零方向舍入 ROUND FL
  • 【Redis】新增数据结构

    BitMap位图 Redis提供了Bitmaps这个 数据类型 可以实现对位的操作 1 Bitmaps本身不是一种数据类型 实际上它就是字符串 key value 但是它可以对字符串的位进行操作 2 Bitmaps单独提供了一套命令 所以在
  • RabbitMQ与SpringBoot整合实战

    SpringBoot整合RabbitMQ SpringBoot与RabbitMQ集成非常筒単 不需要做任何的额外设置只需要两步即可 step1 引入相关依赖 spring boot starter amqp step2 対applicati
  • PyTorch-02梯度下降Gradient Descent、回归案例、手写数字识别案例

    PyTorch 02梯度下降Gradient Descent 回归案例 手写数字识别案例 了解梯度下降 梯度下降是深度学习的精髓 整个deep learning是靠梯度下降所支撑的 可以求解一个非常难的函数 使用的方法就是梯度下降算法 求一