Keras.backend.reshape:类型错误:无法将 类型的对象转换为张量。考虑将元素转换为受支持的类型

2024-01-19

我正在为我的神经网络设计一个自定义层,但我的代码出现错误。

我想做一个注意力层,如论文中所述:SAGAN https://arxiv.org/abs/1805.08318。还有原始tf代码 https://github.com/taki0112/Self-Attention-GAN-Tensorflow

class AttentionLayer(Layer):
def __init__(self, **kwargs):
    super(AttentionLayer, self).__init__(**kwargs)

def build(self, input_shape):
    input_dim = input_shape[-1]
    filters_f_g = input_dim // 8
    filters_h = input_dim
    kernel_shape_f_g = (1, 1) + (input_dim, filters_f_g)
    kernel_shape_h = (1, 1) + (input_dim, filters_h)
    # Create a trainable weight variable for this layer:
    self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True)
    self.kernel_f = self.add_weight(shape=kernel_shape_f_g,
                                    initializer='glorot_uniform',
                                    name='kernel')
    self.kernel_g = self.add_weight(shape=kernel_shape_f_g,
                                    initializer='glorot_uniform',
                                    name='kernel')
    self.kernel_h = self.add_weight(shape=kernel_shape_h,
                                    initializer='glorot_uniform',
                                    name='kernel')
    self.bias_f = self.add_weight(shape=(filters_f_g,),
                                  initializer='zeros',
                                  name='bias')
    self.bias_g = self.add_weight(shape=(filters_f_g,),
                                  initializer='zeros',
                                  name='bias')
    self.bias_h = self.add_weight(shape=(filters_h,),
                                  initializer='zeros',
                                  name='bias')
    super(AttentionLayer, self).build(input_shape)

def call(self, x):
    def hw_flatten(x):
        return K.reshape(x, shape=[x.shape[0], x.shape[1]*x.shape[2], x.shape[-1]])

    f = K.conv2d(x, kernel=self.kernel_f, strides=(1, 1), padding='same')  # [bs, h, w, c']
    f = K.bias_add(f, self.bias_f)
    g = K.conv2d(x, kernel=self.kernel_g, strides=(1, 1), padding='same')  # [bs, h, w, c']
    g = K.bias_add(g, self.bias_g)
    h = K.conv2d(x, kernel=self.kernel_h, strides=(1, 1), padding='same')  # [bs, h, w, c]
    h = K.bias_add(h, self.bias_h)

    # N = h * w
    flatten_g = hw_flatten(g)
    flatten_f = hw_flatten(f)
    s = K.batch_dot(flatten_g, flatten_f, axes=1)  # # [bs, N, N]

    beta = K.softmax(s, axis=-1)  # attention map

    o = K.batch_dot(beta, hw_flatten(h))  # [bs, N, C]

    o = K.reshape(o, shape=x.shape)  # [bs, h, w, C]
    x = self.gamma * o + x

    return x

当我在模型中添加这一层时,出现错误:

TypeError: Expected binary or unicode string, got Dimension(None)

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-5-9d4e83945ade> in <module>()
      5 X = Conv2D(64, kernel_size=5, strides=1, name='conv1')(X)
      6 X = Activation('relu')(X)
----> 7 X = AttentionLayer()(X)
      8 X = Flatten(name='flatten2')(X)
      9 X = Dense(1000, activation='relu')(X)

/anaconda3/envs/pycharm/lib/python3.6/site-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
    617 
    618             # Actually call the layer, collecting output(s), mask(s), and shape(s).
--> 619             output = self.call(inputs, **kwargs)
    620             output_mask = self.compute_mask(inputs, previous_mask)
    621 

~/Projects/inpainting/models/attention.py in call(self, x)
     49 
     50         # N = h * w
---> 51         flatten_g = hw_flatten(g)
     52         flatten_f = hw_flatten(f)
     53         s = K.batch_dot(flatten_g, flatten_f, axes=1)  # # [bs, N, N]

~/Projects/inpainting/models/attention.py in hw_flatten(x)
     39     def call(self, x):
     40         def hw_flatten(x):
---> 41             return K.reshape(x, shape=[x.shape[0], x.shape[1]*x.shape[2], x.shape[-1]])
     42 
     43         f = K.conv2d(x, kernel=self.kernel_f, strides=(1, 1), padding='same')  # [bs, h, w, c']

