mnist example for lstm in caffe

2023-11-17

下面给出在caffe中使用lstm的一个例子,其中数据集采用mnist。

  1. 为了实现mnist数据的序列话,将mnist的每一行看成一帧,每一列则就是该帧的特征矢量。
  2. 在使用lstm时,一定要注意clip_markers,每个序列以0开始,后面接1保持为当前序列。
  3. 损失的计算有两种方式,每个序列的最后一帧参加到损失的计算,或者每个序列中所有帧都参加损失的计算。注意下面代码中最后全连接fc1和损失层中的axis参数,采用第二种方式时,需要将这两个层的axis设置为2。

为了方便测试,这里分享代码和数据:
链接https://pan.baidu.com/s/1grTdZhP4pqZmDzs7-WB31w
提取码:rxyn

训练代码为:
train_mnist_classification.py

#coding=gbk

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio # loadmat
from scipy.misc.pilutil import  *
import h5py
from os.path import join, realpath, dirname 
from caffe import layers as L, params as P
import sys, os
from caffe._caffe import Solver
import scipy.io as io
import time
import caffe
from caffe.proto import caffe_pb2



# enum Engine { DEFAULT = 0; CAFFE = 1; CUDNN = 2; } 
# enum NormRegion { ACROSS_CHANNELS = 0; WITHIN_CHANNEL = 1; } 

# mnist 10class

def make_lstm(trainSource,batchSize,nframe,cross_id,nclass,type1):
    
    n = caffe.NetSpec()
    
    n.data, n.labels, n.clip_markers = L.Python(name='data',ntop=3, 
                                          python_param=dict(module='python_read_data_for_mnist',layer='AllDataLayer',
                                          param_str='{\'phase\': \'train\', \'dataset_name\': \'mnist\', \'data_type\': \'image\',\'batch_size\': '+str(batchSize)
                                          + ',\'cross_id\':'+str(cross_id)+'}'),)
    
    n.fc0 = L.InnerProduct(n.data,name='fc0',num_output=128,
                           weight_filler=dict(type='xavier',std=0.005),
                            bias_filler=dict(type='constant',value=0.1),
                            param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)]
                           )
    
    n.reshape_data = L.Reshape(n.fc0,name='reshape_data',reshape_param={'shape':{'dim':[nframe,batchSize,128]}})
    n.reshape_labels = L.Reshape(n.labels,name='reshape_labels',reshape_param={'shape':{'dim':[nframe,batchSize]}})
    n.reshape_clipmarkers = L.Reshape(n.clip_markers,name='reshape_clipmarkers',reshape_param={'shape':{'dim':[nframe,batchSize]}})
    n.lstm1 = L.LSTM(n.reshape_data,n.reshape_clipmarkers,name='lstm1',recurrent_param={'num_output':64,
                                                                       'weight_filler':{'type':'uniform','min':-0.01,'max':0.01},
                                                                       'bias_filler':{'type':'constant','value':0}})
    # the output of lstm convert to 
    
    if type1 == 'using_last_frame_compute_loss':
    
        n.last_frame_data = L.Python(n.lstm1,name='last_frame_data',ntop=1,
                                  python_param=dict(module='data_separate_for_mnist',layer='data_separate'),
                                  propagate_down=[1]) # for main tasks
        
        n.last_frame_label = L.Python(n.reshape_labels,name='last_frame_label',ntop=1,
                                  python_param=dict(module='label_separate_for_mnist',layer='label_separate'),
                                  propagate_down=[0]) # for main tasks
        
        n.fc1 = L.InnerProduct(n.last_frame_data,name='fc1',num_output=nclass,
                               weight_filler=dict(type='xavier',std=0.005),
                                bias_filler=dict(type='constant',value=0.1),
                                param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)])
        
        n.loss = L.SoftmaxWithLoss(n.fc1,n.last_frame_label,
                                           name='loss',ntop=1)
    
    elif type1 == 'using_all_frame_compute_loss':
        
        n.fc1 = L.InnerProduct(n.lstm1,name='fc1',num_output=nclass,
                               weight_filler=dict(type='xavier',std=0.005),
                                bias_filler=dict(type='constant',value=0.1),
                                param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)],
                                axis=2)
        n.loss = L.SoftmaxWithLoss(n.fc1,n.reshape_labels,
                                           name='loss',ntop=1,
                                           softmax_param={'axis':2})
   
    
    return n.to_proto()




