Source code for pg_gpu.plotting

"""
Population genetics visualization functions.

Provides publication-quality plots for SFS, LD, haplotype structure,
PCA, distance matrices, and windowed statistics using matplotlib and seaborn.
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from typing import Optional, Union, Tuple

from .windowed_analysis import CANONICAL_WINDOW_PREFIX


# ---------------------------------------------------------------------------
# Site Frequency Spectrum
# ---------------------------------------------------------------------------

[docs] def plot_sfs(s, ax=None, folded=False, scaled=False, clip_endpoints=True, log_scale=True, color=None, **kwargs): """Plot a site frequency spectrum. Parameters ---------- s : array_like SFS array (from pg_gpu.sfs.sfs, sfs_folded, etc.). ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates a new figure. folded : bool Label x-axis as minor allele count. scaled : bool Label y-axis as scaled count. clip_endpoints : bool If True, exclude the first and last entries (monomorphic classes). log_scale : bool If True, use log scale for y-axis. color : str, optional Bar color. Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(8, 4)) s = np.asarray(s, dtype=float) if clip_endpoints: s = s[1:-1] if not folded else s[1:] x = np.arange(1, len(s) + 1) else: x = np.arange(len(s)) if color is None: color = sns.color_palette()[0] ax.bar(x, s, color=color, edgecolor='none', **kwargs) if log_scale and np.any(s > 0): ax.set_yscale('log') xlabel = 'Minor allele count' if folded else 'Derived allele count' ylabel = 'Scaled count' if scaled else 'Count' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) sns.despine(ax=ax) return ax
[docs] def plot_joint_sfs(s, ax=None, log_scale=True, cmap='YlOrRd', clip_endpoints=True, **kwargs): """Plot a joint site frequency spectrum as a heatmap. Parameters ---------- s : array_like, 2D Joint SFS array (from pg_gpu.sfs.joint_sfs, etc.). ax : matplotlib.axes.Axes, optional log_scale : bool If True, use log color scale. cmap : str Colormap name. clip_endpoints : bool If True, exclude monomorphic classes (row/col 0 and last). Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(6, 5)) s = np.asarray(s, dtype=float) if clip_endpoints: s = s[1:, 1:] # mask zeros for log scale s_plot = np.ma.masked_equal(s, 0) norm = mcolors.LogNorm(vmin=max(1, s_plot.min()), vmax=s_plot.max()) if log_scale else None im = ax.pcolormesh(s_plot, norm=norm, cmap=cmap, **kwargs) plt.colorbar(im, ax=ax, label='Count') offset = 1 if clip_endpoints else 0 ax.set_xlabel('Derived allele count (pop 2)') ax.set_ylabel('Derived allele count (pop 1)') sns.despine(ax=ax) return ax
# --------------------------------------------------------------------------- # Linkage Disequilibrium # ---------------------------------------------------------------------------
[docs] def plot_pairwise_ld(r2_matrix, ax=None, cmap='Greys', vmin=0, vmax=1, positions=None, **kwargs): """Plot a pairwise LD (r-squared) matrix as a heatmap. Parameters ---------- r2_matrix : array_like, 2D Square r-squared matrix (from HaplotypeMatrix.pairwise_r2()). ax : matplotlib.axes.Axes, optional cmap : str Colormap. vmin, vmax : float Color scale range. positions : array_like, optional Variant positions for axis labels. Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(7, 6)) r2 = np.asarray(r2_matrix) im = ax.imshow(r2, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none', origin='lower', **kwargs) plt.colorbar(im, ax=ax, label='$r^2$') if positions is not None: positions = np.asarray(positions) n_ticks = min(6, len(positions)) tick_idx = np.linspace(0, len(positions) - 1, n_ticks, dtype=int) ax.set_xticks(tick_idx) ax.set_xticklabels([f'{positions[i]/1000:.0f}kb' for i in tick_idx], rotation=45) ax.set_yticks(tick_idx) ax.set_yticklabels([f'{positions[i]/1000:.0f}kb' for i in tick_idx]) ax.set_xlabel('Variant') ax.set_ylabel('Variant') return ax
[docs] def plot_ld_decay(distances, r2_values, ax=None, bins=50, color=None, percentile=50, **kwargs): """Plot LD decay curve (r-squared vs distance). Parameters ---------- distances : array_like Pairwise distances between variants. r2_values : array_like Pairwise r-squared values. ax : matplotlib.axes.Axes, optional bins : int or array_like Number of distance bins or explicit bin edges. color : str, optional percentile : float Percentile of r-squared to plot per bin (default: median). Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(8, 4)) distances = np.asarray(distances) r2_values = np.asarray(r2_values) valid = ~np.isnan(r2_values) distances = distances[valid] r2_values = r2_values[valid] if isinstance(bins, int): bins = np.linspace(distances.min(), distances.max(), bins + 1) bin_idx = np.digitize(distances, bins) - 1 n_bins = len(bins) - 1 bin_centers = (bins[:-1] + bins[1:]) / 2 bin_vals = np.full(n_bins, np.nan) for i in range(n_bins): mask = bin_idx == i if np.sum(mask) > 0: bin_vals[i] = np.percentile(r2_values[mask], percentile) if color is None: color = sns.color_palette()[0] ax.plot(bin_centers, bin_vals, '-o', color=color, markersize=3, **kwargs) ax.set_xlabel('Distance (bp)') ax.set_ylabel('$r^2$') ax.set_ylim(bottom=0) sns.despine(ax=ax) return ax
# --------------------------------------------------------------------------- # PCA / Population Structure # ---------------------------------------------------------------------------
[docs] def plot_pca(coords, ax=None, pc_x=0, pc_y=1, explained_variance=None, labels=None, palette='Set2', s=20, alpha=0.8, **kwargs): """Scatter plot of PCA coordinates. Parameters ---------- coords : array_like, shape (n_samples, n_components) PCA coordinates (from pg_gpu.decomposition.pca). ax : matplotlib.axes.Axes, optional pc_x, pc_y : int Which principal components to plot (0-indexed). explained_variance : array_like, optional Variance explained ratios for axis labels. labels : array_like, optional Population labels per sample for coloring. palette : str Seaborn palette name. s : float Marker size. alpha : float Marker transparency. Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(7, 6)) coords = np.asarray(coords) if labels is not None: labels = np.asarray(labels) unique_labels = np.unique(labels) colors = sns.color_palette(palette, len(unique_labels)) for i, label in enumerate(unique_labels): mask = labels == label ax.scatter(coords[mask, pc_x], coords[mask, pc_y], c=[colors[i]], label=str(label), s=s, alpha=alpha, **kwargs) ax.legend(title='Population', frameon=True) else: color = sns.color_palette(palette)[0] ax.scatter(coords[:, pc_x], coords[:, pc_y], c=[color], s=s, alpha=alpha, **kwargs) if explained_variance is not None: ev = np.asarray(explained_variance) ax.set_xlabel(f'PC{pc_x + 1} ({ev[pc_x]:.1%})') ax.set_ylabel(f'PC{pc_y + 1} ({ev[pc_y]:.1%})') else: ax.set_xlabel(f'PC{pc_x + 1}') ax.set_ylabel(f'PC{pc_y + 1}') sns.despine(ax=ax) return ax
# --------------------------------------------------------------------------- # Distance Matrix # ---------------------------------------------------------------------------
[docs] def plot_pairwise_distance(dist, ax=None, labels=None, cmap='viridis', **kwargs): """Plot a pairwise distance matrix as a heatmap. Parameters ---------- dist : array_like Condensed or square distance matrix. ax : matplotlib.axes.Axes, optional labels : array_like, optional Sample labels for axes. cmap : str Colormap. Returns ------- ax : matplotlib.axes.Axes """ from scipy.spatial.distance import squareform if ax is None: fig, ax = plt.subplots(figsize=(7, 6)) dist = np.asarray(dist, dtype=float) if dist.ndim == 1: D = squareform(dist) else: D = dist im = ax.imshow(D, cmap=cmap, interpolation='none', origin='lower', **kwargs) plt.colorbar(im, ax=ax, label='Distance') if labels is not None: labels = np.asarray(labels) ax.set_xticks(range(len(labels))) ax.set_xticklabels(labels, rotation=90, fontsize=8) ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels, fontsize=8) return ax
# --------------------------------------------------------------------------- # Windowed Statistics # ---------------------------------------------------------------------------
[docs] def plot_windowed(window_start, values, ax=None, stat_name='', color=None, fill_alpha=0.15, **kwargs): """Plot windowed statistics along the genome. Parameters ---------- window_start : array_like Window start positions (from windowed_statistics results). values : array_like Statistic values per window. ax : matplotlib.axes.Axes, optional stat_name : str Label for y-axis. color : str, optional fill_alpha : float Alpha for fill between line and zero. Returns ------- ax : matplotlib.axes.Axes """ if ax is None: fig, ax = plt.subplots(figsize=(10, 3)) window_start = np.asarray(window_start) values = np.asarray(values, dtype=float) if color is None: color = sns.color_palette()[0] # convert to kb or Mb for readability if window_start.max() > 1e6: x = window_start / 1e6 xlabel = 'Position (Mb)' elif window_start.max() > 1e3: x = window_start / 1e3 xlabel = 'Position (kb)' else: x = window_start xlabel = 'Position (bp)' ax.plot(x, values, '-', color=color, linewidth=1, **kwargs) ax.fill_between(x, values, alpha=fill_alpha, color=color) ax.set_xlabel(xlabel) ax.set_ylabel(stat_name) sns.despine(ax=ax) return ax
[docs] def plot_windowed_panel(results, statistics=None, figsize=None, colors=None): """Plot multiple windowed statistics as stacked panels. Parameters ---------- results : dict Output from windowed_statistics() or windowed_statistics_fused(). Must contain 'start' and statistic arrays. statistics : list of str, optional Which statistics to plot. If None, plots all non-metadata keys. figsize : tuple, optional Figure size. colors : list, optional Colors for each panel. Returns ------- fig : matplotlib.figure.Figure axes : list of matplotlib.axes.Axes """ skip_keys = set(CANONICAL_WINDOW_PREFIX) if statistics is None: statistics = [k for k in results if k not in skip_keys] n_panels = len(statistics) if figsize is None: figsize = (10, 2.5 * n_panels) if colors is None: colors = sns.color_palette('Set2', n_panels) fig, axes = plt.subplots(n_panels, 1, figsize=figsize, sharex=True) if n_panels == 1: axes = [axes] window_start = results['start'] for i, stat in enumerate(statistics): plot_windowed(window_start, results[stat], ax=axes[i], stat_name=stat, color=colors[i]) plt.tight_layout() return fig, axes
# --------------------------------------------------------------------------- # Haplotype Visualization # ---------------------------------------------------------------------------
[docs] def plot_haplotype_frequencies(freqs, ax=None, palette='Paired', singleton_color='white'): """Plot haplotype frequencies as a horizontal stacked bar. Parameters ---------- freqs : array_like Sorted haplotype frequencies (descending), e.g. from Garud's H computation. ax : matplotlib.axes.Axes, optional palette : str Seaborn palette for coloring distinct haplotypes. singleton_color : str Color for singleton haplotypes. Returns ------- ax : matplotlib.axes.Axes """ if ax is None: width = plt.rcParams['figure.figsize'][0] fig, ax = plt.subplots(figsize=(width, width / 10)) sns.despine(ax=ax, left=True) freqs = np.asarray(freqs, dtype=float) n_colors = np.sum(freqs > 1.0 / len(freqs)) if len(freqs) > 0 else 0 n_colors = max(n_colors, 1) colors = sns.color_palette(palette, n_colors) x = 0 for i, f in enumerate(freqs): if i < n_colors: color = colors[i] else: color = singleton_color ax.axvspan(x, x + f, color=color) x += f ax.set_xlim(0, 1) ax.set_yticks([]) ax.set_xlabel('Haplotype frequency') return ax
[docs] def plot_variant_locator(positions, ax=None, step=None, color=None): """Plot variant positions showing density along the genome. Parameters ---------- positions : array_like Sorted variant positions. ax : matplotlib.axes.Axes, optional step : int, optional Plot every `step`-th variant (for large datasets). color : str, optional Returns ------- ax : matplotlib.axes.Axes """ if ax is None: width = plt.rcParams['figure.figsize'][0] fig, ax = plt.subplots(figsize=(width, width / 7)) positions = np.asarray(positions) if step is not None: positions = positions[::step] if color is None: color = sns.color_palette()[0] n = len(positions) for i, pos in enumerate(positions): ax.plot([i, pos], [1, 0], '-', color=color, linewidth=0.3, alpha=0.5) ax.set_xlim(0, max(n - 1, positions[-1])) ax.set_ylim(-0.1, 1.1) ax.set_yticks([]) # dual x-axis labels ax.set_xlabel('Genomic position (bp)') sns.despine(ax=ax, left=True) return ax