keras自定义损失纯python(没有keras后端)

2024-02-29

我目前正在编写一个用于图像压缩的自动编码器。我想使用用纯 python 编写的自定义损失函数,即不使用 keras 后端函数。这是否可能?如果可能的话,如何实现? 如果可能的话,我将非常感谢您提供一个最小工作示例(MWE)。 请查看这个 MWE,特别是 mse_keras 函数:

# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import keras.backend as K
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.layers import Input, Dense


def mse_keras(A,B):
    mse = K.mean(K.square(A - B), axis=-1)
    return mse


# Loads the training and test data sets (ignoring class labels)
(x_train, _), (x_test, _) = mnist.load_data()

# Scales the training and test data to range between 0 and 1.
max_value = float(x_train.max())
x_train = x_train.astype('float32') / max_value
x_test = x_test.astype('float32') / max_value


x_train.shape, x_test.shape
# ((60000, 28, 28), (10000, 28, 28))


x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

(x_train.shape, x_test.shape)
# ((60000, 784), (10000, 784))


# input dimension = 784
input_dim = x_train.shape[1]
encoding_dim = 32

compression_factor = float(input_dim) / encoding_dim
print("Compression factor: %s" % compression_factor)

autoencoder = Sequential()
autoencoder.add(Dense(encoding_dim, input_shape=(input_dim,), activation='relu'))
autoencoder.add(Dense(input_dim, activation='sigmoid'))

autoencoder.summary()

input_img = Input(shape=(input_dim,))
encoder_layer = autoencoder.layers[0]
encoder = Model(input_img, encoder_layer(input_img))

encoder.summary()


autoencoder.compile(optimizer='adam', loss=mse_keras, metrics=['mse'])
history=autoencoder.fit(x_train, x_train,
                        epochs=3,
                        batch_size=256,
                        shuffle=True,
                        validation_data=(x_test, x_test))

num_images = 10
np.random.seed(42)
random_test_images = np.random.randint(x_test.shape[0], size=num_images)

decoded_imgs = autoencoder.predict(x_test)


#print(history.history.keys())

plt.figure()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])

plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test', 'mse1', 'val_mse1'], loc='upper left')
plt.show()


plt.figure(figsize=(18, 4))

