from autograd import numpy as np
import autograd
from autograd.core import primitive, defvjp, defjvp
from autograd.numpy.linalg import slogdet, solve, inv
import scipy as sp
import vittles
import time
# Define an objective function.
= 3
hyperdim = 4
dim = np.random.random((dim, dim))
a_mat = a_mat @ a_mat.T + np.eye(dim)
a_mat def obj(par, hyperpar):
return \
0.5 * np.dot(par, a_mat) @ par + \
- 0.5) ** 5) * (np.mean(hyperpar - 0.5) ** 5)
(np.mean(par
= np.zeros(hyperdim)
hyperpar0
= autograd.hessian(obj, argnum=0)
get_obj_hess = autograd.grad(obj, argnum=0)
get_obj_grad_par
# Get the optimum.
= sp.optimize.minimize(
opt lambda par: obj(par, hyperpar0),
np.zeros(dim),='trust-ncg',
method=lambda par: get_obj_grad_par(par, hyperpar0),
jac=lambda par: get_obj_hess(par, hyperpar0),
hess=1e-16, callback=None, options={})
tol= opt.x
par0
# Specify a new hyperparameter.
= 0.1 * np.random.random(hyperdim)
delta = hyperpar0 + delta
hyperpar1
obj(par0, hyperpar0)
= get_obj_hess(par0, hyperpar0)
hess0 = autograd.jacobian(get_obj_grad_par, argnum=1)
get_obj_cross_hess = get_obj_cross_hess(par0, hyperpar0)
cross_hess0
print(opt)
print('Optimal parameter: ', par0)
# Make sure we're at an optimum.
assert(np.linalg.norm(get_obj_grad_par(par0, hyperpar0)) < 1e-8)
fun: 0.0009748280108099906
hess: array([[2.44509221, 1.66677909, 1.72065082, 1.2325689 ],
[1.66677909, 3.00646763, 1.98422462, 1.30356959],
[1.72065082, 1.98422462, 3.22225725, 1.44557718],
[1.2325689 , 1.30356959, 1.44557718, 2.22147862]])
jac: array([-2.91433544e-16, -2.91867225e-16, -2.91867225e-16, -2.91867225e-16])
message: 'A bad approximation caused failure to predict improvement.'
nfev: 4
nhev: 3
nit: 2
njev: 3
status: 2
success: False
x: array([4.59090270e-04, 2.15422310e-04, 8.16495910e-05, 6.64732807e-04])
Optimal parameter: [4.59090270e-04 2.15422310e-04 8.16495910e-05 6.64732807e-04]
def check_hyperpar(hyperpar):
if np.linalg.norm(hyperpar - hyperpar0) > 1e-8:
raise ValueError('Wrong value for hyperpar. ', hyperpar, ' != ', hyperpar0)
# Compare with the linear approximation class, which uses reverse mode.
= vittles.HyperparameterSensitivityLinearApproximation(
obj_lin
obj,
par0,
hyperpar0,=True,
validate_optimum=True) factorize_hessian
# Using Martin's trick, higher-order derivatives of the optimum just take a few lines.
@primitive
def get_parhat(hyperpar):
check_hyperpar(hyperpar)return par0
print('Should be zeros:', get_parhat(hyperpar0) - par0)
# Reverse mode
def get_parhat_vjp(ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
def vjp(g):
return -1 * (get_obj_cross_hess(get_parhat(hyperpar), hyperpar).T @
np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar), g)).Treturn vjp
defvjp(get_parhat, get_parhat_vjp)
# Forward mode
def get_parhat_jvp(g, ans, hyperpar):
#check_hyperpar(hyperpar) # Need to find some way to do this with boxes
return -1 * (np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar),
@ g)
get_obj_cross_hess(get_parhat(hyperpar), hyperpar)) defjvp(get_parhat, get_parhat_jvp)
Should be zeros: [0. 0. 0. 0.]
# Check reverse mode first derivatives against manual formula.
= autograd.jacobian(get_parhat)
get_dpar_dhyperpar = get_dpar_dhyperpar(hyperpar0)
dpar_dhyperpar
# Check that the first derivative matches.
assert(np.linalg.norm(
-
dpar_dhyperpar < 1e-8) obj_lin.get_dopt_dhyper())
# Check forward mode first derivatives.
# I prefer my interface to autograd's for forward mode. Behind the scenes
# it's the same thing.
from vittles.sensitivity_lib import _append_jvp
= _append_jvp(get_parhat)
get_dpar_dhyperpar_delta = get_dpar_dhyperpar_delta(hyperpar0, delta)
dpar_dhyperpar_delta
# Check that the first derivative matches.
assert(np.linalg.norm(
-
dpar_dhyperpar_delta @ delta) < 1e-8) obj_lin.get_dopt_dhyper()
# Let's compare against the Taylor expansion class for higher derivatives.
= vittles.ParametricSensitivityTaylorExpansion(
obj_taylor
obj,
par0,
hyperpar0,=4)
order
# Sanity check.
assert(np.linalg.norm(
=1)[0] -
obj_taylor.evaluate_input_derivs(delta, max_order< 1e-8) dpar_dhyperpar_delta)
/home/rgiordan/Documents/git_repos/vittles/vittles/sensitivity_lib.py:857: UserWarning: The ParametricSensitivityTaylorExpansion is experimental.
'The ParametricSensitivityTaylorExpansion is experimental.')
# Check that the second derivative reverse mode matches.
= time.time()
autograd_time = autograd.jacobian(get_dpar_dhyperpar)
get_d2par_dhyperpar2 = get_d2par_dhyperpar2(hyperpar0)
d2par_dhyperpar2 = time.time() - autograd_time
autograd_time
# Let's be generous and not count the einsum against autograd.
= np.einsum('ijk,j,k->i', d2par_dhyperpar2, delta, delta)
d2par_dhyperpar2_delta2
= time.time()
vittles_time = obj_taylor.evaluate_input_derivs(delta, max_order=2)[1]
d2par_dhyperpar2_delta2_vittles = time.time() - vittles_time
vittles_time
assert(np.linalg.norm(
-
d2par_dhyperpar2_delta2 < 1e-8)
d2par_dhyperpar2_delta2_vittles)
print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time: 0.3968043327331543
vittles time: 0.0041577816009521484
# Check that forward mode second derivatives match.
= time.time()
autograd_time = _append_jvp(get_dpar_dhyperpar_delta)
get_d2par_d2hyperpar_delta = get_d2par_d2hyperpar_delta(hyperpar0, delta, delta)
d2par_d2hyperpar_delta2 = time.time() - autograd_time
autograd_time
assert(np.linalg.norm(
-
d2par_dhyperpar2_delta2 < 1e-8)
d2par_dhyperpar2_delta2_vittles)
print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time: 0.039446353912353516
vittles time: 0.0041577816009521484
# Check that the third derivative matches.
## Reverse mode is super slow even in very low dimensions.
# autograd_time = time.time()
# get_d3par_dhyperpar3 = autograd.jacobian(get_d2par_dhyperpar2)
# d3par_dhyperpar3 = get_d3par_dhyperpar3(hyperpar0)
# autograd_time = time.time() - autograd_time
# print('Autograd time:', autograd_time)
= time.time()
autograd_time = _append_jvp(get_d2par_d2hyperpar_delta)
get_d3par_d3hyperpar_delta = get_d3par_d3hyperpar_delta(hyperpar0, delta, delta, delta)
d3par_dhyperpar3_delta3 = time.time() - autograd_time
autograd_time
= time.time()
vittles_time = obj_taylor.evaluate_input_derivs(delta, max_order=3)[2]
d3par_dhyperpar3_delta3_vittles = time.time() - vittles_time
vittles_time
assert(np.linalg.norm(
-
d3par_dhyperpar3_delta3 < 1e-8)
d3par_dhyperpar3_delta3_vittles)
print('Third order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Third order
Autograd forward mode time: 0.09583282470703125
vittles time: 0.017874479293823242
# Fourth order.
= time.time()
autograd_time = _append_jvp(get_d3par_d3hyperpar_delta)
get_d4par_d4hyperpar_delta = get_d4par_d4hyperpar_delta(hyperpar0, delta, delta, delta, delta)
d4par_dhyperpar4_delta4 = time.time() - autograd_time
autograd_time
= time.time()
vittles_time = obj_taylor.evaluate_input_derivs(delta, max_order=4)[3]
d4par_dhyperpar4_delta4_vittles = time.time() - vittles_time
vittles_time
assert(np.linalg.norm(
-
d4par_dhyperpar4_delta4 < 1e-8)
d4par_dhyperpar4_delta4_vittles)
print('Fourth order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Fourth order
Autograd forward mode time: 0.21596074104309082
vittles time: 0.03277397155761719