Keras训练部分模型问题(关于GAN模型)

2023-12-01

我在使用keras实现GAN模型时遇到了一个奇怪的问题。

对于GAN,我们需要先建立G和D,然后添加一个新的序列模型(GAN),然后依次添加(G)、添加(D)。

当我这样做时,Keras 似乎反向传播回 G(通过 GAN 模型)D.train_on_batch,我得到了一个InvalidArgumentError: You must feed a value for placeholder tensor 'dense_input_1' with dtype float.

如果我删除GAN model(最后堆叠的 G 然后 D 顺序模型),它计算d_loss正确。

我的环境是:

  • 乌班图16.04
  • 喀拉拉邦1.2.2
  • 张量流GPU 1.0.0
  • 喀拉拉邦配置:{ "backend": "tensorflow", "image_dim_ordering": "tf", "epsilon": 1e-07, "floatx": "float32" }

我知道很多人已经成功地用 keras 实现了 GAN,所以我想知道我哪里错了。

import numpy as np
import keras.layers as kl
import keras.models as km
import keras.optimizers as ko
from keras.datasets import mnist

batch_size = 16
lr = 0.0001

def noise_gen(batch_size, z_dim):
    noise = np.zeros((batch_size, z_dim), dtype=np.float32)
    for i in range(batch_size):
        noise[i, :] = np.random.uniform(-1, 1, z_dim)
    return noise

# --------------------Generator Model--------------------

model = km.Sequential()

model.add(kl.Dense(input_dim=100, output_dim=1024))
model.add(kl.Activation('relu'))

model.add(kl.Dense(7*7*128))
model.add(kl.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Reshape((7, 7, 128), input_shape=(7*7*128,)))

model.add(kl.Deconvolution2D(64, 5, 5, (None, 14, 14, 64), subsample=(2, 2),
    input_shape=(7, 7, 128), border_mode='same'))
model.add(kl.BatchNormalization())
model.add(kl.Activation('relu'))

model.add(kl.Deconvolution2D(1, 5, 5, (None, 28, 28, 1), subsample=(2, 2),
    input_shape=(14, 14, 64), border_mode='same'))

G = model
G.compile(  loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------Discriminator Model--------------------

model = km.Sequential()

model.add(kl.Convolution2D( 64, 5, 5, subsample=(2, 2), input_shape=(28, 28, 1)))
model.add(kl.LeakyReLU(alpha=0.2))

model.add(kl.Convolution2D(128, 5, 5, subsample=(2, 2)))
model.add(kl.BatchNormalization())
model.add(kl.LeakyReLU(alpha=0.2))

model.add(kl.Flatten())
model.add(kl.Dense(1))
model.add(kl.Activation('sigmoid'))

D = model
D.compile(  loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------GAN Model--------------------

model = km.Sequential()
model.add(G)
D.trainable = False  # Is this necessary?
model.add(D)
GAN = model
GAN.compile(loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------Main Code--------------------
(X, _), _ = mnist.load_data()
X = X / 255.
X = X[:, :, :, np.newaxis]

X_batch = X[0:batch_size, :]
Z1_batch = noise_gen(batch_size, 100)
Z2_batch = noise_gen(batch_size, 100)

fake_batch = G.predict(Z1_batch)
real_batch = X_batch
print('--------------------Fake Image Generated!--------------------')

combined_X_batch = np.concatenate((real_batch, fake_batch))
combined_y_batch = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))))
print('real_batch={}, fake_batch={}'.format(real_batch.shape, fake_batch.shape))

D.trainable = True
d_loss = D.train_on_batch(combined_X_batch, combined_y_batch)
print('--------------------Discriminator trained!--------------------')
print(d_loss)

D.trainable = False
g_loss = GAN.train_on_batch(Z2_batch, np.ones((batch_size, 1)))
print('--------------------GAN trained!--------------------')
print(g_loss)

错误信息:

