class EnDecoder(nn.Module):
def __init__(self):
super(EnDecoder,self).__init__()
# 定义Encoder
self.Encoder = nn.Sequential(
nn.Linear(784,512),
nn.Tanh(),
nn.Linear(512,256),
nn.Tanh(),
nn.Linear(256,128),
nn.Tanh(),
nn.Linear(128,3),
nn.Tanh()
)
# 定义Decoder
self.Decoder = nn.Sequential(
nn.Linear(3,128),
nn.Tanh(),
nn.Linear(128,256),
nn.Tanh(),
nn.Linear(256,512),
nn.Tanh(),
nn.Linear(512,784),
nn.Sigmoid()
)
# 定义网络的前向传播路径
def forward(self,x):##----------------------》这是正确的
encoder = self.Encoder(x)
decoder = self.Decoder(encoder)
return encoder,decoder
def forward(self,x):##------------》这是不正确的,会出现NotImplementedError
encoder = self.Encoder(x)
decoder = self.Decoder(encoder)
return encoder,decoder
如果在学习自动编码器的时候出现NotImplementedError,一定要检查forward方法是否有缩进问题。