A simple and clever (but inefficient) way to calculate M-estimator sensitivity with automatic differentiation.
I have some recent work (A Higher-Order Swiss Army Infinitesimal Jackknife) that is all about calculating Taylor expansions of optima with respect to hyperparameters using automatic differentiation.
In the paper we talk about sensitivity to data weights, but the idea is much more general. Suppose you have a parameter which you’re optimizing, \(\theta\), and some hyperparameter \(\epsilon\). For a fixed \(\epsilon\), you find \(\hat\theta(\epsilon)\) to satisfy some vector of first-order conditions,
\[G(\hat\theta(\epsilon), \epsilon) = 0.\]For example, if you’re optimizing \(F(\theta, \epsilon)\) for a fixed \(\epsilon\), \(G\) would be the vector of partial derivatives of \(F\) with respect to \(\theta\).
Of course, \(\hat\theta(\epsilon)\) depends on \(\epsilon\), and you might want to approximately calculate \(\hat\theta(\epsilon)\) for different \(\epsilon\) without re-solving an optimiztion problem. Specifically, you might form a Taylor series expansion around some \(\epsilon_0\):
\[\hat\theta(\epsilon) \approx \hat\theta(\epsilon_0) + \left.\frac{d \hat\theta(\epsilon)}{d\epsilon}\right|_{\epsilon_0} (\epsilon - \epsilon_0) + \frac{1}{2} \left.\frac{d^2 \hat\theta(\epsilon)}{d\epsilon^2}\right|_{\epsilon_0} (\epsilon - \epsilon_0)(\epsilon - \epsilon_0) + ...\]The difficulty is how to calculate the derivatives, which are defined implicitly through the solution of \(G(\hat\theta(\epsilon), \epsilon) = 0\). The computation section of our paper describes one way to do so recursively, and I’ve implemented the solution in the ParametricSensitivityTaylorExpansion class of my Python package vittles.
My perspective works (and is amenable to theory), but is a bit complicated to implement. Martin Jankowiak at Uber described to me his idea for an extremely elegant, though unfortunately inefficient, implementation. Let me demonstrate his idea in autograd and discuss how it is inefficient.
First, we’ll need to use the fact that the first derivative is given by
\[\left.\frac{d \hat\theta(\epsilon)}{d\epsilon^T}\right|_{\epsilon_0} = -\left.\frac{\partial G(\theta, \epsilon)}{\partial \theta^T} \right|_{\epsilon_0} ^{-1} \left.\frac{\partial G(\theta, \epsilon)}{\partial \epsilon^T} \right|_{\epsilon_0} ^{-1}.\](Recall that \(G(\theta, \epsilon)\) is a vector of the same length as
\(\theta\).) Now, suppose we have implemented \(G(\theta, \epsilon)\) in
Python, found a solution theta0
at epsilon0
:
def g(theta, epsilon):
... your estimating equation here ...
g(theta0, epsilon0) # ...is a vector of zeros.
Using this, we can implement the optimal \(\hat\theta(\epsilon)\) as the following function, which only evaluates at \(\epsilon_0\):
def check_epsilon(epsilon):
assert np.linalg.norm(epsilon - hyperpar0) < 1e-8
@primitive
def get_thetahat(epsilon):
check_epsilon(epsilon)
return theta0
As-is, this is a useless function. It only returns what we already know,
which is that theta0
is the optimum for epsilon0
, and otherwise throws
an error. However, we have marked it @primitive
, which means we can
specify a custom derivative using the formula above. We will only be
able to evaluate this derivative at epsilon0
, of course, but that’s all
we want.
dg_dtheta = autograd.jacobian(g, argnum=0)
dg_depsilon = autograd.jacobian(g, argnum=1)
# Reverse mode AD is a "vector jacobian product", or "vjp".
def get_thetahat_vjp(ans, epsilon):
def vjp(g):
thetahat = get_thetahat(epsilon)
return -1 * (dg_depsilon(thetahat, epsilon).T @
np.linalg.solve(dg_dtheta(thetahat, epsilon), g)).T
return vjp
# Tell autograd to use get_thetahat_vjp to reverse-mode autodiff get_thetahat.
defvjp(get_thetahat, get_thetahat_vjp)
The fucnction vjp
is simply the reverse-mode implementation of the formula
above for the first derivative.
Now, the magic is that this implemenatation of the derivative of get_thetahat
itself is composed of differentiable functions: linear algebra (solve
and
matrix multiplication), derivatives of g
, and … get_thetahat
itself, whose
derivative we have just defined! Consequently, this single definition suffices
for (reverse mode) automatic differentiation of get_thetahat
of all orders.
A forward mode implementation is obviously similar.
So how does it work? It works, but unfortunately, it’s quite slow for higher
order derivatives, at least relative to vittles
. I believe that our paper
actually makes the reasons for our speedup clear.
- You’ll need to repeatedly solve systems involving
\(\left.\frac{\partial G(\theta, \epsilon)}{\partial \theta^T}
\right|_{\epsilon_0} ^{-1}\)
, i.e., calculate
np.linalg.solve(dg_dtheta(thetahat, epsilon), ...)
. If you are implementing the derivatives in closed form (as I do invittles
) you can Cholesky factorize this matrix once, but autograd doesn’t know that — and can’t, because it needs to differentate through the matrix evaluation for this trick to work. - Lower-order derivatives of \(\hat\theta(\epsilon)\) appear in the expressions for higher-order derivatives. Again, if you are implementing the derivatives in closed form you can avoid re-caluculating these lower-order derivatives, but autodiff has no way to be aware of this redundant structure.
- Because every higher order derivative multiplies only terms of the form \(\epsilon - \epsilon_0\), there is a lot of redundancy due to symmetry. Closed-form implementations can recognize this symmetry and avoid redundant calculations but, again, automatic differentiation is not aware of this structure.
Although vittles
overcomes these difficulties, it does so at the cost of
considerable complexity — as you might expect, since essentially the
benefits come from caching, which is always complicated.
For more details, you can see the notebook below, which is also available here for download.
from autograd import numpy as np
import autograd
from autograd.core import primitive, defvjp, defjvp
from autograd.numpy.linalg import slogdet, solve, inv
import scipy as sp
import vittles
import time
# Define an objective function.
hyperdim = 3
dim = 4
a_mat = np.random.random((dim, dim))
a_mat = a_mat @ a_mat.T + np.eye(dim)
def obj(par, hyperpar):
return \
0.5 * np.dot(par, a_mat) @ par + \
(np.mean(par - 0.5) ** 5) * (np.mean(hyperpar - 0.5) ** 5)
hyperpar0 = np.zeros(hyperdim)
get_obj_hess = autograd.hessian(obj, argnum=0)
get_obj_grad_par = autograd.grad(obj, argnum=0)
# Get the optimum.
opt = sp.optimize.minimize(
lambda par: obj(par, hyperpar0),
np.zeros(dim),
method='trust-ncg',
jac=lambda par: get_obj_grad_par(par, hyperpar0),
hess=lambda par: get_obj_hess(par, hyperpar0),
tol=1e-16, callback=None, options={})
par0 = opt.x
# Specify a new hyperparameter.
delta = 0.1 * np.random.random(hyperdim)
hyperpar1 = hyperpar0 + delta
obj(par0, hyperpar0)
hess0 = get_obj_hess(par0, hyperpar0)
get_obj_cross_hess = autograd.jacobian(get_obj_grad_par, argnum=1)
cross_hess0 = get_obj_cross_hess(par0, hyperpar0)
print(opt)
print('Optimal parameter: ', par0)
# Make sure we're at an optimum.
assert(np.linalg.norm(get_obj_grad_par(par0, hyperpar0)) < 1e-8)
fun: 0.0009748280108099906
hess: array([[2.44509221, 1.66677909, 1.72065082, 1.2325689 ],
[1.66677909, 3.00646763, 1.98422462, 1.30356959],
[1.72065082, 1.98422462, 3.22225725, 1.44557718],
[1.2325689 , 1.30356959, 1.44557718, 2.22147862]])
jac: array([-2.91433544e-16, -2.91867225e-16, -2.91867225e-16, -2.91867225e-16])
message: 'A bad approximation caused failure to predict improvement.'
nfev: 4
nhev: 3
nit: 2
njev: 3
status: 2
success: False
x: array([4.59090270e-04, 2.15422310e-04, 8.16495910e-05, 6.64732807e-04])
Optimal parameter: [4.59090270e-04 2.15422310e-04 8.16495910e-05 6.64732807e-04]
def check_hyperpar(hyperpar):
if np.linalg.norm(hyperpar - hyperpar0) > 1e-8:
raise ValueError('Wrong value for hyperpar. ', hyperpar, ' != ', hyperpar0)
# Compare with the linear approximation class, which uses reverse mode.
obj_lin = vittles.HyperparameterSensitivityLinearApproximation(
obj,
par0,
hyperpar0,
validate_optimum=True,
factorize_hessian=True)
# Using Martin's trick, higher-order derivatives of the optimum just take a few lines.
@primitive
def get_parhat(hyperpar):
check_hyperpar(hyperpar)
return par0
print('Should be zeros:', get_parhat(hyperpar0) - par0)
# Reverse mode
def get_parhat_vjp(ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
def vjp(g):
return -1 * (get_obj_cross_hess(get_parhat(hyperpar), hyperpar).T @
np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar), g)).T
return vjp
defvjp(get_parhat, get_parhat_vjp)
# Forward mode
def get_parhat_jvp(g, ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
return -1 * (np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar),
get_obj_cross_hess(get_parhat(hyperpar), hyperpar)) @ g)
defjvp(get_parhat, get_parhat_jvp)
Should be zeros: [0. 0. 0. 0.]
# Check reverse mode first derivatives against manual formula.
get_dpar_dhyperpar = autograd.jacobian(get_parhat)
dpar_dhyperpar = get_dpar_dhyperpar(hyperpar0)
# Check that the first derivative matches.
assert(np.linalg.norm(
dpar_dhyperpar -
obj_lin.get_dopt_dhyper()) < 1e-8)
# Check forward mode first derivatives.
# I prefer my interface to autograd's for forward mode. Behind the scenes
# it's the same thing.
from vittles.sensitivity_lib import _append_jvp
get_dpar_dhyperpar_delta = _append_jvp(get_parhat)
dpar_dhyperpar_delta = get_dpar_dhyperpar_delta(hyperpar0, delta)
# Check that the first derivative matches.
assert(np.linalg.norm(
dpar_dhyperpar_delta -
obj_lin.get_dopt_dhyper() @ delta) < 1e-8)
# Let's compare against the Taylor expansion class for higher derivatives.
obj_taylor = vittles.ParametricSensitivityTaylorExpansion(
obj,
par0,
hyperpar0,
order=4)
# Sanity check.
assert(np.linalg.norm(
obj_taylor.evaluate_input_derivs(delta, max_order=1)[0] -
dpar_dhyperpar_delta) < 1e-8)
/home/rgiordan/Documents/git_repos/vittles/vittles/sensitivity_lib.py:857: UserWarning: The ParametricSensitivityTaylorExpansion is experimental.
'The ParametricSensitivityTaylorExpansion is experimental.')
# Check that the second derivative reverse mode matches.
autograd_time = time.time()
get_d2par_dhyperpar2 = autograd.jacobian(get_dpar_dhyperpar)
d2par_dhyperpar2 = get_d2par_dhyperpar2(hyperpar0)
autograd_time = time.time() - autograd_time
# Let's be generous and not count the einsum against autograd.
d2par_dhyperpar2_delta2 = np.einsum('ijk,j,k->i', d2par_dhyperpar2, delta, delta)
vittles_time = time.time()
d2par_dhyperpar2_delta2_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=2)[1]
vittles_time = time.time() - vittles_time
assert(np.linalg.norm(
d2par_dhyperpar2_delta2 -
d2par_dhyperpar2_delta2_vittles) < 1e-8)
print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time: 0.3968043327331543
vittles time: 0.0041577816009521484
# Check that forward mode second derivatives match.
autograd_time = time.time()
get_d2par_d2hyperpar_delta = _append_jvp(get_dpar_dhyperpar_delta)
d2par_d2hyperpar_delta2 = get_d2par_d2hyperpar_delta(hyperpar0, delta, delta)
autograd_time = time.time() - autograd_time
assert(np.linalg.norm(
d2par_dhyperpar2_delta2 -
d2par_dhyperpar2_delta2_vittles) < 1e-8)
print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time: 0.039446353912353516
vittles time: 0.0041577816009521484
# Check that the third derivative matches.
## Reverse mode is super slow even in very low dimensions.
# autograd_time = time.time()
# get_d3par_dhyperpar3 = autograd.jacobian(get_d2par_dhyperpar2)
# d3par_dhyperpar3 = get_d3par_dhyperpar3(hyperpar0)
# autograd_time = time.time() - autograd_time
# print('Autograd time:', autograd_time)
autograd_time = time.time()
get_d3par_d3hyperpar_delta = _append_jvp(get_d2par_d2hyperpar_delta)
d3par_dhyperpar3_delta3 = get_d3par_d3hyperpar_delta(hyperpar0, delta, delta, delta)
autograd_time = time.time() - autograd_time
vittles_time = time.time()
d3par_dhyperpar3_delta3_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=3)[2]
vittles_time = time.time() - vittles_time
assert(np.linalg.norm(
d3par_dhyperpar3_delta3 -
d3par_dhyperpar3_delta3_vittles) < 1e-8)
print('Third order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Third order
Autograd forward mode time: 0.09583282470703125
vittles time: 0.017874479293823242
# Fourth order.
autograd_time = time.time()
get_d4par_d4hyperpar_delta = _append_jvp(get_d3par_d3hyperpar_delta)
d4par_dhyperpar4_delta4 = get_d4par_d4hyperpar_delta(hyperpar0, delta, delta, delta, delta)
autograd_time = time.time() - autograd_time
vittles_time = time.time()
d4par_dhyperpar4_delta4_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=4)[3]
vittles_time = time.time() - vittles_time
assert(np.linalg.norm(
d4par_dhyperpar4_delta4 -
d4par_dhyperpar4_delta4_vittles) < 1e-8)
print('Fourth order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Fourth order
Autograd forward mode time: 0.21596074104309082
vittles time: 0.03277397155761719