Source code for xcdat.mask

from typing import Any, Callable

import numpy as np
import regionmask
import xarray as xr

from xcdat import open_dataset
from xcdat._data import _get_pcmdi_mask_path
from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
from xcdat.regridder.accessor import _obj_to_grid_ds
from xcdat.regridder.grid import create_grid
from xcdat.utils import _as_dataarray

logger = _setup_custom_logger(__name__)

VALID_METHODS: list[str] = ["regionmask", "pcmdi"]
VALID_KEEP: list[str] = ["land", "sea"]


def generate_and_apply_land_sea_mask(
    ds: xr.Dataset,
    data_var: str,
    method: str = "regionmask",
    keep: str = "sea",
    threshold: float | None = None,
    mask: xr.DataArray | None = None,
    output_mask: bool | str = False,
    **options: Any,
) -> xr.Dataset:
    """Generate a land-sea mask and apply it to a data variable in a dataset.

    Parameters
    ----------
    ds : xr.Dataset
        The dataset to mask.
    data_var : str
        The key of the data variable to mask.
    method : str, optional
        The masking method, by default "regionmask".
        Supported methods: "regionmask", "pcmdi".
    keep : str, optional
        Whether to keep "land" or "sea" points, by default "sea".
    threshold : float | None, optional
        The threshold used to determine cell classification. The default
        value for this argument depends on `keep`. If `keep` is `sea` then
        the default is 0.2 and values less than or equal will be considered
        sea. If `keep` is `land` then the default is 0.8 and values greater
        or equal will be considered land.
    mask : xr.DataArray | None, optional
        A custom mask to apply, by default None. If None, a mask is
        generated using the specified ``method``.
    **options : Any
        These options are passed directly to the ``method``. See
        :func:`xcdat.mask.pcmdi_land_sea_mask` for PCMDI options.

    Returns
    -------
    xr.Dataset
        The dataset with the masked data variable.

    Raises
    ------
    ValueError
        If `keep` is not "land" or "sea".

    Examples
    --------

    Mask a data variable by land using the default method (regionmask):
    >>> ds_masked = generate_mask(ds, "tas", keep="sea")

    Mask a data variable by sea using the PCMDI method with custom threshold:
    >>> ds_masked = generate_mask(ds, "tas", method="pcmdi", keep="land", threshold=0.7)

    Mask a data variable by land using a custom mask and output the mask:
    >>> custom_mask = xr.DataArray(...)  # Define your custom mask here
    >>> ds_masked = generate_mask(ds, "tas", keep="sea", mask=custom_mask, output_mask=True)

    Mask a data variable by sea and add the mask to the dataset with a custom name:
    >>> ds_masked = generate_mask(ds, "tas", keep="land", output_mask="land_mask")
    """
    if keep not in VALID_KEEP:
        raise ValueError(
            f"Keep value {keep!r} is not valid, options are {', '.join(VALID_KEEP)!r}"
        )

    _ds = ds.copy()

    da = _ds[data_var]

    if mask is None:
        mask = generate_land_sea_mask(da, method, **options)

    if keep == "sea":
        _ds[data_var] = da.where(mask <= (threshold or 0.2))
    else:
        _ds[data_var] = da.where(mask >= (threshold or 0.8))

    if output_mask:
        if isinstance(output_mask, str):
            mask_name = output_mask
        else:
            mask_name = f"{data_var}_mask"

        _ds[mask_name] = mask

    return _ds


