Python 中 FFT 的循环加速(使用“np.einsum”)

2024-04-25

Problem:我想加速包含大量乘积和求和的 python 循环np.einsum,但我也愿意接受任何其他解决方案。

我的函数采用形状为 (n,n,3) 的向量配置 S(我的情况:n=72),并对 N*N 点的相关函数进行傅里叶变换。相关函数定义为每个向量与其他向量的乘积。它乘以向量位置乘以 kx 和 ky 值的余弦函数。每个位置i,j最后求和得到k空间中的一个点p,m:

def spin_spin(S,N):
    n= len(S)
    conf = np.reshape(S,(n**2,3))
    chi = np.zeros((N,N))
    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

    x=np.reshape(triangular(n)[0],(n**2))
    y=np.reshape(triangular(n)[1],(n**2))
    for p in range(N):
        for m in range(N):
            for i in range(n**2):
                for j in range(n**2):        
                    chi[p,m] += 2/(n**2)*np.dot(conf[i],conf[j])*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))
    return(chi,kx,ky)

我的问题是,我需要大约 100*100 个点,用 kx*ky 表示,并且循环需要很多小时才能完成具有 72*72 向量的晶格的这项工作。 计算次数:72*72*72*72*100*100 我无法使用内置的 FFTnumpy,由于我的三角形网格,所以我需要一些其他选项来减少这里的计算成本。

My idea:首先,我认识到将配置重塑为向量列表而不是矩阵可以降低计算成本。此外,我使用了 numba 包,这也降低了成本,但仍然太慢。我发现计算此类对象的一个​​好方法是np.einsum功能。计算每个向量与每个向量的乘积是通过以下方式完成的:

np.einsum('ij,kj -> ik',np.reshape(S,(72**2,3)),np.reshape(S,(72**2,3)))

棘手的部分是计算里面的项np.cos。在这里,我想计算形状 (100,1) 列表与向量位置之间的乘积(例如np.shape(x)=(72**2,1))。特别是我真的不知道如何实现x方向和y方向的距离np.einsum.

重现代码(可能您不需要这个):首先你需要一个矢量配置。你可以简单地用np.ones((72,72,3)或者以随机向量为例:

def spherical_to_cartesian(r, theta, phi):
    '''Convert spherical coordinates (physics convention) to cartesian coordinates'''
    sin_theta = np.sin(theta)
    x = r * sin_theta * np.cos(phi)
    y = r * sin_theta * np.sin(phi)
    z = r * np.cos(theta)

    return x, y, z # return a tuple

def random_directions(n, r):
    '''Return ``n`` 3-vectors in random directions with radius ``r``'''
    out = np.empty(shape=(n,3), dtype=np.float64)

    for i in range(n):
        # Pick directions randomly in solid angle
        phi = random.uniform(0, 2*np.pi)
        theta = np.arccos(random.uniform(-1, 1))
        # unpack a tuple
        x, y, z = spherical_to_cartesian(r, theta, phi)
        out[i] = x, y, z

    return out
S = np.reshape(random_directions(72**2,1),(72,72,3))

(本例中的重塑需要在函数中对其进行整形spin_spin回到 (72**2,3) 形状。)

对于向量的位置,我使用由下式定义的三角形网格

def triangular(nsize):
    '''Positional arguments of the spin configuration'''

    X=np.zeros((nsize,nsize))
    Y=np.zeros((nsize,nsize))
    for i in range(nsize):
        for j in range(nsize):
            X[i,j]+=1/2*j+i
            Y[i,j]+=np.sqrt(3)/2*j
    return(X,Y)

优化的 Numba 实施

代码中的主要问题是调用外部 BLAS 函数np.dot反复地以极其small数据。在此代码中,仅计算一次会更有意义,但如果您必须在循环中执行此计算,请编写 Numba 实现。Example https://stackoverflow.com/a/59356461/4045774

优化功能(暴力破解)

import numpy as np
import numba as nb

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin(S,N):
    n= len(S)
    conf = np.reshape(S,(n**2,3))
    chi = np.zeros((N,N))
    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N).astype(np.float32)
    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N).astype(np.float32)

    x=np.reshape(triangular(n)[0],(n**2)).astype(np.float32)
    y=np.reshape(triangular(n)[1],(n**2)).astype(np.float32)

    #precalc some values
    fact=nb.float32(2/(n**2))
    conf_dot=np.dot(conf,conf.T).astype(np.float32)

    for p in nb.prange(N):
        for m in range(N):
            #accumulating on a scalar is often beneficial
            acc=nb.float32(0)
            for i in range(n**2):
                for j in range(n**2):        
                    acc+= conf_dot[i,j]*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))
            chi[p,m]=fact*acc

    return(chi,kx,ky)

优化功能(去除多余计算)

做了很多多余的计算。这是有关如何删除它们的示例。这也是一个以双精度进行计算的版本。

