Current machine learning methods provide unprecedented accuracy across a range of domains, from computer vision to natural language processing. However, in many important highstakes applications, such as medical diagnosis or autonomous driving, rare mistakes can be extremely costly, and thus effective deployment of learned models requires not only high accuracy, but also a way to measure the certainty in a model’s predictions. Reliable uncertainty quantification is especially important when faced with outofdistribution inputs, as model accuracy tends to degrade heavily on inputs that differ significantly from those seen during training. In this blog post, we will discuss how we can get reliable uncertainty estimation with a strategy that does not simply rely on a learned model to extrapolate to outofdistribution inputs, but instead asks: “given my training data, which labels would make sense for this input?”.
To illustrate how this can allow for more reasonable predictions on outofdistribution data, consider the following example where we attempt to classify automobiles, where all the class 1 training examples are sedans and class 2 examples are large buses.
Figure 1: Given previously seen examples, it is uncertain what the label for the new query point should be. Different classifiers that work well on the training set can give different predictions on the query point.
A classifier could potentially fit the training labels correctly based on several different explanations; for example, it could notice that buses are all longer than sedans and classify accordingly, or it could perhaps pay attention to the height of the vehicle instead. However, if we try to simply extrapolate to an outofdistribution image of a limousine, the classifier’s output could be unpredictable and arbitrary. A classifier based on length could note that the limousine is similar to the buses in its length and confidently predict class 2, while a classifier utilizing the height could confidently predict class 1. Based only on the training set, there is not enough information to accurately decide which class a limousine should fit into, so we would ideally want our classifier to indicate uncertainty instead of providing arbitrary confident predictions for either class. On the other hand, if we explicitly try to find models that explain each potential label, we would find reasonable explanations for either label, suggesting that we should be uncertain about predicting which class the limousine belongs to.
We can instantiate this reasoning with an algorithm that, for every possible label, explicitly updates the model to try to explain that label for the query point and combines the different models to obtain wellcalibrated predictions for outofdistribution inputs. In this blog post, we will motivate and introduce amortized conditional normalized maximum likelihood (ACNML), a practical instantiation of this idea that enables reliable uncertainty estimation with deep neural networks.
Our method extends a prediction strategy from the minimum description length (MDL) literature known as conditional normalized maximum likelihood (CNML),^{} which has been studied for its theoretical properties, but is computationally intractable for all but the simplest problems. We will first review CNML and discuss how its predictions can lead to conservative uncertainty estimates. We will then describe our method, which allows for a practical application of CNML to obtain uncertainty estimates for deep neural networks.
The CNML distribution is derived from achieving a notion of minimax optimality, where we define a notion of regret for each label and choose the distribution that minimizes the worst case regret over labels. Given a training set $D_{rm train}$, a query input $x$, and a set of potential models $Theta$, we define the regret for each label to be the difference between the negative loglikelihood loss for our distribution and the loss under the model that best fits the training dataset together with the query point and label.
Intuitively, minimizing the worst case regret over labels then ensures our predictive distribution is conservative, as it cannot assign low probabilities to any labels that appear consistent with our training data, where consistency is determined by the model class.
The minimax optimal distribution given a particular input $x$ and training set $mathcal D$ can be explicitly computed as follows:

For each label $y$, we append $(x,y)$ to our training set and compute the new optimal parameters $hat theta_y$ for this modified training set.

Use $hat theta_y$ to assign probability for that label.

