I have some recent work (A Higher-Order Swiss Army Infinitesimal Jackknife) that is all about calculating Taylor expansions of optima with respect to hyperparameters using automatic differentiation.

In the paper we talk about sensitivity to data weights, but the idea is much more general. Suppose you have a parameter which you’re optimizing, , and some hyperparameter . For a fixed , you find to satisfy some vector of first-order conditions,

For example, if you’re optimizing for a fixed , would be the vector of partial derivatives of with respect to .

Of course, depends on , and you might want to approximately calculate for different without re-solving an optimiztion problem. Specifically, you might form a Taylor series expansion around some :

The difficulty is how to calculate the derivatives, which are defined implicitly through the solution of . The computation section of our paper describes one way to do so recursively, and I’ve implemented the solution in the ParametricSensitivityTaylorExpansion class of my Python package vittles.

My perspective works (and is amenable to theory), but is a bit complicated to implement. Martin Jankowiak at Uber described to me his idea for an extremely elegant, though unfortunately inefficient, implementation. Let me demonstrate his idea in autograd and discuss how it is inefficient.

First, we’ll need to use the fact that the first derivative is given by

(Recall that is a vector of the same length as .) Now, suppose we have implemented in Python, found a solution theta0 at epsilon0:

def g(theta, epsilon):
    ... your estimating equation here ...

g(theta0, epsilon0) # ...is a vector of zeros.

Using this, we can implement the optimal as the following function, which only evaluates at :

def check_epsilon(epsilon):
    assert np.linalg.norm(epsilon - hyperpar0) < 1e-8

@primitive
def get_thetahat(epsilon):
    check_epsilon(epsilon)
    return theta0

As-is, this is a useless function. It only returns what we already know, which is that theta0 is the optimum for epsilon0, and otherwise throws an error. However, we have marked it @primitive, which means we can specify a custom derivative using the formula above. We will only be able to evaluate this derivative at epsilon0, of course, but that’s all we want.

dg_dtheta = autograd.jacobian(g, argnum=0)
dg_depsilon = autograd.jacobian(g, argnum=1)

# Reverse mode AD is a "vector jacobian product", or "vjp".
def get_thetahat_vjp(ans, epsilon):
    def vjp(g):
        thetahat = get_thetahat(epsilon)
        return -1 * (dg_depsilon(thetahat, epsilon).T @
                     np.linalg.solve(dg_dtheta(thetahat, epsilon), g)).T
    return vjp

# Tell autograd to use get_thetahat_vjp to reverse-mode autodiff get_thetahat.
defvjp(get_thetahat, get_thetahat_vjp)

The fucnction vjp is simply the reverse-mode implementation of the formula above for the first derivative.

Now, the magic is that this implemenatation of the derivative of get_thetahat itself is composed of differentiable functions: linear algebra (solve and matrix multiplication), derivatives of g, and … get_thetahat itself, whose derivative we have just defined! Consequently, this single definition suffices for (reverse mode) automatic differentiation of get_thetahat of all orders. A forward mode implementation is obviously similar.

So how does it work? It works, but unfortunately, it’s quite slow for higher order derivatives, at least relative to vittles. I believe that our paper actually makes the reasons for our speedup clear.

  • You’ll need to repeatedly solve systems involving , i.e., calculate np.linalg.solve(dg_dtheta(thetahat, epsilon), ...). If you are implementing the derivatives in closed form (as I do in vittles) you can Cholesky factorize this matrix once, but autograd doesn’t know that — and can’t, because it needs to differentate through the matrix evaluation for this trick to work.
  • Lower-order derivatives of appear in the expressions for higher-order derivatives. Again, if you are implementing the derivatives in closed form you can avoid re-caluculating these lower-order derivatives, but autodiff has no way to be aware of this redundant structure.
  • Because every higher order derivative multiplies only terms of the form , there is a lot of redundancy due to symmetry. Closed-form implementations can recognize this symmetry and avoid redundant calculations but, again, automatic differentiation is not aware of this structure.

Although vittles overcomes these difficulties, it does so at the cost of considerable complexity — as you might expect, since essentially the benefits come from caching, which is always complicated.

For more details, you can see the notebook below, which is also available here for download.

