Machine Learning Lesson of the Day – Cross-Validation
January 17, 2014 Leave a comment
Validation is a good way to assess the predictive accuracy of a supervised learning algorithm, and the rule of thumb of using 70% of the data for training and 30% of the data for validation generally works well. However, what if the data set is not very large, and the small amount of data for training results in high sampling error? A good way to overcome this problem is K-fold cross-validation.
Cross-validation is best defined by describing its steps:
For each model under consideration,
- Divide the data set into K partitions.
- Designate the first partition as the validation set and designate the other partitions as the training set.
- Use training set to train the algorithm.
- Use the validation set to assess the predictive accuracy of the algorithm; the common measure of predictive accuracy is mean squared error.
- Repeat Steps 2-4 for the second partition, third partition, … , the (K-1)th partition, and the Kth partition. (Essentially, rotate the designation of validation set through every partition.)
- Calculate the average of the mean squared error from all K validations.
Compare the average mean squared errors of all models and pick the one with the smallest average mean squared error as the best model. Test all models on a separate data set (called the test set) to assess their predictive accuracies on new, fresh data.
If there are N data in the data set, and K = N, then this type of K-fold cross-validation has a special name: leave-one-out cross-validation (LOOCV).
There some trade-offs between a large and a small K. The estimator for the prediction error from a larger K results in
- less bias because of more data being used for training
- higher variance because of the higher similarity and lower diversity between the training sets
- slower computation because of more data being used for training
In The Elements of Statistical Learning (2009 Edition, Chapter 7, Page 241-243), Hastie, Tibshirani and Friedman recommend 5 or 10 for K.