Since these probabilities will now sum to a number greater than 1, we normalize to obtain a valid distribution over labels.
CNML has the interesting property that it explicitly optimizes the model to make predictions on the query input, which can lead to more reasonable predictions than simply extrapolating using a model obtained only from the training set. It can also lead to more conservative predictions on outofdistribution inputs, since it would be easier to fit different labels for outofdistribution points without affecting performance on the training set.
We illustrate CNML with a 2dimensional logistic regression example. We compare heatmaps of CNML probabilities with the maximum likelihood classifier to illustrate how CNML provides conservative uncertainty estimates on points away from the data. With this model class, CNML expresses uncertainty and assigns a uniform distribution to any query point where the dataset remains linearly separable (meaning there exists a linear decision boundary could correctly classify all data points) regardless of which label was assigned for the query point.
Figure 2: Here, we show the heatmap of CNML predictions (left) and the predictions of the training set MLE $hat theta_{text{train}}$ (right). The training inputs are shown with blue (class 0) and orange (class 1) dots. Blue shading indicates that higher probability for class 0 on that input and red shading indicates higher probability to class 1, with darker colors indicating more confident predictions. We note that while the original classifier assigns confident predictions for most inputs, CNML assigns close to uniform for most points between the two clusters of training points, indicating high uncertainty about these ambiguous inputs.
We illustrate how CNML computes these probabilities by illustrating the base classifier predictions under parameters $hat theta_{text{train}}$ (the training set MLE), as well as $hat theta_{text{0}}$, $hat theta_{text{1}}$, the parameters computed by CNML after assigning the labels 0 and 1 respectively to a query point.
We first consider an outofdistribution query point far away from any of the training inputs (shown in pink in the bottom of the leftmost image). In the left image, we see the original decision boundary for $hat theta_{text{train}}$ confidently classifies the query point as class 0. In the middle, we see the decision boundary of $hat theta_0$ similarly classifies the query point as class 0. However, we see in the rightmost image that $hat theta_1$ is able to confidently classify the query point as class 1. Since $hat theta_0$ confidently predicts class 0 for the query point and $hat theta_1$ confidently predicts class 1, CNML normalizes the two predictions to assign roughly equal probability to either label.
Figure 3: Query point shown in pink. Both $hat theta_{text{train}}$ and $hat theta_0$ classify the query point as class 0, but $hat theta_1$ is able to classify it as class 1.
On the other hand, for an indistribution query point (again shown in pink) in the middle of the class 0 training inputs, no linear classifier can fit a label of 1 to the query point while still accurately fitting the rest of the training data, so the CNML distribution still confidently predicts class 0 on this query point.
Figure 4: Query point shown in pink. All parameters are forced to classify the query point as class 0 since it is in the middle of the class 0 training points.
Controlling Conservativeness via Regularization
We see in Figure 2 that CNML probabilities are uniform on most of the input space, arguably being too conservative. In this case, the model class is in some sense too expressive, as linear predictors with large coefficients that can assign arbitrarily high probabilities to each label as long as the data remains linearly separable. This problem is exacerbated with even more expressive model classes like deep neural networks, which can potentially fit arbitrary labels. In order to have CNML give more useful predictions, we would need to constrain the allowed set of models to better reflect our notion of reasonable models.
To accomplish this, we generalize CNML to incorporate regularization via a prior term, resulting in conditional normalized maximum a posteriori (CNMAP) instead. Instead of computing maximum likelihood parameters for the training dataset and the new input and label, we compute maximum a posteriori (MAP) solutions instead, with the prior term $p(theta)$ serving as a regularizer to limit the complexity of the selected model.
Going back to the logistic regression example, we add different levels of L2 regularization to the parameters (corresponding to Gaussian priors) and plot CNMAP probabilities in Figure 3 below. As regularization increases, CNML becomes less conservative, with the assigned probabilities transitioning much more smoothly as one moves away from the training points.
Figure 5: Heatmaps of CNMAP probabilities under varying amounts of regularization $lambda w_2^2$. Increasing regularization leads to less conservative predictions.
Computational Intractability
While we see that CNML is able to provide conservative predictions for OOD inputs, computing CNML predictions requires retraining the model using the entire training set multiple times for each test input, which can be very computationally expensive. While explicitly computing CNML distributions was feasible in our logistic regression example with small datasets, it would be computationally intractable to compute CNML naively with datasets consisting of thousands of images and using deep convolutional neural networks, as retraining the model just once could already take many hours. Even initializing from the solution to the training set and finetuning for several epochs after adding the query input and label could still take several minutes per input, rendering it impractical to use with deep neural networks and large datasets.
Since exactly computing CNML or CNMAP distributions is computationally infeasible in deep learning settings due to the need to optimize over large datasets for each new input and label, we need a tractable approximation. In our method, amortized conditional normalized maximum likelihood (ACNML), we utilize approximate Bayesian posteriors to capture necessary information about the training data in order to efficiently compute the MAP/MLE solutions for each datapoint. ACNML amortizes the costs of repeatedly optimizing over the training set by first computing an approximate Bayesian posterior, which serves as a compact approximation to the training losses.
CNMAP and Bayesian Posteriors
We note that the main computational bottleneck is the need to optimize over the entire training set for each query point. In order to sidestep this issue, we first show a relationship between the MAP parameters needed in CNMAP and Bayesian posterior densities:
Rather than computing optimal parameters for the new query point and the training set, we can reformulate CNMAP as optimizing over just the query point and a posterior density. With a uniform prior (equivalent to having no regularizer), we can recover the maximum likelihood parameters to perform CNML if desired.^{}
ACNML now utilizes approximate Bayesian inference to replace the exact Bayesian posterior density with a tractable density $q(theta)$. As many methods have been proposed for approximate Bayesian inference in deep learning, we can simply utilize any approximate posterior that provides tractable densities for ACNML, though we focus on Gaussian approximate posteriors for simplicity and computational efficiency. After computing the approximate posterior once during training, the testtime optimization procedure becomes much simpler, as we only need to optimize over our approximate posterior instead of the training set. When we instantiate ACNML and initialize from the MAP solution, we find that it typically takes only a handful of gradient updates to compute new (approximate) optimal parameters for each label, resulting in much faster testtime inference than a naive CNML instantiation that fine tunes using the whole training set.
In our paper, we analyze the approximation error incurred by using a particular Gaussian posterior in place of the exact training data likelihoods, and show that under certain assumptions, the approximation is accurate when the training set is large.
Experiments
We instantiate ACNML with two different Gaussian posterior approximations, SWAGDiagonal and KFACLaplace and train models on the CIFAR10 image classification dataset. To evaluate outofdistribution performance, we then evaluate on the CIFAR10 Corrupted datasets, which apply a diverse range of common image corruptions at different intensities, allowing us to see how well methods perform under varying levels of distribution shift. We compare against methods using Bayesian marginalization, which average predictions across different models sampled from the approximate posterior. We note that all methods provide very similar accuracy both indistribution and outofdistribution, so we focus on comparing uncertainty estimates.
Figure 7: Reliability Diagrams comparing ACNML against the corresponding Bayesian model averaging method (SWAGD) and the MAP solution (SWA). ACNML generally predicts with lower confidence than other methods, leading to comparatively better uncertainty estimation as the data becomes more outofdistribution.
We first examine ACNML’s predictions using reliability diagrams, which aggregate the test data points into buckets based on how confident the model’s predictions are, then plot the average confidence in a bucket against the actual accuracy of the predictions. These diagrams show the distribution of predicted confidences and can capture how effectively a model’s confidence reflects the actual uncertainty over the prediction.
As we would expect from our earlier discussion about CNML, we find that ACNML reliably gives more conservative (less confident) predictions than other methods, to the point where its predictions are actually underconfident on the indistribution CIFAR10 test set where all methods provide very accurate predictions. However, on the outofdistribution CIFAR10C tasks where classifier accuracy degrades, ACNML’s conservativeness provides much more reliable confidence estimates, while other methods tend to be severely overconfident.
Figure 8: ECE comparisons: We compare instantiations of ACNML with two different approximate posteriors against their Bayesian counterparts.
We quantitatively measure calibration using the Expected Calibration Error, which uses the same buckets as the reliability diagrams and computes the average calibration error (absolute difference between average confidence and accuracy within the bucket) over all buckets. We see that ACNML instantiations provide much better calibration than their Bayesian counterparts and the deterministic baseline as the corruption intensities increase and the data becomes more outofdistribution.
In this post, we discussed how we can obtain reliable uncertainty estimates on outofdistribution data by explicitly optimizing on the data we wish to make predictions on instead of relying on trained models to extrapolate from the training data. We then showed that this can be done concretely with the CNML prediction strategy, a scheme that has been studied theoretically but is computationally intractable to apply in practice. Finally we presented our method, ACNML, a practical approximation to CNML that enables reliable uncertainty estimation with deep neural networks. We hope that this line of work will help enable broader applicability of large scale machine learning systems, especially in safetycritical domains where uncertainty estimation is a necessity.
We thank Sergey Levine and Dibya Ghosh for providing valuable feedback on this post.
This post is based on this following paper:
Be the first to comment on "Training on Test Inputs with Amortized Conditional Normalized Maximum Likelihood"