#!/usr/bin/env python
"""
Functions for performing basic operations on nested sampling runs; such as
working out point weights and splitting and combining runs.
Nested sampling runs are stored in a standard format as python dictionaries
(see the ``data_processing`` module docstring for more details).
"""
import copy
import warnings
import numpy as np
import scipy.special
[docs]def run_estimators(ns_run, estimator_list, simulate=False):
"""Calculates values of list of quantities (such as the Bayesian evidence
or mean of parameters) for a single nested sampling run.
Parameters
----------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
estimator_list: list of functions for estimating quantities from nested
sampling runs. Example functions can be found in estimators.py. Each
should have arguments: func(ns_run, logw=None).
simulate: bool, optional
See get_logw docstring.
Returns
-------
output: 1d numpy array
Calculation result for each estimator in estimator_list.
"""
logw = get_logw(ns_run, simulate=simulate)
output = np.zeros(len(estimator_list))
for i, est in enumerate(estimator_list):
output[i] = est(ns_run, logw=logw)
return output
[docs]def array_given_run(ns_run):
"""Converts information on samples in a nested sampling run dictionary into
a numpy array representation. This allows fast addition of more samples and
recalculation of nlive.
Parameters
----------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
Returns
-------
samples: 2d numpy array
Array containing columns
[logl, thread label, change in nlive at sample, (thetas)]
with each row representing a single sample.
"""
samples = np.zeros((ns_run['logl'].shape[0], 3 + ns_run['theta'].shape[1]))
samples[:, 0] = ns_run['logl']
samples[:, 1] = ns_run['thread_labels']
# Calculate 'change in nlive' after each step
samples[:-1, 2] = np.diff(ns_run['nlive_array'])
samples[-1, 2] = -1 # nlive drops to zero after final point
samples[:, 3:] = ns_run['theta']
return samples
[docs]def dict_given_run_array(samples, thread_min_max):
"""
Converts an array of information about samples back into a nested sampling
run dictionary (see data_processing module docstring for more details).
N.B. the output dict only contains the following keys: 'logl',
'thread_label', 'nlive_array', 'theta'. Any other keys giving additional
information about the run output cannot be reproduced from the function
arguments, and are therefore ommitted.
Parameters
----------
samples: numpy array
Numpy array containing columns
[logl, thread label, change in nlive at sample, (thetas)]
with each row representing a single sample.
thread_min_max': numpy array, optional
2d array with a row for each thread containing the likelihoods at which
it begins and ends.
Needed to calculate nlive_array (otherwise this is set to None).
Returns
-------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
"""
ns_run = {'logl': samples[:, 0],
'thread_labels': samples[:, 1],
'thread_min_max': thread_min_max,
'theta': samples[:, 3:]}
if np.all(~np.isnan(ns_run['thread_labels'])):
ns_run['thread_labels'] = ns_run['thread_labels'].astype(int)
assert np.array_equal(samples[:, 1], ns_run['thread_labels']), ((
'Casting thread labels from samples array to int has changed '
'their values!\nsamples[:, 1]={}\nthread_labels={}').format(
samples[:, 1], ns_run['thread_labels']))
nlive_0 = (thread_min_max[:, 0] <= ns_run['logl'].min()).sum()
assert nlive_0 > 0, 'nlive_0={}'.format(nlive_0)
nlive_array = np.zeros(samples.shape[0]) + nlive_0
nlive_array[1:] += np.cumsum(samples[:-1, 2])
# Check if there are multiple threads starting on the first logl point
dup_th_starts = (thread_min_max[:, 0] == ns_run['logl'].min()).sum()
if dup_th_starts > 1:
# In this case we approximate the true nlive (which we dont really
# know) by making sure the array's final point is 1 and setting all
# points with logl = logl.min() to have the same nlive
nlive_array += (1 - nlive_array[-1])
n_logl_min = (ns_run['logl'] == ns_run['logl'].min()).sum()
nlive_array[:n_logl_min] = nlive_0
warnings.warn((
'duplicate starting logls: {} threads start at logl.min()={}, '
'and {} points have logl=logl.min(). nlive_array may only be '
'approximately correct.').format(
dup_th_starts, ns_run['logl'].min(), n_logl_min), UserWarning)
assert nlive_array.min() > 0, ((
'nlive contains 0s or negative values. nlive_0={}'
'\nnlive_array = {}\nthread_min_max={}').format(
nlive_0, nlive_array, thread_min_max))
assert nlive_array[-1] == 1, (
'final point in nlive_array != 1.\nnlive_array = ' + str(nlive_array))
ns_run['nlive_array'] = nlive_array
return ns_run
[docs]def get_run_threads(ns_run):
"""
Get the individual threads from a nested sampling run.
Parameters
----------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
Returns
-------
threads: list of numpy array
Each thread (list element) is a samples array containing columns
[logl, thread label, change in nlive at sample, (thetas)]
with each row representing a single sample.
"""
samples = array_given_run(ns_run)
unique_threads = np.unique(ns_run['thread_labels'])
assert ns_run['thread_min_max'].shape[0] == unique_threads.shape[0], (
'some threads have no points! {0} != {1}'.format(
unique_threads.shape[0], ns_run['thread_min_max'].shape[0]))
threads = []
for i, th_lab in enumerate(unique_threads):
thread_array = samples[np.where(samples[:, 1] == th_lab)]
# delete changes in nlive due to other threads in the run
thread_array[:, 2] = 0
thread_array[-1, 2] = -1
min_max = np.reshape(ns_run['thread_min_max'][i, :], (1, 2))
assert min_max[0, 1] == thread_array[-1, 0], (
'thread max logl should equal logl of its final point!')
threads.append(dict_given_run_array(thread_array, min_max))
return threads
[docs]def combine_ns_runs(run_list_in, **kwargs):
"""
Combine a list of complete nested sampling run dictionaries into a single
ns run.
Input runs must contain any repeated threads.
Parameters
----------
run_list_in: list of dicts
List of nested sampling runs in dict format (see data_processing module
docstring for more details).
kwargs: dict, optional
Options for check_ns_run.
Returns
-------
run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
"""
run_list = copy.deepcopy(run_list_in)
if len(run_list) == 1:
run = run_list[0]
else:
nthread_tot = 0
for i, _ in enumerate(run_list):
check_ns_run(run_list[i], **kwargs)
run_list[i]['thread_labels'] += nthread_tot
nthread_tot += run_list[i]['thread_min_max'].shape[0]
thread_min_max = np.vstack([run['thread_min_max'] for run in run_list])
# construct samples array from the threads, including an updated nlive
samples_temp = np.vstack([array_given_run(run) for run in run_list])
samples_temp = samples_temp[np.argsort(samples_temp[:, 0])]
# Make combined run
run = dict_given_run_array(samples_temp, thread_min_max)
# Combine only the additive properties stored in run['output']
run['output'] = {}
for key in ['nlike', 'ndead']:
try:
to_sum = [run_temp['output'][key] for run_temp in run_list_in]
# Check if any runs have iterable (rather than float/int)
# values for nlike or ndead and sum to floats/ints when needed.
# Iterable values for nlike are produced for nlike when using
# PolyChord with fast/slow parameters.
for i, value in enumerate(to_sum):
try:
to_sum[i] = sum(value)
except TypeError:
pass
run['output'][key] = sum(to_sum)
except KeyError:
pass
check_ns_run(run, **kwargs)
return run
[docs]def combine_threads(threads, assert_birth_point=False):
"""
Combine list of threads into a single ns run.
This is different to combining runs as repeated threads are allowed, and as
some threads can start from log-likelihood contours on which no dead
point in the run is present.
Note that if all the thread labels are not unique and in ascending order,
the output will fail check_ns_run. However provided the thread labels are
not used it will work ok for calculations based on nlive, logl and theta.
Parameters
----------
threads: list of dicts
List of nested sampling run dicts, each representing a single thread.
assert_birth_point: bool, optional
Whether or not to assert there is exactly one point present in the run
with the log-likelihood at which each point was born. This is not true
for bootstrap resamples of runs, where birth points may be repeated or
not present at all.
Returns
-------
run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
"""
thread_min_max = np.vstack([td['thread_min_max'] for td in threads])
assert len(threads) == thread_min_max.shape[0]
# construct samples array from the threads, including an updated nlive
samples_temp = np.vstack([array_given_run(thread) for thread in threads])
samples_temp = samples_temp[np.argsort(samples_temp[:, 0])]
# update the changes in live points column for threads which start part way
# through the run. These are only present in dynamic nested sampling.
logl_starts = thread_min_max[:, 0]
state = np.random.get_state() # save random state
np.random.seed(0) # seed to make sure any random assignment is repoducable
for logl_start in logl_starts[logl_starts != -np.inf]:
ind = np.where(samples_temp[:, 0] == logl_start)[0]
if assert_birth_point:
assert ind.shape == (1,), \
'No unique birth point! ' + str(ind.shape)
if ind.shape == (1,):
# If the point at which this thread started is present exactly
# once in this bootstrap replication:
samples_temp[ind[0], 2] += 1
elif ind.shape == (0,):
# If the point with the likelihood at which the thread started
# is not present in this particular bootstrap replication,
# approximate it with the point with the nearest likelihood.
ind_closest = np.argmin(np.abs(samples_temp[:, 0] - logl_start))
samples_temp[ind_closest, 2] += 1
else:
# If the point at which this thread started is present multiple
# times in this bootstrap replication, select one at random to
# increment nlive on. This avoids any systematic bias from e.g.
# always choosing the first point.
samples_temp[np.random.choice(ind), 2] += 1
np.random.set_state(state)
# make run
ns_run = dict_given_run_array(samples_temp, thread_min_max)
try:
check_ns_run_threads(ns_run)
except AssertionError:
# If the threads are not valid (e.g. for bootstrap resamples) then
# set them to None so they can't be accidentally used
ns_run['thread_labels'] = None
ns_run['thread_min_max'] = None
return ns_run
[docs]def get_logw(ns_run, simulate=False):
r"""Calculates the log posterior weights of the samples (using logarithms
to avoid overflow errors with very large or small values).
Uses the trapezium rule such that the weight of point i is
.. math:: w_i = \mathcal{L}_i (X_{i-1} - X_{i+1}) / 2
Parameters
----------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
simulate: bool, optional
Should log prior volumes logx be simulated from their distribution (if
false their expected values are used).
Returns
-------
logw: 1d numpy array
Log posterior masses of points.
"""
try:
# find logX value for each point
logx = get_logx(ns_run['nlive_array'], simulate=simulate)
logw = np.zeros(ns_run['logl'].shape[0])
# Vectorized trapezium rule: w_i prop to (X_{i-1} - X_{i+1}) / 2
logw[1:-1] = log_subtract(logx[:-2], logx[2:]) - np.log(2)
# Assign all prior volume closest to first point X_first to that point:
# that is from logx=0 to logx=log((X_first + X_second) / 2)
logw[0] = log_subtract(0, scipy.special.logsumexp([logx[0], logx[1]]) -
np.log(2))
# Assign all prior volume closest to final point X_last to that point:
# that is from logx=log((X_penultimate + X_last) / 2) to logx=-inf
logw[-1] = scipy.special.logsumexp([logx[-2], logx[-1]]) - np.log(2)
# multiply by likelihood (add in log space)
logw += ns_run['logl']
return logw
except IndexError:
if ns_run['logl'].shape[0] == 1:
# If there is only one point in the run then assign all prior
# volume X \in (0, 1) to that point, so the weight is just
# 1 * logl_0 = logl_0
return copy.deepcopy(ns_run['logl'])
else:
raise
[docs]def get_w_rel(ns_run, simulate=False):
"""Get the relative posterior weights of the samples, normalised so
the maximum sample weight is 1. This is calculated from get_logw with
protection against numerical overflows.
Parameters
----------
ns_run: dict
Nested sampling run dict (see data_processing module docstring for more
details).
simulate: bool, optional
See the get_logw docstring for more details.
Returns
-------
w_rel: 1d numpy array
Relative posterior masses of points.
"""
logw = get_logw(ns_run, simulate=simulate)
return np.exp(logw - logw.max())
[docs]def get_logx(nlive, simulate=False):
r"""Returns a logx vector showing the expected or simulated logx positions
of points.
The shrinkage factor between two points
.. math:: t_i = X_{i-1} / X_{i}
is distributed as the largest of :math:`n_i` uniform random variables
between 1 and 0, where :math:`n_i` is the local number of live points.
We are interested in
.. math:: \log(t_i) = \log X_{i-1} - \log X_{i}
which has expected value :math:`-1/n_i`.
Parameters
----------
nlive_array: 1d numpy array
Ordered local number of live points present at each point's
iso-likelihood contour.
simulate: bool, optional
Should log prior volumes logx be simulated from their distribution (if
False their expected values are used).
Returns
-------
logx: 1d numpy array
log X values for points.
"""
assert nlive.min() > 0, (
'nlive contains zeros or negative values! nlive = ' + str(nlive))
if simulate:
logx_steps = np.log(np.random.random(nlive.shape)) / nlive
else:
logx_steps = -1 * (nlive.astype(float) ** -1)
return np.cumsum(logx_steps)
[docs]def log_subtract(loga, logb):
r"""Numerically stable method for avoiding overflow errors when calculating
:math:`\log (a-b)`, given :math:`\log (a)`, :math:`\log (a)` and that
:math:`a > b`.
See https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
for more details.
Parameters
----------
loga: float
logb: float
Must be less than loga.
Returns
-------
log(a - b): float
"""
return loga + np.log(1 - np.exp(logb - loga))
# Functions for checking nestcheck format nested sampling run dictionaries to
# ensure they have the expected properties.
[docs]def check_ns_run(run, dup_assert=False, dup_warn=False):
"""Checks a nestcheck format nested sampling run dictionary has the
expected properties (see the data_processing module docstring for more
details).
Parameters
----------
run: dict
nested sampling run to check.
dup_assert: bool, optional
See check_ns_run_logls docstring.
dup_warn: bool, optional
See check_ns_run_logls docstring.
Raises
------
AssertionError
if run does not have expected properties.
"""
assert isinstance(run, dict)
check_ns_run_members(run)
check_ns_run_logls(run, dup_assert=dup_assert, dup_warn=dup_warn)
check_ns_run_threads(run)
[docs]def check_ns_run_members(run):
"""Check nested sampling run member keys and values.
Parameters
----------
run: dict
nested sampling run to check.
Raises
------
AssertionError
if run does not have expected properties.
"""
run_keys = list(run.keys())
# Mandatory keys
for key in ['logl', 'nlive_array', 'theta', 'thread_labels',
'thread_min_max']:
assert key in run_keys
run_keys.remove(key)
# Optional keys
for key in ['output']:
try:
run_keys.remove(key)
except ValueError:
pass
# Check for unexpected keys
assert not run_keys, 'Unexpected keys in ns_run: ' + str(run_keys)
# Check type of mandatory members
for key in ['logl', 'nlive_array', 'theta', 'thread_labels',
'thread_min_max']:
assert isinstance(run[key], np.ndarray), (
key + ' is type ' + type(run[key]).__name__)
# check shapes of keys
assert run['logl'].ndim == 1
assert run['logl'].shape == run['nlive_array'].shape
assert run['logl'].shape == run['thread_labels'].shape
assert run['theta'].ndim == 2
assert run['logl'].shape[0] == run['theta'].shape[0]
[docs]def check_ns_run_logls(run, dup_assert=False, dup_warn=False):
"""Check run logls are unique and in the correct order.
Parameters
----------
run: dict
nested sampling run to check.
dup_assert: bool, optional
Whether to raise and AssertionError if there are duplicate logl values.
dup_warn: bool, optional
Whether to give a UserWarning if there are duplicate logl values (only
used if dup_assert is False).
Raises
------
AssertionError
if run does not have expected properties.
"""
assert np.array_equal(run['logl'], run['logl'][np.argsort(run['logl'])])
if dup_assert or dup_warn:
unique_logls, counts = np.unique(run['logl'], return_counts=True)
repeat_logls = run['logl'].shape[0] - unique_logls.shape[0]
msg = ('{} duplicate logl values (out of a total of {}). This may be '
'caused by limited numerical precision in the output files.'
'\nrepeated logls = {}\ncounts = {}\npositions in list of {}'
' unique logls = {}').format(
repeat_logls, run['logl'].shape[0],
unique_logls[counts != 1], counts[counts != 1],
unique_logls.shape[0], np.where(counts != 1)[0])
if dup_assert:
assert repeat_logls == 0, msg
elif dup_warn:
if repeat_logls != 0:
warnings.warn(msg, UserWarning)
[docs]def check_ns_run_threads(run):
"""Check thread labels and thread_min_max have expected properties.
Parameters
----------
run: dict
Nested sampling run to check.
Raises
------
AssertionError
If run does not have expected properties.
"""
assert run['thread_labels'].dtype == int
uniq_th = np.unique(run['thread_labels'])
assert np.array_equal(
np.asarray(range(run['thread_min_max'].shape[0])), uniq_th), \
str(uniq_th)
# Check thread_min_max
assert np.any(run['thread_min_max'][:, 0] == -np.inf), (
'Run should have at least one thread which starts by sampling the ' +
'whole prior')
for th_lab in uniq_th:
inds = np.where(run['thread_labels'] == th_lab)[0]
th_info = 'thread label={}, first_logl={}, thread_min_max={}'.format(
th_lab, run['logl'][inds[0]], run['thread_min_max'][th_lab, :])
assert run['thread_min_max'][th_lab, 0] <= run['logl'][inds[0]], (
'First point in thread has logl less than thread min logl! ' +
th_info + ', difference={}'.format(
run['logl'][inds[0]] - run['thread_min_max'][th_lab, 0]))
assert run['thread_min_max'][th_lab, 1] == run['logl'][inds[-1]], (
'Last point in thread logl != thread end logl! ' + th_info)