``````from autograd import numpy as np
from autograd.core import primitive, defvjp, defjvp
from autograd.numpy.linalg import slogdet, solve, inv

import scipy as sp
import vittles

import time``````
``````# Define an objective function.

hyperdim = 3
dim = 4
a_mat = np.random.random((dim, dim))
a_mat = a_mat @ a_mat.T + np.eye(dim)
def obj(par, hyperpar):
return \
0.5 * np.dot(par, a_mat) @ par + \
(np.mean(par - 0.5) ** 5) * (np.mean(hyperpar - 0.5) ** 5)

hyperpar0 = np.zeros(hyperdim)

# Get the optimum.

opt = sp.optimize.minimize(
lambda par: obj(par, hyperpar0),
np.zeros(dim),
method='trust-ncg',
hess=lambda par: get_obj_hess(par, hyperpar0),
tol=1e-16, callback=None, options={})
par0 = opt.x

# Specify a new hyperparameter.

delta = 0.1 * np.random.random(hyperdim)
hyperpar1 = hyperpar0 + delta
obj(par0, hyperpar0)

hess0 = get_obj_hess(par0, hyperpar0)
cross_hess0 = get_obj_cross_hess(par0, hyperpar0)

print(opt)
print('Optimal parameter: ', par0)

# Make sure we're at an optimum.
``````     fun: 0.0009748280108099906
hess: array([[2.44509221, 1.66677909, 1.72065082, 1.2325689 ],
[1.66677909, 3.00646763, 1.98422462, 1.30356959],
[1.72065082, 1.98422462, 3.22225725, 1.44557718],
[1.2325689 , 1.30356959, 1.44557718, 2.22147862]])
jac: array([-2.91433544e-16, -2.91867225e-16, -2.91867225e-16, -2.91867225e-16])
message: 'A bad approximation caused failure to predict improvement.'
nfev: 4
nhev: 3
nit: 2
njev: 3
status: 2
success: False
x: array([4.59090270e-04, 2.15422310e-04, 8.16495910e-05, 6.64732807e-04])
Optimal parameter:  [4.59090270e-04 2.15422310e-04 8.16495910e-05 6.64732807e-04]``````
``````def check_hyperpar(hyperpar):
if np.linalg.norm(hyperpar - hyperpar0) > 1e-8:
raise ValueError('Wrong value for hyperpar. ', hyperpar, ' != ', hyperpar0)``````
``````# Compare with the linear approximation class, which uses reverse mode.
obj_lin = vittles.HyperparameterSensitivityLinearApproximation(
obj,
par0,
hyperpar0,
validate_optimum=True,
factorize_hessian=True)``````
``````# Using Martin's trick, higher-order derivatives of the optimum just take a few lines.

@primitive
def get_parhat(hyperpar):
check_hyperpar(hyperpar)
return par0

print('Should be zeros:', get_parhat(hyperpar0) - par0)

# Reverse mode
def get_parhat_vjp(ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
def vjp(g):
return -1 * (get_obj_cross_hess(get_parhat(hyperpar), hyperpar).T @
np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar), g)).T
return vjp
defvjp(get_parhat, get_parhat_vjp)

# Forward mode
def get_parhat_jvp(g, ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
return -1 * (np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar),
get_obj_cross_hess(get_parhat(hyperpar), hyperpar)) @ g)
defjvp(get_parhat, get_parhat_jvp)``````
``Should be zeros: [0. 0. 0. 0.]``
``````# Check reverse mode first derivatives against manual formula.
dpar_dhyperpar = get_dpar_dhyperpar(hyperpar0)

# Check that the first derivative matches.
assert(np.linalg.norm(
dpar_dhyperpar -
obj_lin.get_dopt_dhyper()) < 1e-8)``````
``````# Check forward mode first derivatives.

# I prefer my interface to autograd's for forward mode.  Behind the scenes
# it's the same thing.
from vittles.sensitivity_lib import _append_jvp

get_dpar_dhyperpar_delta = _append_jvp(get_parhat)
dpar_dhyperpar_delta = get_dpar_dhyperpar_delta(hyperpar0, delta)

# Check that the first derivative matches.
assert(np.linalg.norm(
dpar_dhyperpar_delta -
obj_lin.get_dopt_dhyper() @ delta) < 1e-8)``````
``````# Let's compare against the Taylor expansion class for higher derivatives.
obj_taylor = vittles.ParametricSensitivityTaylorExpansion(
obj,
par0,
hyperpar0,
order=4)

# Sanity check.
assert(np.linalg.norm(
obj_taylor.evaluate_input_derivs(delta, max_order=1)[0] -
dpar_dhyperpar_delta) < 1e-8)``````
``````/home/rgiordan/Documents/git_repos/vittles/vittles/sensitivity_lib.py:857: UserWarning: The ParametricSensitivityTaylorExpansion is experimental.
'The ParametricSensitivityTaylorExpansion is experimental.')``````
``````# Check that the second derivative reverse mode matches.
d2par_dhyperpar2 = get_d2par_dhyperpar2(hyperpar0)

# Let's be generous and not count the einsum against autograd.
d2par_dhyperpar2_delta2 = np.einsum('ijk,j,k->i', d2par_dhyperpar2, delta, delta)

vittles_time = time.time()
d2par_dhyperpar2_delta2_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=2)[1]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
d2par_dhyperpar2_delta2 -
d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('vittles time:\t\t\t', vittles_time)``````
``````Second order
vittles time:            0.0041577816009521484``````
``````# Check that forward mode second derivatives match.
get_d2par_d2hyperpar_delta = _append_jvp(get_dpar_dhyperpar_delta)
d2par_d2hyperpar_delta2 = get_d2par_d2hyperpar_delta(hyperpar0, delta, delta)

assert(np.linalg.norm(
d2par_dhyperpar2_delta2 -
d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('vittles time:\t\t\t', vittles_time)``````
``````Second order
vittles time:            0.0041577816009521484``````
``````# Check that the third derivative matches.

## Reverse mode is super slow even in very low dimensions.
# d3par_dhyperpar3 = get_d3par_dhyperpar3(hyperpar0)
``````autograd_time = time.time()
get_d3par_d3hyperpar_delta = _append_jvp(get_d2par_d2hyperpar_delta)
d3par_dhyperpar3_delta3 = get_d3par_d3hyperpar_delta(hyperpar0, delta, delta, delta)

vittles_time = time.time()
d3par_dhyperpar3_delta3_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=3)[2]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
d3par_dhyperpar3_delta3 -
d3par_dhyperpar3_delta3_vittles) < 1e-8)

print('Third order')
print('vittles time:\t\t\t', vittles_time)``````
``````Third order
vittles time:            0.017874479293823242``````
``````# Fourth order.

get_d4par_d4hyperpar_delta = _append_jvp(get_d3par_d3hyperpar_delta)
d4par_dhyperpar4_delta4 = get_d4par_d4hyperpar_delta(hyperpar0, delta, delta, delta, delta)

vittles_time = time.time()
d4par_dhyperpar4_delta4_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=4)[3]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
d4par_dhyperpar4_delta4 -
d4par_dhyperpar4_delta4_vittles) < 1e-8)

print('Fourth order')
``````Fourth order