Pytorch:获取最终层的正确尺寸

2024-02-28

Pytorch 新手来了!我正在尝试微调 VGG16 模型来预测 3 个不同的类别。我的部分工作涉及将 FC 层转换为 CONV 层。但是,我的预测值不会落在 0 到 2(3 个类别)之间。

有人可以向我指出有关如何计算最后一层的正确尺寸的好资源吗?

以下是 VGG16 的原始 fC 层:

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

我将 FC 层转换为 CONV 的代码:

 def convert_fc_to_conv(self, fc_layers):
        # Replace first FC layer with CONV layer
        fc = fc_layers[0].state_dict()
        in_ch = 512
        out_ch = fc["weight"].size(0)
        first_conv = nn.Conv2d(512, out_ch, kernel_size=(1, 1), stride=(1, 1))

        conv_list = [first_conv]
        for idx, layer in enumerate(fc_layers[1:]):
            if isinstance(layer, nn.Linear):
                fc = layer.state_dict()
                in_ch = fc["weight"].size(1)
                out_ch = fc["weight"].size(0)
                if idx == len(fc_layers)-4:
                    in_ch = 3
                conv = nn.Conv2d(out_ch, in_ch, kernel_size=(1, 1), stride=(1, 1))
                conv_list += [conv]
            else:
                conv_list += [layer]
            gc.collect()

        avg_pool = nn.AvgPool2d(kernel_size=2, stride=1, ceil_mode=False)
        conv_list += [avg_pool, nn.Softmax()]
        top_layers = nn.Sequential(*conv_list)  
        return top_layers

最终模型架构:

    Model(
    (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))

    (classifier): Sequential(
    (0): Conv2d(512, 4096, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Conv2d(4096, 3, kernel_size=(1, 1), stride=(1, 1))
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): AvgPool2d(kernel_size=2, stride=1, padding=0)
    (7): Softmax()
  )
)

模型总结:

            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Conv2d-32           [-1, 4096, 7, 7]       2,101,248
             ReLU-33           [-1, 4096, 7, 7]               0
          Dropout-34           [-1, 4096, 7, 7]               0
           Conv2d-35              [-1, 3, 7, 7]          12,291
             ReLU-36              [-1, 3, 7, 7]               0
          Dropout-37              [-1, 3, 7, 7]               0
        AvgPool2d-38              [-1, 3, 6, 6]               0
          Softmax-39              [-1, 3, 6, 6]               0

我编写了一个函数,以 Pytorch 模型作为输入,并将分类层转换为卷积层。目前它适用于 VGG 和 Alexnet,但您也可以将其扩展到其他模型。

import torch
import torch.nn as nn
from torchvision.models import alexnet, vgg16

def convolutionize(model, num_classes, input_size=(3, 224, 224)):
    '''Converts the classification layers of VGG & Alexnet to convolutions

    Input:
        model: torch.models
        num_classes: number of output classes
        input_size: size of input tensor to the model

    Returns:
        model: converted model with convolutions
    '''
    features = model.features
    classifier = model.classifier

    # create a dummy input tensor and add a dim for batch-size
    x = torch.zeros(input_size).unsqueeze_(dim=0)

    # change the last layer output to the num_classes
    classifier[-1] = nn.Linear(in_features=classifier[-1].in_features,
                               out_features=num_classes)

    # pass the dummy input tensor through the features layer to compute the output size
    for layer in features:
        x = layer(x)

    conv_classifier = []
    for layer in classifier:
        if isinstance(layer, nn.Linear):
            # create a convolution equivalent of linear layer
            conv_layer = nn.Conv2d(in_channels=x.size(1),
                                   out_channels=layer.weight.size(0),
                                   kernel_size=(x.size(2), x.size(3)))

            # transfer the weights
            conv_layer.weight.data.view(-1).copy_(layer.weight.data.view(-1))
            conv_layer.bias.data.view(-1).copy_(layer.bias.data.view(-1))
            layer = conv_layer

        x = layer(x)
        conv_classifier.append(layer)

    # replace the model.classifier with newly created convolution layers
    model.classifier = nn.Sequential(*conv_classifier)

    return model

def visualize(model, input_size=(3, 224, 224)):
    '''Visualize the input size though the layers of the model'''
    x = torch.zeros(input_size).unsqueeze_(dim=0)
    print(x.size())
    for layer in list(model.features) + list(model.classifier):
        x = layer(x)
        print(x.size())

这是输入通过模型时的样子

_vgg = vgg16()
vgg = convolutionize(_vgg, 100)
print('\n\nVGG')
visualize(vgg)

...

VGG
torch.Size([1, 3, 224, 224])
torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 112, 112])
torch.Size([1, 128, 112, 112])
torch.Size([1, 128, 112, 112])
torch.Size([1, 128, 112, 112])
torch.Size([1, 128, 112, 112])
torch.Size([1, 128, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 7, 7])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 4096, 1, 1])
torch.Size([1, 100, 1, 1])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch:获取最终层的正确尺寸 的相关文章

