Source code for huracanpy._accessor

import xarray as xr
from metpy.units import units
import pandas as pd

from . import (
    plot,
    tc,
    info,
    calc,
    save,
    interp_time,
    sel_id,
    trackswhere,
)


@xr.register_dataarray_accessor("hrcn")
class HuracanPyDataArrayAccessor:
    def __init__(self, dataarray):
        self._dataarray = dataarray.copy()

    def nunique(self):
        """
        Method to count number of unique element in a DataArray

        Returns
        -------
        TYPE
            DESCRIPTION.

        """
        return pd.Series(self._dataarray).nunique()


@xr.register_dataset_accessor("hrcn")
class HuracanPyDatasetAccessor:
    def __init__(self, dataset):
        self._dataset = dataset

    # %% Save
[docs] def save(self, filename): """ Save dataset as filename. The file type (NetCDF or csv supported) is detected based on filename extension. Parameters ---------- filename : str Must end in ".nc" or ".csv" Returns ------- None. """ save(self._dataset, filename)
[docs] def sel_id(self, track_id, track_id_name="track_id"): return sel_id(self._dataset, self._dataset[track_id_name], track_id)
[docs] def trackswhere(self, condition, track_id_name="track_id"): return trackswhere(self._dataset, self._dataset[track_id_name], condition)
# %% utils # ---- geography
[docs] def get_hemisphere(self, lat_name="lat"): return info.hemisphere(self._dataset[lat_name])
[docs] def add_hemisphere(self, lat_name="lat"): self._dataset["hemisphere"] = self.get_hemisphere(lat_name=lat_name) return self._dataset
[docs] def get_basin(self, lon_name="lon", lat_name="lat", convention="WMO-TC", crs=None): return info.basin( self._dataset[lon_name], self._dataset[lat_name], convention=convention, crs=crs, )
[docs] def add_basin(self, lon_name="lon", lat_name="lat", convention="WMO-TC", crs=None): self._dataset["basin"] = self.get_basin(lon_name, lat_name, convention, crs) return self._dataset
[docs] def get_is_land(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): return info.is_land( self._dataset[lon_name], self._dataset[lat_name], resolution=resolution, crs=crs, )
[docs] def add_is_land(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): self._dataset["is_land"] = self.get_is_land(lon_name, lat_name, resolution, crs) return self._dataset
[docs] def get_is_ocean(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): return info.is_ocean( self._dataset[lon_name], self._dataset[lat_name], resolution=resolution, crs=crs, )
[docs] def add_is_ocean(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): self._dataset["is_ocean"] = self.get_is_ocean( lon_name, lat_name, resolution, crs ) return self._dataset
[docs] def get_country(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): return info.country( self._dataset[lon_name], self._dataset[lat_name], resolution=resolution, crs=crs, )
[docs] def add_country(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): self._dataset["country"] = self.get_country(lon_name, lat_name, resolution, crs) return self._dataset
[docs] def get_continent(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): return info.continent( self._dataset[lon_name], self._dataset[lat_name], resolution=resolution, crs=crs, )
[docs] def add_continent(self, lon_name="lon", lat_name="lat", resolution="10m", crs=None): self._dataset["continent"] = self.get_continent( lon_name, lat_name, resolution, crs ) return self._dataset
# ---- ACE & PACE
[docs] def get_ace( self, wind_name="wind", sum_by=None, threshold=34 * units("knots"), wind_units="m s-1", ): """ Calculate Accumulated Cyclone Energy (ACE) for each point. """ if sum_by is not None: sum_by = self._dataset[sum_by] return tc.ace( self._dataset[wind_name], sum_by=sum_by, threshold=threshold, wind_units=wind_units, )
[docs] def add_ace( self, wind_name="wind", threshold=34 * units("knots"), wind_units="m s-1" ): """ Add ACE calculation to the dataset. """ self._dataset["ace"] = self.get_ace( wind_name, sum_by=None, threshold=threshold, wind_units=wind_units ) return self._dataset
[docs] def get_pace( self, pressure_name="slp", wind_name=None, model=None, sum_by=None, threshold_wind=None, threshold_pressure=None, wind_units="m s-1", **kwargs, ): """ Calculate Pressure-based Accumulated Cyclone Energy (PACE) for each point. """ pace_values, model = tc.pace( self._dataset[pressure_name], wind=self._dataset[wind_name] if wind_name else None, model=model, sum_by=self._dataset[sum_by] if sum_by else None, threshold_wind=threshold_wind, threshold_pressure=threshold_pressure, wind_units=wind_units, **kwargs, ) return pace_values, model
[docs] def add_pace( self, pressure_name="slp", wind_name=None, model=None, threshold_wind=None, threshold_pressure=None, wind_units="m s-1", **kwargs, ): """ Add PACE calculation to the dataset. """ pace_values, model = self.get_pace( pressure_name=pressure_name, wind_name=wind_name, model=model, sum_by=None, threshold_wind=threshold_wind, threshold_pressure=threshold_pressure, wind_units=wind_units, **kwargs, ) self._dataset["pace"] = pace_values return self._dataset, model
# ---- time
[docs] def get_time_components(self, time_name="time"): """ Expand the time variable into year, month, day, and hour. """ return info.time_components(self._dataset[time_name])
[docs] def add_time_components(self, time_name="time"): """ Add year, month, day, and hour as new variables to the dataset. """ components = self.get_time_components(time_name) return xr.merge([self._dataset, *components])
[docs] def get_season( self, track_id_name="track_id", lat_name="lat", time_name="time", convention="tc-short", ): """ Derive the season for each track based on latitude and time. """ return info.season( self._dataset[track_id_name], self._dataset[lat_name], self._dataset[time_name], convention=convention, )
[docs] def add_season( self, track_id_name="track_id", lat_name="lat", time_name="time", convention="tc-short", ): """ Add the season as a new variable to the dataset. """ self._dataset["season"] = self.get_season( track_id_name, lat_name, time_name, convention ) return self._dataset
# --- utils def get_inferred_track_id(self, *variable_names): return info.inferred_track_id(*[self._dataset[var] for var in variable_names]) def add_inferred_track_id(self, *variable_names, track_id_name="track_id"): self._dataset[track_id_name] = self.get_inferred_track_id(*variable_names) return self._dataset # --- category
[docs] def get_category( self, var_name, bins=None, labels=None, variable_units=None, ): """ Calculate a generic category from a variable and a set of thresholds. """ return info.category( self._dataset[var_name], bins=bins, labels=labels, variable_units=variable_units, )
[docs] def add_category( self, var_name, new_var_name=None, bins=None, labels=None, variable_units=None, ): """ Add a generic category to the dataset as a new variable. """ if new_var_name is None: new_var_name = f"category_{var_name}" self._dataset[new_var_name] = self.get_category( var_name, bins=bins, labels=labels, variable_units=variable_units, ) return self._dataset
[docs] def get_saffir_simpson_category( self, wind_name="wind", convention="Saffir-Simpson", wind_units="m s-1" ): """ Determine the Saffir-Simpson Hurricane Scale (SSHS) category. """ return tc.saffir_simpson_category( self._dataset[wind_name], convention=convention, wind_units=wind_units )
[docs] def add_saffir_simpson_category( self, wind_name="wind", convention="Saffir-Simpson", wind_units="m s-1" ): """ Add the SSHS category to the dataset. """ self._dataset["saffir_simpson_category"] = self.get_saffir_simpson_category( wind_name, convention, wind_units ) return self._dataset
[docs] def get_pressure_category( self, slp_name="slp", convention="Klotzbach", slp_units=None ): """ Determine the pressure category based on the selected convention. """ return tc.pressure_category( self._dataset[slp_name], convention=convention, slp_units=slp_units )
[docs] def add_pressure_category( self, slp_name="slp", convention="Klotzbach", slp_units=None ): """ Add the pressure category to the dataset. """ self._dataset["pressure_category"] = self.get_pressure_category( slp_name, convention, slp_units ) return self._dataset
def get_beta_drift(self, lat_name="lat", wind_name="wind", rmw_name="rmw"): return tc.beta_drift( self._dataset[lat_name], self._dataset[wind_name], self._dataset[rmw_name] ) def add_beta_drift(self, lat_name="lat", wind_name="wind", rmw_name="rmw"): v_drift, theta_drift = self.get_beta_drift(lat_name, wind_name, rmw_name) self._dataset["v_drift"] = v_drift self._dataset["theta_drift"] = theta_drift return self._dataset # ---- translation
[docs] def get_azimuth( self, lon_name="lon", lat_name="lat", track_id_name="track_id", ellps="WGS84", ): """ Compute the azimuth between points along a track """ if track_id_name in list(self._dataset.variables): return calc.azimuth( self._dataset[lon_name], self._dataset[lat_name], track_id=self._dataset[track_id_name], ellps=ellps, ) if (track_id_name is None) or ( track_id_name not in list(self._dataset.variables) ): return calc.azimuth( self._dataset[lon_name], self._dataset[lat_name], track_id=None, ellps=ellps, )
[docs] def add_azimuth( self, lon_name="lon", lat_name="lat", track_id_name="track_id", ellps="WGS84", ): """ Add the azimuth calculation to the dataset. """ self._dataset["azimuth"] = self.get_azimuth( lon_name, lat_name, track_id_name, ellps ) return self._dataset
[docs] def get_distance( self, lon_name="lon", lat_name="lat", track_id_name="track_id", method="geod", ellps="WGS84", ): """ Compute the distance between points along a track. """ if track_id_name in list(self._dataset.variables): return calc.distance( self._dataset[lon_name], self._dataset[lat_name], track_id=self._dataset[track_id_name], method=method, ellps=ellps, ) if (track_id_name is None) or ( track_id_name not in list(self._dataset.variables) ): return calc.distance( self._dataset[lon_name], self._dataset[lat_name], track_id=None, method=method, ellps=ellps, )
[docs] def add_distance( self, lon_name="lon", lat_name="lat", track_id_name="track_id", method="geod", ellps="WGS84", ): """ Add the distance calculation to the dataset. """ self._dataset["distance"] = self.get_distance( lon_name, lat_name, track_id_name, method, ellps ) return self._dataset
[docs] def get_translation_speed( self, lon_name="lon", lat_name="lat", time_name="time", track_id_name="track_id", method="geod", ellps="WGS84", ): """ Compute the translation speed along tracks. """ if track_id_name in list(self._dataset.variables): return calc.translation_speed( self._dataset[lon_name], self._dataset[lat_name], self._dataset[time_name], track_id=self._dataset[track_id_name], method=method, ellps=ellps, ) if (track_id_name is None) or ( track_id_name not in list(self._dataset.variables) ): return calc.translation_speed( self._dataset[lon_name], self._dataset[lat_name], self._dataset[time_name], track_id=None, method=method, ellps=ellps, )
[docs] def add_translation_speed( self, lon_name="lon", lat_name="lat", time_name="time", track_id_name="track_id", method="geod", ellps="WGS84", ): """ Add the translation speed calculation to the dataset. """ self._dataset["translation_speed"] = self.get_translation_speed( lon_name, lat_name, time_name, track_id_name, method, ellps ) return self._dataset
# ---- rates
[docs] def get_delta(self, var_name="wind10", track_id_name="track_id", **kwargs): if track_id_name in list(self._dataset.variables): return calc.delta( self._dataset[var_name], track_ids=self._dataset[track_id_name], **kwargs, ) if (track_id_name is None) or ( track_id_name not in list(self._dataset.variables) ): return calc.delta(self._dataset[var_name], track_ids=None, **kwargs)
[docs] def add_delta(self, var_name="wind10", track_id_name="track_id", **kwargs): """ Add the distance calculation to the dataset. """ self._dataset["delta_" + var_name] = self.get_delta( var_name, track_id_name, **kwargs ) return self._dataset
[docs] def get_rate( self, var_name="wind10", time_name="time", track_id_name="track_id", **kwargs ): if track_id_name in list(self._dataset.variables): return calc.rate( self._dataset[var_name], self._dataset[time_name], track_ids=self._dataset[track_id_name], **kwargs, ) if (track_id_name is None) or ( track_id_name not in list(self._dataset.variables) ): return calc.rate( self._dataset[var_name], self._dataset[time_name], track_ids=None, **kwargs, )
[docs] def add_rate( self, var_name="wind10", time_name="time", track_id_name="track_id", **kwargs ): self._dataset["rate_" + var_name] = self.get_rate( var_name, time_name, track_id_name, **kwargs ) return self._dataset
# ---- interp
[docs] def interp_time(self, freq="1h", track_id_name="track_id", prog_bar=False): """ Interpolate track data at a given frequency. """ return interp_time( self._dataset, self._dataset[track_id_name], freq=freq, prog_bar=prog_bar )
# ---- lifecycle
[docs] def get_time_from_genesis(self, time_name="time", track_id_name="track_id"): return calc.time_from_genesis( self._dataset[time_name], self._dataset[track_id_name] )
[docs] def add_time_from_genesis(self, time_name="time", track_id_name="track_id"): self._dataset["time_from_genesis"] = self.get_time_from_genesis( time_name, track_id_name ) return self._dataset
[docs] def get_time_from_apex( self, time_name="time", track_id_name="track_id", intensity_var_name="wind", stat="max", ): return calc.time_from_apex( self._dataset[time_name], self._dataset[track_id_name], self._dataset[intensity_var_name], stat=stat, )
[docs] def add_time_from_apex( self, time_name="time", track_id_name="track_id", intensity_var_name="wind", stat="max", ): self._dataset["time_from_apex"] = self.get_time_from_apex( time_name, track_id_name, intensity_var_name, stat ) return self._dataset
# %% plot
[docs] def plot_tracks( self, lon_name="lon", lat_name="lat", intensity_var_name=None, **kwargs ): if intensity_var_name in list(self._dataset.variables): intensity_var = self._dataset[intensity_var_name] else: intensity_var = None return plot.tracks( self._dataset[lon_name], self._dataset[lat_name], intensity_var, **kwargs )
[docs] def plot_density( self, lon_name="lon", lat_name="lat", density_kws=dict(), **kwargs ): d = self.get_density(lon_name=lon_name, lat_name=lat_name, **density_kws) return plot.density(d, **kwargs)
def plot_fancyline( self, lon_name="lon", lat_name="lat", track_id_name="track_id", colors=None, linewidths=None, alphas=None, linestyles=None, **kwargs, ): extra_names = dict( colors=colors, linewidths=linewidths, alphas=alphas, linestyles=linestyles ) output = [] for track_id, track in self._dataset.groupby(track_id_name): # Allow the other variables to be passed as variable names or constant # strings # e.g. colors can be a variable on the track or could just be "red" extra_variables = { key: (track[name] if name in track else name) for key, name in extra_names.items() } output.append( plot.fancyline( track[lon_name], track[lat_name], **extra_variables, **kwargs ) ) return output # %% diags # ---- density
[docs] def get_density(self, lon_name="lon", lat_name="lat", method="histogram", **kwargs): return calc.density( self._dataset[lon_name], self._dataset[lat_name], method=method, **kwargs )
# ---- track stats
[docs] def get_track_duration(self, time_name="time", track_id_name="track_id"): return calc.track_duration( self._dataset[time_name], self._dataset[track_id_name] )
[docs] def get_gen_vals(self, time_name="time", track_id_name="track_id"): return calc.gen_vals( self._dataset, self._dataset[time_name], self._dataset[track_id_name] )
[docs] def get_apex_vals(self, var_name, track_id_name="track_id", stat="max"): return calc.apex_vals( self._dataset, variable=self._dataset[var_name], track_id=self._dataset[track_id_name], stat=stat, )