昨天我使用Spark MLlib的朴素贝叶斯进行手写数字识别,准确率在0.83左右,今天使用了RandomForest
来训练模型,并进行了参数调优。
首先来说说RandomForest
训练分类器时使用到的一些参数:
val numClasses = 10
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3
val featureSubsetStrategy = "auto"
val impurity = "gini"
val maxDepth = 4
val maxBins = 32
第一次我将树的数目设定为3,每棵树深度为4。下面开始训练模型:
val randomForestModel = RandomForest.trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
与使用朴素贝叶斯时评估准确率方式一样,我使用训练数据来计算准确率:
val nbTotalCorrect = data.map { point =>
if (randomForestModel.predict(point.features) == point.label) 1 else 0
}.sum
val numData = data.count()
println(numData)
val nbAccuracy = nbTotalCorrect / numData
下面是每次对上面所说到的两个参数进行调整后得到的准确率:
可以发现,准确率在numTree=11,maxDepth=12
附近开始收敛到0.999。这次得到的准确率要比上次使用朴素贝叶斯训练得出的准确率(0.826)要高出许多。现在开始对测试数据进行预测,使用的参数是numTree=29,maxDepth=30
:
val predictions = randomForestModel.predict(features).map { p => p.toInt }
把训练出来的结果上传到Kaggle上,得到的准确率为0.95929
,经过我的四次参数调整,得到的最高的准确率是0.96586
,设置的参数是:numTree=55,maxDepth=30
,当我将参数改为numTree=70,maxDepth=30
时,准确率有所下降,为0.96271
,看来这个时候出现过拟合了。不过准确率能从昨天的0.83提高到0.96还是挺兴奋的,我还会继续尝试使用其他方式进行手写数字识别,不知何时能达到1.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)