Source code for xcdat.axis

"""
Axis module for utilities related to axes, including functions to manipulate
coordinates.
"""
from typing import Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import xarray as xr
from dask.array.core import Array

# https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axis-names
CFAxisName = Literal["X", "Y", "T", "Z"]
# https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#coordinate-names
CFStandardName = Literal["latitude", "longitude", "time", "height", "pressure"]
ShortName = Literal["lat", "lon"]

# The key is the accepted value for method and function arguments, and the
# values are the CF-compliant axis and standard names that are interpreted in
# the dataset.
CF_NAME_MAP: Dict[CFAxisName, List[Union[CFAxisName, CFStandardName, ShortName]]] = {
    "X": ["X", "longitude", "lon"],
    "Y": ["Y", "latitude", "lat"],
    "T": ["T", "time"],
    "Z": ["Z", "height", "pressure"],
}


[docs]def get_axis_coord( obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisName ) -> xr.DataArray: """Gets the coordinate variable for an axis. This function uses ``cf_xarray`` to try to find the matching coordinate variable by checking the following attributes in order: - ``"axis"`` - ``"standard_name"`` - Dimension name - Must follow the valid short-hand convention - For example, ``"lat"`` for latitude and ``"lon"`` for longitude Parameters ---------- obj : Union[xr.Dataset, xr.DataArray] The Dataset or DataArray object. axis : CFAxisName The CF-compliant axis name ("X", "Y", "T", "Z"). Returns ------- xr.DataArray The coordinate variable. Raises ------ KeyError If the coordinate variable was not found. Notes ----- Refer to [1]_ for a list of CF-compliant ``"axis"`` and ``"standard_name"`` attr names that can be interpreted by ``cf_xarray``. References ---------- .. [1] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates """ keys = CF_NAME_MAP[axis] coord_var = None for key in keys: try: coord_var = obj.cf[key] break except KeyError: pass if coord_var is None: raise KeyError( f"A coordinate variable for the {axis} axis was not found. Make sure " "the coordinate variable exists and either the (1) 'axis' attr or (2) " "'standard_name' attr is set, or (3) the dimension name follows the " "short-hand convention (e.g.,'lat')." ) return coord_var
[docs]def get_axis_dim(obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisName) -> str: """Gets the dimension for an axis. The coordinate name should be identical to the dimension name, so this function simply returns the coordinate name. Parameters ---------- obj : Union[xr.Dataset, xr.DataArray] The Dataset or DataArray object. axis : CFAxisName The CF-compliant axis name ("X", "Y", "T", "Z") Returns ------- str The dimension for an axis. """ return str(get_axis_coord(obj, axis).name)
[docs]def center_times(dataset: xr.Dataset) -> xr.Dataset: """Centers time coordinates using the midpoint between time bounds. Time coordinates can be recorded using different intervals, including the beginning, middle, or end of the interval. Centering time coordinates, ensures calculations using these values are performed reliably regardless of the recorded interval. Parameters ---------- dataset : xr.Dataset The Dataset with original time coordinates. Returns ------- xr.Dataset The Dataset with centered time coordinates. """ ds = dataset.copy() time: xr.DataArray = get_axis_coord(ds, "T") time_bounds = ds.bounds.get_bounds("T") lower_bounds, upper_bounds = (time_bounds[:, 0].data, time_bounds[:, 1].data) bounds_diffs: np.timedelta64 = (upper_bounds - lower_bounds) / 2 bounds_mids: np.ndarray = lower_bounds + bounds_diffs time_centered = xr.DataArray( name=time.name, data=bounds_mids, coords={"time": bounds_mids}, attrs=time.attrs, ) time_centered.encoding = time.encoding ds = ds.assign_coords({"time": time_centered}) # Update time bounds with centered time coordinates. time_bounds[time_centered.name] = time_centered ds[time_bounds.name] = time_bounds return ds
[docs]def swap_lon_axis( dataset: xr.Dataset, to: Tuple[float, float], sort_ascending: bool = True ) -> xr.Dataset: """Swaps the orientation of a dataset's longitude axis. This method also swaps the axis orientation of the longitude bounds if it exists. Afterwards, it sorts longitude and longitude bounds values in ascending order. Parameters ---------- dataset : xr.Dataset The Dataset containing a longitude axis. to : Tuple[float, float] The orientation to swap the Dataset's longitude axis to. Supported orientations: * (-180, 180): represents [-180, 180) in math notation * (0, 360): represents [0, 360) in math notation sort_ascending : bool After swapping, sort in ascending order (True), or keep existing order (False). Returns ------- xr.Dataset The Dataset with swapped lon axes orientation. """ ds = dataset.copy() lon: xr.DataArray = get_axis_coord(ds, "X").copy() lon_bounds: xr.DataArray = dataset.bounds.get_bounds("X").copy() with xr.set_options(keep_attrs=True): if to == (-180, 180): new_lon = ((lon + 180) % 360) - 180 new_lon_bounds = ((lon_bounds + 180) % 360) - 180 ds = _reassign_lon(ds, new_lon, new_lon_bounds) elif to == (0, 360): new_lon = lon % 360 new_lon_bounds = lon_bounds % 360 ds = _reassign_lon(ds, new_lon, new_lon_bounds) # Handle cases where a prime meridian cell exists, which can occur # after swapping to (0, 360). p_meridian_index = _get_prime_meridian_index(new_lon_bounds) if p_meridian_index is not None: ds = _align_lon_to_360(ds, p_meridian_index) else: raise ValueError( "Currently, only (-180, 180) and (0, 360) are supported longitude axis " "orientations." ) # If the swapped axis orientation is the same as the existing axis # orientation, return the original Dataset. if new_lon.identical(lon): return dataset if sort_ascending: ds = ds.sortby(new_lon.name, ascending=True) return ds
def _reassign_lon(dataset: xr.Dataset, lon: xr.DataArray, lon_bounds: xr.DataArray): """ Reassign longitude coordinates and bounds to the Dataset after swapping the orientation. Parameters ---------- dataset : xr.Dataset The Dataset. lon : xr.DataArray The swapped longitude coordinates. lon_bounds : xr.DataArray The swapped longitude bounds. Returns ------- xr.Dataset The Dataset with swapped longitude coordinates and bounds. """ lon[lon.name] = lon_bounds[lon.name] = lon dataset[lon.name] = lon dataset[lon_bounds.name] = lon_bounds return dataset def _align_lon_to_360(dataset: xr.Dataset, p_meridian_index: np.ndarray) -> xr.Dataset: """Handles a prime meridian cell to align longitude axis to (0, 360). This method ensures the domain bounds are within 0 to 360 by handling the grid cell that encompasses the prime meridian (e.g., [359, 1]). First, it handles the prime meridian cell within the longitude axis bounds by splitting the cell into two parts (one east and one west of the prime meridian, refer to `_align_lon_bounds_to_360()` for more information). Then it concatenates the 360 coordinate point to the longitude coordinates to handle the addition of the extra grid cell from the previous step. Finally, for each data variable associated with the longitude axis, the value of the data variable at the prime meridian cell is concatenated to the data variable. Parameters ---------- dataset : xr.Dataset The Dataset. p_meridian_index : np.ndarray An array with a single element representing the index of the prime meridian cell. Returns ------- xr.Dataset The Dataset. """ ds = dataset.copy() lon: xr.DataArray = get_axis_coord(ds, "X") lon_bounds: xr.DataArray = dataset.bounds.get_bounds("X") # If chunking, must convert the xarray data structure from lazy # Dask arrays into eager, in-memory NumPy arrays before performing # manipulations on the data. Otherwise, it raises `NotImplementedError # xarray can't set arrays with multiple array indices to dask yet`. if isinstance(lon_bounds.data, Array): lon_bounds.load() # Align the the longitude bounds using the prime meridian index. lon_bounds = _align_lon_bounds_to_360(lon_bounds, p_meridian_index) # Concatenate the longitude coordinates with 360 to handle the prime # meridian cell and update the coordinates for the longitude bounds. p_meridian_cell = xr.DataArray([360.0], coords={lon.name: [360.0]}, dims=[lon.name]) lon = xr.concat((lon, p_meridian_cell), dim=lon.name) lon_bounds[lon.name] = lon # Get the data variables related to the longitude axis and concatenate each # with the value at the prime meridian. lon_vars = {} for key, value in ds.cf.data_vars.items(): if key != lon_bounds.name and lon.name in value.dims: lon_vars[key] = value for name, var in lon_vars.items(): p_meridian_val = var.isel({lon.name: p_meridian_index}) new_var = xr.concat((var, p_meridian_val), dim=lon.name) new_var[lon.name] = lon lon_vars[name] = new_var # Create a Dataset with longitude data vars and merge it to the Dataset # without longitude data vars. ds_lon = xr.Dataset(data_vars={**lon_vars, lon_bounds.name: lon_bounds}) ds_no_lon = ds.get([v for v in ds.data_vars if lon.name not in ds[v].dims]) # type: ignore ds = xr.merge((ds_no_lon, ds_lon)) return ds def _align_lon_bounds_to_360( bounds: xr.DataArray, p_meridian_index: np.ndarray ) -> xr.DataArray: """Handles a prime meridian cell to align longitude bounds axis to (0, 360). This method ensures the domain bounds are within 0 to 360 by handling the grid cell that encompasses the prime meridian (e.g., [359, 1]). In this case, calculating longitudinal weights is complicated because the weights are determined by the difference of the bounds. If this situation exists, the method will split this grid cell into two parts (one east and west of the prime meridian). The original grid cell will have domain bounds extending east of the prime meridian and an extra set of bounds will be concatenated to ``bounds`` corresponding to the domain bounds west of the prime meridian. For instance, a grid cell spanning -1 to 1, will be broken into a cell from 0 to 1 and 359 to 360 (or -1 to 0). Parameters ---------- bounds : xr.DataArray The longitude domain bounds with prime meridian cell. p_meridian_index : np.ndarray The index of the prime meridian cell. Returns ------- xr.DataArray The longitude domain bounds with split prime meridian cell. Raises ------ ValueError If longitude bounds are inclusively between 0 and 360. """ # Example array: [[359, 1], [1, 90], [90, 180], [180, 359]] # Reorient bound to span across zero (i.e., [359, 1] -> [-1, 1]). # Result: [[-1, 1], [1, 90], [90, 180], [180, 359]] bounds[p_meridian_index, 0] = bounds[p_meridian_index, 0] - 360.0 # Extend the array to nlon+1 by concatenating the grid cell that # spans the prime meridian to the end. # Result: [[-1, 1], [1, 90], [90, 180], [180, 359], [-1, 1]] dim = get_axis_dim(bounds, "X") bounds = xr.concat((bounds, bounds[p_meridian_index, :]), dim=dim) # Add an equivalent bound that spans 360 # (i.e., [-1, 1] -> [359, 361]) to the end of the array. # Result: [[-1, 1], [1, 90], [90, 180], [180, 359], [359, 361]] repeat_bound = bounds[p_meridian_index, :][0] + 360.0 bounds[-1, :] = repeat_bound # Update the lower-most min and upper-most max bounds to [0, 360]. # Result: [[0, 1], [1, 90], [90, 180], [180, 359], [359, 360]] bounds[p_meridian_index, 0], bounds[-1, 1] = (0.0, 360.0) return bounds def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> Optional[np.ndarray]: """Gets the index of the prime meridian cell in the longitude bounds. A prime meridian cell can exist when converting the axis orientation from [-180, 180) to [0, 360). Parameters ---------- lon_bounds : xr.DataArray The longitude bounds. Returns ------- Optional[np.ndarray] An array with a single elementing representing the index of the prime meridian index if it exists. Otherwise, None if the cell does not exist. Raises ------ ValueError If more than one grid cell spans the prime meridian. """ p_meridian_index = np.where(lon_bounds[:, 1] - lon_bounds[:, 0] < 0)[0] # FIXME: When does this conditional return true? It seems like swapping from # (-180, to 180) to (0, 360) always produces a prime meridian cell? if p_meridian_index.size == 0: # pragma:no cover return None elif p_meridian_index.size > 1: raise ValueError("More than one grid cell spans prime meridian.") return p_meridian_index