PaddleOCR 文字检测/文字块检测的模型训练过程,DBnet的前处理和后处理流程损失函数

2023-05-16

文章目录

  • 1、环境搭建
  • 2、数据集
  • 3、下载预训练模型
  • 4、配置文件
    • DecodeImage
    • DetLabelEncode
    • IaaAugment
    • EastRandomCropData
    • MakeBorderMap
  • 5、开启训练
  • 6、纯记录,我在我服务器做的事情
  • 7、显示label
  • 8、推导模型的导出与预测
  • 9、转到onnx模型和mnn模型
    • 到onnx
    • 到mnn
  • 10、在线可视化
  • 11、推理的时候,从概率图到文字框的过程(后处理)
  • 12、DBnet损失函数
    • DiceLoss-->loss_binary_maps
    • MaskL1Loss-->loss_threshold_maps
    • BalanceLoss-->bce_loss -->loss_shrink_maps
    • 损失函数配置
  • 13、评估指标

1、环境搭建

官网环境准备参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/environment.md#1.3
paddle官网:https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html

linux cpu用于调试:

conda create -n paddle python=3.7 -y
conda activate paddle
pip install -r requirements.txt
python -m pip install paddlepaddle==2.1.3 -i https://pypi.tuna.tsinghua.edu.cn/simple

如果是gpu,docker用起来:

nvidia-docker run -v $PWD:/paddle -v /ssd/xiedong/datasets/ICDAR2015/:/ICDAR2015 --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8 /bin/bash # 我这里把数据也挂载进去
python -m pip install -i https://pypi.douban.com/simple --upgrade pip && pip config set global.index-url https://pypi.douban.com/simple
cd /paddle
pip install -r requirements.txt

2、数据集

官网的数据集参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/dataset/ocr_datasets.md

以ICDAR 2015数据集来说,可以直接下载PaddleOCR给的标注,也可以自行使用ppocr/utils/gen_label.py进行标签转换:

python gen_label.py --mode="det" --root_path="/ssd/xiedong/datasets/ICDAR2015/ch4_training_images/"  \
                    --input_path="/ssd/xiedong/datasets/ICDAR2015/ch4_training_localization_transcription_gt" \
                    --output_label="/ssd/xiedong/datasets/ICDAR2015/train_icdar2015_label.txt"

3、下载预训练模型

官网参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/detection.md

下载预训练模型:
https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/README_cn.md#resnet%E5%8F%8A%E5%85%B6vd%E7%B3%BB%E5%88%97

我这里下载后放这里:
在这里插入图片描述
修改label文件中所描述的路径:

# 需要重写label文件中的label
import os

lb_file = r"C:\Users\dong.xie\Desktop\test_icdar2015_label.txt"
lb_file_dst = lb_file
# 读取文件所有行
with open(lb_file, 'r') as f:
    lines = f.readlines()

res_list = []
for line in lines:
    pathname, label = line.split('\t')
    basename = os.path.basename(pathname)
    pathname_new = r"/ICDAR2015/ch4_test_images/" + basename  # 改为绝对路径
    res_list.append(pathname_new + '\t' + label)

# 重写label文件
with open(lb_file_dst, 'w') as f:
    f.write(''.join(res_list))

4、配置文件

配置文件修改为:

Global:
  use_gpu: true
  use_xpu: false
  use_mlu: false
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/db_mv3/
  save_epoch_step: 1200
  # evaluation is run every 2000 iterations
  eval_batch_step: [ 0, 2000 ]
  cal_metric_during_train: False
  pretrained_model: ./pretrain_models/MobileNetV3_small_x0_35_ssld_pretrained
  checkpoints:
  save_inference_dir:
  use_visualdl: true
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt  #设置检测模型的结果保存地址

Architecture:
  model_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.35
    model_name: small
  Neck:
    name: DBFPN
    out_channels: 256
  Head:
    name: DBHead
    k: 50

