Tensorflow 2 抛出 ValueError:as_list() 未在未知 TensorShape 上定义

2024-03-04

我正在尝试在 Tensorflow 2.0 中训练 Unet 模型,该模型将图像和分割掩模作为输入,但我得到了ValueError : as_list() is not defined on an unknown TensorShape。堆栈跟踪显示问题发生在_get_input_from_iterator(inputs):

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in _prepare_feed_values(model, inputs, mode)
    110     for inputs will always be wrapped in lists.
    111   """
--> 112   inputs, targets, sample_weights = _get_input_from_iterator(inputs)
    113 
    114   # When the inputs are dict, then we want to flatten it in the same order as

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in _get_input_from_iterator(iterator)
    147   # Validate that all the elements in x and y are of the same type and shape.
    148   dist_utils.validate_distributed_dataset_inputs(
--> 149       distribution_strategy_context.get_strategy(), x, y, sample_weights)
    150   return x, y, sample_weights
    151 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_distributed_dataset_inputs(distribution_strategy, x, y, sample_weights)
    309 
    310   if y is not None:
--> 311     y_values_list = validate_per_replica_inputs(distribution_strategy, y)
    312   else:
    313     y_values_list = None

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_per_replica_inputs(distribution_strategy, x)
    354     if not context.executing_eagerly():
    355       # Validate that the shape and dtype of all the elements in x are the same.
--> 356       validate_all_tensor_shapes(x, x_values)
    357     validate_all_tensor_types(x, x_values)
    358 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_all_tensor_shapes(x, x_values)
    371 def validate_all_tensor_shapes(x, x_values):
    372   # Validate that the shape of all the elements in x have the same shape
--> 373   x_shape = x_values[0].shape.as_list()
    374   for i in range(1, len(x_values)):
    375     if x_shape != x_values[i].shape.as_list():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_shape.py in as_list(self)
   1169     """
   1170     if self._dims is None:
-> 1171       raise ValueError("as_list() is not defined on an unknown TensorShape.")
   1172     return [dim.value for dim in self._dims]
   1173

我浏览了其他几个 Stackoverflow 帖子(here https://stackoverflow.com/questions/48136804/tf-estimator-train-throws-as-list-is-not-defined-on-an-unknown-tensorshape and here https://stackoverflow.com/questions/50752787/valueerror-as-list-is-not-defined-on-an-unknown-tensorshape)出现此错误,但就我而言,我认为问题出现在我传递给数据集的映射函数中。我打电话给process_path下面定义的函数map张量流数据集的功能。这接受图像的路径并构造相应的分割掩模的路径,该分割掩模是numpy file。然后将 numpy 文件中的 (256 256) 数组转换为 (256 256 10),使用kerasUtil.to_categorical其中 10 个通道代表每个类别。我用的是check_shape函数来确认张量形状是否正确,但当我打电话时仍然如此model.fit无法导出形状。

# --------------------------------------------------------------------------------------
# DECODE A NUMPY .NPY FILE INTO THE REQUIRED FORMAT FOR TRAINING
# --------------------------------------------------------------------------------------
def decode_npy(npy):
  filename = npy.numpy()
  data = np.load(filename)
  data = kerasUtils.to_categorical(data, 10)
  return data

def check_shape(image, mask):
  print('shape of image: ', image.get_shape())
  print('shape of mask: ', mask.get_shape())
  return 0.0

# --------------------------------------------------------------------------------------
# DECODE AN IMAGE (PNG) FILE INTO THE REQUIRED FORMAT FOR TRAINING
# --------------------------------------------------------------------------------------
def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_png(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  return tf.image.convert_image_dtype(img, tf.float32)

# --------------------------------------------------------------------------------------
# PROCESS A FILE PATH FOR THE DATASET
# input - path to an image file
# output - an input image and output mask
# --------------------------------------------------------------------------------------
def process_path(filePath):
  parts = tf.strings.split(filePath, '/')
  fileName = parts[-1]
  parts = tf.strings.split(fileName, '.')
  prefix = tf.convert_to_tensor(convertedMaskDir, dtype=tf.string)
  suffix = tf.convert_to_tensor("-mask.npy", dtype=tf.string)
  maskFileName = tf.strings.join((parts[-2], suffix))
  maskPath = tf.strings.join((prefix, maskFileName), separator='/')

  # load the raw data from the file as a string
  img = tf.io.read_file(filePath)
  img = decode_img(img)
  mask = tf.py_function(decode_npy, [maskPath], tf.float32)

  return img, mask

# --------------------------------------------------------------------------------------
# CREATE A TRAINING and VALIDATION DATASETS
# --------------------------------------------------------------------------------------
trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)

allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))
# allDataSet = allDataSet.map(process_path, num_parallel_calls=AUTOTUNE)
# allDataSet = allDataSet.map(process_path)

trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.map(process_path).batch(64)
validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.map(process_path).batch(64)

...

# this code throws the error!
model_history = model.fit(trainDataSet, epochs=EPOCHS,
                          steps_per_epoch=stepsPerEpoch,
                          validation_steps=validationSteps,
                          validation_data=validDataSet,
                          callbacks=callbacks)

我在图像和蒙版方面遇到了与您相同的问题,并通过在预处理函数期间手动设置它们的形状来解决它,特别是在 tf.map 期间调用 pyfunc 时。

def process_path(filePath):
  ...

  # load the raw data from the file as a string
  img = tf.io.read_file(filePath)
  img = decode_img(img)
  mask = tf.py_function(decode_npy, [maskPath], tf.float32)

  # TODO:
  img.set_shape([MANUALLY ENTER THIS])
  mask.set_shape([MANUALLY ENTER THIS])

  return img, mask
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 2 抛出 ValueError:as_list() 未在未知 TensorShape 上定义 的相关文章

