TensorFlow ValueError:变量不存在,或者不是使用 tf.get_variable() 创建的

2023-12-30

我是 Tensorflow 的新手,正在尝试实现生成对抗网络。我正在关注this https://github.com/adeshpande3/Generative-Adversarial-Networks/blob/master/Generative%20Adversarial%20Networks%20Tutorial.ipynb我们正在尝试使用生成模型生成 MNIST 数据集(例如图像)。然而,该代码似乎使用旧版本的 TensorFlow (

行:trainerD = tf.train.AdamOptimizer().minimize(d_loss, var_list=d_vars)

ValueError:变量 d_wconv1/Adam/ 不存在,或未创建 使用 tf.get_variable()。您的意思是在 VarScope 中设置reuse=None 吗?

其代码如下:

import tensorflow as tf
import random
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x_train = mnist.train.images[:55000,:]
#print (x_train.shape)

#randomNum = random.randint(0,55000)
#image = x_train[randomNum].reshape([28,28])
#plt.imshow(image, cmap=plt.get_cmap('gray_r'))
#plt.show()

def conv2d(x, W):
  return tf.nn.conv2d(input=x, filter=W, strides=[1, 1, 1, 1], padding='SAME')

def avg_pool_2x2(x):
  return tf.nn.avg_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

def discriminator(x_image, reuse=False):
    if (reuse):
        tf.get_variable_scope().reuse_variables()
    #First Conv and Pool Layers
    W_conv1 = tf.get_variable('d_wconv1', [5, 5, 1, 8], initializer=tf.truncated_normal_initializer(stddev=0.02))
    b_conv1 = tf.get_variable('d_bconv1', [8], initializer=tf.constant_initializer(0))
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = avg_pool_2x2(h_conv1)

    #Second Conv and Pool Layers
    W_conv2 = tf.get_variable('d_wconv2', [5, 5, 8, 16], initializer=tf.truncated_normal_initializer(stddev=0.02))
    b_conv2 = tf.get_variable('d_bconv2', [16], initializer=tf.constant_initializer(0))
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = avg_pool_2x2(h_conv2)

    #First Fully Connected Layer
    W_fc1 = tf.get_variable('d_wfc1', [7 * 7 * 16, 32], initializer=tf.truncated_normal_initializer(stddev=0.02))
    b_fc1 = tf.get_variable('d_bfc1', [32], initializer=tf.constant_initializer(0))
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*16])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    #Second Fully Connected Layer
    W_fc2 = tf.get_variable('d_wfc2', [32, 1], initializer=tf.truncated_normal_initializer(stddev=0.02))
    b_fc2 = tf.get_variable('d_bfc2', [1], initializer=tf.constant_initializer(0))

    #Final Layer
    y_conv=(tf.matmul(h_fc1, W_fc2) + b_fc2)
    return y_conv

