如何为解码器加载经过训练的自动编码器权重?

2024-01-05

我有一个 CNN 1d 自动编码器,它有一个密集的中央层。我想训练这个自动编码器并保存它的模型。我还想保存解码器部分,目标是:将一些中心特征(独立计算)提供给经过训练和加载的解码器,通过解码器查看这些独立计算的特征的图像是什么。

## ENCODER
encoder_input = Input(batch_shape=(None,501,1))
x  = Conv1D(256,3, activation='tanh', padding='valid')(encoder_input)
x  = MaxPooling1D(2)(x)
x  = Conv1D(32,3, activation='tanh', padding='valid')(x)
x  = MaxPooling1D(2)(x)
_x = Flatten()(x)
encoded = Dense(32,activation = 'tanh')(_x)

## DECODER (autoencoder)
y = Conv1D(32, 3, activation='tanh', padding='valid')(x)
y = UpSampling1D(2)(y)
y = Conv1D(256, 3, activation='tanh', padding='valid')(y)
y = UpSampling1D(2)(y)
y = Flatten()(y)
y = Dense(501)(y)
decoded = Reshape((501,1))(y)

autoencoder = Model(encoder_input, decoded)
autoencoder.save('autoencoder.hdf5')

## DECODER (independent)
decoder_input = Input(batch_shape=K.int_shape(x))  # import keras.backend as K
y = Conv1D(32, 3, activation='tanh', padding='valid')(decoder_input)
y = UpSampling1D(2)(y)
y = Conv1D(256, 3, activation='tanh', padding='valid')(y)
y = UpSampling1D(2)(y)
y = Flatten()(y)
y = Dense(501)(y)
decoded = Reshape((501,1))(y)

decoder = Model(decoder_input, decoded)
decoder.save('decoder.hdf5')

EDIT:

为了确保清楚,我首先需要加入encoded和第一个y, 在某种意义上说y必须采取encoded作为输入。完成此操作后,我需要一种方法来加载经过训练的解码器并替换encoded具有一些新的核心功能,我将向我的解码器提供这些功能。

编辑以下答案:

我实施了建议,请参阅下面的代码

## ENCODER
encoder_input = Input(batch_shape=(None,501,1))
x  = Conv1D(256,3, activation='tanh', padding='valid')(encoder_input)
x  = MaxPooling1D(2)(x)
x  = Conv1D(32,3, activation='tanh', padding='valid')(x)
x  = MaxPooling1D(2)(x)
_x = Flatten()(x)
encoded = Dense(32,activation = 'tanh')(_x)

## DECODER (autoencoder)
encoded = Reshape((32,1))(encoded)
y = Conv1D(32, 3, activation='tanh', padding='valid')(encoded)
y = UpSampling1D(2)(y)
y = Conv1D(256, 3, activation='tanh', padding='valid')(y)
y = UpSampling1D(2)(y)
y = Flatten()(y)
y = Dense(501)(y)
decoded = Reshape((501,1))(y)

autoencoder = Model(encoder_input, decoded)
autoencoder.compile(optimizer='adam', loss='mse')
epochs = 10
batch_size = 100
validation_split = 0.2
# train the model
history = autoencoder.fit(x = training, y = training,
                    epochs=epochs,
                    batch_size=batch_size,
                    validation_split=validation_split)
autoencoder.save_weights('autoencoder_weights.h5')


## DECODER (independent)
decoder_input = Input(batch_shape=K.int_shape(encoded))  # import keras.backend as K
y = Conv1D(32, 3, activation='tanh', padding='valid', name='decod_conv1d_1')(decoder_input)
y = UpSampling1D(2, name='decod_upsampling1d_1')(y)
y = Conv1D(256, 3, activation='tanh', padding='valid', name='decod_conv1d_2')(y)
y = UpSampling1D(2, name='decod_upsampling1d_2')(y)
y = Flatten(name='decod_flatten')(y)
y = Dense(501, name='decod_dense1')(y)
decoded = Reshape((501,1), name='decod_reshape')(y)

