Source code for odc.geo._rgba

""" Helpers for dealing with RGB(A) images.
"""

import functools
from typing import Any, List, Optional, Tuple

import numpy as np
import xarray as xr

from ._interop import is_dask_collection
from .types import Nodata

# pylint: disable=import-outside-toplevel


def is_rgb(x: xr.DataArray):
    """
    Check if array is RGB(A).
    """
    if x.dtype != "uint8":
        return False
    if x.ndim < 3:
        return False
    if x.shape[-1] not in (3, 4):
        return False
    return True


def _guess_rgb_names(bands: List[str]) -> Tuple[str, str, str]:
    def _candidate(color: str) -> str:
        candidates = [name for name in bands if color in name]
        n = len(candidates)
        if n == 1:
            return candidates[0]

        if n == 0:
            raise ValueError(f'Found no candidate for color "{color}"')
        raise ValueError(f'Found too many candidates for color "{color}"')

    r, g, b = [_candidate(c) for c in ("red", "green", "blue")]
    return (r, g, b)


def _auto_guess_clamp(ds: xr.Dataset) -> Tuple[float, float]:
    # TODO: deal with nodata > 0 case
    return (float(0), max(x.data.max() for x in ds.data_vars.values()))


def _to_u8(x: np.ndarray, vmin: float, vmax: float) -> np.ndarray:
    x = np.clip(x, vmin, vmax)

    if x.dtype.kind == "f":
        x = (x - vmin) * (255.0 / (vmax - vmin))
    else:
        x = (x - vmin).astype("uint32") * 255 // (vmax - vmin)
    return x.astype("uint8")


def _np_to_rgba(
    r: np.ndarray,
    g: np.ndarray,
    b: np.ndarray,
    nodata: Nodata,
    vmin: float,
    vmax: float,
) -> np.ndarray:
    rgba = np.zeros((*r.shape, 4), dtype="uint8")

    if r.dtype.kind == "f":
        valid = ~np.isnan(r)
        if nodata is not None and not np.isnan(nodata):
            valid = valid * (r != nodata)
    elif nodata is not None:
        valid = r != nodata
    else:
        valid = np.ones(r.shape, dtype=np.bool_)

    rgba[..., 3] = valid.astype("uint8") * (0xFF)
    for idx, band in enumerate([r, g, b]):
        rgba[..., idx] = _to_u8(band, vmin, vmax)

    return rgba