def generate_land_sea_mask(
    da: xr.DataArray, method: str = "regionmask", **options: Any
) -> xr.DataArray:
    """Generate a land-sea mask.

    Parameters
    ----------
    da : xr.DataArray
        The DataArray to generate the mask for.
    method : str, optional
        The method to use for generating the mask, by default "regionmask".
        Supported methods: "regionmask", "pcmdi".
    **options : Any
        These options are passed directly to the ``method``. See specific
        method documentation for available options:
        :func:`pcmdi_land_sea_mask` for PCMDI options

    Returns
    -------
    xr.DataArray
        The land-sea mask.

    Raises
    ------
    ValueError
        If `method` is not "regionmask" or "pcmdi".

    References
    ----------
    .. _PCMDI's report #58: https://pcmdi.llnl.gov/report/ab58.html

    History
    -------
    2023-06 The [original code](https://github.com/CDAT/cdutil/blob/master/
    cdutil/create_landsea_mask.py) was rewritten using xarray and xcdat by Jiwoo Lee

    Examples
    --------

    Generate a land-sea mask using the default method (regionmask):

    >>> import xcdat
    >>> ds = xcdat.open_dataset("/path/to/file")
    >>> mask = xcdat.mask.generate_land_sea_mask(ds["tas"], method="regionmask")

    Generate a land-sea mask using the PCMDI method with custom options:

    >>> mask = xcdat.mask.generate_land_sea_mask(
    ...     ds["tas"], method="pcmdi", threshold1=0.3, threshold2=0.4
    ... )
    """
    if method not in VALID_METHODS:
        raise ValueError(
            f"Method value {method!r} is not valid, options are {', '.join(VALID_METHODS)!r}"
        )

    if method == "regionmask":
        land_mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110

        lon, lat = get_dim_coords(da, "X"), get_dim_coords(da, "Y")

        land_sea_mask = land_mask.mask(lon, lat=lat)

        land_sea_mask = xr.where(land_sea_mask, 0, 1)
    elif method == "pcmdi":
        land_sea_mask = pcmdi_land_sea_mask(da, **options)

    return land_sea_mask


