我想生成一条具有 5 倍交叉验证的精确召回曲线,显示标准偏差,如ROC 曲线代码示例在这里 https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html.
下面的代码(改编自如何在 Scikit-Learn 中绘制超过 10 倍交叉验证的 PR 曲线 https://stackoverflow.com/questions/29656550/how-to-plot-pr-curve-over-10-folds-of-cross-validation-in-scikit-learn) 给出了每次交叉验证的 PR 曲线以及平均 PR 曲线。我还想以灰色显示平均 PR 曲线上方和下方一个标准差的区域。但它给出了以下错误(详细信息在代码下面的链接中):
ValueError: operands could not be broadcast together with shapes (91,) (78,)
import matplotlib.pyplot as plt
import numpy
from sklearn.datasets import make_blobs
from sklearn.metrics import precision_recall_curve, auc
from sklearn.model_selection import KFold
from sklearn.svm import SVC
X, y = make_blobs(n_samples=500, n_features=2, centers=2, cluster_std=10.0,
random_state=10)
k_fold = KFold(n_splits=5, shuffle=True, random_state=10)
predictor = SVC(kernel='linear', C=1.0, probability=True, random_state=10)
y_real = []
y_proba = []
precisions, recalls = [], []
for i, (train_index, test_index) in enumerate(k_fold.split(X)):
Xtrain, Xtest = X[train_index], X[test_index]
ytrain, ytest = y[train_index], y[test_index]
predictor.fit(Xtrain, ytrain)
pred_proba = predictor.predict_proba(Xtest)
precision, recall, _ = precision_recall_curve(ytest, pred_proba[:,1])
lab = 'Fold %d AUC=%.4f' % (i+1, auc(recall, precision))
plt.plot(recall, precision, alpha=0.3, label=lab)
y_real.append(ytest)
y_proba.append(pred_proba[:,1])
precisions.append(precision)
recalls.append(recall)
y_real = numpy.concatenate(y_real)
y_proba = numpy.concatenate(y_proba)
precision, recall, _ = precision_recall_curve(y_real, y_proba)
lab = 'Overall AUC=%.4f' % (auc(recall, precision))
plt.plot(recall, precision, lw=2,color='red', label=lab)
std_precision = np.std(precisions, axis=0)
tprs_upper = np.minimum(precisions[median] + std_precision, 1)
tprs_lower = np.maximum(precisions[median] - std_precision, 0)
plt.fill_between(recall_overall, upper_precision, lower_precision, alpha=0.5, linewidth=0, color='grey')
报告错误并生成绘图 https://i.stack.imgur.com/5ZPKh.png
您能否建议我如何添加以下代码以显示平均 PR 曲线周围的一个标准差?