Source code for xcdat.regridder.xesmf

import xarray as xr

from xcdat.regridder.base import BaseRegridder, preserve_bounds
from xcdat.utils import _has_module

# TODO: Test this conditional.
_has_xesmf = _has_module("xesmf")
if _has_xesmf:  # pragma: no cover
    import xesmf as xe
else:  # pragma: no cover
    raise ModuleNotFoundError(
        "The `xesmf` package is required for horizontal regridding with `xesmf`. Make "
        "sure your platform supports `xesmf` and it is installed in your conda "
        "environment."
    )


VALID_METHODS = [
    "bilinear",
    "conservative",
    "conservative_normed",
    "patch",
    "nearest_s2d",
    "nearest_d2s",
]

VALID_EXTRAP_METHODS = ["inverse_dist", "nearest_s2d"]


[docs]class XESMFRegridder(BaseRegridder):
[docs] def __init__( self, input_grid: xr.Dataset, output_grid: xr.Dataset, method: str, periodic: bool = False, extrap_method: str = None, extrap_dist_exponent: float = None, extrap_num_src_pnts: int = None, ignore_degenerate: bool = True, **options, ): """Wrapper class for xESMF regridder class. Parameters ---------- input_grid : xr.Dataset Contains source grid coordinates. output_grid : xr.Dataset Contains desintation grid coordinates. method : str Regridding method. Options are - bilinear - conservative - conservative_normed - patch - nearest_s2d - nearest_d2s periodic : bool Treat longitude as periodic. Used for global grids. extrap_method : str Extrapolation method. Options are - inverse_dist - nearest_s2d extrap_dist_exponent : float The exponent to raise the distance to when calculating weights for the extrapolation method. extrap_num_src_pnts : int The number of source points to use for the extrapolation methods that use more than one source point. ignore_degenerate : bool Ignore degenerate cells when checking the `input_grid` for errors. If set False, a degenerate cell produces an error. This only applies to "conservative" and "conservative_normed" regridding methods. Raises ------ KeyError If data variable does not exist in the Dataset. ValueError If ``method`` is not valid. ValueError If ``extrap_method`` is not valid. Examples -------- Import xCDAT: >>> import xcdat >>> from xcdat.regridder import xesmf Open a dataset: >>> ds = xcdat.open_dataset("ts.nc") Create output grid: >>> output_grid = xcdat.create_gaussian_grid(32) Create regridder: >>> regridder = xesmf.XESMFRegridder(ds, output_grid, method="bilinear") Regrid data: >>> data_new_grid = regridder.horizontal("ts", ds, periodic=True) """ super().__init__(input_grid, output_grid) if method not in VALID_METHODS: raise ValueError( f"{method!r} is not valid, possible options: {', '.join(VALID_METHODS)}" ) if extrap_method is not None and extrap_method not in VALID_EXTRAP_METHODS: raise ValueError( f"{extrap_method!r} is not valid, possible options: {', '.join(VALID_EXTRAP_METHODS)}" ) self._method = method self._periodic = periodic self._extrap_method = extrap_method self._extrap_dist_exponent = extrap_dist_exponent self._extrap_num_src_pnts = extrap_num_src_pnts self._regridder: xe.XESMFRegridder = None self._extra_options = options
[docs] def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Regrid ``data_var`` in ``ds`` to output grid. Parameters ---------- data_var : str The name of the data variable inside the dataset to regrid. ds : xr.Dataset The dataset containing ``data_var``. Returns ------- xr.Dataset Dataset with variable on the destination grid. Raises ------ KeyError If data variable does not exist in the Dataset. Examples -------- Create output grid: >>> output_grid = xcdat.create_gaussian_grid(32) Create regridder: >>> regridder = xesmf.XESMFRegridder(ds, output_grid, method="bilinear") Regrid data: >>> data_new_grid = regridder.horizontal("ts", ds) """ input_da = ds.get(data_var, None) if input_da is None: raise KeyError( f"The data variable '{data_var}' does not exist in the dataset." ) if self._regridder is None: self._regridder = xe.Regridder( self._input_grid, self._output_grid, method=self._method, periodic=self._periodic, extrap_method=self._extrap_method, extrap_dist_exponent=self._extrap_dist_exponent, extrap_num_src_pnts=self._extrap_num_src_pnts, **self._extra_options, ) output_da = self._regridder(input_da, keep_attrs=True) output_ds = xr.Dataset({data_var: output_da}, attrs=ds.attrs) # preserve non-spatial bounds output_ds = preserve_bounds(ds, self._output_grid, output_ds) output_ds = output_ds.bounds.add_missing_bounds() return output_ds