for i, image_idx in enumerate(random_test_images):
    # plot original image
    ax = plt.subplot(3, num_images, i + 1)
    plt.imshow(x_test[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # plot reconstructed image
    ax = plt.subplot(3, num_images, 2*num_images + i + 1)
    plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

上面的代码是使用 Keras 后端的自定义损失函数的 MWE。然而,这不是我想要的!我想用以下代码替换代码中的 mse_keras 函数:

def my_mse(A,B):
    mse = ((A - B) ** 2).mean(axis=None)
    return mse

这又只是一个 MWE。它是纯Python和scipy。没有 KERAS 后端! 是否可以使用纯 python 函数作为损失函数(我尝试使用 py_func,但它对我不起作用。) 我之所以问这个问题是因为最终我想使用一种更复杂的损失函数,该函数已经在Python中实现了。而且,我不知道如何使用 keras 后端重新实现它。 (说实话,我也没有时间这样做)

(出于好奇:我想用作损失函数的函数可以在这里看到:https://github.com/aizvorski/video-quality https://github.com/aizvorski/video-quality)

任何帮助将不胜感激。后端可以是theano,tensorflow,我不在乎。如果可能的话,请给我提供 python 3.X 中的 MWE。

提前谢谢了。非常感谢您的帮助。


您不能使用纯 Python 函数作为 Keras 的损失。由于您可能在 GPU 上进行训练,而 python 使用 CPU,这会通过将结果从 GPU 内存传输到 GPU 内存来产生开销。

from https://keras.io/losses/ https://keras.io/losses/

您可以传递现有损失函数的名称,也可以传递 TensorFlow/Theano符号函数返回每个数据点的标量并采用以下两个参数:y_true、y_pred

你的功能将是(与原来的功能相同)

def my_mse(A,B):
    mse = K.mean(K.pow(A - B, 2), axis=None)
    return mse

但是,检查 Keras API,它需要每个数据点都有一个标量,因此取均值可能无法像这样工作axis=None.

我快速浏览了您链接的损失函数,并在 Keras 中实现它们应该是可能的并且不太困难。 Keras(或者实际上是后端 Tensorflow)具有与 numpy 类似的接口。了解后端的计算图(即张量流)如何实现损失可能会很有用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras自定义损失纯python(没有keras后端) 的相关文章

随机推荐

  • 将 UI 图像放入网格布局组 Unity C#

    我在将 UI 图像放入 Unity 的网格布局组中时遇到问题 这会发生什么 检查下面的链接 IMG http i65 tinypic com fp2dly jpg http i65 tinypic com fp2dly jpg IMG IM
  • Docile.jl 在 Julia 0.3 中的使用示例

    我是朱莉娅的新手 我有兴趣使用温顺 jl https michaelhatherly github io Docile jl index html向现有 Julia 项目添加文档 根据这个帖子 https stackoverflow com
  • PHP:字符串到正则表达式

    我尝试使用字符串作为正则表达式模式 但出现以下错误 PHP Warning preg match Unknown modifier gt in Applications MAMP htdocs cruncher Plugins wordpr
  • 如何使 PHPUnit 在有风险的测试中失败

    我想要 PHPUnitfail如果一项或多项测试被认为有风险 实际上 PHPUnit 5 3 4 by Sebastian Bergmann and contributors RRR 7 7 100 Time 2 83 seconds Me
  • 在哪里可以找到新的 azure devops 扩展的所有可用贡献目标?

    新的azure扩展开发文档 https developer microsoft com en us azure devops develop extensions指向一个示例项目github https github com Microso
  • LLVM 和编译器术语

    我正在研究 LLVM 系统并且我已经阅读了入门文档 http llvm org docs GettingStarted html 然而 一些术语 以及 clang 示例中的措辞 仍然有点令人困惑 以下术语和命令都是编译过程的一部分 我想知道
  • 如何对嵌入 JSON 的 JSON 进行编码

    我有一个 JSON 字符串 其中一个字段是文本字段 此文本字段可以包含用户在 UI 中输入的文本 如果他们输入的文本是 JSON 文本 也许为了说明一些编码 我需要对其文本进行编码 以便它不会在发送的实际 JSON 结构中被解释为 JSON
  • 创建具有多个子文件夹链接的 Ajax 网站失败

    我正在尝试创建一个site那是loading全部都是通过 Ajax 的内容 假设该网站是www abc net I have abc net index html并且无论输入什么 URL 文件夹 文件 该文件都将始终被调用 abc net
  • RGDAL 无法安装

    我无法在 R 中安装 RGDAL 我使用的是 Ubuntu 12 04 configure error gdal config not found or not executable ERROR configuration failed f
  • 关于 REST 响应和 XMLElement

    我有下面一个需要在代码中创建的 REST 响应
  • Page_Load 未在 UserControl 中触发

    我在类库中用 C 创建了一个类 并将此控件添加到了 default aspx 但我的代码没有触发 page load 事件 这是代码 我究竟做错了什么 页面已加载 但页面上未显示标签 我已将控件正确添加到页面 没有任何错误 我已经在其中添加
  • 同步集合包装器工厂方法如何“拥有”传递给它的对象?

    Brian Goetz 在 Java Concurrency in Practice 一书中说 传递给类的构造函数和方法的对象是不拥有由类本身 是因为他们是从外面来的 班级无法控制他们吗 他接着说 如果方法被明确设计为转移传入对象的所有权
  • C++ 隐式参数的顺序: this 和返回的对象,哪个在先?

    在 C 中 成员函数最多可以有 2 个隐式参数 this指针和返回对象的地址 它们位于显式参数之前 但是 哪个先走 我特别对 Android NDK 基于 gcc ARM 中发生的情况感兴趣 Example class MyClass pu
  • Java Kafka adminClient 主题配置。配置值被覆盖

    在尝试使用 java kafka adminClient 配置新创建的 kafka 主题时 值被覆盖 我尝试使用控制台命令设置相同的主题配置 并且它有效 不幸的是 当我尝试通过 Java 代码时 一些值发生冲突并被覆盖 ConfigReso
  • Celery 在任何更改时自动重新加载

    当模块发生更改时 我可以使 celery 自动重新加载CELERY IMPORTS in settings py 我尝试让母模块检测子模块的变化 但它没有检测到子模块的变化 这让我明白检测不是由 celery 递归完成的 我在文档中搜索了它
  • 在android中使用内容提供程序获取联系号码

    我按照本教程学习了内容提供商的基础知识 http www vogella de articles AndroidSQLite article html http www vogella de articles AndroidSQLite a
  • 使用更改 django 模板中表单字段的名称属性

    我有表单字段 表单 项目 这将呈现为
  • 数据类型映射参数中的键只能使用列名

    我已经使用 dask read sql table 从 Oracle 数据库成功引入了一张表 但是 当我尝试引入另一个表时 出现此错误KeyError 只有列名可以用作数据类型映射参数中的键 我已经检查了我的连接字符串和架构 所有这些都很好
  • 转换为日期格式 dd/mm/yyyy

    我有以下日期 2010 04 19 18 31 27 我想将此日期转换为日 月 年 format 您可以使用正则表达式或一些手动字符串摆弄 但我想我更喜欢 date d m Y strtotime str
  • keras自定义损失纯python(没有keras后端)

    我目前正在编写一个用于图像压缩的自动编码器 我想使用用纯 python 编写的自定义损失函数 即不使用 keras 后端函数 这是否可能 如果可能的话 如何实现 如果可能的话 我将非常感谢您提供一个最小工作示例 MWE 请查看这个 MWE