R语言中的Softmax Regression建模(MNIST手写体识别和文档多分类应用)

2023-05-16

关于softmax regression的数学模型部分可以参考Stanford的中英文Wiki:

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

 

softmaxregR包的下载地址:

https://cran.r-project.org/web/packages/softmaxreg/index.html


一、介绍

        Softmax Regression模型本质还是一个多分类模型,对Logistic Regression 逻辑回归的拓展。如果将Softmax Regression模型和神经网络隐含层结合起来,可以进一步提升模型的性能,构成包含多个隐含层和最后一个Softmax层的多层神经网络模型。之前发现R里面没有特别适合的方法支持多层的Softmax 模型,于是就想直接用R语言写一个softmaxreg 包。目前可以支持大部分的多分类问题,如下面两个示例:MNIST手写体识别和多文档分类(Multi-Class DocumentClassification) 。

 

二、示例

2.1 MNIST手写体识别数据集

        MNIST手写体识别的数据集是图像识别领域一个基本数据集,很多模型诸如CNN卷积神经网络等模型都经常在这个数据集上测试都能够达到97%以上的准确率。 这里想比较一下包含隐含层的softmaxreg模型,测试结果显示模型的准确率能达到93% 左右。

 

Part1、下载和Load数据

        MNIST手写体识别的数据集可以直接从网站下载http://yann.lecun.com/exdb/mnist/,一共四个文件,分别下载下来并解压。文件格式比较特殊,可以用softmaxreg 包中的load_image_file 和load_label_file 两个函数读取。(读取MNIST数据方法参考 Reference: brendano'connor - gist.github.com/39760)

        训练集有60000幅图片,每个图片都是由16*16个像素构成,代表了0-9中的某一个数字,比如下图。



      利用softmaxreg 包训练一个10分类的MNIST手写体识别的模型,用load_image_file 和load_label_file 来分别读取训练集的图像数据和标签的数据。

library(softmaxreg)
path = "D: \\DeepLearning\\MNIST\\"
#10-class classification, Digit 0-9
x = load_image_file(paste(path,'train-images-idx3-ubyte', sep=""))
y = load_label_file(paste(path,'train-labels-idx1-ubyte', sep=""))
xTest = load_image_file(paste(path,'t10k-images-idx3-ubyte',sep=""))
yTest = load_label_file(paste(path,'t10k-labels-idx1-ubyte', sep=""))


        可以用show_digit函数来看一个数字的图像,比如查看某一个图片,比如第2幅图片

show_digit(x[2,])

 

Part2、训练模型

        利用softmaxReg函数,训练集输入和标签分别为为x和y,maxit 设置最多多少个Epoch, algorithm为优化的算法,rate为学习率,batch参数为SGD随机梯度下降每个Mini-Batch的样本个数。 收敛后用predict方法来看看测试集Test的准确率怎么样

## Normalize Input Data
x = x/255
xTest = xTest/255
model1 = softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1, type = "class", algorithm = "sgd", rate = 0.01, batch = 1000)
loss1 = model1$loss

# Test Accuracy
yFit = predict(model1, newdata = x)
table(y, yFit)


Part3、比较不同优化算法的收敛速度

model2 = softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1, type = "class", algorithm = "adagrad", rate = 0.01, batch = 1000)
loss2 = model2$loss
model3 = softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1, type = "class", algorithm = "rmsprop", rate = 0.01, batch = 1000)
loss3 = model3$loss
model4 = softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1, type = "class", algorithm = "momentum", rate = 0.01, batch = 1000)
loss4 = model4$loss
model5 = softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1, type = "class", algorithm = "nag", rate = 0.01, batch = 1000)
loss5 = model5$loss

# plot the loss convergence
iteration = c(1:length(loss1))
myplot = plot(iteration, loss1, xlab = "iteration", ylab = "loss", ylim = c(0, max(loss1,loss2,loss3,loss4,loss5) + 0.01), 
    type = "p", col = "black", cex = 0.7)
title("Convergence Comparision Between Learning Algorithms")
points(iteration, loss2, col = "red", pch = 2, cex = 0.7)
points(iteration, loss3, col = "blue", pch = 3, cex = 0.7)
points(iteration, loss4, col = "green", pch = 4, cex = 0.7)
points(iteration, loss5, col = "magenta", pch = 5, cex = 0.7)

legend("topright", c("SGD", "Adagrad", "RMSprop", "Momentum", "NAG"), 
col = c("black", "red", "blue", "green", "magenta"),pch = c(1,2,3,4,5))
save.image()


      如果maxit 迭代次数过大,模型运行时间较长,可以保存图像,最后可以看到AdaGrad, rmsprop,momentum, nag 和标准SGD这几种优化算法的收敛速度的比较效果。关于优化算法这个帖子有很好的总结:

