像这样的东西会起作用吗?
from torch.nn import MaxPool1d
import torch.nn.functional as F
class ChannelPool(MaxPool1d):
def forward(self, input):
n, c, w, h = input.size()
input = input.view(n, c, w * h).permute(0, 2, 1)
pooled = F.max_pool1d(
input,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.ceil_mode,
self.return_indices,
)
_, _, c = pooled.size()
pooled = pooled.permute(0, 2, 1)
return pooled.view(n, c, w, h)
或者,使用einops
from torch.nn import MaxPool1d
import torch.nn.functional as F
from einops import rearrange
class ChannelPool(MaxPool1d):
def forward(self, input):
n, c, w, h = input.size()
pool = lambda x: F.max_pool1d(
x,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.ceil_mode,
self.return_indices,
)
return rearrange(
pool(rearrange(input, "n c w h -> n (w h) c")),
"n (w h) c -> n c w h",
n=n,
w=w,
h=h,
)