随机推荐

  • 通过更短的拖动使 ViewPager 对齐

    有什么办法可以让支持包ViewPager用更短的拖动来捕捉到下一页吗 默认行为似乎是 即使我拖动近 75 当我放开时 页面仍然会弹回到上一页 我想缩短捕捉阈值并使 ViewPager 捕捉到下一页 请注意 这适用于拖动手势 猛击手势已经需要
  • 管理频繁数据库轮询的良好 C#.NET 解决方案

    我目前正在开发一个 c NET 桌面应用程序 该应用程序将通过 WCF 和 WCF 数据服务通过互联网与数据库进行通信 应用程序中有许多地方可能需要每隔一段时间刷新一次 最简单的解决方案是将这些区域放在计时器上并重新查询数据库 然而 由于有
  • 为另一个分区/目录运行 apt-get?

    我已经从 Live Ubuntu CD 启动了系统 并且需要修复一些软件包问题 我已经安装了硬盘 现在我想像正常启动一样运行 apt get 即更改 apt get 的工作目录 以便它可以在我的硬盘上工作 我以前做过这个 但我不记得语法了
  • 如何在 CURL 重定向上传递 cookie?

    想象以下场景 我打开一个 CURL 连接并通过 POST 传递一些 XML Logindata 服务器以 302 重定向进行响应 其中设置会话 cookie 并将我重定向到以下 欢迎 页面 如果我启用 FOLLOWLOCATION 则重定向
  • TypeScript:从装饰器推断返回类型?

    当装饰器更改其返回类型时 如何让 TypeScript 推断装饰方法的类型 在下面的基本示例中 我装饰一个方法以返回字符串化对象 function jsonStringify return function target decorated
  • AutoMapper 展平相同类型的复杂对象

    我在映射以下复杂类型时遇到问题 RequestDTO int OldUserId string OldUsername int NewUserId string NewUsername Request User OldUser User N
  • React内联样式中的CSS伪代码“li::before”

    我有一个名为 ExplanationLists 的 React 组件 我想将动态内联样式添加到li带有 css 伪代码的 html 元素li after 这样我可以更好地用图形来设计要点 例如 li before content dynam
  • 如何使用具有多个输入参数的 HttpGet 属性? (并大摇大摆地工作)

    它与下面的代码配合得很好 我只有一个参数 但如何处理两个输入参数 如果我只使用 HttpGet 则不会发送任何参数 尽管它在 Swagger 之外工作正常 帮助 HttpGet Consumes application json HttpG
  • 使用 angular2 进行服务器端渲染是什么?

    我知道 angular2 用于服务器端渲染 所以我想了解更多 我对这种现象有以下疑问 1 什么是服务端渲染 2 它解决什么问题 3 它的应用有哪些 4 为什么使用服务端渲染 5 支持服务端渲染的技术有哪些 6 在Angular2中 服务器端
  • 如何在 git 上恢复旧提交中的特定文件

    我想我的问题很接近这个one https stackoverflow com questions 20971306 hg how do i revert a single file several commits back 但我正在使用 g
  • 使用 SQL Developer 或 Toad 等 IDE 工具的 Oracle 并行查询行为

    一段时间以来 我一直在努力抽出时间来写这个问题并尽可能地解释这个问题 所以请提前原谅我的长文 我的环境 Oracle Database 12 2 在 Red Hat 7 R A C 2 个节点 上运行 每个节点 16CPU 和 64GB R
  • 在一个 SELECT 语句中设置两个标量变量?

    我想做这个 Declare a int Declare b int SET a b SELECT StartNum EndNum FROM Users Where UserId 1223 PRINT a PRINT b 但这是无效的语法 如
  • 如何在 Gatsby URL 中添加发布日期?

    All the Gatsby 入门演示 https github com gatsbyjs gatsby gatsby starters有一条像这样的路径 gatsby starter blog hi folks 我该如何设置 2015 0
  • Cron 作业 + Twitter

    从 12 30 开始 一直到 1 30 2 30 等 我的应用程序每小时都会发布一条静态推文 我目前正在使用 themattharris 的 twitter API 我也有一个 cron 工作 30 php q home1 USER NAM
  • PyQt5 和 datetime.datetime.strptime 之间的冲突

    所以我正在编写一个工具 可以使用基于 python 3 52 和 Qt5 的图形用户界面从文件中读取时间 最少的操作 datetime datetime strptime Tue a 在隔离环境中工作 输出 1900 01 01 00 00
  • 在 php 5.5 中使用什么来代替 apc 用户数据缓存?

    PHP 5 5 默认包含 zend opcache 这基本上意味着几乎没有人会使用 APC 但是用什么来代替 APC 的用户数据缓存部分 apc store apc fetch 类似 呢 我真正喜欢使用 APC 用户数据缓存的一个用例是静态
  • 如何在vba中另存为.txt

    我希望让我的宏将我创建的新工作表保存为 txt 文件 这是我到目前为止的代码 Sub Move Move Macro Keyboard Shortcut Ctrl m Sheets Sheet1 Select Range A1 Select
  • 绘制完成后清除CGPath路径

    我已经在 iOS 中编写了一个在 TouchMoved 方法中绘图的程序 CGContextAddPath UIGraphicsGetCurrentContext path CGPathMoveToPoint path NULL lastP
  • OpenCV - Java:inRange 函数

    我有我的形象mRgba当我这样做时 Core inRange mRgba B1 B2 mRgba 我得到了我期望的结果 我的所有 RGBA 图像的阈值都在 B1 和 B2 之间 现在我想这样做 Mat roi mRgba submat re
  • Pytorch:获取最终层的正确尺寸

    Pytorch 新手来了 我正在尝试微调 VGG16 模型来预测 3 个不同的类别 我的部分工作涉及将 FC 层转换为 CONV 层 但是 我的预测值不会落在 0 到 2 3 个类别 之间 有人可以向我指出有关如何计算最后一层的正确尺寸的好