[docs] def pcmdi_land_sea_mask( da: xr.DataArray, threshold1: float = 0.2, threshold2: float = 0.3, source: xr.Dataset | None = None, source_data_var: str | None = None, ) -> xr.DataArray: """ Generate a land-sea mask using the PCMDI method. This method uses a high-resolution land-sea mask and regrids it to the resolution of the input DataArray. It then iteratively improves the mask based on specified thresholds. Parameters ---------- da : xr.DataArray The DataArray to generate the mask for. threshold1 : float, optional The first threshold for improving the mask, by default 0.2. threshold2 : float, optional The second threshold for improving the mask, by default 0.3. source : xr.Dataset | None, optional A custom Dataset containing the variable to use as the high-resolution source. If not provided, a default high-resolution land-sea mask is used. source_data_var : str | None, optional The name of the variable in `source` to use as the high-resolution source. If `source` is not provided, this defaults to "sftlf". Returns ------- xr.DataArray The generated land-sea mask. Raises ------ ValueError If `source` is provided but `source_data_var` is None. Notes ----- By default, the ``navy_land.nc`` file is used as the high-resolution land–sea mask. This file is distributed by the [1]_ PCMDI (Program for Climate Model Diagnosis and Intercomparison) Metrics Package, and is derived from the U.S. Navy 1/6° land–sea mask dataset. If ``source`` is not provided, the ``navy_land.nc`` file is automatically downloaded and cached from the `xcdat-data` repository: https://github.com/xCDAT/xcdat-data. For more information on caching behavior, refer to the :py:func:`xcdat._data._get_pcmdi_mask_path()` function. References ---------- .. [1] https://github.com/PCMDI/pcmdi_metrics/blob/main/ Examples -------- Generate a land-sea mask using the PCMDI method: >>> import xcdat >>> ds = xcdat.open_dataset("/path/to/file") >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"]) Generate a land-sea mask using the PCMDI method with custom thresholds: >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask( ... ds["tas"], threshold1=0.3, threshold2=0.4 ... ) Generate a land-sea mask using the PCMDI method with a custom high-res source: >>> highres_ds = xcdat.open_dataset("/path/to/file") >>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask( ... ds["tas"], source=highres_ds, source_data_var="highres" ... ) For offline workflows, you can pre-download the mask with: >>> from xcdat._data import _get_pcmdi_mask_path >>> path = _get_pcmdi_mask_path() """ if source is not None and source_data_var is None: raise ValueError( "The 'source_data_var' value cannot be None when using the 'source' option." ) if source is None: source_data_var = "sftlf" resource_path = _get_pcmdi_mask_path() # Turn off time decoding to prevent logger warning since this dataset # does not have a time axis. source = open_dataset(resource_path, decode_times=False) source_regrid = source.regridder.horizontal( source_data_var, _obj_to_grid_ds(da), tool="regrid2" ) mask = source_regrid.copy() # Set keep_attrs=drop_conflicts to ensure that attributes from the argument x (1 in # this case) are not copied. This preserves the existing attributes of the # data variable. The default value None and False removes all attributes, # while True would incorrectly copy attributes from x. mask[source_data_var] = xr.where( source_regrid[source_data_var] > 0.5, 1, 0, keep_attrs="drop_conflicts" ).astype("i") lon = mask[source_data_var].cf["X"] lon_bnds = mask.bounds.get_bounds("X") is_circular = _is_circular(lon, lon_bnds) surrounds = _generate_surrounds(mask[source_data_var], is_circular) i = 0 while i <= 25: logger.debug("Iteration %i", i + 1) improved_mask = _improve_mask( mask.copy(deep=True), source_regrid, source_data_var, # type: ignore[arg-type] surrounds, is_circular, threshold1, threshold2, ) if improved_mask.equals(mask): break mask = improved_mask i += 1 return mask[source_data_var]
def _is_circular(lon: xr.DataArray, lon_bnds: xr.DataArray) -> bool: """Check if a longitude axis is circular. Parameters ---------- lon : xr.DataArray The longitude coordinates. lon_bnds : xr.DataArray The longitude bounds. Returns ------- bool True if the longitude axis is circular, False otherwise. """ axis_start, axis_stop = float(lon[0]), float(lon[-1]) delta = float(lon[-1] - lon[-2]) alignment = abs(axis_stop + delta - axis_start - 360.0) tolerance = 0.01 * delta mod_360 = float(lon_bnds[-1][1] - lon_bnds[0][0]) % 360 return alignment < tolerance and mod_360 == 0 def _improve_mask( mask: xr.Dataset, source: xr.Dataset, data_var: str, surrounds: list[np.ndarray], is_circular: bool, threshold1=0.2, threshold2=0.3, ) -> xr.Dataset: """Improve a land-sea mask. This function improves a land-sea mask by converting points based on their surrounding values and a source mask. It is useful for enhancing the accuracy of land-sea masks, which are often used in climate modeling and geospatial analysis. By considering surrounding points and thresholds, it ensures smoother transitions and corrects discrepancies between the mask and the source dataset. Parameters ---------- mask : xr.Dataset The mask to improve. source : xr.Dataset The source dataset for comparison. data_var : str The name of the data variable in the mask and source. surrounds : list[np.ndarray] A list of surrounding points for each point in the mask. is_circular : bool Whether the longitude axis is circular. threshold1 : float, optional The first threshold for conversion, by default 0.2. threshold2 : float, optional The second threshold for conversion, by default 0.3. Returns ------- xr.Dataset The improved mask. """ mask_approx = _map2four( mask, data_var, ) diff = source[data_var] - mask_approx[data_var] mask_convert_land = _convert_points( mask[data_var] * 1.0, source[data_var], diff, threshold1, threshold2, is_circular, surrounds, ) mask_convert_sea = _convert_points( mask_convert_land, source[data_var], diff, -threshold1, 1.0 - threshold2, is_circular, surrounds, convert_land=False, ) mask[data_var] = mask_convert_sea.astype("i") return mask def _map2four(mask: xr.Dataset, data_var: str) -> xr.Dataset: """Map a mask to four subgrids and back. This function regrids a mask to four subgrids (odd-odd, odd-even, even-odd, even-even) and then combines them back into a single mask. This is used to approximate the mask at a higher resolution. Parameters ---------- mask : xr.Dataset The mask to process. data_var : str The name of the data variable in the mask. Returns ------- xr.Dataset The processed mask. """ mask_temp = mask.copy() lat, lon = mask_temp[data_var].cf["Y"], mask_temp[data_var].cf["X"] lat_odd, lat_even = lat[::2], lat[1::2] lon_odd, lon_even = lon[::2], lon[1::2] odd_odd = create_grid(y=lat_odd, x=lon_odd, add_bounds=True) odd_even = create_grid(y=lat_odd, x=lon_even, add_bounds=True) even_odd = create_grid(y=lat_even, x=lon_odd, add_bounds=True) even_even = create_grid(y=lat_even, x=lon_even, add_bounds=True) regrid_odd_odd = mask_temp.regridder.horizontal(data_var, odd_odd, tool="regrid2") regrid_odd_even = mask_temp.regridder.horizontal(data_var, odd_even, tool="regrid2") regrid_even_odd = mask_temp.regridder.horizontal(data_var, even_odd, tool="regrid2") regrid_even_even = mask_temp.regridder.horizontal( data_var, even_even, tool="regrid2" ) output = np.zeros(mask_temp[data_var].shape, dtype="f") output[::2, ::2] = regrid_odd_odd[data_var].data output[::2, 1::2] = regrid_odd_even[data_var].data output[1::2, ::2] = regrid_even_odd[data_var].data output[1::2, 1::2] = regrid_even_even[data_var].data mask_temp[data_var] = (mask_temp[data_var].dims, output) return mask_temp def _convert_points( mask: xr.DataArray, source: xr.DataArray, diff: xr.DataArray, threshold1: float, threshold2: float, is_circular: bool, surrounds: list[np.ndarray], convert_land=True, ) -> xr.DataArray: """Convert points in a mask based on surrounding values. This function converts points in a mask from land to sea or sea to land based on a set of thresholds and the values of surrounding points. Parameters ---------- mask : xr.DataArray The mask to modify. source : xr.DataArray The source data for comparison. diff : xr.DataArray The difference between the source and an approximated mask. threshold1 : float Threshold for points in the `diff` DataArray. threshold2 : float Threshold for points in the `source` DataArray. is_circular : bool Whether the longitude axis is circular. surrounds : list[np.ndarray] A list of surrounding points for each point in the mask. convert_land : bool, optional Whether to convert points to land (True) or sea (False), by default True. Returns ------- xr.DataArray The modified mask. """ UL, UC, UR, ML, MR, LL, LC, LR = surrounds mask_value = 1.0 compare_func: Callable if convert_land: compare_func = np.greater else: compare_func = np.less mask_value = 0.0 flip_value = abs(mask_value - 1.0) c1 = compare_func(diff, threshold1) c2 = compare_func(source, threshold2) c = _as_dataarray(np.logical_and(c1, c2)) cUL, cUC, cUR, cML, cMR, cLL, cLC, cLR = _generate_surrounds(c, is_circular) if is_circular: c = c[1:-1] temp = source.data[1:-1] else: c = c[1:-1, 1:-1] temp = source.data[1:-1, 1:-1] m = np.logical_and(c, compare_func(temp, np.where(cUL, UL, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cUC, UC, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cUR, UR, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cML, ML, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cMR, MR, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cLL, LL, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cLC, LC, flip_value))) m = np.logical_and(m, compare_func(temp, np.where(cLR, LR, flip_value))) if is_circular: mask[1:-1] = xr.where(m, mask_value, mask[1:-1]) else: mask[1:-1, 1:-1] = xr.where(m, mask_value, mask[1:-1, 1:-1]) return mask def _generate_surrounds(da: xr.DataArray, is_circular: bool) -> list[np.ndarray]: """Generate surrounding points for each point in a DataArray. This function returns a list of 8 arrays, each representing the values of the 8 surrounding points (UL, UC, UR, ML, MR, LL, LC, LR) for each point in the input DataArray. Parameters ---------- da : xr.DataArray The input DataArray. is_circular : bool Whether the longitude axis is circular. Returns ------- list[np.ndarray] A list of 8 arrays representing the surrounding points. """ data = da.data y_up, y_mid, y_down = slice(2, None), slice(1, -1), slice(None, -2) if is_circular: # For circular longitude, roll the horizontal axis. UC, LC = data[y_up, :], data[y_down, :] ML, MR = np.roll(data[y_mid, :], 1, axis=1), np.roll(data[y_mid, :], -1, axis=1) UL, UR = np.roll(data[y_up, :], 1, axis=1), np.roll(data[y_up, :], -1, axis=1) LL, LR = ( np.roll(data[y_down, :], 1, axis=1), np.roll(data[y_down, :], -1, axis=1), ) else: # For non-circular, slice the horizontal axis. x_left, x_mid, x_right = slice(None, -2), slice(1, -1), slice(2, None) UC, LC = data[y_up, x_mid], data[y_down, x_mid] ML, MR = data[y_mid, x_left], data[y_mid, x_right] UL, UR = data[y_up, x_left], data[y_up, x_right] LL, LR = data[y_down, x_left], data[y_down, x_right] return [UL, UC, UR, ML, MR, LL, LC, LR]