Source code for xcdat.regridder.xgcm

from typing import Any, Dict, Hashable, Literal, Optional, Union, get_args

import xarray as xr
from xgcm import Grid

from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
from xcdat.regridder.base import BaseRegridder, _preserve_bounds

XGCMVerticalMethods = Literal["linear", "conservative", "log"]

logger = _setup_custom_logger(__name__)

[docs] class XGCMRegridder(BaseRegridder):
[docs] def __init__( self, input_grid: xr.Dataset, output_grid: xr.Dataset, method: XGCMVerticalMethods = "linear", target_data: Optional[Union[str, xr.DataArray]] = None, grid_positions: Optional[Dict[str, str]] = None, periodic: bool = False, extra_init_options: Optional[Dict[str, Any]] = None, **options: Any, ): """ Extension of ``xgcm`` regridder. The ``XGCMRegridder`` extends ``xgcm`` by automatically constructing the ``Grid`` object, transposing the output data to match the dimensional order of the input data, and ensuring bounds and metadata are preserved in the output dataset. Linear and log methods require a single dimension position, which can usually be automatically derived. A custom position can be specified using the `grid_positions` argument. Conservative regridding requires multiple dimension positions, e.g., {"center": "xc", "left": "xg"} which can be passed using the `grid_positions` argument. ``xgcm.Grid`` can be passed additional arguments using ``extra_init_options``. These arguments can be found on `XGCM's Grid documentation <>`_. ``xgcm.Grid.transform`` can be passed additional arguments using ``options``. These arguments can be found on `XGCM's Grid.transform documentation <>`_. Parameters ---------- input_grid : xr.Dataset Contains source grid coordinates. output_grid : xr.Dataset Contains destination grid coordinates. method : XGCMVerticalMethods Regridding method, by default "linear". Options are - linear (default) - log - conservative target_data : Optional[Union[str, xr.DataArray]] Data to transform target data onto, either the key of a variable in the input dataset or an ``xr.DataArray``, by default None. grid_positions : Optional[Dict[str, str]] Mapping of dimension positions, by default None. If ``None`` then an attempt is made to derive this argument. periodic : Optional[bool] Whether the grid is periodic, by default False. extra_init_options : Optional[Dict[str, Any]] Extra options passed to the ``xgcm.Grid`` constructor, by default None. options : Optional[Dict[str, Any]] Extra options passed to the ``xgcm.Grid.transform`` method. Raises ------ KeyError If data variable does not exist in the Dataset. ValueError If ``method`` is not valid. Examples -------- Import xCDAT: >>> import xcdat Open a dataset: >>> ds = xcdat.open_dataset("...") Create output grid: >>> output_grid = xcdat.create_grid(lev=np.linspace(1000, 1, 5)) Regrid data to ``output_grid``: >>> output_data = ds.regridder.vertical( >>> "so", output_grid, tool="xgcm", method="linear" >>> ) Create pressure variable: >>> ds["pressure"] = (ds["hyam"] * ds["P0"] + ds["hybm"] * ds["PS"]).transpose( >>> **ds["T"].dims >>> ) Regrid data to ``output_grid`` in pressure space: >>> output_data = ds.regridder.vertical( >>> "so", output_grid, tool="xgcm", method="linear", target_data="pressure" >>> ) Passing additional arguments to ``xgcm.Grid`` and ``xgcm.Grid.transform``: >>> regridder = xgcm.XGCMRegridder( >>> ds, >>> output_grid, >>> method="linear", >>> extra_init_options={"boundary": "fill", "fill_value": 1e27}, >>> mask_edges=True >>> ) """ super().__init__(input_grid, output_grid) if method not in get_args(XGCMVerticalMethods): raise ValueError( f"{method!r} is invalid, possible choices: " f"{', '.join(get_args(XGCMVerticalMethods))}" ) self._method = method self._target_data = target_data self._grid_positions = grid_positions if extra_init_options is None: extra_init_options = {} extra_init_options["periodic"] = periodic self._extra_init_options = extra_init_options self._extra_options = options
[docs] def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Placeholder for base class.""" raise NotImplementedError()
[docs] def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """See documentation in :py:func:`xcdat.regridder.xgcm.XGCMRegridder`""" try: output_coord_z = get_dim_coords(self._output_grid, "Z") except KeyError: raise RuntimeError("Could not determine 'Z' coordinate in output dataset") if self._grid_positions is None: grid_coords = self._get_grid_positions() else: # correctly format argument grid_coords = {"Z": self._grid_positions} grid = Grid(ds, coords=grid_coords, **self._extra_init_options) target_data: Union[str, xr.DataArray, None] = None try: target_data = ds[self._target_data] except ValueError: target_data = self._target_data except KeyError: if self._target_data is not None and isinstance(self._target_data, str): raise RuntimeError( f"Could not find target variable {self._target_data!r} in dataset" ) target_data = None # assumes new verical coordinate has been calculated and stored as `pressure` # TODO: auto calculate pressure reference # cf_xarray only supports two ocean s-coordinate and ocean_sigma_coordinate output_da = grid.transform( ds[data_var], "Z", output_coord_z, target_data=target_data, method=self._method, **self._extra_options, ) # when the order of dimensions are mismatched, the output data will be # transposed to match the input dimension order if output_da.dims != ds[data_var].dims: input_coord_z = get_dim_coords(ds[data_var], "Z") output_order = [ x.replace(, # type: ignore[attr-defined] for x in ds[data_var].dims ] output_da = output_da.transpose(*output_order) if target_data is None: output_da.attrs = ds[data_var].attrs.copy() else: output_da.attrs = target_data.attrs.copy() # type: ignore[union-attr] output_ds = xr.Dataset({data_var: output_da}, attrs=ds.attrs.copy()) output_ds = _preserve_bounds(ds, self._output_grid, output_ds, ["Z"]) return output_ds
[docs] def _get_grid_positions(self) -> Dict[str, Union[Any, Hashable]]: if self._method == "conservative": raise RuntimeError( "Conservative regridding requires a second point position, pass these " "manually." ) try: coord_z = get_dim_coords(self._input_grid, "Z") except KeyError: raise RuntimeError("Could not determine `Z` coordinate in dataset.") if isinstance(coord_z, xr.Dataset): coords = ", ".join(sorted(list(coord_z.coords.keys()))) # type: ignore[arg-type] raise RuntimeError( "Could not determine the `Z` coordinate in the input grid. " f"Found multiple axes ({coords}), ensure there is only a " "single `Z` axis in the input grid.", list(coord_z.coords.keys()), ) try: bounds_z = self._input_grid.bounds.get_bounds("Z") except KeyError: raise RuntimeError("Could not determine `Z` bounds in dataset.") # handle simple point positions based on point and bounds if (coord_z[0] > bounds_z[0][0] and coord_z[0] < bounds_z[0][1]) or ( coord_z[0] < bounds_z[0][0] and coord_z[0] > bounds_z[0][1] ): grid_positions = {"center":} elif coord_z[0] == bounds_z[0][0]: grid_positions = {"left":} elif coord_z[0] == bounds_z[0][1]: grid_positions = {"right":} else: raise RuntimeError( "Could not determine the grid point positions, pass these manually " "using the `grid_positions` argument" ) return {"Z": grid_positions}