Source code for nestcheck.ns_run_utils

#!/usr/bin/env python
"""
Functions for performing basic operations on nested sampling runs; such as
working out point weights and splitting and combining runs.

Nested sampling runs are stored in a standard format as python dictionaries
(see the ``data_processing`` module docstring for more details).
"""
import copy
import warnings
import numpy as np
import scipy.special


[docs]def run_estimators(ns_run, estimator_list, simulate=False): """Calculates values of list of quantities (such as the Bayesian evidence or mean of parameters) for a single nested sampling run. Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). estimator_list: list of functions for estimating quantities from nested sampling runs. Example functions can be found in estimators.py. Each should have arguments: func(ns_run, logw=None). simulate: bool, optional See get_logw docstring. Returns ------- output: 1d numpy array Calculation result for each estimator in estimator_list. """ logw = get_logw(ns_run, simulate=simulate) output = np.zeros(len(estimator_list)) for i, est in enumerate(estimator_list): output[i] = est(ns_run, logw=logw) return output
[docs]def array_given_run(ns_run): """Converts information on samples in a nested sampling run dictionary into a numpy array representation. This allows fast addition of more samples and recalculation of nlive. Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). Returns ------- samples: 2d numpy array Array containing columns [logl, thread label, change in nlive at sample, (thetas)] with each row representing a single sample. """ samples = np.zeros((ns_run['logl'].shape[0], 3 + ns_run['theta'].shape[1])) samples[:, 0] = ns_run['logl'] samples[:, 1] = ns_run['thread_labels'] # Calculate 'change in nlive' after each step samples[:-1, 2] = np.diff(ns_run['nlive_array']) samples[-1, 2] = -1 # nlive drops to zero after final point samples[:, 3:] = ns_run['theta'] return samples
[docs]def dict_given_run_array(samples, thread_min_max): """ Converts an array of information about samples back into a nested sampling run dictionary (see data_processing module docstring for more details). N.B. the output dict only contains the following keys: 'logl', 'thread_label', 'nlive_array', 'theta'. Any other keys giving additional information about the run output cannot be reproduced from the function arguments, and are therefore ommitted. Parameters ---------- samples: numpy array Numpy array containing columns [logl, thread label, change in nlive at sample, (thetas)] with each row representing a single sample. thread_min_max': numpy array, optional 2d array with a row for each thread containing the likelihoods at which it begins and ends. Needed to calculate nlive_array (otherwise this is set to None). Returns ------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). """ ns_run = {'logl': samples[:, 0], 'thread_labels': samples[:, 1], 'thread_min_max': thread_min_max, 'theta': samples[:, 3:]} if np.all(~np.isnan(ns_run['thread_labels'])): ns_run['thread_labels'] = ns_run['thread_labels'].astype(int) assert np.array_equal(samples[:, 1], ns_run['thread_labels']), (( 'Casting thread labels from samples array to int has changed ' 'their values!\nsamples[:, 1]={}\nthread_labels={}').format( samples[:, 1], ns_run['thread_labels'])) nlive_0 = (thread_min_max[:, 0] <= ns_run['logl'].min()).sum() assert nlive_0 > 0, 'nlive_0={}'.format(nlive_0) nlive_array = np.zeros(samples.shape[0]) + nlive_0 nlive_array[1:] += np.cumsum(samples[:-1, 2]) # Check if there are multiple threads starting on the first logl point dup_th_starts = (thread_min_max[:, 0] == ns_run['logl'].min()).sum() if dup_th_starts > 1: # In this case we approximate the true nlive (which we dont really # know) by making sure the array's final point is 1 and setting all # points with logl = logl.min() to have the same nlive nlive_array += (1 - nlive_array[-1]) n_logl_min = (ns_run['logl'] == ns_run['logl'].min()).sum() nlive_array[:n_logl_min] = nlive_0 warnings.warn(( 'duplicate starting logls: {} threads start at logl.min()={}, ' 'and {} points have logl=logl.min(). nlive_array may only be ' 'approximately correct.').format( dup_th_starts, ns_run['logl'].min(), n_logl_min), UserWarning) assert nlive_array.min() > 0, (( 'nlive contains 0s or negative values. nlive_0={}' '\nnlive_array = {}\nthread_min_max={}').format( nlive_0, nlive_array, thread_min_max)) assert nlive_array[-1] == 1, ( 'final point in nlive_array != 1.\nnlive_array = ' + str(nlive_array)) ns_run['nlive_array'] = nlive_array return ns_run
[docs]def get_run_threads(ns_run): """ Get the individual threads from a nested sampling run. Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). Returns ------- threads: list of numpy array Each thread (list element) is a samples array containing columns [logl, thread label, change in nlive at sample, (thetas)] with each row representing a single sample. """ samples = array_given_run(ns_run) unique_threads = np.unique(ns_run['thread_labels']) assert ns_run['thread_min_max'].shape[0] == unique_threads.shape[0], ( 'some threads have no points! {0} != {1}'.format( unique_threads.shape[0], ns_run['thread_min_max'].shape[0])) threads = [] for i, th_lab in enumerate(unique_threads): thread_array = samples[np.where(samples[:, 1] == th_lab)] # delete changes in nlive due to other threads in the run thread_array[:, 2] = 0 thread_array[-1, 2] = -1 min_max = np.reshape(ns_run['thread_min_max'][i, :], (1, 2)) assert min_max[0, 1] == thread_array[-1, 0], ( 'thread max logl should equal logl of its final point!') threads.append(dict_given_run_array(thread_array, min_max)) return threads
[docs]def combine_ns_runs(run_list_in, **kwargs): """ Combine a list of complete nested sampling run dictionaries into a single ns run. Input runs must contain any repeated threads. Parameters ---------- run_list_in: list of dicts List of nested sampling runs in dict format (see data_processing module docstring for more details). kwargs: dict, optional Options for check_ns_run. Returns ------- run: dict Nested sampling run dict (see data_processing module docstring for more details). """ run_list = copy.deepcopy(run_list_in) if len(run_list) == 1: run = run_list[0] else: nthread_tot = 0 for i, _ in enumerate(run_list): check_ns_run(run_list[i], **kwargs) run_list[i]['thread_labels'] += nthread_tot nthread_tot += run_list[i]['thread_min_max'].shape[0] thread_min_max = np.vstack([run['thread_min_max'] for run in run_list]) # construct samples array from the threads, including an updated nlive samples_temp = np.vstack([array_given_run(run) for run in run_list]) samples_temp = samples_temp[np.argsort(samples_temp[:, 0])] # Make combined run run = dict_given_run_array(samples_temp, thread_min_max) # Combine only the additive properties stored in run['output'] run['output'] = {} for key in ['nlike', 'ndead']: try: to_sum = [run_temp['output'][key] for run_temp in run_list_in] # Check if any runs have iterable (rather than float/int) # values for nlike or ndead and sum to floats/ints when needed. # Iterable values for nlike are produced for nlike when using # PolyChord with fast/slow parameters. for i, value in enumerate(to_sum): try: to_sum[i] = sum(value) except TypeError: pass run['output'][key] = sum(to_sum) except KeyError: pass check_ns_run(run, **kwargs) return run
[docs]def combine_threads(threads, assert_birth_point=False): """ Combine list of threads into a single ns run. This is different to combining runs as repeated threads are allowed, and as some threads can start from log-likelihood contours on which no dead point in the run is present. Note that if all the thread labels are not unique and in ascending order, the output will fail check_ns_run. However provided the thread labels are not used it will work ok for calculations based on nlive, logl and theta. Parameters ---------- threads: list of dicts List of nested sampling run dicts, each representing a single thread. assert_birth_point: bool, optional Whether or not to assert there is exactly one point present in the run with the log-likelihood at which each point was born. This is not true for bootstrap resamples of runs, where birth points may be repeated or not present at all. Returns ------- run: dict Nested sampling run dict (see data_processing module docstring for more details). """ thread_min_max = np.vstack([td['thread_min_max'] for td in threads]) assert len(threads) == thread_min_max.shape[0] # construct samples array from the threads, including an updated nlive samples_temp = np.vstack([array_given_run(thread) for thread in threads]) samples_temp = samples_temp[np.argsort(samples_temp[:, 0])] # update the changes in live points column for threads which start part way # through the run. These are only present in dynamic nested sampling. logl_starts = thread_min_max[:, 0] state = np.random.get_state() # save random state np.random.seed(0) # seed to make sure any random assignment is repoducable for logl_start in logl_starts[logl_starts != -np.inf]: ind = np.where(samples_temp[:, 0] == logl_start)[0] if assert_birth_point: assert ind.shape == (1,), \ 'No unique birth point! ' + str(ind.shape) if ind.shape == (1,): # If the point at which this thread started is present exactly # once in this bootstrap replication: samples_temp[ind[0], 2] += 1 elif ind.shape == (0,): # If the point with the likelihood at which the thread started # is not present in this particular bootstrap replication, # approximate it with the point with the nearest likelihood. ind_closest = np.argmin(np.abs(samples_temp[:, 0] - logl_start)) samples_temp[ind_closest, 2] += 1 else: # If the point at which this thread started is present multiple # times in this bootstrap replication, select one at random to # increment nlive on. This avoids any systematic bias from e.g. # always choosing the first point. samples_temp[np.random.choice(ind), 2] += 1 np.random.set_state(state) # make run ns_run = dict_given_run_array(samples_temp, thread_min_max) try: check_ns_run_threads(ns_run) except AssertionError: # If the threads are not valid (e.g. for bootstrap resamples) then # set them to None so they can't be accidentally used ns_run['thread_labels'] = None ns_run['thread_min_max'] = None return ns_run
[docs]def get_logw(ns_run, simulate=False): r"""Calculates the log posterior weights of the samples (using logarithms to avoid overflow errors with very large or small values). Uses the trapezium rule such that the weight of point i is .. math:: w_i = \mathcal{L}_i (X_{i-1} - X_{i+1}) / 2 Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). simulate: bool, optional Should log prior volumes logx be simulated from their distribution (if false their expected values are used). Returns ------- logw: 1d numpy array Log posterior masses of points. """ try: # find logX value for each point logx = get_logx(ns_run['nlive_array'], simulate=simulate) logw = np.zeros(ns_run['logl'].shape[0]) # Vectorized trapezium rule: w_i prop to (X_{i-1} - X_{i+1}) / 2 logw[1:-1] = log_subtract(logx[:-2], logx[2:]) - np.log(2) # Assign all prior volume closest to first point X_first to that point: # that is from logx=0 to logx=log((X_first + X_second) / 2) logw[0] = log_subtract(0, scipy.special.logsumexp([logx[0], logx[1]]) - np.log(2)) # Assign all prior volume closest to final point X_last to that point: # that is from logx=log((X_penultimate + X_last) / 2) to logx=-inf logw[-1] = scipy.special.logsumexp([logx[-2], logx[-1]]) - np.log(2) # multiply by likelihood (add in log space) logw += ns_run['logl'] return logw except IndexError: if ns_run['logl'].shape[0] == 1: # If there is only one point in the run then assign all prior # volume X \in (0, 1) to that point, so the weight is just # 1 * logl_0 = logl_0 return copy.deepcopy(ns_run['logl']) else: raise
[docs]def get_w_rel(ns_run, simulate=False): """Get the relative posterior weights of the samples, normalised so the maximum sample weight is 1. This is calculated from get_logw with protection against numerical overflows. Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). simulate: bool, optional See the get_logw docstring for more details. Returns ------- w_rel: 1d numpy array Relative posterior masses of points. """ logw = get_logw(ns_run, simulate=simulate) return np.exp(logw - logw.max())
[docs]def get_logx(nlive, simulate=False): r"""Returns a logx vector showing the expected or simulated logx positions of points. The shrinkage factor between two points .. math:: t_i = X_{i-1} / X_{i} is distributed as the largest of :math:`n_i` uniform random variables between 1 and 0, where :math:`n_i` is the local number of live points. We are interested in .. math:: \log(t_i) = \log X_{i-1} - \log X_{i} which has expected value :math:`-1/n_i`. Parameters ---------- nlive_array: 1d numpy array Ordered local number of live points present at each point's iso-likelihood contour. simulate: bool, optional Should log prior volumes logx be simulated from their distribution (if False their expected values are used). Returns ------- logx: 1d numpy array log X values for points. """ assert nlive.min() > 0, ( 'nlive contains zeros or negative values! nlive = ' + str(nlive)) if simulate: logx_steps = np.log(np.random.random(nlive.shape)) / nlive else: logx_steps = -1 * (nlive.astype(float) ** -1) return np.cumsum(logx_steps)
[docs]def log_subtract(loga, logb): r"""Numerically stable method for avoiding overflow errors when calculating :math:`\log (a-b)`, given :math:`\log (a)`, :math:`\log (a)` and that :math:`a > b`. See https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ for more details. Parameters ---------- loga: float logb: float Must be less than loga. Returns ------- log(a - b): float """ return loga + np.log(1 - np.exp(logb - loga))
# Functions for checking nestcheck format nested sampling run dictionaries to # ensure they have the expected properties.
[docs]def check_ns_run(run, dup_assert=False, dup_warn=False): """Checks a nestcheck format nested sampling run dictionary has the expected properties (see the data_processing module docstring for more details). Parameters ---------- run: dict nested sampling run to check. dup_assert: bool, optional See check_ns_run_logls docstring. dup_warn: bool, optional See check_ns_run_logls docstring. Raises ------ AssertionError if run does not have expected properties. """ assert isinstance(run, dict) check_ns_run_members(run) check_ns_run_logls(run, dup_assert=dup_assert, dup_warn=dup_warn) check_ns_run_threads(run)
[docs]def check_ns_run_members(run): """Check nested sampling run member keys and values. Parameters ---------- run: dict nested sampling run to check. Raises ------ AssertionError if run does not have expected properties. """ run_keys = list(run.keys()) # Mandatory keys for key in ['logl', 'nlive_array', 'theta', 'thread_labels', 'thread_min_max']: assert key in run_keys run_keys.remove(key) # Optional keys for key in ['output']: try: run_keys.remove(key) except ValueError: pass # Check for unexpected keys assert not run_keys, 'Unexpected keys in ns_run: ' + str(run_keys) # Check type of mandatory members for key in ['logl', 'nlive_array', 'theta', 'thread_labels', 'thread_min_max']: assert isinstance(run[key], np.ndarray), ( key + ' is type ' + type(run[key]).__name__) # check shapes of keys assert run['logl'].ndim == 1 assert run['logl'].shape == run['nlive_array'].shape assert run['logl'].shape == run['thread_labels'].shape assert run['theta'].ndim == 2 assert run['logl'].shape[0] == run['theta'].shape[0]
[docs]def check_ns_run_logls(run, dup_assert=False, dup_warn=False): """Check run logls are unique and in the correct order. Parameters ---------- run: dict nested sampling run to check. dup_assert: bool, optional Whether to raise and AssertionError if there are duplicate logl values. dup_warn: bool, optional Whether to give a UserWarning if there are duplicate logl values (only used if dup_assert is False). Raises ------ AssertionError if run does not have expected properties. """ assert np.array_equal(run['logl'], run['logl'][np.argsort(run['logl'])]) if dup_assert or dup_warn: unique_logls, counts = np.unique(run['logl'], return_counts=True) repeat_logls = run['logl'].shape[0] - unique_logls.shape[0] msg = ('{} duplicate logl values (out of a total of {}). This may be ' 'caused by limited numerical precision in the output files.' '\nrepeated logls = {}\ncounts = {}\npositions in list of {}' ' unique logls = {}').format( repeat_logls, run['logl'].shape[0], unique_logls[counts != 1], counts[counts != 1], unique_logls.shape[0], np.where(counts != 1)[0]) if dup_assert: assert repeat_logls == 0, msg elif dup_warn: if repeat_logls != 0: warnings.warn(msg, UserWarning)
[docs]def check_ns_run_threads(run): """Check thread labels and thread_min_max have expected properties. Parameters ---------- run: dict Nested sampling run to check. Raises ------ AssertionError If run does not have expected properties. """ assert run['thread_labels'].dtype == int uniq_th = np.unique(run['thread_labels']) assert np.array_equal( np.asarray(range(run['thread_min_max'].shape[0])), uniq_th), \ str(uniq_th) # Check thread_min_max assert np.any(run['thread_min_max'][:, 0] == -np.inf), ( 'Run should have at least one thread which starts by sampling the ' + 'whole prior') for th_lab in uniq_th: inds = np.where(run['thread_labels'] == th_lab)[0] th_info = 'thread label={}, first_logl={}, thread_min_max={}'.format( th_lab, run['logl'][inds[0]], run['thread_min_max'][th_lab, :]) assert run['thread_min_max'][th_lab, 0] <= run['logl'][inds[0]], ( 'First point in thread has logl less than thread min logl! ' + th_info + ', difference={}'.format( run['logl'][inds[0]] - run['thread_min_max'][th_lab, 0])) assert run['thread_min_max'][th_lab, 1] == run['logl'][inds[-1]], ( 'Last point in thread logl != thread end logl! ' + th_info)