import numpy as np
import numpy as np
from cuqi.sampler import Sampler
from cuqi.array import CUQIarray
from numbers import Number
[docs]
class NUTS(Sampler):
"""No-U-Turn Sampler (Hoffman and Gelman, 2014).
Samples a distribution given its logpdf and gradient using a Hamiltonian
Monte Carlo (HMC) algorithm with automatic parameter tuning.
For more details see: See Hoffman, M. D., & Gelman, A. (2014). The no-U-turn
sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. Journal
of Machine Learning Research, 15, 1593-1623.
Parameters
----------
target : `cuqi.distribution.Distribution`
The target distribution to sample. Must have logpdf and gradient method.
Custom logpdfs and gradients are supported by using a
:class:`cuqi.distribution.UserDefinedDistribution`.
initial_point : ndarray
Initial parameters. *Optional*. If not provided, the initial point is
an array of ones.
max_depth : int
Maximum depth of the tree >=0 and the default is 15.
step_size : None or float
If step_size is provided (as positive float), it will be used as initial
step size. If None, the step size will be estimated by the sampler.
opt_acc_rate : float
The optimal acceptance rate to reach if using adaptive step size.
Suggested values are 0.6 (default) or 0.8 (as in stan). In principle,
opt_acc_rate should be in (0, 1), however, choosing a value that is very
close to 1 or 0 might lead to poor performance of the sampler.
callback : callable, optional
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
Example
-------
.. code-block:: python
# Import cuqi
import cuqi
# Define a target distribution
tp = cuqi.testproblem.WangCubic()
target = tp.posterior
# Set up sampler
sampler = cuqi.sampler.NUTS(target)
# Sample
sampler.warmup(5000)
sampler.sample(10000)
# Get samples
samples = sampler.get_samples()
# Plot samples
samples.plot_pair()
After running the NUTS sampler, run diagnostics can be accessed via the
following attributes:
.. code-block:: python
# Number of tree nodes created each NUTS iteration
sampler.num_tree_node_list
# Step size used in each NUTS iteration
sampler.epsilon_list
# Suggested step size during adaptation (the value of this step size is
# only used after adaptation).
sampler.epsilon_bar_list
"""
_STATE_KEYS = Sampler._STATE_KEYS.union({'_epsilon', '_epsilon_bar',
'_H_bar',
'current_target_logd',
'current_target_grad',
'max_depth'})
_HISTORY_KEYS = Sampler._HISTORY_KEYS.union({'num_tree_node_list',
'epsilon_list',
'epsilon_bar_list'})
[docs]
def __init__(self, target=None, initial_point=None, max_depth=None,
step_size=None, opt_acc_rate=0.6, **kwargs):
super().__init__(target, initial_point=initial_point, **kwargs)
# Assign parameters as attributes
self.max_depth = max_depth
self.step_size = step_size
self.opt_acc_rate = opt_acc_rate
def _initialize(self):
self._current_alpha_ratio = np.nan # Current alpha ratio will be set to some
# value (other than np.nan) before
# being used
self.current_target_logd, self.current_target_grad = self._nuts_target(self.current_point)
# Parameters dual averaging
# Initialize epsilon and epsilon_bar
# epsilon is the step size used in the current iteration
# after warm up and one sampling step, epsilon is updated
# to epsilon_bar for the remaining sampling steps.
if self.step_size is None:
self._epsilon = self._FindGoodEpsilon()
self.step_size = self._epsilon
else:
self._epsilon = self.step_size
self._epsilon_bar = "unset"
# Parameter mu, does not change during the run
self._mu = np.log(10*self._epsilon)
self._H_bar = 0
# NUTS run diagnostics
# number of tree nodes created each NUTS iteration
self._num_tree_node = 0
# Create lists to store NUTS run diagnostics
self._create_run_diagnostic_attributes()
#=========================================================================
#============================== Properties ===============================
#=========================================================================
@property
def max_depth(self):
return self._max_depth
@max_depth.setter
def max_depth(self, value):
if value is None:
value = 15 # default value
if not isinstance(value, int):
raise TypeError('max_depth must be an integer.')
if value < 0:
raise ValueError('max_depth must be >= 0.')
self._max_depth = value
@property
def step_size(self):
return self._step_size
@step_size.setter
def step_size(self, value):
if value is None:
pass # NUTS will adapt the step size
# step_size must be a positive float, raise error otherwise
elif isinstance(value, bool)\
or not isinstance(value, Number)\
or value <= 0:
raise TypeError('step_size must be a positive float or None.')
self._step_size = value
@property
def opt_acc_rate(self):
return self._opt_acc_rate
@opt_acc_rate.setter
def opt_acc_rate(self, value):
if not isinstance(value, Number) or value <= 0 or value >= 1:
raise ValueError('opt_acc_rate must be a float in (0, 1).')
self._opt_acc_rate = value
#=========================================================================
#================== Implement methods required by Sampler =============
#=========================================================================
[docs]
def validate_target(self):
# Check if the target has logd and gradient methods
try:
current_target_logd, current_target_grad =\
self._nuts_target(np.ones(self.dim))
except:
raise ValueError('Target must have logd and gradient methods.')
[docs]
def reinitialize(self):
# Call the parent reset method
super().reinitialize()
# Reset NUTS run diagnostic attributes
self._reset_run_diagnostic_attributes()
[docs]
def step(self):
if isinstance(self._epsilon_bar, str) and self._epsilon_bar == "unset":
self._epsilon_bar = self._epsilon
# Convert current_point, logd, and grad to numpy arrays
# if they are CUQIarray objects
if isinstance(self.current_point, CUQIarray):
self.current_point = self.current_point.to_numpy()
if isinstance(self.current_target_logd, CUQIarray):
self.current_target_logd = self.current_target_logd.to_numpy()
if isinstance(self.current_target_grad, CUQIarray):
self.current_target_grad = self.current_target_grad.to_numpy()
# reset number of tree nodes for each iteration
self._num_tree_node = 0
# copy current point, logd, and grad in local variables
point_k = self.current_point # initial position (parameters)
logd_k = self.current_target_logd
grad_k = self.current_target_grad # initial gradient
# compute r_k and Hamiltonian
r_k = self._Kfun(1, 'sample') # resample momentum vector
Ham = logd_k - self._Kfun(r_k, 'eval') # Hamiltonian
# slice variable
log_u = Ham - np.random.exponential(1, size=1)
# initialization
j, s, n = 0, 1, 1
point_minus, point_plus = point_k.copy(), point_k.copy()
grad_minus, grad_plus = grad_k.copy(), grad_k.copy()
r_minus, r_plus = r_k.copy(), r_k.copy()
# run NUTS
acc = 0
while (s == 1) and (j <= self.max_depth):
# sample a direction
v = int(2*(np.random.rand() < 0.5)-1)
# build tree: doubling procedure
if (v == -1):
point_minus, r_minus, grad_minus, _, _, _, \
point_prime, logd_prime, grad_prime,\
n_prime, s_prime, alpha, n_alpha = \
self._BuildTree(point_minus, r_minus, grad_minus,
Ham, log_u, v, j, self._epsilon)
else:
_, _, _, point_plus, r_plus, grad_plus, \
point_prime, logd_prime, grad_prime,\
n_prime, s_prime, alpha, n_alpha = \
self._BuildTree(point_plus, r_plus, grad_plus,
Ham, log_u, v, j, self._epsilon)
# Metropolis step
alpha2 = min(1, (n_prime/n)) #min(0, np.log(n_p) - np.log(n))
if (s_prime == 1) and \
(np.random.rand() <= alpha2) and \
(not np.isnan(logd_prime)) and \
(not np.isinf(logd_prime)):
self.current_point = point_prime.copy()
# copy if array, else assign if scalar
self.current_target_logd = (
logd_prime.copy()
if isinstance(logd_prime, np.ndarray)
else logd_prime
)
self.current_target_grad = grad_prime.copy()
acc = 1
# update number of particles, tree level, and stopping criterion
n += n_prime
dpoints = point_plus - point_minus
s = s_prime *\
int((dpoints @ r_minus.T) >= 0) * int((dpoints @ r_plus.T) >= 0)
j += 1
self._current_alpha_ratio = alpha/n_alpha
# update run diagnostic attributes
self._update_run_diagnostic_attributes(
self._num_tree_node, self._epsilon, self._epsilon_bar)
self._epsilon = self._epsilon_bar
if np.isnan(self.current_target_logd):
raise NameError('NaN potential func')
return acc
[docs]
def tune(self, skip_len, update_count):
""" adapt epsilon during burn-in using dual averaging"""
if isinstance(self._epsilon_bar, str) and self._epsilon_bar == "unset":
self._epsilon_bar = 1
k = update_count+1
# Fixed parameters that do not change during the run
gamma, t_0, kappa = 0.05, 10, 0.75 # kappa in (0.5, 1]
eta1 = 1/(k + t_0)
self._H_bar = (1-eta1)*self._H_bar +\
eta1*(self.opt_acc_rate - (self._current_alpha_ratio))
self._epsilon = np.exp(self._mu - (np.sqrt(k)/gamma)*self._H_bar)
eta = k**(-kappa)
self._epsilon_bar =\
np.exp(eta*np.log(self._epsilon) +(1-eta)*np.log(self._epsilon_bar))
#=========================================================================
def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
return self.target.logd(x), self.target.gradient(x)
#=========================================================================
# auxiliary standard Gaussian PDF: kinetic energy function
# d_log_2pi = d*np.log(2*np.pi)
def _Kfun(self, r, flag):
if flag == 'eval': # evaluate
return 0.5*(r.T @ r) #+ d_log_2pi
if flag == 'sample': # sample
return np.random.standard_normal(size=self.dim)
#=========================================================================
def _FindGoodEpsilon(self, epsilon=1):
point_k = self.current_point
self.current_target_logd, self.current_target_grad = self._nuts_target(
point_k)
logd = self.current_target_logd
grad = self.current_target_grad
r = self._Kfun(1, 'sample') # resample a momentum
Ham = logd - self._Kfun(r, 'eval') # initial Hamiltonian
_, r_prime, logd_prime, grad_prime = self._Leapfrog(
point_k, r, grad, epsilon)
# trick to make sure the step is not huge, leading to infinite values of
# the likelihood
k = 1
while np.isinf(logd_prime) or np.isinf(grad_prime).any():
k *= 0.5
_, r_prime, logd_prime, grad_prime = self._Leapfrog(
point_k, r, grad, epsilon*k)
epsilon = 0.5*k*epsilon
# doubles/halves the value of epsilon until the accprob of the Langevin
# proposal crosses 0.5
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
log_ratio = Ham_prime - Ham
a = 1 if log_ratio > np.log(0.5) else -1
while (a*log_ratio > -a*np.log(2)):
epsilon = (2**a)*epsilon
_, r_prime, logd_prime, _ = self._Leapfrog(
point_k, r, grad, epsilon)
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
log_ratio = Ham_prime - Ham
return epsilon
#=========================================================================
def _Leapfrog(self, point_old, r_old, grad_old, epsilon):
# symplectic integrator: trajectories preserve phase space volumen
r_new = r_old + 0.5*epsilon*grad_old # half-step
point_new = point_old + epsilon*r_new # full-step
logd_new, grad_new = self._nuts_target(point_new) # new gradient
r_new += 0.5*epsilon*grad_new # half-step
return point_new, r_new, logd_new, grad_new
#=========================================================================
def _BuildTree(
self, point_k, r, grad, Ham, log_u, v, j, epsilon, Delta_max=1000):
# Increment the number of tree nodes counter
self._num_tree_node += 1
if (j == 0): # base case
# single leapfrog step in the direction v
point_prime, r_prime, logd_prime, grad_prime = self._Leapfrog(
point_k, r, grad, v*epsilon)
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval') # Hamiltonian
# eval
n_prime = int(log_u <= Ham_prime) # if particle is in the slice
s_prime = int(log_u < Delta_max + Ham_prime) # check U-turn
#
diff_Ham = Ham_prime - Ham
# Compute the acceptance probability
# alpha_prime = min(1, np.exp(diff_Ham))
# written in a stable way to avoid overflow when computing
# exp(diff_Ham) for large values of diff_Ham
alpha_prime = 1 if diff_Ham > 0 else np.exp(diff_Ham)
n_alpha_prime = 1
#
point_minus, point_plus = point_prime, point_prime
r_minus, r_plus = r_prime, r_prime
grad_minus, grad_plus = grad_prime, grad_prime
else:
# recursion: build the left/right subtrees
point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus, \
point_prime, logd_prime, grad_prime,\
n_prime, s_prime, alpha_prime, n_alpha_prime = \
self._BuildTree(point_k, r, grad,
Ham, log_u, v, j-1, epsilon)
if (s_prime == 1): # do only if the stopping criteria does not
# verify at the first subtree
if (v == -1):
point_minus, r_minus, grad_minus, _, _, _, \
point_2prime, logd_2prime, grad_2prime,\
n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
self._BuildTree(point_minus, r_minus, grad_minus,
Ham, log_u, v, j-1, epsilon)
else:
_, _, _, point_plus, r_plus, grad_plus, \
point_2prime, logd_2prime, grad_2prime,\
n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
self._BuildTree(point_plus, r_plus, grad_plus,
Ham, log_u, v, j-1, epsilon)
# Metropolis step
alpha2 = n_2prime / max(1, (n_prime + n_2prime))
if (np.random.rand() <= alpha2):
point_prime = point_2prime.copy()
# copy if array, else assign if scalar
logd_prime = (
logd_2prime.copy()
if isinstance(logd_2prime, np.ndarray)
else logd_2prime
)
grad_prime = grad_2prime.copy()
# update number of particles and stopping criterion
alpha_prime += alpha_2prime
n_alpha_prime += n_alpha_2prime
dpoints = point_plus - point_minus
s_prime = s_2prime *\
int((dpoints@r_minus.T)>=0) * int((dpoints@r_plus.T)>=0)
n_prime += n_2prime
return point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus,\
point_prime, logd_prime, grad_prime,\
n_prime, s_prime, alpha_prime, n_alpha_prime
#=========================================================================
#======================== Diagnostic methods =============================
#=========================================================================
def _create_run_diagnostic_attributes(self):
"""A method to create attributes to store NUTS run diagnostic."""
self._reset_run_diagnostic_attributes()
def _reset_run_diagnostic_attributes(self):
"""A method to reset attributes to store NUTS run diagnostic."""
# List to store number of tree nodes created each NUTS iteration
self.num_tree_node_list = []
# List of step size used in each NUTS iteration
self.epsilon_list = []
# List of burn-in step size suggestion during adaptation
# only used when adaptation is done
# remains fixed after adaptation (after burn-in)
self.epsilon_bar_list = []
def _update_run_diagnostic_attributes(self, n_tree, eps, eps_bar):
"""A method to update attributes to store NUTS run diagnostic."""
# Store the number of tree nodes created in iteration k
self.num_tree_node_list.append(n_tree)
# Store the step size used in iteration k
self.epsilon_list.append(eps)
# Store the step size suggestion during adaptation in iteration k
self.epsilon_bar_list.append(eps_bar)