http://cs231n.github.io/neural-networks-3/



2.2 多类别的文档分类

        Softmax regression模型的每个输入为一个文档,用一个字符串表示。其中每个词word都可以用一个word2vec模型训练的word Embedding低维度的实数词向量表示。在softmaxreg包中有一个预先训练好的模型:长度为20维的英文词向量的字典,直接用data(word2vec) 调用就可以了。

        假设我们需要对UCI的C50新闻数据集进行分类,数据集包含多个作者写的新闻报道,每个作者的新闻文件都在一个单独的文件夹中。 我们假设挑选5个作者的文章进行训练softmax regression 模型,然后在测试集中预测任意文档属于哪一个作者,这就构成了一个5分类的问题。

 

Part1, 载入预先训练好的 英文word2vec 字典表

library(softmaxreg)
data(word2vec) # default 20 dimension word2vec dataset
#### Reuter 50 DataSet UCI Archived Dataset from


Part2,利用loadURLData函数从网址下载数据并且解压到folder目录

URL = "http://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip"
folder = getwd()
loadURLData(URL, folder, unzip = TRUE)

 

Part3,利用wordEmbed() 函数作为lookup table,从默认的word2vec数据集中查找每个单词的向量表示,默认20维度,可以自己训练自己的字典数据集来替换。

##Training Data
subFoler = c('AaronPressman', 'AlanCrosby', 'AlexanderSmith', 'BenjaminKangLim', 'BernardHickey')

docTrain = document(path = paste(folder, "/C50train/",subFoler, sep = ""), pattern = 'txt')

xTrain = wordEmbed(docTrain, dictionary = word2vec)
yTrain = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
# Assign labels to 5 different authors

##Testing Data
docTest = document(path = paste(folder, "/C50test/",subFoler, sep = ""), pattern = 'txt')
xTest = wordEmbed(docTest, dictionary = word2vec)
yTest = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
samp = sample(250, 50)
xTest = xTest[samp,]
yTest = yTest[samp]


Part4,训练模型,构建一个结构为20-10-5的模型,输入层为20维,即词向量的维度,隐含层的节点数为10,最后softmax层输出节点个数为5.

## Train Softmax Classification Model, 20-10-5
softmax_model = softmaxReg(xTrain, yTrain, hidden = c(10), maxit = 500, type = "class",
algorithm = "nag", rate = 0.05, batch = 10, L2 = TRUE)
summary(softmax_model)
yFit = predict(softmax_model, newdata = xTrain)
table(yTrain, yFit)
## Testing
yPred = predict(softmax_model, newdata = xTest)
table(yTest, yPred)


word2vec 文件也可以用自己训练的word2vec 词向量的字典模型导入,增加embedding的维度到50或者100可以提升模型准确度;


CRAN 文档地址

https://cran.r-project.org/web/packages/softmaxreg/softmaxreg.pdf

http://blog.csdn.net/rockingdingo/article/details/52769178



本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

R语言中的Softmax Regression建模(MNIST手写体识别和文档多分类应用) 的相关文章

