我按照官方的TensorFlow使用数据增强tutorial。
首先,我创建一个具有增强层的顺序模型:
def _getAugmentationFunction(self):
if not self.augmentation:
return None
pipeline = []
pipeline.append(layers.RandomFlip('horizontal_and_vertical'))
pipeline.append(layers.RandomRotation(30))
pipeline.append(layers.RandomTranslation(0.1, 0.1, fill_mode='nearest'))
pipeline.append(layers.RandomBrightness(0.1, value_range=(0.0, 1.0)))
model = Sequential(pipeline)
return lambda x, y: (model(x, training=True), y)
然后,我在数据集上使用映射函数:
data_augmentation = self._getAugmentationFunction()
self.train_data = self.train_data.map(data_augmentation,
num_parallel_calls=AUTOTUNE)
该代码按预期工作,但我收到以下警告:
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting Bitcast
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2
警告的原因是什么以及如何解决?
我使用的是 TF v2.9.1
这不仅仅是警告 - 这些层非常慢!就我而言,一个 epoch 的时间从 30 秒增加到几分钟。
这似乎是 keras 版本 2.9 和 2.10 中的一个错误(包含在tensorflow中):https://github.com/keras-team/keras-cv/issues/581
它可以在 TF v2.8.3 上正常工作 - 没有错误消息,并且训练速度很快。
在我的 arch 系统上 – 我已经通过安装python-tensorflow-opt-cuda
封装使用pacman
– 我发出以下命令解决了该问题:
python -m pip install tensorflow-gpu==2.8.3
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)