def generator(z, batch_size, z_dim, reuse=False):
    if (reuse):
        tf.get_variable_scope().reuse_variables()
    g_dim = 64 #Number of filters of first layer of generator 
    c_dim = 1 #Color dimension of output (MNIST is grayscale, so c_dim = 1 for us)
    s = 28 #Output size of the image
    s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) #We want to slowly upscale the image, so these values will help
                                                              #make that change gradual.

    h0 = tf.reshape(z, [batch_size, s16+1, s16+1, 25])
    h0 = tf.nn.relu(h0)
    #Dimensions of h0 = batch_size x 2 x 2 x 25

    #First DeConv Layer
    output1_shape = [batch_size, s8, s8, g_dim*4]
    W_conv1 = tf.get_variable('g_wconv1', [5, 5, output1_shape[-1], int(h0.get_shape()[-1])], 
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
    b_conv1 = tf.get_variable('g_bconv1', [output1_shape[-1]], initializer=tf.constant_initializer(.1))
    H_conv1 = tf.nn.conv2d_transpose(h0, W_conv1, output_shape=output1_shape, strides=[1, 2, 2, 1], padding='SAME')
    H_conv1 = tf.contrib.layers.batch_norm(inputs = H_conv1, center=True, scale=True, is_training=True, scope="g_bn1")
    H_conv1 = tf.nn.relu(H_conv1)
    #Dimensions of H_conv1 = batch_size x 3 x 3 x 256

    #Second DeConv Layer
    output2_shape = [batch_size, s4 - 1, s4 - 1, g_dim*2]
    W_conv2 = tf.get_variable('g_wconv2', [5, 5, output2_shape[-1], int(H_conv1.get_shape()[-1])], 
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
    b_conv2 = tf.get_variable('g_bconv2', [output2_shape[-1]], initializer=tf.constant_initializer(.1))
    H_conv2 = tf.nn.conv2d_transpose(H_conv1, W_conv2, output_shape=output2_shape, strides=[1, 2, 2, 1], padding='SAME')
    H_conv2 = tf.contrib.layers.batch_norm(inputs = H_conv2, center=True, scale=True, is_training=True, scope="g_bn2")
    H_conv2 = tf.nn.relu(H_conv2)
    #Dimensions of H_conv2 = batch_size x 6 x 6 x 128

    #Third DeConv Layer
    output3_shape = [batch_size, s2 - 2, s2 - 2, g_dim*1]
    W_conv3 = tf.get_variable('g_wconv3', [5, 5, output3_shape[-1], int(H_conv2.get_shape()[-1])], 
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
    b_conv3 = tf.get_variable('g_bconv3', [output3_shape[-1]], initializer=tf.constant_initializer(.1))
    H_conv3 = tf.nn.conv2d_transpose(H_conv2, W_conv3, output_shape=output3_shape, strides=[1, 2, 2, 1], padding='SAME')
    H_conv3 = tf.contrib.layers.batch_norm(inputs = H_conv3, center=True, scale=True, is_training=True, scope="g_bn3")
    H_conv3 = tf.nn.relu(H_conv3)
    #Dimensions of H_conv3 = batch_size x 12 x 12 x 64

    #Fourth DeConv Layer
    output4_shape = [batch_size, s, s, c_dim]
    W_conv4 = tf.get_variable('g_wconv4', [5, 5, output4_shape[-1], int(H_conv3.get_shape()[-1])], 
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
    b_conv4 = tf.get_variable('g_bconv4', [output4_shape[-1]], initializer=tf.constant_initializer(.1))
    H_conv4 = tf.nn.conv2d_transpose(H_conv3, W_conv4, output_shape=output4_shape, strides=[1, 2, 2, 1], padding='VALID')
    H_conv4 = tf.nn.tanh(H_conv4)
    #Dimensions of H_conv4 = batch_size x 28 x 28 x 1

    return H_conv4

sess = tf.Session()
z_dimensions = 100
z_test_placeholder = tf.placeholder(tf.float32, [None, z_dimensions])

sample_image = generator(z_test_placeholder, 1, z_dimensions)
test_z = np.random.normal(-1, 1, [1,z_dimensions])

sess.run(tf.global_variables_initializer())
temp = (sess.run(sample_image, feed_dict={z_test_placeholder: test_z}))

my_i = temp.squeeze()
#plt.imshow(my_i, cmap='gray_r')
#plt.show()

batch_size = 16
tf.reset_default_graph() #Since we changed our batch size (from 1 to 16), we need to reset our Tensorflow graph

sess = tf.Session()
x_placeholder = tf.placeholder("float", shape = [None,28,28,1]) #Placeholder for input images to the discriminator
z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions]) #Placeholder for input noise vectors to the generator

Dx = discriminator(x_placeholder) #Dx will hold discriminator prediction probabilities for the real MNIST images
Gz = generator(z_placeholder, batch_size, z_dimensions) #Gz holds the generated images
Dg = discriminator(Gz, reuse=True) #Dg will hold discriminator prediction probabilities for generated images

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dg, labels=tf.ones_like(Dg)))

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dx, labels=tf.ones_like(Dx)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dg, labels=tf.zeros_like(Dg)))
d_loss = d_loss_real + d_loss_fake

tvars = tf.trainable_variables()
d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name]

trainerD = tf.train.AdamOptimizer().minimize(d_loss, var_list=d_vars)
trainerG = tf.train.AdamOptimizer().minimize(g_loss, var_list=g_vars)

