Skip to content

Evaluating Machine Learning Models


Evaluating ML models by assessing overall, per-class and slice performances.
Goku Mohandas
· ·
Repository Β· Notebook

πŸ“¬  Receive new lessons straight to your inbox (once a month) and join 30K+ developers in learning how to responsibly deliver value with ML.

Intuition

Evaluation is an integral part of modeling and it's one that's often glossed over. We'll often find evaluation to involve simply computing the accuracy or other global metrics but for many real work applications, a much more nuanced evaluation process is required. However, before evaluating our model, we always want to:

  • be clear about what metrics we are prioritizing
  • be careful not to over optimize on any one metric because it may mean you're compromising something else

1
2
# Metrics
metrics = {"overall": {}, "class": {}}
1
2
3
4
5
6
# Data to evaluate
device = torch.device("cuda")
loss_fn = nn.BCEWithLogitsLoss(weight=class_weights_tensor)
trainer = Trainer(model=model.to(device), device=device, loss_fn=loss_fn)
test_loss, y_true, y_prob = trainer.eval_step(dataloader=test_dataloader)
y_pred = np.array([np.where(prob >= threshold, 1, 0) for prob in y_prob])

Coarse-grained

While we were iteratively developing our baselines, our evaluation process involved computing the coarse-grained metrics such as overall precision, recall and f1 metrics.

1
2
3
4
5
6
7
# Overall metrics
overall_metrics = precision_recall_fscore_support(y_test, y_pred, average="weighted")
metrics["overall"]["precision"] = overall_metrics[0]
metrics["overall"]["recall"] = overall_metrics[1]
metrics["overall"]["f1"] = overall_metrics[2]
metrics["overall"]["num_samples"] = np.float64(len(y_true))
print (json.dumps(metrics["overall"], indent=4))
{
    "precision": 0.7896647806486397,
    "recall": 0.5965665236051502,
    "f1": 0.6612830799421741,
    "num_samples": 218.0
}

Note

The precision_recall_fscore_support() function from scikit-learn has an input parameter called average which has the following options below. We'll be using the different averaging methods for different metric granularities.

  • None: metrics are calculated for each unique class.
  • binary: used for binary classification tasks where the pos_label is specified.
  • micro: metrics are calculated using global TP, FP, and FN.
  • macro: per-class metrics which are averaged without accounting for class imbalance.
  • weighted: per-class metrics which are averaged by accounting for class imbalance.
  • samples: metrics are calculated at the per-sample level.

Fine-grained

Inspecting these coarse-grained, overall metrics is a start but we can go deeper by evaluating the same fine-grained metrics at the categorical feature levels.

1
2
3
4
5
6
7
8
9
# Per-class metrics
class_metrics = precision_recall_fscore_support(y_test, y_pred, average=None)
for i, _class in enumerate(label_encoder.classes):
    metrics["class"][_class] = {
        "precision": class_metrics[0][i],
        "recall": class_metrics[1][i],
        "f1": class_metrics[2][i],
        "num_samples": np.float64(class_metrics[3][i]),
    }
1
2
3
# Metrics for a specific class
tag = "transformers"
print (json.dumps(metrics["class"][tag], indent=2))

{
  "precision": 0.6428571428571429,
  "recall": 0.6428571428571429,
  "f1": 0.6428571428571429,
  "num_samples": 28.0
}

As a general rule, the classes with fewer samples will have lower performance so we should always work to identify the class (or fine-grained slices) of data that our model needs to see more samples of to learn from.

1
2
# Number of training samples per class
num_samples = np.sum(y_train, axis=0).tolist()
1
2
3
4
# Number of samples vs. performance (per class)
f1s = [metrics["class"][_class]["f1"]*100. for _class in label_encoder.classes]
sorted_lists = sorted(zip(*[num_samples, f1s])) # sort
num_samples, f1s = list(zip(*sorted_lists))
1
2
3
4
5
6
7
8
9
# Plot
n = 7 # num. top classes to label
fig, ax = plt.subplots()
ax.set_xlabel("# of training samples")
ax.set_ylabel("test performance (f1)")
fig.set_size_inches(25, 5)
ax.plot(num_samples, f1s, "bo-")
for x, y, label in zip(num_samples[-n:], f1s[-n:], label_encoder.classes[-n:]):
    ax.annotate(label, xy=(x,y), xytext=(-5, 5), ha="right", textcoords="offset points")

