PyTorch:如何检查训练期间某些权重是否没有改变?

2024-04-20

如何检查 PyTorch 训练期间某些权重是否未更改?

据我了解,一种选择可以是在某些时期转储模型权重,并检查它们是否通过迭代权重进行更改,但也许有一些更简单的方法?


有两种方法可以解决这个问题:

First

        for name, param in model.named_parameters():
            if 'weight' in name:
                temp = torch.zeros(param.grad.shape)
                temp[param.grad != 0] += 1
                count_dict[name] += temp

此步骤在您之后进行loss.backward()培训模块中的步骤。这count_dict[name]字典跟踪梯度更新。您可以在训练开始之前以这种方式初始化它:

    for name, param in model.named_parameters():
        if 'weight' in name:
            count_dict[name] = torch.zeros(param.grad.shape)

现在,另一种方法是注册一个钩子函数,然后创建该钩子函数,您甚至可以根据需要更新或修改渐变。这对于跟踪权重更新来说并不是必需的,但如果您想对梯度做一些事情,它就会派上用场。 假设,我在这里随机稀疏梯度。

def hook_fn(grad):
    '''
    Randomly sparsify the gradients
    :param grad: Input gradient of the layer
    :return: grad_clone - the sparsified FC layer gradients
    '''
    grad_clone = grad.clone()
    temp = torch.cuda.FloatTensor(grad_clone.shape).uniform_()
    grad_clone[temp < 0.8] = 0
    return grad_clone

在这里我给模型一个钩子。

for name, param in model.named_parameters():
    if 'weight' in name:
            param.register_hook(hook_fn)

因此,这可能只是为您稀疏梯度,您可以通过以下方式跟踪钩子函数本身的梯度:

def hook_func(module, input, output):
    temp = torch.zeros(output.shape)
    temp[output != 0] += 1
    count_dict[module] += temp

虽然,我不建议这样做。这在可视化前向传递特征/激活的情况下通常很有用。而且,输入和输出可能会混淆,因为梯度和参数输入和输出是相反的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch:如何检查训练期间某些权重是否没有改变? 的相关文章

随机推荐