def make_solver(train_net,snapshot,snapshot_prefix):
    
    maxIter = 20000
    
    s = caffe_pb2.SolverParameter()
    s.random_seed = 0xCAFFE
    #s.type = 'Adam'
    s.type = 'SGD'
    s.display = 20
    s.iter_size = 1
    s.base_lr = 0.01 # 0.0005 for customs, 0.0001 for mean3
    s.lr_policy = "step"
    s.gamma = 0.1
    s.momentum = 0.9
    s.stepsize = 5000
    s.max_iter = maxIter
    s.weight_decay = 0.0005
    s.snapshot = snapshot
    s.snapshot_prefix = snapshot_prefix
    s.train_net = train_net
    
    return s

ncross = 1 # 仅进行一次实验。
for nc in range(ncross):

    print('cross_id: ',str(nc+1))
    ##########################
    nclass = 10  # for caspeal: 21 ,for pointing04: 93 , cmupie: 13
    zoo_path = 'ZOO_lstm/'
    snap_path = 'snapshots_lstm_for_mnist_all/'
    pretrained_weight = None
    
    debug_suf = ''
    trainSource = None  
    snapshot = 1000000 # not use
    niter = 10000 #5000
    snap_interval = 2000
    isShow = True
    batchSize = 256
    nframe = 28 # imgSize[0]
    type1 = 'using_last_frame_compute_loss' # using_last_frame_compute_loss, using_all_frame_compute_loss
    ##########################
    nstep = 0
    
    if not os.path.exists(snap_path):
        os.mkdir(snap_path)
    
    file_path = snap_path+'/cross'+str(nc+1)
    if not os.path.exists(file_path):
        os.mkdir(file_path)
    
    with open(zoo_path + 'train.prototxt', 'w') as f:
            f.write(str(make_lstm(trainSource,batchSize,nframe,nc,nclass,type1))) # 这里的 nc不需要加1
    

    caffe.set_device(0)
    caffe.set_mode_gpu()
    #caffe.set_mode_cpu()
    
    train_net = zoo_path + 'train.prototxt'
    snapshot_prefix = snap_path + 'cross' + str(nc+1) + debug_suf + '/vgg_'
    
    print(snapshot_prefix)
    
    print(train_net)
    
    solver_pro = zoo_path + 'solver.prototxt'
    
    with open(solver_pro, 'w') as f:
        f.write(str(make_solver(train_net,snapshot,snapshot_prefix))) 
        
    print(solver_pro)
    
    mysolver = caffe.get_solver(solver_pro)  
    

    loss1 = np.zeros(niter)
    
    disp_interval = 1
    
    isPrint = True
    
    if isShow:
        plt.ion()
        #plt.axis([0,niter,0,1])
        #fig = plt.figure()
        #pass
    for it in range(niter):
        
        mysolver.step(1)
    
        loss1[it] = mysolver.net.blobs['loss'].data.copy()
              
              
        if it % disp_interval == 0 or it+1 == niter :
            print('it:',it,' loss1:', loss1[it])
        
        if isShow and it>=1:
            
            plt.plot([it-1,it],[loss1[it-1],loss1[it]],'r-')
            plt.show()
            plt.pause(0.00001)
            
        if (it+1) % snap_interval == 0:
            plt.savefig(snap_path+str(it+1)+'.pdf')
            mysolver.net.save(snapshot_prefix + 'iter_' + str(it+1) + '.caffemodel')

在这里插入图片描述

预测代码为:

#coding=gbk


import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio # loadmat
from scipy.misc.pilutil import  *
import h5py
from os.path import join, realpath, dirname 
from caffe import layers as L, params as P
import sys, os
from caffe._caffe import Solver
import scipy.io as io
import time
import caffe
from caffe.proto import caffe_pb2
import mnist_version2_for_caffe_lstm as mnist
from numpy import *

# enum Engine { DEFAULT = 0; CAFFE = 1; CUDNN = 2; } 
# enum NormRegion { ACROSS_CHANNELS = 0; WITHIN_CHANNEL = 1; } 

# mnist 10class