There are, of course, nuances to this general rule such as the complexity of distinguishing between some classes where we may not need as many samples for easier sub-tasks. In our case, classes with over 100 training samples consistently perform better than 0.6 f1 score, whereas the other class' performances are mixed.

Confusion matrix

Besides just inspecting the metrics for each class, we can also identify the true positives, false positives and false negatives. Each of these will give us insight about our model beyond what the metrics can provide.

  • True positives (TP): prediction = ground-truth β†’ learn about where our model performs well.
  • False positives (FP): falsely predict sample belongs to class β†’ identify potentially mislabeled samples.
  • False negatives (FN): falsely predict sample does not belong to class β†’ identify the model's less performant areas to upsample later.

It's a good to have our FP/FN samples feed back into our annotation pipelines in the event we want to fix their labels and have those changes be reflected everywhere.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# TP, FP, FN samples
index = label_encoder.class_to_index[tag]
tp, fp, fn = [], [], []
for i in range(len(y_test)):
    true = y_test[i][index]
    pred = y_pred[i][index]
    if true and pred:
        tp.append(i)
    elif not true and pred:
        fp.append(i)
    elif true and not pred:
        fn.append(i)
1
2
3
print (tp)
print (fp)
print (fn)

[4, 9, 27, 38, 40, 52, 58, 74, 79, 88, 97, 167, 174, 181, 186, 191, 194, 195]
[45, 54, 98, 104, 109, 137, 146, 152, 162, 190]
[55, 59, 63, 70, 87, 93, 125, 144, 166, 201]
1
2
3
4
index = tp[0]
print (X_test_raw[index])
print (f"true: {label_encoder.decode([y_test[index]])[0]}")
print (f"pred: {label_encoder.decode([y_pred[index]])[0]}\n")
simple transformers transformers classification ner qa language modeling language generation t5 multi modal conversational ai
true: ['language-modeling', 'natural-language-processing', 'question-answering', 'transformers']
pred: ['attention', 'huggingface', 'language-modeling', 'natural-language-processing', 'transformers']

1
2
3
# Sorted tags
sorted_tags_by_f1 = OrderedDict(sorted(
        metrics["class"].items(), key=lambda tag: tag[1]["f1"], reverse=True))
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
# Samples
num_samples = 3
if len(tp):
    print ("\n=== True positives ===")
    for i in tp[:num_samples]:
        print (f"  {X_test_raw[i]}")
        print (f"    true: {label_encoder.decode([y_test[i]])[0]}")
        print (f"    pred: {label_encoder.decode([y_pred[i]])[0]}\n")
if len(fp):
    print ("=== False positives === ")
    for i in fp[:num_samples]:
        print (f"  {X_test_raw[i]}")
        print (f"    true: {label_encoder.decode([y_test[i]])[0]}")
        print (f"    pred: {label_encoder.decode([y_pred[i]])[0]}\n")
if len(fn):
    print ("=== False negatives ===")
    for i in fn[:num_samples]:
        print (f"  {X_test_raw[i]}")
        print (f"    true: {label_encoder.decode([y_test[i]])[0]}")
        print (f"    pred: {label_encoder.decode([y_pred[i]])[0]}\n")

class = 'transformers'

{
  "precision": 0.6428571428571429,
  "recall": 0.6428571428571429,
  "f1": 0.6428571428571429,
  "num_samples": 28.0
}

=== True positives ===
  simple transformers transformers classification ner qa language modeling language generation t5 multi modal conversational ai
    true: ['language-modeling', 'natural-language-processing', 'question-answering', 'transformers']
    pred: ['attention', 'huggingface', 'language-modeling', 'natural-language-processing', 'transformers']

  bertviz tool visualizing attention transformer model bert gpt 2 albert xlnet roberta ctrl etc
    true: ['attention', 'interpretability', 'natural-language-processing', 'transformers']
    pred: ['attention', 'natural-language-processing', 'transformers']

  summary transformers models high level summary differences model huggingfacetransformer library
    true: ['huggingface', 'natural-language-processing', 'transformers']
    pred: ['huggingface', 'natural-language-processing', 'transformers']

