今天在看DSSINet代码的ssim.py时,遇到了一个用法
class NORMMSSSIM(torch.nn.Module):
def __init__(self, sigma=1.0, levels=5, size_average=True, channel=1):
super(NORMMSSSIM, self).__init__()
self.sigma = sigma
self.window_size = 5
self.levels = levels
self.size_average = size_average
self.channel = channel
self.register_buffer('window', create_window(self.window_size, self.channel, self.sigma))
self.register_buffer('weights', torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))
那么这个register_buffer()是干什么用呢?官方解释如下
nn.modules.module.py
Adds a persistent buffer to the module.向模块添加持久缓冲区。
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的“running_mean”不是参数,而是持久状态的一部分。
Buffers can be accessed as attributes using given names.
缓冲区可以使用给定的名称作为属性访问。
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name 名称(字符串):缓冲区的名称。可以使用给定的名称从该模块访问缓冲区
tensor (Tensor): buffer to be registered.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
应该就是在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出。
pytorch一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者。
class myModel(nn.Module):
def __init__(self, kernel_size=3):
super(Depth_guided1, self).__init__()
self.kernel_size = kernel_size
self.back_end = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 3, padding=1),
torch.nn.ReLU(True),
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(True),
torch.nn.Conv2d(64, 3, 3, padding=1),
torch.nn.ReLU(True),
)
mybuffer = np.arange(1,10,1)
self.mybuffer_tmp = np.randn((len(mybuffer), 1, 1, 10), dtype='float32')
self.mybuffer_tmp = torch.from_numpy(self.mybuffer_tmp)
# register preset variables as buffer
# So that, in testing , we can use buffer variables.
self.register_buffer('mybuffer', self.mybuffer_tmp)
# Learnable weights
self.conv_weights = nn.Parameter(torch.FloatTensor(64, 10).normal_(mean=0, std=0.01))
# Other code
def forward(self):
...
# 这里使用 self.mybuffer!
注记:
1.定义parameter和buffer都只需要传入Tensor即可。也不需要将其转成gpu,这是因为,当网络进行.cuda时候,会自动将里面的层的参数,buffer等转换成相应的GPU上。
2. self.register_buffer可以将tensor注册成buffer,在forward中使用self.mybuffer,而不是self.mybuffer_tmp
3.网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。
4.buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数。