from typing import Any
import numpy as np
import sparse as sp
import xarray as xr
import xcdat as xc
from xcdat.axis import get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds
from xcdat.regridder.grid import create_mask, create_nan_mask
[docs]
class Regrid2Regridder(BaseRegridder):
[docs]
def __init__(
self,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
unmapped_to_nan: bool = True,
output_weights: bool | str = False,
create_nan_mask: bool = False,
**options: Any,
):
"""
Pure python implementation of the regrid2 horizontal regridder from
CDMS2's regrid2 module.
Regrid data from ``input_grid`` to ``output_grid``.
Available options: None
Parameters
----------
input_grid : xr.Dataset
Dataset containing the source grid.
output_grid : xr.Dataset
Dataset containing the destination grid.
unmapped_to_nan : bool
If True, unmapped values are set to NaN. Default is True.
output_weights : bool | str
If True, output weights are added to the output dataset as weights.
If str, the name of the variable to store the weights. Default is False.
create_nan_mask : bool
If True, a mask is created using the nan values from source variable. If
a mask already exists in the Dataset it will be ignored. Default is False.
**options : Any
Dictionary with extra parameters for the regridder.
Examples
--------
Import xCDAT:
>>> import xcdat
Open a dataset:
>>> ds = xcdat.open_dataset("...")
Create output grid:
>>> output_grid = xcdat.create_gaussian_grid(32)
Regrid data:
>>> output_data = ds.regridder.horizontal("ts", output_grid)
"""
super().__init__(input_grid, output_grid, **options)
self._unmapped_to_nan = unmapped_to_nan
self._output_weights = output_weights
self._create_nan_mask = create_nan_mask
[docs]
def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
"""Placeholder for base class."""
raise NotImplementedError()
[docs]
def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
"""See documentation in :py:func:`xcdat.regridder.regrid2.Regrid2Regridder`"""
try:
input_data_var = ds[data_var]
except KeyError:
raise KeyError(
f"The data variable {data_var!r} does not exist in the dataset."
) from None
src_lat_bnds = _get_bounds_ensure_dtype(self._input_grid, "Y")
src_lon_bnds = _get_bounds_ensure_dtype(self._input_grid, "X")
dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y")
dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X")
if self._create_nan_mask:
src_mask = create_nan_mask(input_data_var, ["Y", "X"]).values
else:
# DataArray to np.ndarray, handle error when None
try:
src_mask = self._input_grid.get("mask", None).values # type: ignore
except AttributeError:
# regrid2 requires a mask, so create one
src_mask = create_mask(self._input_grid, ["Y", "X"]).values
nan_replace = input_data_var.encoding.get("_FillValue", None)
if nan_replace is None:
nan_replace = input_data_var.encoding.get("missing_value", 1e20)
# exclude alternative of NaN values if there are any
input_data_var = input_data_var.where(input_data_var != nan_replace)
lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds)
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)
# horizontal regrid
output_data = _regrid(
input_data_var,
lat_mapping,
lon_mapping,
lat_weights,
lon_weights,
src_mask,
unmapped_to_nan=self._unmapped_to_nan,
)
output_ds = _build_dataset(
ds,
data_var,
output_data,
self._input_grid,
self._output_grid,
)
if self._output_weights:
weights = _sparse_weights(
(len(src_lat_bnds), len(src_lon_bnds)),
(len(dst_lat_bnds), len(dst_lon_bnds)),
len(src_lon_bnds),
len(dst_lon_bnds),
lat_mapping,
lon_mapping,
lat_weights,
lon_weights,
)
if isinstance(self._output_weights, str):
output_ds[self._output_weights] = weights
else:
output_ds["weights"] = weights
return output_ds
def _regrid(
input_data_var: xr.DataArray,
lat_mapping: list[np.ndarray],
lon_mapping: list[np.ndarray],
lat_weights: list[np.ndarray],
lon_weights: list[np.ndarray],
src_mask: np.ndarray,
omitted=None,
unmapped_to_nan=True,
) -> np.ndarray:
if omitted is None:
omitted = np.nan
# convert to pure numpy
input_data = input_data_var.astype(np.float32).values
y_name, y_index = _get_dimension(input_data_var, "Y")
x_name, x_index = _get_dimension(input_data_var, "X")
y_length = len(lat_mapping)
x_length = len(lon_mapping)
other_dims = {
x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name)
}
other_sizes = list(other_dims.values())
data_shape = [y_length * x_length] + other_sizes
# output data is always float32 in original code
output_data = np.zeros(data_shape, dtype=np.float32)
output_mask = np.ones(data_shape, dtype=np.float32)
is_2d = input_data_var.ndim <= 2
# TODO: need to optimize further, investigate using ufuncs and dask arrays
# TODO: how common is lon by lat data? may need to reshape
for y in range(y_length):
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)
y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0)
for x in range(x_length):
x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap")
x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap")
cell_weights = np.multiply(
np.dot(lat_weights[y], lon_weights[x]), x_mask_seg
)
cell_weight = np.sum(cell_weights)
output_seg_index = y * x_length + x
if cell_weight == 0.0:
output_mask[output_seg_index] = 0.0
# using the `out` argument is more performant, places data directly into
# array memory rather than allocating a new variable. wasn't working for
# single element output, needs further investigation as we may not need
# branch
if is_2d:
output_data[output_seg_index] = np.divide(
np.sum(
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
cell_weight,
)
else:
output_seg = output_data[output_seg_index]
np.divide(
np.sum(
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
cell_weight,
out=output_seg,
)
if cell_weight <= 0.0:
output_data[output_seg_index] = omitted
# default for unmapped is nan due to division by zero, use output mask to repalce
if not unmapped_to_nan:
output_data[output_mask == 0.0] = 0.0
output_data_shape = [y_length, x_length] + other_sizes
output_data = output_data.reshape(output_data_shape)
# temp dimensional ordering
temp_dims = [y_name, x_name] + list(other_dims.keys())
# map temp ordering to input ordering
output_order = [temp_dims.index(x) for x in input_data_var.dims]
output_data = output_data.transpose(output_order)
return output_data.astype(np.float32)
def _build_dataset(
ds: xr.Dataset,
data_var: str,
output_data: np.ndarray,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
) -> xr.Dataset:
"""Build a new xarray Dataset with the given output data and coordinates.
Parameters
----------
ds : xr.Dataset
The input dataset containing the data variable to be regridded.
data_var : str
The name of the data variable in the input dataset to be regridded.
output_data : np.ndarray
The regridded data to be included in the output dataset.
input_grid : xr.Dataset
The input grid dataset containing the original grid information.
output_grid : xr.Dataset
The output grid dataset containing the new grid information.
Returns
-------
xr.Dataset
A new dataset containing the regridded data variable with updated
coordinates and attributes.
"""
dv_input = ds[data_var]
output_coords = _get_output_coords(dv_input, output_grid)
output_da = xr.DataArray(
output_data,
dims=dv_input.dims,
coords=output_coords,
attrs=ds[data_var].attrs.copy(),
name=data_var,
)
output_ds = output_da.to_dataset()
output_ds.attrs = input_grid.attrs.copy()
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])
return output_ds
def _sparse_weights(
in_shape: tuple[int, int],
out_shape: tuple[int, int],
in_width: int,
out_width: int,
lat_mapping: list[np.ndarray],
lon_mapping: list[np.ndarray],
lat_weights: list[np.ndarray],
lon_weights: list[np.ndarray],
) -> xr.DataArray:
"""
Generates a sparse weight matrix for regridding.
Parameters
----------
in_shape : tuple
Shape of the input grid.
out_shape : tuple
Shape of the output grid.
in_width : int
Width of the input grid row.
out_width : int
Width of the output grid row.
lat_mapping : list[np.ndarray]
List of latitude mappings.
lon_mapping : list[np.ndarray]
List of longitude mappings.
lat_weights : list[np.ndarray]
List of latitude weights.
lon_weights : list[np.ndarray]
List of longitude weights.
Returns
-------
xr.DataArray
A sparse weight matrix for regridding.
"""
weights = np.zeros((np.prod(out_shape), np.prod(in_shape)), dtype=np.float32)
for i, y in enumerate(lat_mapping):
for j, x in enumerate(lon_mapping):
# destination index
dst_row = i * out_width + j
# list of source indexes
src_col = ((y * in_width).reshape((-1, 1)) + x).flatten()
# assign weights to matrix
weights[dst_row, src_col] = np.dot(lat_weights[i], lon_weights[j]).flatten()
# reshape from 2D (src, dest) to 4D (src y, src x, dest y, dest x), then convert to sparse
# provides user with simple way to explore weights and mapping
sparse_weights = sp.COO.from_numpy(weights).reshape(out_shape + in_shape)
coords = {
"y_out": np.arange(out_shape[0]),
"x_out": np.arange(out_shape[1]),
"y_in": np.arange(in_shape[0]),
"x_in": np.arange(in_shape[1]),
}
return xr.DataArray(
sparse_weights, dims=["y_out", "x_out", "y_in", "x_in"], coords=coords
)
def _get_output_coords(
dv_input: xr.DataArray, output_grid: xr.Dataset
) -> dict[str, xr.DataArray]:
"""
Generate the output coordinates for regridding based on the input data
variable and output grid.
Parameters
----------
dv_input : xr.DataArray
The input data variable containing the original coordinates.
output_grid : xr.Dataset
The dataset containing the target grid coordinates.
Returns
-------
dict[str, xr.DataArray]
A dictionary where keys are coordinate names and values are the
corresponding coordinates from the output grid or input data variable,
aligned with the dimensions of the input data variable.
"""
output_coords: dict[str, xr.DataArray] = {}
# First get the X and Y axes from the output grid.
for key in ["X", "Y"]:
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore
output_coords[str(input_coord.name)] = output_coord # type: ignore
# Get the remaining axes the input data variable (e.g., "time").
for dim in dv_input.dims:
if dim not in output_coords:
output_coords[str(dim)] = dv_input[dim]
# Sort the coords to align with the input data variable dims.
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}
return output_coords
def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> tuple[list[np.ndarray], list[np.ndarray]]:
"""
Map source to destination latitude.
Source cells are grouped by the contribution to each output cell.
Source cells have new boundaries calculated by finding minimum northern
and maximum southern boundary between each source cell and the destination
cell it contributes to.
The source cell weights are calculated by taking the difference of sin's
between these new boundary pairs.
Parameters
----------
src : np.ndarray
Array containing the source latitude bounds.
dst : np.ndarray
Array containing the destination latitude bounds.
Returns
-------
tuple[list[np.ndarray], list[np.ndarray]]
A tuple of cell mappings and cell weights.
"""
src_south, src_north = _extract_bounds(src)
dst_south, dst_north = _extract_bounds(dst)
dst_length = dst_south.shape[0]
# finds contributing source cells for each destination cell based on bounds values
# output is a list of lists containing the contributing cell indexes
# e.g. let src_south be [90, 45, 0, -45], source_north be [45, 0, -45, -90],
# dst_north[x] be 70, and dst_south[x] be -70 then the result would be [[1, 2]]
mapping = [
np.where(np.logical_and(src_south < dst_north[x], src_north > dst_south[x]))[0]
for x in range(dst_length)
]
# finds minimum and maximum bounds for each output cell, considers source and
# destination bounds for each cell
bounds = [
(np.minimum(dst_north[x], src_north[y]), np.maximum(dst_south[x], src_south[y]))
for x, y in enumerate(mapping)
]
# convert latitude to cell weight (difference of height above/below equator)
weights = _get_latitude_weights(bounds)
return mapping, weights
def _get_latitude_weights(
bounds: list[tuple[np.ndarray, np.ndarray]],
) -> list[np.ndarray]:
weights = []
for x, y in bounds:
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))
cell_weight = cell_weight.reshape((-1, 1))
weights.append(cell_weight)
return weights
def _map_longitude(src: np.ndarray, dst: np.ndarray) -> tuple[list, list]:
"""
Map source to destination longitude.
Source boundaries are aligned to the most western destination cell.
Source cells are grouped by the contribution to each output cell.
The source cell weights are calculated by find the difference of the
following min/max for each input cell. Minimum of eastern source bounds
and the eastern bounds of the destination cell it contributes to. Maximum
of western source bounds and the western bounds of the destination cell
it contributes to.
These weights are then shifted to align with the destination longitude.
Parameters
----------
src : np.ndarray
Array containing source longitude bounds.
dst : np.ndarray
Array containing destination longitude bounds.
Returns
-------
tuple[list, list]
A tuple of cell mappings and cell weights.
"""
src_west, src_east = _extract_bounds(src)
dst_west, dst_east = _extract_bounds(dst)
# align source and destination longitude
shifted_src_west, shifted_src_east, shift = _align_axis(
src_west,
src_east,
dst_west,
)
src_length = src_west.shape[0]
dst_length = dst_west.shape[0]
# finds contributing source cells for each destination cell based on bounds values
# output is a list of lists containing the contributing cell indexes
mapping = [
np.where(
np.logical_and(
shifted_src_west < dst_east[x], shifted_src_east > dst_west[x]
)
)[0]
for x in range(dst_length)
]
# weights are just the difference between minimum and maximum of contributing bounds
# for each destination cell
weights = [
(
np.minimum(dst_east[x], shifted_src_east[y])
- np.maximum(dst_west[x], shifted_src_west[y])
).reshape((1, -1))
for x, y in enumerate(mapping)
]
# need to adjust the source contributing indexes by the shift required to align
# source and destination longitude
for x in range(len(mapping)):
# shift the mapping indexes by the shift used to determine the weights
mapping[x] += shift
# find the contributing indexes that need to be wrapped
wrapped = np.where(mapping[x] > src_length - 1)[0]
# wrap the contributing index as all indexes must be <src_length
mapping[x][wrapped] -= src_length
return mapping, weights
def _extract_bounds(bounds: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Extract lower and upper bounds from an axis.
Parameters
----------
bounds : np.ndarray
A numpy array of bounds values.
Returns
-------
tuple[np.ndarray, np.ndarray]
A tuple containing the lower and upper bounds for the axis.
"""
if bounds[0, 0] < bounds[0, 1]:
lower = bounds[:, 0]
upper = bounds[:, 1]
else:
lower = bounds[:, 1]
upper = bounds[:, 0]
return lower.astype(np.float32), upper.astype(np.float32)
def _align_axis(
src_west: np.ndarray,
src_east: np.ndarray,
dst_west: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, int]:
"""
Aligns a source and destination longitude axis.
Parameters
----------
src_west : np.ndarray
Array containing the western source bounds.
src_east : np.ndarray
Array containing the eastern source bounds.
dst_west : np.ndarray
Array containing the western destination bounds.
Returns
-------
tuple[np.ndarray, np.ndarray, int]
A tuple containing the shifted western source bounds, the shifted eastern
source bounds, and the number of places shifted to align axis.
"""
# find smallest western bounds
west_most = np.minimum(dst_west[0], dst_west[-1])
# find cell index required to align bounds
alignment_index = _vpertub((west_most - src_west[-1]) / 360.0)
# shift index depending on first/last source bounds
alignment_index = (
alignment_index + 1 if src_west[0] < src_west[-1] else alignment_index - 1
)
# find relative indexes for each source cell to the destinations most western cell
relative_postition = _vpertub((west_most - src_west) / 360.0)
# find all index values that are not the alignment index
src_alignment_index = np.where(relative_postition != alignment_index)[0][0]
# determine the shift factor required to align source and destination bounds
if src_west[0] < src_west[-1]:
if west_most == src_west[src_alignment_index]:
shift = src_alignment_index
else:
shift = src_alignment_index - 1
if shift < 0:
shift = src_west.shape[0] - 1
else:
shift = src_alignment_index
src_length = src_west.shape[0]
# shift the source index values
shifted_indexes = np.arange(src_length + 1) + shift
# find index values that need to be shift to be within 0 - src_length
wrapped = np.where(shifted_indexes > src_length - 1)
# shift the indexes
shifted_indexes[wrapped] -= src_length
# reorder src_west and add portion to align
shifted_src_west = (
src_west[shifted_indexes] + 360.0 * relative_postition[shifted_indexes]
)
# reorder src_east and add portion to align
shifted_src_east = (
src_east[shifted_indexes] + 360.0 * relative_postition[shifted_indexes]
)
# handle ends of each interval
if src_west[-1] > src_west[0]:
if shifted_src_west[0] > west_most:
shifted_src_west[0] += -360.0
shifted_src_east[0] += -360.0
else:
if shifted_src_west[-1] > west_most:
shifted_src_west[-1] += -360.0
shifted_src_east[-1] += -360.0
return shifted_src_west, shifted_src_east, shift
def _pertub(value: np.ndarray) -> np.ndarray:
"""
Pertub a value.
Modifies value with a small constant and returns nearest whole
number.
Parameters
----------
value : np.ndarray
Value to pertub.
Returns
-------
np.ndarray
Value that's been pertubed.
"""
if value >= 0.0:
offset = np.ceil(value + 0.000001)
else:
offset = np.floor(value - 0.000001) + 1.0
return offset
# vectorize version of pertub
_vpertub = np.vectorize(_pertub)
def _get_dimension(input_data_var, cf_axis_name):
name = get_dim_keys(input_data_var, cf_axis_name)
index = input_data_var.dims.index(name)
return name, index
def _get_bounds_ensure_dtype(ds, axis):
bounds = None
try:
bounds = ds.bounds.get_bounds(axis)
except KeyError:
pass
if bounds is None:
raise RuntimeError(f"Could not determine {axis!r} bounds")
if bounds.dtype != np.float32:
bounds = bounds.astype(np.float32)
return bounds.values