-
torch代码计算
def paras_cnn(k,s,p,i=64):
x=torch.ones(1,1,i,i)
conv = torch.nn.Conv2d(1, 1, kernel_size=k, stride=s, padding=p)
convt= torch.nn.ConvTranspose2d(1, 1, kernel_size=k, stride=s, padding=p)
h1=conv(x)
h2=convt(x)
y=convt(h1)
print("conv(x):{} \t convT(x):{} \t convT(conv(x)):{}".format((h1.shape[2],h1.shape[3]),(h2.shape[2],h2.shape[3]),(y.shape[2],y.shape[3])))
return h1.shape[2],h1.shape[3],h2.shape[2],h2.shape[3],y.shape[2],y.shape[3]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)