Python中的嵌入层:如何正确使用Torchsummary?

2024-02-29

这是一个最低限度工作/可重现的示例:

import torch
import torch.nn as nn
from torchsummary import summary

class Network(nn.Module): 
    def __init__(self, channels_img, features_d, num_classes, img_size): 
        super(Network, self).__init__()
        self.img_size = img_size
        self.disc = nn.Conv2d(
            in_channels = channels_img + 1, 
            out_channels = features_d, 
            kernel_size = (4,4)
        )

        # ConditionalGan: 
        self.embed = nn.Embedding(
            num_embeddings = num_classes, 
            embedding_dim = img_size * img_size
        )

   def forward(self, x, labels): 
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim = 1)
        return self.disc(x) 
    
# device: 
device = torch.device("cpu")

# hyperparameter: 
batch_size = 64

# Initialize model: 
model = Network(
    channels_img = 1, 
    features_d = 16, 
    num_classes = 10, 
    img_size = 28).to(device) 

# Print model summary: 
summary(
    model, 
    input_size = [(1, 28, 28), (1, 28, 28)], # MNIST
    batch_size = batch_size
)

我收到的错误消息是(对于带有summary(...)):

Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

我在这看到post https://stackoverflow.com/questions/56360644/pytorch-runtimeerror-expected-tensor-for-argument-1-indices-to-have-scalar-t, that .to(torch.int64)应该有帮助,但老实说我不知道​​该写在哪里。

谢谢你!


问题就出在这里:

self.embed(labels)...

如上所述,嵌入层是离散索引和连续值之间的映射here https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html。也就是说,它的输入应该是整数,并且它将返回浮点数。例如,就您而言,您是嵌入MNIST 的类标签,范围从 0 到 9,到一个连续体(出于某种原因,我不知道,因为我不熟悉 GAN :))。但简而言之,该嵌入层将给出以下变换:10 -> 784PyTorch 说,对你来说,这 10 个数字应该是整数。

整数类型的一个奇特名称是“long”,因此您需要确保输入的数据类型self.embed属于那种类型。有一些方法可以做到这一点:

self.embed(labels.long())

or

self.embed(labels.to(torch.long))

or

self.embed(labels.to(torch.int64))

Long 数据类型实际上是一个 64 位整数(你可能会看到here https://pytorch.org/docs/stable/tensor_attributes.html),所以所有这些都有效。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python中的嵌入层:如何正确使用Torchsummary? 的相关文章

随机推荐