PyTorch中的9种常见梯度下降算法与案例

2023-10-29

本文将介绍PyTorch中的几种常见梯度下降算法,并提供相应的Python案例。

1. 批量梯度下降(Batch Gradient Descent)

批量梯度下降是最基础的梯度下降算法,通过使用全部训练数据计算损失函数的梯度来更新参数。在PyTorch中,可以通过定义损失函数和优化器来实现批量梯度下降。

例如,在简单线性回归问题中使用批量梯度下降:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数
loss_fn = torch.nn.MSELoss()

# 定义优化器和超参数
optimizer = torch.optim.SGD([w, b], lr=0.01)
epochs = 100

# 批量梯度下降
for epoch in range(epochs):
    # 前向传播
    Y_pred = w * X + b
    # 计算损失
    loss = loss_fn(Y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 1.9999034404754639, b = 0.00041007260753999686

绘制的拟合直线如下图所示:

在这里插入图片描述

2. 随机梯度下降(Stochastic Gradient Descent)

随机梯度下降是一种在每次更新时随机选择一个样本进行梯度计算和参数更新的梯度下降算法。相比批量梯度下降,随机梯度下降具有更快的收敛速度,但更新过程较为不稳定。在PyTorch中,可以通过定义DataLoader来实现随机梯度下降。

例如,在简单线性回归问题中使用随机梯度下降:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD([w, b], lr=0.01)

# 定义超参数
batch_size = 1
epochs = 100

# 随机梯度下降
for epoch in range(epochs):
    # 创建DataLoader
    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
    for x_batch, y_batch in loader:
        # 前向传播
        y_pred = w * x_batch + b
        # 计算损失
        loss = loss_fn(y_pred, y_batch)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 清空梯度
        optimizer.zero_grad()


# 输出结果

print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 2.0002050399780273, b = -0.0005163848866038325

绘制的拟合直线如下图所示:

在这里插入图片描述

3. 小批量梯度下降(Mini-batch Gradient Descent)

小批量梯度下降是介于批量梯度下降和随机梯度下降之间的梯度下降算法,即每次更新时选取一定数量的样本进行梯度计算和参数更新。在PyTorch中,可以通过定义DataLoader来实现小批量梯度下降。

例如,在简单线性回归问题中使用小批量梯度下降:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD([w, b], lr=0.01)

# 定义超参数
batch_size = 2
epochs = 100

# 小批量梯度下降
for epoch in range(epochs):
    # 创建DataLoader
    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
    for x_batch, y_batch in loader:
        # 前向传播
        y_pred = w * x_batch + b
        # 计算损失
        loss = loss_fn(y_pred, y_batch)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 清空梯度
        optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 1.9998304843902588, b = -0.00010240276428798664

绘制的拟合直线如下图所示:

在这里插入图片描述

4. 动量梯度下降(Momentum Gradient Descent)

动量梯度下降是一种在梯度下降更新过程中加入动量项的优化算法,可以加速收敛并减少震荡。在PyTorch中,可以通过设置momentum参数实现动量梯度下降。

例如,在简单线性回归问题中使用动量梯度下降:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数和动量
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
momentum = 0.9

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD([w, b], lr=0.01, momentum=momentum)



# 定义超参数
epochs = 100

# 动量梯度下降
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 1.9999991655349731, b = -3.718109681471333e-05

绘制的拟合直线如下图所示:

在这里插入图片描述

5. AdaGrad

AdaGrad是一种自适应学习率的优化算法,在更新参数时根据历史梯度信息来动态调整每个参数的学习率。在PyTorch中,可以通过设置optim.Adagrad()来使用该优化器。

例如,在简单线性回归问题中使用AdaGrad:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adagrad([w, b], lr=0.1)

# 定义超参数
epochs = 100

# AdaGrad
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 1.999985694885254, b = -0.0010528544095081096

绘制的拟合直线如下图所示:

在这里插入图片描述

6. RMSprop

RMSprop是一种自适应学习率的优化算法,在更新参数时根据历史梯度平方的加权平均来动态调整每个参数的学习率。在PyTorch中,可以通过设置optim.RMSprop()来使用该优化器。

例如,在简单线性回归问题中使用RMSprop:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.RMSprop([w, b], lr=0.01)

# 定义超参数
epochs = 100

# RMSprop
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 2.000011920928955, b = -0.00020079404229614145

绘制的拟合直线如下图所示:

在这里插入图片描述

7. Adam

Adam是一种融合了动量梯度下降和自适应学习率的优化算法,在更新参数时既考虑历史梯度的加权平均又考虑历史梯度平方的加权平均。在PyTorch中,可以通过设置optim.Adam()来使用该优化器。

例如,在简单线性回归问题中使用Adam:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([w, b], lr=0.01)

# 定义超参数
epochs = 100

# Adam
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 2.0000016689300537, b = -1.788223307551633e-05

绘制的拟合直线如下图所示:

在这里插入图片描述

8. AdamW

AdamW是一种基于Adam优化算法的变体,它引入了权重衰减(weight decay)来解决Adam可能存在的参数过度拟合问题。在PyTorch中,可以通过设置optim.AdamW()来使用该优化器。

例如,在简单线性回归问题中使用AdamW:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.AdamW([w, b], lr=0.01, weight_decay=0.1)

# 定义超参数
epochs = 100

# AdamW
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 1.9564942121505737, b = 0.063056387424469

绘制的拟合直线如下图所示:

在这里插入图片描述

9. Adadelta

Adadelta是一种自适应学习率的优化算法,它与RMSprop相似,但引入了一个衰减系数来平衡历史梯度平方和目标函数变化量。在PyTorch中,可以通过设置optim.Adadelta()来使用该优化器。

例如,在简单线性回归问题中使用Adadelta:

import torch
import matplotlib.pyplot as plt

# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adadelta([w, b], lr=0.1)

# 定义超参数
epochs = 100

# Adadelta
for epoch in range(epochs):
    # 前向传播
    y_pred = w * X + b
    # 计算损失
    loss = loss_fn(y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出:

w = 2.0000007152557373, b = 4.5047908031606675e-08

绘制的拟合直线如下图所示:

在这里插入图片描述

以上是常见的优化算法,它们在不同的场景下表现出不同的效果。除此之外,PyTorch还支持其他优化算法,如Adagrad、Adamax等,并且用户也可以自定义优化器来满足特定需求。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch中的9种常见梯度下降算法与案例 的相关文章

  • 如果两点之间的距离低于某个阈值,则从列表中删除点

    我有一个点列表 只有当它们之间的距离大于某个阈值时 我才想保留列表中的点 因此 从第一个点开始 如果第一个点和第二个点之间的距离小于阈值 那么我将删除第二个点 然后计算第一个点和第三个点之间的距离 如果该距离小于阈值 则比较第一点和第四点
  • 如何手动计算分类交叉熵?

    当我手动计算二元交叉熵时 我应用 sigmoid 来获取概率 然后使用交叉熵公式并平均结果 logits tf constant 1 1 0 1 2 labels tf constant 0 0 1 1 1 probs tf nn sigm
  • 与区域指示符字符类匹配的 python 正则表达式

    我在 Mac 上使用 python 2 7 10 表情符号中的标志由一对表示区域指示符号 https en wikipedia org wiki Regional Indicator Symbol 我想编写一个 python 正则表达式来在
  • 元组有什么用?

    我现在正在学习 Python 课程 我们刚刚介绍了元组作为数据类型之一 我阅读了它的维基百科页面 但是 我无法弄清楚这种数据类型在实践中会有什么用处 我可以提供一些需要一组不可变数字的示例吗 也许是在 Python 中 这与列表有何不同 每
  • Python 中的舍入浮点问题

    我遇到了 np round np around 的问题 它没有正确舍入 我无法包含代码 因为当我手动设置值 而不是使用我的数据 时 返回有效 但这是输出 In 177 a Out 177 0 0099999998 In 178 np rou
  • 处理 Python 行为测试框架中的异常

    我一直在考虑从鼻子转向行为测试 摩卡 柴等已经宠坏了我 到目前为止一切都很好 但除了以下之外 我似乎无法找出任何测试异常的方法 then It throws a KeyError exception def step impl contex
  • 需要在python中找到print或printf的源代码[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我正在做一些我不能完全谈论的事情 我
  • 使用 kivy textinput 的 'input_type' 属性的问题

    您好 我在使用 kivy 的文本输入小部件的 input type 属性时遇到问题 问题是我制作了两个自定义文本输入 其中一个称为 StrText 其中设置了 input type text 然后是第二个文本输入 名为 NumText 其
  • Python zmq SUB 套接字未接收 MQL5 Zmq PUB 套接字

    我正在尝试在 MQL5 中设置一个 PUB 套接字 并在 Python 中设置一个 SUB 套接字来接收消息 我在 MQL5 中有这个 include
  • 您可以格式化 pandas 整数以进行显示,例如浮点数的“pd.options.display.float_format”?

    我见过this https stackoverflow com questions 18404946 py pandas formatdataframe and this https stackoverflow com questions
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • 如何在 Python 中解析和比较 ISO 8601 持续时间? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 Python v2 库 它允许我解析和比较 ISO 8601 持续时间may处于不同单
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • 如何使用python在一个文件中写入多行

    如果我知道要写多少行 我就知道如何将多行写入一个文件 但是 当我想写多行时 问题就出现了 但是 我不知道它们会是多少 我正在开发一个应用程序 它从网站上抓取并将结果的链接存储在文本文件中 但是 我们不知道它会回复多少行 我的代码现在如下 r
  • 如何通过索引列表从 dask 数据框中选择数据?

    我想根据索引列表从 dask 数据框中选择行 我怎样才能做到这一点 Example 假设我有以下 dask 数据框 dict A 1 2 3 4 5 6 7 B 2 3 4 5 6 7 8 index x1 a2 x3 c4 x5 y6 x
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • Python:XML 内所有标签名称中的字符串替换(将连字符替换为下划线)

    我有一个格式不太好的 XML 标签名称内有连字符 我想用下划线替换它 以便能够与 lxml objectify 一起使用 我想替换所有标签名称 包括嵌套的子标签 示例 XML
  • 模拟pytest中的异常终止

    我的多线程应用程序遇到了一个错误 主线程的任何异常终止 例如 未捕获的异常或某些信号 都会导致其他线程之一死锁 并阻止进程干净退出 我解决了这个问题 但我想添加一个测试来防止回归 但是 我不知道如何在 pytest 中模拟异常终止 如果我只
  • 如何计算Python中字典中最常见的前10个值

    我对 python 和一般编程都很陌生 所以请友善 我正在尝试分析包含音乐信息的 csv 文件并返回最常听的前 n 个乐队 从下面的代码中 每听一首歌曲都是一个列表中的字典条目 格式如下 album Exile on Main Street

随机推荐