Learning from Imbalanced Data

Learning from Imbalanced Data

Most machine learning tutorials work with the data sets built into the libraries like Tensorflow or Pytorch. These datasets are excellent for understanding the concepts behind a given algorithm. But, we all know that such data is far from what we see when working on real life problems. In real life, a huge amount of effort is invested in processing the available data to extract a useful, meaningful dataset that can train the AI models.

There are many different types of problems one might face in doing this. Each problem is unique and deserves a detailed analysis in order to fix it. But, there are some common classes of problems that show up in almost all projects. Researchers have worked on them, and provided us typical solutions that can help us overcome the problem at hand. Again, each problem is unique and it is possible that none of these solutions work for “this” problem. But, these solutions certainly give us a direction for proceeding with our analysis.

One of the major problems we face is about imbalanced data. The models learn from the data we provide. If the data is imbalanced, the resulting model is bound to be imbalanced — increasing the chances of over-fitting the imbalance. Consider for example, if we train a model to identify network intrusion based on various traffic parameters. We can gather data based on random samples over an year. That means a huge amount of data. But, if we think about it; most of the year was spent safely — with very few intrusion attempts. If we train the model based on such data, the model will end up learning that the world is a safe place and there are very few intrusions. Such bias towards “no intrusion” is disastrous for the security of the system.

The data we use has to be real — but it should be balanced. Each new packet should be treated as a potential intrusion. Thus, the data we use should have a comparable number of positives and negatives. Now, we have a problem. We have a good amount of intrusion data. But, that is too small compared to the rest. Do we discard most of the available data, just to achieve this “balance”? Was all the effort on collecting the data wasted? Not really. We have a few techniques that help us overcome this imbalance.

  • Custom Loss Function
  • Partial Augmentation
  • Clustered Ensembles

Let us look at each of these in detail. These are just a few ways in which we can alter the learning process to help it learn from the imbalanced data. Each of these touches different phases of the process. Thus, we can employ any or all of them — to do our job.

Custom Loss Function

This is perhaps the simplest and most intuitive way of dealing with the problem. We all know that the loss function is the most important aspect of any regression. All the effort of the regression algorithm is focused on minimizing the loss function. Typically, we use predefined loss functions like Means Square Error or Cross Entropy — that is built into implementation library we use. In simple words, this loss function is a measure of how wrong is the current model. The generic loss functions consider false positives as bad as false negatives. If our data is inclined towards negatives, a slight error on the positive side will lead to a huge value of loss. But an error on the negative side may not be noticed — because we have very few positives to raise concern about it. In such a case, if we alter the loss function such that the loss is amplified for a false negative; the balance will be restored again. Now, an error on either side will result in similar loss — a false positive will show us as high loss because of the number of negative samples. A false negative will show up as high loss because of the amplified loss on the small number of positives.

Partial Augmentation

Data augmentation is an important step in training any model. The data we get from observation is essentially a chunk of binary values obtained from the field. This contains information that we get from the field. Our understanding of the domain can help us add a lot more information to this. This process is called data augmentation. For example, if we train a model to generate transcript from audio signals — we can get a huge amount of data from the field. For the model, this data would be of the form — binary input mapped to binary output. From our domain knowledge, we can understand that the word spoken does not change if we make small changes to the pitch or speed of the sound. Using this understanding , we can generate a a lot more data from using we have. The augmentation technique we use depends upon the kind of problem we work. Each problem has a unique approach to the augmentation. But, we can find ways of doing it if we dig into the domain and the data available. If nothing else works, just adding Gaussian noise can help us to a great extent. If we have a huge data where only one case is under represented, we can use partial augmentation — where we augment the data for just one category. Of course, augmentation has its limits and we cannot get a million records out of one. But, it is a powerful technique and should be employed wherever possible.

Clustered Ensembles

Ensembles are a powerful way to reduce chances of overfitting as we generate rich models out of small data. Essentially, we train many small models using the available data. Each of these can be based on different techniques or hyperparameters — or just different initial values. Each will lead to a different model that generates an output slightly different from the others. Then we combine these individual models to generate one output — that is a lot more accurate than each individual model. Now, suppose if we have the training data with positive / negative ratio of 1:n. We can generate an ensemble of n small models — each trained with the whole set of positives and a part of the negatives. These small models can then be combined into an ensemble to form the final model. The accuracy of such an ensemble can be further enhanced by clustering the larger set. We can cluster the negative records — into n categories — and then build the n models using the chunk of positives along with each individual categories. The ensemble of these n models can be a lot more efficient. Intuitively, we can think of this as teaching each individual model how the positive is different from that category of negatives.

Summary

An AI model can only be as good as the data we use to train it. If the data is biased, the model too will be biased — unless we take special care to avoid the problem. There are several ways to work around the problem of biased training data. We checked three of the major ones.

If you are interested, you can check up arxiv.org/pdf/1106.1813.pdf for a detailed analysis.