from __future__ import annotations
from typing import Any, List, Literal, Tuple
import xarray as xr
from xcdat.axis import CFAxisKey, get_dim_coords
from xcdat.regridder import regrid2, xesmf, xgcm
from xcdat.regridder.grid import _validate_grid_has_single_axis_dim
HorizontalRegridTools = Literal["xesmf", "regrid2"]
HORIZONTAL_REGRID_TOOLS = {
"regrid2": regrid2.Regrid2Regridder,
"xesmf": xesmf.XESMFRegridder,
}
VerticalRegridTools = Literal["xgcm"]
VERTICAL_REGRID_TOOLS = {"xgcm": xgcm.XGCMRegridder}
[docs]
@xr.register_dataset_accessor(name="regridder")
class RegridderAccessor:
"""
An accessor class that provides regridding attributes and methods for
xarray Datasets through the ``.regridder`` attribute.
Examples
--------
Import xCDAT:
>>> import xcdat
Use RegridderAccessor class:
>>> ds = xcdat.open_dataset("...")
>>>
>>> ds.regridder.<attribute>
>>> ds.regridder.<method>
>>> ds.regridder.<property>
Parameters
----------
dataset : xr.Dataset
The Dataset to attach this accessor.
"""
[docs]
def __init__(self, dataset: xr.Dataset):
self._ds: xr.Dataset = dataset
@property
def grid(self) -> xr.Dataset:
"""
Extract the `X`, `Y`, and `Z` axes from the Dataset and return a new
``xr.Dataset``.
Returns
-------
xr.Dataset
Containing grid axes.
Raises
------
ValueError
If axis dimension coordinate variable is not correctly identified.
ValueError
If axis has multiple dimensions (only one is expected).
Examples
--------
Import xCDAT:
>>> import xcdat
Open a dataset:
>>> ds = xcdat.open_dataset("...")
Extract grid from dataset:
>>> grid = ds.regridder.grid
"""
with xr.set_options(keep_attrs=True):
coords = {}
axis_names: List[CFAxisKey] = ["X", "Y", "Z"]
for axis in axis_names:
try:
data, bnds = self._get_axis_data(axis)
except KeyError:
continue
coords[data.name] = data.copy()
if bnds is not None:
coords[bnds.name] = bnds.copy()
ds = xr.Dataset(coords, attrs=self._ds.attrs)
ds = ds.bounds.add_missing_bounds(axes=["X", "Y", "Z"])
return ds
[docs]
def _get_axis_data(
self, name: CFAxisKey
) -> Tuple[xr.DataArray | xr.Dataset, xr.DataArray]:
coord_var = get_dim_coords(self._ds, name)
_validate_grid_has_single_axis_dim(name, coord_var)
try:
bounds_var = self._ds.bounds.get_bounds(name, coord_var.name)
except KeyError:
bounds_var = None
return coord_var, bounds_var
[docs]
def horizontal(
self,
data_var: str,
output_grid: xr.Dataset,
tool: HorizontalRegridTools = "xesmf",
**options: Any,
) -> xr.Dataset:
"""
Transform ``data_var`` to ``output_grid``.
When might ``Regrid2`` be preferred over ``xESMF``?
If performing conservative regridding from a high/medium resolution lat/lon grid to a
coarse lat/lon target, ``Regrid2`` may provide better results as it assumes grid cells
with constant latitudes and longitudes while ``xESMF`` assumes the cells are connected
by Great Circles [1]_.
Supported tools, methods and grids:
- xESMF (https://pangeo-xesmf.readthedocs.io/en/latest/)
- Methods: Bilinear, Conservative, Conservative Normed, Patch, Nearest s2d, or Nearest d2s.
- Grids: Rectilinear, or Curvilinear.
- Find options at :py:func:`xcdat.regridder.xesmf.XESMFRegridder`
- Regrid2
- Methods: Conservative
- Grids: Rectilinear
- Find options at :py:func:`xcdat.regridder.regrid2.Regrid2Regridder`
Parameters
----------
data_var: str
Name of the variable to transform.
output_grid : xr.Dataset
Grid to transform ``data_var`` to.
tool : str
Name of the tool to use.
**options : Any
These options are passed directly to the ``tool``. See specific
regridder for available options.
Returns
-------
xr.Dataset
With the ``data_var`` transformed to the ``output_grid``.
Raises
------
ValueError
If tool is not supported.
References
----------
.. [1] https://earthsystemmodeling.org/docs/release/ESMF_8_1_0/ESMF_refdoc/node5.html#SECTION05012900000000000000
Examples
--------
Import xCDAT:
>>> import xcdat
Open a dataset:
>>> ds = xcdat.open_dataset("...")
Create output grid:
>>> output_grid = xcdat.create_uniform_grid(-90, 90, 4.0, -180, 180, 5.0)
Regrid variable using "xesmf":
>>> output_data = ds.regridder.horizontal("ts", output_grid, tool="xesmf", method="bilinear")
Regrid variable using "regrid2":
>>> output_data = ds.regridder.horizontal("ts", output_grid, tool="regrid2")
"""
try:
regrid_tool = HORIZONTAL_REGRID_TOOLS[tool]
except KeyError as e:
raise ValueError(
f"Tool {e!s} does not exist, valid choices {list(HORIZONTAL_REGRID_TOOLS)}"
)
input_grid = _get_input_grid(self._ds, data_var, ["X", "Y"])
regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.horizontal(data_var, self._ds)
return output_ds
[docs]
def vertical(
self,
data_var: str,
output_grid: xr.Dataset,
tool: VerticalRegridTools = "xgcm",
**options: Any,
) -> xr.Dataset:
"""
Transform ``data_var`` to ``output_grid``.
Supported tools:
- xgcm (https://xgcm.readthedocs.io/en/latest/index.html)
- Methods: Linear, Conservative, Log
- Find options at :py:func:`xcdat.regridder.xgcm.XGCMRegridder`
Parameters
----------
data_var: str
Name of the variable to transform.
output_grid : xr.Dataset
Grid to transform ``data_var`` to.
tool : str
Name of the tool to use.
**options : Any
These options are passed directly to the ``tool``. See specific
regridder for available options.
Returns
-------
xr.Dataset
With the ``data_var`` transformed to the ``output_grid``.
Raises
------
ValueError
If tool is not supported.
Examples
--------
Import xCDAT:
>>> import xcdat
Open a dataset:
>>> ds = xcdat.open_dataset("...")
Create output grid:
>>> output_grid = xcdat.create_grid(lev=np.linspace(1000, 1, 20))
Regrid variable using "xgcm":
>>> output_data = ds.regridder.vertical("so", output_grid, method="linear")
"""
try:
regrid_tool = VERTICAL_REGRID_TOOLS[tool]
except KeyError as e:
raise ValueError(
f"Tool {e!s} does not exist, valid choices "
f"{list(VERTICAL_REGRID_TOOLS)}"
)
input_grid = _get_input_grid(
self._ds,
data_var,
[
"Z",
],
)
regridder = regrid_tool(input_grid, output_grid, **options)
output_ds = regridder.vertical(data_var, self._ds)
return output_ds
def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: List[CFAxisKey]):
"""
Extract the grid from ``ds``.
This function will remove any duplicate dimensions leaving only dimensions
used by the ``data_var``. All extraneous dimensions and variables are
dropped, returning only the grid.
Parameters
----------
ds : xr.Dataset
Dataset to extract grid from.
data_var : str
Name of target data variable.
dup_check_dims : List[CFAxisKey]
List of dimensions to check for duplicates.
Returns
-------
xr.Dataset
Dataset containing grid dataset.
"""
to_drop = []
all_coords = set(ds.coords.keys())
for dimension in dup_check_dims:
coords = get_dim_coords(ds, dimension)
if isinstance(coords, xr.Dataset):
coord = set([get_dim_coords(ds[data_var], dimension).name])
dimension_coords = set(ds.cf[[dimension]].coords.keys())
# need to take the intersection after as `ds.cf[["Z"]]` will hand back data variables
to_drop += list(dimension_coords.difference(coord).intersection(all_coords))
input_grid = ds.drop_dims(to_drop)
# drops extra dimensions on input grid
grid = input_grid.regridder.grid
# preserve mask on grid
if "mask" in ds:
grid["mask"] = ds["mask"].copy()
return grid