深度学习(20)—— ConvNext 使用

2023-11-05

深度学习(20)—— ConvNext 使用

本篇主要使用convnext做分类任务,其中使用convnext-tiny,其主要有5块

Part 1 Model

model.py

# -*- coding: utf-8 -*-
"""
original code from facebook research:
https://github.com/facebookresearch/ConvNeXt
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise ValueError(f"not support data format '{self.data_format}'")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            # [batch_size, channels, height, width]
            mean = x.mean(1, keepdim=True)
            var = (x - mean).pow(2).mean(1, keepdim=True)
            x = (x - mean) / torch.sqrt(var + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_rate (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_last")
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # [N, C, H, W] -> [N, H, W, C]
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # [N, H, W, C] -> [N, C, H, W]

        x = shortcut + self.drop_path(x)
        return x


class MiniConvNext(nn.Module):
    r""" ConvNeXt
            A PyTorch impl of : `A ConvNet for the 2020s`  -
              https://arxiv.org/pdf/2201.03545.pdf
        Args:
            in_chans (int): Number of input image channels. Default: 3
            num_classes (int): Number of classes for classification head. Default: 1000
            depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
            dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
            drop_path_rate (float): Stochastic depth rate. Default: 0.
            layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
            head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
        """

    def __init__(self, in_chans: int = 3, num_classes: int = 1000, depths: list = None,
                 dims: list = None, drop_path_rate: float = 0., layer_scale_init_value: float = 1e-6,
                 head_init_scale: float = 1.):
        super().__init__()
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
                             LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))
        self.downsample_layers.append(stem)

        # ¶ÔÓ¦stage2-stage4Ç°µÄ3¸ödownsample
        for i in range(3):
            downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                                             nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2))
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        # ¹¹½¨Ã¿¸östageÖжѵþµÄblock
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_rate=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value)
                  for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.2)
            nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        d_0 = self.downsample_layers[0](x)
        x_0 = self.stages[0](d_0)

        d_1 = self.downsample_layers[1](x_0)
        x_1 = self.stages[1](d_1)

        d_2 = self.downsample_layers[2](x_1)
        x_2 = self.stages[2](d_2)

        d_3 = self.downsample_layers[3](x_2)
        x_3 = self.stages[3](d_3)

        return x_3  #


class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(self, in_chans: int = 3, num_classes: int = 1000, depths: list = None,
                 dims: list = None, drop_path_rate: float = 0., layer_scale_init_value: float = 1e-6,
                 head_init_scale: float = 1.):
        super().__init__()
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
                             LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))
        self.downsample_layers.append(stem)

        # ¶ÔÓ¦stage2-stage4Ç°µÄ3¸ödownsample
        for i in range(3):
            downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                                             nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2))
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        # ¹¹½¨Ã¿¸östageÖжѵþµÄblock
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_rate=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value)
                  for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)
        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.2)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)

        return self.norm(x.mean([-2, -1])), x  # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, x_original = self.forward_features(x)
        x = self.head(x)
        return x


def convnext_tiny(num_classes: int):
    # https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
    model = ConvNeXt(depths=[3, 3, 9, 3],
                     dims=[96, 192, 384, 768],
                     num_classes=num_classes)
    return model


def convnext_small(num_classes: int):
    # https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
    model = ConvNeXt(depths=[3, 3, 27, 3],
                     dims=[96, 192, 384, 768],
                     num_classes=num_classes)
    return model


def convnext_base(num_classes: int):
    # https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
    # https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
    model = ConvNeXt(depths=[3, 3, 27, 3],
                     dims=[128, 256, 512, 1024],
                     num_classes=num_classes)
    return model


def convnext_large(num_classes: int):
    # https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
    # https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
    model = ConvNeXt(depths=[3, 3, 27, 3],
                     dims=[192, 384, 768, 1536],
                     num_classes=num_classes)
    return model


def convnext_xlarge(num_classes: int):
    # https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
    model = ConvNeXt(depths=[3, 3, 27, 3],
                     dims=[256, 512, 1024, 2048],
                     num_classes=num_classes)
    return model
'''

Part 2 Utils

# -*- coding: utf-8 -*-
"""
Created on Fri Sep  2 15:25:33 2022

