Source code for foxes.utils.xarray_utils

import numpy as np
from xarray import Dataset, SerializationWarning
from pathlib import Path
import warnings

import foxes.variables as FV


[docs] def compute_scale_and_offset(min, max, n, hasnan=True): """ Computes scale_factor and add_offset for packing data into n-bit integers. Parameters ---------- min: float Minimum value of the data max: float Maximum value of the data n: int Number of bits for packing hasnan: bool NaN present in the data Returns ------- scale_factor: float The scale factor add_offset: float The add offset fill_value: float The fill value for NaN Notes ----- Source: https://docs.unidata.ucar.edu/nug/current/best_practices.html """ if min == max: max = min + 1 if hasnan: scale_factor = (max - min) / (2**n - 2) add_offset = 0.5 * (max + min) fill_value = -(2 ** (n - 1)) else: scale_factor = (max - min) / (2**n - 1) add_offset = min + 2 ** (n - 1) * scale_factor fill_value = None return scale_factor, add_offset, fill_value
[docs] def pack_value(unpacked_value, scale_factor, add_offset, dtype, fill_value): """ Pack a floating point value into an integer representation. Parameters ---------- unpacked_value: float or np.ndarray The floating point value(s) to be packed scale_factor: float The scale factor add_offset: float The add offset dtype: numpy.dtype The dtype of packed values fill_value: float The fill value for NaN Returns ------- packed_value: int or np.ndarray The packed integer value(s) :group: utils """ if fill_value is None: return np.floor((unpacked_value - add_offset) / scale_factor).astype(dtype) else: return np.where( np.isnan(unpacked_value), fill_value, np.floor((unpacked_value - add_offset) / scale_factor), ).astype(dtype)
[docs] def unpack_value(packed_value, scale_factor, add_offset, fill_value): """ Unpack an integer representation back into a floating point value. Parameters ---------- packed_value: int or np.ndarray The packed integer value(s) to be unpacked scale_factor: float The scale factor add_offset: float The add offset fill_value: float The fill value for NaN Returns ------- unpacked_value: float or np.ndarray The unpacked floating point value(s) :group: utils """ if fill_value is None: return (packed_value * scale_factor + add_offset).astype(scale_factor.dtype) else: return np.where( packed_value == fill_value, np.nan, packed_value * scale_factor + add_offset ).astype(scale_factor.dtype)
[docs] def get_encoding(data, complevel=5): """ Get the encoding parameters for a numpy array. Parameters ---------- data: np.ndarray The numpy array for which to get the encoding information. complevel: int The compression level (1-9) Returns ------- encoding: dict The encoding information of the numpy array. :group: utils """ enc = {"zlib": True, "complevel": complevel} if np.issubdtype(data.dtype, np.integer): for t in [np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32]: if np.all(data == data.astype(t)): enc["dtype"] = t.__name__ elif np.issubdtype(data.dtype, np.floating): min = np.min(data) max = np.max(data) hasnan = np.any(np.isnan(data)) for t, n in zip([np.int8, np.int16], [8, 16]): scale_factor, add_offset, fill_value = compute_scale_and_offset( min, max, n, hasnan ) packed = pack_value(data, scale_factor, add_offset, t, fill_value) unpacked = unpack_value(packed, scale_factor, add_offset, fill_value) try: np.testing.assert_allclose(data, unpacked, atol=scale_factor) enc["dtype"] = t.__name__ enc["scale_factor"] = scale_factor enc["add_offset"] = add_offset enc["_FillValue"] = fill_value break except AssertionError: continue return enc
[docs] def write_nc( ds, fpath, round={}, complevel=5, nc_engine="netcdf4", verbosity=1, **kwargs, ): """ Writes a dataset to netCDF file Parameters ---------- fpath: str Path to the output file, should be nc round: dict or int The rounding digits, falling back to defaults if variable not found. If int, applies to all variables. complevel: int The compression level nc_engine: str The NetCDF engine to use verbosity: int The verbosity level, 0 = silent kwargs: dict, optional Additional parameters for xarray.to_netcdf :group: utils """ def _round(x, v, d): """Helper function to round values""" if d is not None: if np.issubdtype(x.dtype, np.integer): return x elif np.issubdtype(x.dtype, np.floating): if verbosity > 1: print(f"File {fpath.name}: Rounding {v} to {d} decimals") r = np.round(x, d) return r return x enc = {} fpath = Path(fpath) if round is not None: crds = {} for v, x in ds.coords.items(): if isinstance(round, int): d = round else: d = round.get(v, FV.get_default_digits(v)) crds[v] = _round(x.to_numpy(), v, d) enc[v] = get_encoding(crds[v], complevel=complevel) # print("WRITENC ENC",v, enc[v]) dvrs = {} for v, x in ds.data_vars.items(): if isinstance(round, int): d = round else: d = round.get(v, FV.get_default_digits(v)) if v != FV.WEIGHT: dvrs[v] = (x.dims, _round(x.to_numpy(), v, d)) else: dvrs[v] = (x.dims, x.to_numpy()) enc[v] = get_encoding(dvrs[v][1], complevel=complevel) # print("WRITENC ENC",v, enc[v]) ds = Dataset(coords=crds, data_vars=dvrs) if verbosity > 0: print("Writing file", fpath) kw = dict(encoding=enc, engine=nc_engine) kw.update(kwargs) # silencing a warning about _FillValue = None with warnings.catch_warnings(): warnings.simplefilter("ignore", category=SerializationWarning) ds.to_netcdf(fpath, **kw)