=== False positives ===
  help read text summarization using flask huggingface text summarization translation questions answers generation using huggingface deployed using flask streamlit detailed guide github
    true: ['huggingface', 'natural-language-processing']
    pred: ['huggingface', 'natural-language-processing', 'transformers']

  silero models pre trained enterprise grade stt models silero speech text models provide enterprise grade stt compact form factor several commonly spoken languages
    true: ['pytorch', 'tensorflow']
    pred: ['natural-language-processing', 'transformers']

  evaluation metrics language modeling article focus traditional intrinsic metrics extremely useful process training language model
    true: ['language-modeling', 'natural-language-processing']
    pred: ['language-modeling', 'natural-language-processing', 'transformers']

=== False negatives ===
  t5 fine tuning colab notebook showcase fine tune t5 model various nlp tasks especially non text 2 text tasks text 2 text approach
    true: ['natural-language-processing', 'transformers']
    pred: ['natural-language-processing']

  universal adversarial triggers attacking analyzing nlp create short phrases cause specific model prediction concatenated input dataset
    true: ['natural-language-processing', 'transformers']
    pred: ['natural-language-processing']

  tempering expectations gpt 3 openai api closer look magic behind gpt 3 caveats aware
    true: ['natural-language-processing', 'transformers']
    pred: []

Tip

While this view is great for cursory inspection, we should have a scaled version that's tied to labeling and boosting workflows so we can act on our findings from this view.

Confidence learning

While the confusion-matrix sample analysis was a coarse-grained process, we can also use fine-grained confidence based approaches to identify potentially mislabeled samples. Here we’re going to focus on the specific probabilities as opposed to the final model predictions.

Simple confidence based techniques include identifying samples whose:

  • Categorical
    • prediction is incorrect (also indicate TN, FP, FN)
    • confidence score for the correct class is below a threshold
    • confidence score for an incorrect class is above a threshold
    • standard deviation of confidence scores over top N samples is low
    • different predictions from same model using different/previous parameters
  • Continuous
    • difference between predicted and ground-truth values is above some %
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Confidence score for the incorrect class is above a threshold
high_confidence = []
max_threshold = 0.2
for i in range(len(y_test)):
    indices = np.where(y_test[i]==0)[0]
    probs = y_prob[i][indices]
    classes = []
    for index in np.where(probs>=max_threshold)[0]:
        classes.append(label_encoder.index_to_class[indices[index]])
    if len(classes):
        high_confidence.append({"text": test_df.text[i], "classes": classes})
1
high_confidence[0:5]
[{'classes': ['computer-vision', 'scikit-learn'],
  'text': 'mljar supervised automated machine learning python package designed save time data scientist'},
 {'classes': ['computer-vision', 'pytorch'],
  'text': 'bootstrap latent approach self supervised learning new approach self supervised image representation learning'},
 {'classes': ['attention', 'huggingface'],
  'text': 'simple transformers transformers classification ner qa language modeling language generation t5 multi modal conversational ai'},
 {'classes': ['embeddings'],
  'text': 'entity embedding lstm time series demonstration using lstm forecasting structured time series data containing categorical numerical features'},
 {'classes': ['huggingface'],
  'text': 'bertviz tool visualizing attention transformer model bert gpt 2 albert xlnet roberta ctrl etc'}]

Calibration

But these are fairly crude techniques because neural networks are easily overconfident and so their confidences cannot be used without calibrating them.

accuracy vs. confidence
Modern (large) neural networks result in higher accuracies but are over confident.
On Calibration of Modern Neural Networks
  • Assumption: β€œthe probability associated with the predicted class label should reflect its ground truth correctness likelihood.”
  • Reality: β€œmodern (large) neural networks are no longer well-calibrated”
  • Solution: apply temperature scaling (extension of Platt scaling) on model outputs

