Building our first classification model
If the goal is to separate the three types of flowers, we can immediately make a few suggestions just by looking at the data. For example, petal length seems to be able to separate Iris Setosa from the other two flower species on its own.
Intuitively, we can build a simple model in our heads: if the petal width is smaller than about 1, then this is an Iris Setosa flower; otherwise it is either Iris Virginica or Iris Versicolor. Machine learning is when we write code to look for this type of separation automatically.
The problem of recognizing Iris Setosa apart from the other two species was very easy. However, we cannot immediately identify the best cut for distinguishing Iris Virginica from Iris Versicolor. We can even see that we will never achieve perfect separation with a simple rule like, if feature X is above a certain value, then A, or else B.
We can try to combine multiple rules in a decision tree. This is one of the simplest models for classification and was one of the first models to be proposed for machine learning. It has the further advantage that the model can be simple to interpret.
With scikit-learn, it is easy to learn a decision tree:
from sklearn import tree tr = tree.DecisionTreeClassifier(min_samples_leaf=10)
tr.fit(features, labels)
That's it. Visualizing the tree requires that we first write it out to a file in dot format and then display it:
import graphviz
tree.export_graphviz(tr, feature_names=feature_names, round-ed=True, out_file='decision.dot')
graphviz.Source(open('decision.dot').read())
We can see that the first split is petal width and results in two nodes, one node where all the samples are of the first class (denoted by [50,0,0]) and the rest of the data ([0,50,50]).
How good is this model? We can try it out by applying it to the data (using the predict method) and seeing how well it matches with the input:
prediction = tr.predict(features) print("Accuracy: {:.1%}".format(np.mean(prediction == labels)))
This prints out the accuracy: 96.0 percent.