问题描述
训练SSD网络时报错:RuntimeError: output with shape [1, 300, 300] doesn’t match the broadcast shape [3, 300, 300]
导致原因
数据集中存在单通道图片
解决办法
1. 若使用的是opencv
import cv2
gray_image = cv2.imread(‘path_to_your_gray_image’, cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
2. 若使用的是PIL
from PIL import Image
gray_image = Image.open(‘path_to_your_gray_image’).convert(‘L’)
rgb_image = gray_image.convert(‘RGB’)
SSD中有如下代码
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
在 image 读取之后,加入转换RGB代码
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
#################### 悠青 #####################
if image.mode != 'RGB':
# print('the picture is not rgb:')
# print(img_path)
image = image.convert('RGB')
#################### 悠青 #####################
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
但此时又遇到了新的问题
使用image.convert()之后,image.mode = None 而不是JEPG了,导致判断语句image.format != “JPEG” 成立,就会报错退出。
在转换后的图像上调用image.format,它可能会返回None,因为经过色彩空间转换后的图像不再关联原始的文件格式。这不会影响到图像的数据和色彩模式,你仍然可以正常地对图像进行操作和处理。
1. 暴力解决
由于报错代码如下,将其注释掉即可
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
- 定义一个新函数
这个方法我没有试,具体如下:
class ConvertToRGB:
def __init__(self):
self.original_format = None
def __call__(self, image):
self.original_format = image.format
if image.mode != 'RGB':
image = image.convert('RGB')
return image