decoder = Model(decoder_input, decoded)
decoder.save_weights('decoder_weights.h5')


encoder = Model(inputs=encoder_input, outputs=encoded, name='encoder')
features = encoder.predict(training) # features
np.savetxt('features.txt', np.squeeze(features))

predictions = autoencoder.predict(training)
predictions = np.squeeze(predictions)
np.savetxt('predictions.txt', predictions)

然后我打开另一个文件

import h5py
import keras.backend as K

def load_weights(model, filepath):
    with h5py.File(filepath, mode='r') as f:
        file_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
        model_layer_names = [layer.name for layer in model.layers]

        weight_values_to_load = []
        for name in file_layer_names:
            if name not in model_layer_names:
                print(name, "is ignored; skipping")
                continue
            g = f[name]
            weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]

            weight_values = []
            if len(weight_names) != 0:
                weight_values = [g[weight_name] for weight_name in weight_names]
            try:
                layer = model.get_layer(name=name)
            except:
                layer = None
            if layer is not None:
                symbolic_weights = (layer.trainable_weights + 
                                    layer.non_trainable_weights)
                if len(symbolic_weights) != len(weight_values):
                    print('Model & file weights shapes mismatch')
                else:
                    weight_values_to_load += zip(symbolic_weights, weight_values)

        K.batch_set_value(weight_values_to_load)

## DECODER (independent)
decoder_input = Input(batch_shape=(None,32,1))
y = Conv1D(32, 3, activation='tanh',padding='valid',name='decod_conv1d_1')(decoder_input)
y = UpSampling1D(2, name='decod_upsampling1d_1')(y)
y = Conv1D(256, 3, activation='tanh', padding='valid', name='decod_conv1d_2')(y)
y = UpSampling1D(2, name='decod_upsampling1d_2')(y)
y = Flatten(name='decod_flatten')(y)
y = Dense(501, name='decod_dense1')(y)
decoded = Reshape((501,1), name='decod_reshape')(y)

decoder = Model(decoder_input, decoded)
#decoder.save_weights('decoder_weights.h5')

load_weights(decoder, 'autoencoder_weights.h5')

# Read autoencoder
decoder.summary()

# read encoded features
features = np.loadtxt('features.txt'.format(batch_size, epochs))
features = np.reshape(features, [1500,32,1])

# evaluate loaded model on features
prediction = decoder.predict(features)



autoencoderpredictions = np.loadtxt('predictions.txt'.format(batch_size, epochs))

fig, ax = plt.subplots(5, figsize=(10,20))
for i in range(5):
        ax[i].plot(prediction[100*i], color='blue', label='Decoder')
        ax[i].plot(autoencoderpredictions[100*i], color='red', label='AE')
        ax[i].set_xlabel('Time components', fontsize='x-large')
        ax[i].set_ylabel('Amplitude', fontsize='x-large')
        ax[i].set_title('Seismogram n. {:}'.format(1500+100*i+1), fontsize='x-large')
        ax[i].legend(fontsize='x-large')
plt.subplots_adjust(hspace=1)
plt.close()

prediction and autoencoderpredictions不同意。看起来好像prediction只是很小的噪音,而autoencoder predictions具有合理的价值。


你需要:(1)保存AE(自动编码器)的权重; (2)负载权重文件; (3) 反序列化文件并仅分配那些与新模型(解码器)兼容的权重。

  • (1): .save确实包括权重,但有一个额外的反序列化步骤,可以通过使用来避免.save_weights反而。还,.save保存优化器状态和模型架构,后者与您的新解码器无关
  • (2): load_weights默认情况下尝试分配all节省了重量,但这是行不通的