Loss:
  name: DBLoss
  balance_loss: true
  main_loss_type: DiceLoss
  alpha: 5
  beta: 10
  ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    learning_rate: 0.001
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DetMetric
  main_indicator: hmean

Train:
  dataset:
    name: SimpleDataSet
    data_dir: /ICDAR2015/ch4_training_images/
    label_file_list:
      - /ICDAR2015/train_icdar2015_label.txt
    ratio_list: [ 1.0 ]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }
            - { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }
      - EastRandomCropData:
          size: [ 640, 640 ]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [ 0.485, 0.456, 0.406 ]
          std: [ 0.229, 0.224, 0.225 ]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: [ 'image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask' ] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 16
    num_workers: 8
    use_shared_memory: True

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: /ICDAR2015/ch4_test_images/
    label_file_list:
      - /ICDAR2015/test_icdar2015_label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
          image_shape: [ 736, 1280 ]
      - NormalizeImage:
          scale: 1./255.
          mean: [ 0.485, 0.456, 0.406 ]
          std: [ 0.229, 0.224, 0.225 ]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ]
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 8
    use_shared_memory: True

选择模型的配置是下面,在初始构建模型的代码中会拉起:
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.35
model_name: small
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50

DecodeImage

从文件路径加载图片到内存:

     - DecodeImage: # load image
          img_mode: BGR
          channel_first: False

DetLabelEncode

处理label文本,可看出文字标记为"*“或者”###"的,tag都会是True,后续将忽略此标记。

      - DetLabelEncode: # Class handling label

在这里插入图片描述

IaaAugment

这段代码定义了一个名为 IaaAugment 的类,用于进行图像增强操作。该类的实例化方法 init 接受一个参数 augmenter_args 和其他可选参数 **kwargs。如果没有提供 augmenter_args,则默认使用三个图像增强器:水平翻转、仿射变换和尺寸调整,每个增强器都有不同的参数设置。

call 方法接受一个数据字典 data,其中包括一个 image 键,对应于输入的图像。方法首先获取输入图像的形状,然后将图像传递给 augmenter 对象进行增强操作。增强器对象首先被转换为确定性的,然后对输入图像进行增强,返回增强后的图像。此外,方法还调用 may_augment_annotation 方法对标注进行相同的增强操作,并将增强后的数据字典返回。

may_augment_annotation 方法接受 aug 增强器对象、 data 数据字典和 shape 输入图像的形状。该方法使用 may_augment_poly 方法对 polys 键中的多边形进行增强操作,并将增强后的多边形数组存储在 data[‘polys’] 中。

may_augment_poly 方法接受 aug 增强器对象、img_shape 输入图像的形状和 poly 多边形数组。该方法将多边形的每个点转换为 imgaug.Keypoint 对象,并将它们作为一个 imgaug.KeypointsOnImage 对象传递给增强器的 augment_keypoints 方法。该方法返回一个包含增强后的关键点的 imgaug.KeypointsOnImage 对象,然后将这些点的坐标转换回原始多边形的形式,最后返回增强后的多边形。

      - IaaAugment:
          augmenter_args: # 水平翻转、仿射变换和尺寸调整
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }
            - { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }

在这里插入图片描述

EastRandomCropData

这段代码定义了两个类 EastRandomCropData 和 RandomCropImgMask,它们都是用于数据增强(data augmentation)的。

EastRandomCropData 类实现了一种基于 EAST(Efficient and Accurate Scene Text detection)文本检测算法的数据增强方法,主要作用是对原始图片进行随机裁剪,使得裁剪出的区域包含至少一个文本框,并保持裁剪后的图片与原始图片比例一致。具体实现过程如下:

首先,通过调用 crop_area 函数计算出一个合适的裁剪区域,该函数会根据输入的 all_care_polys 参数(即所有不被标记为忽略的文本框)和裁剪区域的最小边长比例 min_crop_side_ratio,随机生成多个裁剪区域并返回其中包含至少一个文本框的那个区域。如果经过 max_tries 次尝试后还是没有找到合适的裁剪区域,则直接返回原始图片。