Recent work on confident learning (cleanlab) focuses on identifying noisy labels (with calibration), which can then be properly relabeled and used for training.

1
2
3
import cleanlab
from cleanlab.util import onehot2int
from cleanlab.pruning import get_noise_indices
1
2
# Format our noisy labels `s` (cleanlab expects list of integers for multilabel tasks)
correctly_formatted_labels = onehot2int(y_test)
1
2
3
4
5
6
7
# Determine potential labeling errors
label_error_indices = get_noise_indices(
            s=correctly_formatted_labels,
            psx=y_prob,
            multi_label=True,
            sorted_index_method="self_confidence",
            verbose=0)

Not all of these are necessarily labeling errors but situations where the predicted probabilities were not so confident. Therefore, it will be useful to attach the predicted outcomes along side results. This way, we can know if we need to relabel, upsample, etc. as mitigation strategies to improve our performance.

1
2
3
4
5
6
num_samples = 5
for index in label_error_indices[:num_samples]:
    print ("text:", test_df.iloc[index].text)
    print ("labels:",test_df.iloc[index].tags)
    print ("pred:", label_encoder.decode([y_pred[index]]))
    print ()
text: simclr keras tensorflow keras implementation simclr
labels: ['keras', 'self-supervised-learning', 'tensorflow']
pred: [['keras', 'tensorflow']]

text: tensorflow js object detection browser real time object detection model browser using tensorflow js
labels: ['computer-vision', 'object-detection', 'tensorflow', 'tensorflow-js']
pred: [['computer-vision', 'convolutional-neural-networks', 'keras', 'object-detection', 'tensorflow', 'tensorflow-js']]

text: pokezoo deep learning based web app developed using mern stack tensorflow js
labels: ['computer-vision', 'image-classification', 'tensorflow', 'tensorflow-js']
pred: [['computer-vision', 'keras', 'tensorflow', 'tensorflow-js']]

text: pcdet 3d point cloud detection pcdet toolbox pytorch 3d object detection point cloud
labels: ['computer-vision', 'convolutional-neural-networks', 'object-detection', 'pytorch']
pred: [['computer-vision', 'object-detection', 'pytorch']]

text: clustered graph convolutional networks pytorch implementation cluster gcn efficient algorithm training deep large graph convolutional networks kdd 2019
labels: ['embeddings', 'graphs', 'node-classification', 'pytorch', 'representation-learning']
pred: [['embeddings', 'graph-neural-networks', 'graphs', 'node-classification', 'pytorch', 'representation-learning']]

Manual slices

Just inspecting the overall and class metrics isn't enough to deploy our new version to production. There may be key slices of our dataset that we need to do really well on:

  • Target / predicted classes (+ combinations)
  • Features (explicit and implicit)
  • Metadata (timestamps, sources, etc.)
  • Priority slices / experience (minority groups, large customers, etc.)

An easy way to create and evaluate slices is to define slicing functions.

1
2
3
from snorkel.slicing import PandasSFApplier
from snorkel.slicing import slice_dataframe
from snorkel.slicing import slicing_function
1
2
3
4
@slicing_function()
def cv_transformers(x):
    """Projects with the `computer-vision` and `transformers` tags."""
    return all(tag in x.tags for tag in ["computer-vision", "transformers"])
1
2
3
4
@slicing_function()
def short_text(x):
    """Projects with short titles and descriptions."""
    return len(x.text.split()) < 7  # less than 7 words

Here we're using Snorkel's slicing_function to create our different slices. We can visualize our slices by applying this slicing function to a relevant DataFrame using slice_dataframe.

1
2
short_text_df = slice_dataframe(test_df, short_text)
short_text_df[["text", "tags"]].head()
text tags
44 flask sqlalchemy adds sqlalchemy support flask [flask]
69 scikit lego extra blocks sklearn pipelines [scikit-learn]
83 simclr keras tensorflow keras implementation s... [keras, self-supervised-learning, tensorflow]
215 introduction autoencoders look autoencoders re... [autoencoders, representation-learning]

