TensorFlow ValueError:应定义输入的通道维度。发现“无”

2024-01-02

我正在尝试实现一个“扩张残差网络”,如此处所述Paper https://arxiv.org/abs/1705.09914在 TensorFlow 中(s. PyTorch 实现here https://github.com/fyu/drn)来训练它城市景观数据集 https://www.cityscapes-dataset.com/并将其用于语义图像分割。不幸的是,我在尝试训练时遇到错误,并且似乎无法找到修复方法。

由于此类网络可以看作是 ResNet 的扩展,因此我使用了官方的 TensorFlow ResNet 模型(Link https://github.com/tensorflow/models/tree/master/official/resnet)并通过改变步幅、添加膨胀(作为 tf.layers.conv2d 函数中的参数)和删除残余连接来修改架构。

为了训练这个网络,我想使用与 TensorFlow ResNet 模型相同的方法:tf.estimator 与 input_fn 结合使用(可以在本文末尾找到)。

现在,当我想使用 CityScapes 数据集训练该网络时,出现以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-19-263240bbee7e> in <module>()
----> 1 main()

<ipython-input-16-b57cd9b52bc7> in main()
     27         print('Starting a training cycle.')
     28         drn_classifier.train(
---> 29             input_fn=lambda: input_fn(True, _BATCH_SIZE, _EPOCHS_PER_EVAL),hooks=[logging_hook])
     30 
     31         print(2)

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\estimator\estimator.py in train(self, input_fn, hooks, steps, max_steps, saving_listeners)
    300 
    301     saving_listeners = _check_listeners_type(saving_listeners)
--> 302     loss = self._train_model(input_fn, hooks, saving_listeners)
    303     logging.info('Loss for final step: %s.', loss)
    304     return self

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\estimator\estimator.py in _train_model(self, input_fn, hooks, saving_listeners)
    709       with ops.control_dependencies([global_step_read_tensor]):
    710         estimator_spec = self._call_model_fn(
--> 711             features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
    712       # Check if the user created a loss summary, and add one if they didn't.
    713       # We assume here that the summary is called 'loss'. If it is not, we will

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\estimator\estimator.py in _call_model_fn(self, features, labels, mode, config)
    692     if 'config' in model_fn_args:
    693       kwargs['config'] = config
--> 694     model_fn_results = self._model_fn(features=features, **kwargs)
    695 
    696     if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):

<ipython-input-15-797249462151> in drn_model_fn(features, labels, mode, params)
      7         params['arch'], params['size'], _LABEL_CLASSES, params['data_format'])
      8     print(4)
----> 9     logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
     10     print(12)
     11     predictions = {

\Code\Semantic Image Segmentation\drn.py in model(inputs, is_training)
    255             print(16)
    256         inputs = conv2d_fixed_padding(
--> 257             inputs=inputs, filters=16, kernel_size=7, strides=2,
    258             data_format=data_format,dilation_rate=1)
    259                 print(17)

\Code\Semantic Image Segmentation\drn.py in conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format, dilation_rate)
     90       kernel_initializer=tf.variance_scaling_initializer(),
     91       data_format=data_format,
---> 92       dilation_rate=dilation_rate)
     93 
     94 

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\layers\convolutional.py in conv2d(inputs, filters, kernel_size, strides, padding, data_format, dilation_rate, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, trainable, name, reuse)
    606       _reuse=reuse,
    607       _scope=name)
