Source code for foxes.engines.pool

import numpy as np
import xarray as xr
from abc import abstractmethod
from tqdm import tqdm

from foxes.core import Engine
import foxes.constants as FC


def _run(algo, model, data, iterative, chunk_store, i0_t0, **cpars):
    """Helper function for running in a single process"""
    algo.reset_chunk_store(chunk_store)
    results = model.calculate(algo, *data, **cpars)
    chunk_store = algo.reset_chunk_store() if iterative else {}
    cstore = {i0_t0: chunk_store[i0_t0]} if i0_t0 in chunk_store else {}
    return results, cstore


def _run_map(func, inputs, *args, **kwargs):
    """Helper function for running map func on proc"""
    return [func(x, *args, **kwargs) for x in inputs]


[docs] class PoolEngine(Engine): """ Abstract engine for pool type parallelizations. :group: engines """ @abstractmethod def _create_pool(self): """Creates the pool""" pass @abstractmethod def _submit(self, f, *args, **kwargs): """ Submits to the pool Parameters ---------- f: Callable The function f(*args, **kwargs) to be submitted args: tuple, optional Arguments for the function kwargs: dict, optional Arguments for the function Returns ------- future: object The future object """ pass @abstractmethod def _result(self, future): """ Waits for result from a future Parameters ---------- future: object The future Returns ------- result: object The calculation result """ pass @abstractmethod def _shutdown_pool(self): """Shuts down the pool""" pass def __enter__(self): self._create_pool() return super().__enter__() def __exit__(self, *exit_args): self._shutdown_pool() super().__exit__(*exit_args)
[docs] def map( self, func, inputs, *args, **kwargs, ): """ Runs a function on a list of files Parameters ---------- func: Callable Function to be called on each file, func(input, *args, **kwargs) -> data inputs: array-like The input data list args: tuple, optional Arguments for func kwargs: dict, optional Keyword arguments for func Returns ------- results: list The list of results """ if len(inputs) == 0: return [] elif len(inputs) == 1: return [func(inputs[0], *args, **kwargs)] else: inptl = np.array_split(inputs, min(self.n_procs, len(inputs))) jobs = [] for subi in inptl: jobs.append(self._submit(_run_map, func, subi, *args, **kwargs)) results = [] for j in jobs: results += self._result(j) return results
[docs] def run_calculation( self, algo, model, model_data=None, farm_data=None, point_data=None, out_vars=[], chunk_store={}, sel=None, isel=None, iterative=False, **calc_pars, ): """ Runs the model calculation Parameters ---------- algo: foxes.core.Algorithm The algorithm object model: foxes.core.DataCalcModel The model that whose calculate function should be run model_data: xarray.Dataset The initial model data farm_data: xarray.Dataset The initial farm data point_data: xarray.Dataset The initial point data out_vars: list of str, optional Names of the output variables chunk_store: foxes.utils.Dict The chunk store sel: dict, optional Selection of coordinate subsets isel: dict, optional Selection of coordinate subsets index values iterative: bool Flag for use within the iterative algorithm calc_pars: dict, optional Additional parameters for the model.calculate() Returns ------- results: xarray.Dataset The model results """ # subset selection: model_data, farm_data, point_data = self.select_subsets( model_data, farm_data, point_data, sel=sel, isel=isel ) # basic checks: super().run_calculation(algo, model, model_data, farm_data, point_data) # prepare: n_states = model_data.sizes[FC.STATE] out_coords = model.output_coords() coords = {} if FC.STATE in out_coords and FC.STATE in model_data.coords: coords[FC.STATE] = model_data[FC.STATE].to_numpy() if farm_data is None: farm_data = xr.Dataset() goal_data = farm_data if point_data is None else point_data # DEBUG objec mem sizes: # from foxes.utils import print_mem # for m in [algo] + model.models: # print_mem(m, pre_str="MULTIP CHECKING LARGE DATA", min_csize=9999) # calculate chunk sizes: n_targets = point_data.sizes[FC.TARGET] if point_data is not None else 0 chunk_sizes_states, chunk_sizes_targets = self.calc_chunk_sizes( n_states, n_targets ) n_chunks_states = len(chunk_sizes_states) n_chunks_targets = len(chunk_sizes_targets) self.print( f"{type(self).__name__}: Selecting n_chunks_states = {n_chunks_states}, n_chunks_targets = {n_chunks_targets}", level=2, ) # prepare and submit chunks: n_chunks_all = n_chunks_states * n_chunks_targets self.print( f"{type(self).__name__}: Submitting {n_chunks_all} chunks to {self.n_procs} processes", level=2, ) pbar = tqdm(total=n_chunks_all) if self.verbosity > 1 else None jobs = {} i0_states = 0 for chunki_states in range(n_chunks_states): i1_states = i0_states + chunk_sizes_states[chunki_states] i0_targets = 0 for chunki_points in range(n_chunks_targets): i1_targets = i0_targets + chunk_sizes_targets[chunki_points] # get this chunk's data: data = self.get_chunk_input_data( algo=algo, model_data=model_data, farm_data=farm_data, point_data=point_data, states_i0_i1=(i0_states, i1_states), targets_i0_i1=(i0_targets, i1_targets), out_vars=out_vars, ) # submit model calculation: jobs[(chunki_states, chunki_points)] = self._submit( _run, algo, model, data, iterative, chunk_store, (i0_states, i0_targets), **calc_pars, ) del data i0_targets = i1_targets if pbar is not None: pbar.update() i0_states = i1_states del calc_pars, farm_data, point_data if pbar is not None: pbar.close() # wait for results: if n_chunks_all > 1 or self.verbosity > 1: self.print( f"{type(self).__name__}: Computing {n_chunks_all} chunks using {self.n_procs} processes" ) pbar = ( tqdm(total=n_chunks_all) if n_chunks_all > 1 and self.verbosity > 0 else None ) results = {} for chunki_states in range(n_chunks_states): for chunki_points in range(n_chunks_targets): key = (chunki_states, chunki_points) results[key] = self._result(jobs.pop((chunki_states, chunki_points))) if pbar is not None: pbar.update() if pbar is not None: pbar.close() return self.combine_results( algo=algo, results=results, model_data=model_data, out_vars=out_vars, out_coords=out_coords, n_chunks_states=n_chunks_states, n_chunks_targets=n_chunks_targets, goal_data=goal_data, iterative=iterative, )