使用LineProfiler找出代码的计算瓶颈

2023-11-12

实现同样一个功能,笔者运行需要11秒,而同窗的运行仅需要1秒不到,但是实际实现逻辑是类似的,所以需要使用性能分析工具对瓶颈进行分析。

安装

  • 命令行安装:
pip install line_profiler
  • 本地下载后安装:
    https://www.lfd.uci.edu/~gohlke/pythonlibs/#line_profiler
    根据平台选择对应whl文件,然后本地安装。

修改代码

先来一个demo,do_stuff是我们的目标,要测试这个函数每一行的耗时。

from line_profiler import LineProfiler
import random
 
def do_other_stuff(numbers):
    s = sum(numbers)
 
def do_stuff(numbers):
    do_other_stuff(numbers)
    l = [numbers[i]/43 for i in range(len(numbers))]
    m = ['hello'+str(numbers[i]) for i in range(len(numbers))]
 
numbers = [random.randint(1,100) for i in range(1000)]
lp = LineProfiler()
lp.add_function(do_other_stuff)   # add additional function to profile
lp_wrapper = lp(do_stuff)
lp_wrapper(numbers)
lp.print_stats()

所以实际上最后五行内容是添加进来的。

在笔者的问题中,是加载mnist数据集:

import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import random 
from line_profiler import LineProfiler

import time 
# WORK1: --------------BEGIN-------------------
# 构建数据平衡采样方法:make_batch
# 参数等都可以自定义
# 返回值为(input_a, input_b), label
# input_a形状为(batch_size,28,28),input_b形状为(batch_size,28,28),label形状为(batch_size,)

def make_batch(batch_size, dataset):
    label = []
    input_a = []
    input_b = []
    
    x1 = np.array(dataset[0])
    y1 = np.array(dataset[1])

    cls_num = batch_size // 20 # 每个类采样个数, pos neg
    cls_idx_same = [np.where(y1 == i)[0] for i in range(10)]
    cls_idx_diff = [np.where(y1 != i)[0] for i in range(10)]

    # pos
    for class_num in range(10): # num of classes
        for _ in range(cls_num): # 每个类采样个数
            choose_two = random.sample(cls_idx_same[class_num].tolist(),2)
            input_a.append(x1[choose_two[0]])
            input_b.append(x1[choose_two[1]])
            label.append(0)
    
    # # neg
    for class_num in range(10):
        for _ in range(cls_num):
            choose_same = random.sample(cls_idx_same[class_num].tolist(), 1)
            choose_diff = random.sample(cls_idx_diff[class_num].tolist(), 1)

            input_a.append(x1[choose_same[0]])
            input_b.append(x1[choose_diff[0]])
            label.append(1)

    input_a = np.array(input_a)
    input_b = np.array(input_b)
    label = np.array(label).astype(np.float)
    return (input_a, input_b), label


if __name__ == "__main__":
    path =   './dataset/mnist.npz'
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()

    # 测试也是255归一化的数据,请不要改归一化
    x_train = x_train / 255.
    x_test = x_test / 255.
    idx_shuffle = np.arange(len(x_train))
    np.random.shuffle(idx_shuffle)
    x_train = x_train[idx_shuffle]
    y_train = y_train[idx_shuffle]

    slice_08 = int(len(x_train)*0.8)

    train_set = [x_train[:slice_08],y_train[:slice_08]]

    # train_set = [np.array(x_train[:slice_08]),
                #  np.array(y_train[:slice_08])]
    # val_set = [x_test, y_test]
    lasttime = time.time()
    # for i in range(100):
    #     make_batch(64, train_set)
    lp = LineProfiler()
    lp_warpper = lp(make_batch)
    lp_warpper(64, train_set)
    lp.print_stats()
    print(time.time()-lasttime)

来运行分析一下运行一次的耗时:

在这里插入图片描述

可以发现 np.array(dataset[0]) 耗时非常严重,不适合放到循环中。所以改动这个部分,在循环外提前转换格式可以节约很长时间。

import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import random 
from line_profiler import LineProfiler

import time 
# WORK1: --------------BEGIN-------------------
# 构建数据平衡采样方法:make_batch
# 参数等都可以自定义
# 返回值为(input_a, input_b), label
# input_a形状为(batch_size,28,28),input_b形状为(batch_size,28,28),label形状为(batch_size,)

