MNIST with scikit-learn

Solving MNIST with scikit-learn is easy.

from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

mnist = datasets.fetch_mldata('MNIST original', data_home='data')
train_data, test_data, train_labels, test_labels = train_test_split(
    mnist.data, mnist.target)

clf = KNeighborsClassifier()
clf.fit(train_data, train_labels)

predicted_labels = clf.predict(test_data)
print('Prediction accuracy:', accuracy_score(test_labels, predicted_labels))

I get 97% accuracy which is great for such a simple implementation.

With train_test_split() I split the data into training and testing segments. For simplicity I used K-nearest neighbour classifier but potentially others could be used too (http://brianfarris.me/static/digit_recognizer.html).

Reusability

clf.fit() takes some time to train the classifier. But we can store the trained classifier on disk and reuse it multiple times:

from sklearn.externals import joblib

clf = KNeighborsClassifier()
clf.fit(train_data, train_labels)
joblib.dump(clf, 'classifier_model.pkl')

# ...

clf = joblib.load('classifier_model.pkl')
predicted_labels = clf.predict(test_data)

The serialized classifier takes up to 520 MB on disk.

Comments