如何提高自编码器的准确率?

2024-01-27

我有一个自动编码器,我使用不同的解决方案检查了模型的准确性,例如更改转换层的数量并增加它们,添加或删除批量归一化,更改激活函数,但所有这些解决方案的准确性都是相似的,并且不一样有任何奇怪的改进。我很困惑,因为我认为这些不同解决方案的准确度应该不同,但它是 0.8156。你能帮我看看有什么问题吗?我还用 10000 个 epoch 对其进行训练,但 50 个 epoch 的输出是相同的!我的代码是错误的还是不能变得更好?!准确度图 https://i.stack.imgur.com/KBnrp.png

我也不确定学习率衰减是否有效?! 我也把我的代码放在这里:

from keras.layers import Input, Concatenate, GaussianNoise,Dropout,BatchNormalization
from keras.layers import Conv2D
from keras.models import Model
from keras.datasets import mnist,cifar10
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
import numpy as np
import pylab as pl
import matplotlib.cm as cm
import keract
from matplotlib import pyplot
from keras import optimizers
from keras import regularizers
from tensorflow.python.keras.layers import Lambda;

image = Input((28, 28, 1))
conv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1e')(image)
conv2 = Conv2D(32, (3, 3), activation='elu', padding='same', name='convl2e')(conv1)
conv3 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl3e')(conv2)
#conv3 = Conv2D(8, (3, 3), activation='relu', padding='same', name='convl3e', kernel_initializer='Orthogonal',bias_initializer='glorot_uniform')(conv2)
BN=BatchNormalization()(conv3)
#DrO1=Dropout(0.25,name='Dro1')(conv3)
DrO1=Dropout(0.25,name='Dro1')(BN)
encoded =  Conv2D(1, (3, 3), activation='elu', padding='same',name='encoded_I')(DrO1)



#-----------------------decoder------------------------------------------------
#------------------------------------------------------------------------------
deconv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1d')(encoded)
deconv2 = Conv2D(32, (3, 3), activation='elu', padding='same', name='convl2d')(deconv1)
deconv3 = Conv2D(16, (3, 3), activation='elu',padding='same', name='convl3d')(deconv2)
BNd=BatchNormalization()(deconv3)
DrO2=Dropout(0.25,name='DrO2')(BNd)
#DrO2=Dropout(0.5,name='DrO2')(deconv3)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='decoder_output')(DrO2) 
#model=Model(inputs=[image,wtm],outputs=decoded)

#--------------------------------adding noise----------------------------------
#decoded_noise = GaussianNoise(0.5)(decoded)


watermark_extraction=Model(inputs=image,outputs=decoded)

watermark_extraction.summary()
#----------------------training the model--------------------------------------
#------------------------------------------------------------------------------
#----------------------Data preparation----------------------------------------

(x_train, _), (x_test, _) = mnist.load_data()
x_validation=x_train[1:10000,:,:]
x_train=x_train[10001:60000,:,:]
#(x_train, _), (x_test, _) = cifar10.load_data()
#x_validation=x_train[1:10000,:,:]
#x_train=x_train[10001:60000,:,:]
#
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_validation = x_validation.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_validation = np.reshape(x_validation, (len(x_validation), 28, 28, 1))

#---------------------compile and train the model------------------------------
# is accuracy sensible metric for this model?
learning_rate = 0.1
decay_rate = learning_rate / 50
opt = optimizers.SGD(lr=learning_rate, momentum=0.9, decay=decay_rate, nesterov=False)

watermark_extraction.compile(optimizer=opt, loss=['mse'], metrics=['accuracy'])
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)
#rlrp = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_delta=1E-7, verbose=1)
history=watermark_extraction.fit(x_train, x_train,
          epochs=50,
          batch_size=32, 
          validation_data=(x_validation, x_validation),
          callbacks=[TensorBoard(log_dir='E:/output of tensorboard', histogram_freq=0, write_graph=False),es])
