Machine Learning with Scala Quick Start Guide
上QQ阅读APP看书,第一时间看更新

Preparing training data and training a classifier

Next, we separate the training set and test sets. Let's say that 80% of the training set will be used for the training and the other 20% will be used to evaluate the trained model:

val splits = numericDF.randomSplit(Array(0.8, 0.2))
val trainDF = splits(0)
val testDF = splits(1)

Instantiate a decision tree classifier by specifying impurity, max bins, and the max depth of the trees. Additionally, we set the label and feature columns:

val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxBins(10)
.setMaxDepth(30)
.setLabelCol("label")
.setFeaturesCol("features")

Now that the data and the classifier are ready, we can perform the training:

val dtModel = dt.fit(trainDF)