Source code for pg_gpu.divergence

"""
GPU-accelerated population divergence statistics.

This module provides efficient computation of population divergence metrics
including FST, Dxy, and related statistics using GPU acceleration.
"""

import warnings

import numpy as np
import cupy as cp
from typing import Union, Tuple, Optional, Dict
from .haplotype_matrix import HaplotypeMatrix
from ._memutil import dac_and_n as _pop_dac_and_n
from .diversity import _apply_span_normalize, pi as _diversity_pi


def dxy_components(pop1_haps, pop2_haps):
    """Compute between-population pairwise differences and comparisons.

    Returns raw counts for custom aggregation (e.g., windowed analysis).

    Parameters
    ----------
    pop1_haps, pop2_haps : cupy.ndarray, shape (n_haplotypes, n_variants)
        Haplotype data for each population, with -1 for missing.

    Returns
    -------
    total_diffs : float
        Sum of pairwise differences across usable sites.
    total_comps : float
        Sum of pairwise comparisons across usable sites.
    n_sites : int
        Number of sites with data in both populations.
    """
    pop1_derived, pop1_n = _pop_dac_and_n(pop1_haps)
    pop2_derived, pop2_n = _pop_dac_and_n(pop2_haps)
    pop1_n = pop1_n.astype(cp.float64)
    pop2_n = pop2_n.astype(cp.float64)
    pop1_derived = pop1_derived.astype(cp.float64)
    pop2_derived = pop2_derived.astype(cp.float64)
    pop1_ancestral = pop1_n - pop1_derived
    pop2_ancestral = pop2_n - pop2_derived

    site_diffs = pop1_derived * pop2_ancestral + pop1_ancestral * pop2_derived
    site_comps = pop1_n * pop2_n

    usable = (pop1_n > 0) & (pop2_n > 0)
    total_diffs = float(cp.sum(site_diffs[usable]).get())
    total_comps = float(cp.sum(site_comps[usable]).get())
    n_sites = int(cp.sum(usable).get())

    return total_diffs, total_comps, n_sites


