如题:共需要修改四个文件
1.mmcv/runner/hooks/logger/__init__.py
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
]
2.mmcv/runner/__init__.py
第6行:
TensorboardLoggerHook, WandbLoggerHook)
第18,19,20行:
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str','obj_from_dict'
3.mmcv/runner/hooks/logger/wandb.py 新增整个文件
from ...utils import master_only
from .base import LoggerHook
import numbers
class WandbLoggerHook(LoggerHook):
def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.import_wandb()
def import_wandb(self):
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
if self.wandb is None:
self.import_wandb()
self.wandb.init()
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
runner.log_buffer.output[var]
if isinstance(val, numbers.Number):
metrics[tag] = val
if metrics:
self.wandb.log(metrics, step=runner.iter)
@master_only
def after_run(self, runner):
self.wandb.join()
4.mmcv/runner/hooks/__init__.py
第10行,
TensorboardLoggerHook, WandbLoggerHook)
第15行,
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook','WandbLoggerHook'