def make_lstm_deploy(batchSize,nframe,feat_len,nclass,type1):
    
    n = caffe.NetSpec()
    
    
    n.data = L.Input(name='data',ntop=1,input_param=dict(shape=dict(dim=[batchSize*nframe,feat_len])))
    n.labels = L.Input(name='labels',ntop=1,input_param=dict(shape=dict(dim=[batchSize*nframe,1])))
    n.clip_markers = L.Input(name='clip_markers',ntop=1,input_param=dict(shape=dict(dim=[batchSize*nframe,1])))
    
    n.fc0 = L.InnerProduct(n.data,name='fc0',num_output=128,
                           weight_filler=dict(type='xavier',std=0.005),
                            bias_filler=dict(type='constant',value=0.1),
                            param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)]
                           )
    
    n.reshape_data = L.Reshape(n.fc0,name='reshape_data',reshape_param={'shape':{'dim':[nframe,batchSize,128]}})
    n.reshape_labels = L.Reshape(n.labels,name='reshape_labels',reshape_param={'shape':{'dim':[nframe,batchSize]}})
    n.reshape_clipmarkers = L.Reshape(n.clip_markers,name='reshape_clipmarkers',reshape_param={'shape':{'dim':[nframe,batchSize]}})
    n.lstm1 = L.LSTM(n.reshape_data,n.reshape_clipmarkers,name='lstm1',recurrent_param={'num_output':64,
                                                                       'weight_filler':{'type':'uniform','min':-0.01,'max':0.01},
                                                                       'bias_filler':{'type':'constant','value':0}})
    # the output of lstm convert to 
    if type1 == 'using_last_frame_compute_loss':
    
        n.last_frame_data = L.Python(n.lstm1,name='last_frame_data',ntop=1,
                                  python_param=dict(module='data_separate_for_mnist',layer='data_separate'),
                                  propagate_down=[1]) # for main tasks
        
        n.last_frame_label = L.Python(n.reshape_labels,name='last_frame_label',ntop=1,
                                  python_param=dict(module='label_separate_for_mnist',layer='label_separate'),
                                  propagate_down=[0]) # for main tasks
        
        n.fc1 = L.InnerProduct(n.last_frame_data,name='fc1',num_output=nclass,
                               weight_filler=dict(type='xavier',std=0.005),
                                bias_filler=dict(type='constant',value=0.1),
                                param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)])
        
        
        n.prob = L.Softmax(n.fc1,name='prob',)
        #n.loss = L.SoftmaxWithLoss(n.fc1,n.last_frame_label,
        #                                   name='loss',ntop=1)
    
    elif type1 == 'using_all_frame_compute_loss':
        
        n.fc1 = L.InnerProduct(n.lstm1,name='fc1',num_output=nclass,
                               weight_filler=dict(type='xavier',std=0.005),
                                bias_filler=dict(type='constant',value=0.1),
                                param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)],
                                axis=2)
        
        n.prob = L.Softmax(n.fc1,name='prob',softmax_param={'axis':2})
        
        #n.loss = L.SoftmaxWithLoss(n.fc1,n.reshape_labels,
        #                                   name='loss',ntop=1,softmax_param={'axis':2})
   
    
    return n.to_proto()



###########################################################
nclass = 10  # mnist
zoo_path = 'ZOO_lstm/'
model_root = 'snapshots_lstm_for_mnist/'

pretrained_weight = None

trainSource = None  
snapshot = 1000000 # not use
niter = 10000 #5000
snap_interval = 2000
isShow = True
batchSize = 100
nframe = 28   # imgSize[0]
feat_len = 28 # imgSize[1]
type1 = 'using_last_frame_compute_loss' # using_last_frame_compute_loss, using_all_frame_compute_loss
##########################
nstep = 0

with open(zoo_path + 'deploy.prototxt', 'w') as f: 
    f.write(str(make_lstm_deploy(batchSize,nframe,feat_len,nclass,type1)))

