Source code for xcdat.regridder.regrid2

from typing import Any, Dict, List, Tuple

import numpy as np
import xarray as xr

from xcdat.regridder.base import BaseRegridder, preserve_bounds


[docs]class Regrid2Regridder(BaseRegridder):
[docs] def __init__( self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Dict[str, Any] ): """ Pure python implementation of the regrid2 horizontal regridder from CDMS2's regrid2 module. Regrid data from ``input_grid`` to ``output_grid``. Available options: None Parameters ---------- input_grid : xr.Dataset Dataset containing the source grid. output_grid : xr.Dataset Dataset containing the destination grid. options : Dict[str, Any] Dictionary with extra parameters for the regridder. Examples -------- Import xCDAT: >>> import xcdat >>> from xcdat.regridder import regrid2 Open a dataset: >>> ds = xcdat.open_dataset("ts.nc") Create output grid: >>> output_grid = xcdat.create_gaussian_grid(32) Create regridder: >>> regridder = regrid2.Regrid2Regridder(ds.grid, output_grid) Regrid data: >>> data_new_grid = regridder.horizontal("ts", ds) """ super().__init__(input_grid, output_grid, **options) self._src_lat = self._input_grid.bounds.get_bounds("Y") self._src_lon = self._input_grid.bounds.get_bounds("X") self._dst_lat = self._output_grid.bounds.get_bounds("Y") self._dst_lon = self._output_grid.bounds.get_bounds("X") self._lat_mapping: Any = None self._lon_mapping: Any = None self._lat_weights: Any = None self._lon_weights: Any = None
[docs] def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Regrid ``data_var`` in ``ds`` to output grid. Mappings and weights between input and output grid are calculated on the first call, allowing a regridder to be applied to many input datasets. Parameters ---------- data_var : str The name of the data variable inside the dataset to regrid. ds : xr.Dataset The dataset containing ``data_var``. Returns ------- xr.Dataset Dataset with variable on the destination grid. Raises ------ KeyError If data variable does not exist in the Dataset. Examples -------- Create output grid: >>> output_grid = xcdat.create_gaussian_grid(32) Create regridder: >>> regridder = regrid2.Regrid2Regridder(ds, output_grid) Regrid data: >>> data_new_grid = regridder.horizontal("ts", ds) """ input_data_var = ds.get(data_var, None) if input_data_var is None: raise KeyError( f"The data variable '{data_var}' does not exist in the dataset." ) # Do initial mapping between src/dst latitude and longitude. if self._lat_mapping is None and self._lat_weights is None: self._lat_mapping, self._lat_weights = _map_latitude( self._src_lat, self._dst_lat ) if self._lon_mapping is None and self._lon_weights is None: self._lon_mapping, self._lon_weights = _map_longitude( self._src_lon, self._dst_lon ) src_mask = self._input_grid.get("mask", None) # apply source mask to input data if src_mask is not None: input_data_var = input_data_var.where(src_mask == 0.0) # operate on pure numpy input_data = input_data_var.values axis_variable_name_map = {x: y[0] for x, y in input_data_var.cf.axes.items()} output_axis_sizes = self._output_axis_sizes(input_data_var) ordered_axis_names = list(output_axis_sizes) output_data = self._regrid(input_data, output_axis_sizes, ordered_axis_names) output_ds = self._create_output_dataset( ds, data_var, output_data, axis_variable_name_map, ordered_axis_names ) dst_mask = self._output_grid.get("mask", None) if dst_mask is not None: output_ds[data_var] = output_ds[data_var].where(dst_mask == 0.0) # preserve non-spatial bounds output_ds = preserve_bounds(ds, self._output_grid, output_ds) output_ds = output_ds.bounds.add_missing_bounds() return output_ds
[docs] def _output_axis_sizes(self, da: xr.DataArray) -> Dict[str, int]: """Maps axes to output array sizes. Parameters ---------- da : xr.DataArray Data array containing variable to be regridded. Returns ------- Dict Mapping of axis name e.g. ("X", "Y", etc) to output sizes. """ output_sizes = {} axis_name_map = {y[0]: x for x, y in da.cf.axes.items()} for standard_name in da.sizes.keys(): axis_name = axis_name_map[standard_name] if standard_name in self._output_grid: output_sizes[axis_name] = self._output_grid.sizes[standard_name] else: output_sizes[axis_name] = da.sizes[standard_name] return output_sizes
[docs] def _regrid( self, input_data: np.ndarray, axis_sizes: Dict[str, int], ordered_axis_names: List[str], ) -> np.ndarray: """Applies regridding to input data. Parameters ---------- input_data : np.ndarray Input multi-dimensional array on source grid. axis_sizes : Dict[str, int] Mapping of axis name e.g. ("X", "Y", etc) to output sizes. ordered_axis_names : List[str] List of axis name in order of dimensions of ``input_data``. Returns ------- np.ndarray Multi-dimensional array on destination grid. """ input_lat_index = ordered_axis_names.index("Y") input_lon_index = ordered_axis_names.index("X") output_shape = [axis_sizes[x] for x in ordered_axis_names] output_data = np.zeros(output_shape, dtype=np.float32) base_put_index = self._base_put_indexes(axis_sizes) for lat_index, lat_map in enumerate(self._lat_mapping): lat_weight = self._lat_weights[lat_index] input_lat_segment = np.take(input_data, lat_map, axis=input_lat_index) for lon_index, lon_map in enumerate(self._lon_mapping): lon_weight = self._lon_weights[lon_index] dot_weight = np.dot(lat_weight, lon_weight) cell_weight = np.sum(dot_weight) input_lon_segment = np.take( input_lat_segment, lon_map, axis=input_lon_index ) data = ( np.multiply(input_lon_segment, dot_weight).sum( axis=(input_lat_index, input_lon_index) ) / cell_weight ) # This only handles lat by lon and not lon by lat put_index = base_put_index + ((lat_index * axis_sizes["X"]) + lon_index) np.put(output_data, put_index, data) return output_data
[docs] def _base_put_indexes(self, axis_sizes: Dict[str, int]) -> np.ndarray: """Calculates the base indexes to place cell (0, 0). Example: For a 3D array (time, lat, lon) with the shape (2, 2, 2) the offsets to place cell (0, 0) in each time step would be [0, 4]. For a 4D array (time, plev, lat, lon) with shape (2, 2, 2, 2) the offsets to place cell (0, 0) in each time step would be [0, 4, 8, 16]. Parameters ---------- axis_sizes : Dict[str, int] Mapping of axis name e.g. ("X", "Y", etc) to output sizes. Returns ------- np.ndarray Array containing the base indexes to be used in np.put operations. """ extra_dims = set(axis_sizes) - set(["X", "Y"]) number_of_offsets = np.multiply.reduce([axis_sizes[x] for x in extra_dims]) offset = np.multiply.reduce( [axis_sizes[x] for x in extra_dims ^ set(axis_sizes)] ) return (np.arange(number_of_offsets) * offset).astype(np.int64)
[docs] def _create_output_dataset( self, input_ds: xr.Dataset, data_var: str, output_data: np.ndarray, axis_variable_name_map: Dict[str, str], ordered_axis_names: List[str], ) -> xr.Dataset: """ Creates the output Dataset containing the new variable on the destination grid. Parameters ---------- input_ds : xr.Dataset Input dataset containing coordinates and bounds for unmodified axes. data_var : str The name of the regridded variable. output_data : np.ndarray Output data array. axis_variable_name_map : Dict[str, str] Map of axis name e.g. ("X", "Y", etc) to variable name e.g. ("lon", "lat", etc). ordered_axis_names : List[str] List of axis names in the order observed for ``output_data``. Returns ------- xr.Dataset Dataset containing the variable on the destination grid. """ variable_axis_name_map = {y: x for x, y in axis_variable_name_map.items()} coords = {} # Grab coords and bounds from appropriate dataset. for variable_name, axis_name in variable_axis_name_map.items(): if axis_name in ["X", "Y"]: coords[variable_name] = self._output_grid[variable_name].copy() else: coords[variable_name] = input_ds[variable_name].copy() output_da = xr.DataArray( output_data, dims=[axis_variable_name_map[x] for x in ordered_axis_names], coords=coords, ) data_vars = {data_var: output_da} return xr.Dataset(data_vars, attrs=input_ds.attrs.copy())
def _map_latitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: """ Map source to destination latitude. Parameters ---------- src : xr.DataArray DataArray containing the source latitude bounds. dst : xr.DataArray DataArray containing the destination latitude bounds. Returns ------- Tuple[List, List] A tuple of cell mappings and cell weights. """ src_south, src_north = _extract_bounds(src) dst_south, dst_north = _extract_bounds(dst) mapping = [] weights = [] for i in range(dst.shape[0]): contrib = np.where( np.logical_and(src_south < dst_north[i], src_north > dst_south[i]) )[0] mapping.append(contrib) north_bounds = np.minimum(dst_north[i], src_north[contrib]) south_bounds = np.maximum(dst_south[i], src_south[contrib]) weight = np.sin(np.deg2rad(north_bounds)) - np.sin(np.deg2rad(south_bounds)) weights.append(weight.values.reshape(contrib.shape[0], 1)) return mapping, weights def _map_longitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: """ Map source to destination longitude. Parameters ---------- src : xr.DataArray DataArray containing source longitude bounds. dst : xr.DataArray DataArray containing destination longitude bounds. Returns ------- Tuple[List, List] A tuple of cell mappings and cell weights. """ src_west, src_east = _extract_bounds(src) dst_west, dst_east = _extract_bounds(dst) shifted_src_west, shifted_src_east, shift = _align_axis( src_west, src_east, dst_west ) mapping = [] weights = [] src_length = src_west.shape[0] for i in range(dst_west.shape[0]): contrib = np.where( np.logical_and( shifted_src_west < dst_east[i], shifted_src_east > dst_west[i] ) )[0] weight = np.minimum(dst_east[i], shifted_src_east[contrib]) - np.maximum( dst_west[i], shifted_src_west[contrib] ) weights.append(weight.values.reshape(1, contrib.shape[0])) contrib += shift wrapped = np.where(contrib > src_length - 1) contrib[wrapped] -= src_length mapping.append(contrib) return mapping, weights def _extract_bounds(bounds: xr.DataArray) -> Tuple[xr.DataArray, xr.DataArray]: """ Extract lower and upper bounds from an axis. Parameters ---------- bounds : xr.DataArray Dataset containing axis with bounds. Returns ------- Tuple[xr.DataArray, xr.DataArray] A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: lower = bounds[:, 0] upper = bounds[:, 1] else: lower = bounds[:, 1] upper = bounds[:, 0] return lower, upper def _align_axis( src_west: xr.DataArray, src_east: xr.DataArray, dst_west: xr.DataArray ) -> Tuple[xr.DataArray, xr.DataArray, int]: """ Aligns a longitudinal source axis to the destination axis. Parameters ---------- src_west : xr.DataArray DataArray containing the western source bounds. src_east : xr.DataArray DataArray containing the eastern source bounds. dst_west : xr.DataArray DataArray containing the western destination bounds. Returns ------- Tuple[xr.DataArray, xr.DataArray, int] A tuple containing the shifted western source bounds, the shifted eastern source bounds, and the number of places shifted to align axis. """ west_most = np.minimum(dst_west[0], dst_west[-1]) alignment_index = _vpertub((west_most - src_west[-1]) / 360.0) if src_west[0] < src_west[-1]: alignment_index += 1 else: alignment_index -= 1 src_alignment_index = np.where( _vpertub((west_most - src_west) / 360.0) != alignment_index )[0][0] if src_west[0] < src_west[-1]: if west_most == src_west[src_alignment_index]: shift = src_alignment_index else: shift = src_alignment_index - 1 if shift < 0: shift = src_west.shape[0] - 1 else: shift = src_alignment_index src_length = src_west.shape[0] shifted_indexes = np.arange(src_length + 1) + shift wrapped = np.where(shifted_indexes > src_length - 1) shifted_indexes[wrapped] -= src_length shifted_src_west = src_west[shifted_indexes] + 360.0 * _vpertub( (west_most - src_west[shifted_indexes]) / 360.0 ) shifted_src_east = src_east[shifted_indexes] + 360.0 * _vpertub( (west_most - src_west[shifted_indexes]) / 360.0 ) if src_west[-1] > src_west[0]: if shifted_src_west[0] > west_most: shifted_src_west[0] += -360.0 shifted_src_east[0] += -360.0 else: if shifted_src_west[-1] > west_most: shifted_src_west[-1] += -360.0 shifted_src_east[-1] += -360.0 return shifted_src_west, shifted_src_east, shift def _pertub(value: xr.DataArray) -> xr.DataArray: """ Pertub a value. Modifies value with a small constant and returns nearest whole number. Parameters ---------- value : xr.DataArray Value to pertub. Returns ------- xr.DataArray Value that's been pertubed. """ if value >= 0.0: offset = np.ceil(value + 0.000001) else: offset = np.floor(value - 0.000001) + 1.0 return xr.DataArray(offset) # vectorize version of pertub _vpertub = np.vectorize(_pertub)