MMdetection之train_detector 源码解析

2023-10-29

目录

(一)构建 data loaders(mmdet/datasets/builder.py)

(2)构建分布式处理对象

(3)构建优化器

(4)创建 EpochBasedRunner 并进行训练


(一)构建 data loaders(mmdet/datasets/builder.py)

其主要步骤是创建采样器,并将采样器,collate 函数, worker_init_fn 函数传入 DataLoader 中,用于创建 pytorch dataloader。

def build_dataloader(dataset,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     shuffle=True,
                     seed=None,
                     **kwargs):
    # 获取进程编号和总进程数
    rank, world_size = get_dist_info()
    # 如果是分布式训练, 即使用 dist_train.sh 会进入此 if.
    if dist:
        # DistributedGroupSampler 会进行 shuffle, 而且会保证每个 GPU 的样本都是同一组的.
        if shuffle:
            sampler = DistributedGroupSampler(dataset, samples_per_gpu,
                                              world_size, rank)
        # 不 shuffle, 使用 torch.utils.data 中的 DistributedSampler
        # 因为 pytorch < 1.2 没有 shuffle 形参, 
        # 为了版本适配, 重写了一个 DistributedSampler
        else:
            sampler = DistributedSampler(
                dataset, world_size, rank, shuffle=False)
        batch_size = samples_per_gpu
        num_workers = workers_per_gpu
    # 不是分布式训练, 即直接使用 train.py 进行训练.
    else:
        sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
        batch_size = num_gpus * samples_per_gpu
        num_workers = num_gpus * workers_per_gpu

    init_fn = partial(
        worker_init_fn, num_workers=num_workers, rank=rank,
        seed=seed) if seed is not None else None

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
        pin_memory=False,
        worker_init_fn=init_fn,
        **kwargs)

    return data_loader


def worker_init_fn(worker_id, num_workers, rank, seed):
    # The seed of each worker equals to
    # num_worker * rank + worker_id + user_seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)

采样器总共有三种:DistributedGroupSampler,DistributedSampler,GroupSampler。

DistributedGroupSampler 和 GroupSampler 用于按组进行分布式采样(注意:在 MMDetection 中将图像宽大于高和宽小于高分为两组,每个 GPU 中的图像应该取自同一组)

DistributedSampler 是 pytorch 自带分布式采样的重写类,因为 pytorch < 1.2 没有 shuffle 参数。也就是说,pytorch < 1.2 的 DistributedSampler 只支持顺序分布式采样。而 pytorch >= 1.2 支持顺序和乱序采样。为了版本适配和接口统一,重写了一个 DistributedSampler,将 shuffle 始终设置为 False。

① GroupSampler(datasets/samplers/group_samples.py)

from __future__ import division
import math

import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler


