# 预测脸的下半部分
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_olivetti_faces
from sklearn.utils.validation import check_random_state
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
data = fetch_olivetti_faces()
targets = data.target
print data.images
data = data.images.reshape((len(data.images), -1))
train = data[targets < 30] # data 和 target 的行数必须要相同
test = data[targets >= 30]
n_faces = 5
rng = check_random_state(4)
"""
check_random_state 函数说明:
将种子转换为np.random.RandomState实例
如果seed为None,返回np.random使用的RandomState单例。
如果seed是一个int,返回一个新的RandomState实例种子。
如果seed已经是一个RandomState实例,则返回它。
否则引发ValueError。
提出修改建议
"""
face_ids = rng.randint(test.shape[0], size=(n_faces, ))
test = test[face_ids, :]
n_pixels = data.shape[1]
# 将图像分为上下部分
X_train = train[:, :np.ceil(0.5 * n_pixels)]
y_train = train[:, np.floor(0.5 * n_pixels):]
X_test = test[:, :np.ceil(0.5 * n_pixels)]
y_test = test[:, np.floor(0.5 * n_pixels):]
ESTIMATORS = {
"Extra trees": ExtraTreesRegressor(n_estimators=10, max_features=32, random_state=0),
"K-nn": KNeighborsRegressor(),
"Linear regression": LinearRegression(),
"Ridge": RidgeCV(),
}
y_test_predict = dict()
for name, estimator in ESTIMATORS.items():
estimator.fit(X_train, y_train)
y_test_predict[name] = estimator.predict(X_test)
image_shape = (64, 64)
n_cols = 1 + len(ESTIMATORS)
plt.figure(figsize=(2. * n_cols, 2.26 * n_faces))
plt.suptitle("Face completion with multi-output estimators", size=16)
for i in range(n_faces):
true_face = np.hstack((X_test[i], y_test[i]))
"""
hstack 函数说明:
使多维数组变为一维数组,如:
>>> c
array([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
>>> np.hstack(c)
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
或者使两个数组合并为一个数组(但是行数必须要相同),如:
>>> np.hstack((a, c))
array([[1, 4, 7, 1, 2, 3],
[2, 5, 8, 4, 5, 6],
[3, 6, 9, 7, 8, 9]], dtype=int64)
其它类似的函数如下:
stack:沿着新轴连接数组序列。
vstack:按照垂直(按行)顺序堆叠数组。
dstack:按照深度顺序堆叠数组(沿第三轴)。
concatenate:沿着现有轴连接数组序列。
hsplit:沿第二轴拆分数组
"""
if i:
sub = plt.subplot(n_faces, n_cols, i * n_cols + 1)
else:
sub = plt.subplot(n_faces, n_cols, i * n_cols + 1, title="true faces")
sub.axis("off")
sub.imshow(true_face.reshape(image_shape), cmap=plt.cm.gray, interpolation="nearest")
for j, est in enumerate(sorted(ESTIMATORS)):
completed_face = np.hstack((X_test[i], y_test_predict[est][i]))
if i:
sub = plt.subplot(n_faces, n_cols, i * n_cols + 2 + j)
else:
sub = plt.subplot(n_faces, n_cols, i * n_cols + 2 + j, title=est)
sub.axis("off")
sub.imshow(completed_face.reshape(image_shape), cmap=plt.cm.gray, interpolation="nearest")
plt.show()