传送门
[Detectron2] 01-注册机制 Registry 实现
一、为什么使用注册类
以下转自知乎 https://zhuanlan.zhihu.com/p/93835858
对于detectron2这种,需要支持许多不同的模型的大型框架,理想情况下所有的模型的参数都希望写在配置文件中,那问题来了,如果我希望根据我的配置文件,决定我是需要用VGG还是用ResNet ,我要怎么写呢?
如果是我,我可能会写出这种可扩展性超级低的暴搓的代码:
if class_name == 'VGG':
model = build_VGG(args)
elif class_name == 'ResNet':
model = build_ResNet(args)
但是如果用了注册类,代码就是这样的:
class_name = 'VGG'
model = model_registry(class_name)(args)
可以看到代码的可扩展性变得非常强了
二、注册类的实现
注册类的源码
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name: str, obj: object) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
def deco(func_or_class: object) -> object:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
name = obj.__name__
print("name: ", name)
self._do_register(name, obj)
def get(self, name: str) -> object:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(name, self._name)
)
return ret
def __contains__(self, name: str) -> bool:
return name in self._obj_map
如何使用
registry_machine = Registry('registry_machine')
@registry_machine.register()
def print_hello_world(word):
print("he world")
print(word)
@registry_machine.register()
def print_hi_world(word):
print("hi world")
print(word)
cfg = "print_hello_world"
registry_machine.get(cfg)('hello world')
cfg = "print_hi_world"
registry_machine.get(cfg)('hello world2')
利用@装饰器,将函数传入到registry中,在registry内部获取回调函数的名字,并创建字典,字典对应 函数名:函数,每次装饰一个函数,字典就会添加一次。
如此,就可以通过函数名来找到对应的函数
回调函数获取函数名,可以通过 __name__获取
def hello23():
print("hello2")
def printHello(hello):
print("world")
print(hello.__name__)
printHello(hello23)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)