为了更好理解Pytorch基本类的实现方法,我这里给出了关于参数方面的3个类的源码详解。
此部分可以更好的了解实现逻辑结构,有助于后续代码理解,学pytorch的话这个不是必须掌握的,看不懂也没关系。
文章目录
- 1 Parameter 参数类源码
- 2 ParameterList 参数列表类源码
- 3 ParameterDict 参数字典类源码
- 总结
1 Parameter 参数类源码
此部分参考《pytorch源码阅读系列之Parameter类》,《通俗的讲解Python中的__new__()方法》
因为Parameter继承于torch.Tensor,没有新的变量和添加函数,只是对一些辅助函数进行了定义
Parameter作为Module类的参数,可以自动的添加到Module类的参数列表中,并且可以使用Module.parameters()提供的迭代器获取到,所以这个类是一切网络结构数据的核心。
class Parameter(torch.Tensor):
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto):
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
2 ParameterList 参数列表类源码
这个类实际上是将一个Parameter的List转为ParameterList,如下例所示[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
类型是List,List的每个元素是Parameter,然后这个List作为参数传入这个类构造ParameterList类型。
ParameterList输入一定是一个Parameter的List,其他类型会报错,在注册时候就会提示元素不是Parameter类型。
parms = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
下面是对应的源码。
class ParameterList(Module):
def __init__(self, parameters=None):
super(ParameterList, self).__init__()
if parameters is not None:
self += parameters
def _get_abs_string_index(self, idx):
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(list(self._parameters.values())[idx])
else:
idx = self._get_abs_string_index(idx)
return self._parameters[str(idx)]
def __setitem__(self, idx, param):
idx = self._get_abs_string_index(idx)
return self.register_parameter(str(idx), param)
def __len__(self):
return len(self._parameters)
def __iter__(self):
return iter(self._parameters.values())
def __iadd__(self, parameters):
return self.extend(parameters)
def __dir__(self):
keys = super(ParameterList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def append(self, parameter):
self.register_parameter(str(len(self)), parameter)
return self
def extend(self, parameters):
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(parameters).__name__)
offset = len(self)
for i, param in enumerate(parameters):
self.register_parameter(str(offset + i), param)
return self
def extra_repr(self):
child_lines = []
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
3 ParameterDict 参数字典类源码
ParameterDict 是一个字典类源码,与python的字典非常相似,下面就是字典的一个例子,输入参数是个普通字典,然后转换为ParameterDict类型。
params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10))})
下面给出这个类的源码,并对其进行详细分析理解。
class ParameterDict(Module):
def __init__(self, parameters=None):
super(ParameterDict, self).__init__()
if parameters is not None:
self.update(parameters)
def __getitem__(self, key):
return self._parameters[key]
def __setitem__(self, key, parameter):
self.register_parameter(key, parameter)
def __delitem__(self, key):
del self._parameters[key]
def __len__(self):
return len(self._parameters)
def __iter__(self):
return iter(self._parameters.keys())
def __contains__(self, key):
return key in self._parameters
def clear(self):
self._parameters.clear()
def pop(self, key):
v = self[key]
del self[key]
return v
def keys(self):
return self._parameters.keys()
def items(self):
return self._parameters.items()
def values(self):
r"""Return an iterable of the ParameterDict values.
"""
return self._parameters.values()
def update(self, parameters):
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParametersDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(parameters).__name__)
if isinstance(parameters, container_abcs.Mapping):
if isinstance(parameters, (OrderedDict, ParameterDict)):
for key, parameter in parameters.items():
self[key] = parameter
else:
for key, parameter in sorted(parameters.items()):
self[key] = parameter
else:
for j, p in enumerate(parameters):
if not isinstance(p, container_abcs.Iterable):
raise TypeError("ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(p).__name__)
if not len(p) == 2:
raise ValueError("ParameterDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) +
"; 2 is required")
self[p[0]] = p[1]
def extra_repr(self):
child_lines = []
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
child_lines.append(' (' + k + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
总结
关于参数的三个类的分析就到这里了,其实感觉跟正常的python用法也没啥区别,为了方便用户使用pytorch,官方重载了大量的函数,方便用户使用,很大程度上降低了使用难度。后续,我再对模型的几个类比如Sequential,ModuleList,ModuleDict进行分析,Module这个类我估计不会进行分析了,将近1000行,实现了太多太多功能,我觉得太底层了,就不分析了,如果有人感兴趣的话,欢迎一起讨论研究。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)