Kaggle手势符号识别项目实战

2023-11-06

项目数据集地址:https://www.kaggle.com/datasets/ardamavi/sign-language-digits-dataset

观察到数据集已经做过预先的整理,十分工整,txt文件中类别标记清晰详细

 项目文件如上图所示,接下来分文件展示代码。

dataset.py

import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Lambda
from PIL import Image

ANNOTATIONS_FILE = "./images/train.txt"
IMG_DAR = "./images/train"


class CustomImageDataset(Dataset):
    def __init__(self):
        with open(ANNOTATIONS_FILE, "r") as f:
            # 读取标签文件 读取一行,去掉结尾的\n 然后根据空格分割为图片地址和标签
            self.labels = [line.strip('\n').split(" ") for line in f.readlines()]  
        self.img_dir = IMG_DAR       # 图片地址
        self.transform = ToTensor()  # 图片转化方法
        self.target_transform = Lambda(lambda y: int(y))  # 标签转换方法

    # __len__ 方法返回数据集的总长度
    def __len__(self):
        return len(self.labels)

    # __getitem__ 方法使数据集可以使用下表索引,返回值为一个样本
    def __getitem__(self, idx):
        image = Image.open(self.labels[idx][0])
        label = self.labels[idx][1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        sample = {'image': image, 'label': label}
        return sample

model.py 

from torch import nn


class CnnNet(nn.Module):
    def __init__(self):
        super(CnnNet, self).__init__()
        self.conve1 = nn.Sequential(
            # Conv2d 卷积神经网络, 
            #rgb色彩,所以是3
            #位深
            #卷积核的大小 5*5
            nn.Conv2d(3, 24, 5, padding=2), 
            # 归一化层:数据经过处理之后就会变成均值为零方差为一的正态分布,防止梯度消失
            nn.BatchNorm2d(24),
            # 激活函数
            nn.ReLU()
        )
        # 池化层:作用是防止过拟合和减小训练参数
        # 原来的图片是64*64 的 现在已经变成了32 * 32的 
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conve2 = nn.Sequential(
            nn.Conv2d(24, 48, 3, padding=1),
            nn.BatchNorm2d(48),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc = nn.Sequential(
            #输出的“图片”16 * 16 * 48 的大小,全连接层开始时有16*16*48
            nn.Linear(48 * 16 * 16, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 6)
        )

    def forward(self, x):
        x = self.conve1(x)
        x = self.pool1(x)
        x = self.conve2(x)
        x = self.pool2(x)
        # 把原先tensor中的数据按照行优先的顺序排成一个一维的数据
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        return out

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from dataset import CustomImageDataset
from model import MyCnnNet


def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    for i in range(0, 100):
        for batch, data in enumerate(dataloader):
            pred = model(data['image'].to(device))  
            loss = loss_fn(pred, data['label'].to(device))

            # 清空一下梯度
            optimizer.zero_grad()
            # 进行反向传播和模型优化
            loss.backward()
            optimizer.step()

            # 每隔一段时间输出一下过程
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(data['image'])
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    # 保存模型
    torch.save(model, './model.pkl')


if __name__ == '__main__':
    model = MyCnnNet()
    device = torch.device('cuda:0')
    model.to(device)

    train_dataloader = DataLoader(CustomImageDataset(), batch_size=32)

    loss_fn = nn.CrossEntropyLoss()#交叉熵

    learning_rate = 1e-3
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    #optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # 调用训练方法
    train(train_dataloader, model, loss_fn, optimizer, device)

模型会存储在./路径下,调用模型使用torch.load()既可运行识别自己的图像。

测试

运行以下测试代码片段

import torch
from PIL import Image
from torchvision.transforms import ToTensor, Lambda
# 1. 模型的加载
model = torch.load('./model.pkl')

# 2. 加载图片并且转化为tensor
transform = ToTensor()
img_in = Image.open("./images/test/signs/img_0008.png")
device = torch.device('cuda:0')
# 使用 unsqueeze(0)将序列扩充一维,在训练时我们使用的是一组图片,这里的一张成组
img_in = transform(img_in).unsqueeze(0).to(device)

# 3. 将图片扔到模型里得到输出
out = model(img_in)
print(out)

结果:

tensor([[-4.7332, -5.4074, -2.5515,  3.7113,  6.9314,  9.3520]],
       device='cuda:0', grad_fn=<AddmmBackward>)

概率最高的为数字5

 经检查结果与实际图像匹配。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Kaggle手势符号识别项目实战 的相关文章

  • numba 函数何时编译?

    我正在研究这个例子 http numba pydata org numba doc 0 15 1 examples html multi threading http numba pydata org numba doc 0 15 1 ex
  • Spark 请求最大计数

    我是 Spark 的初学者 我尝试请求允许我检索最常访问的网页 我的要求如下 mostPopularWebPageDF logDF groupBy webPage agg functions count webPage alias cntW
  • 围绕 readline 构建的 python 批处理的触发器选项卡完成

    背景 我有一个 python 程序 它导入并使用 readline 模块来构建自制的命令行界面 我有第二个 python 程序 围绕 Bottle 一个 Web 微框架构建 充当该 CLI 的前端 第二个 python 程序向第一个程序打开
  • django_openid_auth TypeError openid.yadis.manager.YadisServiceManager 对象不是 JSON 可序列化

    I used django openid auth在我的项目上 一段时间以来它运行得很好 但今天 我测试了该应用程序并遇到了这个异常 Environment Request Method GET Request URL http local
  • 使用 python 进行串行数据记录

    Intro 我需要编写一个小程序来实时读取串行数据并将其写入文本文件 我在读取数据方面取得了一些进展 但尚未成功地将这些信息存储在新文件中 这是我的代码 from future import print function import se
  • python future 和元组解包

    实现像使用 future 进行元组解包这样的事情的优雅 惯用的方法是什么 我有这样的代码 a b c f x y g a b z h y c 我想将其转换为使用期货 理想情况下我想写一些类似的东西 a b c ex submit f x y
  • python 模拟第三方模块

    我正在尝试测试一些处理推文的类 我使用 Sixohsix twitter 来处理 Twitter API 我有一个类充当 Twitter 类的外观 我的想法是模拟实际的 Sixohsix 类 通过随机生成新推文或从数据库检索它们来模拟推文的
  • Django 模型在模板中不可迭代

    我试图迭代模型以获取列表中的第一个图像 但它给了我错误 即模型不可迭代 以下是我的模型和模板的代码 我只需要获取与单个产品相关的列表中的第一个图像 模型 py class Product models Model title models
  • Argparse nargs="+" 正在吃位置参数

    这是我的解析器配置的一小部分 parser add argument infile help The file to be imported type argparse FileType r default sys stdin parser
  • 忽略 Mercurial hook 中的某些 Mercurial 命令

    我有一个像这样的善变钩子 hooks pretxncommit myhook python path to file myhook 代码如下所示 def myhook ui repo kwargs do some stuff 但在我的例子中
  • 如何计算numpy数组中元素的频率?

    我有一个 3 D numpy 数组 其中包含重复的元素 counterTraj shape 13530 1 1 例如 counterTraj 包含这样的元素 我只显示了几个元素 array 136 129 130 103 102 101 我
  • 切片 Dataframe 时出现 KeyError

    我的代码如下所示 d pd read csv Collector Output csv df pd DataFrame data d dfa df copy dfa dfa rename columns OBJECTID Object ID
  • 在Python中调整图像大小

    我有一张尺寸为 288 352 的图像 我想将其大小调整为 160 240 我尝试了以下代码 im imread abc png img im resize 160 240 Image ANTIALIAS 但它给出了一个错误TypeErro
  • 如何从Python中的字符串中提取变量名称和值

    我有一根绳子 data var1 id 12345 name John White python中有没有办法将var1提取为python变量 更具体地说 我对字典变量感兴趣 这样我就可以获得变量的值 id和name python 这是由提供
  • Numpy 过滤器平滑零区域

    我有一个 0 及更大整数的 2D numpy 数组 其中值代表区域标签 例如 array 9 9 9 0 0 0 0 1 1 1 9 9 9 9 0 7 1 1 1 1 9 9 9 9 0 2 2 1 1 1 9 9 9 8 0 2 2 1
  • 首先对列表中最长的项目进行排序

    我正在使用 lambda 来修改排序的行为 sorted list key lambda item item lower len item 对包含元素的列表进行排序A1 A2 A3 A B1 B2 B3 B 结果是A A1 A2 A3 B
  • 将 matplotlib 颜色图集中在特定值上

    我正在使用 matplotlib 颜色图 seismic 绘制绘图 并且希望白色以 0 为中心 当我在不进行任何更改的情况下运行脚本时 白色从 0 下降到 10 我尝试设置 vmin 50 vmax 50 但在这种情况下我完全失去了白色 关
  • 使用 NumPy 将非均匀数据从文件读取到数组中

    假设我有一个如下所示的文本文件 33 346 1223 10 23 11 23 12 23 13 23 14 23 15 23 16 24 10 24 11 24 12 24 13 24 14 24 15 24 16 25 14 25 15
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0
  • 具有自定义值的 Django 管理外键下拉列表

    我有 3 个 Django 模型 class Test models Model pass class Page models Model test models ForeignKey Test class Question model M

随机推荐

  • 第一章:VUE3学习(一)---Nodejs安装以及环境变量配置

    Nodejs安装以及环境变量配置 1 下载Nodejs 1 1最新版下载 1 2历史版本下载 2 安装 3 验证 4 环境变量配置 5 npm下载设置 6 测试 6 设置国内镜像提高下载速度 1 下载Nodejs 1 1最新版下载 直接官网
  • 用QT写一个类似的安装向导界面

    本文目录 功能描述 功能实现 框架 功能1 点击同意协议 才能进行下一步 功能2 选中指定路径的文件夹 并遍历该文件夹下所有的文件 功能3 设置进度条 功能4 两种激活方式 完整代码 功能描述 1 点击同意协议 才能进行下一步 2 选择一个
  • 2020软件测试学习自学路线分享,附完整资料,绝对有用哟

    2020软件测试学习路线图 内附自学路线 视频 工具经验 面试篇 划重点 资源链接 黑马程序员社区 想毕业后做测试相关的工作的 找学习资源找的头大 还好终于找到这么优质的可以系统地学习测试知识的途径 想学测试的小伙伴看看 真的可以跟着一步步
  • 误差向量幅度(EVM)

    转自 http blog sina com cn s blog 6c46cb860100otm3 html 误差向量幅度 EVM 误差向量 包括幅度和相位的矢量 是在一个给定时刻理想无误差基准信号与实际发射信号的向量差 Error Vect
  • 微信小程序添加插件腾讯位置服务路线规划,找不着的solution

    第一个 找到网页点击添加插件 提示类别不一样pass 第二个 在后台管理添加插件 提示找不着 pass 这两方法都不行 解决方法是 开发者后台登陆后 右上角服务 进入微信服务市场 选择开发者资源 然后选择插件 搜索腾讯位置服务路线规划 亲测
  • 3045 Lcm与Gcd构造

    已知 gcd a b n lcm a b m 求min a b 是多少 通过gcd的了解我们可以知道 两个数a k1 n以及b k2 n并且gcd k1 k2 1 ab n m m a b n ab k1 k2 n n 于是可以得到 m k
  • Yii Framework 开发教程(44) Zii组件-Resizable示例

    CJuiResizable可以使包含在其中的UI组件支持缩放功能 它封装了 JUI Resizable插件 CJuiResizable基本使用方法如下 php view plain copy print
  • Anaconda Prompt的用法

    Windows 开始菜单 打开Anaconda Prompt 这个窗口和cmd窗口一样的 用命令 conda list 查看已安装的包 从这些库中我们可以发现NumPy Matplotlib Pandas 说明已经安装成功了 下一步可以测试
  • ACM入门攻略(紫书入门,不间断更新)

    声明 本文仅供参考 并且假定读者已经可以熟练运用C语言及其相关知识 大神请走开 谢谢配合 目录 一 ACM入门的相关准备 书籍 OJ 编程语言 常用网站或工具 二 入门阶段的学习路线及其策略 全文以紫书为例 1 紫书第五章语言篇写题策略 2
  • JS之对象-对象声明及静态方法

    声明对象 1 原型实例化 声明对象的方式1 原型实例化 let obj1 new Object obj1 name obj1 张三 obj1 getName function return this name console log obj
  • 八十七.查找与排序习题总结(二)

    查找与排序习题总结 一 查找与排序习题总结 三 题一 调整数组顺序 奇数在左 偶数在右 调整数组的顺序使奇数位于偶数前面 输入一个整数数组 调整数组中数字的顺序使得所有奇数位于数组的前半部分 所有偶数位于数组的后半部分 要求时间复杂度为O
  • Rot.js 随机地牢,迷宫地图生成

    js 插件随机地牢 迷宫地图生成 插件git https github com ondras rot js tree master dist 使用 1 我们的游戏是在网页内进行的 一个基本的 HTML 文件就足够了
  • SPDK块设备

    SPDK视角每个App由多个子系统 subsystem 构成 同时每个子系统又包含多个模块 module 子系统和模块的注入都是可插拔的 通过相关的宏定义声明集成到SPDK组件容器里 其中子系统的注入可通过声明SPDK SUBSYSTEM
  • MMDeploy部署实战系列【第一章】:Docker,Nvidia-docker安装

    MMDeploy部署实战系列 第一章 Docker Nvidia docker安装 这个系列是一个随笔 是我走过的一些路 有些地方可能不太完善 如果有那个地方没看懂 评论区问就可以 我给补充 版权声明 本文为博主原创文章 遵循 CC 4 0
  • Type cannot use 'try' with exceptions disabled

    cannot use throw with exceptions disabled 在为 DragonBonesCPP refactoring 的 cocos2d x 3 2 demo 增加 Android 编译时 NDK 报了一个编译错误
  • 数据结构刷题训练营1

    开启蓝桥杯备战计划 每日练习算法一题 坚持下去 想必下一年的蓝桥杯将会有你 笔者是在力扣上面进行的刷题 由于是第一次刷题 找到的题目也不咋样 所以 就凑合凑合吧 笔者打算从数据结构开始刷起 毕竟现在刚刚接触到数据结构 在力扣上找到的刷题链接
  • 计算机方面英语文献翻译(学习记录更新中)

    在万方找的英文文献摘要 自己翻译的 1 考虑到时间序列数据的高维度和复杂性给数据挖掘带来的困难以及聚类分析在时间序列数据挖掘领域中的重要性 本文总结了国内外时间序列数据聚类的研究现状 时间序列聚类可以被分为全时间序列聚类和子序列聚类 并且可
  • Python流体动力学共形映射库埃特式流

    流体动力学简述 在物理学和工程学中 流体动力学是流体力学的一个分支学科 它描述了流体 液体和气体的流动 它有几个子学科 包括空气动力学 研究空气和其他运动中的气体 和流体动力学 研究运动中的液体 流体动力学具有广泛的应用 包括计算飞机上的力
  • 携程酒店数据爬取2020.5

    携程酒店数据爬取2020 5 1 开题 目前网上有好多爬取携程网站的教程 大多数通过xpath beautifulsoup 正则来解析网页的源代码 然后我这个菜b贪方便 直接copy源码的xpath paste在xpath helper改改
  • Kaggle手势符号识别项目实战

    项目数据集地址 https www kaggle com datasets ardamavi sign language digits dataset 观察到数据集已经做过预先的整理 十分工整 txt文件中类别标记清晰详细 项目文件如上图所