Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

2023-11-13

1、红酒数据介绍

经典的红酒分类数据集是指UCI机器学习库中的Wine数据集。该数据集包含178个样本,每个样本有13个特征,可以用于分类任务。

具体每个字段的含义如下:
alcohol:酒精含量百分比
malic_acid:苹果酸含量(克/升)
ash:灰分含量(克/升)
alcalinity_of_ash:灰分碱度(以mEq/L为单位)
magnesium:镁含量(毫克/升)
total_phenols:总酚含量(以毫克/升为单位)
flavanoids:类黄酮含量(以毫克/升为单位)
nonflavanoid_phenols:非类黄酮酚含量(以毫克/升为单位)
proanthocyanins:原花青素含量(以毫克/升为单位)
color_intensity:颜色强度(以 absorbance 为单位,对应于 1cm 路径长度处的相对宽度)
hue:色调,即色彩的倾向性或相似性(在 1 至 10 之间的一个数字)
od280/od315_of_diluted_wines:稀释葡萄酒样品的光密度比值,用于测量葡萄酒中各种化合物的浓度
proline:脯氨酸含量(以毫克/升为单位),是一种天然氨基酸,与葡萄酒的品质和口感有关。

2、红酒数据集分析

2.1 加载红酒数据集

# 加载红酒数据集
wineBunch = load_wine()
type(wineBunch)

sklearn.utils.Bunch
sklearn.utils.Bunch是Scikit-learn库中的一个数据容器,类似于Python字典(dictionary),
它可以存储任意数量和类型的数据,并且可以使用点(.)操作符来访问数据。Bunch常用于存储机器学习模型的数据集,
例如描述特征矩阵的数据、相关联的目标向量、特征名称等等,以便于组织和传递这些数据到模型中进行训练或预测。

2.2 红酒数据集形状

len(wineBunch.data),len(wineBunch.target)

(178, 178)

2.3 红酒数据集打印前5行和后5行

featuresDf = pd.DataFrame(data=wineBunch.data, columns=wineBunch.feature_names)   # 特征数据
labelDf = pd.DataFrame(data=wineBunch.target, columns=["target"])               # 标签数据
wineDf = pd.concat([featuresDf, labelDf], axis=1)  # 横向拼接
wineDf.head(5).append(wineDf.tail(5))              # 打印首尾5行

在这里插入图片描述

2.4 红酒数据集列名

wineDf.columns

Index([‘alcohol’, ‘malic_acid’, ‘ash’, ‘alcalinity_of_ash’, ‘magnesium’,
‘total_phenols’, ‘flavanoids’, ‘nonflavanoid_phenols’,
‘proanthocyanins’, ‘color_intensity’, ‘hue’,
‘od280/od315_of_diluted_wines’, ‘proline’, ‘target’],
dtype=‘object’)

2.5 红酒数据集目标标签

print(wineDf.target.unique())
[0 1 2]

3、Transformer对红酒进行分类

3.1 Transformer介绍

Transformer是一种基于注意力机制的神经网络结构,主要用于自然语言处理领域中的序列到序列转换任务,比如机器翻译、文本摘要等。它在2017年被Google提出,并被成功应用于Google Translate中。

Transformer的主要特点在于使用了完全基于注意力机制的编码器-解码器结构,避免了传统循环神经网络(如LSTM)中存在的长序列依赖问题和梯度消失问题。此外,Transformer还使用了残差连接和层归一化等技术,增强了模型的训练效果和泛化能力。

在Transformer模型中,输入序列和输出序列都被表示为固定长度的向量,称为词向量,由多个嵌入层和多个编码器和解码器层组成。其中,编码器和解码器层包括多头注意力机制、前馈神经网络和残差连接等模块,以实现对序列的有效建模和转换。

3.2 引入依赖库

import random         # 导入 random 模块,用于随机数生成
import torch          # 导入 PyTorch 模块,用于深度学习任务
import numpy as np    # 导入 numpy 模块,用于数值计算
from torch import nn  # 从 PyTorch 中导入神经网络模块
from sklearn import datasets  # 从sklearn引入数据集
from sklearn.model_selection import train_test_split  # 导入 sklearn 库中的 train_test_split 函数,用于数据划分
from sklearn.preprocessing import StandardScaler     # 导入 sklearn 库中的 StandardScaler 类,用于数据标准化