[docs] def to_rgba( ds: Any, bands: Optional[Tuple[str, str, str]] = None, *, vmin: Optional[float] = None, vmax: Optional[float] = None, ) -> xr.DataArray: """ Convert dataset to RGBA image. Given :py:class:`xarray.Dataset` with bands ``red,green,blue`` construct :py:class:`xarray.Datarray` containing ``uint8`` rgba image. :param ds: :py:class:`xarray.Dataset` :param vmin: Defaults to ``0`` when ``vmax`` is supplied. :param vmax: Configure range, must be supplied for Dask inputs. When not configured ``vmin=0, vmax=max(r,g,b))`` is used. :param bands: Which bands to use, order should be red,green,blue """ # pylint: disable=too-many-locals assert isinstance(ds, xr.Dataset) if bands is None: try: bands = _guess_rgb_names(list(ds.data_vars)) except ValueError as e: raise ValueError( f"Unable to automatically guess RGB colours ({e}). " f"Manually specify red, green and blue bands using the " f"`bands` parameter." ) from e is_dask = is_dask_collection(ds) if vmin is None: if vmax is not None: vmin = 0 if vmax is None: if is_dask: raise ValueError("Must specify clamp for Dask inputs (e.g. vmax, vmin)") _vmin, vmax = _auto_guess_clamp(ds[list(bands)]) vmin = _vmin if vmin is None else vmin assert vmin is not None assert vmax is not None _b = ds[bands[0]] nodata = _b.odc.nodata dims = (*_b.dims, "band") r, g, b = (ds[name].data for name in bands) if is_dask: # pylint: disable=import-outside-toplevel from dask import array as da from dask.base import tokenize assert _b.chunks is not None data = da.map_blocks( _np_to_rgba, r, g, b, nodata, vmin, vmax, name=f"ro_rgba-{tokenize(r, g, b)}", dtype=np.uint8, chunks=(*_b.chunks, (4,)), new_axis=[r.ndim], ) else: data = _np_to_rgba(r, g, b, nodata, vmin, vmax) coords = dict(_b.coords.items()) coords.update(band=xr.DataArray(data=["r", "g", "b", "a"], dims=("band",))) rgba = xr.DataArray(data, coords=coords, dims=dims) return rgba
def _np_colorize(x, cmap, clip): if x.dtype == "bool": x = x.astype("uint8") if clip: x = np.clip(x, 0, cmap.shape[0] - 1) return cmap[x] def _matplotlib_colorize( x, cmap, vmin=None, vmax=None, nodata: Nodata = None, robust=False, ): from matplotlib import colormaps from matplotlib.colors import Normalize if cmap is None or isinstance(cmap, str): # None is a valid input, maps to default cmap cmap = colormaps.get_cmap(cmap) # type: ignore if nodata is not None: x = np.where(x == nodata, np.float32("nan"), x) if robust: if x.dtype.kind != "f": x = x.astype("float32") _vmin, _vmax = np.nanpercentile(x, [2, 98]) # do not override configured values if vmin is None: vmin = _vmin if vmax is None: vmax = _vmax elif x.dtype.kind == "f": if vmin is None: vmin = np.nanmin(x) if vmax is None: vmax = np.nanmax(x) return cmap(Normalize(vmin=vmin, vmax=vmax)(x), bytes=True)
[docs] def colorize( x: Any, cmap=None, attrs=None, *, clip: bool = False, vmin: Optional[float] = None, vmax: Optional[float] = None, robust: Optional[bool] = None, ) -> xr.DataArray: """ Apply colormap to data. There are two modes of operation: * Map categorical values from ``x`` to RGBA according to ``cmap`` lookup table. * Interpolate into RGBA using matplotlib colormaps (needs matplotlib installed) .. note:: When using matplotlib colormaps with Dask inputs one must configure vmin/vmax to ensure all chunks are colorized consistently. :param x: Input xarray data array (can be Dask) :param cmap: Lookup table ``cmap[x] -> RGB(A)`` or matplotlib colormap :param vmin: Valid range to colorize :param vmax: Valid range to colorize :param robust: Use percentiles for clamping ``vmin=2%, vmax=98%`` :param attrs: xarray attributes table, if not supplied input attributes are copied across :param clip: If ``True`` clip values from ``x`` to be in the safe range for ``cmap``. """ # pylint: disable=too-many-locals from ._xr_interop import ODCExtensionDa assert isinstance(x, xr.DataArray) assert isinstance(x.odc, ODCExtensionDa) _is_dask = is_dask_collection(x.data) if isinstance(cmap, np.ndarray): assert cmap.ndim == 2 assert cmap.shape[1] in (3, 4) cmap_dtype = cmap.dtype _impl = functools.partial(_np_colorize, clip=clip) nc = cmap.shape[1] else: # Assume matplotlib # default robust=True for float, non-dask inputs when vmin/vmax/robust are not configured if ( vmin is None and vmax is None and robust is None and x.dtype.kind == "f" and not _is_dask ): robust = True elif robust is None: robust = False _impl = functools.partial( _matplotlib_colorize, vmin=vmin, vmax=vmax, nodata=x.odc.nodata, robust=robust, ) nc, cmap_dtype = 4, "uint8" if attrs is None: attrs = {**x.attrs} attrs.pop("nodata", None) dims = (*x.dims, "band") coords = dict(x.coords.items()) coords["band"] = xr.DataArray(data=["r", "g", "b", "a"][:nc], dims=("band",)) if _is_dask: from dask import array as da from dask import delayed from dask.base import tokenize _cmap = delayed(cmap) if isinstance(cmap, np.ndarray) else cmap assert x.chunks is not None data = da.map_blocks( _impl, x.data, _cmap, name=f"colorize-{tokenize(x, _cmap, clip, vmin, vmax, robust)}", meta=np.ndarray((), cmap_dtype), chunks=(*x.chunks, (nc,)), new_axis=[x.data.ndim], ) else: data = _impl(x.data, cmap) return xr.DataArray(data=data, dims=dims, coords=coords, attrs=attrs)
def replace_transparent_pixels( rgba: np.ndarray, color: Tuple[int, int, int] = (255, 0, 255) ) -> np.ndarray: """ Convert RGBA to RGB. Replaces transparent pixels with a given color. """ assert rgba.ndim == 3 assert rgba.shape[-1] == 4 m = rgba[..., -1] == 0 rgb = rgba[..., :3].copy() rgb[m] = color return rgb