from autograd import numpy as np
import autograd
from autograd.core import primitive, defvjp, defjvp
from autograd.numpy.linalg import slogdet, solve, inv

import scipy as sp
import vittles

import time
# Define an objective function.

hyperdim = 3
dim = 4
a_mat = np.random.random((dim, dim)) 
a_mat = a_mat @ a_mat.T + np.eye(dim)
def obj(par, hyperpar):
    return \
        0.5 * np.dot(par, a_mat) @ par + \
        (np.mean(par - 0.5) ** 5) * (np.mean(hyperpar - 0.5) ** 5)

hyperpar0 = np.zeros(hyperdim)

get_obj_hess = autograd.hessian(obj, argnum=0)
get_obj_grad_par = autograd.grad(obj, argnum=0)

# Get the optimum.

opt = sp.optimize.minimize(
    lambda par: obj(par, hyperpar0),
    np.zeros(dim),
    method='trust-ncg',
    jac=lambda par: get_obj_grad_par(par, hyperpar0),
    hess=lambda par: get_obj_hess(par, hyperpar0),
    tol=1e-16, callback=None, options={})
par0 = opt.x

# Specify a new hyperparameter.

delta = 0.1 * np.random.random(hyperdim)
hyperpar1 = hyperpar0 + delta
obj(par0, hyperpar0)

hess0 = get_obj_hess(par0, hyperpar0)
get_obj_cross_hess = autograd.jacobian(get_obj_grad_par, argnum=1)
cross_hess0 = get_obj_cross_hess(par0, hyperpar0)

print(opt)
print('Optimal parameter: ', par0)

# Make sure we're at an optimum.
assert(np.linalg.norm(get_obj_grad_par(par0, hyperpar0)) < 1e-8)
     fun: 0.0009748280108099906
    hess: array([[2.44509221, 1.66677909, 1.72065082, 1.2325689 ],
       [1.66677909, 3.00646763, 1.98422462, 1.30356959],
       [1.72065082, 1.98422462, 3.22225725, 1.44557718],
       [1.2325689 , 1.30356959, 1.44557718, 2.22147862]])
     jac: array([-2.91433544e-16, -2.91867225e-16, -2.91867225e-16, -2.91867225e-16])
 message: 'A bad approximation caused failure to predict improvement.'
    nfev: 4
    nhev: 3
     nit: 2
    njev: 3
  status: 2
 success: False
       x: array([4.59090270e-04, 2.15422310e-04, 8.16495910e-05, 6.64732807e-04])
Optimal parameter:  [4.59090270e-04 2.15422310e-04 8.16495910e-05 6.64732807e-04]
def check_hyperpar(hyperpar):
    if np.linalg.norm(hyperpar - hyperpar0) > 1e-8:
        raise ValueError('Wrong value for hyperpar. ', hyperpar, ' != ', hyperpar0)
# Compare with the linear approximation class, which uses reverse mode.
obj_lin = vittles.HyperparameterSensitivityLinearApproximation(
    obj,
    par0,
    hyperpar0,
    validate_optimum=True,
    factorize_hessian=True)
# Using Martin's trick, higher-order derivatives of the optimum just take a few lines.

@primitive
def get_parhat(hyperpar):
    check_hyperpar(hyperpar)
    return par0

print('Should be zeros:', get_parhat(hyperpar0) - par0)

# Reverse mode
def get_parhat_vjp(ans, hyperpar):
    #check_hyperpar(hyperpar) # Need to find some way to do this with boxes    
    def vjp(g):
        return -1 * (get_obj_cross_hess(get_parhat(hyperpar), hyperpar).T @
                     np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar), g)).T
    return vjp
defvjp(get_parhat, get_parhat_vjp)

# Forward mode
def get_parhat_jvp(g, ans, hyperpar):
    #check_hyperpar(hyperpar) # Need to find some way to do this with boxes    
    return -1 * (np.linalg.solve(get_obj_hess(get_parhat(hyperpar), hyperpar),
                                 get_obj_cross_hess(get_parhat(hyperpar), hyperpar)) @ g)
defjvp(get_parhat, get_parhat_jvp)
Should be zeros: [0. 0. 0. 0.]
# Check reverse mode first derivatives against manual formula.
get_dpar_dhyperpar = autograd.jacobian(get_parhat)
dpar_dhyperpar = get_dpar_dhyperpar(hyperpar0)

# Check that the first derivative matches.
assert(np.linalg.norm(
    dpar_dhyperpar -
    obj_lin.get_dopt_dhyper()) < 1e-8)
