keras

2023-05-16

大神笔记,转载自http://blog.csdn.net/u012162613/article/details/45397033

  1. Keras简介

Keras是基于Theano的一个深度学习框架,它的设计参考了Torch,用Python语言编写,是一个高度模块化的神经网络库,支持GPU和CPU。使用文档在这:http://keras.io/,这个框架貌似是刚刚火起来的,使用上的问题可以到github提issue:https://github.com/fchollet/keras 

下面简单介绍一下怎么使用Keras,以Mnist数据库为例,编写一个CNN网络结构,你将会发现特别简单。

  1. Keras里的模块介绍

Optimizers

顾名思义,Optimizers包含了一些优化的方法,比如最基本的随机梯度下降SGD,另外还有Adagrad、Adadelta、RMSprop、Adam,一些新的方法以后也会被不断添加进来

上面的代码是SGD的使用方法,lr表示学习速率,momentum表示动量项,decay是学习速率的衰减系数(每个epoch衰减一次),Nesterov的值是False或者True,表示使不使用Nesterov momentum。其他的请参考文档。

Objectives

这是目标函数模块,keras提供了mean_squared_error,mean_absolute_error
,squared_hinge,hinge,binary_crossentropy,categorical_crossentropy这几种目标函数。

这里binary_crossentropy 和 categorical_crossentropy也就是常说的logloss.

Activations

这是激活函数模块,keras提供了linear、sigmoid、hard_sigmoid、tanh、softplus、relu、softplus,另外softmax也放在Activations模块里(我觉得放在layers模块里更合理些)。此外,像LeakyReLU和PReLU这种比较新的激活函数,keras在keras.layers.advanced_activations模块里提供。

Initializations

这是参数初始化模块,在添加layer的时候调用init进行初始化。keras提供了uniform、lecun_uniform、normal、orthogonal、zero、glorot_normal、he_normal这几种。

layers

layers模块包含了core、convolutional、recurrent、advanced_activations、normalization、embeddings这几种layer。

其中core里面包含了flatten(CNN的全连接层之前需要把二维特征图flatten成为一维的)、reshape(CNN输入时将一维的向量弄成二维的)、dense(就是隐藏层,dense是稠密的意思),还有其他的就不介绍了。convolutional层基本就是Theano的Convolution2D的封装。

Preprocessing

这是预处理模块,包括序列数据的处理,文本数据的处理,图像数据的处理。重点看一下图像数据的处理,keras提供了ImageDataGenerator函数,实现data augmentation,数据集扩增,对图像做一些弹性变换,比如水平翻转,垂直翻转,旋转等。

Models

这是最主要的模块,模型。上面定义了各种基本组件,model是将它们组合起来,下面通过一个实例来说明。

3.一个实例:用CNN分类Mnist

数据下载

Mnist数据在其官网上有提供,但是不是图像格式的,因为我们通常都是直接处理图像,为了以后程序能复用,我把它弄成图像格式的,这里可以下载:http://pan.baidu.com/s/1qCdS6,共有42000张图片。

读取图片数据

keras要求输入的数据格式是numpy.array类型(numpy是一个python的数值计算的库),所以需要写一个脚本来读入mnist图像,保存为一个四维的data,还有一个一维的label,代码:

#coding:utf-8

import os
from PIL import Image
import numpy as np

#读取文件夹mnist下的42000张图片,图片为灰度图,所以为1通道,
#如果是将彩色图作为输入,则将1替换为3,并且data[i,:,:,:] = arr改为data[i,:,:,:] = [arr[:,:,0],arr[:,:,1],arr[:,:,2]]
def load_data():
    data = np.empty((42000,1,28,28),dtype="float32")
    label = np.empty((42000,),dtype="uint8")

    imgs = os.listdir("./mnist")
    num = len(imgs)
    for i in range(num):
        img = Image.open("./mnist/"+imgs[i])
        arr = np.asarray(img,dtype="float32")
        data[i,:,:,:] = arr
        label[i] = int(imgs[i].split('.')[0])
    return data,label

构建CNN,训练

短短二十多行代码,构建一个三个卷积层的CNN,直接读下面的代码吧,有注释,很容易读懂:

#导入各种用到的模块组件
from __future__ import absolute_import
from __future__ import print_function
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD, Adadelta, Adagrad
from keras.utils import np_utils, generic_utils
from six.moves import range
from data import load_data