/anaconda3/envs/pycharm/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in reshape(x, shape)
   1896         A tensor.
   1897     """
-> 1898     return tf.reshape(x, shape)
   1899 
   1900 

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
   6111   if _ctx is None or not _ctx._eager_context.is_eager:
   6112     _, _, _op = _op_def_lib._apply_op_helper(
-> 6113         "Reshape", tensor=tensor, shape=shape, name=name)
   6114     _result = _op.outputs[:]
   6115     _inputs_flat = _op.inputs

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    511           except TypeError as err:
    512             if dtype is None:
--> 513               raise err
    514             else:
    515               raise TypeError(

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    508                 dtype=dtype,
    509                 as_ref=input_arg.is_ref,
--> 510                 preferred_dtype=default_dtype)
    511           except TypeError as err:
    512             if dtype is None:

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
   1102 
   1103     if ret is None:
-> 1104       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1105 
   1106     if ret is NotImplemented:

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    233                                          as_ref=False):
    234   _ = as_ref
--> 235   return constant(v, dtype=dtype, name=name)
    236 
    237 

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    212   tensor_value.tensor.CopyFrom(
    213       tensor_util.make_tensor_proto(
--> 214           value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    215   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    216   const_tensor = g.create_op(

/anaconda3/envs/pycharm/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    519       raise TypeError("Failed to convert object of type %s to Tensor. "
    520                       "Contents: %s. Consider casting elements to a "
--> 521                       "supported type." % (type(values), values))
    522     tensor_proto.string_val.extend(str_values)
    523     return tensor_proto

TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [Dimension(None), Dimension(64), Dimension(8)]. Consider casting elements to a supported type.

我试着做x_shape = x.shape.as_list()在 hw_flatten 函数中,但它不起作用,我不知道如何调试此错误。


您正在访问张量.shape属性为您提供 Dimension 对象,而不是实际的形状值。您有 2 个选择:

  1. 如果您知道形状并且它在图层创建时固定,您可以使用K.int_shape(x)[0]这将给出整数值。然而它会返回None如果创建时形状未知;例如,如果batch_size未知。
  2. 如果形状将在运行时确定,那么您可以使用K.shape(x)[0]它将返回一个符号张量,该张量将在运行时保存形状值。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras.backend.reshape:类型错误:无法将 类型的对象转换为张量。考虑将元素转换为受支持的类型 的相关文章

  • 与区域指示符字符类匹配的 python 正则表达式

    我在 Mac 上使用 python 2 7 10 表情符号中的标志由一对表示区域指示符号 https en wikipedia org wiki Regional Indicator Symbol 我想编写一个 python 正则表达式来在
  • 使用特定的类/函数预加载 Jupyter Notebook

    我想预加载一个笔记本 其中包含我在另一个文件中定义的特定类 函数 更具体地说 我想用 python 来做到这一点 比如加载一个配置文件 包含所有相关的类 函数 目前 我正在使用 python 生成笔记本并在服务器上自动启动它们 因为不同的
  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • Python 中的哈希映射

    我想用Python实现HashMap 我想请求用户输入 根据他的输入 我从 HashMap 中检索一些信息 如果用户输入HashMap的某个键 我想检索相应的值 如何在 Python 中实现此功能 HashMap
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 跟踪 pypi 依赖项 - 谁在使用我的包

    无论如何 是否可以通过 pip 或 PyPi 来识别哪些项目 在 Pypi 上发布 可能正在使用我的包 也在 PyPi 上发布 我想确定每个包的用户群以及可能尝试积极与他们互动 预先感谢您的任何答案 即使我想做的事情是不可能的 这实际上是不
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 使用Python请求登录Google帐户

    在多个登录页面上 需要谷歌登录才能继续 我想用requestspython 中的库以便让我自己登录 通常这很容易使用requests库 但是我无法让它工作 我不确定这是否是由于 Google 做出的一些限制 也许我需要使用他们的 API 或
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • 立体太阳图 matplotlib 极坐标图 python

    我正在尝试创建一个与以下类似的简单的立体太阳路径图 http wiki naturalfrequent com wiki Sun Path Diagram http wiki naturalfrequency com wiki Sun Pa
  • YOLOv8获取预测边界框

    我想将 OpenCV 与 YOLOv8 集成ultralytics 所以我想从模型预测中获取边界框坐标 我该怎么做呢 from ultralytics import YOLO import cv2 model YOLO yolov8n pt
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • Jupyter Notebook 找不到 Python 模块

    不知道发生了什么 但每当我使用 ipython 氢 原子 或 jupyter 笔记本时都找不到任何已安装的模块 我知道我安装了 pandas 但笔记本说找不到 我应该补充一点 当我正常运行脚本时 python script py 它确实导入
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • 如何使用原始 SQL 查询实现搜索功能

    我正在创建一个由 CS50 的网络系列指导的应用程序 这要求我仅使用原始 SQL 查询而不是 ORM 我正在尝试创建一个搜索功能 用户可以在其中查找存储在数据库中的书籍列表 我希望他们能够查询 书籍 表中的 ISBN 标题 作者列 目前 它
  • 如何在 Windows 命令行中使用参数运行 Python 脚本

    这是我的蟒蛇hello py script def hello a b print hello and that s your sum sum a b print sum import sys if name main hello sys
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 模拟pytest中的异常终止

    我的多线程应用程序遇到了一个错误 主线程的任何异常终止 例如 未捕获的异常或某些信号 都会导致其他线程之一死锁 并阻止进程干净退出 我解决了这个问题 但我想添加一个测试来防止回归 但是 我不知道如何在 pytest 中模拟异常终止 如果我只
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重

随机推荐