如何计算 CNN 第一个线性层的维度

2024-04-27

目前,我正在使用 CNN,其中附加了一个完全连接的层,并且我正在使用尺寸为 32x32 的 3 通道图像。我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积/最大池层的输入。我希望能够计算第一个线性层的尺寸,仅给出最后一个 conv2d 层和 maxpool 的信息。换句话说,我希望能够计算该值,而不必使用之前层的信息(因此我不必手动计算非常深的网络的权重维度)

我还想了解可接受尺寸的计算,例如这些计算的推理是什么?

由于某种原因,这些计算有效并且 Pytorch 接受了这些尺寸:

val = int((32*32)/4)
self.fc1 = nn.Linear(val, 200)

这也有效

self.fc1 = nn.Linear(64*4*4, 200)

为什么这些值有效?这些方法的计算是否有限制?例如,我觉得如果我要改变步幅距离或内核大小,这就会中断。

这是我正在使用的一般模型架构:

# define the CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)  


        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2,2)
        
        self.dropout = nn.Dropout(0.25)

        # H*W/4
        val = int((32*32)/4)
        #self.fc1 = nn.Linear(64*4*4, 200)
        ################################################
        self.fc1 = nn.Linear(val, 200)  # dimensions of the layer I wish to calculate
        ###############################################
        self.fc2 = nn.Linear(200,100)
        self.fc3 = nn.Linear(100,10)


    def forward(self, x):
        # add sequence of convolutional and max pooling layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        #print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

# create a complete CNN
model = Net()
print(model)

谁能告诉我如何计算第一个线性层的尺寸并解释其推理?


给定输入空间维度 w,2d 卷积层将在该维度上输出具有以下大小的张量:

int((w + 2*p - d*(k - 1) - 1)/s + 1)

完全相同的情况也适用于nn.MaxPool2d https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html。作为参考,您可以在此处查找PyTorch 文档 https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

模型的卷积部分由三个 (Conv2d + MaxPool2d) 块组成。您可以使用此辅助函数轻松推断输出的空间维度大小:

def conv_shape(x, k=1, p=0, s=1, d=1):
    return int((x + 2*p - d*(k - 1) - 1)/s + 1)

递归调用它,您将得到最终的空间维度:

>>> w = conv_shape(conv_shape(32, k=3, p=1), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)

>>> w
2

由于您的卷积具有平方内核和相同的步幅、填充(水平等于垂直),因此上述计算适用于张量的宽度和高度尺寸。最后看最后一个卷积层conv3,它有 64 个过滤器,在全连接层之前每批元素的最终元素数量为:w*w*64, i.e. 256.


但是,没有什么可以阻止您调用图层来找出输出形状!

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten())

        n_channels = self.feature_extractor(torch.empty(1, 3, 32, 32)).size(-1)

        self.classifier = nn.Sequential(
            nn.Linear(n_channels, 200),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(100, 10))

    def forward(self, x):
        features = self.feature_extractor(x)
        out = self.classifier(features)
        return out

model = Net()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何计算 CNN 第一个线性层的维度 的相关文章

