Source code for nestcheck.parallel_utils

#!/usr/bin/env python
"""
Parallel wrapper functions using the concurrent.futures module.
"""

import concurrent.futures
import functools
import warnings
import tqdm


[docs]def parallel_map(func, *arg_iterable, **kwargs): """Apply function to iterable with parallel map, and hence returns results in order. functools.partial is used to freeze func_pre_args and func_kwargs, meaning that the iterable argument must be the last positional argument. Roughly equivalent to >>> [func(*func_pre_args, x, **func_kwargs) for x in arg_iterable] Parameters ---------- func: function Function to apply to list of args. arg_iterable: iterable argument to iterate over. chunksize: int, optional Perform function in batches func_pre_args: tuple, optional Positional arguments to place before the iterable argument in func. func_kwargs: dict, optional Additional keyword arguments for func. parallel: bool, optional To turn off parallelisation if needed. parallel_warning: bool, optional To turn off warning for no parallelisation if needed. max_workers: int or None, optional Number of processes. If max_workers is None then concurrent.futures.ProcessPoolExecutor defaults to using the number of processors of the machine. N.B. If max_workers=None and running on supercomputer clusters with multiple nodes, this may default to the number of processors on a single node. Returns ------- results_list: list of function outputs """ chunksize = kwargs.pop('chunksize', 1) func_pre_args = kwargs.pop('func_pre_args', ()) func_kwargs = kwargs.pop('func_kwargs', {}) max_workers = kwargs.pop('max_workers', None) parallel = kwargs.pop('parallel', True) parallel_warning = kwargs.pop('parallel_warning', True) if kwargs: raise TypeError('Unexpected **kwargs: {0}'.format(kwargs)) func_to_map = functools.partial(func, *func_pre_args, **func_kwargs) if parallel: pool = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) return list(pool.map(func_to_map, *arg_iterable, chunksize=chunksize)) else: if parallel_warning: warnings.warn(('parallel_map has parallel=False - turn on ' 'parallelisation for faster processing'), UserWarning) return list(map(func_to_map, *arg_iterable))
[docs]def parallel_apply(func, arg_iterable, **kwargs): """Apply function to iterable with parallelisation and a tqdm progress bar. Roughly equivalent to >>> [func(*func_pre_args, x, *func_args, **func_kwargs) for x in arg_iterable] but will **not** necessarily return results in input order. Parameters ---------- func: function Function to apply to list of args. arg_iterable: iterable argument to iterate over. func_args: tuple, optional Additional positional arguments for func. func_pre_args: tuple, optional Positional arguments to place before the iterable argument in func. func_kwargs: dict, optional Additional keyword arguments for func. parallel: bool, optional To turn off parallelisation if needed. parallel_warning: bool, optional To turn off warning for no parallelisation if needed. max_workers: int or None, optional Number of processes. If max_workers is None then concurrent.futures.ProcessPoolExecutor defaults to using the number of processors of the machine. N.B. If max_workers=None and running on supercomputer clusters with multiple nodes, this may default to the number of processors on a single node. Returns ------- results_list: list of function outputs """ max_workers = kwargs.pop('max_workers', None) parallel = kwargs.pop('parallel', True) parallel_warning = kwargs.pop('parallel_warning', True) func_args = kwargs.pop('func_args', ()) func_pre_args = kwargs.pop('func_pre_args', ()) func_kwargs = kwargs.pop('func_kwargs', {}) tqdm_kwargs = kwargs.pop('tqdm_kwargs', {}) if kwargs: raise TypeError('Unexpected **kwargs: {0}'.format(kwargs)) if 'leave' not in tqdm_kwargs: # default to leave=False tqdm_kwargs['leave'] = False assert isinstance(func_args, tuple), ( str(func_args) + ' is type ' + str(type(func_args))) assert isinstance(func_pre_args, tuple), ( str(func_pre_args) + ' is type ' + str(type(func_pre_args))) progress = select_tqdm() if not parallel: if parallel_warning: warnings.warn(('parallel_map has parallel=False - turn on ' 'parallelisation for faster processing'), UserWarning) return [func(*(func_pre_args + (x,) + func_args), **func_kwargs) for x in progress(arg_iterable, **tqdm_kwargs)] else: pool = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) futures = [] for element in arg_iterable: futures.append(pool.submit( func, *(func_pre_args + (element,) + func_args), **func_kwargs)) results = [] for fut in progress(concurrent.futures.as_completed(futures), total=len(arg_iterable), **tqdm_kwargs): results.append(fut.result()) return results
[docs]def select_tqdm(): """If running in a jupyter notebook, then returns tqdm_notebook. Otherwise returns a regular tqdm progress bar. Returns ------- progress: function """ try: progress = tqdm.tqdm_notebook assert get_ipython().has_trait('kernel') except (NameError, AssertionError): progress = tqdm.tqdm return progress