随机推荐

  • 如果子文档值不存在,Mongodb 插入子文档

    我对 mongodb 很陌生 我有点迷失 我有 mongo 数据库集合 如下所示 id id createdAt new Date name name friends name 1 children name sarah age 12 do
  • html 或 java 脚本代码在硬盘中创建文本文件

    请有人给我一个代码来在硬盘驱动器中创建一个文本文件 结果应该是一个html文件 当双击 html 文件时 它需要在硬盘驱动器 本地 的给定路径中创建一个文本文件 谢谢 出于安全原因 浏览器中常规 HTML 页面中的 JavaScript 不
  • 构造函数中的默认参数--C++

    我有一个 C 类 其中有一个构造函数char char ostream 我想提供一个默认值ostream cerr 这是在标题或 cpp file 您需要将参数设置为参考参数 您不应该尝试复制std cerr 您可能需要在头文件中指定默认参
  • 更改 WPF 中单个/活动窗口的系统语言

    WPF 中是否可以仅更改一个窗口的系统语言 我知道关于InputLanguageManager但我认为它会改变整个系统的语言 InputLanguageManager 完全符合您的要求 它更改当前应用程序的键盘布局 操作系统为每个正在运行的
  • 有没有办法循环遍历 r 中的线性模型的列名称(而不是数字)?

    我有一个包含 40 个数据列 40 种不同的营养素 的数据表 还有用于绘图数字和因子的附加列 我想自动循环每个列名称并为每个列生成一个线性模型和摘要 数据列从第 10 列开始 for i in 10 ncol df for loop ove
  • 将带有 json 的 numpy 数组发送到带有请求的 Flask 应用程序

    使用请求 我需要在单个帖子中将带有 json 数据的 numpy 数组发送到我的 Flask 应用程序 我该怎么做呢 转换 numpy 数组arr到 json 时 可以将其序列化 同时保留维度json dumps arr tolist 然后
  • LIKE '%...' 如何在索引上查找?

    我期待这两个SELECT具有相同的执行计划和性能 由于有一个前导通配符LIKE 我期望进行索引扫描 当我运行这个并查看计划时 第一个SELECT行为符合预期 通过扫描 但第二个SELECT计划显示索引查找 并且运行速度快 20 倍 Code
  • 将散列数据作为散列的密钥传递再次返回不正确的结果

    我正在创建一个将使用 GAS 访问 AWS 服务的脚本 我使用实用程序库中的哈希函数来执行创建 v4 签名所需的所有哈希 这些函数似乎能够成功地对数据进行一次哈希处理 但尝试将哈希数据传递到参数中会产生不正确的结果 还有其他人遇到这个问题并
  • 在swift中实现函数

    我是 swift 的新手 试图实现一个简单的函数 该函数将最小和最大数字作为输入 并返回一个包含所有限制数字的数组 我收到错误 错误 对泛型类型 Array 的引用需要 中的参数 我可以知道我错过了什么吗 func serialNumber
  • Scikit-learn、KMeans:如何使用 max_iter

    我想了解类中的参数 max itersklearn cluster KMeans http scikit learn org stable modules generated sklearn cluster KMeans html 根据文档
  • Haskell:如何停止程序打印向左或向右

    我用 haskell 制作了一个计算器 我在 GHCi 中运行它 然而 由于最终的数字可以是整数或双精度数 我已经进行了类型声明 calc String gt Either Integer Double 然而 函数的输出总是在其前面有左或右
  • C++ 判断类是否具有可比性

    我或多或少是Java程序员 所以这可能是一个愚蠢的问题 但我没有找到任何简单的解决方案 我在 C 中有一个这样的类 template
  • JavaScript 排序方法处理大写字母

    注意到 JavaScript 可能有些奇怪的地方sort 方法 给定以下数组 var arr Aaa CUSTREF Copy a template Copy of Statementsmm Copy1 of Default Email T
  • 在 ColdFusion 中维护出站 TCP 连接池

    我希望从 ColdFusion 应用程序中大量使用 RESTful API 我不是 CF 专家 但我预计重复的 cfhttp 调用将成为瓶颈 因为我相信每次调用都会导致建立连接 发送请求 收到响应和断开连接 我很好奇 有没有办法维护一个连接
  • 在批处理文件中定义和使用变量

    我正在尝试在批处理文件中定义和使用变量 看起来应该很简单 echo off set location bob echo We re working with location 我得到的输出如下 We re working with 这里发生
  • jQuery Mobile 视口在 Windows Phone 中无法工作

    我正在 WindowsPhone 中测试 jQueryMobile 但视口无法正常工作 有一个解决方法这一页 http forum jquery com topic problem with virtual viewport size on
  • FireStore Tasks.whenAllComplete 与协程

    我想同步实现这段代码 但job join deferred await和火力基地await 不工作 有谁知道解决方案吗 CoroutineScope Dispatchers Main launch val job launch Tasks
  • 如何使用 System.Net.HttpClient 发布复杂类型?

    我有一个自定义复杂类型 我想使用 Web API 来使用它 public class Widget public int ID get set public string Name get set public decimal Price
  • 如何使用复选框更改 QGraphicsView 背景

    在此代码中 更改了QGraphicsView背景 现在当我检查 true 时我需要更改背景checkBox 当我设置为checkBox去检查true我需要像这段代码一样设置背景 当我设置时checkBox去检查false 我需要设置QGra
  • Tensorflow 2 抛出 ValueError:as_list() 未在未知 TensorShape 上定义

    我正在尝试在 Tensorflow 2 0 中训练 Unet 模型 该模型将图像和分割掩模作为输入 但我得到了ValueError as list is not defined on an unknown TensorShape 堆栈跟踪显