W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
W tensorflow/core/framework/op_kernel.cc:993] Invalid argument: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1022, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1004, in _run_fn
    status, run_metadata)
  File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: mul_5/_77 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_1018_mul_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./gen.py", line 84, in <module>
    d_loss = D.train_on_batch(combined_X_batch, combined_y_batch)
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 766, in train_on_batch
    class_weight=class_weight)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1320, in train_on_batch
    outputs = self.train_function(ins)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 1943, in __call__
    feed_dict=feed_dict)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: mul_5/_77 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_1018_mul_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Caused by op 'dense_input_1', defined at:
  File "./gen.py", line 20, in <module>
    model.add(kl.Dense(input_dim=100, output_dim=1024))
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 299, in add
    layer.create_input_layer(batch_input_shape, input_dtype)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 397, in create_input_layer
    dtype=input_dtype, name=name)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 1198, in Input
    input_tensor=tensor)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 1116, in __init__
    name=self.name)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 321, in placeholder
    x = tf.placeholder(dtype, shape=shape, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 1520, in placeholder
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2149, in _placeholder
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'dense_input_1' with dtype float
     [[Node: dense_input_1 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: mul_5/_77 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_1018_mul_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

首先,我建议您切换到功能 API 模型。功能模型更容易处理这些类型的混合模型。

老实说,我不知道为什么你的解决方案不起作用,似乎当你将 D 模型链接到新输入时,它会有点“损坏”并与之链接。 我发现解决这个问题的方法是定义层并将它们用于判别器和 GAN 模型。这是代码:

import numpy as np
from keras.layers import *
import keras.models as km
import keras.optimizers as ko
from keras.datasets import mnist

batch_size = 16
lr = 0.0001

def noise_gen(batch_size, z_dim):
    noise = np.zeros((batch_size, z_dim), dtype=np.float32)
    for i in range(batch_size):
        noise[i, :] = np.random.uniform(-1, 1, z_dim)
    return noise

# Changes the traiable argument for all the layers of model
# to the boolean argument "trainable"
def make_trainable(model, trainable):
    model.trainable = trainable
    for l in model.layers:
        l.trainable = trainable

# --------------------Generator Model--------------------

g_input = Input(shape=(100,))

g_hidden = Dense(1024, activation='relu')(g_input)
g_hidden = Dense(7*7*128, activation='relu')(g_hidden)
g_hidden = BatchNormalization()(g_hidden)
g_hidden = Reshape((7,7,128))(g_hidden)

g_hidden = Deconvolution2D(64,5,5, (None, 14, 14, 64), subsample=(2,2),
        border_mode='same', activation='relu')(g_hidden)
g_hidden = BatchNormalization()(g_hidden)
g_output = Deconvolution2D(1,5,5, (None, 28, 28, 1), subsample=(2,2),
        border_mode='same')(g_hidden)

G = km.Model(input=g_input,output=g_output)
G.compile(loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))
G.summary()

# --------------------Discriminator Model--------------------

d_input = Input(shape=(28,28,1))

d_l1 = Convolution2D(64,5,5, subsample=(2,2))
d_hidden_1 = d_l1(d_input)
d_l2 = LeakyReLU(alpha=0.2)
d_hidden_2 = d_l2(d_hidden_1)

d_l3 = Convolution2D(128,5,5, subsample=(2,2))
d_hidden_3 = d_l3(d_hidden_2)
d_l4 = BatchNormalization()
d_hidden_4 = d_l4(d_hidden_3)
d_l5 = LeakyReLU(alpha=0.2)
d_hidden_5 = d_l5(d_hidden_4)

d_l6 = Flatten()
d_hidden_6 = d_l6(d_hidden_5)
d_l7 = Dense(1, activation='sigmoid')
d_output = d_l7(d_hidden_6)

D = km.Model(input=d_input,output=d_output)
D.compile(loss='binary_crossentropy',optimizer=ko.SGD(lr=lr,momentum=0.9, nesterov=True))
D.summary()

# --------------------GAN Model--------------------
make_trainable(D,False)

gan_input = Input(shape=(100,))
gan_hidden = G(gan_input)
gan_hidden = d_l1(gan_hidden)
gan_hidden = d_l2(gan_hidden)
gan_hidden = d_l3(gan_hidden)
gan_hidden = d_l4(gan_hidden)
gan_hidden = d_l5(gan_hidden)
gan_hidden = d_l6(gan_hidden)
gan_output = d_l7(gan_hidden)

GAN = km.Model(input=gan_input,output=gan_output)
GAN.compile(loss='binary_crossentropy',optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))
GAN.summary()

# --------------------Main Code--------------------
(X, _), _ = mnist.load_data()
X = X / 255.
X = X[:, :, :, np.newaxis]

X_batch = X[0:batch_size, :]
Z1_batch = noise_gen(batch_size, 100)
Z2_batch = noise_gen(batch_size, 100)

print(type(X_batch),X_batch.shape)
print(type(Z1_batch),Z1_batch.shape)

fake_batch = G.predict(Z1_batch)
real_batch = X_batch
print('--------------------Fake Image Generated!--------------------')

combined_X_batch = np.concatenate((real_batch, fake_batch))
combined_y_batch = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))))
print('real_batch={}, fake_batch={}'.format(real_batch.shape, fake_batch.shape))
print(type(combined_X_batch),combined_X_batch.dtype,combined_X_batch.shape)
print(type(combined_y_batch),combined_y_batch.dtype,combined_y_batch.shape)
make_trainable(D,True)
d_loss = D.train_on_batch(combined_X_batch, combined_y_batch)
print('--------------------Discriminator trained!--------------------')
print(d_loss)

