"""
GPU-accelerated population divergence statistics.
This module provides efficient computation of population divergence metrics
including FST, Dxy, and related statistics using GPU acceleration.
"""
import warnings
import numpy as np
import cupy as cp
from typing import Union, Tuple, Optional, Dict
from .haplotype_matrix import HaplotypeMatrix
from ._memutil import dac_and_n as _pop_dac_and_n
from .diversity import _apply_span_normalize, pi as _diversity_pi
def dxy_components(pop1_haps, pop2_haps):
"""Compute between-population pairwise differences and comparisons.
Returns raw counts for custom aggregation (e.g., windowed analysis).
Parameters
----------
pop1_haps, pop2_haps : cupy.ndarray, shape (n_haplotypes, n_variants)
Haplotype data for each population, with -1 for missing.
Returns
-------
total_diffs : float
Sum of pairwise differences across usable sites.
total_comps : float
Sum of pairwise comparisons across usable sites.
n_sites : int
Number of sites with data in both populations.
"""
pop1_derived, pop1_n = _pop_dac_and_n(pop1_haps)
pop2_derived, pop2_n = _pop_dac_and_n(pop2_haps)
pop1_n = pop1_n.astype(cp.float64)
pop2_n = pop2_n.astype(cp.float64)
pop1_derived = pop1_derived.astype(cp.float64)
pop2_derived = pop2_derived.astype(cp.float64)
pop1_ancestral = pop1_n - pop1_derived
pop2_ancestral = pop2_n - pop2_derived
site_diffs = pop1_derived * pop2_ancestral + pop1_ancestral * pop2_derived
site_comps = pop1_n * pop2_n
usable = (pop1_n > 0) & (pop2_n > 0)
total_diffs = float(cp.sum(site_diffs[usable]).get())
total_comps = float(cp.sum(site_comps[usable]).get())
n_sites = int(cp.sum(usable).get())
return total_diffs, total_comps, n_sites
[docs]
def fst(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
method: str = 'hudson',
missing_data: str = 'include') -> float:
"""
Compute FST between two populations.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population (name from sample_sets or list of indices)
pop2 : str or list
Second population (name from sample_sets or list of indices)
method : str
FST estimation method ('hudson', 'weir_cockerham', 'nei')
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
float
FST value between populations
"""
if method == 'hudson':
return fst_hudson(haplotype_matrix, pop1, pop2, missing_data)
elif method == 'weir_cockerham':
return fst_weir_cockerham(haplotype_matrix, pop1, pop2, missing_data)
elif method == 'nei':
return fst_nei(haplotype_matrix, pop1, pop2, missing_data)
else:
raise ValueError(f"Unknown FST method: {method}")
[docs]
def fst_hudson(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include') -> float:
"""
Compute Hudson's FST estimator between two populations.
Hudson (1992) estimator:
FST = 1 - (Hw / Hb)
where Hw is within-population diversity and Hb is between-population diversity
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population
pop2 : str or list
Second population
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
float
Hudson's FST estimate
"""
# Ensure data is on GPU if available
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1, pop2])
if haplotype_matrix.num_variants == 0:
return 0.0
pop1_idx = _get_population_indices(haplotype_matrix, pop1)
pop2_idx = _get_population_indices(haplotype_matrix, pop2)
pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :]
pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :]
pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps)
pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps)
pop1_counts = pop1_counts.astype(cp.float64)
pop2_counts = pop2_counts.astype(cp.float64)
n1 = pop1_n.astype(cp.float64)
n2 = pop2_n.astype(cp.float64)
pop1_freqs = cp.where(n1 > 0, pop1_counts / n1, 0.0)
pop2_freqs = cp.where(n2 > 0, pop2_counts / n2, 0.0)
# Per-site within-population mean pairwise difference
# mpd(p, n) = p*(1-p)*n/(n-1) (unbiased heterozygosity)
within1 = cp.zeros_like(pop1_freqs)
within2 = cp.zeros_like(pop2_freqs)
valid1 = n1 > 1
valid2 = n2 > 1
within1[valid1] = (pop1_freqs[valid1] * (1 - pop1_freqs[valid1])
* n1[valid1] / (n1[valid1] - 1))
within2[valid2] = (pop2_freqs[valid2] * (1 - pop2_freqs[valid2])
* n2[valid2] / (n2[valid2] - 1))
within = (within1 + within2) / 2.0
# Per-site between-population mean pairwise difference
between = (pop1_freqs * (1 - pop2_freqs)
+ pop2_freqs * (1 - pop1_freqs)) / 2.0
# Numerator and denominator per SNP (ratio-of-averages)
num = between - within
den = between
valid_mask = (den > 0) & (n1 > 0) & (n2 > 0)
if cp.any(valid_mask):
fst_val = float((cp.sum(num[valid_mask]) / cp.sum(den[valid_mask])).get())
return fst_val
else:
return 0.0
[docs]
def fst_weir_cockerham(haplotype_matrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include') -> float:
"""
Compute Weir & Cockerham's (1984) FST estimator.
This is the method of moments estimator that accounts for sampling
variance. Computes all three variance components (a, b, c) including
within-individual heterozygosity from paired haplotypes.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data. Consecutive haplotype pairs are treated as
diploid individuals for the heterozygosity component.
pop1 : str or list
First population
pop2 : str or list
Second population
missing_data : str
'include' - per-site sample sizes, ratio-of-sums
'exclude' - Only use sites with no missing data
Returns
-------
float
Weir & Cockerham's FST estimate
"""
if hasattr(haplotype_matrix, 'device') and haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1, pop2])
if haplotype_matrix.num_variants == 0:
return 0.0
pop1_idx = _get_population_indices(haplotype_matrix, pop1)
pop2_idx = _get_population_indices(haplotype_matrix, pop2)
pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :]
pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :]
# Compute per-site observed heterozygosity and allele counts from
# complete diploid pairs only. WC FST requires that frequencies and
# het are computed from the same set of individuals — using all valid
# haplotypes for counts but only complete pairs for n inflates
# frequencies by 1/(1-miss_rate) under MCAR.
def _pop_diploid_stats(pop_haps):
ha = pop_haps[0::2, :] # first haplotype of each diploid
hb = pop_haps[1::2, :] # second haplotype
both_valid = (ha >= 0) & (hb >= 0)
het = (ha != hb) & both_valid
n_called = cp.sum(both_valid, axis=0).astype(cp.float64)
h_bar = cp.where(n_called > 0,
cp.sum(het, axis=0).astype(cp.float64) / n_called, 0.0)
# Derived allele count from complete pairs only
derived = cp.sum(
cp.where(both_valid, ha, 0) + cp.where(both_valid, hb, 0),
axis=0).astype(cp.float64)
return h_bar, n_called, derived
if len(pop1_idx) >= 2 and len(pop2_idx) >= 2:
h_bar1, n1, pop1_counts = _pop_diploid_stats(pop1_haps)
h_bar2, n2, pop2_counts = _pop_diploid_stats(pop2_haps)
else:
pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps)
pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps)
pop1_counts = pop1_counts.astype(cp.float64)
pop2_counts = pop2_counts.astype(cp.float64)
n1 = pop1_n.astype(cp.float64)
n2 = pop2_n.astype(cp.float64)
h_bar1 = cp.zeros_like(n1)
h_bar2 = cp.zeros_like(n2)
# Allele frequencies from complete diploid pairs
pop1_freqs = cp.where(n1 > 0, pop1_counts / (2.0 * n1), 0.0)
pop2_freqs = cp.where(n2 > 0, pop2_counts / (2.0 * n2), 0.0)
r = 2.0
n_total = n1 + n2
n_bar = n_total / r
# n_C: sample size correction factor
nc = cp.zeros_like(n_total)
valid = n_total > 0
nc[valid] = (n_total[valid] - (n1[valid]**2 + n2[valid]**2) / n_total[valid]) / (r - 1)
# Weighted average allele frequency
p_bar = cp.zeros_like(pop1_freqs)
valid = n_total > 0
p_bar[valid] = (n1[valid] * pop1_freqs[valid] + n2[valid] * pop2_freqs[valid]) / n_total[valid]
# Sample variance of allele frequencies
s_squared = cp.zeros_like(p_bar)
valid = n_bar > 0
s_squared[valid] = (n1[valid] * (pop1_freqs[valid] - p_bar[valid])**2 +
n2[valid] * (pop2_freqs[valid] - p_bar[valid])**2) / ((r - 1) * n_bar[valid])
# Average observed heterozygosity weighted by sample size
h_bar = cp.zeros_like(p_bar)
valid = n_total > 0
h_bar[valid] = (n1[valid] * h_bar1[valid] + n2[valid] * h_bar2[valid]) / n_total[valid]
# W-C variance components (Eqs 2, 3, 4 from Weir & Cockerham 1984)
a = cp.zeros_like(p_bar)
b = cp.zeros_like(p_bar)
c = cp.zeros_like(p_bar)
valid = (n_bar > 1) & (nc > 0)
pq = p_bar[valid] * (1 - p_bar[valid])
s2 = s_squared[valid]
nb = n_bar[valid]
ncc = nc[valid]
hb = h_bar[valid]
a[valid] = (nb / ncc) * (s2 - (1.0 / (nb - 1)) * (pq - (r - 1) * s2 / r - hb / 4.0))
b[valid] = (nb / (nb - 1)) * (pq - (r - 1) * s2 / r - (2 * nb - 1) * hb / (4.0 * nb))
c[valid] = hb / 2.0
# Global FST = sum(a) / sum(a + b + c)
valid_mask = (n1 > 0) & (n2 > 0)
if cp.any(valid_mask):
sum_a = float(cp.sum(a[valid_mask]).get())
sum_abc = float(cp.sum((a + b + c)[valid_mask]).get())
if sum_abc > 0:
return sum_a / sum_abc
return 0.0
[docs]
def fst_nei(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include') -> float:
"""
Compute Nei's GST (1973).
GST = (HT - HS) / HT
where HT is total heterozygosity and HS is within-subpopulation heterozygosity
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population
pop2 : str or list
Second population
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
float
Nei's GST estimate
"""
# Ensure data is on GPU if available
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1, pop2])
if haplotype_matrix.num_variants == 0:
return 0.0
pop1_idx = _get_population_indices(haplotype_matrix, pop1)
pop2_idx = _get_population_indices(haplotype_matrix, pop2)
pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :]
pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :]
pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps)
pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps)
pop1_counts = pop1_counts.astype(cp.float64)
pop2_counts = pop2_counts.astype(cp.float64)
n1 = pop1_n.astype(cp.float64)
n2 = pop2_n.astype(cp.float64)
pop1_freqs = cp.where(n1 > 0, pop1_counts / n1, 0.0)
pop2_freqs = cp.where(n2 > 0, pop2_counts / n2, 0.0)
# Within-population heterozygosity
hs1 = 2.0 * pop1_freqs * (1 - pop1_freqs)
hs2 = 2.0 * pop2_freqs * (1 - pop2_freqs)
hs = cp.zeros_like(hs1)
p_total = cp.zeros_like(pop1_freqs)
valid = (n1 + n2) > 0
hs[valid] = (hs1[valid] * n1[valid] + hs2[valid] * n2[valid]) / (n1[valid] + n2[valid])
p_total[valid] = (pop1_freqs[valid] * n1[valid] + pop2_freqs[valid] * n2[valid]) / (n1[valid] + n2[valid])
ht = 2.0 * p_total * (1 - p_total)
# Calculate GST - only for sites with sufficient data
valid_mask = (ht > 0) & (n1 > 0) & (n2 > 0)
if not cp.any(valid_mask):
return 0.0
# Ratio-of-averages: sum(HT-HS) / sum(HT)
sum_ht = float(cp.sum(ht[valid_mask]).get())
sum_hs = float(cp.sum(hs[valid_mask]).get())
if sum_ht == 0:
return 0.0
return (sum_ht - sum_hs) / sum_ht
[docs]
def dxy(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
per_site: bool = False,
missing_data: str = 'include',
span_normalize=True,
) -> Union[float, cp.ndarray]:
"""
Compute absolute divergence (Dxy) between two populations.
Dxy measures the average number of nucleotide differences between
sequences from two populations.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population
pop2 : str or list
Second population
per_site : bool
If True, return per-site values; if False, return mean
missing_data : str
``'include'`` (default) uses per-site valid data.
``'exclude'`` filters to sites with no missing data.
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sum / sites-with-data average.
Returns
-------
float or cp.ndarray
Mean Dxy or per-site Dxy values
"""
# Ensure data is on GPU if available
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1, pop2])
if haplotype_matrix.num_variants == 0:
return 0.0 if not per_site else np.array([])
pop1_idx = _get_population_indices(haplotype_matrix, pop1)
pop2_idx = _get_population_indices(haplotype_matrix, pop2)
pop1_haps = haplotype_matrix.haplotypes[pop1_idx, :]
pop2_haps = haplotype_matrix.haplotypes[pop2_idx, :]
n_filtered = pop1_haps.shape[1]
# Get allele frequencies from non-missing data per site
pop1_counts, pop1_n = _pop_dac_and_n(pop1_haps)
pop2_counts, pop2_n = _pop_dac_and_n(pop2_haps)
pop1_counts = pop1_counts.astype(cp.float64)
pop2_counts = pop2_counts.astype(cp.float64)
pop1_n = pop1_n.astype(cp.float64)
pop2_n = pop2_n.astype(cp.float64)
pop1_freqs = cp.where(pop1_n > 0, pop1_counts / pop1_n, 0.0)
pop2_freqs = cp.where(pop2_n > 0, pop2_counts / pop2_n, 0.0)
# Calculate Dxy only for sites with data
valid_mask = (pop1_n > 0) & (pop2_n > 0)
dxy_per_site = cp.zeros(n_filtered)
dxy_per_site[valid_mask] = (pop1_freqs[valid_mask] + pop2_freqs[valid_mask] -
2 * pop1_freqs[valid_mask] * pop2_freqs[valid_mask])
if per_site:
return dxy_per_site.get()
else:
dxy_sum = cp.sum(dxy_per_site)
if span_normalize is False:
n_sites = int(cp.sum(valid_mask).get())
return float(dxy_sum.get() / n_sites) if n_sites > 0 else 0.0
return _apply_span_normalize(dxy_sum, haplotype_matrix, span_normalize)
[docs]
def da(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
span_normalize=True) -> float:
"""
Compute net divergence (Da) between two populations.
Da = Dxy - (pi1 + pi2) / 2
where pi1 and pi2 are within-population diversities
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population
pop2 : str or list
Second population
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sum / sites-with-data average.
Returns
-------
float
Net divergence (Da)
"""
# Get Dxy
dxy_value = dxy(haplotype_matrix, pop1, pop2, missing_data=missing_data,
span_normalize=span_normalize)
# Get within-population diversities
pi1 = _diversity_pi(haplotype_matrix, population=pop1, missing_data=missing_data,
span_normalize=span_normalize)
pi2 = _diversity_pi(haplotype_matrix, population=pop2, missing_data=missing_data,
span_normalize=span_normalize)
# Calculate Da
da_value = dxy_value - (pi1 + pi2) / 2.0
return da_value
def pi_within_population(haplotype_matrix: HaplotypeMatrix,
pop: Union[str, list],
missing_data: str = 'include',
span_normalize=True) -> float:
"""
Compute nucleotide diversity (pi) within a population.
Delegates to ``diversity.pi()`` for consistent behavior.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop : str or list
Population name or list of indices
missing_data : str
'include' or 'exclude'
span_normalize : bool
``True`` (default): auto-detect. ``False``: raw sum.
Returns
-------
float
Nucleotide diversity
"""
return _diversity_pi(haplotype_matrix, population=pop,
span_normalize=span_normalize, missing_data=missing_data)
[docs]
def divergence_stats(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
statistics: list = ['fst', 'dxy', 'da'],
missing_data: str = 'include',
span_normalize=True) -> Dict[str, float]:
"""
Compute multiple divergence statistics between two populations.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop1 : str or list
First population
pop2 : str or list
Second population
statistics : list
List of statistics to compute
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
span_normalize : bool
``True`` (default): auto-detect. ``False``: raw sum / per-site average.
Returns
-------
dict
Dictionary of statistic names to values
"""
results = {}
for stat in statistics:
if stat == 'fst':
results['fst'] = fst(haplotype_matrix, pop1, pop2, missing_data=missing_data)
elif stat == 'fst_hudson':
results['fst_hudson'] = fst_hudson(haplotype_matrix, pop1, pop2, missing_data=missing_data)
elif stat == 'fst_wc':
results['fst_wc'] = fst_weir_cockerham(haplotype_matrix, pop1, pop2, missing_data=missing_data)
elif stat == 'fst_nei':
results['fst_nei'] = fst_nei(haplotype_matrix, pop1, pop2, missing_data=missing_data)
elif stat == 'dxy':
results['dxy'] = dxy(haplotype_matrix, pop1, pop2, missing_data=missing_data,
span_normalize=span_normalize)
elif stat == 'da':
results['da'] = da(haplotype_matrix, pop1, pop2, missing_data=missing_data,
span_normalize=span_normalize)
elif stat == 'pi1':
results['pi1'] = pi_within_population(haplotype_matrix, pop1, missing_data=missing_data,
span_normalize=span_normalize)
elif stat == 'pi2':
results['pi2'] = pi_within_population(haplotype_matrix, pop2, missing_data=missing_data,
span_normalize=span_normalize)
else:
raise ValueError(f"Unknown statistic: {stat}")
return results
[docs]
def pairwise_fst(haplotype_matrix: HaplotypeMatrix,
populations: Optional[list] = None,
method: str = 'hudson',
missing_data: str = 'include') -> Tuple[cp.ndarray, list]:
"""
Compute pairwise FST matrix for multiple populations.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
populations : list, optional
List of population names. If None, uses all populations in sample_sets
method : str
FST method to use
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
fst_matrix : cp.ndarray
Pairwise FST matrix
pop_names : list
Population names in matrix order
"""
if populations is None:
if haplotype_matrix.sample_sets is None:
raise ValueError("No populations defined in haplotype matrix")
populations = list(haplotype_matrix.sample_sets.keys())
n_pops = len(populations)
fst_matrix = cp.zeros((n_pops, n_pops))
for i in range(n_pops):
for j in range(i + 1, n_pops):
fst_value = fst(haplotype_matrix, populations[i], populations[j], method, missing_data)
fst_matrix[i, j] = fst_value
fst_matrix[j, i] = fst_value
return fst_matrix, populations
def _get_population_indices(haplotype_matrix: HaplotypeMatrix,
pop: Union[str, list]) -> list:
"""
Get population indices from name or list.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
pop : str or list
Population name or list of indices
Returns
-------
list
Population indices
"""
if isinstance(pop, str):
if haplotype_matrix.sample_sets is None:
raise ValueError("No sample_sets defined in haplotype matrix")
if pop not in haplotype_matrix.sample_sets:
raise ValueError(f"Population {pop} not found in sample_sets")
return haplotype_matrix.sample_sets[pop]
else:
return list(pop)
def _pop_allele_counts(haplotype_matrix, pop, missing_data='include'):
"""Compute per-variant allele counts for a population on GPU.
Returns (ac_0, ac_1, n) as CuPy arrays. n is per-site (array)
for 'include' mode, and also per-site after filtering
for 'exclude' mode.
"""
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop])
pop_idx = _get_population_indices(haplotype_matrix, pop)
h = haplotype_matrix.haplotypes[pop_idx, :]
ac_1, n = _pop_dac_and_n(h)
n = n.astype(cp.float64)
ac_1 = ac_1.astype(cp.float64)
ac_0 = n - ac_1
return ac_0, ac_1, n
def _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac2_0, ac2_1, n2):
"""Compute per-variant Hudson FST num/den from precomputed allele counts.
Returns (num, den) as CuPy arrays on GPU.
"""
n1_pairs = n1 * (n1 - 1) / 2
n1_same = (ac1_0 * (ac1_0 - 1) + ac1_1 * (ac1_1 - 1)) / 2
mpd1 = cp.where(n1_pairs > 0, (n1_pairs - n1_same) / n1_pairs, 0.0)
n2_pairs = n2 * (n2 - 1) / 2
n2_same = (ac2_0 * (ac2_0 - 1) + ac2_1 * (ac2_1 - 1)) / 2
mpd2 = cp.where(n2_pairs > 0, (n2_pairs - n2_same) / n2_pairs, 0.0)
within = (mpd1 + mpd2) / 2.0
n_between = n1 * n2
n_between_same = ac1_0 * ac2_0 + ac1_1 * ac2_1
between = cp.where(n_between > 0,
(n_between - n_between_same) / n_between, 0.0)
return between - within, between
def _windowed_fst(num, den, size, start=0, stop=None, step=None):
"""Compute windowed FST from per-variant numerator/denominator on GPU.
Uses cumulative sums for O(n) windowed reduction.
"""
n = len(num)
if stop is None:
stop = n
if step is None:
step = size
num_clean = cp.where(cp.isfinite(num), num, 0.0)
den_clean = cp.where(cp.isfinite(den), den, 0.0)
cs_num = cp.concatenate([cp.zeros(1, dtype=cp.float64), cp.cumsum(num_clean)])
cs_den = cp.concatenate([cp.zeros(1, dtype=cp.float64), cp.cumsum(den_clean)])
w_starts = cp.arange(start, stop - size + 1, step)
num_sum = cs_num[w_starts + size] - cs_num[w_starts]
den_sum = cs_den[w_starts + size] - cs_den[w_starts]
return cp.where(den_sum != 0, num_sum / den_sum, cp.nan)
[docs]
def pbs(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
pop3: Union[str, list],
window_size: int,
window_start: int = 0,
window_stop: Optional[int] = None,
window_step: Optional[int] = None,
normed: bool = True,
missing_data: str = 'include'):
"""Compute the Population Branch Statistic (PBS).
PBS detects genomic regions unusually differentiated in pop1 relative
to pop2 and pop3, using pairwise Hudson FST estimates.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data containing all three populations.
pop1, pop2, pop3 : str or list
Population names or sample indices.
window_size : int
Number of variants per window.
window_start : int, optional
Starting variant index.
window_stop : int, optional
Stopping variant index.
window_step : int, optional
Stride between windows. Defaults to window_size (non-overlapping).
normed : bool, optional
If True (default), return normalized PBS (PBSn1).
missing_data : str
'include' - per-site sample sizes
'exclude' - only use sites with no missing data
Returns
-------
ndarray, float64, shape (n_windows,)
PBS values per window.
"""
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
# precompute allele counts once per population
ac1_0, ac1_1, n1 = _pop_allele_counts(haplotype_matrix, pop1, missing_data)
ac2_0, ac2_1, n2 = _pop_allele_counts(haplotype_matrix, pop2, missing_data)
ac3_0, ac3_1, n3 = _pop_allele_counts(haplotype_matrix, pop3, missing_data)
# compute all three pairwise FST num/den from shared counts
num12, den12 = _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac2_0, ac2_1, n2)
num13, den13 = _hudson_fst_from_counts(ac1_0, ac1_1, n1, ac3_0, ac3_1, n3)
num23, den23 = _hudson_fst_from_counts(ac2_0, ac2_1, n2, ac3_0, ac3_1, n3)
fst12 = _windowed_fst(num12, den12, window_size, window_start,
window_stop, window_step)
fst13 = _windowed_fst(num13, den13, window_size, window_start,
window_stop, window_step)
fst23 = _windowed_fst(num23, den23, window_size, window_start,
window_stop, window_step)
cp.clip(fst12, 0, 0.99999, out=fst12)
cp.clip(fst13, 0, 0.99999, out=fst13)
cp.clip(fst23, 0, 0.99999, out=fst23)
t12 = -cp.log(1 - fst12)
t13 = -cp.log(1 - fst13)
t23 = -cp.log(1 - fst23)
ret = (t12 + t13 - t23) / 2
if normed:
norm = 1 + (t12 + t13 + t23) / 2
ret = ret / norm
return ret.get()
# ---------------------------------------------------------------------------
# Two-population distance-based statistics
# ---------------------------------------------------------------------------
def _snn_one_pop(within, between):
"""Score one population block for Hudson's Snn on GPU.
For each haplotype, checks whether its nearest neighbor is within-pop
(score 1), between-pop (score 0), or tied (fractional score).
"""
w = within.copy()
cp.fill_diagonal(w, cp.inf)
min_within = cp.min(w, axis=1)
min_between = cp.min(between, axis=1)
score = (min_within < min_between).astype(cp.float64)
tied = min_within == min_between
if cp.any(tied):
n_within_ties = cp.sum(w == min_within[:, None], axis=1)
n_between_ties = cp.sum(between == min_between[:, None], axis=1)
tie_score = n_within_ties / (n_within_ties + n_between_ties)
score = cp.where(tied, tie_score, score)
return float(cp.sum(score).get())
def _resolve_distance_matrices(haplotype_matrix, pop1, pop2,
missing_data='include',
distance_matrices=None):
"""Resolve or compute pairwise distance matrices.
If distance_matrices is provided, validates shapes against populations.
Otherwise computes from scratch.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple of (dist_between, dist_within1, dist_within2), optional
Pre-computed cupy distance matrices from a prior call to
pairwise_distance_matrix or distance_based_stats.
Returns
-------
dist_between, dist_within1, dist_within2 : cupy.ndarray
"""
if distance_matrices is not None:
n1 = len(_get_population_indices(haplotype_matrix, pop1))
n2 = len(_get_population_indices(haplotype_matrix, pop2))
db, dw1, dw2 = distance_matrices
if db.shape != (n1, n2):
raise ValueError(
f"distance_matrices between-pop shape {db.shape} does not "
f"match populations ({n1}, {n2})")
if dw1.shape != (n1, n1):
raise ValueError(
f"distance_matrices within-pop1 shape {dw1.shape} does not "
f"match pop1 size ({n1}, {n1})")
if dw2.shape != (n2, n2):
raise ValueError(
f"distance_matrices within-pop2 shape {dw2.shape} does not "
f"match pop2 size ({n2}, {n2})")
return db, dw1, dw2
return pairwise_distance_matrix(haplotype_matrix, pop1, pop2, missing_data)
[docs]
def pairwise_distance_matrix(haplotype_matrix, pop1, pop2,
missing_data='include'):
"""Compute pairwise distance matrices between and within two populations.
Returns three cupy distance matrices: between-population,
within-pop1, and within-pop2. Pre-compute once and pass to
individual stats via the ``distance_matrices`` parameter to
avoid redundant GPU work.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
Returns
-------
dist_between : cupy.ndarray, float64, shape (n1, n2)
dist_within1 : cupy.ndarray, float64, shape (n1, n1)
dist_within2 : cupy.ndarray, float64, shape (n2, n2)
"""
from .distance_stats import _pairwise_diffs_matrix_gpu
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1, pop2])
pop1_idx = _get_population_indices(haplotype_matrix, pop1)
pop2_idx = _get_population_indices(haplotype_matrix, pop2)
all_idx = pop1_idx + pop2_idx
n1 = len(pop1_idx)
hap = haplotype_matrix.haplotypes
hap_sub = hap[all_idx, :]
# Raw Hamming distances (not normalized) — appropriate for ratio/rank stats
diffs = _pairwise_diffs_matrix_gpu(hap_sub, missing_data='include')
return diffs[:n1, n1:], diffs[:n1, :n1], diffs[n1:, n1:]
[docs]
def snn(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
distance_matrices=None) -> float:
"""Hudson's nearest-neighbor statistic (Snn).
For each haplotype, determines whether its nearest neighbor is from
the same population. Snn is the fraction of haplotypes whose nearest
neighbor is conspecific. Under panmixia Snn ~ 0.5; under population
structure Snn -> 1.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple, optional
Pre-computed (dist_between, dist_within1, dist_within2) from
pairwise_distance_matrix. Avoids recomputation when calling
multiple stats on the same population pair.
Returns
-------
float
Snn statistic in [0, 1].
References
----------
Hudson, R.R. (2000). A New Statistic for Detecting Genetic
Differentiation. Genetics, 155(4), 2011-2014.
"""
dist_between, dist_within1, dist_within2 = _resolve_distance_matrices(
haplotype_matrix, pop1, pop2, missing_data, distance_matrices)
n1, n2 = dist_between.shape
count = _snn_one_pop(dist_within1, dist_between)
count += _snn_one_pop(dist_within2, dist_between.T)
return count / (n1 + n2)
[docs]
def dxy_min(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
distance_matrices=None) -> float:
"""Minimum pairwise distance between two populations.
The Hamming distance of the closest pair of haplotypes across
the two populations. Used in Gmin (Geneva et al.) and dd
(Schrider et al.) statistics.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple, optional
Pre-computed distance matrices.
Returns
-------
float
Minimum pairwise distance.
References
----------
Geneva, A.J. et al. (2015). A new method to scan genomes for
introgression in a secondary contact model. PLoS ONE, 10(4).
"""
dist_between, _, _ = _resolve_distance_matrices(
haplotype_matrix, pop1, pop2, missing_data, distance_matrices)
return float(cp.min(dist_between).get())
[docs]
def gmin(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
distance_matrices=None) -> float:
"""Geneva's Gmin: ratio of minimum to mean between-population distance.
Gmin = Dxy_min / Dxy_mean. Low values indicate unusually similar
haplotypes across populations relative to average divergence,
suggesting recent gene flow or shared ancestral variation.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple, optional
Pre-computed distance matrices.
Returns
-------
float
Gmin ratio.
References
----------
Geneva, A.J. et al. (2015). A new method to scan genomes for
introgression in a secondary contact model. PLoS ONE, 10(4).
"""
dist_between, _, _ = _resolve_distance_matrices(
haplotype_matrix, pop1, pop2, missing_data, distance_matrices)
mean_dxy = float(cp.mean(dist_between).get())
min_dxy = float(cp.min(dist_between).get())
if mean_dxy == 0:
return float('nan')
return min_dxy / mean_dxy
[docs]
def dd(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
distance_matrices=None) -> Tuple[float, float]:
"""Relative minimum divergence (dd1, dd2).
dd1 = Dxy_min / pi1, dd2 = Dxy_min / pi2. Low values indicate
that the closest between-population pair is unusually similar
relative to within-population diversity.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple, optional
Pre-computed distance matrices.
Returns
-------
dd1 : float
Dxy_min / pi(pop1).
dd2 : float
Dxy_min / pi(pop2).
References
----------
Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018).
Supervised machine learning reveals introgressed loci in the genomes
of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4),
e1007341. https://doi.org/10.1371/journal.pgen.1007341
"""
from . import diversity
dist_between, _, _ = _resolve_distance_matrices(
haplotype_matrix, pop1, pop2, missing_data, distance_matrices)
min_dxy = float(cp.min(dist_between).get())
pi1 = diversity.pi(haplotype_matrix, population=pop1,
span_normalize=False, missing_data=missing_data)
pi2 = diversity.pi(haplotype_matrix, population=pop2,
span_normalize=False, missing_data=missing_data)
dd1 = min_dxy / pi1 if pi1 > 0 else float('nan')
dd2 = min_dxy / pi2 if pi2 > 0 else float('nan')
return dd1, dd2
[docs]
def dd_rank(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include',
distance_matrices=None) -> Tuple[float, float]:
"""Rank of Dxy_min in within-population pairwise distance distributions.
For each population, computes the fraction of within-population
pairwise distances that are <= Dxy_min. Low values indicate the
closest between-population pair is more similar than most
within-population pairs, suggesting introgression.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
distance_matrices : tuple, optional
Pre-computed distance matrices.
Returns
-------
rank1 : float
Fraction of pop1 within-pop distances <= Dxy_min.
rank2 : float
Fraction of pop2 within-pop distances <= Dxy_min.
References
----------
Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018).
Supervised machine learning reveals introgressed loci in the genomes
of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4),
e1007341. https://doi.org/10.1371/journal.pgen.1007341
"""
dist_between, dist_within1, dist_within2 = _resolve_distance_matrices(
haplotype_matrix, pop1, pop2, missing_data, distance_matrices)
min_dxy = cp.min(dist_between)
# Extract upper triangle of within-pop distances (exclude diagonal)
idx1 = cp.triu_indices(dist_within1.shape[0], k=1)
within1 = dist_within1[idx1]
idx2 = cp.triu_indices(dist_within2.shape[0], k=1)
within2 = dist_within2[idx2]
rank1 = float(cp.mean((within1 <= min_dxy).astype(cp.float64)).get()) if len(within1) > 0 else float('nan')
rank2 = float(cp.mean((within2 <= min_dxy).astype(cp.float64)).get()) if len(within2) > 0 else float('nan')
return rank1, rank2
[docs]
def zx(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include') -> float:
"""ZnS ratio: within-population LD relative to total LD.
Zx = (ZnS_pop1 + ZnS_pop2) / (2 * ZnS_total). Values > 1 indicate
stronger LD within populations than across the combined sample,
consistent with population structure or recent admixture.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
Returns
-------
float
Zx ratio.
References
----------
Schrider, D.R., Ayroles, J., Matute, D.R. & Kern, A.D. (2018).
Supervised machine learning reveals introgressed loci in the genomes
of Drosophila simulans and D. sechellia. PLoS Genetics, 14(4),
e1007341. https://doi.org/10.1371/journal.pgen.1007341
"""
from . import ld_statistics
from ._utils import get_population_matrix
m1 = get_population_matrix(haplotype_matrix, pop1)
m2 = get_population_matrix(haplotype_matrix, pop2)
z1 = ld_statistics.zns(m1, missing_data=missing_data)
z2 = ld_statistics.zns(m2, missing_data=missing_data)
z_total = ld_statistics.zns(haplotype_matrix, missing_data=missing_data)
if z_total == 0:
return float('nan')
return (z1 + z2) / (2 * z_total)
[docs]
def distance_based_stats(haplotype_matrix: HaplotypeMatrix,
pop1: Union[str, list],
pop2: Union[str, list],
missing_data: str = 'include') -> Dict[str, float]:
"""Compute all distance-based two-population statistics at once.
Shares the pairwise distance matrix computation across Snn, Gmin,
dd, and dd_rank, avoiding redundant GPU work.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
pop1, pop2 : str or list
missing_data : str
Returns
-------
dict
Keys: snn, dxy_min, gmin, dd1, dd2, dd_rank1, dd_rank2.
"""
dist_between, dist_within1, dist_within2 = pairwise_distance_matrix(
haplotype_matrix, pop1, pop2, missing_data)
n1, n2 = dist_between.shape
min_dxy_gpu = cp.min(dist_between)
min_dxy = float(min_dxy_gpu.get())
mean_dxy = float(cp.mean(dist_between).get())
snn_val = (_snn_one_pop(dist_within1, dist_between)
+ _snn_one_pop(dist_within2, dist_between.T)) / (n1 + n2)
idx1 = cp.triu_indices(n1, k=1)
within1 = dist_within1[idx1]
idx2 = cp.triu_indices(n2, k=1)
within2 = dist_within2[idx2]
rank1 = float(cp.mean((within1 <= min_dxy_gpu).astype(cp.float64)).get()) if len(within1) > 0 else float('nan')
rank2 = float(cp.mean((within2 <= min_dxy_gpu).astype(cp.float64)).get()) if len(within2) > 0 else float('nan')
pi1 = float(cp.mean(within1).get()) if len(within1) > 0 else 0.0
pi2 = float(cp.mean(within2).get()) if len(within2) > 0 else 0.0
return {
'snn': snn_val,
'dxy_min': min_dxy,
'gmin': min_dxy / mean_dxy if mean_dxy > 0 else float('nan'),
'dd1': min_dxy / pi1 if pi1 > 0 else float('nan'),
'dd2': min_dxy / pi2 if pi2 > 0 else float('nan'),
'dd_rank1': rank1,
'dd_rank2': rank2,
}