好的,先看结果:
性能(我的笔记本电脑有 RTX-2070,PyTorch 正在使用它):
# Method 1: Use the jacobian function
CPU times: user 34.6 s, sys: 165 ms, total: 34.7 s
Wall time: 5.8 s
# Method 2: Sample with appropriate vectors
CPU times: user 1.11 ms, sys: 0 ns, total: 1.11 ms
Wall time: 191 µs
速度快了大约 30000 倍。
为什么你应该使用backward
代替jacobian
(就你而言)
我不是 PyTorch 的专业人士。但是,根据我的经验,如果不需要雅可比矩阵中的所有元素,那么计算雅可比矩阵的效率相当低。
如果只需要对角线元素,可以使用backward
计算函数vector- 与一些特定向量的雅可比乘法。如果您设置vector如果正确,您可以从雅可比矩阵中采样/提取特定元素。
一点线性代数:
j = np.array([[1,2],[3,4]]) # 2x2 jacobi you want
sv = np.array([[1],[0]]) # 2x1 sampling vector
first_diagonal_element = sv.T.dot(j).dot(sv) # it's j[0, 0]
对于这个简单的例子来说,它的功能并不是那么强大。但是如果 PyTorch 需要计算所有雅可比矩阵(j
可能是一长串矩阵-矩阵乘法的结果),它会太慢。相反,如果我们计算向量雅可比乘法序列,计算速度将会非常快。
Solution
雅可比的示例元素:
import torch
from torch.autograd import grad
import torch.nn as nn
import torch.optim as optim
class net_x(nn.Module):
def __init__(self):
super(net_x, self).__init__()
self.fc1=nn.Linear(1, 20)
self.fc2=nn.Linear(20, 20)
self.out=nn.Linear(20, 400) #a,b,c,d
def forward(self, x):
x=torch.tanh(self.fc1(x))
x=torch.tanh(self.fc2(x))
x=self.out(x)
return x
nx = net_x()
#input
val = 100
a = torch.rand(val, requires_grad = True) #input vector
t = torch.reshape(a, (val,1)) #reshape for batch
#method
%time dx = torch.autograd.functional.jacobian(lambda t_: nx(t_), t)
dx = torch.diagonal(torch.diagonal(dx, 0, -1), 0)[0] #first vector
#dx = torch.diagonal(torch.diagonal(dx, 1, -1), 0)[0] #2nd vector
#dx = torch.diagonal(torch.diagonal(dx, 2, -1), 0)[0] #3rd vector
#dx = torch.diagonal(torch.diagonal(dx, 3, -1), 0)[0] #4th vector
print(dx)
out = nx(t)
m = torch.zeros((val,400))
m[:, 0] = 1
%time out.backward(m)
print(a.grad)
a.grad
应等于第一个张量dx
. And, m
是与代码中所谓的“第一个向量”相对应的采样向量。
- 但如果我再次运行它,答案就会改变。
是的,你已经明白了。每次调用时梯度都会累积backward
。所以你必须设置a.grad
如果您必须多次运行该单元,请先为零。
- 你能解释一下背后的想法吗
m
方法?两者都使用torch.zeros
并将该列设置为1
。还有,怎么毕业了a
而不是在t
?
- 背后的想法
m
方法是:功能是什么backward
计算实际上是一个向量雅可比矩阵乘法,其中向量代表所谓的“上游梯度”,雅可比矩阵是“局部梯度”(这个雅可比矩阵也是你用jacobian
函数,因为你的lambda
可以被视为单个“本地”操作)。如果您需要来自雅可比的一些元素,您可以伪造(或者更准确地说,构造)一些“上游梯度”,用它您可以从雅可比中提取特定的条目。然而,有时如果涉及复杂的张量运算,这些上游梯度可能很难找到(至少对我来说)。
- PyTorch 在计算图的叶节点上累积梯度。而且,你原来的代码行
t = torch.reshape(t, (3,1))
失去叶节点的句柄,并且t
现在指的是中间节点而不是叶节点。为了访问叶节点,我创建了张量a
.