Source code for xcdat.regridder.regrid2

from typing import Any, List, Optional, Tuple

import numpy as np
import xarray as xr

from xcdat.axis import get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds


[docs] class Regrid2Regridder(BaseRegridder):
[docs] def __init__( self, input_grid: xr.Dataset, output_grid: xr.Dataset, unmapped_to_nan=True, **options: Any, ): """ Pure python implementation of the regrid2 horizontal regridder from CDMS2's regrid2 module. Regrid data from ``input_grid`` to ``output_grid``. Available options: None Parameters ---------- input_grid : xr.Dataset Dataset containing the source grid. output_grid : xr.Dataset Dataset containing the destination grid. options : Any Dictionary with extra parameters for the regridder. Examples -------- Import xCDAT: >>> import xcdat Open a dataset: >>> ds = xcdat.open_dataset("...") Create output grid: >>> output_grid = xcdat.create_gaussian_grid(32) Regrid data: >>> output_data = ds.regridder.horizontal("ts", output_grid) """ super().__init__(input_grid, output_grid, **options) self._unmapped_to_nan = unmapped_to_nan
[docs] def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Placeholder for base class.""" raise NotImplementedError()
[docs] def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """See documentation in :py:func:`xcdat.regridder.regrid2.Regrid2Regridder`""" try: input_data_var = ds[data_var] except KeyError: raise KeyError( f"The data variable {data_var!r} does not exist in the dataset." ) from None src_lat_bnds = _get_bounds_ensure_dtype(self._input_grid, "Y") src_lon_bnds = _get_bounds_ensure_dtype(self._input_grid, "X") dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y") dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X") src_mask_da = self._input_grid.get("mask", None) # DataArray to np.ndarray, handle error when None try: src_mask = src_mask_da.values # type: ignore except AttributeError: src_mask = None nan_replace = input_data_var.encoding.get("_FillValue", None) if nan_replace is None: nan_replace = input_data_var.encoding.get("missing_value", 1e20) # exclude alternative of NaN values if there are any input_data_var = input_data_var.where(input_data_var != nan_replace) # horizontal regrid output_data = _regrid( input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds, src_mask, unmapped_to_nan=self._unmapped_to_nan, ) output_ds = _build_dataset( ds, data_var, output_data, dst_lat_bnds, dst_lon_bnds, self._input_grid, self._output_grid, ) return output_ds
def _regrid( input_data_var: xr.DataArray, src_lat_bnds: np.ndarray, src_lon_bnds: np.ndarray, dst_lat_bnds: np.ndarray, dst_lon_bnds: np.ndarray, src_mask: Optional[np.ndarray], omitted=None, unmapped_to_nan=True, ) -> np.ndarray: if omitted is None: omitted = np.nan lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds) lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) # convert to pure numpy input_data = input_data_var.astype(np.float32).values y_name, y_index = _get_dimension(input_data_var, "Y") x_name, x_index = _get_dimension(input_data_var, "X") y_length = len(lat_mapping) x_length = len(lon_mapping) if src_mask is None: input_data_shape = input_data.shape src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index])) other_dims = { x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name) } other_sizes = list(other_dims.values()) data_shape = [y_length * x_length] + other_sizes # output data is always float32 in original code output_data = np.zeros(data_shape, dtype=np.float32) output_mask = np.ones(data_shape, dtype=np.float32) is_2d = input_data_var.ndim <= 2 # TODO: need to optimize further, investigate using ufuncs and dask arrays # TODO: how common is lon by lat data? may need to reshape for y in range(y_length): y_seg = np.take(input_data, lat_mapping[y], axis=y_index) y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0) for x in range(x_length): x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap") cell_weights = np.multiply( np.dot(lat_weights[y], lon_weights[x]), x_mask_seg ) cell_weight = np.sum(cell_weights) output_seg_index = y * x_length + x if cell_weight == 0.0: output_mask[output_seg_index] = 0.0 # using the `out` argument is more performant, places data directly into # array memory rather than allocating a new variable. wasn't working for # single element output, needs further investigation as we may not need # branch if is_2d: output_data[output_seg_index] = np.divide( np.sum( np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), cell_weight, ) else: output_seg = output_data[output_seg_index] np.divide( np.sum( np.multiply(x_seg, cell_weights), axis=(y_index, x_index), ), cell_weight, out=output_seg, ) if cell_weight <= 0.0: output_data[output_seg_index] = omitted # default for unmapped is nan due to division by zero, use output mask to repalce if not unmapped_to_nan: output_data[output_mask == 0.0] = 0.0 output_data_shape = [y_length, x_length] + other_sizes output_data = output_data.reshape(output_data_shape) output_order = [x + 2 for x in range(input_data_var.ndim - 2)] + [0, 1] output_data = output_data.transpose(output_order) return output_data.astype(np.float32) def _build_dataset( ds: xr.Dataset, data_var: str, output_data: np.ndarray, dst_lat_bnds, dst_lon_bnds, input_grid: xr.Dataset, output_grid: xr.Dataset, ) -> xr.Dataset: input_data_var = ds[data_var] output_coords: dict[str, xr.DataArray] = {} output_data_vars: dict[str, xr.DataArray] = {} dims = list(input_data_var.dims) output_da = xr.DataArray( output_data, dims=dims, coords=output_coords, attrs=ds[data_var].attrs.copy(), name=data_var, ) output_data_vars[data_var] = output_da output_ds = xr.Dataset( output_data_vars, attrs=input_grid.attrs.copy(), ) output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) return output_ds def _map_latitude( src: np.ndarray, dst: np.ndarray ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Map source to destination latitude. Source cells are grouped by the contribution to each output cell. Source cells have new boundaries calculated by finding minimum northern and maximum southern boundary between each source cell and the destination cell it contributes to. The source cell weights are calculated by taking the difference of sin's between these new boundary pairs. Parameters ---------- src : np.ndarray Array containing the source latitude bounds. dst : np.ndarray Array containing the destination latitude bounds. Returns ------- Tuple[List[np.ndarray], List[np.ndarray]] A tuple of cell mappings and cell weights. """ src_south, src_north = _extract_bounds(src) dst_south, dst_north = _extract_bounds(dst) dst_length = dst_south.shape[0] # finds contributing source cells for each destination cell based on bounds values # output is a list of lists containing the contributing cell indexes # e.g. let src_south be [90, 45, 0, -45], source_north be [45, 0, -45, -90], # dst_north[x] be 70, and dst_south[x] be -70 then the result would be [[1, 2]] mapping = [ np.where(np.logical_and(src_south < dst_north[x], src_north > dst_south[x]))[0] for x in range(dst_length) ] # finds minimum and maximum bounds for each output cell, considers source and # destination bounds for each cell bounds = [ (np.minimum(dst_north[x], src_north[y]), np.maximum(dst_south[x], src_south[y])) for x, y in enumerate(mapping) ] # convert latitude to cell weight (difference of height above/below equator) weights = _get_latitude_weights(bounds) return mapping, weights def _get_latitude_weights( bounds: List[Tuple[np.ndarray, np.ndarray]] ) -> List[np.ndarray]: weights = [] for x, y in bounds: cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y)) cell_weight = cell_weight.reshape((-1, 1)) weights.append(cell_weight) return weights def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: """ Map source to destination longitude. Source boundaries are aligned to the most western destination cell. Source cells are grouped by the contribution to each output cell. The source cell weights are calculated by find the difference of the following min/max for each input cell. Minimum of eastern source bounds and the eastern bounds of the destination cell it contributes to. Maximum of western source bounds and the western bounds of the destination cell it contributes to. These weights are then shifted to align with the destination longitude. Parameters ---------- src : np.ndarray Array containing source longitude bounds. dst : np.ndarray Array containing destination longitude bounds. Returns ------- Tuple[List, List] A tuple of cell mappings and cell weights. """ src_west, src_east = _extract_bounds(src) dst_west, dst_east = _extract_bounds(dst) # align source and destination longitude shifted_src_west, shifted_src_east, shift = _align_axis( src_west, src_east, dst_west, ) src_length = src_west.shape[0] dst_length = dst_west.shape[0] # finds contributing source cells for each destination cell based on bounds values # output is a list of lists containing the contributing cell indexes mapping = [ np.where( np.logical_and( shifted_src_west < dst_east[x], shifted_src_east > dst_west[x] ) )[0] for x in range(dst_length) ] # weights are just the difference between minimum and maximum of contributing bounds # for each destination cell weights = [ ( np.minimum(dst_east[x], shifted_src_east[y]) - np.maximum(dst_west[x], shifted_src_west[y]) ).reshape((1, -1)) for x, y in enumerate(mapping) ] # need to adjust the source contributing indexes by the shift required to align # source and destination longitude for x in range(len(mapping)): # shift the mapping indexes by the shift used to determine the weights mapping[x] += shift # find the contributing indexes that need to be wrapped wrapped = np.where(mapping[x] > src_length - 1)[0] # wrap the contributing index as all indexes must be <src_length mapping[x][wrapped] -= src_length return mapping, weights def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Extract lower and upper bounds from an axis. Parameters ---------- bounds : np.ndarray A numpy array of bounds values. Returns ------- Tuple[np.ndarray, np.ndarray] A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: lower = bounds[:, 0] upper = bounds[:, 1] else: lower = bounds[:, 1] upper = bounds[:, 0] return lower.astype(np.float32), upper.astype(np.float32) def _align_axis( src_west: np.ndarray, src_east: np.ndarray, dst_west: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, int]: """ Aligns a source and destination longitude axis. Parameters ---------- src_west : np.ndarray Array containing the western source bounds. src_east : np.ndarray Array containing the eastern source bounds. dst_west : np.ndarray Array containing the western destination bounds. Returns ------- Tuple[np.ndarray, np.ndarray, int] A tuple containing the shifted western source bounds, the shifted eastern source bounds, and the number of places shifted to align axis. """ # find smallest western bounds west_most = np.minimum(dst_west[0], dst_west[-1]) # find cell index required to align bounds alignment_index = _vpertub((west_most - src_west[-1]) / 360.0) # shift index depending on first/last source bounds alignment_index = ( alignment_index + 1 if src_west[0] < src_west[-1] else alignment_index - 1 ) # find relative indexes for each source cell to the destinations most western cell relative_postition = _vpertub((west_most - src_west) / 360.0) # find all index values that are not the alignment index src_alignment_index = np.where(relative_postition != alignment_index)[0][0] # determine the shift factor required to align source and destination bounds if src_west[0] < src_west[-1]: if west_most == src_west[src_alignment_index]: shift = src_alignment_index else: shift = src_alignment_index - 1 if shift < 0: shift = src_west.shape[0] - 1 else: shift = src_alignment_index src_length = src_west.shape[0] # shift the source index values shifted_indexes = np.arange(src_length + 1) + shift # find index values that need to be shift to be within 0 - src_length wrapped = np.where(shifted_indexes > src_length - 1) # shift the indexes shifted_indexes[wrapped] -= src_length # reorder src_west and add portion to align shifted_src_west = ( src_west[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) # reorder src_east and add portion to align shifted_src_east = ( src_east[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) # handle ends of each interval if src_west[-1] > src_west[0]: if shifted_src_west[0] > west_most: shifted_src_west[0] += -360.0 shifted_src_east[0] += -360.0 else: if shifted_src_west[-1] > west_most: shifted_src_west[-1] += -360.0 shifted_src_east[-1] += -360.0 return shifted_src_west, shifted_src_east, shift def _pertub(value: np.ndarray) -> np.ndarray: """ Pertub a value. Modifies value with a small constant and returns nearest whole number. Parameters ---------- value : np.ndarray Value to pertub. Returns ------- np.ndarray Value that's been pertubed. """ if value >= 0.0: offset = np.ceil(value + 0.000001) else: offset = np.floor(value - 0.000001) + 1.0 return offset # vectorize version of pertub _vpertub = np.vectorize(_pertub) def _get_dimension(input_data_var, cf_axis_name): name = get_dim_keys(input_data_var, cf_axis_name) index = input_data_var.dims.index(name) return name, index def _get_bounds_ensure_dtype(ds, axis): try: name = ds.cf.bounds[axis][0] except (KeyError, IndexError): raise RuntimeError(f"Could not determine {axis!r} bounds") else: bounds = ds[name] if bounds.dtype != np.float32: bounds = bounds.astype(np.float32) return bounds.values