@author: Lenovo
"""

import sys
import json
import pickle
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
import os
from PIL import Image
from torchvision import transforms
import random
from fund_detect.pre_deal import augment, reshape ,test_get_boxes
from fund_detect.src.utils import get_center
from swtf_tf_sgm.patch_process import patch2global,global2patch

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.BCEWithLogitsLoss()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred_classes = torch.sigmoid(pred).gt(0.5).int()
        accu_num += torch.eq(pred_classes.squeeze(1), labels.to(device)).sum()

        labels = labels.float()
        loss = loss_function(pred, labels.unsqueeze(-1).to(device))
        loss.backward()
        accu_loss += loss.detach()

        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}, lr: {:.5f}".format(
            epoch,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num,
            optimizer.param_groups[0]["lr"]
        )

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        optimizer.step()
        optimizer.zero_grad()

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num

@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
    loss_function = torch.nn.BCEWithLogitsLoss()
    model.eval()
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    accu_loss = torch.zeros(1).to(device)  # 累计损失

    sample_num = 0
    data_loader = tqdm(data_loader, file=sys.stdout)
    for step, data in enumerate(data_loader):
        images, labels = data
        sample_num += images.shape[0]

        pred = model(images.to(device))
        pred_classes = torch.sigmoid(pred).gt(0.5).int()
        accu_num += torch.eq(pred_classes.squeeze(1), labels.to(device)).sum()

        labels = labels.float()
        loss = loss_function(pred, labels.unsqueeze(-1).to(device))
        accu_loss += loss

        data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(
            epoch,
            accu_loss.item() / (step + 1),
            accu_num.item() / sample_num
        )

    return accu_loss.item() / (step + 1), accu_num.item() / sample_num

def create_lr_scheduler(optimizer,
                        num_step: int,
                        epochs: int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3,
                        end_factor=1e-6):
    assert num_step > 0 and epochs > 0
    if warmup is False:
        warmup_epochs = 0

    def f(x):
        """
        根据step数返回一个学习率倍率因子,
        注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
        """
        if warmup is True and x <= (warmup_epochs * num_step):
            alpha = float(x) / (warmup_epochs * num_step)
            # warmup过程中lr倍率因子从warmup_factor -> 1
            return warmup_factor * (1 - alpha) + alpha
        else:
            current_step = (x - warmup_epochs * num_step)
            cosine_steps = (epochs - warmup_epochs) * num_step
            # warmup后lr倍率因子从1 -> end_factor
            return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)


def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-5):
    # 记录optimize要训练的权重参数
    parameter_group_vars = {"decay": {"params": [], "weight_decay": weight_decay},
                            "no_decay": {"params": [], "weight_decay": 0.}}

    # 记录对应的权重名称
    parameter_group_names = {"decay": {"params": [], "weight_decay": weight_decay},
                             "no_decay": {"params": [], "weight_decay": 0.}}

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights

        if len(param.shape) == 1 or name.endswith(".bias"):
            group_name = "no_decay"
        else:
            group_name = "decay"

        parameter_group_vars[group_name]["params"].append(param)
        parameter_group_names[group_name]["params"].append(name)

    print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
    return list(parameter_group_vars.values())

def load_img(img_path, data_transform):
    img = Image.open(img_path)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    return img
    

Part 3 Training

  • convnext有很多种规格,一般做分类使用tiny
  • 可以自己写一个dataloader,但是分类任务有一个相对方便的函数datasets.ImageFolder,前提是需要一个这样结构的文件夹- 请添加图片描述
  • convnext-tiny主要有五块组成(一次是stage0,stage1,stage2,stage3,head),使用freeze_layers,第一次学习冻结除了head的全部层,在训练过程中逐层开始解冻

train.py

# -*- coding: utf-8 -*-
"""
Created on Fri Sep  2 15:25:18 2022