make_trainable(D,False)
g_loss = GAN.train_on_batch(Z2_batch, np.ones((batch_size, 1)))
print('--------------------GAN trained!--------------------')
print(g_loss)

这有帮助吗?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras训练部分模型问题(关于GAN模型) 的相关文章

随机推荐

  • Perl 中最小的非零正浮点数是多少?

    我有一个 Perl 程序 它处理的概率有时可能非常小 由于舍入误差 有时其中一个概率为零 我想做以下检查 use constant TINY FLOAT gt 1e 200 my prob calculate prob if prob 0
  • sha1() 用于密码哈希

    我使用 sha1 来保证密码安全 我以这种方式将密码存储在register php中 secure password salt openssl random pseudo bytes 20 secured password sha1 pas
  • HTML5 Canvas putImageData,翻译它,更改图像

    我想使用 HTML5 画布绘制图像 翻译图像 然后更改图像 但保留我所做的转换 这可能吗 这是一些伪代码来说明我的问题 initially draw an image and translate it var context canvas
  • 如何在Python中打印嵌套列表的所有可能性?

    这是我的清单 pos det noun adj noun vb det vb noun adj Or pos det noun adj noun vb det vb noun adj 我正在尝试打印所有组合 det noun noun vb
  • 为什么要修复 E_NOTICE 错误?

    作为一名开发人员 我在打开 E NOTICE 的情况下工作 最近有人问我为什么应该修复 E NOTICE 错误 我能想到的唯一原因是纠正这些问题是最佳实践 还有其他人有任何理由证明花费额外的时间 成本来纠正这些问题是合理的吗 更具体地说 如
  • 在头文件中使用 extern 的优点

    这里有一个类似的问题标题 但在阅读答案时 它似乎没有解决该特定问题 C 头文件中的 extern 有什么用 它更像是 为什么使用头文件 在下面的用法中extern extern int a int b structs have no ext
  • ggplot2:箱线图是否仅用于计算 y 轴范围内的值?

    我注意到箱线图的中位数 用受限的 ylim 参数构建 可能与中位数 函数或未调整 y 轴的箱线图获得的中位数不同 这是否意味着箱线图仅用于计算位于 y 轴定义间隔内的值 如果是这样 我怎样才能获得正确的箱线图 基于所有值 但将其绘制在 y
  • findViewById 不适用于特定视图

    我有一个从 XML 加载的活动 视图像往常一样具有 ID
  • 初始化和解构控制器的最佳实践

    构建应用程序时设置控制器的最佳方法是什么 我知道路由 事件侦听器和大多数交互性都应该由控制器管理 但是我的主控制器开始失去控制 我不确定如何最好地将我的逻辑分离到单独的控制器中而不保留它们全部 一直在奔跑 即使有数十数千个控制器 也可以在应
  • Rails 3 - 登录另一个站点并在会话中保留 cookie

    我正在组装一个小应用程序 它允许用户登录一个没有足够 API 的半流行社交网站 因此我使用 mechanize gem 来自动化我想要添加的一些功能任何人都可以使用 例如群发消息 由于 API 限制 我被迫假装是与站点的 http 接口交互
  • sa 用户被锁定,Windows 身份验证登录失败。怎么办?

    我以管理员身份登录 Windows 7 计算机 我正在尝试连接到同一台计算机上的 SQL Server 2012 SA 帐户已被锁定 当我尝试 Windows 身份验证时 登录失败 我现在如何登录 有没有外部工具可以解锁SA用户 我使用外部
  • 如何将行数据溢出到其他行下方

    在下面的变量中 我想将 900000 划分到其各自的三行下面 其中存在零 对于其他值类似 在新值到达之前动态计数的零的数量 A B是数据和A Bnew是期望的输出 请告诉我如何在 R 中做到这一点 A B A Bnew 0 0 0 0 90
  • 从 PowerShell 格式化 XML

    我有一个 PowerShell 脚本 它运行一个返回 XML 的存储过程 然后 我将 XML 导出到一个文件中 但当我打开该文件时 每行末尾都有 3 个点 并且该行不完整 这是使用out file 当我使用Export Clixml从查询返
  • jQuery 电子邮件地址输入

    我希望我的网站上有一个自动完成 自动格式 收件人 字段 其工作方式类似于 GMail 中的字段 有谁知道 jQuery 有这样的东西吗 纯 JavaScript 或者还有其他替代方案吗 http bassistance de jquery
  • $_POST 在 php 5.3.5 中不起作用

    我正在 PHP 5 3 5 中处理一个页面 看起来 POST不包含从我的表单提交的数据 这是 HTML 文件
  • 如何从命令行执行 Cucumber Spring Boot 打包的 Jar?

    我对 cucumber jvm 世界相当陌生 尝试将 Cucumber Spring Boot 应用程序打包为 Jar 应用程序在 Eclipse 中运行良好 但是当我打包为可执行 jar 时 它失败并出现 Exception 主线程异常
  • WooCommerce 使用辅助 PHP 文件(调用模板)

    我该如何调用辅助 PHP 文件 这是我的代码 add filter woocommerce product tabs woo simfree product tab function woo simfree product tab tabs
  • 使用 xarray 获取 netcdf 文件的平均值

    我使用 xarray 在 python 中打开了一个 netcdf 文件 数据集摘要如下所示 Dimensions latitude 721 longitude 1440 time 41 Coordinates longitude long
  • 自托管 asp.net mvc

    是否可以在另一个应用程序中自行托管 asp net mvc 控制台 Windows 窗体 服务等 我想构建一个应用程序 提供一个 Web 界面来控制它 并且我想使用 asp net mvc 作为它的 Web 部分 我确实快速浏览了一下 Na
  • Keras训练部分模型问题(关于GAN模型)

    我在使用keras实现GAN模型时遇到了一个奇怪的问题 对于GAN 我们需要先建立G和D 然后添加一个新的序列模型 GAN 然后依次添加 G 添加 D 当我这样做时 Keras 似乎反向传播回 G 通过 GAN 模型 D train on