基于pytorch的手势识别

2023-11-13

本次实验主要是使用pytorch完成手势识别。网络包含两个隐藏层,第一层隐藏层有576个节点,第二层隐藏层有144个节点,输入784个节点(图片大小为28×28),输出10个节点(10种手势)。

目录

1. 数据集处理

2. 神经网络的建立

3. 神经网络的训练

4. 神经网络的测试


1. 数据集处理

本次实验所用数据集为自建数据集,首先预览了解数据,确保数据能够被正常载入。

import pandas
from torch.utils.data import Dataset

import torch
import matplotlib.pyplot as plt


class GestureDataset(Dataset):
    def __init__(self, csv_file):
        self.dataset = pandas.read_csv(csv_file, header=0)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # 图像标签
        label = self.dataset.iloc[index, 0]
        target = torch.zeros(10)  # 神经网络预期输出
        target[label] = 1.0
        # 图像数据,取值范围是0~255,标准化为0~1
        image_value = torch.FloatTensor(self.dataset.iloc[index, 1:].values) / 255
        # 返回标签、图像数据张量以及目标张量
        return label, image_value, target

    def plot_image(self, index):
        arr = self.dataset.iloc[index, 1:].values.reshape(28, 28)
        plt.title("label = " + str(self.dataset.iloc[index, 0]))
        plt.imshow(arr, interpolation='none', cmap='gray')
        plt.show()


# 查看图片
gesture_dataset = GestureDataset('train.csv')
gesture_dataset.plot_image(9)
print(gesture_dataset[100])
print(len(gesture_dataset))

以上代码中各函数含义如下:

__len__() 函数的作用是返回DataFrame的大小。

__getitem__()函数索引获取数据集中的第 n 项,数据集中的第index项中提取一个标签(label)。返回值中的 target 表示神经网络的预期输出。除了与标签相对应的位置是1之外,其他值皆为0。比如手势 2 的 target 应该表示为[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]。

运行以上代码,结果如下:

2. 神经网络的建立

import torch
import torch.nn as nn

import pandas, numpy
import matplotlib.pyplot as plt


class Classifier(nn.Module):
    def __init__(self):
        # 初始化pytorch父类
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(784, 576),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(576),

            nn.Linear(576, 144),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(144),

            nn.Linear(144, 10),
            nn.Sigmoid()
        )

        # 创建损失函数
        self.Loss_function = nn.BCELoss()
        # 优化器
        self.optimiser = torch.optim.Adam(self.parameters(),
                                          lr=0.0001)
        # 记录训练进展的计数器和列表
        self.counter = 0
        self.process = []

    def forward(self, inputs):
        # 直接运行模型
        return self.model(inputs)

    def cnn_train(self, inputs, targets):
        # 计算网络的输出值
        outputs = self.forward(inputs)
        # 计算损失值
        loss = self.Loss_function(outputs, targets)
        # 梯度归零,反向传播,并更新权重
        self.optimiser.zero_grad()  # 梯度全部归零
        loss.backward()
        self.optimiser.step()  # 使用梯度更新可学习参数

        # 每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾,共36080张图片
        self.counter += 1
        if self.counter % 10 == 0:
            self.process.append(loss.item())

        # 在每10000次训练后打印计数器的值,了解训练进展的快慢
        if self.counter % 10000 == 0:
            print("counter=", self.counter)

    # 绘制训练过程的损失值
    def plot_progress(self):
        df = pandas.DataFrame(self.process, columns=["loss"])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.',
                grid=True, yticks=(0, 0.25, 0.5))
        plt.show()

3. 神经网络的训练

数据集处理部分代码保存为gesture_dataset.py,神经网络的建立部分代码保存为gesture_cnn.py,在创建了网络后,需要使用数据集训练网络,并保存网络参数以便后续使用。

import torch

from gesture_dataset import GestureDataset
from gesture_cnn import Classifier

# 创建神经网络
C = Classifier()
gesture_dataset = GestureDataset('train.csv')
# 在数据集训练神经网络
epochs = 3
for i in range(epochs):
    print('training epoch', i + 1, "of", epochs)
    for label, image_data_tensor, target_tensor in gesture_dataset:
        C.cnn_train(image_data_tensor, target_tensor)
    pass
pass

# 绘制分类器损失值
C.plot_progress()
# 保存网络
torch.save(C.model, 'gesture_cnn_model.pkl')

4. 神经网络的测试

from gesture_dataset import GestureDataset

import torch
import pandas
import numpy as np
import matplotlib.pyplot as plt

# 加载测试集数据
test_dataset = GestureDataset('test.csv')
record = 19
test_dataset.plot_image(record)

image_data = test_dataset[record][1]
# 调用训练后的神经网络
cnn_model = torch.load('gesture_cnn_model.pkl')
output = cnn_model(image_data)
# 绘制输出张量
# pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
#                                                legend=False, ylim=(0, 1))
# plt.show()