We can define even more slicing functions and create a slices record array using the PandasSFApplier. The slices array has N (# of data points) items and each item has S (# of slicing functions) items, indicating whether that data point is part of that slice. Think of this record array as a masking layer for each slicing function on our data.

1
2
3
4
5
# Slices
slicing_functions = [cv_transformers, short_text]
applier = PandasSFApplier(slicing_functions)
slices = applier.apply(test_df)
print (slices)
rec.array([(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
           (1, 0) (0, 0) (0, 1) (0, 0) (0, 0) (1, 0) (0, 0) (0, 0) (0, 1) (0, 0)
           ...
           (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1),
           (0, 0), (0, 0)],
    dtype=[('cv_transformers', 'i8'), ('short_text', 'i8')])

If our task was multiclass instead of multilabel, we could've used snorkel.analysis.Scorer to retrieve our slice metrics. But we've implemented a naive version for our multilabel task based on it.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# Score slices
metrics["slices"] = {}
for slice_name in slices.dtype.names:
    mask = slices[slice_name].astype(bool)
    if sum(mask):
        slice_metrics = precision_recall_fscore_support(
            y_test[mask], y_pred[mask], average="micro"
        )
        metrics["slices"][slice_name] = {}
        metrics["slices"][slice_name]["precision"] = slice_metrics[0]
        metrics["slices"][slice_name]["recall"] = slice_metrics[1]
        metrics["slices"][slice_name]["f1"] = slice_metrics[2]
        metrics["slices"][slice_name]["num_samples"] = len(y_true[mask])
1
print(json.dumps(metrics["slices"], indent=2))
{
  "cv_transformers": {
    "precision": 0.9230769230769231,
    "recall": 0.8571428571428571,
    "f1": 0.888888888888889,
    "num_samples": 3
  },
  "short_text": {
    "precision": 0.8,
    "recall": 0.5714285714285714,
    "f1": 0.6666666666666666,
    "num_samples": 4
  }
}

Generated slices

Manually creating slices is a massive improvement towards identifying problem subsets in our dataset compared to coarse-grained evaluation but what if there are problematic slices of our dataset that we failed to identify? SliceLine is a recent work that uses a linear-algebra and pruning based technique to identify large slices (specify minimum slice size) that result in meaningful errors from the forward pass. Without pruning, automatic slice identification becomes computationally intensive because it involves enumerating through many combinations of data points to identify the slices. But with this technique, we can discover hidden underperforming subsets in our dataset that we weren’t explicitly looking for!

slicefinder GUI

Hidden stratification

What if the features to generate slices on are implicit/hidden?

Subgroup examples

To address this, there are recent clustering-based techniques to identify these hidden slices and improve the system.

  1. Estimate implicit subclass labels via unsupervised clustering
  2. Train new more robust model using these clusters
Identifying subgroups via clustering and training on them.

Model patching

Another recent work on model patching takes this another step further by learning how to transform between subgroups so we can train models on the augmented data:

  1. Learn subgroups
  2. Learn transformations (ex. CycleGAN) needed to go from one subgroup to another under the same superclass (label)
  3. Augment data with artificially introduced subgroup features
  4. Train new robust model on augmented data
Using learned subgroup transformations to augment data.

Explainability

Besides just comparing predicted outputs with ground truth values, we can also inspect the inputs to our models. What aspects of the input are more influential towards the prediction? If the focus is not on intuitive features of our input, then we need to explore if there is a hidden pattern we're missing or if our model has learned to overfit on the incorrect features. We can use techniques such as SHAP (SHapley Additive exPlanations) or LIME (Local Interpretable Model-agnostic Explanations) to inspect feature importance. On a high level, these techniques learn which features have the most signal by assessing the performance in their absence. These inspections can be performed on a global level (ex. per-class) or on a local level (ex. single prediction).

TODO: Adding relevant code and results for this section this week.

We can also use model-specific approaches to explainability we we did in our embeddings lesson, where we used SAME padding to extract the most influential n-grams in our text.

Counterfactuals

Another way to evaluate our systems is to identify counterfactuals -- data with similar features that belongs to another class (classification) or above a certain difference (regression). These points allow us to evaluate model sensitivity to certain features and feature values that may be signs of overfitting. A great tool to identify and probe for counterfactuals (also great for slicing and fairness metrics) is the What-if tool.

Identifying counterfactuals using the What-if tool

Behavioral testing

Besides just looking at metrics, we also want to conduct some behavior sanity tests. Behavioral testing is the process of testing input data and expected outputs while treating the model as a black box. They don't necessarily have to be adversarial in nature but more along the types of perturbations we'll see in the real world once our model is deployed. A landmark paper on this topic is Beyond Accuracy: Behavioral Testing of NLP Models with CheckList which breaks down behavioral testing into three types of tests:

  • invariance: Changes should not affect outputs.
    1
    2
    3
    4
    # INVariance via verb injection (changes should not affect outputs)
    tokens = ["revolutionized", "disrupted"]
    tags = [["transformers"], ["transformers"]]
    texts = [f"Transformers have {token} the ML field." for token in tokens]
    
  • directional: Change should affect outputs.
    1
    2
    3
    4
    5
    6
    7
    # DIRectional expectations (changes with known outputs)
    tokens = ["PyTorch", "Huggingface"]
    tags = [
        ["pytorch", "transformers"],
        ["huggingface", "transformers"],
    ]
    texts = [f"A {token} implementation of transformers." for token in tokens]
    
  • minimum functionality: Simple combination of inputs and expected outputs.
    1
    2
    3
    4
    # Minimum Functionality Tests (simple input/output pairs)
    tokens = ["transformers", "graph neural networks"]
    tags = [["transformers"], ["graph-neural-networks"]]
    texts = [f"{token} have revolutionized machine learning." for token in tokens]
    

We'll learn how to systematically create tests in our testing lesson.

Evaluating evaluations

How can we know if our models and systems are performing better over time? Unfortunately, depending on how often we retrain or how quickly our dataset grows, it won't always be a simple decision where all metrics/slices are performing better than the previous version. In these scenarios, it's important to know what our main priorities are and where we can have some leighway:

  • What criteria are most important?
  • What criteria can/cannot regress?
  • How much of a regression can be tolerated?
1
2
3
4
assert precision > prev_precision  # most important, cannot regress
assert recall >= best_prev_recall - 0.03  # recall cannot regress > 3%
assert metrics["class"]["data_augmentation"]["f1"] > prev_data_augmentation_f1  # priority class
assert metrics["slices"]["class"]["cv_transformers"]["f1"] > prev_cv_transformers_f1  # priority slice

And as we develop these criteria over time, we can systematically enforce them via CI/CD workflows to decrease the manual time in between system updates.

Seems straightforward, doesn't it?

With all these different evaluation methods, how can we choose "the best" version of our model if some versions are better for some evaluation criteria?

Show answer

You and your team need to agree on what evaluation criteria are most important and what is the minimum performance required for each one. This will allow us to filter amongst all the different solutions by removing ones that don't satisfy all the minimum requirements and ranking amongst the remaining by which ones perform the best for the highest priority criteria.

Online evaluation

Once we've evaluated our model's ability to perform on a static dataset that is representative of production data, we can run several types of online evaluation techniques to determine performance on actual product data. It can be performed using labels or, in the event we don't readily have labels, proxy signals.

  • manually label a subset of incoming data to evaluate periodically.
  • asking the initial set of users viewing a newly categorized content if it's correctly classified.
  • allow users to report misclassified content by our model.

Model CI

An effective way to evaluate our systems is to encapsulate them as a collection (suite) and use them for continuous integration. We would continue to add to our evaluation suites and they would be executed whenever we are experimenting with changes to our system (new models, data, etc.). Often, problematic slices of data identified during monitoring are often added to the evaluation test suite to avoid repeating the same regressions in the future.

Resources


To cite this lesson, please use:

1
2
3
4
5
6
@article{madewithml,
    author       = {Goku Mohandas},
    title        = { Evaluation - Made With ML },
    howpublished = {\url{https://madewithml.com/}},
    year         = {2021}
}