



















整个量化算法使用对称量化(-max, max)-> (-127, 127)


计算最佳阈值的方法:1. 统计激活值的直方图,2. 采用遍历的方法找到量化后KL散度最小时对应的最佳阈值。伪代码如下:








在量化时,不考虑是否有ReLU,全部量化在(-max, max)之间。(格林深瞳算法只考虑了带有ReLU的,即将(0,max)量化到(0,127)),这样就简化了运算,不用再分情况了。



from torch import nn
import torch

# the module that replace BN layer
class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        return x

# BN flod
def bn_folding(conv, bn):
    # ******************** BN parameter *********************
    mean = bn.running_mean
    std = torch.sqrt(bn.running_var + bn.eps)
    gamma = bn.weight
    beta = bn.bias
    # ******************* conv parameter********************
    w = conv.weight
    w_fold = w.clone()
    if conv.bias is not None:
        b = conv.bias
        b = mean.new_zeros(mean.shape)
    b_fold = b.clone()
    w_fold = w * (gamma / std).reshape([conv.out_channels, 1, 1, 1])
    b_fold = beta + (b - mean) * (gamma / std) 
    bnfold_conv = nn.Conv2d(conv.in_channels,
    bnfold_conv.weight.data = w_fold
    bnfold_conv.bias.data = b_fold
    return bnfold_conv

'''BN must be after convolution'''
def model_bn_folding(model):
    children = list(model.named_children())
    # children = list(model.named_modules())
    name_temp = None
    child_temp = None
    for name, child in children:
        #print(name, '   ', child)
        if isinstance(child, nn.BatchNorm2d):
            bnfold_conv = bn_folding(child_temp, child) # BN融合
            model._modules[name_temp] = bnfold_conv
            model._modules[name] = DummyModule()
            child_temp = None
        elif isinstance(child, nn.Conv2d):
            name_temp = name
            child_temp = child
    return model



import torch
from torch import nn

import torch.nn.functional as F
from quant_utils import ConvRelu, LinearRelu, DummyModule

# device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

'''BN must be after convolution'''
def model_relu_folding(model):
    children = list(model.named_children())
    # children = list(model.named_modules())
    name_temp = None
    child_temp = None
    is_conv = True
    for name, child in children:
        print(name, '   ', child)
        if isinstance(child, nn.ReLU):
            if is_conv:
                model._modules[name_temp] = ConvRelu(child_temp, is_relu=1).to(device)
                model._modules[name_temp] = LinearRelu(child_temp, is_relu=1).to(device)
            model._modules[name] = DummyModule().to(device)
            # child_temp = None
            # name_temp = None
        elif isinstance(child, nn.Conv2d):
            name_temp = name
            child_temp = child               
            model._modules[name] = ConvRelu(child, is_relu=0).to(device)            
            is_conv = True
        elif isinstance(child, nn.Linear):
            name_temp = name
            child_temp = child            
            model._modules[name] = LinearRelu(child, is_relu=0).to(device)
            is_conv = False
    return model







from torch import nn
import torch
import torch.nn.functional as F
import copy
from collections import OrderedDict
import numpy as np

QUANTIZE_NUM = 127    # 7bit

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# the module that replace relu layer
class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        return x
# the module that replace conv layer
class ConvRelu(nn.Module):
    def __init__(self, conv, is_relu=0, bits=8, threshold=204800):
        super(ConvRelu, self).__init__()
        #self.conv_relu_fold = conv
        self.threshold = threshold
        self.bits = bits
        self.is_relu = is_relu
        self.kernel_size = conv.kernel_size
        self.stride = conv.stride
        self.padding = conv.padding
        self.groups = conv.groups
        self.bias = conv.bias
        self.weight = conv.weight
        '''mode : Normal, TRT_weight_quant, TRT_activate_collection_max, TRT_activate_collection_hist, TRT_activate_KL, Normal_TRT'''
        self.mode = 'TRT_weight_quant'         
        #self.register_buffer('is_relu', torch.tensor(is_relu))
        self.register_buffer('quant_num', torch.tensor((1 << bits) - 1))
        self.register_buffer('activate_flag', torch.zeros(1))
        self.register_buffer('activate_distubution', torch.zeros(INTERVAL_NUM))
        self.register_buffer('activate_distubution_edges', torch.zeros(INTERVAL_NUM+1))
        self.register_buffer('activate_max', torch.zeros(1))
        self.register_buffer('th', torch.zeros(1))
        self.register_buffer('optimal_th', torch.zeros(1))
        # self.register_buffer('activate_distubution_interval', torch.zeros(1))
        self.register_buffer('weight_flag', torch.zeros(1))
        self.register_buffer('weight_scale', torch.zeros(conv.weight.data.shape[0]))
        self.register_buffer('weight_zero', torch.zeros(conv.weight.data.shape[0]))
        self.register_buffer('weight_max', torch.zeros(conv.weight.data.shape[0]))
    def initial_activate_max(self, input):
        max_val = torch.max(input)
        min_val = torch.min(input)
        self.activate_max = torch.max(self.activate_max, torch.max(torch.abs(max_val), torch.abs(min_val)))
        # Avoid unusually large activation by clip blob_max with threshold
        self.th= min(self.activate_max, self.threshold)
        # print('test: ', self.th)
    def weight_quant(self):
        '''Avoid multiple operations caused by multiple identification of the module'''
        self.weight_flag = torch.ones(1).to(device)
        weight_max = torch.max(torch.max(torch.max(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
        weight_min = torch.min(torch.min(torch.min(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
        # weight_max_min = torch.cat((torch.abs(weight_max), torch.abs(weight_min)), 0).view([2,-1])
        # self.weight_max = torch.max(weight_max_min,0,keepdim=True)[0]
        weight_threshold = torch.max(torch.abs(weight_max), torch.abs(weight_min))
        self.weight_max = weight_threshold
        # print('weight_shape: ', weight_threshold.shape)   
        self.weight_scale = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(0.0).to(device), ((1 << (self.bits-1))-1) / weight_threshold)
        # print('weight_scale111: ', self.weight_scale)
        self.weight_zero = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(1.0).to(device), torch.tensor(0.0).to(device))
    # def initial_activate_distubution_interval(self):
    #     self.activate_distubution_interval = (torch.tensor(STATISTIC).to(device)) * self.th / torch.tensor(INTERVAL_NUM).to(device).astype(float)
    def initial_histograms(self, input):
        # Truncate the boundary of the active hist graph,
        # so the number exceeding the boundary value will not fall into statistics.
        # print('id0: ', id(input))
        input_cpu = input.cpu()
        # print('id1: ', id(input_cpu))
        # print(input_cpu)
        input_cpu_numpy = input_cpu.numpy().flatten()
        th = self.th.cpu().item()
        # print(th)
        hist, hist_edges = np.histogram(input_cpu_numpy, bins=INTERVAL_NUM, range=(-th, th))
        #hist = torch.histc(input, bins=INTERVAL_NUM, min=-self.th, max=self.th)
        self.activate_distubution += torch.from_numpy(hist).to(device)
        self.activate_distubution[2000] = torch.tensor(0).to(device)
        self.activate_distubution_edges = torch.from_numpy(hist_edges).to(device)
    def plot_hist(self, optimal_th=None):
        a = self.activate_distubution_edges.cpu().numpy()[:-1]
        b = self.activate_distubution.cpu().numpy()
        print('hist: ', a)
        print('hist_edge: ', b)
        import matplotlib.pyplot as plt
        plt.plot(self.activate_distubution_edges.cpu().numpy()[:-1], self.activate_distubution.cpu().numpy())
        if optimal_th is not None:
            plt.plot(optimal_th, 0, 'om')
            plt.annotate('optimal_th', xy=(optimal_th, 0), xytext=(optimal_th+1, 10000), arrowprops=dict(arrowstyle='->'))
        plt.ylabel('activate distubution')
    def get_optimal_threshold(self):
        '''Avoid multiple operations caused by multiple identification of the module'''
        self.activate_flag = torch.ones(1).to(device)
        length = self.activate_distubution.shape[0]
        assert (length % 2 == 1)
        hist = self.activate_distubution.cpu().numpy()
        hist_edge = self.activate_distubution_edges.cpu().numpy()
        num_quantized_bins = self.quant_num.cpu().item()
        optimal_threshold = calibrate(hist, hist_edge, num_quantized_bins)
        self.optimal_th = torch.tensor(optimal_threshold).to(device)
        print('th: ', self.th)
        print('optimal_th: ', self.optimal_th)

    def forward(self, x):
        assert self.training is False
        # print('test')
        x  = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.groups)
        # x = self.conv_relu_fold(x)
        if self.is_relu:
            x = F.relu(x)
        if self.mode == 'TRT_activate_collection_max':
            '''collect max,min,threshold'''
        elif self.mode == 'TRT_activate_collection_hist':
            '''collect histograms'''
        elif self.mode == 'TRT_activate_KL':
            '''calibrate for optimal_threshold'''
            # self.get_optimal_threshold()
        elif self.mode == 'Normal_TRT':
        elif self.mode != 'TRT_weight_quant':
            raise ValueError("mode error")
        return x

 以下代码是第二个参考代码中调用的C++代码,该C++代码有点错误,处理边界存在叠加,问题在:merge hist into num_quantized_bins bins部分,注意区分(已修改)。

def calibrate(hist, hist_edge, num_quantized_bins=255):
    num_bins = hist.size
    assert num_bins+1 == hist_edge.size
    zero_bin_idx = num_bins // 2
    num_half_quantized_bins = num_quantized_bins // 2
    thresholds = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
    divergence = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
    for i in range(num_half_quantized_bins, zero_bin_idx+1, 1):
        p_bin_index_start = zero_bin_idx - i
        p_bin_index_stop = zero_bin_idx + i + 1
        thresholds[i - num_half_quantized_bins] = hist_edge[p_bin_index_stop];
        sliced_nd_hist = np.zeros(p_bin_index_stop - p_bin_index_start)
        p = np.zeros(p_bin_index_stop - p_bin_index_start)
        # for j in range(num_bins):
        #     if j <= p_bin_index_start:
        #         p[0] +=
        p[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
        sliced_nd_hist[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
        p[0] = np.sum(hist[:p_bin_index_start+1])
        p[-1] = p[-1] + np.sum(hist[p_bin_index_stop:])
        # print(p)
        # print(sliced_nd_hist)
        '''calculate how many bins should be merged to generate quantized distribution q'''
        num_merged_bins = sliced_nd_hist.size // num_quantized_bins
        '''merge hist into num_quantized_bins bins'''
        quantized_bins = np.zeros(num_quantized_bins)
        for j in range(num_quantized_bins):
            start = j * num_merged_bins
            stop = (j+1) * num_merged_bins
            quantized_bins[j] = np.sum(sliced_nd_hist[start:stop])
        quantized_bins[-1] = quantized_bins[-1] + np.sum(sliced_nd_hist[num_quantized_bins * num_merged_bins : ])
        '''expand quantized_bins into p.size bins'''
        q = np.zeros(p_bin_index_stop - p_bin_index_start)
        is_nonzeros = (p != 0).astype(np.int64)
        for j in range(num_quantized_bins):
            start = j * num_merged_bins
            stop = q.size if (j == num_quantized_bins-1)  else (j+1) * num_merged_bins
            norm = is_nonzeros[start:stop].sum()
            if norm != 0:
                q[start:stop] = float(quantized_bins[j]) / float(norm)
        q[p == 0] = 0
        p = _smooth_distribution(p);
        q = _smooth_distribution(q);
        # p[p == 0] = 0.0001
        # q[q == 0] = 0.0001
        # print('p: ', p)
        # print('q: ', q)
        divergence[i - num_half_quantized_bins] = ComputeEntropy(p, q)
        # print(divergence[i - num_half_quantized_bins])
        # print('done')
    min_kl_divergence = np.argmin(divergence)
    return thresholds[min_kl_divergence]
def _smooth_distribution(p, eps=0.0001):
    is_zeros = (p == 0).astype(np.float32)
    is_nonzeros = (p != 0).astype(np.float32)
    n_zeros = is_zeros.sum()
    n_nonzeros = p.size - n_zeros
    if not n_nonzeros:
        raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
    eps1 = eps * float(n_zeros) / float(n_nonzeros)
    assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
    hist = p.astype(np.float32)
    hist += eps * is_zeros + (-eps1) * is_nonzeros
    assert (hist <= 0).sum() == 0
    return hist

#from scipy import *
def ComputeEntropy(p, q):
    assert p.size == q.size 
    p_sum = np.sum(p)
    q_sum = np.sum(q)
    p = p / p_sum
    q = q / q_sum
    KL_dis = np.sum(p * np.lib.scimath.log(p / q))
    return KL_dis


import torch
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from VggNet import * 
from datetime import datetime
from torch.utils.data import DataLoader

from torchvision import datasets,transforms

from ConvReluFold import model_relu_folding

from ConvBNFold import model_bn_folding
from quant_utils import ConvRelu, LinearRelu, DummyModule, TRT_Quantizer

model = torch.load('./model/vgg0.904_bnrelufold.pth')

'''---------------------- TRT_weight_quant ------------------------------------'''
TRT_Quantizer(model, mode='TRT_weight_quant')

'''---------------------- TRT_activate_collection_max ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_max')

correct = 0.0
total = 0
num = 0
with torch.no_grad():  
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device) 
        outputs = model(inputs)
        pred = outputs.argmax(dim = 1)  # 
        total += inputs.size(0)
        correct += torch.eq(pred,labels).sum().item()
        num += 1
        if num > 20:
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))

'''---------------------- TRT_activate_collection_hist ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_hist')
correct = 0.0
total = 0
num = 0
with torch.no_grad():  # 训练集不需要反向传播
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device) 
        outputs = model(inputs)
        pred = outputs.argmax(dim = 1)  
        total += inputs.size(0)
        correct += torch.eq(pred,labels).sum().item()
        num += 1
        if num > 20:
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))

'''---------------------- TRT_activate_KL ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_KL')





