去掉batch_size这一维度,按照矩阵乘法运算即可
自行体会
class PAM_Module(nn.Module):
""" Position attention module"""
def __init__(self, in_dim) -> None:
super().__init__()
self.channel_in = in_dim
self.query_conv = Conv2d(in_dim, in_dim // 8, kernel_size=1)
self.key_conv = Conv2d(in_dim, in_dim//8, kernel_size=1)
self.value_conv = Conv2d(in_dim, in_dim, kernel_size=1)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim = -1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
## B C H W
m_batchsize, C, height, width = x.size()
## B C H W --> B C HW --> B HW C
proj_query = self.query_conv(x).view(m_batchsize, -1, height*width).permute(0, 2, 1)
## B C HW
proj_key = self.key_conv(x).view(m_batchsize, -1, height*width)
## B HW HW
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
## B C HW
proj_value = self.value_conv(x).view(m_batchsize, -1, height*width)
## bmm(B C HW, B HW HW) --> B C HW
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
## B C H W
out = out.view(m_batchsize, C, height, width)
## B C H W
out = self.gamma*out + x
return out