我目前正在尝试延长a model https://github.com/microsoft/MASS这是基于 FairSeq/PyTorch 的。在训练过程中,我需要训练两个编码器:一个使用目标样本,另一个使用源样本。
所以当前的forward函数看起来像这样:
def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
并以此为基础这个想法 https://github.com/golsun/SpaceFusion我想要这样的东西:
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
return decoder_out
有什么办法可以做到这一点吗?
编辑:
这些是我所面临的限制,因为我需要扩展FairseqEncoderDecoder模型:
@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
编辑2:
传递给 Fairseq 中的前向函数的参数可以通过实现您自己的标准来更改,请参见示例交叉熵准则 https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L28, where sample['net_input']
被传递到__call__
模型的函数,它调用forward
method.