Three versions of risk-controlling prediction sets

conformal
Published

June 30, 2023

I had the honor and good fortune to present the “Distribution-Free, Risk-Cntrolling Prediction Sets” paper (RCPS, [1]) at the Jordan symposium this month. The paper is actually a bit more general than it first appears. Their key contribution — the use of more complex and general loss functions — can actually play well with more traditional conformal inference methods.

I got to talk to Stephen and Anastasios, and they (anr presumably the other authors) are all well aware of the variants I’m about to describe, and other sophisticated readers of the conformal literature will have seen these variants, too. But it took me some thought, so I thought it was worth writing up here.

Setup

As with classical conformal, RCPS data takes the form of IID pairs \(Z_m = (X_m, Y_m)\), and we want to form a set, \(S(X_n)\), such that \(Y_n \in S(X_n)\) with a new datapoint \(Z_n\). We want this to happen with high probability according to some notion thereof — we’ll explore a few different versions below.

I will assume (as in RCPS) that we have a family of sets that are parameterized by a scalar parameter \(\lambda\), writing \(S_\lambda(\cdot)\). The size of the sets must be non-decreasing in \(\lambda\), so the task is to choose a sufficiently large \(\hat{\lambda}\) with the help of a “calibration” data set \(\mathcal{Z} := Z_1, \ldots, Z_N\) so that the sets are large enough (but not too large). I’ll write \(S_{\hat{\lambda}}(\cdot)\) to emphasize with the hat that the value of \(\hat{\lambda}\) depends on a randomly selected calibration set.

Classical conformal inference (CI) produce intervals such that

\[ \underset{\mathcal{Z},Z_n}{\mathcal{P}}\left( Y_n \in S_{\hat{\lambda}}(X_n) \right) \ge 1 - \varepsilon, \quad\quad\textrm{Eq. 1 (traditional CI)} \]

for some target error \(\varepsilon\). That is, there is high probability that a new datapoint’s response lies within the given set, where the probability is taken jointly over the new datapoint and the calibration set.

RCPS does something formally different. They take a loss function \(L(Y_n, S)\) (which is non-increasing in the size of \(S\)), and control

\[ \underset{\mathcal{Z}}{\mathcal{P}}\left( \underset{Z_n}{\mathbb{E}}\left[ L(Y_n, S_{\hat{\lambda}}(X_n)) \right] \ge \alpha \right) \ge 1 - \delta, \quad\quad\textrm{Eq. 2 (RCPS)} \]

for some target accuracy \(\alpha\) and risk level \(\delta\).

Differences from traditional CI

Superficially, there are three differences between RCPS and CI:

  1. The use of a generic loss function
  2. Separately controlling the randomness from the calibration set and new data point
  3. Controlling a “point estimate” (the expected loss) rather than producing an interval (e.g., guaranteeing that the loss is less than some amount some fraction of the time)

A naive reader (e.g. me, on the first read) might wonder whether all three differences are tied together somehow. But these features can all be achieved separately — in particular, we can provide interval-like guarantees with generic losses, as well as separately control the randomness in the calibration set and new datapoint.

Traditional CI with a generic loss

First, let’s do something like traditional CI but with a generic loss function. That might mean choosing \(\hat{\lambda}\) so that

\[ \underset{\mathcal{Z},Z_n}{\mathcal{P}}\left( L(Y_n, S_{\hat{\lambda}}(X_n)) \le \beta \right) \ge 1 - \varepsilon, \quad\quad\textrm{Eq. 3 (loss)} \]

for some \(\beta\) and some \(\varepsilon\). Here we have retained difference (1), but not (2) and (3) — we provide and instance-wise interval for the loss rather than a point estimate (3), and have not separately controlled the randomness in the calibration and test point (2). To achieve Eq. 3, we can invert the map from \(\lambda \mapsto L(Y_m, S_{\lambda}(X_m)\). Define

\[ \lambda(Z_m) := \inf \, \{ \lambda: L(Y_m, S_{\lambda}(X_m)) \le \beta \}. \]

The values \(\lambda(Z_m)\) on the caibration set are exchangeable with the \(\lambda(Z_n)\) on a new datapoint. Taking \(\lambda(Z_m)\) as our “conformity scores” and applying traditional CI thus gives Eq. 3.

Traditional CI with separate control on the randomness

Similarly, we can achieve (2) without (1) and (3), doing traditional CI intervals but with separate control over the randomness in the new datapoint and calibration dataset. Specifically, we’d like to find a \(\hat{\lambda}\) so that

\[ \underset{\mathcal{Z}}{\mathcal{P}}\left( \underset{Z_n}{\mathcal{P}}\left( Y_n \in S_{\hat{\lambda}}(X_n) \right) \ge 1 - \gamma \right) \ge 1 - \delta. \quad\quad\textrm{Eq. 4 (separate randomness control)} \]

Eq. 4 is exactly Eq. 1, but we have separated out the sources of randomness. Eq. 4 can be achieved by doing standard RCPS with the indicator loss function:

\[ L(Y_n, S) = \mathbb{I}\left( Y_n \notin S \right). \]

As pointed out by [1] (see Proposition 4 and discussion), the loss here is binomial and so the Bentkus bound produces nearly tight intervals.

Generic loss functions and high-probability interval bounds

Finally, we can acheive (1) and (2) but not (3). By combining the previous two ideas, we can find sets satisfying

\[ \underset{\mathcal{Z}}{\mathcal{P}}\left( \underset{Z_n}{\mathcal{P}}\left( L(Y_n, S_{\hat{\lambda}}(X_n)) \le \beta \right) \ge 1 - \gamma \right) \ge 1 - \varepsilon. \quad\quad\textrm{Eq. 5 (loss, intervals, separate control)} \]

Comparison

To me, Eq. 5 looks like the best sort of guarantee — interval control over a generic loss, with separate control of the randomness. In fact, I rather wish the original RCPS paper had looked at bounds of the form Eq. 5 rather than Eq. 2.

Of course, computing the \(\lambda(Z_m)\) for each calibration point requires inverting \(N\) loss functions rather than a single empirical loss as in the RCPS paper. Furthermore, Anastasios pointed out to me that the concentration bounds used by RCPS to separately control the randomness (difference (2)) converge at rate \(1/\sqrt{N}\), while traditional CI is based on quantile estimates which are accurate at rate \(1/N\). So there is both a computational and theoretical price to be paid.

But the main point is that the machinery developed in [1] allows you to pick and choose what you need for your particular problem. In this sense, [1] represents an even richer set of techniques than might appear at first glance.

References

[1] Bates, S., Angelopoulos, A., Lei, L., Malik, J. and Jordan, M., 2021. Distribution-free, risk-controlling prediction sets. Journal of the ACM (JACM), 68(6), pp.1-34.