Source code for satpy.enhancements

#!/usr/bin/env python
# Copyright (c) 2017 Satpy developers
#
# This file is part of satpy.
#
# satpy is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# satpy is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# satpy.  If not, see <http://www.gnu.org/licenses/>.
"""Enhancements."""

import numpy as np
import xarray as xr
import dask
import dask.array as da
import logging

LOG = logging.getLogger(__name__)


[docs]def stretch(img, **kwargs): """Perform stretch.""" return img.stretch(**kwargs)
[docs]def gamma(img, **kwargs): """Perform gamma correction.""" return img.gamma(**kwargs)
[docs]def invert(img, *args): """Perform inversion.""" return img.invert(*args)
[docs]def apply_enhancement(data, func, exclude=None, separate=False, pass_dask=False): """Apply `func` to the provided data. Args: data (xarray.DataArray): Data to be modified inplace. func (callable): Function to be applied to an xarray exclude (iterable): Bands in the 'bands' dimension to not include in the calculations. separate (bool): Apply `func` one band at a time. Default is False. pass_dask (bool): Pass the underlying dask array instead of the xarray.DataArray. """ attrs = data.attrs bands = data.coords['bands'].values if exclude is None: exclude = ['A'] if 'A' in bands else [] if separate: data_arrs = [] for idx, band_name in enumerate(bands): band_data = data.sel(bands=[band_name]) if band_name in exclude: # don't modify alpha data_arrs.append(band_data) continue if pass_dask: dims = band_data.dims coords = band_data.coords d_arr = func(band_data.data, index=idx) band_data = xr.DataArray(d_arr, dims=dims, coords=coords) else: band_data = func(band_data, index=idx) data_arrs.append(band_data) # we assume that the func can add attrs attrs.update(band_data.attrs) data.data = xr.concat(data_arrs, dim='bands').data data.attrs = attrs return data else: band_data = data.sel(bands=[b for b in bands if b not in exclude]) if pass_dask: dims = band_data.dims coords = band_data.coords d_arr = func(band_data.data) band_data = xr.DataArray(d_arr, dims=dims, coords=coords) else: band_data = func(band_data) attrs.update(band_data.attrs) # combine the new data with the excluded data new_data = xr.concat([band_data, data.sel(bands=exclude)], dim='bands') data.data = new_data.sel(bands=bands).data data.attrs = attrs return data
# pointed to by generic.yaml
[docs]def crefl_scaling(img, **kwargs): LOG.debug("Applying the crefl_scaling") def func(band_data, index=None): idx = np.array(kwargs['idx']) / 255 sc = np.array(kwargs['sc']) / 255 band_data *= .01 # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). band_data = xr.DataArray(da.clip(band_data.data.map_blocks(np.interp, xp=idx, fp=sc), 0, 1), coords=band_data.coords, dims=band_data.dims, name=band_data.name, attrs=band_data.attrs) return band_data return apply_enhancement(img.data, func, separate=True)
[docs]def cira_stretch(img, **kwargs): """Logarithmic stretch adapted to human vision. Applicable only for visible channels. """ LOG.debug("Applying the cira-stretch") def func(band_data): log_root = np.log10(0.0223) denom = (1.0 - log_root) * 0.75 band_data *= 0.01 band_data = band_data.clip(np.finfo(float).eps) band_data = np.log10(band_data) band_data -= log_root band_data /= denom return band_data return apply_enhancement(img.data, func)
def _lookup_delayed(luts, band_data): # can't use luts.__getitem__ for some reason return luts[band_data]
[docs]def lookup(img, **kwargs): """Assign values to channels based on a table.""" luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0 def func(band_data, luts=luts, index=-1): # NaN/null values will become 0 lut = luts[:, index] if len(luts.shape) == 2 else luts band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) new_delay = dask.delayed(_lookup_delayed)(lut, band_data) new_data = da.from_delayed(new_delay, shape=band_data.shape, dtype=luts.dtype) return new_data return apply_enhancement(img.data, func, separate=True, pass_dask=True)
[docs]def colorize(img, **kwargs): """Colorize the given image.""" full_cmap = _merge_colormaps(kwargs) img.colorize(full_cmap)
[docs]def palettize(img, **kwargs): """Palettize the given image (no color interpolation).""" full_cmap = _merge_colormaps(kwargs) img.palettize(full_cmap)
def _merge_colormaps(kwargs): """Merge colormaps listed in kwargs.""" from trollimage.colormap import Colormap full_cmap = None palette = kwargs['palettes'] if isinstance(palette, Colormap): full_cmap = palette else: for itm in palette: cmap = create_colormap(itm) cmap.set_range(itm["min_value"], itm["max_value"]) if full_cmap is None: full_cmap = cmap else: full_cmap = full_cmap + cmap return full_cmap
[docs]def create_colormap(palette): """Create colormap of the given numpy file, color vector or colormap.""" from trollimage.colormap import Colormap fname = palette.get('filename', None) if fname: data = np.load(fname) cmap = [] num = 1.0 * data.shape[0] for i in range(int(num)): cmap.append((i / num, (data[i, 0] / 255., data[i, 1] / 255., data[i, 2] / 255.))) return Colormap(*cmap) colors = palette.get('colors', None) if isinstance(colors, (tuple, list)): cmap = [] values = palette.get('values', None) for idx, color in enumerate(colors): if values is not None: value = values[idx] else: value = idx / float(len(colors) - 1) cmap.append((value, tuple(color))) return Colormap(*cmap) if isinstance(colors, str): from trollimage import colormap import copy return copy.copy(getattr(colormap, colors)) return None
def _three_d_effect_delayed(band_data, kernel, mode): from scipy.signal import convolve2d band_data = band_data.reshape(band_data.shape[1:]) new_data = convolve2d(band_data, kernel, mode=mode) return new_data.reshape((1, band_data.shape[0], band_data.shape[1]))
[docs]def three_d_effect(img, **kwargs): """Create 3D effect using convolution""" w = kwargs.get('weight', 1) LOG.debug("Applying 3D effect with weight %.2f", w) kernel = np.array([[-w, 0, w], [-w, 1, w], [-w, 0, w]]) mode = kwargs.get('convolve_mode', 'same') def func(band_data, kernel=kernel, mode=mode, index=None): del index delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode) new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype) return new_data return apply_enhancement(img.data, func, separate=True, pass_dask=True)
[docs]def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs): """Scale data linearly in two separate regions. This enhancement scales the input data linearly by splitting the data into two regions; min_in to threshold and threshold to max_in. These regions are mapped to 1 to threshold_out and threshold_out to 0 respectively, resulting in the data being "flipped" around the threshold. A default threshold_out is set to `176.0 / 255.0` to match the behavior of the US National Weather Service's forecasting tool called AWIPS. Args: img (XRImage): Image object to be scaled min_in (float): Minimum input value to scale max_in (float): Maximum input value to scale threshold (float): Input value where to split data in to two regions threshold_out (float): Output value to map the input `threshold` to. Optional, defaults to 176.0 / 255.0. """ threshold_out = threshold_out if threshold_out is not None else (176 / 255.0) low_factor = (threshold_out - 1.) / (min_in - threshold) low_offset = 1. + (low_factor * min_in) high_factor = threshold_out / (max_in - threshold) high_offset = high_factor * max_in def _bt_threshold(band_data): # expects dask array to be passed return da.where(band_data >= threshold, high_offset - high_factor * band_data, low_offset - low_factor * band_data) return apply_enhancement(img.data, _bt_threshold, pass_dask=True)