您可以使用 numpy 广播规则生成笛卡尔积。这numpy.ix_
函数创建适当数组的列表。它相当于下面的内容:
>>> def pseudo_ix_gen(*arrays):
... base_shape = [1 for arr in arrays]
... for dim, arr in enumerate(arrays):
... shape = base_shape[:]
... shape[dim] = len(arr)
... yield numpy.array(arr).reshape(shape)
...
>>> def pseudo_ix_(*arrays):
... return list(pseudo_ix_gen(*arrays))
或者,更简洁地说:
>>> def pseudo_ix_(*arrays):
... shapes = numpy.diagflat([len(a) - 1 for a in arrays]) + 1
... return [numpy.array(a).reshape(s) for a, s in zip(arrays, shapes)]
结果是可广播数组的列表:
>>> numpy.ix_(*[[2, 4], [1, 3], [0, 2]])
[array([[[2]],
[[4]]]), array([[[1],
[3]]]), array([[[0, 2]]])]
将其与结果进行比较numpy.ogrid
:
>>> numpy.ogrid[0:2, 0:2, 0:2]
[array([[[0]],
[[1]]]), array([[[0],
[1]]]), array([[[0, 1]]])]
正如你所看到的,它是相同的,但是numpy.ix_
允许您使用非连续索引。现在,当我们应用 numpy 广播规则时,我们得到一个笛卡尔积:
>>> list(numpy.broadcast(*numpy.ix_(*[[2, 4], [1, 3], [0, 2]])))
[(2, 1, 0), (2, 1, 2), (2, 3, 0), (2, 3, 2),
(4, 1, 0), (4, 1, 2), (4, 3, 0), (4, 3, 2)]
如果,而不是传递结果numpy.ix_
to numpy.broadcast
,我们用它来索引一个数组,我们得到这个:
>>> a = numpy.arange(6 ** 4).reshape((6, 6, 6, 6))
>>> a[numpy.ix_(*[[2, 4], [1, 3], [0, 2]])]
array([[[[468, 469, 470, 471, 472, 473],
[480, 481, 482, 483, 484, 485]],
[[540, 541, 542, 543, 544, 545],
[552, 553, 554, 555, 556, 557]]],
[[[900, 901, 902, 903, 904, 905],
[912, 913, 914, 915, 916, 917]],
[[972, 973, 974, 975, 976, 977],
[984, 985, 986, 987, 988, 989]]]])
然而,买者自负。可广播数组对于索引很有用,但如果您确实想要枚举值,你可能最好使用itertools.product
:
>>> %timeit list(itertools.product(range(5), repeat=5))
10000 loops, best of 3: 196 us per loop
>>> %timeit list(numpy.broadcast(*numpy.ix_(*([range(5)] * 5))))
100 loops, best of 3: 2.74 ms per loop
所以如果你无论如何都要合并一个 for 循环,那么itertools.product
可能会更快。不过,您仍然可以使用上述方法在纯 numpy 中获取一些类似的数据结构:
>> pgrid_idx = numpy.ix_(*[[2, 4], [1, 3], [0, 2]])
>>> sub_indices = numpy.rec.fromarrays(numpy.indices((6, 6, 6)))
>>> a[pgrid_idx].reshape((8, 6))
array([[468, 469, 470, 471, 472, 473],
[480, 481, 482, 483, 484, 485],
[540, 541, 542, 543, 544, 545],
[552, 553, 554, 555, 556, 557],
[900, 901, 902, 903, 904, 905],
[912, 913, 914, 915, 916, 917],
[972, 973, 974, 975, 976, 977],
[984, 985, 986, 987, 988, 989]])
>>> sub_indices[pgrid_idx].reshape((8,))
rec.array([(2, 1, 0), (2, 1, 2), (2, 3, 0), (2, 3, 2),
(4, 1, 0), (4, 1, 2), (4, 3, 0), (4, 3, 2)],
dtype=[('f0', '<i8'), ('f1', '<i8'), ('f2', '<i8')])