import torch.nn as nn
import torch
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, 3, stride=2, padding=1)
self.layer1 = nn.Conv2d(6,12,3,stride=2, padding=1)
self.layer2 = nn.Conv2d(12, 24, 3, stride=2, padding=1)
self.layer3 = nn.Conv2d(24, 48, 3, stride=2, padding=1)
def forward(self, x):
x = self.conv(x)
x = self.layer1(x)
x = self.layer2(x)
x = x.detach()
x = self.layer3(x)
x = torch.mean(x)
x.backward()
print(self.layer2.weight.grad) # None
print(self.layer3.weight.grad) # Tensor Size([48, 24, 3, 3)
if __name__ == '__main__':
data = torch.randn(1, 3, 224, 224)
n = net()
pred = n(data)