Source code for nestcheck.write_polychord_output

#!/usr/bin/env python
"""
Functions for writing PolyChord-format output files given a nested sampling run
dictionary stored in the nestcheck format.
"""

import copy
import functools
import os
import numpy as np
import nestcheck.estimators as e
import nestcheck.error_analysis
import nestcheck.ns_run_utils


[docs]def write_run_output(run, **kwargs): r"""Writes PolyChord output files corresponding to the input nested sampling run. The file root is .. code-block:: python root = os.path.join(run['output']['base_dir'], run['output']['file_root']) Output files which can be made with this function (see the PolyChord documentation for more information about what each contains): * [root].stats * [root].txt * [root]_equal_weights.txt * [root]_dead-birth.txt * [root]_dead.txt Files produced by PolyChord which are not made by this function: * [root].resume: for resuming runs part way through (not relevant for a completed run). * [root]_phys_live.txt and [root]phys_live-birth.txt: for checking runtime progress (not relevant for a completed run). * [root].paramnames: for use with getdist (not needed when calling getdist from within python). Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). write_dead: bool, optional Whether or not to write [root]_dead.txt and [root]_dead-birth.txt. write_stats: bool, optional Whether or not to write [root].stats. posteriors: bool, optional Whether or not to write [root].txt. equals: bool, optional Whether or not to write [root]_equal_weights.txt. stats_means_errs: bool, optional Whether or not to calculate mean values of :math:`\log \mathcal{Z}` and each parameter, and their uncertainties. fmt: str, optional Formatting for numbers written by np.savetxt. Default value is set to make output files look like the ones produced by PolyChord. n_simulate: int, optional Number of bootstrap replications to use when estimating uncertainty on evidence and parameter means. logl_init: float, optional Value used to identify the inital live points which were sampled from the whole prior. Default value is set to -1e30 as in PolyChord. """ write_dead = kwargs.pop('write_dead', True) write_stats = kwargs.pop('write_stats', True) posteriors = kwargs.pop('posteriors', False) equals = kwargs.pop('equals', False) stats_means_errs = kwargs.pop('stats_means_errs', True) fmt = kwargs.pop('fmt', '% .14E') n_simulate = kwargs.pop('n_simulate', 100) logl_init = kwargs.pop('logl_init', -1e30) if kwargs: raise TypeError('Unexpected **kwargs: {0}'.format(kwargs)) mandatory_keys = ['file_root', 'base_dir'] for key in mandatory_keys: assert key in run['output'], key + ' not in run["output"]' root = os.path.join(run['output']['base_dir'], run['output']['file_root']) if write_dead: samples = run_dead_birth_array(run, logl_init=logl_init) np.savetxt(root + '_dead-birth.txt', samples, fmt=fmt) np.savetxt(root + '_dead.txt', samples[:, :-1], fmt=fmt) if equals or posteriors: w_rel = nestcheck.ns_run_utils.get_w_rel(run) post_arr = np.zeros((run['theta'].shape[0], run['theta'].shape[1] + 2)) post_arr[:, 0] = w_rel post_arr[:, 1] = -2 * run['logl'] post_arr[:, 2:] = run['theta'] if posteriors: np.savetxt(root + '.txt', post_arr, fmt=fmt) run['output']['nposterior'] = post_arr.shape[0] else: run['output']['nposterior'] = 0 if equals: inds = np.where(w_rel > np.random.random(w_rel.shape[0]))[0] np.savetxt(root + '_equal_weights.txt', post_arr[inds, 1:], fmt=fmt) run['output']['nequals'] = inds.shape[0] else: run['output']['nequals'] = 0 if write_stats: run['output']['ndead'] = run['logl'].shape[0] if stats_means_errs: # Get logZ and param estimates and errors estimators = [e.logz] for i in range(run['theta'].shape[1]): estimators.append(functools.partial(e.param_mean, param_ind=i)) values = nestcheck.ns_run_utils.run_estimators(run, estimators) stds = nestcheck.error_analysis.run_std_bootstrap( run, estimators, n_simulate=n_simulate) run['output']['logZ'] = values[0] run['output']['logZerr'] = stds[0] run['output']['param_means'] = list(values[1:]) run['output']['param_mean_errs'] = list(stds[1:]) write_stats_file(run['output'])
[docs]def run_dead_birth_array(run, logl_init=-1e30, **kwargs): """Converts input run into an array of the format of a PolyChord <root>_dead-birth.txt file. Note that this in fact includes live points remaining at termination as well as dead points. Parameters ---------- ns_run: dict Nested sampling run dict (see data_processing module docstring for more details). logl_init: float, optional Value used to identify the inital live points which were sampled from the whole prior. Default value is set to -1e30 as in PolyChord. kwargs: dict, optional Options for check_ns_run. Returns ------- samples: 2d numpy array Array of dead points and any remaining live points at termination. Has #parameters + 2 columns: param_1, param_2, ... , logl, birth_logl """ nestcheck.ns_run_utils.check_ns_run(run, **kwargs) threads = nestcheck.ns_run_utils.get_run_threads(run) samp_arrays = [] ndim = run['theta'].shape[1] for th in threads: samp_arr = np.zeros((th['theta'].shape[0], ndim + 2)) samp_arr[:, :ndim] = th['theta'] samp_arr[:, ndim] = th['logl'] samp_arr[1:, ndim + 1] = th['logl'][:-1] if th['thread_min_max'][0, 0] == -np.inf: samp_arr[0, ndim + 1] = logl_init else: samp_arr[0, ndim + 1] = th['thread_min_max'][0, 0] samp_arrays.append(samp_arr) samples = np.vstack(samp_arrays) samples = samples[np.argsort(samples[:, ndim]), :] return samples
[docs]def write_stats_file(run_output_dict): """Writes a dummy PolyChord format .stats file for tests functions for processing stats files. This is written to: base_dir/file_root.stats Also returns the data in the file as a dict for comparison. Parameters ---------- run_output_dict: dict Output information to write to .stats file. Must contain file_root and base_dir. If other settings are not specified, default values are used. Returns ------- output: dict The expected output of nestcheck.process_polychord_stats(file_root, base_dir) """ mandatory_keys = ['file_root', 'base_dir'] for key in mandatory_keys: assert key in run_output_dict, key + ' not in run_output_dict' default_output = {'logZ': 0.0, 'logZerr': 0.0, 'logZs': [0.0], 'logZerrs': [0.0], 'ncluster': 1, 'nposterior': 0, 'nequals': 0, 'ndead': 0, 'nlike': 0, 'nlive': 0, 'avnlike': 0.0, 'avnlikeslice': 0.0, 'param_means': [0.0, 0.0, 0.0], 'param_mean_errs': [0.0, 0.0, 0.0]} allowed_keys = set(mandatory_keys) | set(default_output.keys()) assert set(run_output_dict.keys()).issubset(allowed_keys), ( 'Input dict contains unexpected keys: {}'.format( set(run_output_dict.keys()) - allowed_keys)) output = copy.deepcopy(run_output_dict) for key, value in default_output.items(): if key not in output: output[key] = value # Make a PolyChord format .stats file corresponding to output file_lines = [ 'Evidence estimates:', '===================', (' - The evidence Z is a log-normally distributed, with location and ' 'scale parameters mu and sigma.'), ' - We denote this as log(Z) = mu +/- sigma.', '', 'Global evidence:', '----------------', '', 'log(Z) = {0} +/- {1}'.format( output['logZ'], output['logZerr']), '', '', 'Local evidences:', '----------------', ''] for i, (lz, lzerr) in enumerate(zip(output['logZs'], output['logZerrs'])): file_lines.append('log(Z_ {0}) = {1} +/- {2}'.format( str(i + 1).rjust(2), lz, lzerr)) file_lines += [ '', '', 'Run-time information:', '---------------------', '', ' ncluster: 0 / 1', ' nposterior: {0}'.format(output['nposterior']), ' nequals: {0}'.format(output['nequals']), ' ndead: {0}'.format(output['ndead']), ' nlive: {0}'.format(output['nlive']), ' nlike: {0}'.format( output['nlike'] if not isinstance(output['nlike'], (list, tuple)) else " ".join([str(x) for x in output['nlike']])), ' <nlike>: {0} ( {1} per slice )'.format( output['avnlike'] if not isinstance(output['avnlike'], (list, tuple)) else " ".join([str(x) for x in output['avnlike']]), output['avnlikeslice'] if not isinstance( output['avnlikeslice'], (list, tuple)) else " ".join([str(x) for x in output['avnlikeslice']])), '', '', 'Dim No. Mean Sigma'] for i, (mean, meanerr) in enumerate(zip(output['param_means'], output['param_mean_errs'])): file_lines.append('{0} {1} +/- {2}'.format( str(i + 1).ljust(3), mean, meanerr)) file_lines.append('-------------------------------') file_path = os.path.join(output['base_dir'], output['file_root'] + '.stats') with open(file_path, 'w') as stats_file: stats_file.writelines('{}\n'.format(line) for line in file_lines) return output