接着,根据裁剪区域的大小和目标大小(即 self.size 参数),计算出缩放比例 scale,并将裁剪后的图片缩放到目标大小。如果 keep_ratio 参数为 True,则先将目标大小的全零矩阵 padimg 创建出来,将缩放后的图片居中填充到 padimg 中,最后返回 padimg;否则直接将缩放后的图片调整为目标大小并返回。

最后,对原始文本框进行相应的裁剪和缩放操作,以保证其与裁剪后的图片大小相匹配,并返回增强后的数据。

RandomCropImgMask 类实现了一种随机裁剪的数据增强方法,主要作用是对输入的图像和掩码进行随机裁剪。具体实现过程如下:

首先,根据输入图像的大小和目标大小,生成一个随机裁剪区域,并将该区域应用于输入图像和掩码上。具体来说,如果掩码中存在非零像素(即存在文本区域),则保证裁剪区域至少包含一个文本区域;否则随机生成一个裁剪区域。

接着,对输入数据中包含在 self.crop_keys 列表中的数据(如图像、掩码等)进行裁剪操作,并将裁剪后的数据更新到原始数据中。

最后,返回增强后的数据。

      - EastRandomCropData:
          size: [ 640, 640 ]
          max_tries: 50
          keep_ratio: true

在这里插入图片描述

MakeBorderMap

配置:

      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7