def make_batch(batch_size, dataset):
    label = []
    input_a = []
    input_b = []
    
    x1 = dataset[0]
    y1 = dataset[1]

    cls_num = batch_size // 20 # 每个类采样个数, pos neg
    cls_idx_same = [np.where(y1 == i)[0] for i in range(10)]
    cls_idx_diff = [np.where(y1 != i)[0] for i in range(10)]

    # pos
    for class_num in range(10): # num of classes
        for _ in range(cls_num): # 每个类采样个数
            choose_two = random.sample(cls_idx_same[class_num].tolist(),2)
            input_a.append(x1[choose_two[0]])
            input_b.append(x1[choose_two[1]])
            label.append(0)
    
    # # neg
    for class_num in range(10):
        for _ in range(cls_num):
            choose_same = random.sample(cls_idx_same[class_num].tolist(), 1)
            choose_diff = random.sample(cls_idx_diff[class_num].tolist(), 1)

            input_a.append(x1[choose_same[0]])
            input_b.append(x1[choose_diff[0]])
            label.append(1)

    input_a = np.array(input_a)
    input_b = np.array(input_b)
    label = np.array(label).astype(np.float)
    return (input_a, input_b), label


if __name__ == "__main__":
    path =   './dataset/mnist.npz'
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()

    # 测试也是255归一化的数据,请不要改归一化
    x_train = x_train / 255.
    x_test = x_test / 255.
    idx_shuffle = np.arange(len(x_train))
    np.random.shuffle(idx_shuffle)
    x_train = x_train[idx_shuffle]
    y_train = y_train[idx_shuffle]

    slice_08 = int(len(x_train)*0.8)

    train_set = [np.array(x_train[:slice_08]),np.array(y_train[:slice_08])]
    # val_set = [x_test, y_test]
    lasttime = time.time()
    # for i in range(100):
    #     make_batch(64, train_set)
    lp = LineProfiler()
    lp_warpper = lp(make_batch)
    lp_warpper(64, train_set)
    lp.print_stats()
    print(time.time()-lasttime)

运行结果如下:
在这里插入图片描述
这样瓶颈就转移到其他地方了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用LineProfiler找出代码的计算瓶颈 的相关文章

