【python、pytorch】

2024-01-09

什么是Pytorch

Pytorch是一个基于Numpy的科学计算包,向它的使用者提供了两大功能。作为Numpy的替代者,向用户提供使用GPU强大功能的能力。做为一款深度学习的平台,向用户提供最大的灵活性和速度。

基本元素操作

Tenors张量:张量的概念类似于Numpy中的ndarray数据结构,最大的区别在于Tensor可以利用
GPU的加速功能。

引用

from __future__ import print_function
import torch

创建矩阵

未初始化矩阵 (脏数据)

x=torch.empty(5,3)

初始化矩阵

x=torch.rand(5,3)

全零矩阵

x=torch.zeros(5,3,dtype=torch.long)

dtype可以设置数据类型

创建张量

x=torch.tensor([[1,2],[3,4]])

创建同尺寸的张量(数据随机)

x=x.new_ones(5,3,dtype=torch.double)
y=torch.rand_like(x,dtype=torch.float)

张量的尺寸(返回的是元组)

x=x.new_ones(5,3,dtype=torch.double)
x.size()

基本运算操作

加法操作

x,y尺寸相同

x+y

尺寸相同和上个方法一致

torch.add(x,y)

把结果存入result

result=torch.empty(5,3)
torch.add(x,y,out=result)

就地置换(存入x)

x.add_(y)

所有就地置换函数都有_的后缀

切片操作

和numpy几乎一致

改变张量形状(-1在前匹配列,在后或没有匹配行)

取出唯一个元素

x.item()

与其他格式的相互转换

tensor转numpy array

b=x.numpy()

注意:如果对其中一个数据操作,另一个也会随之发生改变

numpy array转tensor

a=np.ones(5)
b=torch.from_numpy(a)

注意:如果对其中一个数据操作,另一个也会随之发生改变

cuda tensor

if torch.cuda.is_available():
    device = torch.device("cuda")
    y=torch.ones_like(x,device=device)
    x=x.to(device)
    z=x+y
    print(z)
    print(z.to("cpu",torch.double))

autograd

在整个Pytorch框架中,所有的神经网络本质上都是一个autograd pickage(自动求导工具包)。

autograd package提供了一个对Tensors上所有的操作进行自动微分的功能。

torch.Tensor

requires_grad

torch.Tensor是整个package中的核心类,如将属性requires_grad设置为True,它将追踪在这个类上定义的所有操作,当代码要进行反向传播的时候,直接调用backward()就可以自动计算所有的梯度.在这个Tensor上的所有梯度将被累加进属性grad中。

x = torch.ones(2, 2, requires_grad=True)
y=x+2
print(x.grad_fn)
print(y.grad_fn)

改变 requires_grad属性
x.requires_grad_(True)

detach()

如果想终止一个Tensor在计算图中的追踪回溯,只需要执行detach()就可以将该Tensor从计算图中撤下,在未来的回溯计算中也不会再计算该Tensor。

除了detach(),如果想终止对计算图的回溯,也就是不再进行方向传播求导数的过程,也可以采用代码块的方式with torch.no_grad():, 这种方式非常适用于对模型进行预测的时候,因为预测阶段不再需要对梯度进行计算.

torch.Function

Function类是和Tensor类同等重要的一个核心类,它和Tensor共同构建了一个完整的类,每一个Tensor拥有一个grad_fn属性,代表引用了哪个具体的Function创建了该Tensor。