随机推荐

  • 【报错】resultMap认知错误

    数据库改了一个字段的名字 xff0c 后来牵扯到实体类标准化都要改 xff0c 原来以为 xff0c mybatis使用的sql语句都是通过resultMap映射后 xff0c 可以使用后面的property xff0c 因为之前colum
  • Spring框架知识点

    1 Spring概述 1 1 什么是框架 xff1f 框架 xff08 Framework xff09 xff1a 框 xff08 指其约束性 xff09 架 xff08 指其支撑性 xff09 xff0c 在软件设计中指为解决一个开放性问
  • 前端脚手架开发工具包

    文章目录 文件操作fs extrareaddirpevent stream监听文件变化 chokidar 文件匹配glob 远程下载模板代码download git repo 命令行参数解析 minimist轻量级的命令行参数解析引擎 Co
  • 关于联想Y7000P睡眠后无法唤醒问题修复

    这个新的机器是WINDOWS11的 xff0c 症状了自己睡眠后就醒不过来了 xff0c 于是我找到了公众号 xff0c 提示下载一个软件修复驱动 xff0c http tools lenovo com cn tools exeTools
  • GPU渲染管线之旅|05 图元处理、Clip/Cull, 投影和视图变换

    上一篇中我们讨论了关于 纹理和采样 xff0c 这一篇我们回到3D管线的前端 在执行完顶点着色之后 xff0c 就可以实际的渲染东西了 xff0c 对吗 xff1f 暂时还不行 xff0c 因为在我们实际开始光栅化图元之前 xff0c 仍然
  • mac地址的作用

    最近读一本关于linux编程的书籍 xff0c 看到一部分很迷茫 xff0c 忽然不知道mac地址的作用 xff0c 既然已经有了ip地址了要mac地址何用呢 xff1f MAC地址是数据链路层的地址 xff0c 如果mac地址不可直达 直
  • 谈谈OpenCV中的四边形

    首先抛出一个问题 xff0c 给定一系列二维平面上的的点 xff0c 这些点是可以组成一个封闭的二维图形 因为这些点是矩形区域拍摄图像后识别得到的图形的边界点 xff0c 所以我们要抽象出来这个矩形 xff0c 也就是我们要反映出这个矩形
  • GPU渲染管线之旅|07 深度处理、模板处理

    在这一篇中 xff0c 我们来讨论Z pipline的前端部分 简称它为early Z 以及它是在光栅化中怎么起作用的 和上一篇一样 xff0c 本篇也不会按实际的管道顺序进行讨论 xff1b 我将首先描述基础算法 xff0c 然后再补充管
  • GPU渲染管线之旅|08 Pixel Shader

    在这一部分中 xff0c 我们来谈谈像素处理的前半部分 dispatch和实际的像素着色 事实上 xff0c 这部分是大多数图形开发者在谈到PS stage时所关心的内容 有关alpha blend和Late Z的内容则会下一篇文章中去探讨
  • MFC基于CSplitterWnd类的多窗口分割

    使用平台 xff1a win7 64bit 使用环境 xff1a VS2012 1 CSplitterWnd介绍 上图是从MSDN中截取的类的继承图表 xff0c CSplitterWnd类继承自CWnd类 这个类主要就是提供窗口分割的功能
  • OpenCV - 区域生长算法

    1 理论基础 区域生长算法的基本思想是将有相似性质的像素点合并到一起 对每一个区域要先指定一个种子点作为生长的起点 xff0c 然后将种子点周围领域的像素点和种子点进行对比 xff0c 将具有相似性质的点合并起来继续向外生长 xff0c 直
  • 不规则Contours内部像素的操作

    在findContours函数使用了之后 xff0c 有时候就会面临对Contours内部区域的访问 由于contours不一定是凸图形 xff0c 所以使用循环操作的时候总感觉不那么方便 比如在下图中 xff0c 已经使用findCont
  • Ubuntu 16.04 使用

    这篇博客用来专门记录尝试搬迁工作环境到Linux下的使用笔记 xff0c 主要包含有常用软件的安装 xff0c 配置 1 安装输入法 ubuntu 16 04中支持ibus输入系统 1 系统 gt 首选项 gt IBus设置 在弹出的IBu
  • 牛顿迭代法求解方程

    说明 xff1a 该篇博客源于博主的早些时候的一个csdn博客中的一篇 xff0c 由于近期使用到了 xff0c 所以再次作一总结 原文地址 概述 牛顿迭代法 xff08 Newton s method xff09 又称为牛顿 拉夫逊 xf
  • OpenCV - 均值迭代分割

    题外话 之前在博客中写过一篇 区域生长 的博客 xff0c 区域生长在平时经常用到 xff0c 也比较容易理解和代码实现 xff0c 所以在很多情况下大家会选择这种方法 但是区域生长有一个最致命的点就是需要选取一个生长的种子点 为了交流学习
  • [常见Bug]Kotlin,编译报错“Unresolved reference: ......”的解决方法

    注 xff1a 第1 2种情况较常见 第1种可能的情况 原因 xff1a Android Studio中目前的Kotlin插件版本 和 kotlin gradle plugin版本不一致 当版本不一致时 xff0c 检查build grad
  • mac xcode出现xxx.h没有出现的问题

    mac xcode出现xxx h没有出现的问题 xff0c 在命令行使用g 43 43 lxx编译代码却完全没有问题 xff0c 得出的结论是xcode的配置不到位 找了半个小时没找到莫名的烦躁 xff0c 使用xarman studio
  • 关于hive数据导入的小实验

    首先在自己本地路径编写2个数据文件 xff1a pv txt xff1a 1 111 2 111 1 222 user txt 111 25 111 18 222 32 然后hive中直接创建对应的2个表pv和users xff1a cre
  • MySQL索引(什么是索引、如何创建索引、什么时候用索引、索引的作用)

    1什么是索引 xff1f 简单来讲就是排好序的快速查找数据结构 2索引的优势劣势 3索引分类和创建索引的命令 4 BTree索引检索原理 5 那些情况适合索引 6 哪些情况不适合创建索引 下图是关于第三种情况的一个计算选择性的公式 xff0
  • R语言中的Softmax Regression建模(MNIST手写体识别和文档多分类应用)

    关于softmax regression的数学模型部分可以参考Stanford的中英文Wiki http ufldl stanford edu wiki index php Softmax E5 9B 9E E5 BD 92 softmax