Within a dataset we often face the problem of imbalance, in which some classes have a majority of observations and the rest have only a few observations. Class imbalance can be found in disease diagnosis, such as cancer detection, fraud detection, or spam filtering. In this article, we’ll look at a few techniques that can be used to deal with imbalanced datasets.
If we apply the wrong evaluation matrix on the imbalanced dataset, it can give us misleading results. Let us take the example of credit card fraud detection. There will be much less data points for “fraud” class (suppose 1%). So, if we take “accuracy” to measure the performance of the model and our model predicts all the observations as “not fraud” then it will give us a 99% accuracy score. In this way, we can understand that accuracy is not the best metric in this situation.
What can be used instead:
Confusion Matrix: It tells us about the false positive, false negative, true positive and true negative. It can be used to understand the number of correct and incorrect predictions.
Precision: The number of true positives divided by total positives; it is also known as Specificity. It tells us how many chosen instances by the model are relevant (fraud in our example).
Recall: It can be calculated by dividing the total number of true positives with the sum of true positives and false negatives. It is also known as Sensitivity. It tells us how many relevant instances are chosen by the model.
F1-Score: Harmonic mean of precision and recall.
AUC-ROC: It gives us the relation between the true positive rate and false positive rate.
Resample means to change the distribution of the imbalance classes in the dataset. We can either do undersampling by reducing the number of observations from the majority classes or we can do oversampling in which we increase the data for minority classes. Let’s see each of these types of resampling in action:
Undersampling: In this instance we’ll reduce the number of observations from the majority class. Python has a library for this - imblearn We can understand it from this example:
1
2
3
4
5
6
7
8
9
10
11
12
from collections
import Counter
from sklearn.datasets
import make_classification
from imblearn.under_sampling
import RandomUnderSampler
X, y = make_classification(n_classes = 2, random_state = 42, n_features = 20, n_informative = 3,
weights = [0.1, 0.9], n_redundant = 1, flip_y = 0, n_clusters_per_class = 1, n_samples = 1000, class_sep = 2)
print('Original data shape %s' % Counter(y))
random_sampler = RandomUnderSampler(random_state = 42)
X_res, y_res = random_sampler.fit_resample(X, y)
print('Resampled data shape %s' % Counter(y_res))
Oversampling: In this we’ll increase the number of observations for the minority class. We can understand it from this example:
1
2
3
4
5
6
7
8
9
10
11
12
from collections
import Counter
from sklearn.datasets
import make_classification
from imblearn.under_sampling
import RandomOverSampler
X, y = make_classification(n_classes = 2, random_state = 42, n_features = 20, n_informative = 3,
weights = [0.1, 0.9], n_redundant = 1, flip_y = 0, n_clusters_per_class = 1, n_samples = 1000, class_sep = 2)
print('Original data shape %s' % Counter(y))
random_sampler = RandomOverSampler(random_state = 42)
X_res, y_res = random_sampler.fit_resample(X, y)
print('Resampled data shape %s' % Counter(y_res))
NOTE: Before doing oversampling, we should split the dataset into train and test, otherwise the same data points can appear in both train and test data which will lead to overfitting, causing our model to give improper generalized results.
SMOTE: There is one more technique which can be used for oversampling: SMOTE. It stands for Synthetic Minority Over-Sampling Technique. In this technique we generate synthetic data for minority classes using nearest neighbors algorithm.
1
2
3
4
5
6
7
from imblearn.over_sampling
import SMOTE
target = df.Class
train = df.drop('Class', axis = 1)
X_train_data, X_test_data, y_train, y_test = train_test_split(train, target, test_size = 0.25, random_state = 27)
sm = SMOTE(random_state = 27, ratio = 1.0)
X_train, y_train = sm.fit_sample(X_train_data, y_train)
One way is to try different algorithms and see how they perform. Generally, algorithms like logistic regression or random forest tend to discard the rare classes while training. Decision tree can be used in this case, as it takes care of all the classes while building if-else rules. Another way to deal with imbalanced dataset is to use the ensemble technique. In that technique we split the majority class data into n samples and club them with rare classes to train n different models. This way we can generalize our final model.
I hope you have found this article helpful. This is a good starting point to learn how to handle an imbalanced dataset. You can try these various techniques to find the one that works best in your case. Till then keep learning :)