代码块:
在这里插入图片描述
从代码可以看出,输入是文本框polys,输出是canvas和mask,利用下面的代码可以将两者保存为灰度图:

        # 保存成图
        import cv2
        # 将 canvas 和 mask 缩放到 [0, 255] 区间内并转换为无符号 8 位整数类型
        canvas_uint8 = cv2.normalize(canvas, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        mask_uint8 = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        # 保存为灰度图像
        cv2.imwrite('canvas.png', canvas_uint8)
        cv2.imwrite('mask.png', mask_uint8)

canvas就是threshold_map:
在这里插入图片描述
mask就是threshold_mask:
在这里插入图片描述

在这里插入图片描述

5、开启训练

开启训练的指令:

# 单机单卡训练 mv3_db 模型
python3 tools/train.py -c configs/det/det_mv3_db.yml \
     -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained

# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
     -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained

det_mv3_db.yml的含义:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/config.md

6、纯记录,我在我服务器做的事情

在19服务器将docker镜像拉下来:

docker pull kevinchina/deeplearning:paddle2.1.3paddleocr

使用rsync同步文件:

rsync -avz xiedong@10.20.31.16:/ssd/xiedong/workplace/PaddleOCR-release-2.6/ ./PaddleOCR-release-2.6/

启动容器:

nvidia-docker run -v $PWD:/paddle -v /ssd/xiedong/datasets/ICDAR2015/:/ICDAR2015 --shm-size=64G --network=host -it  kevinchina/deeplearning:paddle2.1.3paddleocr /bin/bash

nvidia-docker run -v $PWD:/paddle  -v /ssd/xiedong/datasets/para_det_paddle:/para_det_paddle --shm-size=64G --network=host -it  kevinchina/deeplearning:paddle2.1.3paddleocr /bin/bash

单机多卡训练:

python3 -m paddle.distributed.launch --gpus '0,1,2' tools/train.py -c configs/det/det_mv3_db_para_det_paddle.yml

python3 -m paddle.distributed.launch --gpus '0,1,2' tools/train.py -c configs/det/det_mv3_db_para_det_paddle_1_0_20230506.yml


断点resume训练:
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:

python3 -m paddle.distributed.launch --gpus '0,1,2' tools/train.py -c configs/det/det_mv3_db_para_det_paddle.yml -o Global.checkpoints=./output/db_mv3_para_det_paddle/latest.pdparams

7、显示label

将图片中的文字画框后显示出来:

import os

import PIL.Image as Image
import PIL.ImageDraw as ImageDraw
import cv2

src = r"E:\q\ICDAR2015"
labelfilename = r"E:\q\ICDAR2015\train_icdar2015_label.txt"

# ch4_test_images/img_61.jpg	[{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {"transcription": "###", "points": [[1197, 126], [1252, 118], [1257, 136], [1203, 144]]}, {"transcription": "###", "points": [[1137, 140], [1177, 132], [1180, 148], [1140, 156]]}, {"transcription": "###", "points": [[1096, 152], [1130, 145], [1133, 158], [1100, 165]]}, {"transcription": "###", "points": [[1061, 161], [1092, 154], [1093, 168], [1062, 175]]}, {"transcription": "###", "points": [[1030, 168], [1055, 162], [1056, 177], [1030, 183]]}, {"transcription": "###", "points": [[1000, 173], [1023, 168], [1025, 184], [1002, 189]]}, {"transcription": "###", "points": [[223, 293], [313, 288], [313, 311], [222, 316]]}]
with open(labelfilename, "r", encoding="utf-8") as f:
    lines = f.read().splitlines()
    for line in lines:
        line = line.split("\t")
        imgpath = os.path.join(src, line[0])
        img = Image.open(imgpath)
        img = img.convert("RGB")
        draw = ImageDraw.Draw(img)
        for d in eval(line[1]):
            points = d["points"]
            draw.line(points[0] + points[1] + points[2] + points[3] + points[0], fill=(255, 0, 0), width=3)
        img.show()


8、推导模型的导出与预测

检测模型转inference 模型方式:

python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model="./output/db_mv3_3/best_accuracy" Global.save_inference_dir="./output/db_mv3_3_inference/"

在这里插入图片描述
段落框:

python3 tools/export_model.py -c configs/det/det_mv3_db_para_det_paddle.yml -o Global.pretrained_model="./output/db_mv3_para_det_paddle/best_accuracy" Global.save_inference_dir="./output/db_mv3_para_det_paddle_inference/"

用于预测:

python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./output/det_db_inference/" --image_dir="./doc/imgs/" --use_gpu=True

用于预测,段落框:

python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./output/db_mv3_para_det_paddle_inference/" --image_dir="./doc/imgs/" --use_gpu=True

9、转到onnx模型和mnn模型

到onnx

安装:pip install paddle2onnx onnx onnx-simplifier onnxruntime-gpu
导出模型:paddle2onnx --model_dir ./output/db_mv3_para_det_paddle_inference/ --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./output/db_mv3_para_det_paddle_inference/model.onnx --opset_version 10 --enable_dev_version True --enable_onnx_checker True

参数选项

参数参数说明
–model_dir配置包含Paddle模型的目录路径
–model_filename[可选] 配置位于--model_dir下存储网络结构的文件名
–params_filename[可选] 配置位于--model_dir下存储模型参数的文件名称
–save_file指定转换后的模型保存目录路径
–opset_version[可选] 配置转换为ONNX的OpSet版本,目前支持7~15等多个版本,默认为9
–enable_dev_version[可选] 是否使用新版本Paddle2ONNX(推荐使用),默认为False
–enable_onnx_checker[可选] 配置是否检查导出为ONNX模型的正确性, 建议打开此开关。若指定为True, 默认为False
–enable_auto_update_opset[可选] 是否开启opset version自动升级,当低版本opset无法转换时,自动选择更高版本的opset 默认为True
–input_shape_dict[可选] 配置输入的shape, 默认为空; 此参数即将移除,如需要固定Paddle模型输入Shape,请使用此工具处理
–version[可选] 查看paddle2onnx版本
  • 使用onnxruntime验证转换模型, 请注意安装最新版本(最低要求1.10.0):

如你有ONNX模型优化的需求,推荐使用onnx-simplifier,也可使用如下命令对模型进行优化:

python -m paddle2onnx.optimize --input_model model.onnx --output_model new_model.onnx

如需要修改导出的模型输入形状,如改为静态shape:

python -m paddle2onnx.optimize --input_model model.onnx \
                               --output_model new_model.onnx \
                               --input_shape_dict "{'x':[1,3,224,224]}"

到mnn

sudo /ssd/xiedong/miniconda3/envs/py37c/bin/mnnconvert -f ONNX --modelFile model.onnx --MNNModel model.mnn --bizCode MNN

10、在线可视化

百度提供的visualdl 对log可视化:

visualdl --logdir="vdl"

下面这图是我resume训练了,看起来loss还在波动。
在这里插入图片描述

也可以将inference.pdmodel模型放入后查看模型结构:
在这里插入图片描述

11、推理的时候,从概率图到文字框的过程(后处理)

可以执行下面的指令,添加断点:

python3 tools/export_model.py -c configs/det/det_mv3_db_para_det_paddle.yml -o Global.pretrained_model="./output/db_mv3_para_det_paddle/best_accuracy" Global.save_inference_dir="./output/db_mv3_para_det_paddle_inference/"

此内容在 class DBPostProcess(object) 中, 使用config的一些配置,完成了从概率图到文本框的转换。

在这里插入图片描述

12、DBnet损失函数

源码:


from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss


class DBLoss(nn.Layer):
    """
    Differentiable Binarization (DB) Loss Function
    args:
        param (dict): the super paramter for DB Loss
    """

    def __init__(self,
                 balance_loss=True,
                 main_loss_type='DiceLoss',
                 alpha=5,
                 beta=10,
                 ohem_ratio=3,
                 eps=1e-6,
                 **kwargs):
        super(DBLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.dice_loss = DiceLoss(eps=eps)
        self.l1_loss = MaskL1Loss(eps=eps)
        self.bce_loss = BalanceLoss(
            balance_loss=balance_loss,
            main_loss_type=main_loss_type,
            negative_ratio=ohem_ratio)

    def forward(self, predicts, labels):
        predict_maps = predicts['maps']
        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
            1:]
        shrink_maps = predict_maps[:, 0, :, :]
        threshold_maps = predict_maps[:, 1, :, :]
        binary_maps = predict_maps[:, 2, :, :]

        loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
                                         label_shrink_mask)
        loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
                                           label_threshold_mask)
        loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
                                          label_shrink_mask)
        loss_shrink_maps = self.alpha * loss_shrink_maps
        loss_threshold_maps = self.beta * loss_threshold_maps

        loss_all = loss_shrink_maps + loss_threshold_maps \
                   + loss_binary_maps
        losses = {'loss': loss_all, \
                  "loss_shrink_maps": loss_shrink_maps, \
                  "loss_threshold_maps": loss_threshold_maps, \
                  "loss_binary_maps": loss_binary_maps}
        return losses