随机推荐

  • java 读取excel数据

    本文共介绍两种方式 第一种是常规POI读取 第二种是大文件读取 依赖包
  • AbstractQueuedSynchronizer(AQS) 原理

    一 AQS 简介 1 1 AQS 是什么 AQS AbstractQueuedSynchronizer 抽象队列同步器 是一个用来构建锁和同步器的框架 使用 AQS 能简单且高效地构造出应用广泛的大量的同步器 比如我们提到的 Reentra
  • 高光谱图像处理

    Development of a classification algorithm for efficient handling of multiple classes in sorting systems basesd on hypers
  • ps语义分割_解决实例分割任务中边缘不够精细:PointRend: Image Segmentation as Rendering...

    加入极市专业CV交流群 与6000 来自腾讯 华为 百度 北大 清华 中科院等名企名校视觉开发者互动交流 更有机会与李开复老师等大牛群内互动 同时提供每月大咖直播分享 真实项目需求对接 干货资讯汇总 行业技术交流 关注 极市平台 公众号 回
  • 计算机网络基础知识整理

    计算机网络 用通信设备和线路将处在不同地理位置 操作相对独立的多台计算机连接起来 并配置相应的系统和应用软件 在原本各自独立的计算机之间实现软硬件资源共享和信息传递等功能的系统 计算机网络的功能 数据通信 2 资源共享 3 增加可靠性 4
  • 【技术解析笔记】DDPM解析

    本文为youtube上一个ddpm解析视频的摘录笔记 youtube原视频链接 https www youtube com watch v W O7AZNzbzQ 基本介绍 DDPM指的是Denoising diffusion probal
  • hive - 面试题 - 最近一次购物在一年前(近一年内无购物)

    要求 有表 用户id 订单id 下单日期 该用户符合365天内无交易且当日有交易的数据打标签 如果当天有多条记录 同样打标签 思路 当前订单时间 最近一次的下单时间 gt 365 即最近365天内无订单记录 中间有个问题 一天内多次下单 只
  • 【GCC-RT-Thread】gcc交叉编译 STM32 - RT-Thread

    GCC RT Thread gcc交叉编译 STM32 RT Thread 最近在公司实习 公司想将原来在Windows keil上开发的项目移到Linux 并上RTT操作系统 最近就被安排做了这件事 首先 下载 RT Thread Nan
  • 学习记录396@git clone 只克隆到.git文件

    github上的仓库 但是使用乌龟克隆时只克隆到 git文件和README文件 原因是在我的仓库中 没有选择分支 默认是main分支 但我的项目在master分支 因此加上分支选项处在clone即可 如果是使用命令行clone 需要使用如下
  • python基础 -15- 深浅拷贝

    浅拷贝 data name alex age 18 scores 语文 130 数学 60 英语 98 浅拷贝 data copy data copy 再看一下各自的内存地址 可以发现指向的内存地址不一样 print data的内存地址 i
  • 白盒测试题(13-16道题目+详细代码)

    白盒测试 题 13 根据下列流程图编写程序实现相应分析处理并显示结果 并设计最少的测试数据进行判定覆盖测试 输入数据打印出 输入 x 值 输入 y 值 输出文字 a 和 a 的值 输出文字 b 和 b 的值 其中变量 x y 均须为整型 i
  • 红队靶场内网渗透(从DMZ主机渗透到域内机器)

    目录 一 红队靶场内网渗透 1 靶机工具下载 2 本实验网络拓扑图 3 内网渗透攻击流程 二 环境搭建 1 DMZ区win7 2 内网办公区 3 域控主机 三 开始攻击 1 DMZ区win7渗透 1 1信息收集 1 2收集到的信息 1 3远
  • 安卓手机使用Termux实现gitee云端代码本地化修改

    Termux是什么 Termux是一个Android终端仿真器和Linux环境应用程序 直接工作 无需根目录或设置 额外的软件包可以使用APT软件包管理器来使用 不需要root 有root更方便修改代码 下载地址 Termux 0 99 T
  • 人脸识别(dlib.face_recognition_model_v1 方法 -- 使用resnet模型)

    人脸识别 思路 通过检测面部特征 对该特征与数据存放的特征进行比对 文件结构 文件名 weights 的目录下 resnet模型 dat文件 识别68个关键点模型 dat文件 共两个模型文件 补充 你如果不使用dlib库中自带的HOG人脸检
  • 网络基础 (深信服)

    一 走进网络世界 1 1 1 企业网络环境介绍 计算机网络类型 LAN 本地局域网 Local Area Network 通常指几千米以内的 可通过某种介质互联的计算机 打印机 modem或其他设备的集合 WAN 广 域 网 Wide Ar
  • 【C语言】你还在写void main()吗?我劝你别用,小心出BUG

    目录 前言 C语言标准并不支持void main 用void main 可能会报错 总结 前言 你的教材上是不是经常出现void main 呢 我想说 永远不要写void main 为什么 这种写法普遍存在于我们国内的很多教材 既然出现在教
  • day02-08 python基础语法

    模块一 python基础语法 day2 快速上手 今日概要 课程目标 学习Python最基础的语法知识 可以用代码快速实现一些简单的功能 课程概要 初识编码 密码本 编程初体验 输出 初识数据类型 变量 注释 输入 条件语句 1 编码 密码
  • [编程题]输出元素组成数组的排列组合形式

    题目 一个由有限个不同元素组成的数组的所有组合排列形式 要求排列的顺序以从小到大的顺序排列 按首列排序 首列相同 则按照第二列排序 前两列相同 则以第三列排序 以此顺序递推 输入例子1 1 2 输出例子1 1 2 2 1 例子说明1 输出结
  • 服务器划分多台虚拟pc,pc服务器建立多台虚拟主机

    pc服务器建立多台虚拟主机 内容精选 换一换 虚拟IP主要用在弹性云服务器的主备切换 达到高可用性HA High Availability 的目的 当主服务器发生故障无法对外提供服务时 动态将虚拟IP切换到备服务器 继续对外提供服务 了解更
  • 使用LineProfiler找出代码的计算瓶颈

    实现同样一个功能 笔者运行需要11秒 而同窗的运行仅需要1秒不到 但是实际实现逻辑是类似的 所以需要使用性能分析工具对瓶颈进行分析 安装 命令行安装 pip install line profiler 本地下载后安装 https www l