#加载数据
data, label = load_data()
print(data.shape[0], ' samples')

#label为0~9共10个类别,keras要求格式为binary class matrices,转化一下,直接调用keras提供的这个函数
label = np_utils.to_categorical(label, 10)

###############
#开始建立CNN模型
###############

#生成一个model
model = Sequential()

#第一个卷积层,4个卷积核,每个卷积核大小5*5。1表示输入的图片的通道,灰度图为1通道。
#border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d
#激活函数用tanh
#你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5))
model.add(Convolution2D(4, 1, 5, 5, border_mode='valid')) 
model.add(Activation('tanh'))

#第二个卷积层,8个卷积核,每个卷积核大小3*3。4表示输入的特征图个数,等于上一层的卷积核个数
#激活函数用tanh
#采用maxpooling,poolsize为(2,2)
model.add(Convolution2D(8,4, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))
model.add(MaxPooling2D(poolsize=(2, 2)))

#第三个卷积层,16个卷积核,每个卷积核大小3*3
#激活函数用tanh
#采用maxpooling,poolsize为(2,2)
model.add(Convolution2D(16, 8, 3, 3, border_mode='valid')) 
model.add(Activation('tanh'))
model.add(MaxPooling2D(poolsize=(2, 2)))

#全连接层,先将前一层输出的二维特征图flatten为一维的。
#Dense就是隐藏层。16就是上一层输出的特征图个数。4是根据每个卷积层计算出来的:(28-5+1)得到24,(24-3+1)/2得到11,(11-3+1)/2得到4
#全连接有128个神经元节点,初始化方式为normal
model.add(Flatten())
model.add(Dense(16*4*4, 128, init='normal'))
model.add(Activation('tanh'))

#Softmax分类,输出是10类别
model.add(Dense(128, 10, init='normal'))
model.add(Activation('softmax'))

#############
#开始训练模型
##############
#使用SGD + momentum
#model.compile里的参数loss就是损失函数(目标函数)
sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

#调用fit方法,就是一个训练过程. 训练的epoch数设为10,batch_size为100.
#数据经过随机打乱shuffle=True。verbose=1,训练过程中输出的信息,0、1、2三种方式都可以,无关紧要。show_accuracy=True,训练时每一个epoch都输出accuracy。
#validation_split=0.2,将20%的数据作为验证集。
model.fit(data, label, batch_size=100,nb_epoch=10,shuffle=True,verbose=1,show_accuracy=True,validation_split=0.2)

代码使用与结果

结果如下所示,在Epoch 9达到了0.98的训练集识别率和0.97的验证集识别率:

这里写图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras 的相关文章