watermark_extraction.summary()
#--------------------visuallize the output layers------------------------------
#_, train_acc = watermark_extraction.evaluate(x_train, x_train)
#_, test_acc = watermark_extraction.evaluate([x_test[5000:5001],wt_expand], [x_test[5000:5001],wt_expand])
#print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))
## plot loss learning curves
pyplot.subplot(211)
pyplot.title('MSE Loss', pad=-40)
pyplot.plot(history.history['loss'], label='train')
pyplot.plot(history.history['val_loss'], label='validation')
pyplot.legend()

pyplot.subplot(212)
pyplot.title('Accuracy', pad=-40)
pyplot.plot(history.history['acc'], label='train')
pyplot.plot(history.history['val_acc'], label='test')
pyplot.legend()
pyplot.show

既然你说你是一个初学者,我将尝试从下往上构建,并尝试尽可能多地用该解释来解释你的代码。

Part 1自动编码器由两部分组成(编码器和解码器)。自动编码器减少存储信息所需的变量数量,而解码器尝试从压缩形式中获取此信息。 (请注意,由于自动编码器的不确定性和数据依赖性,因此在实际数据压缩任务中不使用自动编码器)。

现在在你的代码中你保留padding一样。

conv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1e')(image)

这基本上消除了自动编码器的压缩和扩展功能,即在每个步骤中,您都使用相同数量的变量来表示信息。

Part 2现在开始训练算法

history=watermark_extraction.fit(x_train, x_train,
          epochs=50,
          batch_size=32, 
          validation_data=(x_validation, x_validation),
          callbacks=[TensorBoard(log_dir='E:/PhD/thesis/deepwatermark/journal code/autoencoder_watermark/11-2-2019/output of tensorboard', histogram_freq=0, write_graph=False),es])

从这个表达式/语句/代码行我得出的结论是,您想要生成与您放入代码中的相同的图像,现在,由于图像存储在相同数量的变量中,您的模型只需传递相同的图像图像到每个步骤而不更改图像中的任何内容,这会激励您的模型将每个过滤器参数优化为 1。

Part 3现在棺材上最大的钉子来了,你已经实现了一个dropout层,首先你应该NEVER在卷积层中实现dropout。此链接解释了原因,并讨论了我认为如果您是初学者应该查看的各种想法。 https://towardsdatascience.com/dont-use-dropout-in-convolutional-networks-81486c823c16现在让我们看看为什么你使用 Dropout 的方式真的很糟糕。正如已经解释过的,最适合您模型的参数是学习值 1 的过滤器中的所有参数。现在发生的情况是您强制关闭其中一些过滤器,这除了关闭所讨论的一些过滤器之外没有任何作用在文章中,这一切都会降低下一层图像的强度。(因为 CNN 过滤器对所有输入通道取平均值)

DrO2=Dropout(0.25,name='DrO2')(BNd)

Part 4这只是一点建议,不会成为任何问题的根源BNd=BatchNormalization()(deconv3)

在这里,您尝试对批次中的数据进行标准化,在大多数情况下,数据标准化非常重要,因为您可能知道它不会让一个特征决定模型,并且每个特征在模型中获得平等的发言权,但在图像数据中每个点都已在 0 到 255 之间缩放,因此使用归一化将其缩放到 0 到 1 之间不会增加任何值,只会向模型添加不必要的计算。

