Source code for geodata.plot

# Copyright 2022 Jiahe Feng (Davidson Lab)
# Copyright 2022 Xiqiang Liu (Davidson Lab)

# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 3 of the
# License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import logging
from typing import Literal, Optional, Union

import geopandas as gpd
import matplotlib.animation as anim
import matplotlib.pyplot as plt
import xarray as xr

from ._cutout import ds_reformat_index
from .mask import show  # noqa: F401

plt.rcParams["animation.html"] = "jshtml"

_logger = logging.getLogger(__name__)
[docs] CoordinateType = tuple[Union[float, int], Union[float, int]]
[docs] ReductionType = Literal["mean", "sum"]
def ds_ts_aggregate( ds: Union[xr.Dataset, xr.DataArray], agg_method: ReductionType ) -> Union[xr.DataArray, xr.Dataset]: """Aggregate the xarray.Dataset or xarray.DataArray along the lat and lon dimensions. Args: ds (Union[xr.Dataset, xr.DataArray]): The xarray Dataset. agg_method (Literal["mean", "sum"]): The aggregation method. If "mean", the mean of all of the values will be taken. If "sum", the sum of all of the values will be taken. Returns: Union[xr.Dataset, xr.DataArray]: The aggregated xarray Dataset. Raises: NotImplementedError: If the aggregation method is not supported. """ if agg_method == "mean": ds = ds.transpose("time", "lat", "lon").mean(axis=1).mean(axis=1) elif agg_method == "sum": ds = ds.transpose("time", "lat", "lon").sum(axis=1).sum(axis=1) else: raise NotImplementedError(f"agg_method {agg_method} is not supported.") return ds def time_series( ds: xr.DataArray, lat_slice: Optional[CoordinateType] = None, lon_slice: Optional[CoordinateType] = None, agg_slice: bool = True, agg_slice_method: ReductionType = "mean", coord_dict: Optional[dict[str, CoordinateType]] = None, time_factor: float = 1.0, agg_time_method: ReductionType = "mean", figsize: tuple[float, float] = (10.0, 5.0), ds_name: Optional[str] = None, loc_name: Optional[str] = None, title: Optional[str] = None, title_size: float = 12.0, grid: bool = True, legend: bool = True, return_fig: bool = False, **kwargs, ) -> Optional[plt.Figure]: """Take in the xarray.DataArray, slice of latitude or longitude, plot the time series. When users give lat or lon slices, the values can be mean/sum aggregated. Users can also provide a dictionary of name-coordinate pairs instead of lat and lon input. By default, the method shows the time series' aggregated mean. Args: ds (xr.DataArray): The target xarray.DataArray object. lat_slice (tuple): The slice of latitude values. lon_slice (tuple): The slice of longitude values agg_slice (bool): Whether the program plot the aggregate values from the slices agg_slice_method (str): Reduction method for aggregating the spatial slices. This can be either "mean" or "sum". coord_dict (dict): The (Name, Coordinate) pair of different locations; An example: `{'Beijing': (40, 116.25), 'Shanghai': (31, 121.25)}` time_factor (float): The factor to aggregate the value of dataArray on An example: for daily mean on hourly data, time_factor = 24 agg_time_method (str): Reduction method for time dimension. This can be either "mean" or "sum". figsize (tuple): The size of the plot ds_name (str): Name of the DataArray to be shown on title loc_name (str): Location of the place to be shown on title title (str): The title of the result plot title_size (float): The size of the title of the result plolt grid (bool): Whether to add grid lines to the plot, True by default legend (bool): Add legend to the plot if multiple locations are provided, True by default return_fig (bool): Whether to return the figure or not. True by default **kwargs: Other keyword arguments for xarray.DataArray.plot() Returns: Optional[plt.Figure]: The figure object if `return_fig` is True. Raises: NotImplementedError: If the reduction method is not supported. """ ds = ds_reformat_index(ds) if not ds_name: ds_name = ds.name fig = plt.figure(figsize=figsize) ax = fig.add_subplot() create_title = f"{ds.name} Time Series - " ds = ds.coarsen(time=time_factor, boundary="trim") if agg_time_method == "mean": ds = ds.mean() elif agg_time_method == "sum": ds = ds.sum() else: raise NotImplementedError( f"agg_time_method {agg_time_method} is not supported." ) if not agg_slice and lat_slice is None and lon_slice is None: if agg_slice: raise RuntimeError( "agg_slice cannot be set to True without lat_slice or lon_slice." ) if lat_slice: if lat_slice[1] < lat_slice[0]: raise ValueError( "Please give correct latitude slice. The second value should be larger than the first." ) ds = ds.where(ds.lat >= lat_slice[0], drop=True).where( ds.lat <= lat_slice[1], drop=True ) create_title += f"lat slice {lat_slice} " if lon_slice: if lon_slice[1] < lon_slice[0]: raise ValueError( "Please give correct longitude slice. The second value should be larger than the first." ) ds = ds.where(ds.lon >= lon_slice[0], drop=True).where( ds.lon <= lon_slice[1], drop=True ) create_title += f"lon slice {lon_slice} " if coord_dict: if not loc_name: loc_name = ", ".join(coord_dict.keys()) for key, value in coord_dict.items(): all_lat = ds.lat.data all_lon = ds.lon.data la, lo = value[0], value[1] log_new_coord = False if la not in all_lat: if la > all_lat.min() and la < all_lat.max(): log_new_coord = True la = all_lat[all_lat < la][-1] else: raise ValueError(f"Latitude for {key} out of bound.") if lo not in all_lon: if lo > all_lon.min() and lo < all_lon.max(): log_new_coord = True lo = all_lon[all_lon < lo][-1] else: raise ValueError(f"Longitude for {key} out of bound.") if log_new_coord: _logger.info( "Find grid cell containing coordinate for %s at lat = %f, lon = %f.", key, la, lo, ) ds.sel(lat=la, lon=lo).plot(ax=ax, label=key, **kwargs) create_title += f"{loc_name} " if not coord_dict and ( agg_slice is True or (lat_slice is None and lon_slice is None) ): ds = ds_ts_aggregate(ds, agg_slice_method) ds.plot(ax=ax, **kwargs) create_title += f"spatially {agg_slice_method} aggregated " if agg_slice is False and (lat_slice is not None or lon_slice is not None): if lat_slice and not lon_slice: if agg_slice_method == "mean": ds = ds.mean(axis=2) elif agg_slice_method == "sum": ds = ds.sum(axis=2) create_title += f"with longitude {agg_slice_method} aggregated " for la in ds.lat.values: ds.sel(lat=la).plot(ax=ax, label=f"lat {la}", **kwargs) elif lon_slice and not lat_slice: if agg_slice_method == "mean": ds = ds.mean(axis=1) elif agg_slice_method == "sum": ds = ds.sum(axis=1) create_title += f"with latitude {agg_slice_method} aggregated " for lo in ds.lon.values: ds.sel(lon=lo).plot(ax=ax, label=f"lon {lo}", **kwargs) elif lat_slice and lon_slice: for la in ds.lat.values: for lo in ds.lon.values: ds.sel(lat=la, lon=lo).plot( ax=ax, label=f"lat {la}, lon {lo}", **kwargs ) if legend and (agg_slice is False or coord_dict): ax.legend() if grid: ax.grid() if time_factor > 1: create_title += f"- time aggregated by factor of {time_factor}." if not title: title = create_title ax.set_title(title, size=title_size) if return_fig: return fig def heatmap( ds: xr.DataArray, t: Optional[Union[int, str]] = None, agg_method: ReductionType = "mean", shape: Optional[gpd.GeoSeries] = None, shape_width: float = 0.5, shape_color: str = "black", map_type: Literal["contour", "colormesh"] = "colormesh", cmap: str = "bone_r", figsize: tuple[float, float] = (10.0, 6.0), title: Optional[str] = None, title_size: float = 12, grid: bool = True, return_fig: bool = False, **kwargs, ) -> Optional[plt.Figure]: """Take an xarray.DataArray and a time index or string, plot contour/colormesh map for its values. Args: ds (xr.DataArray): The target DataArray object. t (Optional[Union[int, str]]): Target timestamp. This could either a numeric time index, or a time string from the xarray.DataArray time dimension. agg_method (Literal["mean", "sum"]): Aggregation method in the time dimension. This is used if t was not not provided. Options can either be mean aggregation or sum aggregation. shape (geopandas.GeoSeries): Shapes to be plotted over the raster. shape_width (float): Width of lines for plotting shapes. 0.5 by default. shape_color (str): Color of the shape line. Black by default. map_type (Literal["contour", "colormesh"]): Map type. This can either be "contour" or "colormesh". cmap (str): The color of the heat map, select one from matplotlib.pyplot.colormaps. figsize (tuple): The size of the plot. title (Optional[str]): The title of the result plot. Optional. If not provided, the title will be automatically generated. title_size (float): The size of the title of the result plot. 12 by default. coastlines (bool): Whether to add coast lines to the plot, True by default. grid (bool): Whether to add grid lines to the plot, True by default. return_fig (bool): Whether to return the plt.Figure object. True by default **kwargs: Additional arguments for xarray.DataArray.plot.pcolormesh or xarray.DataArray.plot.contourf, depending on selected `map_type`. Returns: Optional[plt.Figure]: The figure object if `return_fig` is True. Raises: ValueError: If the map type is not supported. """ if map_type not in {"contour", "colormesh"}: raise ValueError(f"map_type {map_type} is not supported.") if cmap not in plt.colormaps(): raise ValueError( "Please see available colormaps through: matplotlib.pyplot.colormaps() or " "https://matplotlib.org/stable/gallery/color/colormap_reference.html" ) ds = ds_reformat_index(ds) fig = plt.figure(figsize=figsize) ax = fig.add_subplot() if t is not None: if isinstance(t, int): time_idx = ds.time.data[t] ds = ds.isel(time=t) else: ds = ds.sel(time=t) else: if agg_method == "mean": ds = ds.mean(axis=0) elif agg_method == "sum": ds = ds.sum(axis=0) if map_type == "contour": ds.plot.contourf("lon", "lat", ax=ax, cmap=cmap, **kwargs) elif map_type == "colormesh": ds.plot.pcolormesh("lon", "lat", ax=ax, cmap=cmap, **kwargs) if shape is not None: shape.boundary.plot(ax=ax, linewidth=shape_width, color=shape_color) ax.set_xlim(ds.lon.min(), ds.lon.max()) ax.set_ylim(ds.lat.min(), ds.lat.max()) if not title: if t is None: title = f"{ds.name} aggregated {agg_method}" elif isinstance(t, int): title = f"{ds.name} Amount at time index {t} - {time_idx}" elif isinstance(t, str): title = f"{ds.name} Amount at {t}" ax.set_title(title, size=title_size) if grid: ax.grid() if return_fig: return fig def heatmap_animation( ds: xr.DataArray, time_factor: float = 1, agg_method: ReductionType = "mean", shape: Optional[gpd.GeoSeries] = None, shape_width: float = 0.5, shape_color: str = "black", cmap: str = "bone_r", v_max: Optional[float] = None, ds_name: Optional[str] = None, figsize: tuple[float, float] = (10, 5), title: Optional[str] = None, title_size: float = 12, grid: bool = True, **kwargs, ): """Created animated version of `colormesh` so users can see the value change over time at default, each frame is the average or sum of value per time_unit * time_factor. Args: ds (xarray.DataArray): The target DataArray object. time_factor (float): Tthe factor to aggregate the value of DataArray on Example: for daily mean on hourly data, time_factor = 24. Defaults to 1. agg_method (str): Aggregation method. Can be either `mean` or `sum`. shape (geopandas.GeoSeries): Shapes to be plotted over the raster. shape_width (float): The line width for plotting shapes. 0.5 by default. shape_color (str): Color of the shape line. Black by default. cmap (str): The color of the heat map, select one from matplotlib.pyplot.colormaps. v_max (float): The maximum value in the heatmap. ds_name (str): Name of the DataArray to be shown on title. figsize (tuple): The size of the plot. (10, 5) by default. title (str): The title of the result plot. Optional. If not provided, the title will be automatically generated. title_size (float): The size of the title of the result plot. coastlines (bool): Whether to add coast lines to the plot, True by default. grid (bool): Whether to add grid lines to the plot, True by default. **kwargs (dict): Additional arguments for xarray.DataArray.plot.imshow. Returns: matplotlib.animation.FuncAnimation: The animation object. Raises: ValueError: If the DataArray does not contain the time dimension. ValueError: If the colormap is not supported. NotImplementedError: If the aggregation method is not supported. """ if "time" not in ds.dims: raise ValueError("The DataArray must contain the time dimension") if cmap not in plt.colormaps(): raise ValueError( "Please see available colormaps through: matplotlib.pyplot.colormaps() or " "https://matplotlib.org/stable/gallery/color/colormap_reference.html" ) ds = ds_reformat_index(ds) if not ds_name: ds_name = ds.name if agg_method == "mean": ds = ds.coarsen(time=time_factor, boundary="trim").mean() elif agg_method == "sum": ds = ds.coarsen(time=time_factor, boundary="trim").sum() else: raise NotImplementedError(f"agg_method {agg_method} is not supported.") if v_max is None: v_max = ds.max() fig = plt.figure(figsize=figsize) ax = fig.add_subplot() if shape is not None: shape.boundary.plot(ax=ax, linewidth=shape_width, color=shape_color) # initial frame image = ds.isel(time=0).plot.imshow(ax=ax, vmin=0, vmax=v_max, cmap=cmap, **kwargs) if grid: ax.grid() def update(t): """function to update each frame""" if title is None: ax.set_title(f"{ds_name} at time = {t}", size=title_size) else: ax.set_title(title, size=title_size) image.set_array(ds.sel(time=t)) return image animation = anim.FuncAnimation(fig, update, frames=ds.time.values, blit=False) return animation def save_animation(file_name: str): """If the Ipython notebook is opened in a browser, and an animation output was already generated. This functioon saves the animation to a file. Args: file_name (str): The output file name. """ javascript = ( """ <script type="text/Javascript"> function set_value(){ elements = document.getElementsByClassName('output_subarea output_html rendered_html output_result') var var_values = '' for (i = 0; i < elements.length; i++){ if (elements[i].getElementsByClassName('animation').length != 0){ var_values += elements[i].innerHTML; }} (function(console){ /* credit of the console.save function: stackoverflow.com/questions/11849562/*/ console.save = function(data, filename){ if(!data) { console.error('Console.save: No data') return; } if(!filename) filename = 'console.json' if(typeof data === "object"){ data = JSON.stringify(data, undefined, 4) } var blob = new Blob([data], {type: 'text/json'}), e = document.createEvent('MouseEvents'), a = document.createElement('a') a.download = filename a.href = window.URL.createObjectURL(blob) a.dataset.downloadurl = ['text/json', a.download, a.href].join(':') e.initMouseEvent('click', true, false, window, 0, 0, 0, 0, 0, false, false, false, false, 0, null) a.dispatchEvent(e) } })(console) console.save(var_values, '""" # noqa: E501 + file_name + """') } set_value() </script> """ ) return javascript