sess.run(tf.global_variables_initializer())
iterations = 3000
for i in range(iterations):
    z_batch = np.random.normal(-1, 1, size=[batch_size, z_dimensions])
    real_image_batch = mnist.train.next_batch(batch_size)
    real_image_batch = np.reshape(real_image_batch[0],[batch_size,28,28,1])
    _,dLoss = sess.run([trainerD, d_loss],feed_dict={z_placeholder:z_batch,x_placeholder:real_image_batch}) #Update the discriminator
    _,gLoss = sess.run([trainerG,g_loss],feed_dict={z_placeholder:z_batch}) #Update the generator

sample_image = generator(z_placeholder, 1, z_dimensions)
z_batch = np.random.normal(-1, 1, size=[1, z_dimensions])
temp = (sess.run(sample_image, feed_dict={z_placeholder: z_batch}))
my_i = temp.squeeze()
plt.imshow(my_i, cmap='gray_r')
plt.show()

它似乎有一个简单的解决方案,不幸的是我无法弄清楚。任何帮助,将不胜感激。


请修改您的代码如下,

with tf.variable_scope(tf.get_variable_scope(),reuse=False): trainerD = tf.train.AdamOptimizer().minimize(d_loss, var_list=d_vars) trainerG = tf.train.AdamOptimizer().minimize(g_loss, var_list=g_vars)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow ValueError:变量不存在,或者不是使用 tf.get_variable() 创建的 的相关文章

  • 从文本文件中删除特定字符

    我对 Python 和编码都很陌生 我当时正在做一个小项目 但遇到了一个问题 44 1 6 23 2 7 49 2 3 53 2 1 68 1 6 71 2 7 我只需要从每行中删除第三个和第六个字符 或者更具体地说 从整个文件中删除 字符
  • 使用python查找txt文件中字母出现的次数

    我需要从 txt 文件中读取该字母并打印 txt 文件中出现的次数 到目前为止 我已经能够在一行中打印内容 但计数有问题 有人可以指导吗 infile open grades txt content infile read for char
  • Virtualenv 在 OS X Yosemite 上失败并出现 OSError

    我最近更新到 OSX Yosemite 现在无法使用virtualenv pip 每当我执行 virtualenv env 它抛出一个 OSError Command Users administrator ux env bin pytho
  • 使用 django-rest-framework 设置对象级权限

    尝试使用 django rest framework 最干净 最规范地管理 django guardian 对象级权限 我想将对象的读取权限 module view object 分配给在执行 POST 时发出请求的用户 我的基于阶级的观点
  • 按边距(“全部”)值列对 Pandas 数据透视表进行排序

    我试图根据 pandas 数据透视表中的行总和对最后一列 边距 aggrfunc 进行降序排序 我知道我在这里错过了一些简单的东西 但我无法弄清楚 数据框 数据透视表 WIDGETS DATE 2 1 16 2 2 16 2 3 16 Al
  • 使用 Django 将文件异步上传到 Amazon S3

    我使用此文件存储引擎在上传文件时将文件存储到 Amazon S3 http code welldev org django storages wiki Home http code welldev org django storages w
  • 使用 Paramiko 进行 DSA 密钥转发?

    我正在使用 Paramiko 在远程服务器上执行 bash 脚本 在其中一些脚本中 存在与其他服务器的 ssh 连接 如果我只使用 bash 不使用 Python 我的 DSA 密钥将被第一个远程服务器上的 bash 脚本转发并使用 以连接
  • 如何在 Tensorflow 对象检测 API 中查找边界框坐标

    我正在使用 Tensorflow 对象检测 API 代码 我训练了我的模型并获得了很高的检测百分比 我一直在尝试获取边界框坐标 但它不断打印出 100 个奇怪数组的列表 经过在线广泛搜索后 我发现数组中的数字意味着什么 边界框坐标相对于底层
  • 如何确保 re.findall() 停止在正确的位置?

    这是我的代码 a import re re findall r lt title gt lt title gt a 结果是 title aaa
  • 如何使用scrapy检查网站是否支持http、htts和www前缀

    我正在使用 scrapy 来检查某些网站是否工作正常 当我使用http example com https example com or http www example com 当我创建 scrapy 请求时 它工作正常 例如 在我的pa
  • Python Anaconda:如何测试更新的库是否与我现有的代码兼容?

    我在 Windows 7 机器上使用 Python 2 7 Anaconda 安装进行数据分析和科学计算 当新的库发布时 例如新版本的 pandas patsy 等 您建议我如何测试新版本与现有代码的兼容性 是否可以在同一台机器上安装两个
  • 移动设备上的 TensorFlow(Android、iOS、Windows Phone)

    我目前正在寻找不同的深度学习框架 特别是用于训练和部署卷积神经网络 要求是 它可以在带有 GPU 的普通 PC 上进行训练 但训练后的模型必须部署在三个主要的移动操作系统上 即 Android iOS 和 Windows Phone Ten
  • 使用 for 循环创建一系列元组

    我已经搜索过 但找不到答案 尽管我确信它已经存在了 我对 python 很陌生 但我以前用其他语言做过这种事情 我正在以行形式读取数据文件 我想将每行数据存储在它自己的元组中 以便在 for 循环之外访问 tup i inLine wher
  • Plotly:如何检查基本图形结构(版本 4)

    对于旧版本的plotly 例如在 Jupyterlab 中 您可以简单地运行figure像这样检查你的图形的基础知识 Ouput data marker color red size 10 symbol 104 mode markers l
  • 如何指示 urwid 列表框的项目数多于当前显示的项目数?

    有没有办法向用户显示 urwid 列表框在显示部分上方 下方有其他项目 我正在考虑类似滚动条的东西 它可以显示条目的数量 或者列表框顶部 底部的单独栏 如果这个行为无法实现 有哪些方法可以实现这个通知 在我的研究过程中 我发现这个问题 ht
  • Django 管理器链接

    我想知道是否有可能 如果可以的话 如何 将多个管理器链接在一起以生成受两个单独管理器影响的查询集 我将解释我正在研究的具体示例 我有多个抽象模型类 用于为其他模型提供小型的特定功能 其中两个模型是DeleteMixin 和GlobalMix
  • rpy2 无法加载外部库

    希望有人能帮忙解决这个问题 R版本 2 14 1rpy2版本 2 2 5蟒蛇版本 2 7 3 一直在尝试在 python 脚本中使用 rpy2 加载 R venneuler 包 该包以 rJava 作为依赖项 venneuler 和 rJa
  • 如何从namedtuple实例列表创建pandas DataFrame(带有索引或多索引)?

    简单的例子 from collections import namedtuple import pandas Price namedtuple Price ticker date price a Price GE 2010 01 01 30
  • 如何获取pandas中groupby对象中的组数?

    我想知道有多少个独特的组需要执行计算 给定一个名为 groupby 的对象dfgroup 我们如何找到组的数量 简单 快速 Pandaic ngroups 较新版本的 groupby API pandas gt 0 23 提供了此 未记录的
  • 无法安装最新版本的 Numpy (1.22.3)

    我正在尝试安装最新版本的 numpy 即 1 22 3 但看起来 pip 无法找到最后一个版本 我知道我可以从源代码本地安装它 但我想了解为什么我无法使用 pip 安装它 PS 我有最新版本的pip 22 0 4 ERROR Could n