@nb.njit()
def precalc(S):
    #There may not be all redundancies removed
    n= len(S)
    conf = np.reshape(S,(n**2,3))
    conf_dot=np.dot(conf,conf.T)
    x=np.reshape(triangular(n)[0],(n**2))
    y=np.reshape(triangular(n)[1],(n**2))

    x_s=set()
    y_s=set()
    for i in range(n**2):
        for j in range(n**2):
            x_s.add((x[i]-x[j]))
            y_s.add((y[i]-y[j]))

    x_arr=np.sort(np.array(list(x_s)))
    y_arr=np.sort(np.array(list(y_s)))


    conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))
    for i in range(n**2):
        for j in range(n**2):
            ii=np.searchsorted(x_arr,x[i]-x[j])
            jj=np.searchsorted(y_arr,y[i]-y[j])
            conf_dot_sel[ii,jj]+=conf_dot[i,j]

    return x_arr,y_arr,conf_dot_sel

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin_opt_2(S,N):
    chi = np.empty((N,N))
    n= len(S)

    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

    x_arr,y_arr,conf_dot_sel=precalc(S)
    fact=2/(n**2)
    for p in nb.prange(N):
        for m in range(N):
            acc=nb.float32(0)
            for i in range(x_arr.shape[0]):
                for j in range(y_arr.shape[0]):        
                    acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])
            chi[p,m]=acc

    return(chi,kx,ky)

@nb.njit()
def precalc(S):
    #There may not be all redundancies removed
    n= len(S)
    conf = np.reshape(S,(n**2,3))
    conf_dot=np.dot(conf,conf.T)
    x=np.reshape(triangular(n)[0],(n**2))
    y=np.reshape(triangular(n)[1],(n**2))

    x_s=set()
    y_s=set()
    for i in range(n**2):
        for j in range(n**2):
            x_s.add((x[i]-x[j]))
            y_s.add((y[i]-y[j]))

    x_arr=np.sort(np.array(list(x_s)))
    y_arr=np.sort(np.array(list(y_s)))


    conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))
    for i in range(n**2):
        for j in range(n**2):
            ii=np.searchsorted(x_arr,x[i]-x[j])
            jj=np.searchsorted(y_arr,y[i]-y[j])
            conf_dot_sel[ii,jj]+=conf_dot[i,j]

    return x_arr,y_arr,conf_dot_sel

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin_opt_2(S,N):
    chi = np.empty((N,N))
    n= len(S)

    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

    x_arr,y_arr,conf_dot_sel=precalc(S)
    fact=2/(n**2)
    for p in nb.prange(N):
        for m in range(N):
            acc=nb.float32(0)
            for i in range(x_arr.shape[0]):
                for j in range(y_arr.shape[0]):        
                    acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])
            chi[p,m]=acc

    return(chi,kx,ky)

Timings

#brute-force
%timeit res=spin_spin(S,100)
#48 s ± 671 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#new version
%timeit res_2=spin_spin_opt_2(S,100)
#5.33 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit res_2=spin_spin_opt_2(S,1000)
#1min 23s ± 2.43 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

编辑(SVML 检查)

import numba as nb
import numpy as np

@nb.njit(fastmath=True)
def foo(n):
    x   = np.empty(n*8, dtype=np.float64)
    ret = np.empty_like(x)
    for i in range(ret.size):
            ret[i] += np.cos(x[i])
    return ret

foo(1000)

if 'intel_svmlcc' in foo.inspect_llvm(foo.signatures[0]):
    print("found")
else:
    print("not found")

#found