我建议你逐步理解,如果有不清楚的地方,请在下面评论,尽量不要使用 CNN 来谈论自动编码器(无论如何它们没有任何实际应用),而是用它来理解 ConvNet 的各种复杂性( CNN),我选择写这样的答案来解释你的网络部分而不是代码的原因是因为你正在寻找的代码只需谷歌搜索即可,如果你对这个答案感兴趣并且想要了解 CNN 的具体工作原理,请查看此内容,如果您对此答案中的任何内容甚至对这些视频有任何疑问,请在下面评论。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何提高自编码器的准确率? 的相关文章

  • 需要根据数据框中的行号应用不同的公式

    我正在努力在数据框中找到某种移动平均值 该公式将根据正在计算的行数而变化 实际场景是我需要计算Z列 Edit 2 以下是我正在使用的实际数据 Date Open High Low Close 0 01 01 2018 1763 95 176
  • 按每个元素中出现的数字对字符串列表进行排序[重复]

    这个问题在这里已经有答案了 我有一个脚本 其目的是对不断下载到服务器上的空间数据集文件进行排序和处理 我的列表目前大致如下 list file t00Z wrff02 grib2 file t00Z wrff03 grib2 file t0
  • virtualenvwrapper 函数在 shell 脚本中不可用

    所以 我再一次制作了一个很棒的 python 程序 它让我的生活变得更加轻松 并节省了大量时间 当然 这涉及到一个 virtualenv 用mkvirtualenvvirtualenvwrapper 的功能 该项目有一个requiremen
  • 是否可以在 IPython 控制台中显示 pandas 样式?

    是否可以显示熊猫风格 https pandas pydata org pandas docs stable user guide style html在 iPython 控制台中 Jupyter 笔记本中的以下代码 import panda
  • pandas read_csv 之前预处理数据文件

    我使用 SAP 的数据输出 但它既不是 CSV 因为它不引用包含其分隔符的字符串 也不是固定宽度 因为它具有多字节字符 它是一种 固定宽度 字符 为了将其放入 pandas 我当前读取文件 获取分隔符位置 对分隔符周围的每一行进行切片 然后
  • 打印出网络架构中每一层的形状

    在 Keras 中 我们可以如下定义网络 有什么办法可以输出每层之后的形状 例如 我想打印出以下形状inputs在定义行之后inputs 然后打印出形状conv1在定义行之后conv1 etc inputs Input 1 img rows
  • 在 python 中发送标头[重复]

    这个问题在这里已经有答案了 我有以下 python 脚本 我想发送 假 标头信息 以便我的应用程序就像 Firefox 一样运行 我怎么能这么做呢 import urllib urllib2 cookielib username passw
  • NumPy 数组与 SQLite

    我在 Python 中见过的最常见的 SQLite 接口是sqlite3 但是有什么东西可以很好地与 NumPy 数组或 rearray 配合使用吗 我的意思是 它可以识别数据类型 不需要逐行插入 并提取到 NumPy rec 数组中 有点
  • 将列表值转换为 pandas 中的行

    我有数据帧 其中一列具有相同长度的 numpy ndarray 值 df list 0 Out 92 array 0 0 0 0 29273096 0 30691767 0 27531403 我想将这些列表值转换为数据框并从 df iloc
  • 使用 python 从 CSV 创建字典

    我有一个 CSV 格式的文件 其中 A B 和 C 是标题 我如何以Python方式将此CSV转换为以下形式的字典 A 1 B 4 C 7 A 2 B 5 C 8 A 3 B 6 C 9 到目前为止我正在尝试以下代码 import csv
  • 类变量:“类列表”与“类布尔值”[重复]

    这个问题在这里已经有答案了 我不明白以下示例的区别 一次类的实例可以更改另一个实例的类变量 而另一次则不能 示例1 class MyClass object mylist def add self self mylist append 1
  • 机器学习的周期性数据(例如度角 -> 179 与 -179 相差 2)

    我使用 Python 进行核密度估计 并使用高斯混合模型对多维数据样本的可能性进行排名 每一条数据都是一个角度 我不确定如何处理机器学习的角度数据的周期性 首先 我通过添加 360 来删除所有负角 因此所有负角都变成了正角 179 变成了
  • 在python中使用编解码器utf-8打开文件错误

    我在 windows xp 和 python 2 6 4 上执行以下代码 但它显示 IOError 如何打开名称带有 utf 8 编解码器的文件 gt gt gt open unicode txt euc kr encode utf 8 T
  • 为什么我无法杀死 k8s pod 中的 python 进程?

    我试图杀死一个 python 进程 ps aux grep python root 1 12 6 2 1 2234740 1332316 Ssl 20 04 19 36 usr bin python3 batch run py root 4
  • 列表中的特定范围(python)

    我有一个从文本字符串中提取的整数列表 因此当我打印该列表 我称之为test I get 135 2256 1984 3985 1991 1023 1999 我想打印或制作一个仅包含特定范围内的数字的新列表 例如1000 2000之间 我尝试
  • 在Python中随机交错2个数组

    假设我有两个数组 a 1 2 3 4 b 5 6 7 8 9 我想将这两个数组交错为变量 c 注意 a 和 b 不一定具有相同的长度 但我不希望它们以确定性的方式交错 简而言之 仅仅压缩这两个数组是不够的 我不想要 c 1 5 2 6 3
  • 在 4K 屏幕上使用 Matplotlib 和 TKAgg 或 Qt5Agg 后端

    我在 Ubuntu 16 04 上使用 Matplotlib 2 0 和 Python 3 6 来创建数据图 电脑显示器的分辨率为 4k 分辨率为 3840x2160 绘图数字看起来非常小 字体也很小 我已经尝试过TKAgg and Qt5
  • python 中的 F 字符串前缀给出语法错误[重复]

    这个问题在这里已经有答案了 我有一个名为 method 的变量 它的值是 POST 但是当我尝试运行时print f method method is used 它不断在最后一个双引号处给出语法错误 我找不到它这样做的原因 我正在使用 py
  • Django 中使用外键的抽象基类继承

    我正在尝试在 Django 支持的网站上进行模型继承 以遵守 DRY 我的目标是使用一个名为 BasicCompany 的抽象基类来为三个子类提供通用信息 Butcher Baker CandlestickMaker 它们位于各自的应用程序
  • 在Python中使用os.makedirs创建目录时出现权限问题

    我只是想处理上传的文件并将其写入工作目录中 该目录的名称是系统时间戳 问题是我想以完全权限创建该目录 777 但我不能 使用以下代码创建的目录755权限 def handle uploaded file upfile cTimeStamp