class GroupSampler(Sampler):
    def __init__(self, dataset, samples_per_gpu=1):
        # 如果图片的  宽 > 高, 记为 为 1
        #            宽 < 高, 记为 为 0
        # flag 是一个记录了数据集中所有图片的 ndarray
        assert hasattr(dataset, 'flag')
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.flag = dataset.flag.astype(np.int64)
        # np.bincount 计算每个索引出现的次数
        # 在这里就相当于计算了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
        self.group_sizes = np.bincount(self.flag)
        self.num_samples = 0
        for i, size in enumerate(self.group_sizes):
            # 保证每组的 sample 数都能被 samples_per_gpu 的数量整除
            self.num_samples += int(np.ceil(
                size / self.samples_per_gpu)) * self.samples_per_gpu

    def __iter__(self):
        indices = []
        for i, size in enumerate(self.group_sizes):
            # 如果数据集中的所有的图片的宽都 < 高, 那么进行下一次循环.
            if size == 0:
                continue
            # 找到 宽 < 高(i = 0) 或 宽 > 高(i = 1) 的所有的图片索引
            indice = np.where(self.flag == i)[0]
            assert len(indice) == size
            # 随机打乱索引
            np.random.shuffle(indice)
            # 因为图片个数不一定会被 samples_per_gpu 整除, 所以添加额外的数据.
            # num_extra 即为添加额外数据的数量.
            num_extra = int(np.ceil(size / self.samples_per_gpu)
                            ) * self.samples_per_gpu - len(indice)
            # np.concatenate(需要concat的list, axis=0)
            # np.random.choice(list, 选的size)
            # 生成所有的 index
            indice = np.concatenate(
                [indice, np.random.choice(indice, num_extra)])
            indices.append(indice)
        # 整合所有的 index
        indices = np.concatenate(indices)
        # 如下操作可以保证每个 samples_per_gpu 的 flag 都相同
        indices = [
            indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
            for i in np.random.permutation(
                range(len(indices) // self.samples_per_gpu))
        ]
        indices = np.concatenate(indices)
        indices = indices.astype(np.int64).tolist()
        assert len(indices) == self.num_samples
        return iter(indices)

    def __len__(self):
        return self.num_samples

② DistributedGroupSampler(mmset/datasets/samplers/group_samples.py)

class DistributedGroupSampler(Sampler):
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        # 获取 rank 和 world_size (num_replicas)
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank

        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu

        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        # 统计了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
        self.group_sizes = np.bincount(self.flag)

        # 每个进程需要采样的样本数
        self.num_samples = 0

        for i, j in enumerate(self.group_sizes):
            # self.group_sizes[i] / self.samples_per_gpu:能分成多少组
            # 下面的代表计算了每个机器分的个数.
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        # 所有进程要采样的样本总数。
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # 把当前的 epoch 作为随机数种子,
        # 这样能保证在相同的 epoch 的实验有可重复性,
        # 且在不同的 epoch 之间有随机性.
        g = torch.Generator()
        g.manual_seed(self.epoch)

        indices = []
        for i, size in enumerate(self.group_sizes):
            # 如果有样本
            if size > 0:
                # 找出所有属于此类的索引
                indice = np.where(self.flag == i)[0]
                assert len(indice) == size
                # 随机打乱索引
                indice = indice[list(torch.randperm(int(size),
                                                    generator=g))].tolist()
                # 总共需要额外添加的样本数
                extra = int(
                    math.ceil(
                        size * 1.0 / self.samples_per_gpu / self.num_replicas)
                ) * self.samples_per_gpu * self.num_replicas - len(indice)

                # 填充 indice
                tmp = indice.copy()
                for _ in range(extra // size):
                    indice.extend(tmp)
                # 取随机后的前 extra 个作为 extra 样本.
                indice.extend(tmp[:extra % size])
                indices.extend(indice)

        assert len(indices) == self.total_size

        # 打乱 sample_per_gpu 之间的顺序,
        # 因为上面已经打乱了每个 group 之内的元素,
        # 所以这里只用打乱组之间的顺序即可.
        indices = [
            indices[j] for i in list(
                torch.randperm(
                    len(indices) // self.samples_per_gpu, generator=g))
            for j in range(i * self.samples_per_gpu, (i + 1) *
                           self.samples_per_gpu)
        ]

        # 采样 num_samples 个.不同进程之间按照打乱的数据集顺序采样.
        offset = self.num_samples * self.rank
        indices = indices[offset:offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

③ DistributedSampler(datasets/samplers/distributed_sampler.py)

import torch
from torch.utils.data import DistributedSampler as _DistributedSampler


# pytorch < 1.2 没有 shuffle, 为了版本适配, 这里选择重写
class DistributedSampler(_DistributedSampler):

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
        self.shuffle = shuffle

    def __iter__(self):
        # 把当前的 epoch 作为随机数种子,
        # 这样能保证在相同的 epoch 的实验有可重复性,
        # 且在不同的 epoch 之间有随机性.
        if self.shuffle:
            # 使用随机数生成器, 根据 epoch 生成随机数种子.
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()

        # 添加额外的样本使其均匀可分
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

collate 函数用于整理与合并每个 batch 的数据。并将这个函数传给 DataLoader 的 collate_fn 参数。

from collections.abc import Mapping, Sequence

import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

from .data_container import DataContainer


def collate(batch, samples_per_gpu=1):
    """Puts each data field into a tensor/DataContainer with outer dimension
    batch size.

    Extend default_collate to add support for
    :type:`~mmcv.parallel.DataContainer`. There are 3 cases.

    1. cpu_only = True, e.g., meta data
    2. cpu_only = False, stack = True, e.g., images tensors
    3. cpu_only = False, stack = False, e.g., gt bboxes
    """
    # batch 是一个长度为 batch_size 的列表, 每个元素是一个字典, 每个字典代表一张图片.
    # 字典的键为:dict_keys(['img_metas', 'img', 'gt_bboxes', 'gt_labels'])

    # 确保 batch 是一个序列
    if not isinstance(batch, Sequence):
        raise TypeError(f'{batch.dtype} is not supported.')

    if isinstance(batch[0], DataContainer):
        assert len(batch) % samples_per_gpu == 0
        stacked = []

        # cpu_only 说明是 meta data
        if batch[0].cpu_only:
            # batch[0].stack:           False
            # batch[0].padding_value:   0
            for i in range(0, len(batch), samples_per_gpu):
                # 每 samples_per_gpu 个, 创建一个列表
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
            # 转成 DataContainer 对象
            return DataContainer(
                stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
        # stack 为 True 说明是图片类型的数据 或 label 数据
        elif batch[0].stack:
            for i in range(0, len(batch), samples_per_gpu):
                assert isinstance(batch[i].data, torch.Tensor)
                # 需要填充维度
                if batch[i].pad_dims is not None:
                    ndim = batch[i].dim()
                    assert ndim > batch[i].pad_dims
                    max_shape = [0 for _ in range(batch[i].pad_dims)]
                    for dim in range(1, batch[i].pad_dims + 1):
                        max_shape[dim - 1] = batch[i].size(-dim)
                    for sample in batch[i:i + samples_per_gpu]:
                        for dim in range(0, ndim - batch[i].pad_dims):
                            assert batch[i].size(dim) == sample.size(dim)
                        for dim in range(1, batch[i].pad_dims + 1):
                            max_shape[dim - 1] = max(max_shape[dim - 1],
                                                     sample.size(-dim))
                    padded_samples = []
                    for sample in batch[i:i + samples_per_gpu]:
                        pad = [0 for _ in range(batch[i].pad_dims * 2)]
                        for dim in range(1, batch[i].pad_dims + 1):
                            pad[2 * dim -
                                1] = max_shape[dim - 1] - sample.size(-dim)
                        padded_samples.append(
                            F.pad(
                                sample.data, pad, value=sample.padding_value))
                    stacked.append(default_collate(padded_samples))
                # 不填充维度
                elif batch[i].pad_dims is None:
                    stacked.append(
                        default_collate([
                            sample.data
                            for sample in batch[i:i + samples_per_gpu]
                        ]))
                else:
                    raise ValueError(
                        'pad_dims should be either None or integers (1-3)')
        # 说明是 gt bboxes
        else:
            # 取 samples_per_gpu 个, 创建列表返回.
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
        return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
    # 是序列
    elif isinstance(batch[0], Sequence):
        transposed = zip(*batch)
        return [collate(samples, samples_per_gpu) for samples in transposed]
    
    # 最开始传入的是一个字典, 里面有 图像属性, 图像, gt, label 等信息.
    # 所以会先进入下面的 if
    elif isinstance(batch[0], Mapping):
        # 返回一个字典: 每个 key 的值是原来所有 key 的值的 collate 后的结果
        return {
            key: collate([d[key] for d in batch], samples_per_gpu)
            # 遍历每个 key
            for key in batch[0]
        }
    # 采用默认的整理方式
    else:
        return default_collate(batch)

每个线程的随机数种子默认为线程ID,每次运行时随机数种子不固定。考虑到实验的可重复性,创建一个 worker_init_fn 函数传参给 DataLoader 中的 worker_init_fn 参数,此参数是 worker 的初始化函数。将 num_worker * rank + worker_id + user_seed 作为随机数种子,可以解决每个线程中随机数种子不确定的情况。

def worker_init_fn(worker_id, num_workers, rank, seed):
    # 将 num_worker * rank + worker_id + user_seed 作为随机数种子
    # 可以解决线程之间随机数种子不固定的情况
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)

(2)构建分布式处理对象

MMDetection 对 pytorch 的 DistributedDataParallel 和 DataParallel 在外面有一层封装,重写了 scatter 方法,额外实现了 train_step 和 val_step 方法。scatter 方法用于将数据分发到指定的 GPU,train_step 和 val_step 对于传入的一个 batch 的数据,会调用 Detector 的 train_step 或 val_step 计算损失或得到模型输出值。(MMDetection 所有 Detector 都有 train_step 和 val_step 方法,这样在训练的时候就不需要传入损失函数来计算损失了,不同的模型可以使用不同的损失,同一个模型也可以使用不同的损失函数。这样更灵活)

如果单机多卡会使用 MMDistributedDataParallel 构建对象。如果单机单卡会使用 MMDataParallel 构建对象。

① MMDistributedDataParallel

# Copyright (c) Open-MMLab. All rights reserved.
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
                                           _find_tensors)

from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs


class MMDistributedDataParallel(DistributedDataParallel):
    """The DDP module that supports DataContainer.

    MMDDP has two main differences with PyTorch DDP:

    - It supports a custom type :class:`DataContainer` which allows more
      flexible control of input data.
    - It implement two APIs ``train_step()`` and ``val_step()``.
    """

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def train_step(self, *inputs, **kwargs):
        """train_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.train_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.train_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.train_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output

    def val_step(self, *inputs, **kwargs):
        """val_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.val_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.val_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.val_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output

② MMDataParallel

# Copyright (c) Open-MMLab. All rights reserved.
from itertools import chain

from torch.nn.parallel import DataParallel

from .scatter_gather import scatter_kwargs


class MMDataParallel(DataParallel):

    def scatter(self, inputs, kwargs, device_ids):
        """将数据分散到指定的 GPU设备"""
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def train_step(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module.train_step(*inputs, **kwargs)

        assert len(self.device_ids) == 1, \
            ('MMDataParallel only supports single GPU training, if you need to'
             ' train with multiple GPUs, please use MMDistributedDataParallel'
             'instead.')

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    'module must have its parameters and buffers '
                    f'on device {self.src_device_obj} (device_ids[0]) but '
                    f'found one of them on device: {t.device}')

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        return self.module.train_step(*inputs[0], **kwargs[0])

    def val_step(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module.val_step(*inputs, **kwargs)

        assert len(self.device_ids) == 1, \
            ('MMDataParallel only supports single GPU training, if you need to'
             ' train with multiple GPUs, please use MMDistributedDataParallel'
             'instead.')

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    'module must have its parameters and buffers '
                    f'on device {self.src_device_obj} (device_ids[0]) but '
                    f'found one of them on device: {t.device}')

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        return self.module.val_step(*inputs[0], **kwargs[0])

(3)构建优化器

构建优化器使用 build_optimizer 函数,我们可以看出它的本质也是调用 build_from_cfg。

import copy
import inspect

import torch

from ...utils import Registry, build_from_cfg

OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')


def register_torch_optimizers():
    torch_optimizers = []
    for module_name in dir(torch.optim):
        if module_name.startswith('__'):
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim,
                                                  torch.optim.Optimizer):
            OPTIMIZERS.register_module()(_optim)
            torch_optimizers.append(module_name)
    return torch_optimizers


TORCH_OPTIMIZERS = register_torch_optimizers()


def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)


def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return optimizer

(4)创建 EpochBasedRunner 并进行训练

它继承了 BaseRunner。对于 BaseRunner 主要提供了公共的属性和方法如,获取训练的属性(epoch 数量,iter 次数等)注册 hook,查看 hook 等。还有 4 个抽象方法,需要子类继承,分别是:train,val,run,save_checkpoint。

EpochBasedRunner 继承 BaseRunner,重写了 train,val,run,save_checkpoint 方法。

调用 run 方法,传入 dataloaders,work_flow,最大的循环次数等。就可以实现训练。对于不同的阶段(如:run 前,epoch 前等)调用所有相关注册的 hook。这样可定制性很强。

# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings

import torch

import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info


class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            if self.batch_processor is None:
                outputs = self.model.train_step(data_batch, self.optimizer,
                                                **kwargs)
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=True, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.train_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                if self.batch_processor is None:
                    outputs = self.model.val_step(data_batch, self.optimizer,
                                                  **kwargs)
                else:
                    outputs = self.batch_processor(
                        self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.val_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))


class Runner(EpochBasedRunner):
    """Deprecated name of EpochBasedRunner."""

    def __init__(self, *args, **kwargs):
        warnings.warn(
            'Runner was deprecated, please use EpochBasedRunner instead')
        super().__init__(*args, **kwargs)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MMdetection之train_detector 源码解析 的相关文章

随机推荐

  • C++代码复习(三+)——SeqList顺序表内数据的数组方式实现

    include
  • 排序二叉树转变为有序双向链表

    要点 1 直接改变树的结构 2 排序二叉树在中序遍历的时候是有序的 3 双向链表 需要前后两个指针 可以将Tree的节点作为链表节点 代码实现 中序的递归实现 void ToList Tree pTree Tree pHead Tree p
  • 图形识别工具-百度AI接口实现

    简介 借出百度AI平台后看到有一个图像识别工具 就简单实现了下 效果挺好的 使用也简单 百度提供了两种实现方式 1 api方式调用 2 sdk方式调用 此方式简单 本文就以此为主讲一下 a 先下载图像识别sdk 地址 https cloud
  • 【毕业设计】 微信小程序购物商城系统 【含代码】

    文章目录 0 前言 1 开发工具 2 总体架构 3 项目规划 4 云数据库 5 项目解构 5 1 购买首页 5 2 商品详情页 5 3 搜索页 5 4 品牌分类页 5 5 筛选排序页 6 最后 0 前言 Hi 同学们好呀 学长今天带大家做一
  • 06-限流策略有哪些,滑动窗口算法和令牌桶区别,使用场景?【Java面试题总结】

    限流策略有哪些 滑动窗口算法和令牌桶区别 使用场景 常见的限流算法有固定窗口 滑动窗口 漏桶 令牌桶等 6 1 固定窗口 概念 固定窗口 又称计算器限流 对一段固定时间窗口内的请求进行一个计数 如果请求数量超过阈值 就会舍弃这个请求 如果没
  • 【EI会议】2022年第三届纳米材料与纳米技术国际会议(NanoMT 2022)

    2022年第三届纳米材料与纳米技术国际会议 NanoMT 2022 重要信息 会议网址 www nanomt org 会议时间 2022年9月23 25日 召开地点 中国南京 截稿时间 2022年8月21日 录用通知 投稿后2周内 收录检索
  • QT鼠标控制

    文章目录 鼠标状态改变 限制鼠标活动区域 鼠标状态改变 void QApplication setOverrideCursor const QCursor cursor bool replace FALSE 设置应用程序强制光标为 curs
  • 晨读-为什么有时控制不了我的情绪?

    情绪是天生的 而且每一种情绪都有它的功能 例如恐惧让我们远离危险 焦虑让我们提升行动力 等等 但是我们还是会出现的情况是 明明我都理解 那些道理我都懂 为什么我还是忍不住难受 这些冒出来的情绪还是不受控制 在控制情绪之前 我们先要了解 我们
  • [避坑指南]GD32F130系列TIMER14

    前言 本人在使用GD32F130F8P6时 使能PA3引脚输出PWM波 但是检查代码没有问题 就是不出PWM波 折磨了3天 最后发现是该款单片机没有TIMER14定时器 手册误导用户啊 代码部分 此代码驱动TIMER14是没有问题的 voi
  • 9道常见的java笔试选择题

    1 关于Java编译 下面哪一个正确 选择一项 A Java程序经编译后产生machine code B Java程序经编译后会生产byte code C Java程序经编译后会产生DLL D 以上都不正确 答案 B 分析 Java是解释型
  • 北京大学肖臻老师《区块链技术与应用》公开课笔记15——ETH概述篇

    北京大学肖臻老师 区块链技术与应用 公开课笔记 以太坊概述篇 对应肖老师视频 click here 全系列笔记请见 click here About Me 点击进入我的Personal Page BTC和ETH为最主要的两种加密货币 BTC
  • 山东大学项目实训开发日志——基于vue+springboot的医院耗材管理系统(16)

    今天我们解决了一个困扰了我们很久的问题 isqr值的获取与使用 功能的设想 通过isqr这个值来确定该耗材是否使用二维码管理 在新增耗材种类的时候加入该属性 选择是或否 并写入数据库 在显示库存数据的时候通过耗材的id查找该值 以此决定是否
  • 解决:Cannot deserialize value of type `java.util.Date` from String “xxx“: not a valid representation..

    一 问题 在做数据更新操作的时候 后台数据为Date时 前端把String类型数据传到后台时 Date类型无法识别这个String数据 所以会报错 二 错误描述 主要问题 Caused by com fasterxml jackson da
  • linux重启命令

    shutdown重启系统 usr sbin shutdown r now usr sbin 指定了命令的位置 路径 shutdown 是命令本身 r 是指示重新启动系统的选项 now 表示立即执行命令 不进行倒计时 也可以指定一个时间延迟
  • el-input校验,只能输入正整数

    一 表单校验方式 fileSort required true message 请输入排序 trigger blur pattern 1 9 d message 请输入正整数 trigger blur 二 el input的type设置为n
  • mybatis笔记(老杜版本)

    一 MyBatis概述 1 1框架 Java常 框架 SSM三 框架 Spring SpringMVC MyBatis SpringBoot SpringCloud 等 SSM三 框架的学习顺序 MyBatis Spring SpringM
  • mysql jdbc url utf8_Mysql JDBC Url参数与异常问题

    今天在写Java项目使用了 SELECT FROM plan WHERE isDelete isDelete AND nestId in open close separator gt nestId 但是很不幸 后台报异常 java sql
  • springboot整合七牛云对象存储

    目录 一 测试 二 整合 一 测试 注册七牛云账号 并进行邮箱绑定和实名认证 七牛云每个月送10G完全够我们开发 创建一个空间 存储区域哪里离你近选哪里 访问控制一定要公开 创建完成后 后期上传的静态资源 可以根据域名 文件名直接访问 自定
  • Java中正则表达式的使用

    Java中正则表达式的使用 在Java中 我们为了查找某个给定字符串中是否有需要查找的某个字符或者子字串 或者对字符串进行分割 或者对字符串一些字符进行替换 删除 一般会通过if else for 的配合使用来实现这些功能 如下所示 Jav
  • MMdetection之train_detector 源码解析

    目录 一 构建 data loaders mmdet datasets builder py 2 构建分布式处理对象 3 构建优化器 4 创建 EpochBasedRunner 并进行训练 一 构建 data loaders mmdet d