Source code for huracanpy.plot._venn
"""Venn diagrams for tracks matching visualisation"""
from matplotlib_venn import venn2, venn2_circles, venn3, venn3_circles
import numpy as np
[docs]
def venn(datasets, match, labels, colors=None, circle_color="k"):
"""
Plot venn diagram to compare the datasets.
Parameters
----------
datasets : list of xr.dataset
list of the datasets compared.
match : pd.DataFrame
match dataframe issued from match_pair or match_multiple.
labels : list of str
labels of the datasets.
colors : list of str, optional
list of colors to be used for each dataset. The default is None.
circle_color : str, optional
color of the overlaid circles. The default is "k".
Raises
------
NotImplementedError
If more than three or less than two datasets are given.
Returns
-------
None.
"""
if len(datasets) == 2:
f = _venn_2datasets
elif len(datasets) == 3:
f = _venn_3datasets
else:
raise NotImplementedError(
"We cannot plot Venn diagrams for more than 3 datasets."
)
if len(datasets) != len(labels):
raise ValueError("datasets and labels must have the same length")
if colors is None:
colors = ["w"] * len(datasets)
else:
if len(colors) != len(datasets):
raise ValueError("datasets and colors must have the same length")
f(*datasets, match, colors, labels, circle_color)
def _venn_2datasets(data1, data2, match, colors, labels=None, circle_color="k"):
N1 = len(np.unique(data1.track_id.values)) # Number of tracks in dataset 1
N2 = len(np.unique(data2.track_id.values)) # Number of tracks in dataset 2
m = len(match) # Number of tracks matching
venn2((N1 - m, N2 - m, m), set_colors=colors, set_labels=labels)
venn2_circles((N1 - m, N2 - m, m), color=circle_color)
def _venn_3datasets(data1, data2, data3, M, colors, labels=None, circle_color="k"):
N1 = len(np.unique(data1.track_id.values)) # Number of tracks in dataset 1
N2 = len(np.unique(data2.track_id.values)) # Number of tracks in dataset 2
N3 = len(np.unique(data3.track_id.values)) # Number of tracks in dataset 3
M_not1 = len(M[M.iloc[:, 0].isna()])
M_not2 = len(M[M.iloc[:, 1].isna()])
M_not3 = len(M[M.iloc[:, 2].isna()])
M_all = len(M[M.isna().sum(axis=1) == 0])
subsets = (
(N1 - M_all - M_not2 - M_not3),
(N2 - M_all - M_not1 - M_not3),
M_not3,
(N3 - M_all - M_not1 - M_not2),
M_not2,
M_not1,
M_all,
)
venn3(
subsets,
set_labels=labels,
set_colors=colors,
)
venn3_circles(subsets, color=circle_color)