如果某个张量Tensor是用户自定义的,则其对应的grad_fn is None。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【python、pytorch】 的相关文章

  • 在PyGI中获取窗口句柄

    在我的程序中 我使用 PyGObject PyGI 和 GStreamer 在 GUI 中显示视频 该视频显示在Gtk DrawingArea因此我需要获取它的窗口句柄realize 信号处理程序 在 Linux 上 我使用以下方法获取该句
  • 如何指定聚类的距离函数?

    我想对给定距离的点进行聚类 奇怪的是 似乎 scipy 和 sklearn 聚类方法都不允许指定距离函数 例如 在sklearn cluster AgglomerativeClustering 我唯一可以做的就是输入一个亲和力矩阵 这将非常
  • 通过 boto3 承担 IAM 用户角色时访问被拒绝

    Issue 我有一个 IAM 用户和一个 IAM 角色 我正在尝试将 IAM 用户配置为有权使用 STS 承担 IAM 角色 我不确定为什么收到 访问被拒绝 错误 Details IAM 角色 arn aws iam 123456789 r
  • Flask中使用的路由装饰器是如何工作的

    我熟悉 Python 装饰器的基础知识 但是我不明白这个用于 Flask 路由的特定装饰器是如何工作的 以下是 Flask 网站上的代码片段 from flask import Flask escape request app Flask
  • 如何删除 PyCharm 中的项目?

    如果我关闭一个项目 然后删除该项目文件夹 则在 PyCharm 重新启动后 会再次创建一个空的项目文件夹 只需按顺序执行以下步骤即可 他们假设您当前在 PyCharm 窗口中打开了该项目 单击 文件 gt 关闭项目 关闭项目 在 PyCha
  • 如何使用 python 的 http.client 准确读取一个响应块?

    Using http client在 Python 3 3 或任何其他内置 python HTTP 客户端库 中 如何一次读取一个分块 HTTP 响应一个 HTTP 块 我正在扩展现有的测试装置 使用 python 编写 http clie
  • McNemar 在 Python 中的测试以及分类机器学习模型的比较 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 有没有用 Python 实现的好的 McNemar 测试 我在 Scipy stats 或 Scikit
  • 数据框 - 平均列

    我在 pandas 中有以下数据框 Column 1 Column 2 Column3 Column 4 2 2 2 4 1 2 2 3 我正在创建一个数据框 其中包含第 1 列和第 2 列 第 3 列和第 4 列等的平均值 ColumnA
  • 如何使用 Python boto3 获取 redshift 中的列名称

    我想使用 python boto3 获取 redshift 中的列名称 创建Redshift集群 将数据插入其中 配置的机密管理器 配置 SageMaker 笔记本 打开Jupyter Notebook写入以下代码 import boto3
  • 如何将 sql 数据输出到 QCalendarWidget

    我希望能够在日历小部件上突出显示 SQL 数据库中的一天 就像启动程序时突出显示当前日期一样 在我的示例中 它是红色突出显示 我想要发生的是 当用户按下突出显示的日期时 数据库中日期旁边的文本将显示在日历下方的标签上 这是我使用 QT De
  • python celery -A 的无效值无法加载应用程序

    我有一个以下项目目录 azima init py main py tasks py task py from main import app app task def add x y return x y app task def mul
  • 使用标签或 href 传递 Django 数据

    我有一个包含链接的表 当单击该链接进行更多操作时 我想将一些数据传递给我的函数 my html table tbody for query in queries tr td value a href internal my func que
  • App Engine 实体到字典

    将 google app engine 实体 在 python 中 复制到字典对象的好方法是什么 我正在使用 db Expando 对象 所有属性均为扩展属性 Thanks 有一个名为foo尝试 foo dict
  • 使用seaborn绘制简单线图

    我正在尝试使用seaborn python 绘制ROC曲线 对于 matplotlib 我只需使用该函数plot plt plot one minus specificity sensitivity bs where one minus s
  • PyInstaller“ValueError:源代码字符串不能包含空字节”

    我得到了一个ValueError source code string cannot contain null bytes执行命令时pyinstaller main py在具有和不具有管理员权限的cmd中 Traceback most re
  • numpy polyfit 中使用的权重值是多少以及拟合误差是多少

    我正在尝试对 numpy 中的某些数据进行线性拟合 Ex 其中 w 是该值的样本数 即对于点 x 0 y 0 我只有 1 个测量值 该测量值是2 2 但对于这一点 1 1 我有 2 个测量值 值为3 5 x np array 0 1 2 3
  • 寻找完美的正方形

    我有这个Python代码 def sqrt x ans 0 if x gt 0 while ans ans lt x ans ans 1 if ans ans x print x is not a perfect square return
  • 在matlab中,如何读取python pickle文件?

    在 python 中 我生成了一个 p 数据文件 pickle dump allData open myallData p wb 现在我想在Matlab中读取myallData p 我的Matlab安装在Windows 8下 其中没有Pyt
  • 如何绘制更大的边界框和仅裁剪边界框文本 Python Opencv

    我正在使用 easyocr 来检测图像中的文本 该方法给出输出边界框 输入图像如下所示 Image 1 Image 2 使用下面的代码获得输出图像 But I want to draw a Single Bigger bounding bo
  • Python 中的 Unix cat 函数 (cat * > merged.txt)? [复制]

    这个问题在这里已经有答案了 一旦建立了目录 有没有办法在Python中使用Unix中的cat函数或类似的函数 我想将 files 1 3 合并到 merged txt 我通常会在 Unix 中找到该目录 然后运行 cat gt merged

随机推荐