机器学习好伙伴之scikit-learn的使用——学习曲线
什么是学习曲线呢,其内容主要包含当训练量增加时,loss的变化情况。
什么是学习曲线
学习曲线主要反应的是学习的一个过程,常用的表示方法是训练集的loss和测试集的loss与训练量之间的关系。其示意图如下:
sklearn中学习曲线的实现
在进行学习曲线的绘制之前,首先要导入学习曲线的绘制的模块。
from sklearn.model_selection import learning_curve
学习曲线的绘制的重要函数是:
learning_curve(
estimator,
X, y,
train_sizes=array([0.1, 0.325, 0.55, 0.775, 1. ]),
cv=None,
scoring=None,
exploit_incremental_learning=False,
n_jobs=1,
pre_dispatch='all',
verbose=0
)
其常用参数如下:
1、estimator:用于预测的模型
2、X:预测的特征数据
3、y:预测结果
4、train_sizes:训练样本相对的或绝对的数字,这些量的样本将会生成learning curve,当其为[0.1, 0.325, 0.55, 0.775, 1. ]时代表使用10%训练集训练,32.5%训练集训练,55%训练集训练,77.5%训练集训练100%训练集训练时的分数。
5、cv:交叉验证生成器或可迭代的次数
6、scoring:调用的方法
可进行的scoring方式具体可以查阅https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
使用方式如下:
train_sizes, train_loss, test_loss = learning_curve(
SVC(gamma=0.01), X, y, cv=10, scoring='neg_mean_squared_error',
train_sizes=np.linspace(.1, 1.0, 5))
代表使用SVM的分类模型,输入特征为X,输出label为y,进行10折交叉验证,通过均值平方差的方式计分,学习曲线分为5段。
其一共具有3个返回值,分别是train_sizes, train_loss, test_loss,其中train_loss指的是训练集的loss,其shape为(5,10),第n行对应学习曲线的第n段,第n行的内容代表着第n段的10折交叉验证的结果;test_loss的含义与train_loss类似,其对应的是测试集的loss。
应用示例
代码源自莫烦python教学网站
# 学习曲线模块
from sklearn.model_selection import learning_curve
# 导入digits数据集
from sklearn.datasets import load_digits
# 支持向量机
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
# neg_mean_squared_error代表求均值平方差
train_sizes, train_loss, test_loss = learning_curve(
SVC(gamma=0.01), X, y, cv=10, scoring='neg_mean_squared_error',
train_sizes=np.linspace(.1, 1.0, 5))
# loss值为负数,需要取反
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
# 设置样式与label
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",
label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",
label="Cross-validation")
plt.xlabel("Training examples")
plt.ylabel("Loss")
# 显示图例
plt.legend(loc="best")
plt.show()
实验结果为:
如上图所示的训练结果存在过拟合的现象。
调整GAMMA = 0.001后过拟合现象消失,Cross-validation不再上升。