3.3 设置随机种子

# 设置随机种子,让模型每次输出的结果都一样
seed_value = 42
random.seed(seed_value)                         # 设置 random 模块的随机种子
np.random.seed(seed_value)                      # 设置 numpy 模块的随机种子
torch.manual_seed(seed_value)                   # 设置 PyTorch 中 CPU 的随机种子
#tf.random.set_seed(seed_value)                 # 设置 Tensorflow 中随机种子
if torch.cuda.is_available():                   # 如果可以使用 CUDA,设置随机种子
    torch.cuda.manual_seed(seed_value)          # 设置 PyTorch 中 GPU 的随机种子
    torch.backends.cudnn.deterministic = True   # 使用确定性算法,使每次运行结果一样
    torch.backends.cudnn.benchmark = False      # 不使用自动寻找最优算法加速运算

3.4 检测GPU是否可用

# 检测GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3.5 加载数据集

# 加载红酒数据集
wine = datasets.load_wine()
X = wine.data
y = wine.target

3.6 拆分训练集和测试集

# 拆分成训练集和测试集,训练集80%和测试集20%
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

3.7 缩放数据

# 缩放数据
scaler = StandardScaler() # 创建一个标准化转换器的实例
X_train = scaler.fit_transform(X_train) # 对训练集进行拟合(计算平均值和标准差)
X_test = scaler.transform(X_test) # 对测试集进行标准化转换,使用与训练集相同的平均值和标准差

3.8 转化成pytorch张量

# 将训练集转换为 PyTorch 张量,并转换为浮点数类型,如果 GPU 可用,则将张量移动到 GPU 上
X_train = torch.tensor(X_train).float().to(device)
# 将训练集转换为 PyTorch 张量,并转换为长整型,如果 GPU 可用,则将张量移动到 GPU 上
y_train = torch.tensor(y_train).long().to(device)
X_test = torch.tensor(X_test).float().to(device)
y_test = torch.tensor(y_test).long().to(device)

3.9 定义Transformer模型

定义 Transformer 模型

class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerModel, self).__init__()
        # 构建Transformer编码层,参数包括输入维度、注意力头数
        # 其中d_model要和模型输入维度相同
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size,  # 输入维度
                                                        nhead=1)             # 注意力头数
        # 构建Transformer编码器,参数包括编码层和层数
        self.encoder = nn.TransformerEncoder(self.encoder_layer,             # 编码层
                                             num_layers=1)                   # 层数
        # 构建线性层,参数包括输入维度和输出维度(num_classes)
        self.fc = nn.Linear(input_size,                                      # 输入维度
                            num_classes)                                     # 输出维度

    def forward(self, x):
        #print("A:", x.shape)  # torch.Size([142, 13])
        x = x.unsqueeze(1)    # 增加一个维度,变成(batch_size, 1, input_size)的形状
        #print("B:", x.shape)  # torch.Size([142, 1, 13])
        x = self.encoder(x)   # 输入Transformer编码器进行编码
        #print("C:", x.shape)  # torch.Size([142, 1, 13])
        x = x.squeeze(1)      # 压缩第1维,变成(batch_size, input_size)的形状
        #print("D:", x.shape)  # torch.Size([142, 13])
        x = self.fc(x)        # 输入线性层进行分类预测
        #print("E:", x.shape)  # torch.Size([142, 3])
        return x
# 初始化Transformer模型,并移动到GPU
model = TransformerModel(input_size=13,             # 输入维度
                         num_classes=3).to(device)  # 输出维度

3.10 定义损失函数和优化器

定义损失函数和优化器

criterion = nn.CrossEntropyLoss() # 定义损失函数-交叉熵损失函数

定义优化器

optimizer = torch.optim.Adam(model.parameters(), # 模型参数
lr=0.01) # 学习率

3.11 训练模型

# 训练模型
num_epochs = 100     # 训练100轮
for epoch in range(num_epochs):
    # 正向传播:将训练数据放到模型中,得到模型的输出
    outputs = model(X_train)
    loss = criterion(outputs, y_train)  # 计算交叉熵损失

    # 反向传播和优化:清零梯度、反向传播计算梯度,并根据梯度更新模型参数
    optimizer.zero_grad()  # 清零梯度
    loss.backward()        # 反向传播计算梯度
    optimizer.step()       # 根据梯度更新模型参数

    # 每10轮打印一次损失值,查看模型训练的效果
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