如果有一个not found read 这个链接。 https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#intel-svml它应该可以在 Linux 和 Windows 上运行,但我还没有在 macOS 上进行测试。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Python 中 FFT 的循环加速(使用“np.einsum”) 的相关文章

  • 在云服务器中运行 python 脚本的最简单方法是什么?

    我有一个网络爬行 python 脚本 需要几个小时才能完成 并且无法在我的本地计算机上完整运行 有没有一种方便的方法可以将其部署到简单的 Web 服务器 该脚本基本上将网页下载到文本文件中 如何最好地实现这一点 谢谢 既然你说性能是一个问题
  • 在 Heroku 应用程序中同时运行 Django 和 Node

    我想在我的 heroku 实例上同时运行 django 应用程序和节点应用程序 这是我的进程文件 web python manage py runserver 0 0 0 0 PORT web node bin node modules a
  • mod_wsgi 下的 psp(python 服务器页面)代码?

    有没有办法在 apache mod wsgi 下运行 psp python 服务器页面 代码 虽然我们正在转向基于 wsgi 的新框架 但我们仍然有一些用 psp 编写的遗留代码 这些代码在 mod python 下运行 我们希望能够在托管
  • 如何在 Python 中使用 Selenium 运行无头 Chrome?

    我正在尝试使用 selenium 进行一些操作 我真的希望我的脚本能够快速运行 我认为使用无头 Chrome 运行我的脚本会使其速度更快 首先 这个假设是否正确 或者我是否使用无头驱动程序运行我的脚本并不重要 我希望无头 Chrome 能够
  • 使用SchemDraw库自动保存图像

    我想在Python中使用这个库来生成电气图 https cdelker bitbucket io SchemDraw https cdelker bitbucket io SchemDraw 我想在服务器中运行这段代码 这个想法是生成图像
  • python 2.7 字符 \u2013 [重复]

    这个问题在这里已经有答案了 我有以下代码 coding utf 8 print u William Burges 1827 81 was an English architect and designer 当我尝试从cmd运行它时 我收到以
  • python 线程是如何工作的?

    我想知道 python 线程是并发运行还是并行运行 例如 如果我有两个任务并在两个线程中运行它们 它们是同时运行还是计划同时运行 我知道GIL并且线程仅使用一个 CPU 核心 这是一个复杂的问题 需要大量解释 我将坚持使用 CPython
  • 将numpy字符串数组转换为int数组[重复]

    这个问题在这里已经有答案了 我有一个 numpy ndarray a 0 99 0 56 0 56 2 02 0 96 如何将其转换为int 输出 a 0 99 0 0 0 56 0 56 2 02 0 96 我想要 0 0 代替空白 im
  • Python 中意外的缩进错误[重复]

    这个问题在这里已经有答案了 我有一段简单的代码 我不明白我的错误来自哪里 解析器在第 5 行 if 语句 上用意外的缩进向我咆哮 有人看到这里的问题吗 我不 def gen fibs a b 0 1 while True a b b a b
  • 如何动态构造方法?

    我设计了一个类 它非常标准 具有一些方法属性 class foo def f1 self print f1 def f2 self print f2 def fn self print fn 现在我想创建一个包含一组 foo 实例的类 cl
  • 使用光栅重新投影 .tiff 文件:CRSError:无法解析 WKT。 OGR 错误代码 6

    我正在尝试使用以下代码将 tiff 文件重新投影到 EPSG 32638 我安装过的版本 光栅版本 1 1 5 Numpy 版本 1 18 1 这是我正在使用的代码 https rasterio readthedocs io en late
  • 为 Mercurial 执行 hgweb.cgi 时,指定的 CGI 应用程序行为不当...

    我有 IIS 6 我将 Mercurial 安装在 c program files mercurial 中 我在 c program files python 中安装了 Python 2 6 I added extension handli
  • 如何在 python 中将 selenium webelement 转换为字符串变量

    from selenium import webdriver from time import sleep from selenium common exceptions import NoSuchAttributeException fr
  • 在 Python 中删除表达式树及其每个子表达式树中第一个元素周围的括号

    目标是实现简化操作 删除表达式树及其每个子表达式树中第一个元素周围的括号 其中表达式作为括在各个括号中的字符串输入给出 这必须适用于任意数量的括号 例如 12 3 45 6 gt 123 45 6 删除 12 周围的括号 然后删除 45 周
  • 抑制来自 python pandas 描述的名称 dtype

    可以说我有 r pd DataFrame A 1 B pd Series 1 index list range 4 dtype float32 And r B describe mean std min max 给出输出 mean 1 0
  • 如何编辑多个 Pandas DataFrame 浮点列的字符串格式?

    我有一个pd DataFrame浮点数 import numpy as np import pandas as pd pd DataFrame np random rand 5 5 0 1 2 3 4 0 0 795329 0 125540
  • Pandas:Drop() int64 基于值返回对象

    我需要删除其中一列低于某个值的所有行 我使用了下面的命令 但这将列作为对象返回 我需要将其保留为int64 df customer id df drop df customer id df customer id lt 9999999 in
  • pyspark:将 schemaRDD 保存为 json 文件

    我正在寻找一种将数据从 Apache Spark 以 JSON 格式导出到各种其他工具的方法 我认为一定有一种非常简单的方法来做到这一点 示例 我有以下 JSON 文件 jfile json key value a1 key2 value
  • Django - 在启动时执行代码

    我正在使用 Django 1 9 3 我有一个包含多个应用程序的项目 我想在项目启动时更新其中一个应用程序的表 用例 例如 假设我想在我的网站上销售商品 我有一个包含模型项目的应用程序 我在 Django 之外有一个网络服务 它提供服务 g
  • 将下载的字体添加到 Tkinter

    我想下载一个开源字体并在我的 Python Tkinter 程序中使用它 如何告诉 Tkinter 从目录导入字体或将字体放在与程序相同的文件夹中 Note 我已经寻找答案一段时间了 甚至阅读了 Tkinter 的 API 参考 了解我能找

随机推荐