@author: Lenovo
"""

import os
import json
import torch
import torch.optim as optim
from pandas.core.frame import DataFrame
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from model import convnext_tiny as create_model # 导入tiny
from utils import get_params_groups, train_one_epoch, evaluate, setup_seed, create_lr_scheduler

global log
log = []
import time

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"using {device} device.")
    start_time = time.time()
    start_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime(start_time))

    image_path = '/data/home/yangjy/data/five_class/' # 数据地址,具体格式如上
 #   pretrained_weight = '/home/yangjy/projects/Jane_TF_classification/convnext/weights/convnext_tiny_1k_224_ema.pth'  # 使用convnext在imagenet上的权重作为初始权重
    pretrained_weight = '/home/yangjy/projects/Jane_TF_classification/convnext/weights/2022-12-12-00-29.pth' # 自已之前已将训练过的权重作为初始权重
    weight_path = f"/home/yangjy/projects/Jane_TF_classification/convnext/weights/{start_time}.pth" #训练权重保存位置
    csv_path = f'/home/yangjy/projects/Jane_TF_classification/convnext/results/train/{start_time}.csv'# 用于保存loss和acc的csv地址

    num_classes = 5 # 类别数目,主要用于创建模型时候最后一层全连接层

    batch_size = 64
    freeze_layers = False #将模型冻结一部分还是全部训练

    learning_rate = 1e-4
    weight_decay = 1e-3 
	#学习率衰减
    step_size = 20
    gamma = 0.95
	#早停策略
    early_stop_step = 30
    epochs = 200
    best_acc = 0.5

    data_transform = {
        "train": transforms.Compose([transforms.Resize([224, 224]),
                                     transforms.RandomHorizontalFlip(p=0.5),
                                     transforms.RandomVerticalFlip(p=0.5),
                                     transforms.ColorJitter(0.1, 0.1, 0.1),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize([224, 224]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    # 实例化训练数据集
    # 训练加载数据
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)  # 确保图片路径无误
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    # 具体分类写入json
    #    {"0": "grade0","1": "grade1","2": "grade2","3": "grade3","4": "grade4"}
    grade_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in grade_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=5)
    with open('./class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    # 转为dataloader型
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw)
    # 加载验证数据集
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))

    model = create_model(num_classes=num_classes).to(device)
	#其实这里可以使用另外一个变量进行判断,但是我懒就都使用pretrained_weight了,如果pretrained_weight的地址是ImageNet的地址(第一次训练)使用这个作为初始权重
    if pretrained_weight == '/home/yangjy/projects/Jane_TF_classification/convnext/weights/convnext_tiny_1k_224_ema.pth':
        assert os.path.exists(pretrained_weight), "weights file: '{}' not exist.".format(pretrained_weight)
        weights_dict = torch.load(pretrained_weight, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))
        print("Loaded convnext pretrained in ImageNet!")

    elif os.path.exists(pretrained_weight):
        model.load_state_dict(torch.load(pretrained_weight, map_location=device))
        print("Loaded weight pretrained in our data!")
    else:
        print("SORRY!   No pretrained weight!!")
	
    if freeze_layers == True: 
        for name, para in model.named_parameters():
            # 初次训练除head外,其他权重全部冻结,后面逐层(stage3-stage0)解冻
            if ("head" not in name) and ("stages.3" not in name) and ("stages.2" not in name) and ("stages.1" not in name) and ("stages.0" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    # pg = [p for p in model.parameters() if p.requires_grad]
    pg = get_params_groups(model, weight_decay=weight_decay)
    optimizer = optim.AdamW(pg, lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)  # 每十次迭代,学习率减半
    #    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), epochs,warmup=True, warmup_epochs=10)

    # 设置早停
    total_batch = 0
    last_decrease = 0
    min_loss = 1000
    flag = False

    for epoch in range(epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch, )
        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=validate_loader,
                                     device=device,
                                     epoch=epoch)
        lr_scheduler.step()

        if val_acc > best_acc:  # acc improve save weight
            best_acc = val_acc
            torch.save(model.state_dict(), weight_path)

        if val_loss < min_loss:  # loss decrease save epoch
            min_loss = val_loss

            last_decrease = total_batch
            print((min_loss, last_decrease))
        total_batch += 1

        if total_batch - last_decrease > early_stop_step:
            print("No optimization for a long time, auto-stopping...")
            flag = True
            break
        log.append([epoch, train_loss, val_loss, train_acc, val_acc])
    print('Finished Training')
    data = DataFrame(data=log, columns=['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc'])
    data.to_csv(csv_path)


if __name__ == '__main__':
    main()

注:convnext在训练过程中需要的学习率大,最好不要小于1e-5,我一般初始设为1e-3,完全解冻后设为1e-4。且weight-decay不像resnet设置的较小,convnext的weight-decay一般在5e-2到1e-4,我一般设1e-3

Part 4 Predict

# -*- coding: utf-8 -*-
import json
from PIL import Image
from torchvision import transforms
from model import convnext_tiny as create_model
from pandas.core.frame import DataFrame
from utils import setup_seed
import torch
import os
log = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
    [transforms.Resize([224, 224]),
     transforms.ToTensor(),
     transforms.Normalize([0.456, 0.485, 0.406], [0.224, 0.229, 0.225])])

num_classes = 5
grade = 'grade4'
img_dir = f'/home/yangjy/data/five_class/val/{grade}' # 图片地址
pretrained_weight_path = '/home/yangjy/projects/Jane_TF_classification/convnext/weights/2022-12-12-00-29.pth' # 权重
save_predict_path = f'/home/yangjy/projects/Jane_TF_classification/convnext/results/predict/{grade}.csv'# 保存结果地址

# read class_indict
json_path = '/home/yangjy/projects/Jane_TF_classification/convnext/code/class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
    class_indict = json.load(f)

model = create_model(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(pretrained_weight_path,map_location=device))
color_list = os.listdir(img_dir)
for picture in color_list:
    img_path = os.path.join(img_dir, picture)
    try:
        img = Image.open(img_path)
        img = data_transform(img)
        img = torch.unsqueeze(img, dim=0)

        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)
            predict_cla = torch.argmax(predict).numpy()
        res = [picture, class_indict[str(predict_cla)], grade]
        log.append(res)


    except Exception as e:
        print(e)
        continue

data = DataFrame(data=log, columns=['pic_name', 'predict_result', 'label'])
data.to_csv(save_predict_path)
print('Finished Predicting')




本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习(20)—— ConvNext 使用 的相关文章

  • Google App Engine queue.yaml 无法在开发服务器中工作

    我无法让 dev appserver py 识别我使用queue yaml 创建的自定义队列 他们没有出现在http localhost 8000 taskqueue http localhost 8000 taskqueue 当我尝试向其
  • 如何在 Windows 64 上安装 NumPy?

    NumPy 安装程序在注册表中找不到 python 路径 无法安装 需要 Python 2 5 版本 但在注册表中未找到该版本 OK 我必须修改注册表吗 我已经修改了 PATH 以指向Python25安装目录 我可以检查一下您使用的是什么安
  • Python 3 os.urandom

    在哪里可以找到完整的教程或文档os urandom 我需要获得一个随机 int 来从 80 个字符的字符串中选择一个字符 如果你只需要一个随机整数 你可以使用random randint a b 来自随机模块 http docs pytho
  • 如何在Python中流式传输和操作大数据文件

    我有一个相对较大 1 GB 的文本文件 我想通过跨类别求和来减小其大小 Geography AgeGroup Gender Race Count County1 1 M 1 12 County1 2 M 1 3 County1 2 M 2
  • 使用 pygame 显示 unicode 符号

    我检查了其他答案 但不明白为什么我的代码错误地显示 This is what I currently see https i stack imgur com 8tNIK png 这是关于文本渲染的相关代码 font pygame font
  • 为什么删除临时文件时出现WindowsError?

    我创建了一个临时文件 向创建的文件添加了一些数据 已保存 然后尝试将其删除 但我越来越WindowsError 编辑后我已关闭该文件 如何检查哪个其他进程正在访问该文件 C Documents and Settings Administra
  • Python - 来自 .进口

    我第一次尝试图书馆 我注意到解决图书馆内导入问题的最简单方法是使用如下结构 from import x from some module import y 我觉得这件事有些 糟糕 也许只是因为我不记得经常看到它 尽管公平地说我还没有深入研究
  • Python 中的流式传输管道

    我正在尝试使用 Python 将 vmstat 的输出转换为 CSV 文件 因此我使用类似的方法转换为 CSV 并将日期和时间添加为列 vmstat 5 python myscript py gt gt vmstat log 我遇到的问题是
  • 在 Django OAuth Toolkit 中安全创建新应用程序

    如何将 IsAdminUser 权限添加到 Django OAuth Toolkit 中的 o applications 视图 REST FRAMEWORK DEFAULT PERMISSION CLASSES rest framework
  • Emacs 24.x 上的 IPython 支持

    我对 IPython 与 Emacs 的集成感到困惑 从 Emacs 24 开始 Emacs 附带了自己的python el 该文件是否支持 IPython 还是仅支持 Python 另外 维基百科 http emacswiki org e
  • Django send_mail SMTPSenderRefused 530 与 gmail

    一段时间以来 我一直在尝试使用 Django 从我正在开发的网站接收电子邮件 现在 我还没有部署它 并且我正在使用Django开发服务器 我不知道这是否会影响它 这是我的 settings py 配置 EMAIL BACKEND djang
  • 在 Windows 上使用 apache mod_wsgi 运行 Flask 应用程序时导入冲突

    我允许您询问我在 Windows 上使用您的 mod wsgi portage 托管 Flask 应用程序时遇到的问题 我有两个烧瓶应用程序 由于导入冲突 只有一个可以同时存在 IE 如果请求申请 1 我有回复 然后 如果我请求应用程序 2
  • pytest:同一接口的不同实现的可重用测试

    想象一下我已经实现了一个名为的实用程序 可能是一个类 Bar在一个模块中foo 并为其编写了以下测试 测试 foo py from foo import Bar as Implementation from pytest import ma
  • 通过索引访问Python字典的元素

    考虑一个像这样的字典 mydict Apple American 16 Mexican 10 Chinese 5 Grapes Arabian 25 Indian 20 例如 我如何访问该字典的特定元素 例如 我想在对 Apple 的第一个
  • SMTP_SSL SSLError: [SSL: UNKNOWN_PROTOCOL] 未知协议 (_ssl.c:590)

    此问题与 smtplib 的 SMTP SSL 连接有关 当与 SMTP 无 ssl 连接时 它正在工作 在 SMTP SSL 中尝试相同的主机和端口时 出现错误 该错误仅基于主机 gmail 设置也工作正常 请检查下面的示例 如果 Out
  • 在系统托盘中隐藏 tkinter 窗口 [重复]

    这个问题在这里已经有答案了 我正在制作一个程序来提醒我朋友的生日 这样我就不会忘记祝福他们 为此 我制作了两个 tkinter 窗口 1 First one is for entering name and birth date 2 Sec
  • 在 Django 查询中使用 .extra(select={...}) 引入的值上使用 .aggregate() ?

    我正在尝试计算玩家每周玩游戏的次数 如下所示 player game objects extra select week WEEK games game date aggregate count Count week 但姜戈抱怨说 Fiel
  • python 线程安全可变对象复制

    Is 蟒蛇的copy http docs python org 2 library copy html模块线程安全吗 如果不是 我应该如何在 python 中以线程安全的方式复制 deepcopy 可变对象 蟒蛇的GIL http en w
  • 使用ssl和socket的python客户端身份验证

    我有一个 python 服务器 需要客户端使用证书进行身份验证 我如何制作一个客户端脚本 使用客户端证书由 python 中的服务器使用 ssl 和套接字模块进行身份验证 有没有仅使用套接字和 ssl 而不扭曲的示例 from OpenSS
  • 查找总和为给定数字的值组合的函数

    这个帖子查找提供的 Sum 值的组合 https stackoverflow com a 20194023 1561176呈现函数subsets with sum 它在数组中查找总和等于给定值的值的组合 但由于这个帖子已经有6年多了 我发这

随机推荐

  • 【转】do{...}while(0)的意义和用法

    转自 http blogread cn it article 5907 linux内核和其他一些开源的代码中 经常会遇到这样的代码 do while 0 这样的代码一看就不是一个循环 do while表面上在这里一点意义都没有 那么为什么要
  • 2020软件测试工程师面试题汇总(内含答案)-看完BATJ面试官对你竖起大拇指!

    2020最新软件测试面试题汇总 内附参考答案 测试技术面试题 1 什么是兼容性测试 兼容性测试侧重哪些方面 参考答案 兼容测试主要是检查软件在不同的硬件平台 软件平台上是否可以正常的运行 即是通常说的软件的可移植性 兼容的类型 如果细分的话
  • Spring MVC 学习总结(四)——校验与文件上传

    Spring MVC不仅是在架构上改变了项目 使代码变得可复用 可维护与可扩展 其实在功能上也加强了不少 验证与文件上传是许多项目中不可缺少的一部分 在项目中验证非常重要 首先是安全性考虑 如防止注入攻击 XSS等 其次还可以确保数据的完整
  • 【Python基础】python必会的10个知识点

    来源 Towards Data Science 作者 Soner Yildirim 编译 VK Python在数据科学生态系统中占据主导地位 我认为 占据主导地位的两大原因是相对容易学习和数据科学库的丰富选择 Python是一种通用语言 因
  • android main system log,Log中'main', 'system', 'radio', 'events'以及android log分析

    1log文件分类简介 实时打印 的主要有 logcat main logcat radio logcat events tcpdump 还有高通平台的还会有 QXDM 日志 状态信息 的有 adb shell cat proc kmsg a
  • vlan中ACL inbound与outbound详解

    关键字 华为ACL配置 Cisco ACL配置 Vlan ACL配置 ACL一般有两种应用场景 应用到交换机物理端口和应用到Vlan 场景一 应用到交换机物理端口 网络拓扑 PC连接在交换机的Gig0 0 1端口 实现目的 禁止在PC上能够
  • 在python中使用nohup命令说明

    nohup功能 nohup 是 no hang up 的缩写 就是不挂断的意思 如果你正在运行一个进程 而且你觉得在退出帐户时该进程还不会结束 那么可以使用nohup命令 该命令可以在你退出帐户 关闭终端之后继续运行相应的进程 1 代码 n
  • 第七讲:构造函数与析构函数

    第七讲 构造函数与析构函数 本讲基本要求 掌握 构造和析构函数概念 初始化 作用 理解 构造构函的重载 带参数的构造函数两种表达格式 重点 难点 构造和析构函数概念 初始化 作用 通过前两章的学习 我们已经对类和对象有了初步的了解 在本章中
  • python期末复习提纲

    1 注释 变量命名 缩进 2 数据输入字符串函数input 注意结果为字符串 3 字符串解析函数eval的使用 特别注意输入字符串可直接解析为组合数据类型 理解 将字符串类型转化为现有组合类型 list dict set 或现有定义的变量等
  • 01-Embedding层是什么?怎么理解?简单的评论情感分类实验

    文章目录 1 One hot编码 2 Embedding 3 语义理解中Embedding意义 4 文本评论 代码实验 1 One hot编码 要知道embedding的作用 首先要了解独热编码 one hot 假设现在有如下对应关系 那么
  • 2016年下半年信息安全工程师上午选择题及解析

    以下有关信息安全管理员职责的叙述 不正确的是 A 信息安全管理员应该对网络的总体安全布局进行规划 B 信息安全管理员应该对信息系统安全事件进行处理 C 信息安全管理员应该负责为用户编写安全应用程序 D 信息安全管理员应该对安全设备进行优化配
  • SSM基本系统架构设计(Spring、Spring MVC 、MyBatis)

    系统根据功能的不同 项目结构可以划分为以下几个层次 1 持久对象层 也称持久层或持久化层 该层由若干持久化类 实体类 组成 2 数据访问层 DAO 层 该层由若干DAO 接口和MyBatis 映射文件组成 接口的名称统一以Dao 结尾 且M
  • 总汇nexus 服务启动异常

    总汇nexus 服务启动异常 故障描述1 故障描述2 故障描述3 备份准备修复的数据库 故障描述1 nexus oss 3 一直运行得没什么问题 忽然发现运行特别慢然后到服务器去重启 莫名其妙一直启动失败 查看日志发现如下报错信息 2022
  • C51单片机晶振频率、时钟周期、状态周期、机器周期、指令周期和总线周期的关系

    一 晶振频率 1 英文全称 frequency oscillate 2 定义 晶体振荡器的固有频率 不能改变 3 如果外接12Mhz晶振 则晶振频率12Mhz 二 时钟周期 1 英文全称 Clock Cycle 为晶振频率12Mhz倒数 2
  • 面试题computed和watch的区别

    computed和watch的区别 1 英文翻译成中文 computed就是计算属性的意思 是用来计算出一个值的 这个值 我们在调用的时候 1 不需要加括号 2 根据依赖缓存 watch就是监听的意思 1 immediat表示是否第一次执行
  • Prometheus on k8s 部署与实战操作进阶篇

    文章目录 一 概述 二 常见的几款监控工具 1 kube prometheus 和 kube prometheus stack 区别 2 Prometheus Operator 和kube prometheus 或 kube prometh
  • openpose人体姿态估计

    参考博客 Openpose驾驶员危险驾驶检测 抽烟打电话 人体姿态识别模型 openpose OpenPose人体姿态识别项目是美国卡耐基梅隆大学 CMU 基于卷积神经网络和监督学习并以caffe为框架开发的开源库 可以实现人体动作 面部表
  • Spring MVC实例(增删改查)

    本文转载自 https www cnblogs com beast king p 5786752 html 作者 beast king 转载请注明该声明 数据库配置文件application context jdbc xml
  • 计算机视觉中自注意力构建块的PyTorch实现

    作者 AI Summer 编译 ronghuaiyang 导读 一个非常好用的git仓库 封装了非常全面的计算机视觉中的自注意力构建块 直接调用 无需重复造轮子了 git仓库地址 https github com The AI Summer
  • 深度学习(20)—— ConvNext 使用

    深度学习 20 ConvNext 使用 本篇主要使用convnext做分类任务 其中使用convnext tiny 其主要有5块 stage0 stage1 stage2 stage3 head 文章目录 深度学习 20 ConvNext