#!/usr/bin/env python
r"""Module containing functions for loading and processing output files
produced by nested sampling software.
Background: threads
-------------------
``nestcheck``'s error estimates and diagnostics rely on the decomposition
of a nested sampling run into multiple runs, each with a single live point.
We refer to these constituent single live point runs as *threads*.
See "Sampling Errors In Nested Sampling Parameter Estimation" (Higson et
al. 2018) for a detailed discussion, including an algorithm for dividing
nested sampling runs into their constituent threads.
Nested sampling run format
--------------------------
``nestcheck`` stores nested sampling runs in a standard format as python
dictionaries. For a run with :math:`n_\mathrm{samp}` samples, the keys are:
logl: 1d numpy array
Loglikelihood values (floats) for each sample.
Shape is (:math:`n_\mathrm{samp}`,).
thread_labels: 1d numpy array
Integer label for each point representing which thread each point
belongs to.
Shape is (:math:`n_\mathrm{samp}`,).
For some thread label k, the thread's start (birth)
log-likelihood and end log-likelihood are given by
thread_min_max[k, :].
thread_min_max: 2d numpy array
Shape is (:math:`n_\mathrm{threads}`, 2).
Each row with index k contains the logl from within which the first
point in the thread with label k was sampled (the "birth contour") and
the logl of the final point in the thread.
The birth contour is -inf if the thread began by sampling from the
whole prior.
theta: 2d numpy array
Parameter values for samples - each row represents a sample.
Shape is (:math:`n_\mathrm{samp}`, d) where d is number of dimensions.
nlive_array: 1d numpy array
Number of live points present between the previous point and
this point.
output: dict (optional)
Dict containing extra information about the run.
Samples are arranged in ascending order of logl.
Processing nested sampling software output
------------------------------------------
To process output files for a nested sampling run into the format described
above, the following information is required:
* Samples' loglikelihood values;
* Samples' parameter values;
* Information allowing decomposition into threads and identifying each thread's
birth contour (starting logl).
The first two items are self-explanatory, but the latter is more challenging
as it can take different formats and may not be provided by all nested sampling
software packages.
Sufficient information for thread decomposition and calculating the number of
live points (including for dynamic nested sampling) is provided by a list of
the loglikelihoods from within which each point was sampled (the points'
birth contours). This is output by ``PolyChord`` >= v1.13 and ``MultiNest``
>= v3.11, and is used in the output processing for these packages via the
``birth_inds_given_contours`` and ``threads_given_birth_inds`` functions.
Also sufficient is a list of the indexes of the point which was removed
at the step when each point was sampled ("birth indexes"), as this can be
mapped to the birth contours and vice versa.
``process_dynesty_run`` does not require the ``birth_inds_given_contours`` and
``threads_given_birth_inds`` functions as ``dynesty`` results objects
already include thread labels via their ``samples_id`` property. If the
``dynesty`` run is dynamic, the ``batch_bounds`` property is need to determine
the threads' starting birth contours.
Adding a new processing function for another nested sampling package
--------------------------------------------------------------------
You can add new functions to process output from other nested sampling
software, provided the output files include the required information for
decomposition into threads. Depending on how this information is provided you
may be able to adapt ``process_polychord_run`` or ``process_dynesty_run``.
If thread decomposition information if provided in a different format, you
will have to write your own helper functions to process the output into the
``nestcheck`` dictionary format described above.
"""
import os
import re
import warnings
import copy
import numpy as np
import nestcheck.io_utils
import nestcheck.ns_run_utils
import nestcheck.parallel_utils
[docs]@nestcheck.io_utils.save_load_result
def batch_process_data(file_roots, **kwargs):
"""Process output from many nested sampling runs in parallel with optional
error handling and caching.
The result can be cached using the 'save_name', 'save' and 'load' kwargs
(by default this is not done). See save_load_result docstring for more
details.
Remaining kwargs passed to parallel_utils.parallel_apply (see its
docstring for more details).
Parameters
----------
file_roots: list of strs
file_roots for the runs to load.
base_dir: str, optional
path to directory containing files.
process_func: function, optional
function to use to process the data.
func_kwargs: dict, optional
additional keyword arguments for process_func.
errors_to_handle: error or tuple of errors, optional
which errors to catch when they occur in processing rather than
raising.
save_name: str or None, optional
See nestcheck.io_utils.save_load_result.
save: bool, optional
See nestcheck.io_utils.save_load_result.
load: bool, optional
See nestcheck.io_utils.save_load_result.
overwrite_existing: bool, optional
See nestcheck.io_utils.save_load_result.
Returns
-------
list of ns_run dicts
List of nested sampling runs in dict format (see the module
docstring for more details).
"""
base_dir = kwargs.pop('base_dir', 'chains')
process_func = kwargs.pop('process_func', process_polychord_run)
func_kwargs = kwargs.pop('func_kwargs', {})
func_kwargs['errors_to_handle'] = kwargs.pop('errors_to_handle', ())
data = nestcheck.parallel_utils.parallel_apply(
process_error_helper, file_roots, func_args=(base_dir, process_func),
func_kwargs=func_kwargs, **kwargs)
# Sort processed runs into the same order as file_roots (as parallel_apply
# does not preserve order)
data = sorted(data,
key=lambda x: file_roots.index(x['output']['file_root']))
# Extract error information and print
errors = {}
for i, run in enumerate(data):
if 'error' in run:
try:
errors[run['error']].append(i)
except KeyError:
errors[run['error']] = [i]
for error_name, index_list in errors.items():
message = (error_name + ' processing ' + str(len(index_list)) + ' / '
+ str(len(file_roots)) + ' files')
if len(index_list) != len(file_roots):
message += ('. Roots with errors have (zero based) indexes: '
+ str(index_list))
print(message)
# Return runs which did not have errors
return [run for run in data if 'error' not in run]
[docs]def process_error_helper(root, base_dir, process_func, errors_to_handle=(),
**func_kwargs):
"""Wrapper which applies process_func and handles some common errors so one
bad run does not spoil the whole batch.
Useful errors to handle include:
OSError: if you are not sure if all the files exist
AssertionError: if some of the many assertions fail for known reasons;
for example is there are occasional problems decomposing runs into threads
due to limited numerical precision in logls.
Parameters
----------
root: str
File root.
base_dir: str
Directory containing file.
process_func: func
Function for processing file.
errors_to_handle: error type or tuple of error types
Errors to catch without throwing an exception.
func_kwargs: dict
Kwargs to pass to process_func.
Returns
-------
run: dict
Nested sampling run dict (see the module docstring for more
details) or, if an error occured, a dict containing its type
and the file root.
"""
try:
return process_func(root, base_dir, **func_kwargs)
except errors_to_handle as err:
run = {'error': type(err).__name__,
'output': {'file_root': root}}
return run
[docs]def process_polychord_run(file_root, base_dir, process_stats_file=True,
**kwargs):
"""Loads data from a PolyChord run into the nestcheck dictionary format for
analysis.
N.B. producing required output file containing information about the
iso-likelihood contours within which points were sampled (where they were
"born") requies PolyChord version v1.13 or later and the setting
write_dead=True.
Parameters
----------
file_root: str
Root for run output file names (PolyChord file_root setting).
base_dir: str
Directory containing data (PolyChord base_dir setting).
process_stats_file: bool, optional
Should PolyChord's <root>.stats file be processed? Set to False if you
don't have the <root>.stats file (such as if PolyChord was run with
write_stats=False).
kwargs: dict, optional
Options passed to ns_run_utils.check_ns_run.
Returns
-------
ns_run: dict
Nested sampling run dict (see the module docstring for more details).
"""
# N.B. PolyChord dead points files also contains remaining live points at
# termination
samples = np.loadtxt(os.path.join(base_dir, file_root) + '_dead-birth.txt')
ns_run = process_samples_array(samples, **kwargs)
ns_run['output'] = {'base_dir': base_dir, 'file_root': file_root}
if process_stats_file:
try:
ns_run['output'] = process_polychord_stats(file_root, base_dir)
except (OSError, IOError, ValueError, IndexError, NameError,
TypeError) as err:
warnings.warn(
('process_polychord_stats raised {} processing {}.stats file. '
' I am proceeding without the .stats file.').format(
type(err).__name__, os.path.join(base_dir, file_root)),
UserWarning)
return ns_run
[docs]def process_multinest_run(file_root, base_dir, **kwargs):
"""Loads data from a MultiNest run into the nestcheck dictionary format for
analysis.
N.B. producing required output file containing information about the
iso-likelihood contours within which points were sampled (where they were
"born") requies MultiNest version 3.11 or later.
Parameters
----------
file_root: str
Root name for output files. When running MultiNest, this is determined
by the nest_root parameter.
base_dir: str
Directory containing output files. When running MultiNest, this is
determined by the nest_root parameter.
kwargs: dict, optional
Passed to ns_run_utils.check_ns_run (via process_samples_array)
Returns
-------
ns_run: dict
Nested sampling run dict (see the module docstring for more details).
"""
# Load dead and live points
dead = np.loadtxt(os.path.join(base_dir, file_root) + 'dead-birth.txt')
live = np.loadtxt(os.path.join(base_dir, file_root)
+ 'phys_live-birth.txt')
# Remove unnecessary final columns
dead = dead[:, :-2]
live = live[:, :-1]
assert dead[:, -2].max() < live[:, -2].min(), (
'final live points should have greater logls than any dead point!',
dead, live)
ns_run = process_samples_array(np.vstack((dead, live)), **kwargs)
assert np.all(ns_run['thread_min_max'][:, 0] == -np.inf), (
'As MultiNest does not currently perform dynamic nested sampling, all '
'threads should start by sampling the whole prior.')
ns_run['output'] = {}
ns_run['output']['file_root'] = file_root
ns_run['output']['base_dir'] = base_dir
return ns_run
[docs]def process_dynesty_run(results):
"""Transforms results from a dynesty run into the nestcheck dictionary
format for analysis. This function has been tested with dynesty v9.2.0.
Note that the nestcheck point weights and evidence will not be exactly
the same as the dynesty ones as nestcheck calculates logX volumes more
precisely (using the trapezium rule).
This function does not require the birth_inds_given_contours and
threads_given_birth_inds functions as dynesty results objects
already include thread labels via their samples_id property. If the
dynesty run is dynamic, the batch_bounds property is need to determine
the threads' starting birth contours.
Parameters
----------
results: dynesty results object
N.B. the remaining live points at termination must be included in the
results (dynesty samplers' run_nested method does this if
add_live_points=True - its default value).
Returns
-------
ns_run: dict
Nested sampling run dict (see the module docstring for more details).
"""
samples = np.zeros((results.samples.shape[0],
results.samples.shape[1] + 3))
samples[:, 0] = results.logl
samples[:, 1] = results.samples_id
samples[:, 3:] = results.samples
unique_th, first_inds = np.unique(results.samples_id, return_index=True)
assert np.array_equal(unique_th, np.asarray(range(unique_th.shape[0])))
thread_min_max = np.full((unique_th.shape[0], 2), np.nan)
is_dynamic_dynesty = False
try:
# Try processing standard nested sampling results
assert unique_th.shape[0] == results.nlive
assert np.array_equal(
np.unique(results.samples_id[-results.nlive:]),
np.asarray(range(results.nlive))), (
'perhaps the final live points are not included?')
thread_min_max[:, 0] = -np.inf
except AttributeError:
# If results has no nlive attribute, it must be dynamic nested sampling
assert unique_th.shape[0] == sum(results.batch_nlive)
# if the object has a samples_n attribute, it is from dynesty
if hasattr(results, 'samples_n'):
is_dynamic_dynesty = True
#numpy diff goes out[i] = samples[i+1] - samples[i], so it records the
#samples added/removed at samples[i]
diff_nlive = np.diff(results.samples_n)
#results.samples_n tells us how many live samples there are at a given iteration,
#so use the diff of this to assign the samples[change_in_nlive_at_sample (col 2)]
#value. We know we want the last n_live to end with 1
samples[:-1,2] = diff_nlive
for th_lab, ind in zip(unique_th, first_inds):
thread_min_max[th_lab, 0] = (
results.batch_bounds[results.samples_batch[ind], 0])
for th_lab in unique_th:
final_ind = np.where(results.samples_id == th_lab)[0][-1]
thread_min_max[th_lab, 1] = results.logl[final_ind]
if not is_dynamic_dynesty:
samples[final_ind, 2] = -1
assert np.all(~np.isnan(thread_min_max))
run = nestcheck.ns_run_utils.dict_given_run_array(samples, thread_min_max)
nestcheck.ns_run_utils.check_ns_run(run)
return run
[docs]def process_polychord_stats(file_root, base_dir):
"""Reads a PolyChord <root>.stats output file and returns the information
contained in a dictionary.
Parameters
----------
file_root: str
Root for run output file names (PolyChord file_root setting).
base_dir: str
Directory containing data (PolyChord base_dir setting).
Returns
-------
output: dict
See PolyChord documentation for more details.
"""
filename = os.path.join(base_dir, file_root) + '.stats'
output = {'base_dir': base_dir,
'file_root': file_root}
with open(filename, 'r') as stats_file:
lines = stats_file.readlines()
output['logZ'] = float(lines[8].split()[2])
output['logZerr'] = float(lines[8].split()[4])
# Cluster logZs and errors
output['logZs'] = []
output['logZerrs'] = []
for line in lines[14:]:
if line[:5] != 'log(Z':
break
output['logZs'].append(float(
re.findall(r'=(.*)', line)[0].split()[0]))
output['logZerrs'].append(float(
re.findall(r'=(.*)', line)[0].split()[2]))
# Other output info
nclust = len(output['logZs'])
output['ncluster'] = nclust
output['nposterior'] = int(lines[20 + nclust].split()[1])
output['nequals'] = int(lines[21 + nclust].split()[1])
output['ndead'] = int(lines[22 + nclust].split()[1])
output['nlive'] = int(lines[23 + nclust].split()[1])
try:
output['nlike'] = [int(x) for x in lines[24 + nclust].split()[1:]]
if len(output['nlike']) == 1:
output['nlike'] = output['nlike'][0]
except ValueError:
# if nlike has too many digits, PolyChord just writes ***** to .stats
# file. This causes a ValueError
output['nlike'] = np.nan
line = lines[25 + nclust].split()
i = line.index('(')
# If there are multiple parameter speeds then multiple values are written
# for avnlike and avnlikeslice
output['avnlike'] = [float(x) for x in line[1:i]]
# If only one value, keep as float
if len(output['avnlike']) == 1:
output['avnlike'] = output['avnlike'][0]
output['avnlikeslice'] = [float(x) for x in line[i+1:-3]]
# If only one value, keep as float
if len(output['avnlikeslice']) == 1:
output['avnlikeslice'] = output['avnlikeslice'][0]
# Means and stds of dimensions (not produced by PolyChord<=1.13)
if len(lines) > 29 + nclust:
output['param_means'] = []
output['param_mean_errs'] = []
for line in lines[29 + nclust:]:
if '------------------' in line:
# A line of dashes is used to show the start of the derived
# parameters in the .stats file for later versions of
# PolyChord
continue
output['param_means'].append(float(line.split()[1]))
output['param_mean_errs'].append(float(line.split()[3]))
return output
[docs]def process_samples_array(samples, **kwargs):
"""Convert an array of nested sampling dead and live points of the type
produced by PolyChord and MultiNest into a nestcheck nested sampling run
dictionary.
Parameters
----------
samples: 2d numpy array
Array of dead points and any remaining live points at termination.
Has #parameters + 2 columns:
param_1, param_2, ... , logl, birth_logl
kwargs: dict, optional
Options passed to birth_inds_given_contours
Returns
-------
ns_run: dict
Nested sampling run dict (see the module docstring for more
details). Only contains information in samples (not additional
optional output key).
"""
samples = samples[np.argsort(samples[:, -2])]
ns_run = {}
ns_run['logl'] = samples[:, -2]
ns_run['theta'] = samples[:, :-2]
birth_contours = samples[:, -1]
# birth_contours, ns_run['theta'] = check_logls_unique(
# samples[:, -2], samples[:, -1], samples[:, :-2])
birth_inds = birth_inds_given_contours(
birth_contours, ns_run['logl'], **kwargs)
ns_run['thread_labels'] = threads_given_birth_inds(birth_inds)
unique_threads = np.unique(ns_run['thread_labels'])
assert np.array_equal(unique_threads,
np.asarray(range(unique_threads.shape[0])))
# Work out nlive_array and thread_min_max logls from thread labels and
# birth contours
thread_min_max = np.zeros((unique_threads.shape[0], 2))
# NB delta_nlive indexes are offset from points' indexes by 1 as we need an
# element to represent the initial sampling of live points before any dead
# points are created.
# I.E. birth on step 1 corresponds to replacing dead point zero
delta_nlive = np.zeros(samples.shape[0] + 1)
for label in unique_threads:
thread_inds = np.where(ns_run['thread_labels'] == label)[0]
# Max is final logl in thread
thread_min_max[label, 1] = ns_run['logl'][thread_inds[-1]]
thread_start_birth_ind = birth_inds[thread_inds[0]]
# delta nlive indexes are +1 from logl indexes to allow for initial
# nlive (before first dead point)
delta_nlive[thread_inds[-1] + 1] -= 1
if thread_start_birth_ind == birth_inds[0]:
# thread minimum is -inf as it starts by sampling from whole prior
thread_min_max[label, 0] = -np.inf
delta_nlive[0] += 1
else:
assert thread_start_birth_ind >= 0
thread_min_max[label, 0] = ns_run['logl'][thread_start_birth_ind]
delta_nlive[thread_start_birth_ind + 1] += 1
ns_run['thread_min_max'] = thread_min_max
ns_run['nlive_array'] = np.cumsum(delta_nlive)[:-1]
return ns_run
[docs]def birth_inds_given_contours(birth_logl_arr, logl_arr, **kwargs):
"""Maps the iso-likelihood contours on which points were born to the
index of the dead point on this contour.
MultiNest and PolyChord use different values to identify the inital live
points which were sampled from the whole prior (PolyChord uses -1e+30
and MultiNest -0.179769313486231571E+309). However in each case the first
dead point must have been sampled from the whole prior, so for either
package we can use
init_birth = birth_logl_arr[0]
If there are many points with the same logl_arr and dup_assert is False,
these points are randomly assigned an order (to ensure results are
consistent, random seeding is used).
Parameters
----------
logl_arr: 1d numpy array
logl values of each point.
birth_logl_arr: 1d numpy array
Birth contours - i.e. logl values of the iso-likelihood contour from
within each point was sampled (on which it was born).
dup_assert: bool, optional
See ns_run_utils.check_ns_run_logls docstring.
dup_warn: bool, optional
See ns_run_utils.check_ns_run_logls docstring.
Returns
-------
birth_inds: 1d numpy array of ints
Step at which each element of logl_arr was sampled. Points sampled from
the whole prior are assigned value -1.
"""
dup_assert = kwargs.pop('dup_assert', False)
dup_warn = kwargs.pop('dup_warn', False)
if kwargs:
raise TypeError('Unexpected **kwargs: {0}'.format(kwargs))
assert logl_arr.ndim == 1, logl_arr.ndim
assert birth_logl_arr.ndim == 1, birth_logl_arr.ndim
# Check for duplicate logl values (if specified by dup_assert or dup_warn)
nestcheck.ns_run_utils.check_ns_run_logls(
{'logl': logl_arr}, dup_assert=dup_assert, dup_warn=dup_warn)
# Random seed so results are consistent if there are duplicate logls
state = np.random.get_state() # Save random state before seeding
np.random.seed(0)
# Calculate birth inds
init_birth = birth_logl_arr[0]
assert np.all(birth_logl_arr <= logl_arr), (
logl_arr[birth_logl_arr > logl_arr])
birth_inds = np.full(birth_logl_arr.shape, np.nan)
birth_inds[birth_logl_arr == init_birth] = -1
for i, birth_logl in enumerate(birth_logl_arr):
if not np.isnan(birth_inds[i]):
# birth ind has already been assigned
continue
dup_deaths = np.where(logl_arr == birth_logl)[0]
if dup_deaths.shape == (1,):
# death index is unique
birth_inds[i] = dup_deaths[0]
continue
# The remainder of this loop deals with the case that multiple points
# have the same logl value (=birth_logl). This can occur due to limited
# precision, or for likelihoods with contant regions. In this case we
# randomly assign the duplicates birth steps in a manner
# that provides a valid division into nested sampling runs
dup_births = np.where(birth_logl_arr == birth_logl)[0]
assert dup_deaths.shape[0] > 1, dup_deaths
if np.all(birth_logl_arr[dup_deaths] != birth_logl):
# If no points both are born and die on this contour, we can just
# randomly assign an order
np.random.shuffle(dup_deaths)
inds_to_use = dup_deaths
else:
# If some points are both born and die on the contour, we need to
# take care that the assigned birth inds do not result in some
# points dying before they are born
try:
inds_to_use = sample_less_than_condition(
dup_deaths, dup_births)
except ValueError:
raise ValueError((
'There is no way to allocate indexes dup_deaths={} such '
'that each is less than dup_births={}.').format(
dup_deaths, dup_births))
try:
# Add our selected inds_to_use values to the birth_inds array
# Note that dup_deaths (and hence inds to use) may have more
# members than dup_births, because one of the duplicates may be
# the final point in a thread. We therefore include only the first
# dup_births.shape[0] elements
birth_inds[dup_births] = inds_to_use[:dup_births.shape[0]]
except ValueError:
warnings.warn((
'for logl={}, the number of points born (indexes='
'{}) is bigger than the number of points dying '
'(indexes={}). This indicates a problem with your '
'nested sampling software - it may be caused by '
'a bug in PolyChord which was fixed in PolyChord '
'v1.14, so try upgrading. I will try to give an '
'approximate allocation of threads but this may '
'fail.').format(
birth_logl, dup_births, inds_to_use), UserWarning)
extra_inds = np.random.choice(
inds_to_use, size=dup_births.shape[0] - inds_to_use.shape[0])
inds_to_use = np.concatenate((inds_to_use, extra_inds))
np.random.shuffle(inds_to_use)
birth_inds[dup_births] = inds_to_use[:dup_births.shape[0]]
assert np.all(~np.isnan(birth_inds)), np.isnan(birth_inds).sum()
np.random.set_state(state) # Reset random state
return birth_inds.astype(int)
[docs]def sample_less_than_condition(choices_in, condition):
"""Creates a random sample from choices without replacement, subject to the
condition that each element of the output is greater than the corresponding
element of the condition array.
condition should be in ascending order.
"""
output = np.zeros(min(condition.shape[0], choices_in.shape[0]))
choices = copy.deepcopy(choices_in)
for i, _ in enumerate(output):
# randomly select one of the choices which meets condition
avail_inds = np.where(choices < condition[i])[0]
selected_ind = np.random.choice(avail_inds)
output[i] = choices[selected_ind]
# remove the chosen value
choices = np.delete(choices, selected_ind)
return output
[docs]def threads_given_birth_inds(birth_inds):
"""Divides a nested sampling run into threads, using info on the indexes
at which points were sampled. See "Sampling errors in nested sampling
parameter estimation" (Higson et al. 2018) for more information.
Parameters
----------
birth_inds: 1d numpy array
Indexes of the iso-likelihood contours from within which each point was
sampled ("born").
Returns
-------
thread_labels: 1d numpy array of ints
labels of the thread each point belongs to.
"""
unique, counts = np.unique(birth_inds, return_counts=True)
# First get a list of all the indexes on which threads start and their
# counts. This is every point initially sampled from the prior, plus any
# indexes where more than one point is sampled.
thread_start_inds = np.concatenate((
unique[:1], unique[1:][counts[1:] > 1]))
thread_start_counts = np.concatenate((
counts[:1], counts[1:][counts[1:] > 1] - 1))
thread_labels = np.full(birth_inds.shape, np.nan)
thread_num = 0
for nmulti, multi in enumerate(thread_start_inds):
for i, start_ind in enumerate(np.where(birth_inds == multi)[0]):
# unless nmulti=0 the first point born on the contour (i=0) is
# already assigned to a thread
if i != 0 or nmulti == 0:
# check point has not already been assigned
assert np.isnan(thread_labels[start_ind])
thread_labels[start_ind] = thread_num
# find the point which replaced it
next_ind = np.where(birth_inds == start_ind)[0]
while next_ind.shape != (0,):
# check point has not already been assigned
assert np.isnan(thread_labels[next_ind[0]])
thread_labels[next_ind[0]] = thread_num
# find the point which replaced it
next_ind = np.where(birth_inds == next_ind[0])[0]
thread_num += 1
if not np.all(~np.isnan(thread_labels)):
warnings.warn((
'{} points (out of a total of {}) were not given a thread label! '
'This is likely due to small numerical errors in your nested '
'sampling software while running the calculation or writing the '
'input files. '
'I will try to give an approximate answer by randomly assigning '
'these points to threads.'
'\nIndexes without labels are {}'
'\nIndexes on which threads start are {} with {} threads '
'starting on each.').format(
(np.isnan(thread_labels)).sum(), birth_inds.shape[0],
np.where(np.isnan(thread_labels))[0],
thread_start_inds, thread_start_counts))
inds = np.where(np.isnan(thread_labels))[0]
state = np.random.get_state() # Save random state before seeding
np.random.seed(0) # make thread decomposition is reproducible
for ind in inds:
# Get the set of threads with members both before and after ind to
# ensure we don't change nlive_array by extending a thread
labels_to_choose = np.intersect1d( # N.B. this removes nans too
thread_labels[:ind], thread_labels[ind + 1:])
if labels_to_choose.shape[0] == 0:
# In edge case that there is no intersection, just randomly
# select from non-nan thread labels
labels_to_choose = np.unique(
thread_labels[~np.isnan(thread_labels)])
thread_labels[ind] = np.random.choice(labels_to_choose)
np.random.set_state(state) # Reset random state
assert np.all(~np.isnan(thread_labels)), (
'{} points still do not have thread labels'.format(
(np.isnan(thread_labels)).sum()))
assert np.array_equal(thread_labels, thread_labels.astype(int)), (
'Thread labels should all be ints!')
thread_labels = thread_labels.astype(int)
# Check unique thread labels are a sequence from 0 to nthreads-1
assert np.array_equal(
np.unique(thread_labels),
np.asarray(range(sum(thread_start_counts)))), (
str(np.unique(thread_labels)) + ' is not equal to range('
+ str(sum(thread_start_counts)) + ')')
return thread_labels