predict = output.detach().numpy()
print(np.where(predict == np.max(predict))[0][0])

# 测试正确率
T_test = 0
counter_test = 0

for label, image_data_tensor, target_tensor in test_dataset:
    predict = cnn_model(image_data_tensor).detach().numpy()
    if np.where(predict == np.max(predict))[0][0] == label:
        T_test += 1
    pass
    counter_test += 1
    if counter_test % 100 == 0:
        print('counter_test = ', counter_test)
pass

test_accuracy = T_test/len(test_dataset)
print('Test Accuracy = ', test_accuracy)

# 训练集正确率
train_dataset = GestureDataset('train.csv')

T_train = 0
counter_train = 0

for label, image_data_tensor, target_tensor in train_dataset:
    predict = cnn_model(image_data_tensor).detach().numpy()
    if np.where(predict == np.max(predict))[0][0] == label:
        T_train += 1
    pass
    counter_train += 1
    if counter_train % 1000 == 0:
        print('counter_train = ', counter_train)
pass

train_accuracy = T_train/len(train_dataset)
print('Train Accuracy = ', train_accuracy)
print(len(train_dataset))

测试结果如下:

网络最终在训练集上的正确率约为99.83%,在测试集上的正确率约为97.50%,测试结果表明网络性能较好,训练结果较好,最终的手势识别效果较好。