下面的代码完成 (3)(以及补救措施 (2))如下:

  1. 加载所有重量
  2. 检索加载的重量名称并将其存储在file_layer_names (list)
  3. 取回当前型号权重名称并将它们存储在model_layer_names (list)
  4. 迭代一遍file_layer_names as name; if name is in model_layer_names,将带有该名称的加载重量附加到weight_values_to_load
  5. 分配权重weight_values_to_load建模使用K.batch_set_value

请注意,这需要您nameAE 和解码器模型中的每一层并使它们匹配。可以重写此代码以在 a 中按顺序进行暴力分配try-except循环,但这既低效又容易出错。


Usage:

## omitted; use code as in question but name all ## DECODER layers as below
autoencoder.save_weights('autoencoder_weights.h5')

## DECODER (independent)
decoder_input = Input(batch_shape=K.int_shape(x))
y = Conv1D(32, 3, activation='tanh',padding='valid',name='decod_conv1d_1')(decoder_input)
y = UpSampling1D(2, name='decod_upsampling1d_1')(y)
y = Conv1D(256, 3, activation='tanh', padding='valid', name='decod_conv1d_2')(y)
y = UpSampling1D(2, name='decod_upsampling1d_2')(y)
y = Flatten(name='decod_flatten')(y)
y = Dense(501, name='decod_dense1')(y)
decoded = Reshape((501,1), name='decod_reshape')(y)

decoder = Model(decoder_input, decoded)
decoder.save_weights('decoder_weights.h5')

load_weights(decoder, 'autoencoder_weights.h5')

功能:

import h5py
import keras.backend as K

def load_weights(model, filepath):
    with h5py.File(filepath, mode='r') as f:
        file_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
        model_layer_names = [layer.name for layer in model.layers]

        weight_values_to_load = []
        for name in file_layer_names:
            if name not in model_layer_names:
                print(name, "is ignored; skipping")
                continue
            g = f[name]
            weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]

            weight_values = []
            if len(weight_names) != 0:
                weight_values = [g[weight_name] for weight_name in weight_names]
            try:
                layer = model.get_layer(name=name)
            except:
                layer = None
            if layer is not None:
                symbolic_weights = (layer.trainable_weights + 
                                    layer.non_trainable_weights)
                if len(symbolic_weights) != len(weight_values):
                    print('Model & file weights shapes mismatch')
                else:
                    weight_values_to_load += zip(symbolic_weights, weight_values)

        K.batch_set_value(weight_values_to_load)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何为解码器加载经过训练的自动编码器权重? 的相关文章