ncross = 1    
for nc in range(ncross):

    test_sample_num = 0
    
    data_obj = mnist.Mnist(nc)
    test_data = data_obj.test_data    
    ntest     = data_obj.ntest
    
    labels = test_data['labels']
    datas  = test_data['datas']
  
    
    test_sample_num = ntest
    
    assert(ntest == 10000)
    

    iter_num  = range(10000,10001,1000) # trained model
    
    iter_result = np.zeros(shape=(1,len(iter_num)))
    inum = 0
    for im in iter_num:
    
        #===========================================================================#
        weight_file = model_root + 'cross' + str(nc+1) + '/vgg_iter_' + str(im) + '.caffemodel'  
        deploy_pro = zoo_path + '/deploy.prototxt'                    
        #===========================================================================#
    
        print('deploy_pro: ',deploy_pro)
        net = caffe.Net(deploy_pro,weight_file,caffe.TEST)   
        
        
        cnt_num = 0

        result_number = np.zeros(test_sample_num)
        result_number_probs = np.zeros((test_sample_num,nclass))
        count = 0 
        
        test_patchSize = batchSize # 
        test_times = int32(ceil(datas.shape[0] / float32(test_patchSize) ))
    
        # true label
        for c in range(test_times):
            
            start_ind = c*test_patchSize
            if c == test_times-1:
                end_ind = datas.shape[0]
            else:
                end_ind = (c+1)*test_patchSize 
            batchSize2 = len(range(start_ind,end_ind))
            

            true_labels   = labels[start_ind:end_ind,...]
            test_features = datas[start_ind:end_ind,...]
            
            testLabels = np.zeros(shape=(batchSize2*nframe,1),dtype=np.float32)
             
            # train_features.shape = [T,N,H,W]
            testFeatures = np.zeros(shape=(batchSize2*nframe,feat_len),dtype=np.float32)
            # train_clipmarks.shape = [T,N,H,W]
            testClipmarks =  np.ones(shape=(batchSize2*nframe,1),dtype=np.float32)
                
            for i in range(nframe):
                for j in range(batchSize2):
                                
                    # the first img 28 * 28  H, W    
                    img = test_features[j,:,:]
                    label = true_labels[j,:]
                    frame = img[i,:]
                    
                    testFeatures[i*batchSize2+j,:] = frame
                    testLabels[i*batchSize2+j,:]   = label
                    if j == 0:
                        testClipmarks[i*batchSize2+j,:] = 0
            
            # forward net
            net.blobs['data'].reshape(batchSize2*nframe,feat_len)
            net.blobs['clip_markers'].reshape(batchSize2*nframe,1)
            
            
            net.blobs['data'].data[...] = testFeatures
            net.blobs['clip_markers'].data[...] = testClipmarks
       
            #plt.imshow(testFeatures)
            #plt.show()
            output = net.forward()
            probs = output['prob']

            # get predict labels
            if type1 != 'using_all_frame_compute_loss' and type1 != 'using_last_frame_compute_loss':
                raise Exception('please input correct type1...')
            
            if type1 == 'using_all_frame_compute_loss':
                probs = probs[nframe-1,:,:]
            
            predict_labels  = probs.argmax(1) # 
            
            i = 0
            isShow = False
                       
            for j in arange(start_ind,end_ind):
                
                predict_label = predict_labels[i]
                true_label = true_labels[i] 
    
                result_number[count] = predict_label
                
                print('headpose: ',predict_label,true_label)
                # 
                if predict_label == true_label:
                    cnt_num += 1

                        
                i += 1
                count +=1
            print("iter",c,'predict over...')
        print('cnt_mnist: ',cnt_num ,';total:', test_sample_num)
        print('the accuracy: ',cnt_num*1.0/test_sample_num)
         
        iter_result[0,inum] = cnt_num*1.0/test_sample_num 
        inum += 1 #统计测试模型的次数
ir = 0
for im in iter_num:
    print('mnist: ', str(im), ' :', iter_result[0,ir])
    ir += 1
    

在这里插入图片描述

结果:我们训练模型10k次,在10000个测试样本上的识别为:91.33%。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mnist example for lstm in caffe 的相关文章