随机推荐

  • 我的大学——学习生活总结

    纪念我终将逝去的青春 大一上學期 專業 1 C語言K amp R amp amp 習題 2 C語言經典習題 3 C語言趣味習題 4 C陷阱与缺陷 5 彙編語言 6 C 43 43 程序設計 7 C 程序設計
  • latex论文作图(python+matplotlib)

    20210425 0 引言 论文中进行作图 xff0c 需要对图片中的各种元素进行控制 xff0c 最近在论文写作过程中为了能够得到匹配文章的高质量图片 xff0c 也是花了很多心血 除了对图片中的风格进行控制 xff0c 另一方面比较重要
  • SAP结转方法:表结法、帐结法

    SAP 处理会计期间结帐方法主要有两种方法 xff1a 表结法和帐结法 国内在会计期末结帐大都采用 帐结 的方法 xff0c 而 SAP 一般都是采用 表 结 xff0c 通过财务报表的编制来披露当期利润 xff0c 即 xff1a 销售科
  • V4L2读取摄像头YUYV(YUV420)帧后使用C语言转存为bmp格式

    摄像头配置读取一帧YUV420 xff08 YUYV xff09 保存为RGB24图像 BRG的顺序 xff0c bmp 下面是内存中摄像头读取的数据直接转存为RGB图片的源码 输入 xff1a 图像指针地址 xff0c 图像长度 xff0
  • Linux内核系统调用原理与实现

    解决什么问题 Linux系统调用主要是操作系统实现的应用编程接口 xff0c 简单的说就是linux内核提供对外 对于应用程序 的接口函数 xff0c 进程通过调用系统调用完成自身的功能 系统调用在每个平台的实现方式都不同相同 xff0c
  • Docker容器基础

    1 介绍 Docker官网 xff1a https docs docker com Docker的github地址 xff1a https github com moby moby Dockerhub官网 https registry hu
  • 【自动驾驶】常见位姿估计算法的比较: 三角测量、PNP、ICP、

    PnP问题 3D 2D DLT 直接线性变换算法 相机标定工程用到的是DLT 直接线性变换算法 xff0c 它是一类PnP问题 3D 2D 请参考 位姿估计 视觉SLAM 笔记 常见位姿估计算法的比较 PnP xff08 Perspecti
  • CC2530 BootLoader,不带协议栈,任意跳转

    最近业余研究了下CC2530的远程固件更新 空中下载 现做个总结 一则方便大家学习共进 二则自己做个记录以防日后忘了 一 BootLoader主要技术点 nbsp nbsp 1 程序跳转到指定位置 nbsp nbsp 2 设置好相应的中断向
  • 使用 VNC 实现多用户登录

    Virtual Network Computing VNC 是一种提供计算机远程访问的流行工具 常规的 VNC 配置是针对单用户工作台而进行优化的 xff0c 可登录到 VNC 端口直接访问单一用户的桌面 然而 xff0c 这一配置在多用户
  • STLink V2烧录SWIM和SWD接口接线图

    stm8 采用SVTP软件烧录 xff0c 烧录接口为SWIM xff08 stlink v2烧录器带有该接口 xff09 xff0c 如下图 xff1a stm32可采用stlink v2 的SWD接口烧录 xff0c 接线图如下 xff
  • 车辆姿态角(Euler角)Pitch、Yaw、Roll 的设定

    首先申明 xff1a 此坐标系是针对车辆而设定的 xff0c 对于无人机来说是不同的 pitch xff1a 俯仰角 xff0c pitchAngleC2W orientation radian Y yaw xff1a 航向角 xff0c
  • Docker(四)----Docker-Compose 详解

    1 什么是Docker Compose Compose项目来源于之前的fig项目 xff0c 使用python语言编写 与docker swarm配合度很高 Compose 是 Docker 容器进行编排的工具 xff0c 定义和运行多容器
  • 转贴:ERP实施过程中的40个问题

    笔者在多年的实践中 xff0c 结合自身经验和多年的理论积累 xff0c 总结出有关ERP 实施的最关键的39 个问题 xff0c 以问答的形式 xff0c 让您在最短的时间内对ERP 实施有一个全面而客观的认识 xff0c 以免陷入日新月
  • VS Code 常用设置集合

    常用设置 xff08 setting json xff09 34 editor parameterHints enabled 34 true 开启参数预览窗口 设置字体颜色 34 editor semanticTokenColorCusto
  • Arduino--LCD1602(IIC)

    xff08 1 xff09 简介 前篇文章介绍了LCD1602的四位数据线控制方法 xff1a https blog csdn net u011816009 article details 106573622 但是该方法还是需要较多的IO口
  • Px4 ULog文件详解

    Px4 ULog文件详解 简介数据类型文件组织文件头定义段消息标记位消息格式定义消息信息消息复合信息消息参数消息 数据段订阅消息取消订阅消息日志数据消息字符串消息同步消息丢失 附录 简介 ULog 是用于记录数据的文件格式 xff0c 该格
  • 开发日记(一)

    这是自己编程第二天 xff0c 自己解决了好几个问题 xff0c 觉得很有成就感 xff0c 决定写下以后开发中遇到的问题 1 在多个Activity中传递数据 xff0c 之前只学过绑定基本的putExtra xff0c 今天上网一搜 x
  • 源程序生成控制流图和du-path

    最近上 源代码分析技术 这个课 xff0c 老师让写一个程序 xff0c 由一段c代码 xff0c 生成生成控制流图和du path xff0c 控制流图不用解释了 xff0c 说一下du path xff0c 这个术语是针对变量来说的 x
  • pandas使用笔记

    DataFrame使用笔记 dates 61 pd date range span class hljs string 39 20160728 39 span periods 61 span class hljs number 6 span
  • keras

    大神笔记 xff0c 转载自http blog csdn net u012162613 article details 45397033 Keras简介 Keras是基于Theano的一个深度学习框架 xff0c 它的设计参考了Torch