3.12 测试模型

# 测试模型,在没有梯度更新的情况下,对测试集进行推断
with torch.no_grad():
    outputs = model(X_test)   # 使用训练好的模型对测试集进行预测
    _, predicted = torch.max(outputs.data, 1)  # 对输出的结果取 argmax,得到预测概率最大的类别
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)  # 计算模型在测试集上的准确率
    print(f'Test Accuracy: {accuracy:.2f}')   # 打印测试集准确率

3.13 控制输出

Epoch [10/100], Loss: 0.1346
Epoch [20/100], Loss: 0.0325
Epoch [30/100], Loss: 0.0116
Epoch [40/100], Loss: 0.0064
Epoch [50/100], Loss: 0.0040
Epoch [60/100], Loss: 0.0029
Epoch [70/100], Loss: 0.0026
Epoch [80/100], Loss: 0.0021
Epoch [90/100], Loss: 0.0019
Epoch [100/100], Loss: 0.0019
Test Accuracy: 1.00

Process finished with exit code 0

正确率:100%

3.14 保存模型

# 保存模型
PATH = "model.pt"
torch.save(model.state_dict(), PATH)

3.15 加载模型

加载模型

model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子 的相关文章

  • Spark 请求最大计数

    我是 Spark 的初学者 我尝试请求允许我检索最常访问的网页 我的要求如下 mostPopularWebPageDF logDF groupBy webPage agg functions count webPage alias cntW
  • Python中Decimal类型的澄清

    每个人都知道 或者至少 每个程序员都应该知道 http docs oracle com cd E19957 01 806 3568 ncg goldberg html 即使用float类型可能会导致精度错误 然而 在某些情况下 精确的解决方
  • 我怎样才能更多地了解Python的内部原理? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我使用Python编程已经有半年多了 我对Python内部更感兴趣 而不是使用Python开发应用程序
  • Python模块可以访问英语词典,包括单词的定义[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 python 模块 它可以帮助我从英语词典中获取单词的定义 当然有enchant 这可以帮助我检查该单词是否存在于英语中
  • 从 ffmpeg 获取实时输出以在进度条中使用(PyQt4,stdout)

    我已经查看了很多问题 但仍然无法完全弄清楚 我正在使用 PyQt 并且希望能够运行ffmpeg i file mp4 file avi并获取流式输出 以便我可以创建进度条 我看过这些问题 ffmpeg可以显示进度条吗 https stack
  • 忽略 Mercurial hook 中的某些 Mercurial 命令

    我有一个像这样的善变钩子 hooks pretxncommit myhook python path to file myhook 代码如下所示 def myhook ui repo kwargs do some stuff 但在我的例子中
  • 在Python中调整图像大小

    我有一张尺寸为 288 352 的图像 我想将其大小调整为 160 240 我尝试了以下代码 im imread abc png img im resize 160 240 Image ANTIALIAS 但它给出了一个错误TypeErro
  • 更改 `base_compiledir` 以将编译后的文件保存在另一个目录中

    theano base compiledir指编译后的文件存放的目录 有没有办法可以永久设置theano base compiledir到不同的位置 也许通过修改一些内部 Theano 文件的内容 http deeplearning net
  • 使用 OLS 回归预测未来值(Python、StatsModels、Pandas)

    我目前正在尝试在 Python 中实现 MLR 但不确定如何将我找到的系数应用于未来值 import pandas as pd import statsmodels formula api as sm import statsmodels
  • TensorFlow的./configure在哪里以及如何启用GPU支持?

    在我的 Ubuntu 上安装 TensorFlow 时 我想将 GPU 与 CUDA 结合使用 但我却停在了这一步官方教程 http www tensorflow org get started os setup md 这到底是哪里 con
  • 如何设置 Celery 来调用自定义工作器初始化?

    我对 Celery 很陌生 我一直在尝试设置一个具有 2 个独立队列的项目 一个用于计算 另一个用于执行 到目前为止 一切都很好 我的问题是执行队列中的工作人员需要实例化一个具有唯一 object id 的类 每个工作人员一个 id 我想知
  • Seaborn Pairplot 图例不显示颜色

    我一直在学习如何在Python中使用seaborn和pairplot 这里的一切似乎都工作正常 但由于某种原因 图例不会显示相关的颜色 我无法找到解决方案 因此如果有人有任何建议 请告诉我 x sns pairplot stats2 hue
  • 创建嵌套字典单行

    您好 我有三个列表 我想使用一行创建一个三级嵌套字典 i e l1 a b l2 1 2 3 l3 d e 我想创建以下嵌套字典 nd a 1 d 0 e 0 2 d 0 e 0 3 d 0 e 0 b a 1 d 0 e 0 2 d 0
  • 使用 Firefox 绕过弹出窗口下载文件:Selenium Python

    我正在使用 selenium 和 python 来从中下载某些文件web page http www oceanenergyireland com testfacility corkharbour observations 我之前一直使用设
  • mac osx 10.8 上的初学者 python

    我正在学习编程 并且一直在使用 Ruby 和 ROR 但我觉得我更喜欢 Python 语言来学习编程 虽然我看到了 Ruby 和 Rails 的优点 但我觉得我需要一种更容易学习编程概念的语言 因此是 Python 但是 我似乎找不到适用于
  • 如何在 OSX 上安装 numpy 和 scipy?

    我是 Mac 新手 请耐心等待 我现在使用的是雪豹 10 6 4 我想安装numpy和scipy 所以我从他们的官方网站下载了python2 6 numpy和scipy dmg文件 但是 我在导入 numpy 时遇到问题 Library F
  • 默认情况下,Keras 自定义层参数是不可训练的吗?

    我在 Keras 中构建了一个简单的自定义层 并惊讶地发现参数默认情况下未设置为可训练 我可以通过显式设置可训练属性来使其工作 我无法通过查看文档或代码来解释为什么会这样 这是应该的样子还是我做错了什么导致默认情况下参数不可训练 代码 im
  • 如何为每个屏幕添加自己的 .py 和 .kv 文件?

    我想为每个屏幕都有一个单独的 py 和 kv 文件 应通过 main py main kv 中的 ScreenManager 选择屏幕 设计应从文件 screen X kv 加载 类等应从文件 screen X py 加载 Screens
  • 当鼠标悬停在上面时,intellisense vscode 不显示参数或文档

    我正在尝试将整个工作流程从 Eclipse 和 Jupyter Notebook 迁移到 VS Code 我安装了 python 扩展 它应该带有 Intellisense 但它只是部分更糟糕 我在输入句点后收到建议 但当将鼠标悬停在其上方
  • 您可以将操作直接应用于map/reduce/filter 中的参数吗?

    map and filter通常可以与列表理解互换 但是reduce并不那么容易被交换map and filter 此外 在某些情况下我仍然更喜欢函数语法 但是 当您需要对参数本身进行操作时 我发现自己正在经历语法体操 最终必须编写整个函数

随机推荐

  • html与python后端交互,python后端中取表单

    参考 http www manongjc com detail 13 owqqwhqvsqworkh html 前端
  • CDZSC_2022寒假个人训练赛21级(1)

    A 题意 略 题解 将n个数加起来的总和除以n即可 include
  • 红帽Linux系统管理员学习哪些内容?

    开源技术现在越来越火 无论是从事DBA 网络运维还是开发 云计算 人工智能等岗位 都需要具备些Linux基础知识 本文主要介绍Redhat Linux系统管理员一般学习哪些内容 Redhat Linux系统管理学习内容 课程概述 一 红帽系
  • 计算机考研经验分享:一战暨南大学(死亡计专),调剂七天上岸华侨大学

    计算机考研经验分享 一战暨南大学 死亡计专 调剂七天上岸华侨大学 前言 这篇文章我本来很早就打算写了 调剂过程只有过来人才懂吧 因此 我希望自己的这篇文章能对看到的人考研有所帮助 我是12号晚上11点左右收到的录取通知 然后13号太兴奋了
  • 使用Element-UI中的Upload控件上传文件 (Vue + Flask)

    知识点 前端 使用 http request 覆盖默认的上传行为 可以自定义上传的实现 使用 DataForm 携带需要上传的文件 需要将http request 的 headers中的content type 设置为 content ty
  • 微信小程序简介

    一 了解微信小程序微信小程序 小程序的一种 英文名Wechat Mini Program 是一种不需要下载安装即可使用的应用 张小龙 发布时间2017年1月9日 二 微信小程序和普通H5的区别1 微信小程序用开发者工具来查看预览页面 H5用
  • R中的统计模型

    R中的统计模型 这一部分假定读者已经对统计方法 特别是回归分析和方差分析有一定的了解 后面我们还会假定读者对广义线性模型和非线性模型也有所了解 R已经很好地定义了统计模型拟合中的一些前提条件 因此我们能构建出一些通用的方法以用于各种问题 R
  • 股票资金建仓分仓补仓计算器

    软件演示图 百度网盘下载地址 http pan baidu com s 1o8Prq6A 软件功能原理与应用价值 我们每个人买股票基本很难做到一买就涨的 买了后可能会下跌一波段再涨 则此就会另到我们时常赚不到钱而纠结卖出一分钱都不能获利而离
  • grep常用需要转义字符汇总

    最近用grep的时候发现转义非常恶心 干脆做个测试 统计一下表示特殊语意时 需要转义的字符 这里的特殊语意是指非匹配自己本身 有特殊含义的时候
  • 神经网络参数理解与设置

    一 超参数 1 学习率 每次迭代的步长 决定着目标函数能否收敛到局部最小值以及何时收敛到最小值 学习率越高 步长越大 2 batch 当训练数据过多时 无法一次将所有的数据送入计算 所以需要将数据分成几个部分 多个batch 逐一地送入计算
  • excel重复的数据只计数一次_你还在加班核对重复数据?3个Excel技巧教你快速进行数据查重...

    相信使用Excel办公的同学 绝大多数都会碰到一个问题 它就是数据重复值的问题 因为数据里面有重复内容 经常会让我们的工作变得非常的棘手 如上图所示 里面是我们仓库发出的单号 我们需要里面就有包含重复发货的单号 如果我们单凭肉眼去看基本是不
  • 联想电脑安装虚拟机出现不可恢复的错误

    VMware Workstation 不可恢复错误 vcpu 0 vcpu 0 VERIFY vmcore vmm main cpuid c 376 bugNr 1036521 日志文件位于 F centos vmware log 中 您可
  • websocket协议与实现原理

    文章目录 一 websocket 二 websocket的协议实现 websocket的协议格式 websocket如何验证客户端合法 websocket传输的明文和密文的传输 websocket如何断开 实现 一 websocket we
  • __attribute__((__aligned__(n)))对结构体对齐的影响

    1 attribute 是什么 attribute 是GCC里的编译参数 用法有很多种 感兴趣可以阅读一下gcc的相关文档 这里说一下 attribute 对变量和结构体对齐的影响 这里的影响大概分为两个方面 对齐和本身占用的字节数的大小
  • android ndk常见的问题及解决的方法

    原文 http blog csdn net fangyuanseu article details 6857911 在ndk编译的过程中遇到的一些问题 1 在用ndk build编译的时候 被编译文件的路径中不能包含空格 如果包含有空格将会
  • Content-Type的几种常用数据编码格式

    Content Type 内容类型 一般是指网页中存在的Content Type ContentType属性指定请求和响应的HTTP内容类型 如果未指定 ContentType 默认为text html 1 text html 文本方式的网
  • Ubuntu 环境下使用中文输入法

    Ubuntu 环境下使用中文输入法 安装fcitx 1 进入系统设置 gt 语言支持 将汉语 中国 拖到最上面 如果列表中没有 选择 添加或删除语言 来添加 2 切换键盘输入法系统 将其修改为fcitx 如果下拉框中没有显示fcitx 则需
  • java poi导入Excel、导出excel

    java poi导入Excel 导出excel 导出meven架包
  • hive中判断一个字符串是否包含另一个子串的四种方法,sql中也可用

    hive中判断一个字符串是否包含另一个子串的四种方法 如果你有一个数据需求 需要从一个字段中 判断是否有一个字符串 你该怎么做 一 方法1 like和rlike 最能想到的方法 用like或者rlike select i want to t
  • Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

    1 红酒数据介绍 经典的红酒分类数据集是指UCI机器学习库中的Wine数据集 该数据集包含178个样本 每个样本有13个特征 可以用于分类任务 具体每个字段的含义如下 alcohol 酒精含量百分比 malic acid 苹果酸含量 克 升