随机推荐

  • curl:(7)无法连接到192.168.99.100端口31591:连接被拒绝

    这些是我的豆荚 hello kubernetes 5569fb7d8f 4rkhs 0 1 ImagePullBackOff 0 5d2h hello minikube 5857d96c67 44kfg 1 1 Running 1 5d2h
  • “char *_EXFUN(index,(const char *, int));”的含义

    我发现这是 eclipse idexer intelisence 的一个命题 无论它叫什么 就是这样 char EXFUN index const char int 首先 它看起来像一个返回 char 指针的函数 但参数 如果它是一个函数
  • 使用 M1 在 dockerized Linux 上安装节点画布

    我有以下Dockerfile我在 MacBook Air M1 上运行 所以在 docker 中我有带有 M1 的 linux FROM node 16 7 0 WORKDIR work CMD while true do sleep 10
  • Spirit qi 解析为嵌套函数的抽象语法树

    我正在尝试使用 boost 的spirit qi 解析器创建一个解析器 它正在解析包含三种类型值的字符串 常量 变量或函数 这些函数可以相互嵌套 测试字符串是f a b f g z x g x h x c where a e是常数 f r是
  • 如何获得修改任何参数的函数?

    我的目标 我必须创建一个将两个分数相加的函数 我定义了一个新的struct typedef 称为fraction 该函数不能有返回类型fraction 它一定要是void 因此它必须修改输入的参数之一 我该如何实现这一点 也许是指点 您将如
  • 从状态栏中删除通知图标

    我在状态栏中显示一个图标 现在我想在打开该内容时立即删除该图标 一段时间后如果我们收到任何警报 该图标将再次显示 我怎样才能做到这一点 使用NotificationManager取消您的通知 您只需提供您的通知 ID https devel
  • 将一个表的所有行复制到另一个表

    我有两个数据库MySQL and SQL Server 我想在其中创建表SQL Server并复制表中的所有行MySQL到新表中SQL Server 我可以在中创建表SQL Server与 一样MySQL 使用以下代码 List
  • 自 2012 年以来,WinSock 注册 IO 性能是否有所下降?

    我最近使用 MS 为该 API 提供的稍微可接受的文档编写了基于 WinSock Registered IO RIO 的 UDP 接收 最终的性能非常令人失望 单套接字性能有些稳定 约为每秒 180k 数据包 使用多个 RSS 队列 即多个
  • 选择从查询中检索列名称的列

    我正在寻找一种优雅的方法来从表 A 中选择列 其中列名是从表 B 上的查询中检索的 对表 B 的查询结果 col01 表 A 有几个名为 col01 col02 col03 最终查询应该是为了结果 result from B effecti
  • 根据区域设置获取货币 ISO 4217 代码

    假设我用以下命令解析 HTTP Accept Language 标头Locale acceptFromHttp http www php net manual en locale acceptfromhttp php是否有一种简单可靠的方法
  • Java 中的字符串拆分:可变长度的前向和后向

    我想使用数字作为分隔符来破坏 Java 中的字符串 但保留数字 一些研究表明 使用 String 中的 split method 是合适的 但我不明白如何做到这一点 为了进一步解释我的问题 我将使用一个例子 Input 20 55 50 0
  • 使用 VBA 从分布生成随机数到内存

    我想从 VBA Excel 2007 中选定的分布生成随机数 我目前正在使用带有以下代码的分析工具库 Application Run ATPVBAEN XLAM Random A B C D E F Where A how many var
  • 如何在 POSIXct 中获取一天的开始

    我的一天开始于2016 03 02 00 00 00 Not 2016 03 02 00 00 01 我如何开始一天的工作POSIXct当地时间 我的困惑可能来自于 R 认为这是 2016 03 01 的结束日期这一事实 鉴于 R 使用 I
  • 如何减少基于位置的 Android 应用程序的功耗?

    如何减少应用程序的功耗 我可以使用什么代码来实现这个 有几种不同的方法可以减少尝试获取位置信息时所用的电量 Use the 最后已知位置 http developer android com reference android locati
  • HtmlAgilityPack 设置节点 InnerText

    我想用其他文本替换 HTML 标签的内部文本 我正在使用 HtmlAgilityPack我使用这段代码来提取所有文本 HtmlDocument doc new HtmlDocument doc Load some path foreach
  • Spring MVC:在表单处理操作中有多个@ModelAttribute

    上下文 我在两个实体之间有一个简单的关联 Category and Email NtoM 我正在尝试创建用于浏览和管理它们的网络界面 要浏览类别并将电子邮件添加到该类别中 我使用包含以下内容的控制器 RequestMapping带有类别 I
  • 使用线程或异步任务的位图工厂动画

    这个问题是我在这个论坛上提出的多个问题的后续问题 这些问题涉及为什么我一直在尝试的动画不起作用 简单回答一下之前的问题 我的动画作为 2 个班级的单独项目工作 但无法工作 当包含在我的包含多个类的项目中时 使用 finish 类关闭了导致我
  • 如何在 Visual Studio Code 中的事件上使用 JSDoc 自定义 EventEmitter?

    我一直致力于 Node js 项目 只是注意到 Visual Studio Code 提供了有关基本 EventEmitter 对象的信息 所以我想也应该可以为自定义提供 JSDoc 我已经尝试遵循 JSDochttp usejsdoc o
  • sql查询使用pivot动态添加会计月份

    ALTER PROCEDURE dbo sp GetDMActivityTrackerReport CoachId VARCHAR 7 Month INT FiscalYear INT AS BEGIN INSERT FiscalMonth
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出