如何使用pytorch构建多任务DNN,例如超过100个任务?

2024-04-29

下面是使用 pytorch 为两个回归任务构建 DNN 的示例代码。这forward函数返回两个输出(x1,x2)。用于大量回归/分类任务的网络怎么样?例如,100 或 1000 个输出。对所有输出(例如 x1、x2、...、x100)进行硬编码绝对不是一个好主意。有一个简单的方法可以做到这一点吗?谢谢。

import torch
from torch import nn
import torch.nn.functional as F

class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.lin1 = nn.Linear(5, 10)
        self.lin2 = nn.Linear(10, 3)
        self.lin3 = nn.Linear(10, 4)

    def forward(self, x):
        x = self.lin1(x)
        x1 = self.lin2(x)
        x2 = self.lin3(x)
        return x1, x2

if __name__ == '__main__':
    x = torch.randn(1000, 5)
    y1 = torch.randn(1000, 3)
    y2 = torch.randn(1000,  4)
    model = mynet()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    for epoch in range(100):
        model.train()
        optimizer.zero_grad()
        out1, out2 = model(x)
        loss = 0.2 * F.mse_loss(out1, y1) + 0.8 * F.mse_loss(out2, y2)
        loss.backward()
        optimizer.step()

您可以(并且应该)使用nn容器 https://pytorch.org/docs/stable/nn.html#containers例如nn.ModuleList https://pytorch.org/docs/stable/nn.html#modulelist or nn.ModuleDict https://pytorch.org/docs/stable/nn.html#moduledict管理任意数量的子模块。

例如(使用nn.ModuleList https://pytorch.org/docs/stable/nn.html#modulelist):

class MultiHeadNetwork(nn.Module):
    def __init__(self, list_with_number_of_outputs_of_each_head):
        super(MultiHeadNetwork, self).__init__()
        self.backbone = ...  # build the basic "backbone" on top of which all other heads come
        # all other "heads"
        self.heads = nn.ModuleList([])
        for nout in list_with_number_of_outputs_of_each_head:
            self.heads.append(nn.Sequential(
              nn.Linear(10, nout * 2),
              nn.ReLU(inplace=True),
              nn.Linear(nout * 2, nout)))

    def forward(self, x):
        common_features = self.backbone(x)  # compute the shared features
        outputs = []
        for head in self.heads:
            outputs.append(head(common_features))
        return outputs

请注意,在此示例中,每个头比单个头更复杂nn.Linear layer.
不同“头”的数量(以及输出的数量)由参数的长度决定list_with_number_of_outputs_of_each_head.


重要的提醒:使用是至关重要的nn容器 https://pytorch.org/docs/stable/nn.html#containers,而不是简单的 pythonic 列表/字典来存储所有子模块。否则 pytorch 将难以管理所有子模块。
参见,例如,这个答案 https://stackoverflow.com/a/59279872/1714410, 这个问题 https://stackoverflow.com/q/54678896/1714410 and this one https://stackoverflow.com/q/57320958/1714410.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用pytorch构建多任务DNN,例如超过100个任务? 的相关文章

随机推荐