我遇到一个 scipy 函数,无论传递给它什么,它似乎都会返回一个 numpy 数组。在我的应用程序中,我需要能够仅传递标量和列表,因此唯一的“问题”是,当我将标量传递给函数时,返回一个包含一个元素的数组(当我期望标量时)。我应该忽略这种行为,还是修改该函数以确保在传递标量时返回标量?
示例代码:
#! /usr/bin/env python
import scipy
import scipy.optimize
from numpy import cos
# This a some function we want to compute the inverse of
def f(x):
y = x + 2*cos(x)
return y
# Given y, this returns x such that f(x)=y
def f_inverse(y):
# This will be zero if f(x)=y
def minimize_this(x):
return y-f(x)
# A guess for the solution is required
x_guess = y
x_optimized = scipy.optimize.fsolve(minimize_this, x_guess) # THE PROBLEM COMES FROM HERE
return x_optimized
# If I call f_inverse with a list, a numpy array is returned
print f_inverse([1.0, 2.0, 3.0])
print type( f_inverse([1.0, 2.0, 3.0]) )
# If I call f_inverse with a tuple, a numpy array is returned
print f_inverse((1.0, 2.0, 3.0))
print type( f_inverse((1.0, 2.0, 3.0)) )
# If I call f_inverse with a scalar, a numpy array is returned
print f_inverse(1.0)
print type( f_inverse(1.0) )
# This is the behaviour I expected (scalar passed, scalar returned).
# Adding [0] on the return value is a hackey solution (then thing would break if a list were actually passed).
print f_inverse(1.0)[0] # <- bad solution
print type( f_inverse(1.0)[0] )
在我的系统上,其输出是:
[ 2.23872989 1.10914418 4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989 1.10914418 4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989]
<type 'numpy.ndarray'>
2.23872989209
<type 'numpy.float64'>
我正在使用 MacPorts 提供的 SciPy 0.10.1 和 Python 2.7.3。
SOLUTION
阅读下面的答案后,我决定采用以下解决方案。将返回线替换为f_inverse
with:
if(type(y).__module__ == np.__name__):
return x_optimized
else:
return type(y)(x_optimized)
Here return type(y)(x_optimized)
导致返回类型与调用函数的类型相同。不幸的是,如果 y 是 numpy 类型,这不起作用,所以if(type(y).__module__ == np.__name__)
习惯于使用此处介绍的想法检测 numpy 类型 https://stackoverflow.com/questions/12569452/how-to-identify-numpy-types-in-python/12569453#12569453并将它们排除在类型转换之外。