随机推荐

  • ASP.NET Core 3.0 Razor Pages 中的路由本地化

    我想在 ASP NET Core 3 0 Razor Pages 应用程序中使用路由本地化 https stackoverflow com a 52976625 107718 https stackoverflow com a 529766
  • 使 JPA EntityManager 会话失效

    我正在开发的一个项目使用 Spring 2 5 和 JPA 并以 Hibernate 作为提供程序 我的 DAO 类扩展了 JpaDaoSupport 因此我使用 getJpaTemplate 方法获取 JpaTemplate 后端数据库可
  • 在VB6中编译DLL时出现“加载DLL时出错”

    我有一个使用引用的 Visual Basic 6 dll 项目 当单击 文件 gt 生成 dll 选项时 它应该生成一个 dll 文件 好吧 当单击 文件 gt 生成 dll 时 我收到错误 加载 DLL 时出错 如何查看缺少哪些参考文献
  • 由 twine python 发布的包未出现在存储库中

    我正在尝试将我的 python 包发布到私有存储库 我是按照官方指南来的https packaging python org en latest tutorials packaging projects https packaging py
  • 如何在已被 Rails 转义的正则表达式中转义 \\ ?

    我试图将正则表达式存储在数据库中 但它们被 Rails 转义了 例如 w s s变成 w s s在数据库中以及检索时 我插入尝试将它们与 mystring sub regex variable 一起使用 但转义的正则表达式未按预期匹配 解决
  • Mongodb:如何检查点是否包含在多边形中?

    我有一个点数组 纬度 经度 中某个区域的点列表 我已经在这些数组上创建了一个索引 现在我想知道一个点是否在该多边形内部 MongoDB 可以吗 我已经尝试过这些命令但没有运气 gt polygonA 48 780809 2 307129 4
  • 具有左右标签的 UITableViewCell 的最佳方法

    我的应用程序有多个可选择的设置 例如枚举值 我想复制 iOS 的声音设置表视图单元格 其中名称位于左侧 所选值位于右侧 后面是公开指示器 gt 到目前为止 我的方法是创建一个自定义表格视图单元格 xib和定制UITableViewCell类
  • 如何使用 jQuery 或纯 JS 重置所有复选框?

    如何使用 jQuery 或纯 JS 重置文档中的所有复选框 如果您的意思是如何从所有复选框中删除 选中 状态 input checkbox removeAttr checked
  • 有没有适用于Python3的工作内存分析器[关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 Python 2 中有一些工具 但一切似乎都已经过时了 我找到了 PySizer 和 Heapy 但一
  • 使用 LINQ 查找数组中的最小和最大日期?

    我有一系列带有属性的类Date i e class Record public DateTime Date get private set void Summarize Record arr foreach var r in arr do
  • 多少个线程太多了?

    我正在编写一个服务器 当收到请求时 我将每个操作发送到一个单独的线程中 我这样做是因为几乎每个请求都会进行数据库查询 我正在使用线程池库来减少线程的构造 销毁 我的问题是 对于这样的 I O 线程来说 什么是一个好的截止点 我知道这只是一个
  • 使用新标签页替换插件打开新标签时,如何保持地址栏清晰?

    我正在为 Firefox 开发一个新的标签页替换插件 安装后 当我单击新选项卡图标打开新选项卡时 新选项卡打开正常 但地址栏显示混乱的 URL 资源 firefox p at getblog dot com getblog buttons
  • 使用字典更新 pandas DataFrame 行

    我在 pandas DataFrames 中发现了我不理解的行为 df pd DataFrame np random randint 1 10 3 3 index one one two columns col1 col2 col3 new
  • Rails 升级到 Angular 2

    我想升级现有的 Rails 和 Angular 1 x 应用程序 我正在关注 ng upgrade文档 https angular io docs ts latest guide upgrade html并看到有很多依赖项 包括system
  • WPF:TabControl 和动态 TabItem

    我正在尝试使用 C 中的 WPF 为我当前的项目创建一个 GUI 我想要有选项卡 在运行时动态创建 并且每个选项卡应该打开一个具有相同列标题但内容不同的表 我知道我可以实现这样的选项卡和表格
  • 如何组合列表元素并找到最大组合的价格

    我有一个类 其中包含特定项目的详细信息 如下所示 Detail class Long detailsId Integer price List
  • 如何处理 REST API 中的更新?

    我想了解一些有关使用 RESTful API 执行写入的方法的观点 对于此示例 假设有一个 Person 对象 id 1 name Example Person addresses id 11 friends id 21 name John
  • 为什么 cabal 不能动态构建 mighttpd2?

    GHC 当静态链接我的可执行文件时太慢 所以我想使用 dynamic 选项进行测试 尽管以下两个命令会导致相同的错误cabal install mighttpd2 is ok cabal install ghc options dynami
  • 如何通过帖子链接阅读 Telegram 频道帖子的内容?

    右键单击 Telegram 频道帖子时会显示帖子链接 格式如下 https telegram me channel name post ID https telegram me channel name post ID 问题是我们如何使用服
  • 如何为解码器加载经过训练的自动编码器权重?

    我有一个 CNN 1d 自动编码器 它有一个密集的中央层 我想训练这个自动编码器并保存它的模型 我还想保存解码器部分 目标是 将一些中心特征 独立计算 提供给经过训练和加载的解码器 通过解码器查看这些独立计算的特征的图像是什么 ENCODE