这可能并不明显,但是pd.Series.isin
uses O(1)
-查找每个元素。
经过分析证明上述陈述后,我们将利用其见解来创建一个 Cython 原型,它可以轻松击败最快的开箱即用解决方案。
我们假设“集合”有n
元素和“系列”有m
元素。那么运行时间为:
T(n,m)=T_preprocess(n)+m*T_lookup(n)
对于纯 python 版本,这意味着:
-
T_preprocess(n)=0
- 无需预处理
-
T_lookup(n)=O(1)
- python 集合的众所周知的行为
- 结果是
T(n,m)=O(m)
会发生什么pd.Series.isin(x_arr)
?显然,如果我们跳过预处理并在线性时间内搜索,我们将得到O(n*m)
,这是不可接受的。
借助调试器或分析器(我使用 valgrind-callgrind+kcachegrind)很容易看到发生了什么:工作的马是函数__pyx_pw_6pandas_5_libs_9hashtable_23ismember_int64
。其定义可以查到here:
- 在预处理步骤中,哈希映射(pandas 使用来自 klib 的卡什)创建于
n
元素来自x_arr
,即在运行时O(n)
.
-
m
查找发生在O(1)
每个或O(m)
总共在构造的哈希图中。
- 结果是
T(n,m)=O(m)+O(n)
我们必须记住 - numpy-array 的元素是原始 C 整数,而不是原始集合中的 Python 对象 - 所以我们不能按原样使用该集合。
将 Python 对象集转换为 C 整数集的另一种方法是将单个 C 整数转换为 Python 对象,从而能够使用原始集。这就是发生在[i in x_set for i in ser.values]
-变体:
- 没有预处理。
- m 次查找发生在
O(1)
每次或O(m)
总共,但由于需要创建 Python 对象,查找速度较慢。
- 结果是
T(n,m)=O(m)
显然,您可以使用 Cython 稍微加快此版本的速度。
但理论已经足够了,让我们看看不同情况下的运行时间n
s 与固定m
s:
我们可以看到:预处理的线性时间在大型的 numpy 版本中占主导地位n
是。从 numpy 转换为纯 python 的版本(numpy->python
)与纯 python 版本具有相同的恒定行为,但由于必要的转换而速度较慢 - 这一切都符合我们的分析。
这在图中看不清楚:如果n < m
numpy 版本变得更快 - 在这种情况下,查找速度更快khash
-lib 起着最重要的作用,而不是预处理部分。
我从这个分析中得出的结论是:
n < m
: pd.Series.isin
应该采取,因为O(n)
- 预处理并不那么昂贵。
n > m
:(可能是 cythonized 版本)[i in x_set for i in ser.values]
应采取,因此O(n)
避免了。
显然存在一个灰色地带n
and m
近似相等,未经测试很难判断哪种解决方案最好。
如果你能控制它:最好的办法是构建set
直接作为 C 整数集 (khash
(已经被熊猫包裹了)或者甚至可能是一些 C++ 实现),从而消除了预处理的需要。我不知道 pandas 中是否有可以重用的东西,但在 Cython 中编写该函数可能不是什么大问题。
问题是最后一个建议不能开箱即用,因为 pandas 和 numpy 在它们的界面中都没有集合的概念(至少就我有限的知识而言)。但拥有原始 C 集接口将是两全其美:
- 不需要预处理,因为值已经作为一组传递
- 不需要转换,因为传递的集合由原始 C 值组成
我编写了一个快速而肮脏的代码khash 的 Cython 包装器(受到 pandas 中包装器的启发),可以通过以下方式安装pip install https://github.com/realead/cykhash/zipball/master
然后与 Cython 一起使用以获得更快的速度isin
版本:
%%cython
import numpy as np
cimport numpy as np
from cykhash.khashsets cimport Int64Set
def isin_khash(np.ndarray[np.int64_t, ndim=1] a, Int64Set b):
cdef np.ndarray[np.uint8_t,ndim=1, cast=True] res=np.empty(a.shape[0],dtype=np.bool)
cdef int i
for i in range(a.size):
res[i]=b.contains(a[i])
return res
作为进一步的可能性,c++unordered_map
可以被包装(参见清单 C),它的缺点是需要 c++ 库并且(正如我们将看到的)速度稍慢。
比较这些方法(参见清单 D 创建时序):
khash 的速度大约比 20 倍快numpy->python
,比纯 python 快约 6 倍(但无论如何,纯 python 都不是我们想要的),甚至比 cpp 版本快约 3 倍。
Listings
1) 使用 valgrind 进行分析:
#isin.py
import numpy as np
import pandas as pd
np.random.seed(0)
x_set = {i for i in range(2*10**6)}
x_arr = np.array(list(x_set))
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
for _ in range(10):
ser.isin(x_arr)
and now:
>>> valgrind --tool=callgrind python isin.py
>>> kcachegrind
导致以下调用图:
B:用于生成运行时间的 ipython 代码:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
np.random.seed(0)
x_set = {i for i in range(10**2)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
lst = arr.tolist()
n=10**3
result=[]
while n<3*10**6:
x_set = {i for i in range(n)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
t1=%timeit -o ser.isin(x_arr)
t2=%timeit -o [i in x_set for i in lst]
t3=%timeit -o [i in x_set for i in ser.values]
result.append([n, t1.average, t2.average, t3.average])
n*=2
#plotting result:
for_plot=np.array(result)
plt.plot(for_plot[:,0], for_plot[:,1], label='numpy')
plt.plot(for_plot[:,0], for_plot[:,2], label='python')
plt.plot(for_plot[:,0], for_plot[:,3], label='numpy->python')
plt.xlabel('n')
plt.ylabel('running time')
plt.legend()
plt.show()
C:cpp 包装器:
%%cython --cplus -c=-std=c++11 -a
from libcpp.unordered_set cimport unordered_set
cdef class HashSet:
cdef unordered_set[long long int] s
cpdef add(self, long long int z):
self.s.insert(z)
cpdef bint contains(self, long long int z):
return self.s.count(z)>0
import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def isin_cpp(np.ndarray[np.int64_t, ndim=1] a, HashSet b):
cdef np.ndarray[np.uint8_t,ndim=1, cast=True] res=np.empty(a.shape[0],dtype=np.bool)
cdef int i
for i in range(a.size):
res[i]=b.contains(a[i])
return res
D:使用不同的集合包装器绘制结果:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
from cykhash import Int64Set
np.random.seed(0)
x_set = {i for i in range(10**2)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
lst = arr.tolist()
n=10**3
result=[]
while n<3*10**6:
x_set = {i for i in range(n)}
x_arr = np.array(list(x_set))
cpp_set=HashSet()
khash_set=Int64Set()
for i in x_set:
cpp_set.add(i)
khash_set.add(i)
assert((ser.isin(x_arr).values==isin_cpp(ser.values, cpp_set)).all())
assert((ser.isin(x_arr).values==isin_khash(ser.values, khash_set)).all())
t1=%timeit -o isin_khash(ser.values, khash_set)
t2=%timeit -o isin_cpp(ser.values, cpp_set)
t3=%timeit -o [i in x_set for i in lst]
t4=%timeit -o [i in x_set for i in ser.values]
result.append([n, t1.average, t2.average, t3.average, t4.average])
n*=2
#ploting result:
for_plot=np.array(result)
plt.plot(for_plot[:,0], for_plot[:,1], label='khash')
plt.plot(for_plot[:,0], for_plot[:,2], label='cpp')
plt.plot(for_plot[:,0], for_plot[:,3], label='pure python')
plt.plot(for_plot[:,0], for_plot[:,4], label='numpy->python')
plt.xlabel('n')
plt.ylabel('running time')
ymin, ymax = plt.ylim()
plt.ylim(0,ymax)
plt.legend()
plt.show()