随机推荐

  • 区块链数据库

    大家好 这里是链客区块链技术问答社区 链客 有问必答 区块链是互联网未来十年中举足轻重的技术 区块链 Blockchain 或者说分布式账本 DLT Distributed Ledger Technology 最早是起源于比特币的一个重要概
  • ruoyi框架时间范围range增加今日,近7日,近30日时间选择

    原先layui时间控件是不支持今日 近7日 近30日选择的 网上的解决方法是直接在引用的js中修改代码 这是一种方法 但是对于不能修改源代码的童鞋来说是不行的 所以一下解决方法诞生了 直接添加这三个按钮并和时间控件 laydate 有友好的
  • 【数字图像处理】图像形态学算法C语言实现(图像卷积,膨胀,腐蚀,开运算,闭运算,顶帽,黑帽,雕版,锐化)

    文章目录 一 图像卷积 1 图像卷积 2 数字信号处理中的卷积 3 数字图像处理中的卷积 二 图像卷积实现各种形态学运算 腐蚀 膨胀 形态学梯度 开运算 闭运算 顶帽 黑帽 雕版 锐化 li conv c main c 三 效果展示 原图
  • SPI转can芯片CSM300详解以及Linux驱动移植调试笔记

    更多嵌入式Linux干货 请关注 一口Linux 一 CSM300概述 CSM300 A 系列是一款可以支持 SPI UART 接口的CAN模块 1 简介 CSM300 A 系列隔离 SPI UART 转 CAN 模块是集成微处理器 CAN
  • Linux常用命令合集(二)

    file命令 该命令用于判断接在file命令后的文件的基本数据 因为在Linux下文件的类型并不是以后缀为分的 所以这个命令对我们来说就很有用了 gt file rumenz txt 查看rumenz txt的文件类型 rumenz txt
  • “泰迪杯”挑战赛 - 通过聚类方法对航空客运的客户进行细分

    目 录 挖掘目标 分析方法与过程 2 1 总体流程 2 2 具体步骤 步骤一 数据预处理 步骤二 群体聚类 步骤三 行为特征聚类 2 3 结果分析 第一类 第二类 第三类 结论 参考文献 1 挖掘目标 本次建模目标是在航空公司的海量会员数据
  • 分享几个好用的WP插件,让你的网站牛逼起来

    1 WP Rocket WPRocket缓存插件是目前最高效灵活的WordPress静态缓存插件 它可以优化你的JSCSS文件结构 减少多次请求 达到优化速度的目的 它还集成了图像延迟加载 对于想要最极致加速的用户来说是一个不错的选择 通过
  • Type Script 之 类型

    Type Script 中的类型有很多 常见的类型有 undefined null boolean number bigint string symbol void object unknown never any 其中基本类型有 void
  • java内存模型 堆栈_Java内存模型分析

    一 Java内存的构成 先上一个官方java document里的图 由上图可知 整块区域分为Young Generation Tenured Generation Permanent Generation 详细解释一下Young区 You
  • Laravel 表单验证器的常用的2种使用方法

    1 使用控制器的 validate 方法进行参数验证 场景一 前后端未分离 保存一篇新的博客文章 param Request request return Response public function store Request req
  • 【csv】csv文件存储上数据精度丢失问题

    最近发现较长的id信息在csv文件中会发生精度丢失 当然python直接处理数据是没问题的 只是csv显示有问题 case1 通常在Excel中输入数值时 如果超过11位 12位及以上 Excel就会用科学计数法显示该数值 如 123456
  • 让macOS支持读写NTFS格式的移动硬盘

    第一步 获知磁盘的名称 两种方法可以知道磁盘名称 第一种 当插入移动硬盘时 桌面上会出现移动硬盘的图标还有名称 第二种 打开终端 输入diskutil list 即可知道磁盘名称 由图中可知我的移动硬盘名称是 备份 第二步 打开终端 按照以
  • 操作系统的逻辑结构

    2 1 操作系统的逻辑结构 逻辑结构 OS的设计和实现思路 逻辑结构的种类 1 整体结构 2 层次式结构 3 微内核结构 客户 服务器结构 Client Server 操作系统作为一个大型软件 它的设计逻辑实现的思路 我们叫做操作系统的逻辑
  • 壁纸网站研究:强大到没朋友的壁纸网站整理(动漫/二次元/宅男/风景/真人)

    1 wallhaven 域名 https wallhaven cc 介绍 一个强大的壁纸网站 包含人物 动漫 风景 同时有一些老司机内容 需要选择NSFW 但需要登录才能观看 隐藏功能 但是海外网站 国内网站较慢 有时候打不开 总结 语言
  • 【华为OD机试真题2023B卷 JAVA&JS】内存资源分配

    华为OD2023 B卷 机试题库全覆盖 刷题指南点这里 内存资源分配 知识点贪心编程基础 时间限制 1s 空间限制 32MB 限定语言 不限 题目描述 有一个简易内存池 内存按照大小粒度分类 每个粒度有若干个可用内存资源 用户会进行一系列内
  • Spring 中容器启动分析之refresh方法执行之前

    内容来自 自学星球 欢迎大家来了解我的星球 和星主 也就是我 一起学习 Java 深入 Java 体系中的所有技术 我给自己定的时间是一年 无论结果如何 必定能给星球中的各位带来点东西 想要了解更多 欢迎访问 自学星球 SSM系列源码文章及
  • JavaScript随机生成颜色

    function getRandomColor const letters 0123456789ABCDEF let color for let i 0 i lt 6 i color letters Math floor Math rand
  • IPv4数据报的分段与重组

    文章摘自书籍 深入理解计算机网络 王达 机械工业出版社 IPv4数据报头格式请点击此处 IPv4数据报的封装与解封装请点击此处 IPv4数据报的分段与重组 在网络层中还涉及一个分段的问题 那就是因为不同网络线路上可以传输的数据报大小是有限制
  • QT学习记录(三)通过ui和代码的方式往窗口添加组件

    写在前面 本文是b站教程的https www bilibili com video BV1g4411H78N p 5 vd source a3efe214b8a2ba185e92e79cb6d6321b的笔记 外加自己的一些其他想法 如有侵
  • mnist example for lstm in caffe

    下面给出在caffe中使用lstm的一个例子 其中数据集采用mnist 为了实现mnist数据的序列话 将mnist的每一行看成一帧 每一列则就是该帧的特征矢量 在使用lstm时 一定要注意clip markers 每个序列以0开始 后面接