thetahat_ad.utf8.md
from autograd import numpy as np
import autograd
from autograd.core import primitive, defvjp, defjvp
from autograd.numpy.linalg import slogdet, solve, inv

import scipy as sp
import vittles

import time
# Define an objective function.

hyperdim = 3
dim = 4
a_mat = np.random.random((dim, dim)) 
a_mat = a_mat @ a_mat.T + np.eye(dim)
def obj(par, hyperpar):
    return \
        0.5 * np.dot(par, a_mat) @ par + \
        (np.mean(par - 0.5) ** 5) * (np.mean(hyperpar - 0.5) ** 5)

hyperpar0 = np.zeros(hyperdim)

get_obj_hess = autograd.hessian(obj, argnum=0)
get_obj_grad_par = autograd.grad(obj, argnum=0)

# Get the optimum.

opt = sp.optimize.minimize(
    lambda par: obj(par, hyperpar0),
    np.zeros(dim),
    method='trust-ncg',
    jac=lambda par: get_obj_grad_par(par, hyperpar0),
    hess=lambda par: get_obj_hess(par, hyperpar0),
    tol=1e-16, callback=None, options={})
par0 = opt.x

# Specify a new hyperparameter.

delta = 0.1 * np.random.random(hyperdim)
hyperpar1 = hyperpar0 + delta
obj(par0, hyperpar0)

hess0 = get_obj_hess(par0, hyperpar0)
get_obj_cross_hess = autograd.jacobian(get_obj_grad_par, argnum=1)
cross_hess0 = get_obj_cross_hess(par0, hyperpar0)

print(opt)
print('Optimal parameter: ', par0)

# Make sure we're at an optimum.
assert(np.linalg.norm(get_obj_grad_par(par0, hyperpar0)) < 1e-8)
     fun: 0.0009748280108099906
    hess: array([[2.44509221, 1.66677909, 1.72065082, 1.2325689 ],
       [1.66677909, 3.00646763, 1.98422462, 1.30356959],
       [1.72065082, 1.98422462, 3.22225725, 1.44557718],
       [1.2325689 , 1.30356959, 1.44557718, 2.22147862]])
     jac: array([-2.91433544e-16, -2.91867225e-16, -2.91867225e-16, -2.91867225e-16])
 message: 'A bad approximation caused failure to predict improvement.'
    nfev: 4
    nhev: 3
     nit: 2
    njev: 3
  status: 2
 success: False
       x: array([4.59090270e-04, 2.15422310e-04, 8.16495910e-05, 6.64732807e-04])
Optimal parameter:  [4.59090270e-04 2.15422310e-04 8.16495910e-05 6.64732807e-04]
def check_hyperpar(hyperpar):
    if np.linalg.norm(hyperpar - hyperpar0) > 1e-8:
        raise ValueError('Wrong value for hyperpar. ', hyperpar, ' != ', hyperpar0)
# Compare with the linear approximation class, which uses reverse mode.
obj_lin = vittles.HyperparameterSensitivityLinearApproximation(
    obj,
    par0,
    hyperpar0,
    validate_optimum=True,
    factorize_hessian=True)
# Using Martin's trick, higher-order derivatives of the optimum just take a few lines.

@primitive
def get_parhat(hyperpar):
    check_hyperpar(hyperpar)
    return par0

print('Should be zeros:', get_parhat(hyperpar0) - par0)

# Reverse mode
def get_parhat_vjp(ans, hyperpar):
    #check_hyperpar(hyperpar) # Need to find some way to do this with boxes    
    def vjp(g):
        return -1 * (get_obj_cross_hess(get_parhat(hyperpar), hyperpar).T @
                     np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar), g)).T
    return vjp
defvjp(get_parhat, get_parhat_vjp)

# Forward mode
def get_parhat_jvp(g, ans, hyperpar):
    #check_hyperpar(hyperpar) # Need to find some way to do this with boxes    
    return -1 * (np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar),
                                 get_obj_cross_hess(get_parhat(hyperpar), hyperpar)) @ g)
defjvp(get_parhat, get_parhat_jvp)
Should be zeros: [0. 0. 0. 0.]
# Check reverse mode first derivatives against manual formula.
get_dpar_dhyperpar = autograd.jacobian(get_parhat)
dpar_dhyperpar = get_dpar_dhyperpar(hyperpar0)