随机推荐

  • 将某些工作表从 Excel 工作簿导出为 PDF

    我正在编写一个 VBA 代码 将 Excel 中的一些工作表导出到同一个 PDF 我的 Excel 文件中有几个图表工作表 每个图表工作表的名称都以 name Chart 结尾 我想将名称以图表结尾的所有工作表导出到一个 PDF 文件 这是
  • 在 C/C++ 中获取大随机数

    标准rand 函数给出的数字对我来说不够大 我需要unsigned long long那些 我们如何获得非常大的随机数 我尝试修改一个简单的哈希函数 但它太big 运行时间太长 并且永远不会产生小于 1e5 的数字 你可以轻松地做到这一点s
  • Android 自定义 Widget 膨胀异常

    XML
  • 自定义 li 列表样式,带有很棒的字体图标

    我想知道是否可以利用 font awesome 或任何其他标志性字体 类来创建自定义 li 列表样式类型 我目前正在使用 jQuery 来执行此操作 即 li myClass prepend i class i 然而 当 li li 文本环
  • POST Restful API 的响应代码 400 或 403

    我正在设计一个 POST Restful API 在这种情况下 我必须根据请求正文中提供的元素之一来授权用户 例如 division 1 name MyName address no 123 street abc pincode 22211
  • Minimongo 尚不支持投影中的 $ 运算符

    我有这个文件 username torayeff profile friends id aSD4wmMEsFnYLcvmP state active id ShFTXxuQLAxWCh4cq state active id EQjoKMNB
  • 如何在android的WebView中显示滚动条

    是否可以在 android 的 WebView 中显示可滚动 html 元素的滚动条 怎么做 in your onCreate 方法 试试这个 Override public void onCreate Bundle savedInstan
  • ASP.NET Web API JSON 输出中没有时间的日期

    有没有一种简单的方法来配置 JSON NET 以便some DateTime字段将被格式化 没有时间和其他DateTime字段仍会随时间格式化吗 例子 firstName John lastName Doe birthday 1965 09
  • vbscript 调试器[关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我有一个可以在开发中使用的 vbscript 但不能在服务器上使用 我想调试这个 但我不想在服务器上安
  • 当我们使用rest api进行调用时,我们在该url参数上使用什么twiml

    我正在使用其余 api 创建呼叫 try Initiate a new outbound call call this gt client gt calls gt create to call num2 Step 5 Change the
  • 是否有可能伪造 Windows 控制台 api?

    我用 C 编写了一个 ssh 服务器 我认为将 powershell 作为 shell 连接起来会很漂亮 我尝试了两种方法来使其正常工作 但这两种方法都远非完美 这是我尝试过的 启动 powershell exe 并将其重定向到 std i
  • Elasticsearch批量索引超时错误!错误:30000 毫秒后请求超时

    最近 我想将旧的指数数据滚动到新的月度指数 存储的数据从2015 07开始至今 每个月几乎有30 000条记录 跟着scroll and bulk中提供的方法2 2 API https www elastic co guide en ela
  • 在 PHP 中处理大量数据

    To use 模幂 http en wikipedia org wiki Modular exponentiation正如您在使用时所需要的费马素性测试 http en wikipedia org wiki Fermat primality
  • 如何在 LARAVEL 5.2 中将数据存储到数据库

    我是 Laravel 的初学者 当我想将数据存储到数据库时遇到问题 当视图上的名称与数据库上的字段名称不同时 数据未保存在数据库中 但当视图上的输入名称与数据库上的字段名称相同时 数据已正确存储 例子 这是视图 div class form
  • 在 C 中实现 ceil()

    我想实现我自己的ceil in C 在库中搜索源代码并找到 但似乎很难理解 我想要干净优雅的代码 我也在SO上搜索 找到了一些答案here https stackoverflow com questions 2796639 implemen
  • 将信息从一层传递到另一层

    抽象视图 我想将信息从一层传递到另一层 注意 当该主题有更好的标题时请告诉我 我有一个 ViewModel 它与我的视图和服务层进行通信 我与持久层有服务层通信 假设我有以下课程 public class EmployeeViewModel
  • java序列化导致utfdataformatException

    我正在尝试将多个对象序列化到一个文件中 特别是 当我尝试写时 public void execute PipelineContext context throws Exception FileOutputStream fos new Fil
  • Lucene 中的 WordnetSynonymParser

    我是 Lucene 的新手 我正在尝试使用 WordnetSynonymParser 来使用 wordnet 同义词序言来扩展查询 这是我到目前为止所拥有的 public class CustomAnalyzer extends Analy
  • 应用程序不显示 adwhirl 广告

    我创建了一个简单的应用程序 显示使用广告旋转添加 它不显示任何广告 我添加了 logcat 文件 提前致谢 11 18 15 08 55 940 ERROR AdWhirl SDK 619 Caught IOException in fet
  • TensorFlow ValueError:变量不存在,或者不是使用 tf.get_variable() 创建的

    我是 Tensorflow 的新手 正在尝试实现生成对抗网络 我正在关注this https github com adeshpande3 Generative Adversarial Networks blob master Genera