无法在 Keras 中复制 matconvnet CNN 架构

2024-02-17

我在 matconvnet 中有以下卷积神经网络架构,我用它来训练我自己的数据:

function net = cnn_mnist_init(varargin)
% CNN_MNIST_LENET Initialize a CNN similar for MNIST
opts.batchNormalization = false ;
opts.networkType = 'simplenn' ;
opts = vl_argparse(opts, varargin) ;

f= 0.0125 ;
net.layers = {} ;
net.layers{end+1} = struct('name','conv1',...
                           'type', 'conv', ...
                           'weights', {{f*randn(3,3,1,64, 'single'), zeros(1, 64, 'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','pool1',...
                           'type', 'pool', ...
                           'method', 'max', ...
                           'pool', [3 3], ...
                           'stride', 1, ...
                           'pad', 0);
net.layers{end+1} = struct('name','conv2',...
                           'type', 'conv', ...
                           'weights', {{f*randn(5,5,64,128, 'single'),zeros(1,128,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','pool2',...
                           'type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ;
net.layers{end+1} = struct('name','conv3',...
                           'type', 'conv', ...
                           'weights', {{f*randn(3,3,128,256, 'single'),zeros(1,256,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','pool3',...
                           'type', 'pool', ...
                           'method', 'max', ...
                           'pool', [3 3], ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('name','conv4',...
                           'type', 'conv', ...
                           'weights', {{f*randn(5,5,256,512, 'single'),zeros(1,512,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','pool4',...
                           'type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('name','ip1',...
                           'type', 'conv', ...
                           'weights', {{f*randn(1,1,256,256, 'single'),  zeros(1,256,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','relu',...
                           'type', 'relu');
net.layers{end+1} = struct('name','classifier',...
                           'type', 'conv', ...
                           'weights', {{f*randn(1,1,256,2, 'single'), zeros(1,2,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0,...
                           'learningRate', [1 2]) ;
net.layers{end+1} = struct('name','loss',...
                           'type', 'softmaxloss') ;

% optionally switch to batch normalization
if opts.batchNormalization
  net = insertBnorm(net, 1) ;
  net = insertBnorm(net, 4) ;
  net = insertBnorm(net, 7) ;
  net = insertBnorm(net, 10) ;
  net = insertBnorm(net, 13) ;
end

% Meta parameters
net.meta.inputSize = [28 28 1] ;
net.meta.trainOpts.learningRate = [0.01*ones(1,10) 0.001*ones(1,10) 0.0001*ones(1,10)];
disp(net.meta.trainOpts.learningRate);
pause;
net.meta.trainOpts.numEpochs = length(net.meta.trainOpts.learningRate) ;
net.meta.trainOpts.batchSize = 256 ;
net.meta.trainOpts.momentum = 0.9 ;
net.meta.trainOpts.weightDecay = 0.0005 ;

% --------------------------------------------------------------------
function net = insertBnorm(net, l)
% --------------------------------------------------------------------
assert(isfield(net.layers{l}, 'weights'));
ndim = size(net.layers{l}.weights{1}, 4);
layer = struct('type', 'bnorm', ...
               'weights', {{ones(ndim, 1, 'single'), zeros(ndim, 1, 'single')}}, ...
               'learningRate', [1 1], ...
               'weightDecay', [0 0]) ;
net.layers{l}.biases = [] ;
net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ;

我想做的是在 Keras 中构建相同的架构,这就是我到目前为止所尝试的:

model = Sequential()

model.add(Conv2D(64, (3, 3), strides=1, input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(3, 3), strides=1))

model.add(Conv2D(128, (5, 5), strides=1))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))

model.add(Conv2D(256, (3, 3), strides=1))
model.add(MaxPooling2D(pool_size=(3, 3), strides=1))

model.add(Conv2D(512, (5, 5), strides=1))
model.add(MaxPooling2D(pool_size=(2, 2), strides=1))

model.add(Conv2D(256, (1, 1)))
convout1=Activation('relu')
model.add(convout1)

model.add(Flatten())
model.add(Dense(num_classes, activation='softmax'))

opt = keras.optimizers.rmsprop(lr=0.0001, decay=0.0005)  
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

然而,当我运行 matconvnet 网络时,我的准确率是 87%,如果我运行 keras 版本,我的准确率是 77%。如果它们应该是同一个网络并且数据是相同的,那么区别在哪里?我的 Keras 架构出了什么问题?


在您的 MatConvNet 版本中,您使用带有动量的 SGD。

在 Keras 中,您使用 rmsprop

使用不同的学习规则,您应该尝试不同的学习率。有时,动量在训练 CNN 时也很有帮助。

您能尝试一下 Keras 中的 SGD+momentum 并让我知道会发生什么吗?

另一件可能不同的事情是初始化。例如,在 MatConvNet 中,您使用高斯初始化,并以 f= 0.0125 作为标准差。在 Keras 中,我不确定默认初始化。

一般来说,如果不使用批量归一化,网络很容易出现许多数值问题。如果您在两个网络中使用批量归一化,我敢打赌结果会相似。您有什么理由不想使用批量归一化吗?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

无法在 Keras 中复制 matconvnet CNN 架构 的相关文章

  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • 使用字典映射数据帧索引

    为什么不df index map dict 工作就像df column name map dict 这是尝试使用index map的一个小例子 import pandas as pd df pd DataFrame one A 10 B 2
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • Cython 和类的构造函数

    我对 Cython 使用默认构造函数有疑问 我的 C 类 Node 如下 Node h class Node public Node std cerr lt lt calling no arg constructor lt lt std e
  • javascript 是否有等效的 __repr__ ?

    我最接近Python的东西repr这是 function User name password this name name this password password User prototype toString function r
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside
  • Django-tables2 列总计

    我正在尝试使用此总结列中的所有值文档 https github com bradleyayers django tables2 blob master docs pages column headers and footers rst 但页
  • 如何应用一个函数 n 次? [关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 假设我有一个函数 它接受一个参数并返回相同类型的结果 def increment x return x 1 如何制作高阶函数repeat可以
  • 如何计算Python中字典中最常见的前10个值

    我对 python 和一般编程都很陌生 所以请友善 我正在尝试分析包含音乐信息的 csv 文件并返回最常听的前 n 个乐队 从下面的代码中 每听一首歌曲都是一个列表中的字典条目 格式如下 album Exile on Main Street
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是
  • 张量流中的复杂卷积

    我正在尝试运行一个简单的卷积 但包含复数 r np random random 1 10 10 10 i np random random 1 10 10 10 x tf complex r i conv layer tf layers c

随机推荐

  • 如何在 Android Studio 首次运行时禁用下载组件

    我提取Android Studio IDE 135 1740770 还为 SDK 安装了这些软件包 Tools Android SDK工具24 1 2 Android SDK平台 工具22 Android SDK构建工具22 0 1 Ext
  • 使用 STAX 解析器将 XML 解组为三个不同对象的列表

    有没有一种方法可以使用 STAX 解析器来有效地解析具有不同类 POJO 的多个对象列表的 XML 文档 我的 XML 的确切结构如下 类名不是真实的
  • 在Scheme 中是否有相当于Lisp 的“运行时”原语?

    根据SICP 第 1 2 6 节 http mitpress mit edu sicp full text book book Z H 11 html sec 1 2 6 练习 1 22 大多数 Lisp 实现都包含一个称为运行时的原语 它
  • IDIV 汇编语言的问题

    CX 等于 14 AX 等于 16 IDIV CX 但 ALL 中的某个地方有 37 个 该行之前没有任何错误或错误 我哪里做错了 谢谢你 附注在Emu8086上写 IDIV CX除 32 位值DX AX by CX 并将商存储在AX和剩余
  • 如何在本机 C++ 项目中使用 tlb 文件

    我有一个 tlb 文件 其中包含一些我需要使用的函数声明 If I use import type library tlb 我可以从我的代码中正确引用该函数 tlb namespace required function 但是当我编译项目时
  • AVfoundation 反向视频

    我尝试制作反向视频 在 AVPlayer 中播放资源时 我将速率设置为 1 以使其以反向格式工作 但如何导出该视频呢 我查看了文档 阅读有关 avassetwrite sambuffers compositions 的内容 但没有找到任何方
  • Python Pygame 无法正确显示图像

    我是 Python 新手 我开始学习 Eric Matthes 的 Python 速成课程 我在 Pygame 章节的开头 遵循代码 但我加载的图像总是看起来损坏 我不知道为什么 代码来自书本 第一个文件 import pygame cla
  • Kivy RecycleView 作为 ListView 的替代品?它是如何工作的?

    我应该先说一下 我仍然是 Kivy 的新手 我尝试寻找类似的问题 但它们要么过时 要么不清楚 我正在寻找一些东西来显示元素列表 用户可以在其中选择一个元素来与其他小部件 按钮等 进行交互 我偶然发现了ListView 上的文档页面 http
  • 帮助理解 PHP5 错误

    简而言之 问题是 说什么 扩展 我没有收到错误 严格标准 非静态方法 Pyro Template preLoad 不应静态调用 假设 this 来自 opt lampp htdocs dc pyro app controllers admi
  • 选择 Xamarin Forms 中存储文件的路径

    我有一个 Xamarin 表单应用程序 我想保存文件 当用户在手机中打开文件管理器或手机连接到计算机时 应该显示该文件 我读了这个article https developer xamarin com guides xamarin form
  • Openblas 没有链接到 Scipy

    我目前在 Debian Jessie 上运行 scipy 我已经从 apt get 安装了 scipy 我还从 apt 安装了 blas 和 lapack sudo apt get install python scipy libblas
  • MySQL-SUM 日期时间?

    我需要总和日期时间值 但我不知道如何做到这一点 我有桌子 我的查询 SELECT SUM h dtplay AS Time FROM tblhistory AS h tblgame AS g WHERE h idgame g id AND
  • 允许在 asp.net 文本框中使用 html

    我将 ValidateRequest false 添加到页面指令中 但页面的行为就像没有回发一样 如果我删除 html 那么它会正常回发 使用更新面板内的文本框应该不会产生影响 对吗 我正在尝试使用 html 格式将文本存储在我的数据库中
  • Qt 全局样式表加载?

    如何使用 Qt 全局加载样式表 qss 样式资源 我正在努力让事情变得比以下更有效率 middleIntText gt setStyleSheet QLineEdit border 1px solid gray border radius
  • 是否可以将 LIMIT 与子查询结果一起使用?

    当需要有序集的最后几行时 通常会创建派生表并重新排序 例如 返回自动递增表的最后 3 个元素id SELECT FROM SELECT FROM table ORDER BY id DESC LIMIT 3 t ORDER BY t id
  • \n 在 Sklabel SpriteKit 中不起作用

    我在我的游戏中使用了以下代码 问题是我无法像使用 CCLabelTTF 那样在 spritekit 中制作多行标签 有人可以帮助我吗 另外我无法在我的代码中使用 t 或 n 感谢您的提前回复 SKLabelNode winner SKLab
  • C++11 std::threads 并等待线程完成

    我有一个计时器对象向量 每个计时器对象都会启动一个模拟生长期的 std thread 我正在使用命令模式 发生的情况是每个计时器都被一个接一个地执行 但我真正想要的是一个被执行 然后一旦完成 下一个 一旦完成下一个 同时不干扰主程序的执行
  • 工具提示内的图像tiptip

    我需要在具有悬停效果的工具提示中插入图像 tel view tipTip defaultPosition top delay 400 fadeIn 400 keepAlive true activation click HTML img s
  • 多个神经网络各有一个输出还是一个有多个输出?

    我想将输入分类为三种可能性之一 使用 3 个网络 每个网络有一个输出 还是 1 个网络 每个网络有 3 个输出 更好 即 3 个网络输出0 or 1或 1 个输出长度为 3 的单热向量的网络 1 0 0 答案是否会根据输入数据分类的复杂程度
  • 无法在 Keras 中复制 matconvnet CNN 架构

    我在 matconvnet 中有以下卷积神经网络架构 我用它来训练我自己的数据 function net cnn mnist init varargin CNN MNIST LENET Initialize a CNN similar fo