您可以(并且应该)使用nn容器 https://pytorch.org/docs/stable/nn.html#containers例如nn.ModuleList https://pytorch.org/docs/stable/nn.html#modulelist or nn.ModuleDict https://pytorch.org/docs/stable/nn.html#moduledict管理任意数量的子模块。
例如(使用nn.ModuleList https://pytorch.org/docs/stable/nn.html#modulelist):
class MultiHeadNetwork(nn.Module):
def __init__(self, list_with_number_of_outputs_of_each_head):
super(MultiHeadNetwork, self).__init__()
self.backbone = ... # build the basic "backbone" on top of which all other heads come
# all other "heads"
self.heads = nn.ModuleList([])
for nout in list_with_number_of_outputs_of_each_head:
self.heads.append(nn.Sequential(
nn.Linear(10, nout * 2),
nn.ReLU(inplace=True),
nn.Linear(nout * 2, nout)))
def forward(self, x):
common_features = self.backbone(x) # compute the shared features
outputs = []
for head in self.heads:
outputs.append(head(common_features))
return outputs
请注意,在此示例中,每个头比单个头更复杂nn.Linear
layer.
不同“头”的数量(以及输出的数量)由参数的长度决定list_with_number_of_outputs_of_each_head
.
重要的提醒:使用是至关重要的nn容器 https://pytorch.org/docs/stable/nn.html#containers,而不是简单的 pythonic 列表/字典来存储所有子模块。否则 pytorch 将难以管理所有子模块。
参见,例如,这个答案 https://stackoverflow.com/a/59279872/1714410, 这个问题 https://stackoverflow.com/q/54678896/1714410 and this one https://stackoverflow.com/q/57320958/1714410.