Skip to content

Splitting a Dataset for Machine Learning


Appropriately splitting our dataset for training, validation and testing.
Goku Mohandas
Goku Mohandas
· · ·
Repository ยท Notebook

Subscribe to our newsletter

๐Ÿ“ฌ  Receive new lessons straight to your inbox (once a month) and join 40K+ developers in learning how to responsibly deliver value with ML.


Intuition

To determine the efficacy of our models, we need to have an unbiased measuring approach. To do this, we split our dataset into training, validation, and testing data splits.

  1. Use the training split to train the model.

    Here the model will have access to both inputs and outputs to optimize its internal weights.

  2. After each loop (epoch) of the training split, we will use the validation split to determine model performance.

    Here the model will not use the outputs to optimize its weights but instead, we will use the performance to optimize training hyperparameters such as the learning rate, etc.

  3. After training stops (epoch(s)), we will use the testing split to perform a one-time assessment of the model.

    This is our best measure of how the model may behave on new, unseen data. Note that training stops when the performance improvement is not significant or any other stopping criteria that we may have specified.

Creating proper data splits

What are the criteria we should focus on to ensure proper data splits?

Show answer
  • the dataset (and each data split) should be representative of data we will encounter
  • equal distributions of output values across all splits
  • shuffle your data if it's organized in a way that prevents input variance
  • avoid random shuffles if your task can suffer from data leaks (ex. time-series)

We need to clean our data first before splitting, at least for the features that splitting depends on. So the process is more like: preprocessing (global, cleaning) โ†’ splitting โ†’ preprocessing (local, transformations).

Naive split

We'll start by splitting our dataset into three data splits for training, validation and testing.

1
from sklearn.model_selection import train_test_split
1
2
3
4
# Split sizes
train_size = 0.7
val_size = 0.15
test_size = 0.15

For our multi-class task (each input has one label), we want to ensure that each data split has similar class distributions. We can achieve this by specifying how to stratify the split by adding the stratify keyword argument.

1
2
3
# Split (train)
X_train, X_, y_train, y_ = train_test_split(
    X, y, train_size=train_size, stratify=y)
1
2
print (f"train: {len(X_train)} ({(len(X_train) / len(X)):.2f})\n"
       f"remaining: {len(X_)} ({(len(X_) / len(X)):.2f})")

train: 668 (0.70)
remaining: 287 (0.30)

1
2
3
# Split (test)
X_val, X_test, y_val, y_test = train_test_split(
    X_, y_, train_size=0.5, stratify=y_)
1
2
3
print(f"train: {len(X_train)} ({len(X_train)/len(X):.2f})\n"
      f"val: {len(X_val)} ({len(X_val)/len(X):.2f})\n"
      f"test: {len(X_test)} ({len(X_test)/len(X):.2f})")

train: 668 (0.70)
val: 143 (0.15)
test: 144 (0.15)

1
2
3
4
5
# Get counts for each class
counts = {}
counts["train_counts"] = {tag: label_encoder.decode(y_train).count(tag) for tag in label_encoder.classes}
counts["val_counts"] = {tag: label_encoder.decode(y_val).count(tag) for tag in label_encoder.classes}
counts["test_counts"] = {tag: label_encoder.decode(y_test).count(tag) for tag in label_encoder.classes}
1
2
3
4
5
6
# View distributions
pd.DataFrame({
    "train": counts["train_counts"],
    "val": counts["val_counts"],
    "test": counts["test_counts"]
}).T.fillna(0)

computer-vision mlops natural-language-processing other
train 249 55 272 92
val 53 12 58 20
test 54 12 58 20

It's hard to compare these because our train and test proportions are different. Let's see what the distribution looks like once we balance it out. What do we need to multiply our test ratio by so that we have the same amount as our train ratio?

\[ \alpha * N_{test} = N_{train} \]
\[ \alpha = \frac{N_{train}}{N_{test}} \]

1
2
3
4
5
6
7
# Adjust counts across splits
for k in counts["val_counts"].keys():
    counts["val_counts"][k] = int(counts["val_counts"][k] * \
        (train_size/val_size))
for k in counts["test_counts"].keys():
    counts["test_counts"][k] = int(counts["test_counts"][k] * \
        (train_size/test_size))
1
2
3
4
5
6
dist_df = pd.DataFrame({
    "train": counts["train_counts"],
    "val": counts["val_counts"],
    "test": counts["test_counts"]
}).T.fillna(0)
dist_df

computer-vision mlops natural-language-processing other
train 249 55 272 92
val 247 56 270 93
test 252 56 270 93

We can see how much deviance there is in our naive data splits by computing the standard deviation of each split's class counts from the mean (ideal split).

\[ \sigma = \sqrt{\frac{(x - \bar{x})^2}{N}} \]
1
2
# Standard deviation
np.mean(np.std(dist_df.to_numpy(), axis=0))
0.9851056877051131
1
2
3
4
5
# Split DataFrames
train_df = pd.DataFrame({"text": X_train, "tag": label_encoder.decode(y_train)})
val_df = pd.DataFrame({"text": X_val, "tag": label_encoder.decode(y_val)})
test_df = pd.DataFrame({"text": X_test, "tag": label_encoder.decode(y_test)})
train_df.head()
text tags
0 laplacian pyramid reconstruction refinement se... computer-vision
1 extract stock sentiment news headlines project... natural-language-processing
2 big bad nlp database collection 400 nlp datasets... natural-language-processing
3 job classification job classification done usi... natural-language-processing
4 optimizing mobiledet mobile deployments learn ... computer-vision

Multi-label classification

If we had a multi-label classification task, then we would've applied iterative stratification via the skmultilearn library, which essentially splits each input into subsets (where each label is considered individually) and then it distributes the samples starting with fewest "positive" samples and working up to the inputs that have the most labels.

from skmultilearn.model_selection import IterativeStratification
def iterative_train_test_split(X, y, train_size):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'
    """
    stratifier = IterativeStratification(
        n_splits=2, order=1, sample_distribution_per_fold=[1.0-train_size, train_size, ])
    train_indices, test_indices = next(stratifier.split(X, y))
    X_train, y_train = X[train_indices], y[train_indices]
    X_test, y_test = X[test_indices], y[test_indices]
    return X_train, X_test, y_train, y_test

Iterative stratification essentially creates splits while "trying to maintain balanced representation with respect to order-th label combinations". We used to an order=1 for our iterative split which means we cared about providing representative distribution of each tag across the splits. But we can account for higher-order label relationships as well where we may care about the distribution of label combinations.


Upcoming live cohorts

Sign up for our upcoming live cohort, where we'll provide live lessons + QA, compute (GPUs) and community to learn everything in one day.


To cite this content, please use:

1
2
3
4
5
6
@article{madewithml,
    author       = {Goku Mohandas},
    title        = { Splitting a Dataset for Machine Learning - Made With ML },
    howpublished = {\url{https://madewithml.com/}},
    year         = {2023}
}