# Check forward mode first derivatives.

# I prefer my interface to autograd's for forward mode.  Behind the scenes
# it's the same thing.
from vittles.sensitivity_lib import _append_jvp

get_dpar_dhyperpar_delta = _append_jvp(get_parhat)
dpar_dhyperpar_delta = get_dpar_dhyperpar_delta(hyperpar0, delta)

# Check that the first derivative matches.
assert(np.linalg.norm(
    dpar_dhyperpar_delta -
    obj_lin.get_dopt_dhyper() @ delta) < 1e-8)
# Let's compare against the Taylor expansion class for higher derivatives.
obj_taylor = vittles.ParametricSensitivityTaylorExpansion(
    obj,
    par0,
    hyperpar0,
    order=4)

# Sanity check.
assert(np.linalg.norm(
    obj_taylor.evaluate_input_derivs(delta, max_order=1)[0] -
    dpar_dhyperpar_delta) < 1e-8)
/home/rgiordan/Documents/git_repos/vittles/vittles/sensitivity_lib.py:857: UserWarning: The ParametricSensitivityTaylorExpansion is experimental.
  'The ParametricSensitivityTaylorExpansion is experimental.')
# Check that the second derivative reverse mode matches.
autograd_time = time.time()
get_d2par_dhyperpar2 = autograd.jacobian(get_dpar_dhyperpar)
d2par_dhyperpar2 = get_d2par_dhyperpar2(hyperpar0)
autograd_time = time.time() - autograd_time

# Let's be generous and not count the einsum against autograd.
d2par_dhyperpar2_delta2 = np.einsum('ijk,j,k->i', d2par_dhyperpar2, delta, delta)

vittles_time = time.time()
d2par_dhyperpar2_delta2_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=2)[1]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d2par_dhyperpar2_delta2 -
    d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time:  0.3968043327331543
vittles time:            0.0041577816009521484
# Check that forward mode second derivatives match.
autograd_time = time.time()
get_d2par_d2hyperpar_delta = _append_jvp(get_dpar_dhyperpar_delta)
d2par_d2hyperpar_delta2 = get_d2par_d2hyperpar_delta(hyperpar0, delta, delta)
autograd_time = time.time() - autograd_time

assert(np.linalg.norm(
    d2par_dhyperpar2_delta2 -
    d2par_dhyperpar2_delta2_vittles) < 1e-8)

print('Second order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Second order
Autograd forward mode time:  0.039446353912353516
vittles time:            0.0041577816009521484
# Check that the third derivative matches.

## Reverse mode is super slow even in very low dimensions.
# autograd_time = time.time()
# get_d3par_dhyperpar3 = autograd.jacobian(get_d2par_dhyperpar2)
# d3par_dhyperpar3 = get_d3par_dhyperpar3(hyperpar0)
# autograd_time = time.time() - autograd_time
# print('Autograd time:', autograd_time)
autograd_time = time.time()
get_d3par_d3hyperpar_delta = _append_jvp(get_d2par_d2hyperpar_delta)
d3par_dhyperpar3_delta3 = get_d3par_d3hyperpar_delta(hyperpar0, delta, delta, delta)
autograd_time = time.time() - autograd_time

vittles_time = time.time()
d3par_dhyperpar3_delta3_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=3)[2]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d3par_dhyperpar3_delta3 -
    d3par_dhyperpar3_delta3_vittles) < 1e-8)

print('Third order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Third order
Autograd forward mode time:  0.09583282470703125
vittles time:            0.017874479293823242
# Fourth order.

autograd_time = time.time()
get_d4par_d4hyperpar_delta = _append_jvp(get_d3par_d3hyperpar_delta)
d4par_dhyperpar4_delta4 = get_d4par_d4hyperpar_delta(hyperpar0, delta, delta, delta, delta)
autograd_time = time.time() - autograd_time

vittles_time = time.time()
d4par_dhyperpar4_delta4_vittles = obj_taylor.evaluate_input_derivs(delta, max_order=4)[3]
vittles_time = time.time() - vittles_time

assert(np.linalg.norm(
    d4par_dhyperpar4_delta4 -
    d4par_dhyperpar4_delta4_vittles) < 1e-8)

print('Fourth order')
print('Autograd forward mode time:\t', autograd_time)
print('vittles time:\t\t\t', vittles_time)
Fourth order
Autograd forward mode time:  0.21596074104309082
vittles time:            0.03277397155761719