The numpy.where
函数用于检索 ndarray 中给定条件为 true 的元素的索引。
学完本教程后,您将充分了解如何使用numpy.where
查询NumPy 数组.
语法和参数
The numpy.where
函数允许您对 NumPy 数组执行复杂的查询。
这是基本语法:
numpy.where(condition[, x, y])
-
condition
:该参数是一个包含布尔值的数组。它定义了必须满足的条件。
您可以使用比较运算符来定义给定数组的条件。
-
x
, y
:这些是可选参数。如果提供的话,numpy.where
返回从中选择的元素x
or y
视情况而定。
如果未提供这些参数,该函数将返回条件为 true 的索引。
让我们通过一个例子来探索语法:
import numpy as np
array = np.array([10, 20, 30, 40])
condition = array > 25
result = np.where(condition)
print(result)
Output:
(array([2, 3], dtype=int64),)
在此示例中,我们定义了一个条件数组 > 25。
The numpy.where
函数检查数组中每个元素的此条件,并返回一个包含满足条件的元素索引的元组。
元素 30 和 40 满足条件,并返回它们的索引(2 和 3)。
可选参数 x 和 y 提供对输出的进一步控制。
使用 x 和 y 参数替换值
The x
and y
参数在numpy.where
为函数的行为提供额外的灵活性。当提供这些参数时,该函数返回值x
and y
基于条件,而不是返回索引。
下面是一个例子来演示使用x
and y
:
import numpy as np
array = np.array([5, 15, 25, 35])
result = np.where(array > 20, 'High', 'Low')
print(result)
Output:
['Low' 'Low' 'High' 'High']
在此示例中,x 和 y 参数分别设置为 High 和 Low。条件是数组 > 20。
当条件满足时(对于元素 25 和 35),返回 High 值。
如果不满足条件(对于元素 5 和 15),则返回值 Low。
Using numpy.where
,我们用字符串“Low”替换所有不匹配的数字,用字符串“High”替换所有匹配的数字。
返回值
您可以返回满足的值numpy.where
查询而不是像这样返回索引:
import numpy as np
array = np.array([25, 15, 35, 10, 40])
filtered_indices = np.where(array > 20)
filtered_values = array[filtered_indices]
print("Filtered indices:", filtered_indices)
print("Filtered values:", filtered_values)
Output:
Filtered indices: (array([0, 2, 4]),)
Filtered values: [25 35 40]
在这个例子中,我们首先使用numpy.where
查找条件数组 > 20 为 true 的索引。然后,我们使用这些索引从原始数组中提取相应的值。
结果是一个仅包含满足条件的值的新数组。
在多个条件下使用 where
这是一个演示如何使用的示例numpy.where
具有多个条件:
import numpy as np
array = np.array([5, 15, 25, 35, 45])
condition = (array > 20) & (array < 40)
result = np.where(condition, 'Match', 'No Match')
print(result)
Output:
['No Match' 'No Match' 'Match' 'Match' 'No Match']
在此示例中,我们使用逻辑 AND 运算符 & 来组合两个条件:array > 20 和 array
The numpy.where
函数对于同时满足条件(25 和 35)的元素返回 Match,对于不满足条件的元素返回 No Match。
将where与逻辑运算相结合
numpy.where
可以与逻辑运算结合起来在数组上创建复杂的查询。
通过使用逻辑运算符,例如&
(and), |
(或),以及~
(不是),可以组合多个条件。
这是一个演示组合的示例numpy.where
与逻辑运算:
import numpy as np
array = np.array([10, 20, 30, 40, 50])
result = np.where((array > 15) & (array < 45) | (array == 10), 'Selected', 'Not Selected')
print(result)
Output:
['Selected' 'Selected' 'Selected' 'Selected' 'Not Selected']
在此示例中,我们组合了三个条件:
1. (array > 15):选择大于15的元素。
2. (array 3. (array == 10):选择等于10的元素。
我们使用了&
运算符组合前两个条件和|
运算符包括第三个条件。
结果是一个数组,将除最后一个 (50) 之外的所有元素标记为'Selected'
.
将 where 与数学函数结合使用
The numpy.where
函数可以与数学函数结合以根据条件执行计算。
这允许您根据是否满足条件对元素应用不同的数学转换。
这是一个例子:
import numpy as np
array = np.array([1, 2, 3, 4, 5])
result = np.where(array > 3, np.square(array), np.sqrt(array))
print(result)
Output:
[1. 1.41421356 1.73205081 16. 25. ]
在此示例中,numpy.where 函数根据条件数组 > 3 应用两个不同的数学函数:
如果条件为 true,则应用 np.square 函数,对值进行平方。
如果条件为假,则应用 np.sqrt 函数,取值的平方根。
对于元素 1、2 和 3(条件为假),计算平方根。
对于元素 4 和 5(条件为真),计算平方。
嵌套 where 函数
The numpy.where
函数可以嵌套在其自身内以创建条件链,从而可以对输出进行更精细的控制。
当您想要应用多个级别的条件时,这非常有用。
这是一个嵌套的例子numpy.where
功能:
import numpy as np
array = np.array([5, 15, 25, 35, 45])
result = np.where(array < 20, 'Low', np.where(array < 40, 'Medium', 'High'))
print(result)
Output:
['Low' 'Low' 'Medium' 'Medium' 'High']
在这个例子中,我们使用了两个嵌套的numpy.where
函数将元素分为三组。
首先numpy.where
函数检查元素是否小于 20。如果为 true,则返回 Low。
如果为 false,则调用第二个numpy.where
函数,进一步将元素分类为“中”或“高”。
与原生 Python 的性能比较
这是使用两者的基准测试numpy.where
以及原生 Python 方法:
import numpy as np
import time
array = np.random.randint(0, 100, size=100000000)
# Using numpy.where
start_time = time.time()
result_np = np.where(array > 50, 'Greater', 'Smaller')
end_time = time.time()
print("Using numpy.where:", end_time - start_time)
# Using native Python
start_time = time.time()
result_python = ['Greater' if x > 50 else 'Smaller' for x in array]
end_time = time.time()
print("Using native Python:", end_time - start_time)
Output:
Using numpy.where: 1.0875394344329834
Using native Python: 10.121704816818237
在此比较中,我们使用以下方法测量了执行相同操作所需的时间numpy.where
以及原生的 Python 列表理解。
The numpy.where
速度明显更快,因为它利用了底层 C 实现并避免了 Python 的循环开销。
向量化运算 where
向量化操作是指立即将函数或操作应用于整个数组,而不是逐个元素地迭代它。
numpy.where
支持向量化操作,使其能够高效地进行大规模数据操作。
这是一个演示矢量化操作的示例numpy.where
:
import numpy as np
array1 = np.array([1, 2, 3, 4, 5])
array2 = np.array([5, 4, 3, 2, 1])
condition = array1 > array2
result = np.where(condition, array1 + array2, array1 - array2)
print(result)
Output:
[-4 -2 6 6 10]
在此示例中,我们创建了两个 NumPy 数组和一个比较它们对应元素的条件。
Using numpy.where
,我们根据条件应用了两种不同的向量化操作:
如果条件为真,则将 array1 和 array2 的相应元素相加。
如果条件为假,则将 array1 和 array2 的相应元素相减。
由于条件仅适用于第三个、第四个和第五个元素,因此将它们相加,而将其余元素相减。
使用 where 进行广播(处理不同的形状)
NumPy 中的广播是指能够对不同形状和大小的数组执行操作,并将其自动广播为通用形状。
这是一个例子:
import numpy as np
array = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
condition = np.array([True, False, True])
result = np.where(condition, array, -array)
print(result)
Output:
[[ 1 -2 3]
[ 4 -5 6]
[ 7 -8 9]]
在此示例中,条件数组的形状为 (3,),而数组的形状为 (3, 3)。
The numpy.where
函数广播条件以匹配数组的形状。
对于第一列和第三列(条件为真),保留原始值。
对于第二列(条件为假),值被否定。
资源
https://numpy.org/doc/stable/reference/ generated/numpy.where.html