Machine Learning for dummies – overfitting

Before we can move to more advanced topics, we should discuss over and underfitting. It reminded me my high school biology lessons. We were expected to recognize different kinds of flowers. In the lessons we were given images of flowers and we were supposed to learn that this flower is yellow and has this type of leafs and this one is blue with different kind of leafs. But there was an alternative strategy how to pass the exam. The images had numbers in the corner and it's been much more easier to memorize the numbers with corresponding names. You did not learn anything about the flowers but if your only objective was to pass the exam, it was perfect.

And exactly the same problem we have in machine learning. If you do not have many training examples and you have complicated model, the learning algorithm may just memorize the training samples without trying to generalize. It basically does the same thing what my classmates did in the biology lessons. It is great if you only want to classify samples from the training set. If you however want to classify unknown samples, you want your algorithm to generalize.

Let's look at some examples. I will use the same algorithm as the last time on notMNIST data, I will just make the training set smaller, let's say 1k samples. In noMNIST set there are 10 letters A-J. We have shuffled our training set so in those 1k samples there will be around 100 samples of each letter. I am using Logistic Regression in 784 dimensional space which maps to those 10 classes. We also have bias elements for each class so the model has 7850 parameters. Parameters are just some numbers which the algorithm learns based on the samples. We dummies do not need to know more. So the algorithm learns ~8000 numbers from 1000 samples. Or for each letter ~800 parameters from ~100 samples. There is quite substantial danger that the model will just memorize the training set.

from sklearn import linear_model

clf_l = linear_model.LogisticRegression()
# we do not want to use the whole set
size = 1000

# images are 28x28, we have to make a vector from them
tr = train_dataset[:size,:,:].reshape(size, 784)

# This is the most important line. I just feed the model training samples and it learns
clf_l.fit(tr, train_labels[:size])

# This is prediction on training set
prd_tr = clf_l.predict(tr)
print(float(sum(prd_tr == train_labels[:size]))/prd_tr.shape[0])

# let's check if the model is not cheating, it has never seen this letters before
prd_l = clf_l.predict(valid_dataset.reshape(valid_dataset.shape[0], 784))
float(sum(prd_l == valid_labels))/prd_l.shape[0]

And indeed, the model has been able to correctly classify 99% of the samples from the training set. That's why we have the validation set. We will use it to check our model. It's as if our biology teacher had one set of images for the lessons and another set for the exams. Actually, it's more complicated, we have also test set, but we will not be using that one so often so let's forget about it for now.

Our model is able to recognize 74% of samples from our validation set. Not to bad. Luckily for us, Logistic Regression is just trying to find linear (flat) plane to split the space so it is not able to cheat much. It can not simply pick the training examples, it has to take half of the space with them. But let's try to force it to cheat anyway. Logistic Regression has parameter C (as for cheating) which says how much we allow the algorithm to over-fit. Higher C means more cheating, lower C means less cheating (for those who are used to more usual notation C = 1/lambda)

# Cheating level Carnage
clf_l = linear_model.LogisticRegression(C=100)

If we tune cheating parameter to 100, we get 99.9% of success on training set and 71% on validation set. If we use a value from other side of the spectrum C=0.001 we get 75% on training set and 76% on validation set. It's up to us to decide what is better. Here is a table for some values of C for 1k samples.

C tr. set valid. set
0.001 75% 76%
0.01 81% 79%
0.1 92% 78%
1 99% 75%
10 99% 72%
100 99.9% 71%

Another other way for preventing overfitting is to provide more training examples. Imagine our biology lesson with much more samples with different numbers. From certain number of samples, it is much easier to try to learn how to recognize those damn flowers than to try to memorize all the numbers in the corner. And the same applies for machine learning algorithms too. If I use C=100 with 10k training samples, I get 93% accuracy on training set and 74% on validation set. You can notice that it's actually worse than most of the results with 1k of examples and different C values. But we can combine both approaches. Let's say that I pick C=0.01 and use 20k training samples. I get 83% accuracy on training set and 82% on validation set. As you can see, both numbers are converging. There is only 1% difference. It means that we are approaching limits of our model. Using more samples is not likely to help much. I am afraid we have to move to better model to get better results. That's it for today. You can go now.

Update: Fixed number of parameters

2 Responses to “Machine Learning for dummies – overfitting”

  1. Miroslav Spousta Says:

    Nice, looking forward for following posts!

    "It's up to us to decide what is better." -- for the model performance, training set result is not very useful, usually you only focus on the validation set performance. There are simple techniques like cross-validation that makes better estimate of model performance (split data into training/validation sets multiple times and average result). The reason is that (large enough) data is usually not very homogeneous.

  2. Lukáš Křečan Says:

    You are right, you are not a dummy.

Leave a Reply