一步一步拆解开看,

DiceLoss–>loss_binary_maps

源码:

class DiceLoss(nn.Layer):
    def __init__(self, eps=1e-6):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, gt, mask, weights=None):
        """
        DiceLoss function.
        """

        assert pred.shape == gt.shape
        assert pred.shape == mask.shape
        if weights is not None:
            assert weights.shape == mask.shape
            mask = weights * mask
        intersection = paddle.sum(pred * gt * mask)

        union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
        loss = 1 - 2.0 * intersection / union
        assert loss <= 1
        return loss

公式计算:
在这里插入图片描述

MaskL1Loss–>loss_threshold_maps

源码:


class MaskL1Loss(nn.Layer):
    def __init__(self, eps=1e-6):
        super(MaskL1Loss, self).__init__()
        self.eps = eps

    def forward(self, pred, gt, mask):
        """
        Mask L1 Loss
        """
        loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
        loss = paddle.mean(loss)
        return loss

公式表达:
在这里插入图片描述

BalanceLoss–>bce_loss -->loss_shrink_maps

源码:


class BalanceLoss(nn.Layer):
    def __init__(self,
                 balance_loss=True,
                 main_loss_type='DiceLoss',
                 negative_ratio=3,
                 return_origin=False,
                 eps=1e-6,
                 **kwargs):
        """
               The BalanceLoss for Differentiable Binarization text detection
               args:
                   balance_loss (bool): whether balance loss or not, default is True
                   main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
                       'Euclidean','BCELoss', 'MaskL1Loss'], default is  'DiceLoss'.
                   negative_ratio (int|float): float, default is 3.
                   return_origin (bool): whether return unbalanced loss or not, default is False.
                   eps (float): default is 1e-6.
               """
        super(BalanceLoss, self).__init__()
        self.balance_loss = balance_loss
        self.main_loss_type = main_loss_type
        self.negative_ratio = negative_ratio
        self.return_origin = return_origin
        self.eps = eps

        if self.main_loss_type == "CrossEntropy":
            self.loss = nn.CrossEntropyLoss()
        elif self.main_loss_type == "Euclidean":
            self.loss = nn.MSELoss()
        elif self.main_loss_type == "DiceLoss":
            self.loss = DiceLoss(self.eps)
        elif self.main_loss_type == "BCELoss":
            self.loss = BCELoss(reduction='none')
        elif self.main_loss_type == "MaskL1Loss":
            self.loss = MaskL1Loss(self.eps)
        else:
            loss_type = [
                'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
            ]
            raise Exception(
                "main_loss_type in BalanceLoss() can only be one of {}".format(
                    loss_type))

    def forward(self, pred, gt, mask=None):
        """
        The BalanceLoss for Differentiable Binarization text detection
        args:
            pred (variable): predicted feature maps.
            gt (variable): ground truth feature maps.
            mask (variable): masked maps.
        return: (variable) balanced loss
        """
        positive = gt * mask
        negative = (1 - gt) * mask

        positive_count = int(positive.sum())
        negative_count = int(
            min(negative.sum(), positive_count * self.negative_ratio))
        loss = self.loss(pred, gt, mask=mask)

        if not self.balance_loss:
            return loss

        positive_loss = positive * loss
        negative_loss = negative * loss
        negative_loss = paddle.reshape(negative_loss, shape=[-1])
        if negative_count > 0:
            sort_loss = negative_loss.sort(descending=True)
            negative_loss = sort_loss[:negative_count]
            # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
            balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
                positive_count + negative_count + self.eps)
        else:
            balance_loss = positive_loss.sum() / (positive_count + self.eps)
        if self.return_origin:
            return balance_loss, loss

        return balance_loss