代码注释详细,作者能力有限,如有发现问题欢迎评论提出。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于pytorch的手势识别 的相关文章

  • python:查找围绕某个 GPS 位置的圆的 GPS 坐标的优雅方法

    我有一组以十进制表示的 GPS 坐标 并且我正在寻找一种方法来查找每个位置周围半径可变的圆中的坐标 这是一个例子 http green and energy com downloads test circle html我需要什么 这是一个圆
  • 为什么从 Pandas 1.0 中删除了日期时间?

    我在 pandas 中处理大量数据分析并每天使用 pandas datetime 最近我收到警告 FutureWarning pandas datetime 类已弃用 并将在未来版本中从 pandas 中删除 改为从 datetime 模块
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • 需要在python中找到print或printf的源代码[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我正在做一些我不能完全谈论的事情 我
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 立体太阳图 matplotlib 极坐标图 python

    我正在尝试创建一个与以下类似的简单的立体太阳路径图 http wiki naturalfrequent com wiki Sun Path Diagram http wiki naturalfrequency com wiki Sun Pa
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • 如何使用python在一个文件中写入多行

    如果我知道要写多少行 我就知道如何将多行写入一个文件 但是 当我想写多行时 问题就出现了 但是 我不知道它们会是多少 我正在开发一个应用程序 它从网站上抓取并将结果的链接存储在文本文件中 但是 我们不知道它会回复多少行 我的代码现在如下 r
  • 如何通过索引列表从 dask 数据框中选择数据?

    我想根据索引列表从 dask 数据框中选择行 我怎样才能做到这一点 Example 假设我有以下 dask 数据框 dict A 1 2 3 4 5 6 7 B 2 3 4 5 6 7 8 index x1 a2 x3 c4 x5 y6 x
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • javascript 是否有等效的 __repr__ ?

    我最接近Python的东西repr这是 function User name password this name name this password password User prototype toString function r
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • 从 NumPy ndarray 中选择行

    我只想从 a 中选择某些行NumPy http en wikipedia org wiki NumPy基于第二列中的值的数组 例如 此测试数组的第二列包含从 1 到 10 的整数 gt gt gt test numpy array nump
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • Python ImportError:无法导入名称 __init__.py

    我收到此错误 ImportError cannot import name life table from cdc life tables C Users tony OneDrive Documents Retirement retirem
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重

随机推荐

  • 一起学nRF51xx 10 -  rng

    前言 随机数产生器 RNG 的结构 随机数发生器 RNG 根据内部热产生真实的非确定性随机数噪音 RNG通过触发START任务启动 并通过触发STOP任务停止 当随机数已经生成 它会产生一个VALRDY事件 同时把随机数存入VALUE寄存器
  • 智慧城市领域大单,巨头占尽优势

    智慧城市领域 哪个公司做的比较好 一 前言 二 智慧城市中标大单 清单 三 中标厂商分析 1 华为 2 科大讯飞 3 腾讯 4 阿里 5 中国电科 6 中国电子 7 百度 8 数字广东 四 获取 智慧城市等全套最新解决方案合集 一 前言 在
  • python eclipse+pydev(An error has occurred when creating this preference page)

    Eclipse 安装pydev Help gt Install New Software gt add gt Location http pydev org updates 点击pydev左边的小三角勾选pydev for eclipse
  • Shell init Ubuntu

    echo HISTFILESIZE 99999 gt gt bashrc echo HISTSIZE 99999 gt gt bashrc echo HISTTIMEFORMAT F T gt gt bashrc echo PROMPT C
  • Thrift原理简析(JAVA)

    Apache Thrift是一个跨语言的服务框架 本质上为RPC 同时具有序列化 反序列化机制 当我们开发的service需要开放出去的时候 就会遇到跨语言调用的问题 JAVA语言开发了一个UserService用来提供获取用户信息的服务
  • CUDA编程 基础与实践 学习笔记(十)

    线程束 warp 一个GPU由多个SM组成 一个SM上可以放多个线程块 不同线程块之间并行或顺序执行 一个线程块分为多个线程束 一个线程束由32个线程 有连续的线程号 组成 从更细粒度来看 一个SM以一个线程束为单位产生 管理 调度 执行线
  • Java面向对象 - 封装、继承和多态

    第1关 什么是封装 如何使用封装 相关知识 为了完成本关任务 你需要掌握 1 什么是封装 2 封装的意义 3 实现Java封装的步骤 package case1 public class TestPersonDemo public stat
  • GoLang之”奇怪用法“实践总结

    2013 11 23 wcdj 0 摘要 本文通过对A Tour of Go的实践 总结Go语言的基础用法 1 Go语言 奇怪用法 有哪些 1 go的变量声明顺序是 先写变量名 再写类型名 此与C C 的语法孰优孰劣 可见下文解释 http
  • 销售心理学

    销售中的心理学 影响你一生的销售心理学书籍 要想钓到鱼 就要像鱼一样思考 在生活中 如果想钓到鱼 你就得像鱼那样思考 而不是像渔夫那样思考 当你对鱼了解得越多 你也就越来越会钓鱼了 这样的想法用在销售中同样适用 要知道 销售的过程其实就是销
  • 【Redis17】Redis进阶:管道

    Redis进阶 管道 管道是啥 我们做开发的同学们经常会在 Linux 环境中用到管道命令 比如 ps ef grep php 在之前学习 Laravel框架时的 Laravel6 4 管道过滤器https mp weixin qq com
  • Latex使用

    问题 在使用latex的过程中插入图片 在某些条件下 图片可能会出现越过后续的文字出现在下一页的页首 解决办法 在该tex文件首部加上 usepackage stfloats 然后参数设置成H如下 begin figure H center
  • 使用frp 实现内网穿透 & 将私人电脑变成一个服务器

    使用frp 实现内网穿透 frp 是什么 frp 是一个可用于内网穿透的高性能的反向代理应用 支持 tcp udp 协议 为 http 和 https 应用协议提供了额外的能力 且尝试性支持了点对点穿透 作用 比如你需要用到云服务器部署你的
  • 阅读GFS论文

    GFS论文发表距今已经十几年了 据之开源的hdfs也已经在业界得到了广泛应用 为了取得分布式系统的真经 拜读一下这篇经典论文 重要假设 软硬件失败乃家常便饭 我们写大文件 不屑小文件 文件改动的主流是追加新数据 随机写是非主流 一旦写完 仅
  • Neon Instruction C支持的向量运算

    转载请标明出处 https blog csdn net u013752202 article details 92008843 文章目的 快速索引到需要的向量运算 vadd gt ri ai bi 1 Vector add 正常指令 r a
  • pagehelper使用方法及参数说明

    pagehelper使用方法及参数说明 使用方法 Override public PageInfo
  • spring源码--10--IOC高级特性--autowiring实现原理

    spring源码 10 IOC高级特性 autowiring实现原理 1 Spring IoC容器提供了2种方式 管理Bean的依赖关系 1 1 显式管理 通过BeanDefinition的属性值和构造方法实现Bean依赖关系管理 1 2
  • vue学习笔记:在vscode中使用@提示路径

    在vscode中输入 后如果可以智能提示路径 可以有效防止路径名称输入错误 减少不必要的麻烦 效果如下图所示 安装 Path Autocomplete 插件后可以实现路径的智能提示 步骤如下 1 在vscode中查找Path Autocom
  • 关于shell运行python文件中的错误——shell脚本换行

    问题 https ask csdn net questions 7900411 spm 1001 2014 3001 5505 问题由来 由于工程需要在本地window中写 当需要比较少的算力时在本地跑 当需要比较大的算力时就需要在auto
  • K8S调用GPU资源配置指南

    06 09 K8S调用GPU资源配置指南 时间 版本号 修改描述 修改人 2022年6月9日15 33 12 V0 1 新建K8S调用GPU资源配置指南 编写了Nvidia驱动安装过程 2022年6月10日11 16 52 V0 2 添加K
  • 基于pytorch的手势识别

    本次实验主要是使用pytorch完成手势识别 网络包含两个隐藏层 第一层隐藏层有576个节点 第二层隐藏层有144个节点 输入784个节点 图片大小为28 28 输出10个节点 10种手势 目录 1 数据集处理 2 神经网络的建立 3 神经