随机推荐

  • jsTree 拖放按类限制文件夹

    如何通过类名 class locked 锁定文件夹上的拖动功能 同时锁定其他要拖到该文件夹 中的文件夹class locked 我想要一个既具有拖放功能又具有上下文菜单的设置 如果节点的类名 锁定 我只想禁用上下文菜单的编辑以及拖入此文件夹
  • 使用 python 有效提取 1-5 克

    我有一个 3 000 000 行的巨大文件 每行有 20 40 个单词 我必须从语料库中提取 1 到 5 个 ngram 我的输入文件是标记化的纯文本 例如 This is a foo bar sentence There is a com
  • 用于从 Google Sheets URL 中提取电子表格 ID 和工作表 ID 的 JavaScript 正则表达式

    我想要 Javascript 正则表达式从 google 表格 URL 中提取电子表格 ID 和工作表 ID Sheets google com 电子表格的 URL 如下所示 https docs google com spreadshee
  • 删除 d3js 不工作的事件侦听器

    我有一个 SVG 结构 里面有一些形状 我想在单击形状时触发一个事件 在 SVG 上单击时触发另一个事件 问题是 SVG 事件总是被触发 为了防止这种情况 我禁用了形状的事件冒泡 我还尝试使用 d3 禁用该事件 但似乎不起作用 还尝试使用本
  • 朱莉娅 git 错误

    几个月前我在使用 Julia 最近我想再次使用它 我想要一个新版本 所以我删除了以前的版本和我拥有的所有软件包 现在 安装新版本后 0 6 2 我无法使用任何 Pkg 命令 使用后会出现以下错误init add or update 错误 G
  • 通过 pod 访问 kubernetes python api

    所以我需要通过 pod 连接到 python kubernetes 客户端 我一直在尝试使用config load incluster config 基本上遵循以下示例here https github com kubernetes cli
  • Spearman 与底座 R 的尺距距离

    给定两个排列 v1 1 4 3 1 5 2 v2 1 2 3 4 5 1 如何计算以 R 为基数的 Spearman 尺尺距离 所有元素的总位移 可灵活用于任意两种尺寸排列n 例如 对于这两个向量 如下 1被感动了2地点来自v1 to v2
  • 如何为多个开发人员使用 git

    对于经验丰富的 Git 用户来说 这是一个非常简单的问题 我已经在 git 托管上创建了存储库并设置了我的电脑 git init git remote add origin git sourcerepo com git 然后 经过一些更改后
  • 爪哇。 GUI WindowBuilder 通过单击按钮从 JTextField 读取

    I m useing WindowBuilder and I want to ask how to search in a text file for specific word which I enter to JTextField by
  • 如何在 Python 中使用 Selenium 获取
    1. 元素的长度?

    我有一个 ol 在我的 HTML 中列出 如下所示 ol li class foo li li class foo li li class foo li li class foo li ol 我需要做的是验证 ol 列表包含 li 内的项目
  • ReaderWriterLockSlim 和 async\await

    我有一些问题ReaderWriterLockSlim 我无法理解它是如何发挥作用的 My code private async Task LoadIndex if File Exists FileName index txt return
  • 在 vi 中删除连续的重复行而不排序

    这个问题 https stackoverflow com questions 351161 removing duplicate rows in vi已经解决了如何删除重复行 但强制首先对列表进行排序 我想执行删除连续重复行步骤 即uniq
  • 带数组的 SwitchMap 运算符

    我正在尝试学习 rxjs 和 Observable 的一般概念 并且有一个场景 我有一类
  • 如何防止引用的包含搜索当前源文件的目录?

    海湾合作委员会提供 I 选项 其中 I之前的目录 I 搜索引用的包含 include foo h and I以下目录 I 搜索括号内的包含 include
  • 在verilog中将wire值转换为整数

    我想将电线中的数据转换为整数 例如 wire 2 0 w 3 b101 我想要一个将其转换为 5 并将其存储在整数中的方法 我怎样才能以比这更好的方式做到这一点 j 1 for i 0 i lt 2 i i 1 begin a a w i
  • 如何通过 Google Drive API 使用刷新令牌生成访问令牌?

    我已完成授权步骤并获得访问令牌和刷新令牌 接下来我应该做什么来使用我通过 google Drive API 存储的刷新令牌生成访问令牌 由于我在 Force com 上工作 因此我无法使用任何 sdk 因此请建议直接通过 API 实现它的方
  • 经典 asp - 仅接收肥皂响应的一部分

    我试图从经典 asp 调用肥皂请求 它将在稍后更新 但现在它仍然是经典 asp 但我只得到一半的响应 当我在 SoapUI 中使用请求字符串时 我得到了我正在寻找的响应 但在 asp 中我只收到了部分响应 ASP 请求 Set oXmlHT
  • scala:重写构造函数的隐式参数

    我有一个类 它采用隐式参数 该参数由类内部方法调用的函数使用 我希望能够覆盖该隐式参数 或者从其源复制隐式参数 举个例子 def someMethod implicit p List Int uses p class A implicit
  • 如何在市场上发布应用程序的两个版本?

    我想将我的应用程序的两个版本添加到 Android 市场 一种只需几美分 另一种是带有广告的免费版本 这是一种非常常见的做法 我目前正在将 AdMod 构建到我的应用程序中 看来我必须更改相当多的文件 因此最好为此制作一个单独的应用程序版本
  • 如何提高自编码器的准确率?

    我有一个自动编码器 我使用不同的解决方案检查了模型的准确性 例如更改转换层的数量并增加它们 添加或删除批量归一化 更改激活函数 但所有这些解决方案的准确性都是相似的 并且不一样有任何奇怪的改进 我很困惑 因为我认为这些不同解决方案的准确度应