[docs] def fst(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], method: str = 'hudson', missing_data: str = 'include') -> float: """ Compute FST between two populations. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population (name from sample_sets or list of indices) pop2 : str or list Second population (name from sample_sets or list of indices) method : str FST estimation method ('hudson', 'weir_cockerham', 'nei') missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data Returns ------- float FST value between populations """ if method == 'hudson': return fst_hudson(haplotype_matrix, pop1, pop2, missing_data) elif method == 'weir_cockerham': return fst_weir_cockerham(haplotype_matrix, pop1, pop2, missing_data) elif method == 'nei': return fst_nei(haplotype_matrix, pop1, pop2, missing_data) else: raise ValueError(f"Unknown FST method: {method}")
[docs] def fst_hudson(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include') -> float: """ Compute Hudson's FST estimator between two populations. Hudson (1992) estimator: FST = 1 - (Hw / Hb) where Hw is within-population diversity and Hb is between-population diversity Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population pop2 : str or list Second population missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data Returns ------- float Hudson's FST estimate """ # Ensure data is on GPU if available if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop1, pop2]) if haplotype_matrix.num_variants == 0: return 0.0 pop1_idx = _get_population_indices(haplotype_matrix, pop1) pop2_idx = _get_population_indices(haplotype_matrix, pop2) pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :] pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :] pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps) pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps) pop1_counts = pop1_counts.astype(cp.float64) pop2_counts = pop2_counts.astype(cp.float64) n1 = pop1_n.astype(cp.float64) n2 = pop2_n.astype(cp.float64) pop1_freqs = cp.where(n1 > 0, pop1_counts / n1, 0.0) pop2_freqs = cp.where(n2 > 0, pop2_counts / n2, 0.0) # Per-site within-population mean pairwise difference # mpd(p, n) = p*(1-p)*n/(n-1) (unbiased heterozygosity) within1 = cp.zeros_like(pop1_freqs) within2 = cp.zeros_like(pop2_freqs) valid1 = n1 > 1 valid2 = n2 > 1 within1[valid1] = (pop1_freqs[valid1] * (1 - pop1_freqs[valid1]) * n1[valid1] / (n1[valid1] - 1)) within2[valid2] = (pop2_freqs[valid2] * (1 - pop2_freqs[valid2]) * n2[valid2] / (n2[valid2] - 1)) within = (within1 + within2) / 2.0 # Per-site between-population mean pairwise difference between = (pop1_freqs * (1 - pop2_freqs) + pop2_freqs * (1 - pop1_freqs)) / 2.0 # Numerator and denominator per SNP (ratio-of-averages) num = between - within den = between valid_mask = (den > 0) & (n1 > 0) & (n2 > 0) if cp.any(valid_mask): fst_val = float((cp.sum(num[valid_mask]) / cp.sum(den[valid_mask])).get()) return fst_val else: return 0.0
[docs] def fst_weir_cockerham(haplotype_matrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include') -> float: """ Compute Weir & Cockerham's (1984) FST estimator. This is the method of moments estimator that accounts for sampling variance. Computes all three variance components (a, b, c) including within-individual heterozygosity from paired haplotypes. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data. Consecutive haplotype pairs are treated as diploid individuals for the heterozygosity component. pop1 : str or list First population pop2 : str or list Second population missing_data : str 'include' - per-site sample sizes, ratio-of-sums 'exclude' - Only use sites with no missing data Returns ------- float Weir & Cockerham's FST estimate """ if hasattr(haplotype_matrix, 'device') and haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop1, pop2]) if haplotype_matrix.num_variants == 0: return 0.0 pop1_idx = _get_population_indices(haplotype_matrix, pop1) pop2_idx = _get_population_indices(haplotype_matrix, pop2) pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :] pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :] # Compute per-site observed heterozygosity and allele counts from # complete diploid pairs only. WC FST requires that frequencies and # het are computed from the same set of individuals — using all valid # haplotypes for counts but only complete pairs for n inflates # frequencies by 1/(1-miss_rate) under MCAR. def _pop_diploid_stats(pop_haps): ha = pop_haps[0::2, :] # first haplotype of each diploid hb = pop_haps[1::2, :] # second haplotype both_valid = (ha >= 0) & (hb >= 0) het = (ha != hb) & both_valid n_called = cp.sum(both_valid, axis=0).astype(cp.float64) h_bar = cp.where(n_called > 0, cp.sum(het, axis=0).astype(cp.float64) / n_called, 0.0) # Derived allele count from complete pairs only derived = cp.sum( cp.where(both_valid, ha, 0) + cp.where(both_valid, hb, 0), axis=0).astype(cp.float64) return h_bar, n_called, derived if len(pop1_idx) >= 2 and len(pop2_idx) >= 2: h_bar1, n1, pop1_counts = _pop_diploid_stats(pop1_haps) h_bar2, n2, pop2_counts = _pop_diploid_stats(pop2_haps) else: pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps) pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps) pop1_counts = pop1_counts.astype(cp.float64) pop2_counts = pop2_counts.astype(cp.float64) n1 = pop1_n.astype(cp.float64) n2 = pop2_n.astype(cp.float64) h_bar1 = cp.zeros_like(n1) h_bar2 = cp.zeros_like(n2) # Allele frequencies from complete diploid pairs pop1_freqs = cp.where(n1 > 0, pop1_counts / (2.0 * n1), 0.0) pop2_freqs = cp.where(n2 > 0, pop2_counts / (2.0 * n2), 0.0) r = 2.0 n_total = n1 + n2 n_bar = n_total / r # n_C: sample size correction factor nc = cp.zeros_like(n_total) valid = n_total > 0 nc[valid] = (n_total[valid] - (n1[valid]**2 + n2[valid]**2) / n_total[valid]) / (r - 1) # Weighted average allele frequency p_bar = cp.zeros_like(pop1_freqs) valid = n_total > 0 p_bar[valid] = (n1[valid] * pop1_freqs[valid] + n2[valid] * pop2_freqs[valid]) / n_total[valid] # Sample variance of allele frequencies s_squared = cp.zeros_like(p_bar) valid = n_bar > 0 s_squared[valid] = (n1[valid] * (pop1_freqs[valid] - p_bar[valid])**2 + n2[valid] * (pop2_freqs[valid] - p_bar[valid])**2) / ((r - 1) * n_bar[valid]) # Average observed heterozygosity weighted by sample size h_bar = cp.zeros_like(p_bar) valid = n_total > 0 h_bar[valid] = (n1[valid] * h_bar1[valid] + n2[valid] * h_bar2[valid]) / n_total[valid] # W-C variance components (Eqs 2, 3, 4 from Weir & Cockerham 1984) a = cp.zeros_like(p_bar) b = cp.zeros_like(p_bar) c = cp.zeros_like(p_bar) valid = (n_bar > 1) & (nc > 0) pq = p_bar[valid] * (1 - p_bar[valid]) s2 = s_squared[valid] nb = n_bar[valid] ncc = nc[valid] hb = h_bar[valid] a[valid] = (nb / ncc) * (s2 - (1.0 / (nb - 1)) * (pq - (r - 1) * s2 / r - hb / 4.0)) b[valid] = (nb / (nb - 1)) * (pq - (r - 1) * s2 / r - (2 * nb - 1) * hb / (4.0 * nb)) c[valid] = hb / 2.0 # Global FST = sum(a) / sum(a + b + c) valid_mask = (n1 > 0) & (n2 > 0) if cp.any(valid_mask): sum_a = float(cp.sum(a[valid_mask]).get()) sum_abc = float(cp.sum((a + b + c)[valid_mask]).get()) if sum_abc > 0: return sum_a / sum_abc return 0.0
[docs] def fst_nei(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include') -> float: """ Compute Nei's GST (1973). GST = (HT - HS) / HT where HT is total heterozygosity and HS is within-subpopulation heterozygosity Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population pop2 : str or list Second population missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data Returns ------- float Nei's GST estimate """ # Ensure data is on GPU if available if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop1, pop2]) if haplotype_matrix.num_variants == 0: return 0.0 pop1_idx = _get_population_indices(haplotype_matrix, pop1) pop2_idx = _get_population_indices(haplotype_matrix, pop2) pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :] pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :] pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps) pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps) pop1_counts = pop1_counts.astype(cp.float64) pop2_counts = pop2_counts.astype(cp.float64) n1 = pop1_n.astype(cp.float64) n2 = pop2_n.astype(cp.float64) pop1_freqs = cp.where(n1 > 0, pop1_counts / n1, 0.0) pop2_freqs = cp.where(n2 > 0, pop2_counts / n2, 0.0) # Within-population heterozygosity hs1 = 2.0 * pop1_freqs * (1 - pop1_freqs) hs2 = 2.0 * pop2_freqs * (1 - pop2_freqs) hs = cp.zeros_like(hs1) p_total = cp.zeros_like(pop1_freqs) valid = (n1 + n2) > 0 hs[valid] = (hs1[valid] * n1[valid] + hs2[valid] * n2[valid]) / (n1[valid] + n2[valid]) p_total[valid] = (pop1_freqs[valid] * n1[valid] + pop2_freqs[valid] * n2[valid]) / (n1[valid] + n2[valid]) ht = 2.0 * p_total * (1 - p_total) # Calculate GST - only for sites with sufficient data valid_mask = (ht > 0) & (n1 > 0) & (n2 > 0) if not cp.any(valid_mask): return 0.0 # Ratio-of-averages: sum(HT-HS) / sum(HT) sum_ht = float(cp.sum(ht[valid_mask]).get()) sum_hs = float(cp.sum(hs[valid_mask]).get()) if sum_ht == 0: return 0.0 return (sum_ht - sum_hs) / sum_ht
[docs] def dxy(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], per_site: bool = False, missing_data: str = 'include', span_normalize=True, ) -> Union[float, cp.ndarray]: """ Compute absolute divergence (Dxy) between two populations. Dxy measures the average number of nucleotide differences between sequences from two populations. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population pop2 : str or list Second population per_site : bool If True, return per-site values; if False, return mean missing_data : str ``'include'`` (default) uses per-site valid data. ``'exclude'`` filters to sites with no missing data. span_normalize : bool ``True`` (default): auto-detect best denominator. ``False``: return raw sum / sites-with-data average. Returns ------- float or cp.ndarray Mean Dxy or per-site Dxy values """ # Ensure data is on GPU if available if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop1, pop2]) if haplotype_matrix.num_variants == 0: return 0.0 if not per_site else np.array([]) pop1_idx = _get_population_indices(haplotype_matrix, pop1) pop2_idx = _get_population_indices(haplotype_matrix, pop2) pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :] pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :] n_filtered = pop1_haps.shape[1] # Get allele frequencies from non-missing data per site pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps) pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps) pop1_counts = pop1_counts.astype(cp.float64) pop2_counts = pop2_counts.astype(cp.float64) pop1_n = pop1_n.astype(cp.float64) pop2_n = pop2_n.astype(cp.float64) pop1_freqs = cp.where(pop1_n > 0, pop1_counts / pop1_n, 0.0) pop2_freqs = cp.where(pop2_n > 0, pop2_counts / pop2_n, 0.0) # Calculate Dxy only for sites with data valid_mask = (pop1_n > 0) & (pop2_n > 0) dxy_per_site = cp.zeros(n_filtered) dxy_per_site[valid_mask] = (pop1_freqs[valid_mask] + pop2_freqs[valid_mask] - 2 * pop1_freqs[valid_mask] * pop2_freqs[valid_mask]) if per_site: return dxy_per_site.get() else: dxy_sum = cp.sum(dxy_per_site) if span_normalize is False: n_sites = int(cp.sum(valid_mask).get()) return float(dxy_sum.get() / n_sites) if n_sites > 0 else 0.0 return _apply_span_normalize(dxy_sum, haplotype_matrix, span_normalize)
[docs] def da(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', span_normalize=True) -> float: """ Compute net divergence (Da) between two populations. Da = Dxy - (pi1 + pi2) / 2 where pi1 and pi2 are within-population diversities Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population pop2 : str or list Second population missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data span_normalize : bool ``True`` (default): auto-detect best denominator. ``False``: return raw sum / sites-with-data average. Returns ------- float Net divergence (Da) """ # Get Dxy dxy_value = dxy(haplotype_matrix, pop1, pop2, missing_data=missing_data, span_normalize=span_normalize) # Get within-population diversities pi1 = _diversity_pi(haplotype_matrix, population=pop1, missing_data=missing_data, span_normalize=span_normalize) pi2 = _diversity_pi(haplotype_matrix, population=pop2, missing_data=missing_data, span_normalize=span_normalize) # Calculate Da da_value = dxy_value - (pi1 + pi2) / 2.0 return da_value
def pi_within_population(haplotype_matrix: HaplotypeMatrix, pop: Union[str, list], missing_data: str = 'include', span_normalize=True) -> float: """ Compute nucleotide diversity (pi) within a population. Delegates to ``diversity.pi()`` for consistent behavior. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop : str or list Population name or list of indices missing_data : str 'include' or 'exclude' span_normalize : bool ``True`` (default): auto-detect. ``False``: raw sum. Returns ------- float Nucleotide diversity """ return _diversity_pi(haplotype_matrix, population=pop, span_normalize=span_normalize, missing_data=missing_data)
[docs] def divergence_stats(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], statistics: list = ['fst', 'dxy', 'da'], missing_data: str = 'include', span_normalize=True) -> Dict[str, float]: """ Compute multiple divergence statistics between two populations. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop1 : str or list First population pop2 : str or list Second population statistics : list List of statistics to compute missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data span_normalize : bool ``True`` (default): auto-detect. ``False``: raw sum / per-site average. Returns ------- dict Dictionary of statistic names to values """ results = {} for stat in statistics: if stat == 'fst': results['fst'] = fst(haplotype_matrix, pop1, pop2, missing_data=missing_data) elif stat == 'fst_hudson': results['fst_hudson'] = fst_hudson(haplotype_matrix, pop1, pop2, missing_data=missing_data) elif stat == 'fst_wc': results['fst_wc'] = fst_weir_cockerham(haplotype_matrix, pop1, pop2, missing_data=missing_data) elif stat == 'fst_nei': results['fst_nei'] = fst_nei(haplotype_matrix, pop1, pop2, missing_data=missing_data) elif stat == 'dxy': results['dxy'] = dxy(haplotype_matrix, pop1, pop2, missing_data=missing_data, span_normalize=span_normalize) elif stat == 'da': results['da'] = da(haplotype_matrix, pop1, pop2, missing_data=missing_data, span_normalize=span_normalize) elif stat == 'pi1': results['pi1'] = pi_within_population(haplotype_matrix, pop1, missing_data=missing_data, span_normalize=span_normalize) elif stat == 'pi2': results['pi2'] = pi_within_population(haplotype_matrix, pop2, missing_data=missing_data, span_normalize=span_normalize) else: raise ValueError(f"Unknown statistic: {stat}") return results
[docs] def pairwise_fst(haplotype_matrix: HaplotypeMatrix, populations: Optional[list] = None, method: str = 'hudson', missing_data: str = 'include') -> Tuple[cp.ndarray, list]: """ Compute pairwise FST matrix for multiple populations. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data populations : list, optional List of population names. If None, uses all populations in sample_sets method : str FST method to use missing_data : str 'include' - Use all sites, calculate from available data per site 'exclude' - Only use sites with no missing data Returns ------- fst_matrix : cp.ndarray Pairwise FST matrix pop_names : list Population names in matrix order """ if populations is None: if haplotype_matrix.sample_sets is None: raise ValueError("No populations defined in haplotype matrix") populations = list(haplotype_matrix.sample_sets.keys()) n_pops = len(populations) fst_matrix = cp.zeros((n_pops, n_pops)) for i in range(n_pops): for j in range(i + 1, n_pops): fst_value = fst(haplotype_matrix, populations[i], populations[j], method, missing_data) fst_matrix[i, j] = fst_value fst_matrix[j, i] = fst_value return fst_matrix, populations
def _get_population_indices(haplotype_matrix: HaplotypeMatrix, pop: Union[str, list]) -> list: """ Get population indices from name or list. Parameters ---------- haplotype_matrix : HaplotypeMatrix The haplotype data pop : str or list Population name or list of indices Returns ------- list Population indices """ if isinstance(pop, str): if haplotype_matrix.sample_sets is None: raise ValueError("No sample_sets defined in haplotype matrix") if pop not in haplotype_matrix.sample_sets: raise ValueError(f"Population {pop} not found in sample_sets") return haplotype_matrix.sample_sets[pop] else: return list(pop) def _pop_allele_counts(haplotype_matrix, pop, missing_data='include'): """Compute per-variant allele counts for a population on GPU. Returns (ac_0, ac_1, n) as CuPy arrays. n is per-site (array) for 'include' mode, and also per-site after filtering for 'exclude' mode. """ if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop]) pop_idx = _get_population_indices(haplotype_matrix, pop) h = haplotype_matrix.haplotypes[pop_idx, :] ac_1, n = _pop_dac_and_n(h) n = n.astype(cp.float64) ac_1 = ac_1.astype(cp.float64) ac_0 = n - ac_1 return ac_0, ac_1, n def _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac2_0, ac2_1, n2): """Compute per-variant Hudson FST num/den from precomputed allele counts. Returns (num, den) as CuPy arrays on GPU. """ n1_pairs = n1 * (n1 - 1) / 2 n1_same = (ac1_0 * (ac1_0 - 1) + ac1_1 * (ac1_1 - 1)) / 2 mpd1 = cp.where(n1_pairs > 0, (n1_pairs - n1_same) / n1_pairs, 0.0) n2_pairs = n2 * (n2 - 1) / 2 n2_same = (ac2_0 * (ac2_0 - 1) + ac2_1 * (ac2_1 - 1)) / 2 mpd2 = cp.where(n2_pairs > 0, (n2_pairs - n2_same) / n2_pairs, 0.0) within = (mpd1 + mpd2) / 2.0 n_between = n1 * n2 n_between_same = ac1_0 * ac2_0 + ac1_1 * ac2_1 between = cp.where(n_between > 0, (n_between - n_between_same) / n_between, 0.0) return between - within, between def _windowed_fst(num, den, size, start=0, stop=None, step=None): """Compute windowed FST from per-variant numerator/denominator on GPU. Uses cumulative sums for O(n) windowed reduction. """ n = len(num) if stop is None: stop = n if step is None: step = size num_clean = cp.where(cp.isfinite(num), num, 0.0) den_clean = cp.where(cp.isfinite(den), den, 0.0) cs_num = cp.concatenate([cp.zeros(1, dtype=cp.float64), cp.cumsum(num_clean)]) cs_den = cp.concatenate([cp.zeros(1, dtype=cp.float64), cp.cumsum(den_clean)]) w_starts = cp.arange(start, stop - size + 1, step) num_sum = cs_num[w_starts + size] - cs_num[w_starts] den_sum = cs_den[w_starts + size] - cs_den[w_starts] return cp.where(den_sum != 0, num_sum / den_sum, cp.nan)
[docs] def pbs(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], pop3: Union[str, list], window_size: int, window_start: int = 0, window_stop: Optional[int] = None, window_step: Optional[int] = None, normed: bool = True, missing_data: str = 'include'): """Compute the Population Branch Statistic (PBS). PBS detects genomic regions unusually differentiated in pop1 relative to pop2 and pop3, using pairwise Hudson FST estimates. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data containing all three populations. pop1, pop2, pop3 : str or list Population names or sample indices. window_size : int Number of variants per window. window_start : int, optional Starting variant index. window_stop : int, optional Stopping variant index. window_step : int, optional Stride between windows. Defaults to window_size (non-overlapping). normed : bool, optional If True (default), return normalized PBS (PBSn1). missing_data : str 'include' - per-site sample sizes 'exclude' - only use sites with no missing data Returns ------- ndarray, float64, shape (n_windows,) PBS values per window. """ if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() # precompute allele counts once per population ac1_0, ac1_1, n1 = _pop_allele_counts(haplotype_matrix, pop1, missing_data) ac2_0, ac2_1, n2 = _pop_allele_counts(haplotype_matrix, pop2, missing_data) ac3_0, ac3_1, n3 = _pop_allele_counts(haplotype_matrix, pop3, missing_data) # compute all three pairwise FST num/den from shared counts num12, den12 = _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac2_0, ac2_1, n2) num13, den13 = _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac3_0, ac3_1, n3) num23, den23 = _hudson_fst_from_counts(ac2_0, ac2_1, n2, ac3_0, ac3_1, n3) fst12 = _windowed_fst(num12, den12, window_size, window_start, window_stop, window_step) fst13 = _windowed_fst(num13, den13, window_size, window_start, window_stop, window_step) fst23 = _windowed_fst(num23, den23, window_size, window_start, window_stop, window_step) cp.clip(fst12, 0, 0.99999, out=fst12) cp.clip(fst13, 0, 0.99999, out=fst13) cp.clip(fst23, 0, 0.99999, out=fst23) t12 = -cp.log(1 - fst12) t13 = -cp.log(1 - fst13) t23 = -cp.log(1 - fst23) ret = (t12 + t13 - t23) / 2 if normed: norm = 1 + (t12 + t13 + t23) / 2 ret = ret / norm return ret.get()
# --------------------------------------------------------------------------- # Two-population distance-based statistics # --------------------------------------------------------------------------- def _snn_one_pop(within, between): """Score one population block for Hudson's Snn on GPU. For each haplotype, checks whether its nearest neighbor is within-pop (score 1), between-pop (score 0), or tied (fractional score). """ w = within.copy() cp.fill_diagonal(w, cp.inf) min_within = cp.min(w, axis=1) min_between = cp.min(between, axis=1) score = (min_within < min_between).astype(cp.float64) tied = min_within == min_between if cp.any(tied): n_within_ties = cp.sum(w == min_within[:, None], axis=1) n_between_ties = cp.sum(between == min_between[:, None], axis=1) tie_score = n_within_ties / (n_within_ties + n_between_ties) score = cp.where(tied, tie_score, score) return float(cp.sum(score).get()) def _resolve_distance_matrices(haplotype_matrix, pop1, pop2, missing_data='include', distance_matrices=None): """Resolve or compute pairwise distance matrices. If distance_matrices is provided, validates shapes against populations. Otherwise computes from scratch. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple of (dist_between, dist_within1, dist_within2), optional Pre-computed cupy distance matrices from a prior call to pairwise_distance_matrix or distance_based_stats. Returns ------- dist_between, dist_within1, dist_within2 : cupy.ndarray """ if distance_matrices is not None: n1 = len(_get_population_indices(haplotype_matrix, pop1)) n2 = len(_get_population_indices(haplotype_matrix, pop2)) db, dw1, dw2 = distance_matrices if db.shape != (n1, n2): raise ValueError( f"distance_matrices between-pop shape {db.shape} does not " f"match populations ({n1}, {n2})") if dw1.shape != (n1, n1): raise ValueError( f"distance_matrices within-pop1 shape {dw1.shape} does not " f"match pop1 size ({n1}, {n1})") if dw2.shape != (n2, n2): raise ValueError( f"distance_matrices within-pop2 shape {dw2.shape} does not " f"match pop2 size ({n2}, {n2})") return db, dw1, dw2 return pairwise_distance_matrix(haplotype_matrix, pop1, pop2, missing_data)
[docs] def pairwise_distance_matrix(haplotype_matrix, pop1, pop2, missing_data='include'): """Compute pairwise distance matrices between and within two populations. Returns three cupy distance matrices: between-population, within-pop1, and within-pop2. Pre-compute once and pass to individual stats via the ``distance_matrices`` parameter to avoid redundant GPU work. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str Returns ------- dist_between : cupy.ndarray, float64, shape (n1, n2) dist_within1 : cupy.ndarray, float64, shape (n1, n1) dist_within2 : cupy.ndarray, float64, shape (n2, n2) """ from .distance_stats import _pairwise_diffs_matrix_gpu if missing_data == 'exclude': haplotype_matrix = haplotype_matrix.exclude_missing_sites( populations=[pop1, pop2]) pop1_idx = _get_population_indices(haplotype_matrix, pop1) pop2_idx = _get_population_indices(haplotype_matrix, pop2) all_idx = pop1_idx + pop2_idx n1 = len(pop1_idx) hap = haplotype_matrix.haplotypes hap_sub = hap[all_idx, :] # Raw Hamming distances (not normalized) — appropriate for ratio/rank stats diffs = _pairwise_diffs_matrix_gpu(hap_sub, missing_data='include') return diffs[:n1, n1:], diffs[:n1, :n1], diffs[n1:, n1:]
[docs] def snn(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', distance_matrices=None) -> float: """Hudson's nearest-neighbor statistic (Snn). For each haplotype, determines whether its nearest neighbor is from the same population. Snn is the fraction of haplotypes whose nearest neighbor is conspecific. Under panmixia Snn ~ 0.5; under population structure Snn -> 1. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple, optional Pre-computed (dist_between, dist_within1, dist_within2) from pairwise_distance_matrix. Avoids recomputation when calling multiple stats on the same population pair. Returns ------- float Snn statistic in [0, 1]. References ---------- Hudson, R.R. (2000). A New Statistic for Detecting Genetic Differentiation. Genetics, 155(4), 2011-2014. """ dist_between, dist_within1, dist_within2 = _resolve_distance_matrices( haplotype_matrix, pop1, pop2, missing_data, distance_matrices) n1, n2 = dist_between.shape count = _snn_one_pop(dist_within1, dist_between) count += _snn_one_pop(dist_within2, dist_between.T) return count / (n1 + n2)
[docs] def dxy_min(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', distance_matrices=None) -> float: """Minimum pairwise distance between two populations. The Hamming distance of the closest pair of haplotypes across the two populations. Used in Gmin (Geneva et al.) and dd (Schrider et al.) statistics. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple, optional Pre-computed distance matrices. Returns ------- float Minimum pairwise distance. References ---------- Geneva, A.J. et al. (2015). A new method to scan genomes for introgression in a secondary contact model. PLoS ONE, 10(4). """ dist_between, _, _ = _resolve_distance_matrices( haplotype_matrix, pop1, pop2, missing_data, distance_matrices) return float(cp.min(dist_between).get())
[docs] def gmin(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', distance_matrices=None) -> float: """Geneva's Gmin: ratio of minimum to mean between-population distance. Gmin = Dxy_min / Dxy_mean. Low values indicate unusually similar haplotypes across populations relative to average divergence, suggesting recent gene flow or shared ancestral variation. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple, optional Pre-computed distance matrices. Returns ------- float Gmin ratio. References ---------- Geneva, A.J. et al. (2015). A new method to scan genomes for introgression in a secondary contact model. PLoS ONE, 10(4). """ dist_between, _, _ = _resolve_distance_matrices( haplotype_matrix, pop1, pop2, missing_data, distance_matrices) mean_dxy = float(cp.mean(dist_between).get()) min_dxy = float(cp.min(dist_between).get()) if mean_dxy == 0: return float('nan') return min_dxy / mean_dxy
[docs] def dd(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', distance_matrices=None) -> Tuple[float, float]: """Relative minimum divergence (dd1, dd2). dd1 = Dxy_min / pi1, dd2 = Dxy_min / pi2. Low values indicate that the closest between-population pair is unusually similar relative to within-population diversity. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple, optional Pre-computed distance matrices. Returns ------- dd1 : float Dxy_min / pi(pop1). dd2 : float Dxy_min / pi(pop2). References ---------- Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018). Supervised machine learning reveals introgressed loci in the genomes of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4), e1007341. https://doi.org/10.1371/journal.pgen.1007341 """ from . import diversity dist_between, _, _ = _resolve_distance_matrices( haplotype_matrix, pop1, pop2, missing_data, distance_matrices) min_dxy = float(cp.min(dist_between).get()) pi1 = diversity.pi(haplotype_matrix, population=pop1, span_normalize=False, missing_data=missing_data) pi2 = diversity.pi(haplotype_matrix, population=pop2, span_normalize=False, missing_data=missing_data) dd1 = min_dxy / pi1 if pi1 > 0 else float('nan') dd2 = min_dxy / pi2 if pi2 > 0 else float('nan') return dd1, dd2
[docs] def dd_rank(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include', distance_matrices=None) -> Tuple[float, float]: """Rank of Dxy_min in within-population pairwise distance distributions. For each population, computes the fraction of within-population pairwise distances that are <= Dxy_min. Low values indicate the closest between-population pair is more similar than most within-population pairs, suggesting introgression. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str distance_matrices : tuple, optional Pre-computed distance matrices. Returns ------- rank1 : float Fraction of pop1 within-pop distances <= Dxy_min. rank2 : float Fraction of pop2 within-pop distances <= Dxy_min. References ---------- Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018). Supervised machine learning reveals introgressed loci in the genomes of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4), e1007341. https://doi.org/10.1371/journal.pgen.1007341 """ dist_between, dist_within1, dist_within2 = _resolve_distance_matrices( haplotype_matrix, pop1, pop2, missing_data, distance_matrices) min_dxy = cp.min(dist_between) # Extract upper triangle of within-pop distances (exclude diagonal) idx1 = cp.triu_indices(dist_within1.shape[0], k=1) within1 = dist_within1[idx1] idx2 = cp.triu_indices(dist_within2.shape[0], k=1) within2 = dist_within2[idx2] rank1 = float(cp.mean((within1 <= min_dxy).astype(cp.float64)).get()) if len(within1) > 0 else float('nan') rank2 = float(cp.mean((within2 <= min_dxy).astype(cp.float64)).get()) if len(within2) > 0 else float('nan') return rank1, rank2
[docs] def zx(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include') -> float: """ZnS ratio: within-population LD relative to total LD. Zx = (ZnS_pop1 + ZnS_pop2) / (2 * ZnS_total). Values > 1 indicate stronger LD within populations than across the combined sample, consistent with population structure or recent admixture. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str Returns ------- float Zx ratio. References ---------- Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018). Supervised machine learning reveals introgressed loci in the genomes of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4), e1007341. https://doi.org/10.1371/journal.pgen.1007341 """ from . import ld_statistics from ._utils import get_population_matrix m1 = get_population_matrix(haplotype_matrix, pop1) m2 = get_population_matrix(haplotype_matrix, pop2) z1 = ld_statistics.zns(m1, missing_data=missing_data) z2 = ld_statistics.zns(m2, missing_data=missing_data) z_total = ld_statistics.zns(haplotype_matrix, missing_data=missing_data) if z_total == 0: return float('nan') return (z1 + z2) / (2 * z_total)
[docs] def distance_based_stats(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include') -> Dict[str, float]: """Compute all distance-based two-population statistics at once. Shares the pairwise distance matrix computation across Snn, Gmin, dd, and dd_rank, avoiding redundant GPU work. Parameters ---------- haplotype_matrix : HaplotypeMatrix pop1, pop2 : str or list missing_data : str Returns ------- dict Keys: snn, dxy_min, gmin, dd1, dd2, dd_rank1, dd_rank2. """ dist_between, dist_within1, dist_within2 = pairwise_distance_matrix( haplotype_matrix, pop1, pop2, missing_data) n1, n2 = dist_between.shape min_dxy_gpu = cp.min(dist_between) min_dxy = float(min_dxy_gpu.get()) mean_dxy = float(cp.mean(dist_between).get()) snn_val = (_snn_one_pop(dist_within1, dist_between) + _snn_one_pop(dist_within2, dist_between.T)) / (n1 + n2) idx1 = cp.triu_indices(n1, k=1) within1 = dist_within1[idx1] idx2 = cp.triu_indices(n2, k=1) within2 = dist_within2[idx2] rank1 = float(cp.mean((within1 <= min_dxy_gpu).astype(cp.float64)).get()) if len(within1) > 0 else float('nan') rank2 = float(cp.mean((within2 <= min_dxy_gpu).astype(cp.float64)).get()) if len(within2) > 0 else float('nan') pi1 = float(cp.mean(within1).get()) if len(within1) > 0 else 0.0 pi2 = float(cp.mean(within2).get()) if len(within2) > 0 else 0.0 return { 'snn': snn_val, 'dxy_min': min_dxy, 'gmin': min_dxy / mean_dxy if mean_dxy > 0 else float('nan'), 'dd1': min_dxy / pi1 if pi1 > 0 else float('nan'), 'dd2': min_dxy / pi2 if pi2 > 0 else float('nan'), 'dd_rank1': rank1, 'dd_rank2': rank2, }