--> 608   return layer.apply(inputs)
    609 
    610 

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\layers\base.py in apply(self, inputs, *args, **kwargs)
    669       Output tensor(s).
    670     """
--> 671     return self.__call__(inputs, *args, **kwargs)
    672 
    673   def _add_inbound_node(self,

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\layers\base.py in __call__(self, inputs, *args, **kwargs)
    557           input_shapes = [x.get_shape() for x in input_list]
    558           if len(input_shapes) == 1:
--> 559             self.build(input_shapes[0])
    560           else:
    561             self.build(input_shapes)

~\Anaconda3\envs\master-thesis\lib\site-packages\tensorflow\python\layers\convolutional.py in build(self, input_shape)
    130       channel_axis = -1
    131     if input_shape[channel_axis].value is None:
--> 132       raise ValueError('The channel dimension of the inputs '
    133                        'should be defined. Found `None`.')
    134     input_dim = input_shape[channel_axis].value

ValueError: The channel dimension of the inputs should be defined. Found `None`.

我已经在网上搜索了这个错误,但只找到了与 Keras 相关的帖子,当时后端未正确初始化(s.1)。this https://github.com/keras-team/keras/issues/5900).

如果有人能指出我寻找错误的方向,我会很高兴。

这是我的 input_fn:

def input_fn(is_training, batch_size, num_epochs=1):
    """Input function which provides batches for train or eval."""
    # Get list of paths belonging to training images and corresponding label images
    filename_list = filenames(is_training)
    filenames_train = []
    filenames_labels = []
    for i in range(len(filename_list)):
        filenames_train.append(train_dataset_dir+filename_list[i])
        filenames_labels.append(gt_dataset_dir+filename_list[i])


    filenames_train = tf.convert_to_tensor(tf.constant(filenames_train, dtype=tf.string))
    filenames_labels = tf.convert_to_tensor(tf.constant(filenames_labels, dtype=tf.string))

    dataset = tf.data.Dataset.from_tensor_slices((filenames_train,filenames_labels))

    if is_training:
        dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
        dataset = dataset.map(image_parser)
        dataset = dataset.prefetch(batch_size)

        if is_training:
          # When choosing shuffle buffer sizes, larger sizes result in better
          # randomness, while smaller sizes have better performance.
            dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

      # We call repeat after shuffling, rather than before, to prevent separate
      # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)

    iterator = dataset.make_one_shot_iterator()
    images, labels = iterator.get_next()
    return images, labels

这是 input_fn 中使用的 image_parser 函数:

def image_parser(filename, label): 
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_image(image_string,_NUM_CHANNELS)  
    image_decoded = tf.image.convert_image_dtype(image_decoded, dtype=tf.float32)
    label_string = tf.read_file(label)
    label_decoded = tf.image.decode_image(label)
    return image_decoded, tf.one_hot(label_decoded, _LABEL_CLASSES)

之后尝试这个tf.read_file:

image_decoded = tf.image.decode_image(image_string, channels=3)
image_decoded.set_shape([None, None, 3])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow ValueError:应定义输入的通道维度。发现“无” 的相关文章

随机推荐

  • JQuery UI 自动完成,更改事件不会触发

    我的 JQuery 自动完成有一些问题 我的代码如下 var mySource label Value one id 1 label Value two id 2 label Value three id 3 txtAutocomplete
  • 如何获取RAM名称?

    我有一个关于 RAM 信息的问题 如何获得它的名称 到目前为止 从Win32 PhysicalMemory and Win32 PhysicalMemoryArray类我能够获得除名称之外的大部分信息 显示名称 例如 CRUCIAL BAL
  • OS X Leopard 上的多个版本的 Python

    目前 我的 Mac 上安装了多个版本的 Python 其中包括随机附带的版本 最近从 python org 下载的版本 用于本地运行 Zope 的旧版本以及 Appengine 正在使用的另一个版本 有点乱 有什么建议可以使用一个版本的 p
  • 在 Chrome 和 Safari 中,使用 valign=top 的表格中 的错误呈现

    我有以下 HTML p style font family Verdana test p
  • Twitter Typeahead Ajax 结果未定义

    我正在使用对 PHP 文件的 ajax JSON 调用来构建 Twitter 预输入自动完成功能来获取一些数据 但它一直在下拉结果列表中显示以下内容 不明确的 不明确的 不明确的 但是当我这样做时 alert data 我显示了正确的数据
  • Visual Studio 2017前缀文件嵌套

    有没有办法对具有相同后缀但具有变体前缀的文件进行分组 Example hero model ts power hero model ts weapon hero model ts bullet weapon hero model ts 本指
  • 大量删除文档会影响ES查询性能吗

    在我的 ES 集群中 我几乎没有读取大量索引 开始看到这些索引的性能问题 该集群拥有大约 5000 万个文档 并注意到其中大多数文档的总文档数的 25 左右被删除 我知道当后台合并操作发生时 这些已删除的文档数量会随着时间的推移而减少 但就
  • 如何使用 angular-cli webpack 调试 Angular 应用程序?

    I used 电子邮件受保护 cdn cgi l email protection之前和现在我更新到 angular cli webpack beta 11 经过大量的自定义更改后 我让它工作了 唯一的问题是 现在我无法使用 webstor
  • 为什么 font-sizing vw 在 safari 中不起作用?

    我使用 vw 作为字体大小的单位 这样在调整浏览器大小时它看起来会很漂亮 但是 当我在 Safari 中浏览它时 内容会运行 有人知道如何解决吗 谢谢 CSS flatNav background image url img navBar
  • OpenLayers 3 和 XYZ 层

    I have a map which I want to display It consists of a standard map OSM Google or Bing and a layer provided by Openseamap
  • NextJS - getServerSideProps - 错误 400 - 错误请求

    我在 NEXT JS 中使用 getServerSideProps 函数进行 fetch 时遇到问题 当我开始使用这个框架时 我可能做得不好 访问外部 API 的所有凭据都是正确的 当我在 React 中使用相同的参数进行获取时 它会为我带
  • 在 OS X JavaScript for Automation (JXA) 中附加事件侦听器

    如何在 OS X JavaScript 中监听事件以实现自动化 在消息应用程序的脚本库中 有一个事件处理程序列表 例如messageSent and messageReceived 但是 我不知道如何使用它们 尝试传递函数会产生错误 尝试将
  • 使用 JPA(带注释的实体)和 liquibase 的 Hibernate

    liquibase 是 hibernate 的完美替代品hbm2ddl 自动 http www jroller com eyallupu entry hibernate s hbm2ddl tool属性 如果您使用 xml 映射 但我使用
  • 是否有猫鼬连接错误回调

    如果猫鼬无法连接到我的数据库 我如何设置错误处理的回调 我知道 connection on open function 但有没有类似的东西 connection on error function err 连接后 您可以在回调中发现错误 m
  • 构建 GoogleSignInOptions 时的 firebase serverClientId

    我正在使用 google 登录工作流程来获取 GoogleSignInAccount 对象 我想对我的 firebase 应用程序验证 google 用户的身份 这需要一个可以使用请求的令牌requestIdToken 字符串服务器客户端I
  • 防止发送内容类型为“多部分/相关”的 SOAP 消息

    我正在从 Web 服务客户端 代码由 IBM RAD 7 5 生成 向主机发送一条 SOAP 消息 该消息带有一个 Web 服务故障以及一条在主机日志中显示为 序言中不允许的内容 的消息 当我使用 SoapUI 或简单的 apache Ht
  • Java 中的 Throwable 方法重写

    首先 抱歉我的英语不好 问题 如果我有一个子类扩展了一个抛出 CHECKED 异常的方法 那么为什么 Java 允许我在子类的重写方法中抛出 RuntimeException 如下例所示 public class A public void
  • 如何打开 Eclipse Web 浏览器?

    Eclipse 有一个 Web 浏览器 但我不知道打开它的任何直接方法 我知道如果你去 Eclipse 市场并要求更多结果它会打开 如果你让浏览器打开你的 html 你可以打开它 但是有没有更直接的方法来打开它 例如 显示浏览器的按钮或菜单
  • pandas 格式日期时间索引到季度

    通过重新采样作业 我将每月值转换为季度值 hs hs resample QS axis 1 mean 效果很好 我的专栏如下所示 hs columns DatetimeIndex 2000 01 01 2000 04 01 2000 07
  • TensorFlow ValueError:应定义输入的通道维度。发现“无”

    我正在尝试实现一个 扩张残差网络 如此处所述Paper https arxiv org abs 1705 09914在 TensorFlow 中 s PyTorch 实现here https github com fyu drn 来训练它城