我正在尝试创建一个transform
打乱批次中每个图像的补丁。
我的目标是以与其他转换相同的方式使用它torchvision
:
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ShufflePatches(patch_size=(16,16)) # our new transform
])
更具体地说,输入是BxCxHxW
张量。我想将批次中的每个图像分割成大小为 patch_size 的不重叠的补丁,将它们打乱,然后重新组合成单个图像。
给定图像(大小224x224
):
Using ShufflePatches(patch_size=(112,112))
我想生成输出图像:
我认为解决方案与torch.unfold
and torch.fold
,但没能取得进一步的进展。
任何帮助,将不胜感激!
Indeed unfold and fold https://stackoverflow.com/a/53972525/1714410在这种情况下似乎合适。
import torch
import torch.nn.functional as nnf
class ShufflePatches(object):
def __init__(self, patch_size):
self.ps = patch_size
def __call__(self, x):
# divide the batch of images into non-overlapping patches
u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
return f
Here's an example with patch size=16:
![enter image description here](https://i.stack.imgur.com/SJXrK.png)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)