公式表达:

在这里插入图片描述

损失函数配置

Loss:
  name: DBLoss
  balance_loss: true    # 需要平衡损失
  main_loss_type: DiceLoss
  alpha: 5  # 给loss_shrink_maps加权
  beta: 10 # 给loss_threshold_maps加权
  ohem_ratio: 3

loss_shrink_maps:该损失函数是基于二分类交叉熵损失函数的,用于优化预测结果与实际文本区域边界收缩情况之间的差异,以便网络学习到更好的文本边界信息。

loss_threshold_maps:该损失函数是基于L1损失函数的,用于优化预测结果与实际文本区域阈值之间的差异,以便网络学习到更好的文本区域阈值信息。在二值化处理中,使用了一个预设的阈值,该损失函数可以使网络更好地学习到最优阈值。

loss_binary_maps:该损失函数是基于Dice Loss的,用于优化预测结果与实际二值化文本区域之间的重叠率,以便网络学习到更准确的二值化文本区域信息。

这三种损失函数都是为了优化网络预测结果与实际标签之间的差异,从而提高文本检测的性能。同时,不同的损失函数也针对不同的方面进行优化,使网络学习到更多的相关信息,并最终提高整个文本检测系统的性能。

13、评估指标

在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PaddleOCR 文字检测/文字块检测的模型训练过程,DBnet的前处理和后处理流程损失函数 的相关文章

随机推荐