# Check that the first derivative matches.
assert(np.linalg.norm(
    dpar_dhyperpar -
    obj_lin.get_dopt_dhyper()) < 1e-8)
# Check forward mode first derivatives.

# I prefer my interface to autograd's for forward mode.  Behind the scenes
# it's the same thing.
from vittles.sensitivity_lib import _append_jvp

get_dpar_dhyperpar_delta = _append_jvp(get_parhat)
dpar_dhyperpar_delta = get_dpar_dhyperpar_delta(hyperpar0, delta)

# Check that the first derivative matches.
assert(np.linalg.norm(
    dpar_dhyperpar_delta -
    obj_lin.get_dopt_dhyper() @ delta) < 1e-8)
# Let's compare against the Taylor expansion class for higher derivatives.
obj_taylor = vittles.ParametricSensitivityTaylorExpansion(
    obj,
    par0,
    hyperpar0,
    order=4)

# Sanity check.
assert(np.linalg.norm(
    obj_taylor.evaluate_input_derivs(delta, max_order=1)[0] -
    dpar_dhyperpar_delta) < 1e-8)
/home/rgiordan/Documents/git_repos/vittles/vittles/sensitivity_lib.py:857: UserWarning: The ParametricSensitivityTaylorExpansion is experimental.
  'The ParametricSensitivityTaylorExpansion is experimental.')
# Check that the second derivative reverse mode matches.
autograd_time = time.time()
get_d2par_dhyperpar2 = autograd.jacobian(get_dpar_dhyperpar)
d2par_dhyperpar2 = get_d2par_dhyperpar2(hyperpar0)
autograd_time = time.time() - autograd_time

# Let's be generous and not count the einsum against autograd.
d2par_dhyperpar2_delta2 = np.einsum('ijk,j,k->i', d2par_dhyperpar2, delta, delta)

vittles_time = time.time()
d2par_dhyperpar2_delta2_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=2)[1]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d2par_dhyperpar2_delta2 -
    d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time:  0.3968043327331543
vittles time:            0.0041577816009521484
# Check that forward mode second derivatives match.
autograd_time = time.time()
get_d2par_d2hyperpar_delta = _append_jvp(get_dpar_dhyperpar_delta)
d2par_d2hyperpar_delta2 = get_d2par_d2hyperpar_delta(hyperpar0, delta, delta)
autograd_time = time.time() - autograd_time

assert(np.linalg.norm(
    d2par_dhyperpar2_delta2 -
    d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time:  0.039446353912353516
vittles time:            0.0041577816009521484
# Check that the third derivative matches.

## Reverse mode is super slow even in very low dimensions.
# autograd_time = time.time()
# get_d3par_dhyperpar3 = autograd.jacobian(get_d2par_dhyperpar2)
# d3par_dhyperpar3 = get_d3par_dhyperpar3(hyperpar0)
# autograd_time = time.time() - autograd_time
# print('Autograd time:', autograd_time)
autograd_time = time.time()
get_d3par_d3hyperpar_delta = _append_jvp(get_d2par_d2hyperpar_delta)
d3par_dhyperpar3_delta3 = get_d3par_d3hyperpar_delta(hyperpar0, delta, delta, delta)
autograd_time = time.time() - autograd_time

vittles_time = time.time()
d3par_dhyperpar3_delta3_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=3)[2]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d3par_dhyperpar3_delta3 -
    d3par_dhyperpar3_delta3_vittles) < 1e-8)

print('Third order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Third order
Autograd forward mode time:  0.09583282470703125
vittles time:            0.017874479293823242
# Fourth order.

autograd_time = time.time()
get_d4par_d4hyperpar_delta = _append_jvp(get_d3par_d3hyperpar_delta)
d4par_dhyperpar4_delta4 = get_d4par_d4hyperpar_delta(hyperpar0, delta, delta, delta, delta)
autograd_time = time.time() - autograd_time

vittles_time = time.time()
d4par_dhyperpar4_delta4_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=4)[3]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d4par_dhyperpar4_delta4 -
    d4par_dhyperpar4_delta4_vittles) < 1e-8)

print('Fourth order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Fourth order
Autograd forward mode time:  0.21596074104309082
vittles time:            0.03277397155761719