Splitting a Dataset for Machine Learning
Repository · Notebook
📬 Receive new lessons straight to your inbox (once a month) and join 30K+ developers in learning how to responsibly deliver value with ML.
To determine the efficacy of our models, we need to have an unbiased measuring approach. To do this, we split our dataset into
testing data splits.
- Use the training split to train the model.
Here the model will have access to both inputs and outputs to optimize its internal weights.
- After each loop (epoch) of the training split, we will use the validation split to determine model performance.
Here the model will not use the outputs to optimize its weights but instead, we will use the performance to optimize training hyperparameters such as the learning rate, etc.
- After training stops (epoch(s)), we will use the testing split to perform a one-time assessment of the model.
This is our best measure of how the model may behave on new, unseen data. Note that training stops when the performance improvement is not significant or any other stopping criteria that we may have specified.
Creating proper data splits
What are the criteria we should focus on to ensure proper data splits?
- the dataset (and each data split) should be representative of data we will encounter
- equal distributions of output values across all splits
- shuffle your data if it's organized in a way that prevents input variance
- avoid random shuffles if your task can suffer from data leaks (ex.
We need to clean our data first before splitting, at least for the features that splitting depends on. So the process is more like: preprocessing (global, cleaning) → splitting → preprocessing (local, transformations).
We'll start by splitting our dataset into three data splits for training, validation and testing.
1 2 3 4
For our multi-class task (each input has one label), we want to ensure that each data split has similar class distributions. We can achieve this by specifying how to stratify the split by adding the
stratify keyword argument.
1 2 3
train: 668 (0.70) remaining: 287 (0.30)
1 2 3
1 2 3
train: 668 (0.70) val: 143 (0.15) test: 144 (0.15)
1 2 3 4 5
1 2 3 4 5 6
It's hard to compare these because our train and test proportions are different. Let's see what the distribution looks like once we balance it out. What do we need to multiply our test ratio by so that we have the same amount as our train ratio?
1 2 3 4 5 6 7
1 2 3 4 5 6
We can see how much deviance there is in our naive data splits by computing the standard deviation of each split's class counts from the mean (ideal split).
1 2 3 4 5
|0||laplacian pyramid reconstruction refinement se...||computer-vision|
|1||extract stock sentiment news headlines project...||natural-language-processing|
|2||big bad nlp database collection 400 nlp datasets...||natural-language-processing|
|3||job classification job classification done usi...||natural-language-processing|
|4||optimizing mobiledet mobile deployments learn ...||computer-vision|
If we had a multi-label classification task, then we would've applied iterative stratification via the skmultilearn library, which essentially splits each input into subsets (where each label is considered individually) and then it distributes the samples starting with fewest "positive" samples and working up to the inputs that have the most labels.
from skmultilearn.model_selection import IterativeStratification def iterative_train_test_split(X, y, train_size): """Custom iterative train test split which 'maintains balanced representation with respect to order-th label combinations.' """ stratifier = IterativeStratification( n_splits=2, order=1, sample_distribution_per_fold=[1.0-train_size, train_size, ]) train_indices, test_indices = next(stratifier.split(X, y)) X_train, y_train = X[train_indices], y[train_indices] X_test, y_test = X[test_indices], y[test_indices] return X_train, X_test, y_train, y_test
Iterative stratification essentially creates splits while "trying to maintain balanced representation with respect to order-th label combinations". We used to an
order=1 for our iterative split which means we cared about providing representative distribution of each tag across the splits. But we can account for higher-order label relationships as well where we may care about the distribution of label combinations.
To cite this content, please use:
1 2 3 4 5 6