"""
GPU-accelerated diversity and polymorphism statistics.
This module provides efficient computation of within-population genetic diversity
metrics including nucleotide diversity (π), Watterson's theta, Tajima's D, and
related statistics. Includes the FrequencySpectrum class for SFS-based analysis,
custom weight functions, and SFS projection.
"""
import math
import numpy as np
import cupy as cp
from typing import Union, Optional, Dict, Callable
from functools import lru_cache
from .haplotype_matrix import HaplotypeMatrix
from ._utils import get_population_matrix
def _apply_span_normalize(value, matrix, span_normalize):
"""Apply span normalization to a raw statistic value.
Parameters
----------
value : float or cupy scalar
Raw statistic sum.
matrix : HaplotypeMatrix
Source matrix (for get_span).
span_normalize : bool or str
True: auto-detect best span. False: return raw value.
String: explicit mode passed to get_span() (internal use).
"""
if span_normalize is False:
return float(value.get() if hasattr(value, 'get') else value)
mode = 'auto' if span_normalize is True else span_normalize
span = matrix.get_span(mode)
if span > 0:
return float(value / span)
return float('nan')
_gpu_lookup_cache = {}
def _get_a1_inv(n_max):
"""Get cached 1/a1(n) lookup array on GPU."""
key = ('a1_inv', n_max)
if key in _gpu_lookup_cache:
return _gpu_lookup_cache[key]
a1 = np.zeros(n_max + 1, dtype=np.float64)
for n in range(2, n_max + 1):
a1[n] = 1.0 / np.sum(1.0 / np.arange(1, n, dtype=np.float64))
result = cp.asarray(a1)
_gpu_lookup_cache[key] = result
return result
def _get_minus_eta1_norm(n_max):
"""Get cached 1/(a1-1) normalizer for minus_eta1 estimator on GPU."""
key = ('minus_eta1', n_max)
if key in _gpu_lookup_cache:
return _gpu_lookup_cache[key]
arr = np.zeros(n_max + 1, dtype=np.float64)
for ni in range(3, n_max + 1):
a1 = np.sum(1.0 / np.arange(1, ni, dtype=np.float64))
arr[ni] = 1.0 / (a1 - 1.0) if a1 > 1.0 else 0.0
result = cp.asarray(arr)
_gpu_lookup_cache[key] = result
return result
def _get_minus_eta1_star_norm(n_max):
"""Get cached normalizer for minus_eta1_star estimator on GPU."""
key = ('minus_eta1_star', n_max)
if key in _gpu_lookup_cache:
return _gpu_lookup_cache[key]
arr = np.zeros(n_max + 1, dtype=np.float64)
for ni in range(4, n_max + 1):
a1 = np.sum(1.0 / np.arange(1, ni, dtype=np.float64))
denom = a1 - 1.0 - 1.0 / (ni - 1)
arr[ni] = 1.0 / denom if denom > 0 else 0.0
result = cp.asarray(arr)
_gpu_lookup_cache[key] = result
return result
def _achaz_alpha_beta(v1, v2, n):
"""Compute Achaz (2009) Eq. 9 alpha_n, beta_n from two per-xi weight vectors.
O(n) time and memory — computes the bilinear form w^T @ sigma @ w
directly using the structure of Fu (1995) sigma_ij, without
materializing the full (n-1)x(n-1) covariance matrix.
Our weight functions return per-xi weights w[i]. Achaz's v-vectors are
per-u_hat weights where u_hat_i = i * xi_i. The conversion is
v_i/sum(v_j) = w[i]/i, giving V_i = w1[i]/i - w2[i]/i.
"""
H = cp.asarray(_harmonic_sums(n))
an = H[n - 1] # CuPy scalar — stays on GPU
k = cp.arange(1, n, dtype=cp.float64)
k_int = k.astype(cp.int64)
V = cp.asarray((v1[1:n] - v2[1:n])) / k
w = k * V
alpha_n = cp.sum(k * V ** 2)
# Precompute beta(i) for i=1..n-1 (reuse k instead of separate idx_b)
a_b = H[k_int - 1]
beta_vals = cp.zeros(n + 1, dtype=cp.float64)
beta_vals[1:n] = (2.0 * n / ((n - k + 1) * (n - k))
* (an + 1.0 / n - a_b) - 2.0 / (n - k))
# --- Diagonal: sum w_i^2 * sigma_{ii} ---
# Clip indices protect against out-of-bounds; valid masks ensure
# only correct contributions are summed.
diag_sigma = cp.where(
2 * k < n, beta_vals[k_int + 1],
cp.where(2 * k == n,
2.0 * (an - H[k_int - 1]) / (n - k) - 1.0 / (k * k),
beta_vals[k_int] - 1.0 / (k * k)))
diag_sum = cp.sum(w ** 2 * diag_sigma)
# --- Prefix/suffix sums for off-diagonal ---
W_prefix = cp.cumsum(w)
W_suffix = cp.flip(cp.cumsum(cp.flip(w)))
V_suffix = cp.flip(cp.cumsum(cp.flip(V)))
# Case A: i < j, i+j < n => sigma = (beta(j+1) - beta(j)) / 2
j_idx = cp.arange(2, n, dtype=cp.int64)
dbeta = beta_vals[j_idx + 1] - beta_vals[j_idx]
upper = cp.minimum(j_idx - 1, n - j_idx - 1)
valid_a = upper >= 1
psum = cp.where(valid_a,
W_prefix[cp.clip(upper - 1, 0, n - 2).astype(cp.int64)], 0.0)
case_a = cp.sum(dbeta * w[j_idx - 1] * psum * valid_a)
# Case C: i < j, i+j > n => sigma = (beta(i)-beta(i+1))/2 - 1/(i*j)
i_idx = cp.arange(1, n - 1, dtype=cp.int64)
dbeta_i = beta_vals[i_idx] - beta_vals[i_idx + 1]
j_start = cp.maximum(i_idx + 1, n - i_idx + 1)
valid_c = j_start <= n - 1
ws = cp.where(valid_c,
W_suffix[cp.clip(j_start - 1, 0, n - 2).astype(cp.int64)], 0.0)
case_c1 = cp.sum(dbeta_i * w[i_idx - 1] * ws * valid_c)
vs = cp.where(valid_c,
V_suffix[cp.clip(j_start - 1, 0, n - 2).astype(cp.int64)], 0.0)
case_c2 = cp.sum(V[i_idx - 1] * vs * valid_c)
# Case B: i+j == n (anti-diagonal, O(n) terms)
i_b = cp.arange(1, n, dtype=cp.int64)
j_b = n - i_b
valid_b = (j_b > 0) & (j_b < n) & (j_b > i_b)
ai = H[cp.clip(i_b - 1, 0, n).astype(cp.int64)]
aj = H[cp.clip(j_b - 1, 0, n).astype(cp.int64)]
j_f = j_b.astype(cp.float64)
i_f = i_b.astype(cp.float64)
s_b = ((an - aj) / (n - j_f) + (an - ai) / (n - i_f)
- (beta_vals[j_b] + beta_vals[cp.clip(i_b + 1, 0, n)]) / 2.0
- 1.0 / (j_f * i_f))
case_b = cp.sum(
2.0 * w[i_b - 1] * w[cp.clip(j_b - 1, 0, n - 2)] * s_b * valid_b)
# Single GPU->CPU transfer at the end
beta_n = diag_sum + case_a + case_c1 - 2.0 * case_c2 + case_b
return float(alpha_n), float(beta_n)
@lru_cache(maxsize=128)
def _achaz_variance_coefficients(w1_name, w2_name, n):
"""Cached Achaz (2009) Eq. 9 variance coefficients for named weight pairs.
This is the single source of truth for all neutrality test variances.
"""
v1 = WEIGHT_REGISTRY[w1_name](n)
v2 = WEIGHT_REGISTRY[w2_name](n)
return _achaz_alpha_beta(v1, v2, n)
def _achaz_variance(w1_name, w2_name, n, S):
"""Compute the Achaz (2009) null variance for a neutrality test.
Var(T) = alpha_n * theta_est + beta_n * theta_sq_est
where theta_est = S/a1 and theta_sq_est = S(S-1)/(a1^2+a2).
"""
alpha_n, beta_n = _achaz_variance_coefficients(w1_name, w2_name, n)
a1, a2 = _harmonic_a1_a2(n)
return alpha_n * S / a1 + beta_n * S * (S - 1) / (a1 ** 2 + a2)
def _site_contribution(name, d, n_safe, seg, n_valid, n_hap, dac=None):
"""Compute per-site contribution for a theta estimator on GPU.
This is the single source of truth for what each estimator computes.
Both scalar (_compute_thetas) and windowed (_windowed_thetas_scatter)
paths call this function.
Parameters
----------
name : str
Estimator name.
d, n_safe : cupy.ndarray, float64
Derived allele count (float) and safe sample size per site.
seg : cupy.ndarray, bool
Segregating site mask.
n_valid : cupy.ndarray, int64
Per-site valid sample count (for watterson lookup).
n_hap : int
Total haplotype count (for harmonic number lookup).
dac : cupy.ndarray, int64, optional
Integer derived allele count (for exact comparisons like == 1).
If None, uses d cast to int64.
Returns
-------
cupy.ndarray, float64, shape (n_variants,)
Per-site contribution (zero for non-segregating sites).
"""
if dac is None:
dac = d.astype(cp.int64)
if name in ('pi', 'theta_pi'):
return cp.where(seg, 2 * d * (n_safe - d) / (n_safe * (n_safe - 1)), 0.0)
elif name in ('watterson', 'theta_s'):
a1_inv = _get_a1_inv(n_hap)
return cp.where(seg, a1_inv[n_valid], 0.0)
elif name == 'theta_h':
return cp.where(seg, 2 * d * d / (n_safe * (n_safe - 1)), 0.0)
elif name == 'theta_l':
return cp.where(seg, d / (n_safe - 1), 0.0)
elif name == 'eta1':
# Singletons only: dac == 1
a1_inv = _get_a1_inv(n_hap)
return cp.where(seg & (dac == 1), a1_inv[n_valid], 0.0)
elif name == 'eta1_star':
# Singletons + (n-1)-tons
a1_inv = _get_a1_inv(n_hap)
is_edge = (dac == 1) | (dac == n_valid - 1)
return cp.where(seg & is_edge, a1_inv[n_valid], 0.0)
elif name == 'minus_eta1':
not_sing = dac >= 2
a1m1_gpu = _get_minus_eta1_norm(n_hap)
return cp.where(seg & not_sing, a1m1_gpu[n_valid], 0.0)
elif name == 'minus_eta1_star':
interior = (dac >= 2) & (dac <= n_valid - 2)
norm_gpu = _get_minus_eta1_star_norm(n_hap)
return cp.where(seg & interior, norm_gpu[n_valid], 0.0)
else:
raise ValueError(f"Unknown estimator: {name}. Use FrequencySpectrum "
f"for custom weight functions.")
def _prepare_dac(matrix):
"""Compute dac, n_valid, and derived quantities on GPU.
Returns (dac, n_valid, d, n_safe, seg, n_hap) — the shared
intermediate arrays used by all theta estimator paths.
"""
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
from ._memutil import dac_and_n
dac, n_valid = dac_and_n(matrix.haplotypes)
n = n_valid.astype(cp.float64)
d = dac.astype(cp.float64)
seg = (dac > 0) & (dac < n_valid) & (n_valid >= 2)
n_safe = cp.maximum(n, 2.0)
return dac, n_valid, d, n_safe, seg, matrix.num_haplotypes
def _compute_thetas(matrix, estimators=('pi', 'watterson', 'theta_h', 'theta_l')):
"""Compute multiple theta estimators via direct vectorized GPU arithmetic.
Parameters
----------
matrix : HaplotypeMatrix
Population-subsetted, on GPU.
estimators : tuple of str
Estimator names.
Returns
-------
dict with keys:
'thetas': dict of estimator name -> float (raw sum)
'S': int, number of segregating sites
'n_harmonic_mean': int, harmonic mean of per-site sample sizes
"""
dac, n_valid, d, n_safe, seg, n_hap = _prepare_dac(matrix)
thetas = {}
for name in estimators:
val = cp.sum(_site_contribution(name, d, n_safe, seg, n_valid, n_hap, dac=dac))
thetas[name] = float(val.get())
S = int(cp.sum(seg).get())
has_data = n_valid >= 2
if cp.any(has_data):
valid_n = n_valid[has_data].astype(cp.float64)
n_harm = round(float(len(valid_n) / cp.sum(1.0 / valid_n).get()))
else:
n_harm = 0
return {'thetas': thetas, 'S': S, 'n_harmonic_mean': n_harm}
def _compute_neutrality_test(matrix, w1_name, w2_name):
"""Compute a neutrality test statistic using the Achaz (2009) framework.
T = (theta_w1 - theta_w2) / sqrt(alpha_n * theta_est + beta_n * theta_sq_est)
Parameters
----------
matrix : HaplotypeMatrix
Population-subsetted, on GPU.
w1_name, w2_name : str
Weight vector names from WEIGHT_REGISTRY.
Returns
-------
float
"""
result = _compute_thetas(matrix, (w1_name, w2_name))
S = result['S']
n = result['n_harmonic_mean']
if S < 3 or n < 3:
return float('nan')
var = _achaz_variance(w1_name, w2_name, n, S)
if var <= 0:
return float('nan')
num = result['thetas'][w1_name] - result['thetas'][w2_name]
return float(num / math.sqrt(var))
def _prepare_matrix(haplotype_matrix, population=None, missing_data='include'):
"""Extract population subset and apply exclude filtering."""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
if missing_data == 'exclude':
matrix = matrix.exclude_missing_sites()
return matrix
# ---------------------------------------------------------------------------
# SFS projection and covariance (Gutenkunst et al. 2009, Fu 1995)
# ---------------------------------------------------------------------------
@lru_cache(maxsize=64)
def _harmonic_sums(n):
"""Precompute harmonic sums H[i] = sum(1/j for j=1..i) for i=0..n."""
H = np.zeros(n + 1)
H[1:] = np.cumsum(1.0 / np.arange(1, n + 1))
return H
@lru_cache(maxsize=128)
def _harmonic_a1_a2(n):
"""Return (a1, a2) harmonic number sums for sample size n.
a1 = sum(1/i for i=1..n-1), a2 = sum(1/i^2 for i=1..n-1).
"""
k = np.arange(1, n, dtype=np.float64)
return float(np.sum(1.0 / k)), float(np.sum(1.0 / (k * k)))
@lru_cache(maxsize=64)
def _compute_sigma_ij_gpu(n):
"""Compute Fu (1995) sigma_ij on GPU, return CuPy array (cached)."""
H = _harmonic_sums(n)
an = H[n - 1]
# beta(i, n) for i = 1..n-1 (CPU, 1D)
idx_b = np.arange(1, n, dtype=np.float64)
a_b = H[idx_b.astype(int) - 1]
beta_arr = (2.0 * n / ((n - idx_b + 1) * (n - idx_b))
* (an + 1.0 / n - a_b)
- 2.0 / (n - idx_b))
beta_full = np.zeros(n + 1)
beta_full[1:n] = beta_arr
# O(n^2) broadcast on GPU
beta_gpu = cp.asarray(beta_full)
H_gpu = cp.asarray(H)
idx = cp.arange(1, n, dtype=cp.float64)
ii = idx[:, None]
jj = idx[None, :]
i_hi = cp.maximum(ii, jj)
j_lo = cp.minimum(ii, jj)
s = i_hi + j_lo
i_hi_int = i_hi.astype(cp.int64)
j_lo_int = j_lo.astype(cp.int64)
case_lt = (beta_gpu[i_hi_int + 1] - beta_gpu[i_hi_int]) / 2.0
ai_hi = H_gpu[i_hi_int - 1]
aj_lo = H_gpu[j_lo_int - 1]
case_eq = ((an - ai_hi) / (n - i_hi) + (an - aj_lo) / (n - j_lo)
- (beta_gpu[i_hi_int] + beta_gpu[j_lo_int + 1]) / 2.0
- 1.0 / (i_hi * j_lo))
case_gt = ((beta_gpu[j_lo_int] - beta_gpu[j_lo_int + 1]) / 2.0
- 1.0 / (i_hi * j_lo))
sigma = cp.where(s < n, case_lt, cp.where(s == n, case_eq, case_gt))
# Diagonal
i_d = idx
i_d_int = i_d.astype(cp.int64)
diag_vals = cp.where(
2 * i_d < n, beta_gpu[i_d_int + 1],
cp.where(2 * i_d == n,
2.0 * (an - H_gpu[i_d_int - 1]) / (n - i_d) - 1.0 / (i_d * i_d),
beta_gpu[i_d_int] - 1.0 / (i_d * i_d)))
d_idx = cp.arange(n - 1)
sigma[d_idx, d_idx] = diag_vals
return sigma
def compute_sigma_ij(n):
"""Compute the Fu (1995) covariance matrix sigma_ij for sample size n.
sigma_ij = Cov[xi_i, xi_j] / theta^2 for the unfolded SFS under the
standard neutral model.
Parameters
----------
n : int
Sample size (number of haplotypes).
Returns
-------
sigma : ndarray, float64, shape (n-1, n-1)
"""
return _compute_sigma_ij_gpu(n).get()
@lru_cache(maxsize=128)
def _projection_matrix(n_from, n_to):
"""Hypergeometric projection matrix from n_from to n_to."""
from scipy.special import comb
P = np.zeros((n_to + 1, n_from + 1))
for k_from in range(n_from + 1):
for k_to in range(max(0, k_from - (n_from - n_to)),
min(k_from, n_to) + 1):
P[k_to, k_from] = (comb(k_from, k_to, exact=True)
* comb(n_from - k_from, n_to - k_to, exact=True)
/ comb(n_from, n_to, exact=True))
return P
def project_sfs(sfs, n_from, n_to):
"""Project an SFS from sample size n_from down to n_to.
Uses hypergeometric sampling (Gutenkunst et al. 2009).
"""
if n_to > n_from:
raise ValueError(f"Cannot project up: n_to={n_to} > n_from={n_from}")
if n_to == n_from:
return sfs.copy()
return _projection_matrix(n_from, n_to) @ sfs
# ---------------------------------------------------------------------------
# FrequencySpectrum: power-user class for SFS analysis
# ---------------------------------------------------------------------------
# Weight functions for SFS dot-product path (used by FrequencySpectrum.theta
# for custom callables; built-in names use _site_contribution instead).
def _weights_watterson(n):
w = np.zeros(n + 1)
a1 = np.sum(1.0 / np.arange(1, n))
w[1:n] = 1.0 / a1
return w
def _weights_pi(n):
k = np.arange(n + 1, dtype=np.float64)
w = 2.0 * k * (n - k) / (n * (n - 1))
w[0] = w[n] = 0.0
return w
def _weights_theta_h(n):
k = np.arange(n + 1, dtype=np.float64)
w = 2.0 * k ** 2 / (n * (n - 1))
w[0] = w[n] = 0.0
return w
def _weights_theta_l(n):
k = np.arange(n + 1, dtype=np.float64)
w = k / (n - 1)
w[0] = w[n] = 0.0
return w
def _weights_eta1(n):
w = np.zeros(n + 1)
a1 = np.sum(1.0 / np.arange(1, n))
w[1] = 1.0 / a1
return w
def _weights_eta1_star(n):
w = np.zeros(n + 1)
a1 = np.sum(1.0 / np.arange(1, n))
w[1] = w[n - 1] = 1.0 / a1
return w
def _weights_minus_eta1(n):
w = np.zeros(n + 1)
a1 = np.sum(1.0 / np.arange(1, n))
w[2:n] = 1.0 / (a1 - 1.0)
return w
def _weights_minus_eta1_star(n):
w = np.zeros(n + 1)
a1 = np.sum(1.0 / np.arange(1, n))
w[2:n - 1] = 1.0 / (a1 - 1.0 - 1.0 / (n - 1))
return w
WEIGHT_REGISTRY: Dict[str, Callable] = {
'watterson': _weights_watterson, 'theta_s': _weights_watterson,
'pi': _weights_pi, 'theta_pi': _weights_pi,
'theta_h': _weights_theta_h,
'theta_l': _weights_theta_l,
'eta1': _weights_eta1,
'eta1_star': _weights_eta1_star,
'minus_eta1': _weights_minus_eta1,
'minus_eta1_star': _weights_minus_eta1_star,
}
[docs]
class FrequencySpectrum:
"""Site frequency spectrum with support for variable sample sizes.
Computes derived allele counts on GPU, groups by per-site sample
size, and provides theta estimation via weight vector dot products.
For built-in estimators, use the scalar functions (``pi()``, etc.)
which are faster. This class is for custom weight functions,
SFS inspection, projection, and the general Achaz variance framework.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
'include' (default) or 'exclude'
n_total_sites : int, optional
Total callable sites for invariant site correction.
"""
def __init__(self, haplotype_matrix, population=None,
missing_data='include', n_total_sites=None):
if population is not None:
matrix = get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
self._source_matrix = matrix
n_hap = matrix.num_haplotypes
if n_total_sites is None:
n_total_sites = matrix.n_total_sites
from ._memutil import dac_and_n as _dac_n
dac, n_valid = _dac_n(matrix.haplotypes)
if missing_data == 'exclude':
complete = n_valid == n_hap
dac = dac[complete]
n_valid = n_valid[complete]
self.sfs_by_n = {}
if len(dac) == 0:
self.n_max = 0
self.n_segregating = 0
else:
unique_n = cp.unique(n_valid)
self.n_max = int(unique_n[-1].get())
for ni_gpu in unique_n:
ni = int(ni_gpu.get())
if ni < 2:
continue
mask = n_valid == ni
xi = cp.bincount(dac[mask], minlength=ni + 1)[:ni + 1]
self.sfs_by_n[ni] = xi.astype(cp.float64).get()
self.n_segregating = sum(
int(np.sum(xi[1:n])) for n, xi in self.sfs_by_n.items())
self.n_total_sites = n_total_sites
if n_total_sites is not None and self.n_max > 0:
n_invariant = n_total_sites - self.n_segregating
if n_invariant > 0 and self.n_max in self.sfs_by_n:
self.sfs_by_n[self.n_max][0] += n_invariant
[docs]
def theta(self, weights='pi', span_normalize=False, span=None):
"""Compute a theta estimator from the SFS.
Parameters
----------
weights : str or callable
Name of a built-in weight function, or a callable w(n) -> array.
span_normalize : bool
span : float, optional
"""
if isinstance(weights, str):
if weights not in WEIGHT_REGISTRY:
raise ValueError(f"Unknown weight: {weights}")
weights_fn = WEIGHT_REGISTRY[weights]
else:
weights_fn = weights
total = 0.0
for n, xi in self.sfs_by_n.items():
w = weights_fn(n)
total += np.sum(xi[:len(w)] * w[:len(xi)])
if span_normalize is not False:
if span is not None and span > 0:
total /= span
elif self._source_matrix is not None:
mode = 'auto' if span_normalize is True else span_normalize
s = self._source_matrix.get_span(mode)
if s > 0:
total /= s
return total
[docs]
def neutrality_test(self, w1='pi', w2='watterson'):
"""Compute T = (theta1 - theta2) / sqrt(var) using Achaz (2009) Eq. 9."""
theta1 = self.theta(w1)
theta2 = self.theta(w2)
numerator = theta1 - theta2
S = self.n_segregating
if S < 3:
return float('nan')
n_eff = max(self.sfs_by_n.keys(),
key=lambda n: np.sum(self.sfs_by_n[n]))
w1_name = w1 if isinstance(w1, str) else None
w2_name = w2 if isinstance(w2, str) else None
if w1_name and w2_name:
variance = _achaz_variance(w1_name, w2_name, n_eff, S)
else:
w1_fn = WEIGHT_REGISTRY[w1] if isinstance(w1, str) else w1
w2_fn = WEIGHT_REGISTRY[w2] if isinstance(w2, str) else w2
alpha_n, beta_n = _achaz_alpha_beta(w1_fn(n_eff), w2_fn(n_eff), n_eff)
a1, a2 = _harmonic_a1_a2(n_eff)
variance = alpha_n * S / a1 + beta_n * S * (S - 1) / (a1 ** 2 + a2)
if variance <= 0:
return float('nan')
return numerator / math.sqrt(variance)
[docs]
def suggest_projection_n(self, retain_fraction=0.95):
"""Suggest a projection target retaining most sites."""
if len(self.sfs_by_n) <= 1:
return self.n_max
sorted_ns = sorted(self.sfs_by_n.keys(), reverse=True)
total_seg = self.n_segregating
if total_seg == 0:
return self.n_max
cumulative = 0
for ni in sorted_ns:
cumulative += int(np.sum(self.sfs_by_n[ni][1:ni]))
if cumulative / total_seg >= retain_fraction:
return ni
return sorted_ns[-1]
[docs]
def project(self, target_n):
"""Project all SFS groups to a common sample size."""
projected = np.zeros(target_n + 1)
for n, xi in self.sfs_by_n.items():
if n < target_n:
continue
projected += project_sfs(xi, n, target_n)
result = object.__new__(FrequencySpectrum)
result.sfs_by_n = {target_n: projected}
result.n_max = target_n
result.n_segregating = int(np.sum(projected[1:target_n]))
result.n_total_sites = self.n_total_sites
result._source_matrix = self._source_matrix
return result
[docs]
def sfs(self, n=None):
"""Return the SFS, optionally projected."""
if n is not None:
return self.project(n).sfs_by_n[n]
if len(self.sfs_by_n) == 1:
return list(self.sfs_by_n.values())[0]
return self.sfs_by_n.get(self.n_max, np.array([]))
[docs]
def all_thetas(self, span_normalize=False, span=None):
"""Compute all 8 standard theta estimators."""
return {name: self.theta(name, span_normalize=span_normalize, span=span)
for name in ['pi', 'watterson', 'theta_h', 'theta_l',
'eta1', 'eta1_star', 'minus_eta1', 'minus_eta1_star']}
[docs]
def tajimas_d(self):
"""Tajima's D via Achaz (2009) general variance framework."""
return self.neutrality_test('pi', 'watterson')
[docs]
def fay_wu_h(self, normalized=False):
"""Fay & Wu's H = pi - theta_H. Optionally normalized (H*)."""
h = self.theta('pi') - self.theta('theta_h')
if not normalized:
return h
return self.neutrality_test('pi', 'theta_h')
[docs]
def zeng_e(self):
"""Zeng's E via Achaz (2009) general variance framework."""
return self.neutrality_test('theta_l', 'watterson')
[docs]
def all_tests(self):
"""All standard neutrality tests."""
return {
'tajimas_d': self.tajimas_d(),
'fay_wu_h': self.fay_wu_h(),
'normalized_fay_wu_h': self.fay_wu_h(normalized=True),
'zeng_e': self.zeng_e(),
}
# ---------------------------------------------------------------------------
# Pairwise components (power-user API)
# ---------------------------------------------------------------------------
def pi_components(haplotypes, n_total_sites=None, n_haplotypes_full=None):
"""Compute pairwise differences and comparisons across all sites.
For advanced use cases (custom windowed aggregation, etc.).
Parameters
----------
haplotypes : cp.ndarray, shape (n_haplotypes, n_variants)
Haplotype data with -1 for missing.
n_total_sites : int, optional
Total callable sites (variant + invariant). If provided, invariant
sites contribute 0 diffs and C(n_haplotypes_full, 2) comps each.
n_haplotypes_full : int, optional
Full sample size (used for invariant site comps).
Returns
-------
total_diffs : float
total_comps : float
total_missing : float
n_sites : int
"""
dac, n_valid_i = _dac_and_n(haplotypes)
n_valid = n_valid_i.astype(cp.float64)
derived = dac.astype(cp.float64)
ancestral = n_valid - derived
site_diffs = derived * ancestral
site_comps = n_valid * (n_valid - 1) / 2.0
usable = n_valid >= 2
total_diffs = float(cp.sum(site_diffs[usable]).get())
total_comps = float(cp.sum(site_comps[usable]).get())
n_sites = int(cp.sum(usable).get())
if n_total_sites is not None:
n_full = n_haplotypes_full or haplotypes.shape[0]
n_invariant = n_total_sites - n_sites
if n_invariant > 0:
total_comps += n_invariant * (n_full * (n_full - 1) / 2.0)
n_sites += n_invariant
n_full = n_haplotypes_full or haplotypes.shape[0]
total_possible = (n_full * (n_full - 1) / 2.0) * n_sites
total_missing = total_possible - total_comps
return total_diffs, total_comps, total_missing, n_sites
def _dac_and_n(haplotypes):
"""Shared helper: derived allele counts and valid sample counts per site.
Uses adaptive chunking from _memutil for memory safety on large matrices.
Parameters
----------
haplotypes : cupy.ndarray, int8, shape (n_hap, n_var)
Returns
-------
dac : cupy.ndarray, int64, shape (n_var,)
n_valid : cupy.ndarray, int64, shape (n_var,)
"""
from ._memutil import dac_and_n
return dac_and_n(haplotypes)
[docs]
def pi(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
span_normalize=True,
missing_data: str = 'include') -> float:
"""
Calculate nucleotide diversity (pi) for a population.
Nucleotide diversity is the average number of nucleotide differences
per site between two randomly chosen sequences from the population.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
span_normalize : bool
``True`` (default): auto-detect best denominator (accessible
bases if mask set, else genomic span).
``False``: return raw sum.
missing_data : str
``'include'`` (default) uses per-site valid data.
``'exclude'`` filters to sites with no missing data.
Returns
-------
float
Nucleotide diversity value
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return 0.0
result = _compute_thetas(matrix, ('pi',))
return _apply_span_normalize(result['thetas']['pi'], matrix, span_normalize)
[docs]
def theta_w(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
span_normalize=True,
missing_data: str = 'include') -> float:
"""
Calculate Watterson's theta for a population.
Watterson's theta is an estimator of the population mutation rate based
on the number of segregating sites.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sum.
missing_data : str
``'include'`` (default) uses per-site valid data.
``'exclude'`` filters to sites with no missing data.
Returns
-------
float
Watterson's theta value
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return 0.0
result = _compute_thetas(matrix, ('watterson',))
return _apply_span_normalize(result['thetas']['watterson'], matrix, span_normalize)
[docs]
def tajimas_d(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""
Calculate Tajima's D statistic.
Tajima's D tests the neutral mutation hypothesis by comparing two
estimates of the population mutation rate.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
``'include'`` (default) uses per-site valid data with harmonic
mean of per-site sample sizes for variance terms.
``'exclude'`` filters to sites with no missing data.
Returns
-------
float
Tajima's D value
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return float("nan")
return _compute_neutrality_test(matrix, 'pi', 'watterson')
[docs]
def allele_frequency_spectrum(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> cp.ndarray:
"""
Calculate the allele frequency spectrum (AFS).
The AFS is a histogram of allele frequencies across all sites.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
'include' - Calculate AFS using available data per site.
'exclude' - Only use sites with no missing data.
Returns
-------
ndarray
Array where element i contains the number of sites with i derived alleles
"""
# Get population subset if specified
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
# Ensure on GPU
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
if missing_data == 'exclude':
matrix = matrix.exclude_missing_sites()
if matrix.num_variants == 0:
return np.zeros(matrix.num_haplotypes + 1, dtype=np.int64)
n_haplotypes = matrix.num_haplotypes
freqs = cp.sum(matrix.haplotypes, axis=0)
else: # missing_data == 'include'
max_n = matrix.num_haplotypes
derived_counts, n_valid_per_site = _dac_and_n(matrix.haplotypes)
sites_with_data = n_valid_per_site > 0
if not cp.any(sites_with_data):
return np.zeros(max_n + 1, dtype=np.int64)
# Filter to sites with valid data and check they're biallelic
valid_sites = cp.where(sites_with_data)[0]
derived_at_valid = derived_counts[valid_sites]
n_valid_at_valid = n_valid_per_site[valid_sites]
# Check biallelic assumption: derived count should be <= n_valid
biallelic_mask = derived_at_valid <= n_valid_at_valid
final_derived = derived_at_valid[biallelic_mask]
# Create AFS histogram
# Use bincount which is more efficient than a loop
if len(final_derived) > 0:
# Ensure derived counts don't exceed max_n
final_derived = cp.minimum(final_derived, max_n)
afs = cp.bincount(final_derived, minlength=max_n + 1)
# Ensure correct size and type
if len(afs) < max_n + 1:
afs_full = cp.zeros(max_n + 1, dtype=cp.int64)
afs_full[:len(afs)] = afs
afs = afs_full
else:
afs = afs[:max_n + 1].astype(cp.int64)
else:
return np.zeros(max_n + 1, dtype=np.int64)
return afs.get()
# For exclude mode, create standard histogram
return cp.histogram(freqs, bins=cp.arange(n_haplotypes + 2))[0].get()
[docs]
def segregating_sites(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> int:
"""
Count the number of segregating sites.
A site is segregating if it has more than one allele present.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
'include' - Count sites as segregating based on non-missing data only.
'exclude' - Only count sites with no missing data.
Returns
-------
int
Number of segregating sites
"""
# Get population subset if specified
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
# Ensure on GPU
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
if missing_data == 'exclude':
matrix = matrix.exclude_missing_sites()
if matrix.num_variants == 0:
return 0
allele_counts = cp.sum(matrix.haplotypes, axis=0)
n_haplotypes = matrix.num_haplotypes
segregating = (allele_counts > 0) & (allele_counts < n_haplotypes)
else: # missing_data == 'include'
derived_counts, n_valid_per_site = _dac_and_n(matrix.haplotypes)
sites_with_data = n_valid_per_site >= 2
if not cp.any(sites_with_data):
return 0
valid_sites = cp.where(sites_with_data)[0]
segregating_mask = (derived_counts[valid_sites] > 0) & (derived_counts[valid_sites] < n_valid_per_site[valid_sites])
return int(cp.sum(segregating_mask).get())
return int(cp.sum(segregating).get())
[docs]
def singleton_count(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> int:
"""
Count the number of singleton variants.
A singleton is a variant present in exactly one haplotype.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
'include' - Count singletons based on non-missing data only.
'exclude' - Only count singletons at sites with no missing data.
Returns
-------
int
Number of singleton variants
"""
# Get population subset if specified
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
# Ensure on GPU
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
if missing_data == 'exclude':
matrix = matrix.exclude_missing_sites()
if matrix.num_variants == 0:
return 0
allele_counts = cp.sum(matrix.haplotypes, axis=0)
else: # missing_data == 'include'
derived_counts, n_valid_per_site = _dac_and_n(matrix.haplotypes)
sites_with_data = n_valid_per_site >= 1
if not cp.any(sites_with_data):
return 0
valid_sites = cp.where(sites_with_data)[0]
return int(cp.sum(derived_counts[valid_sites] == 1).get())
# For exclude mode
return int(cp.sum(allele_counts == 1).get())
[docs]
def diversity_stats(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
statistics: list = ['pi', 'theta_w', 'tajimas_d'],
span_normalize=True,
missing_data: str = 'include') -> Dict[str, float]:
"""
Compute multiple diversity statistics at once.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
statistics : list
List of statistics to compute
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sums.
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
dict
Dictionary mapping statistic names to values
"""
# Map stat names to estimator names for batched computation
theta_stats = {'pi': 'pi', 'theta_w': 'watterson', 'theta_h': 'theta_h',
'theta_l': 'theta_l'}
# Neutrality tests: (w1, w2) weight pairs
test_specs = {
'tajimas_d': ('pi', 'watterson'),
'fay_wus_h': ('pi', 'theta_h'),
'normalized_fay_wus_h': ('pi', 'theta_h'),
'zeng_e': ('theta_l', 'watterson'),
}
needs_thetas = {s for s in statistics if s in theta_stats or s in test_specs}
results = {}
if needs_thetas:
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
for s in needs_thetas:
results[s] = 0.0 if s in theta_stats else float('nan')
else:
# Collect all needed estimators for a single _compute_thetas call
estimators = set()
for s in needs_thetas:
if s in theta_stats:
estimators.add(theta_stats[s])
elif s in test_specs:
estimators.update(test_specs[s])
ct = _compute_thetas(matrix, tuple(estimators))
for s in needs_thetas:
if s in theta_stats:
results[s] = _apply_span_normalize(
ct['thetas'][theta_stats[s]], matrix, span_normalize)
elif s in test_specs:
w1, w2 = test_specs[s]
S = ct['S']
n = ct['n_harmonic_mean']
if S < 3 or n < 3:
results[s] = float('nan')
elif s == 'fay_wus_h':
results[s] = float(ct['thetas'][w1] - ct['thetas'][w2])
else:
var = _achaz_variance(w1, w2, n, S)
num = ct['thetas'][w1] - ct['thetas'][w2]
results[s] = float(num / math.sqrt(var)) if var > 0 else float('nan')
# Non-theta stats
for stat in statistics:
if stat in results:
continue
if stat == 'segregating_sites':
results['segregating_sites'] = segregating_sites(haplotype_matrix, population, missing_data)
elif stat == 'singletons':
results['singletons'] = singleton_count(haplotype_matrix, population, missing_data)
elif stat == 'n_variants':
m = _prepare_matrix(haplotype_matrix, population, missing_data)
results['n_variants'] = m.num_variants
elif stat == 'haplotype_diversity':
results['haplotype_diversity'] = haplotype_diversity(haplotype_matrix, population, missing_data)
elif stat not in ('pi', 'theta_w', 'theta_h', 'theta_l', 'tajimas_d'):
raise ValueError(f"Unknown statistic: {stat}")
return results
[docs]
def fay_wus_h(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""
Calculate Fay and Wu's H statistic.
Tests for an excess of high-frequency derived alleles, which can indicate
positive selection.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
float
Fay and Wu's H value
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return float("nan")
result = _compute_thetas(matrix, ('pi', 'theta_h'))
return result['thetas']['pi'] - result['thetas']['theta_h']
[docs]
def haplotype_diversity(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""
Calculate haplotype diversity for a population.
Haplotype diversity is defined as 1 - sum(p_i^2) where p_i is the
frequency of the i-th unique haplotype in the population.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices. If None, uses all samples
missing_data : str
'include' - exclude haplotypes with any missing data
'exclude' - filter to sites with no missing data
Returns
-------
float
Haplotype diversity value
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
haplotypes = matrix.haplotypes # (n_hap, n_var)
if missing_data == 'exclude':
missing_per_var = cp.sum(haplotypes < 0, axis=0)
complete = cp.where(missing_per_var == 0)[0]
haplotypes = haplotypes[:, complete]
n_haplotypes = haplotypes.shape[0]
if n_haplotypes <= 1:
return 0.0
has_missing = bool(cp.any(haplotypes < 0).get())
if has_missing:
# Fallback: wildcard matching requires CPU pairwise comparison
haplotypes_cpu = haplotypes.get() if hasattr(haplotypes, 'get') else haplotypes
cluster_id = _cluster_haplotypes_with_missing(haplotypes_cpu)
from collections import Counter
counts = Counter(cluster_id)
frequencies = np.array(list(counts.values())) / n_haplotypes
else:
_, counts_gpu = _count_unique_haplotypes_gpu(haplotypes)
frequencies = (counts_gpu.astype(cp.float64) / n_haplotypes).get()
diversity = (1.0 - np.sum(frequencies ** 2)) * n_haplotypes / (n_haplotypes - 1)
return float(diversity)
_HASH_SEED = 42 # fixed so identical inputs produce identical groupings across calls
_HASH_TOL = 1e-3 # collision-safe for float32 dot products of 0/1 vectors at our n_var
def _count_unique_haplotypes_gpu(haplotypes):
"""Count unique haplotypes on GPU via dot-product hashing.
Caller must guarantee the input contains no missing data (-1).
Returns
-------
n_unique : int
counts : cupy.ndarray of group sizes (unsorted)
"""
n_haplotypes, n_var = haplotypes.shape
rng = cp.random.RandomState(seed=_HASH_SEED)
w1 = rng.standard_normal(n_var, dtype=cp.float32)
w2 = rng.standard_normal(n_var, dtype=cp.float32)
h_f32 = haplotypes.astype(cp.float32)
hash1 = h_f32 @ w1
hash2 = h_f32 @ w2
order = cp.lexsort(cp.stack([hash2, hash1]))
s1 = hash1[order]
s2 = hash2[order]
diff = (cp.abs(s1[1:] - s1[:-1]) > _HASH_TOL) | (cp.abs(s2[1:] - s2[:-1]) > _HASH_TOL)
boundaries = cp.concatenate([cp.ones(1, dtype=cp.bool_), diff])
boundary_idx = cp.where(boundaries)[0]
tail = cp.full(1, n_haplotypes, dtype=boundary_idx.dtype)
counts_gpu = cp.diff(cp.concatenate([boundary_idx, tail]))
return boundary_idx.shape[0], counts_gpu
def _cluster_haplotypes_with_missing(haps):
"""Cluster haplotypes treating -1 as compatible with any allele.
Two haplotypes are in the same cluster if they match at all positions
where both are non-missing. Uses greedy assignment: each haplotype
joins the first compatible cluster.
Parameters
----------
haps : ndarray, shape (n_haplotypes, n_variants)
Returns
-------
labels : list of int, length n_haplotypes
"""
n = haps.shape[0]
has_any_missing = np.any(haps < 0)
if not has_any_missing:
# fast path: no missing data, use string hashing
hap_strings = [''.join(map(str, h)) for h in haps]
label_map = {}
labels = []
next_id = 0
for s in hap_strings:
if s not in label_map:
label_map[s] = next_id
next_id += 1
labels.append(label_map[s])
return labels
# slow path: pairwise comparison with wildcard matching
# representative haplotype per cluster (index into haps)
cluster_reps = [0]
labels = [0]
for i in range(1, n):
matched = False
for c_idx, rep in enumerate(cluster_reps):
# check if haps[i] matches haps[rep] at jointly non-missing sites
both_valid = (haps[i] >= 0) & (haps[rep] >= 0)
if np.all(haps[i][both_valid] == haps[rep][both_valid]):
labels.append(c_idx)
matched = True
break
if not matched:
cluster_reps.append(i)
labels.append(len(cluster_reps) - 1)
return labels
_get_population_matrix = get_population_matrix
[docs]
def theta_h(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
span_normalize=True,
missing_data: str = 'include') -> float:
"""Compute theta_H (homozygosity-based diversity estimator).
theta_H = sum_i [ i^2 * S_i ] * 2 / (n*(n-1)) where S_i is the count
of variants with derived allele count i. Used to compute Fay and Wu's H.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sum.
missing_data : str
'include' - per-site sample sizes
'exclude' - only sites with no missing data
Returns
-------
float
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return 0.0
result = _compute_thetas(matrix, ('theta_h',))
return _apply_span_normalize(result['thetas']['theta_h'], matrix, span_normalize)
[docs]
def theta_l(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
span_normalize=True,
missing_data: str = 'include') -> float:
"""Compute theta_L diversity estimator.
theta_L = sum_i(i * xi_i) / (n - 1), where xi_i is the count of
sites with derived allele count i. Weights variants linearly by
derived allele frequency, bridging theta_pi and theta_H.
Reference: Zeng et al. (2006), Genetics 174: 1431-1439, Equation (8).
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
span_normalize : bool
``True`` (default): auto-detect best denominator.
``False``: return raw sum.
missing_data : str
'include' - per-site sample sizes
'exclude' - only sites with no missing data
Returns
-------
float
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return 0.0
result = _compute_thetas(matrix, ('theta_l',))
return _apply_span_normalize(result['thetas']['theta_l'], matrix, span_normalize)
[docs]
def normalized_fay_wus_h(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""Compute normalized Fay and Wu's H (H*).
H = theta_pi - theta_H, normalized by its standard deviation under
the standard neutral model. The normalization allows comparison
across samples with different numbers of segregating sites.
Reference: Zeng et al. (2006), "Statistical Tests for Detecting
Positive Selection by Utilizing High-Frequency Variants",
Genetics 174: 1431-1439, Equation (11).
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
``'include'`` (default) uses per-site sample
sizes with harmonic mean n for variance terms.
``'exclude'`` filters to sites with no missing data.
Returns
-------
float
Normalized H*. Negative values indicate excess high-frequency
derived alleles (directional selection signal).
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return float("nan")
return _compute_neutrality_test(matrix, 'pi', 'theta_h')
[docs]
def zeng_e(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""Compute Zeng's E test statistic.
E = theta_L - theta_W, normalized by its standard deviation.
Reference: Zeng et al. (2006), Genetics 174: 1431-1439, Equation (13).
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
``'include'`` (default) or ``'exclude'``.
Returns
-------
float
"""
matrix = _prepare_matrix(haplotype_matrix, population, missing_data)
if matrix.num_variants == 0:
return float('nan')
return _compute_neutrality_test(matrix, 'theta_l', 'watterson')
[docs]
def zeng_dh(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""Compute Zeng's DH joint test statistic.
Combines Tajima's D and Fay & Wu's H into a single test with
improved power to detect directional selection. Defined as the
product D * H when both are negative, zero otherwise.
Reference: Zeng et al. (2006), "Statistical Tests for Detecting
Positive Selection by Utilizing High-Frequency Variants",
Genetics 174: 1431-1439, Equation (15).
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
Passed through to tajimas_d and fay_wus_h.
Returns
-------
float
DH statistic. Positive when both D and H are negative
(consistent with a selective sweep).
"""
D = tajimas_d(haplotype_matrix, population, missing_data=missing_data)
H = fay_wus_h(haplotype_matrix, population, missing_data=missing_data)
# DH is the product when both are negative (sweep signal)
if D < 0 and H < 0:
return float(D * H)
else:
return 0.0
[docs]
def max_daf(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""Maximum derived allele frequency across all variants.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
'include' - per-site n_valid for frequency
'exclude' - only sites with no missing data
Returns
-------
float
Maximum DAF in [0, 1].
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
dac_i, n_valid_i = _dac_and_n(matrix.haplotypes)
dac = dac_i.astype(cp.float64)
n_valid = n_valid_i.astype(cp.float64)
if missing_data == 'exclude':
complete = n_valid_i == matrix.haplotypes.shape[0]
freqs = cp.where(complete, dac / n_valid, -1.0)
else:
usable = n_valid > 0
freqs = cp.where(usable, dac / n_valid, 0.0)
return float(cp.max(freqs).get())
[docs]
def haplotype_count(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> int:
"""Count distinct haplotypes.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
'include' - exclude haplotypes with any missing
'exclude' - filter to sites with no missing data
Returns
-------
int
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
haplotypes = matrix.haplotypes
excluded = missing_data == 'exclude'
if excluded:
haplotypes = haplotypes[:, cp.sum(haplotypes < 0, axis=0) == 0]
if haplotypes.shape[0] <= 1:
return haplotypes.shape[0]
# 'exclude' already removed every site with a -1, so the remainder is clean
has_missing = False if excluded else bool(cp.any(haplotypes < 0).get())
if has_missing:
# Wildcard matching requires CPU pairwise comparison
hap_cpu = haplotypes.get().astype(np.int8)
labels = _cluster_haplotypes_with_missing(hap_cpu)
return len(set(labels))
n_unique, _ = _count_unique_haplotypes_gpu(haplotypes)
return n_unique
[docs]
def daf_histogram(matrix, n_bins: int = 20,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include'):
"""Normalized histogram of derived allele frequencies.
Accepts HaplotypeMatrix or GenotypeMatrix. For diploid data,
DAF = sum(genotypes) / (2 * n_individuals).
Parameters
----------
matrix : HaplotypeMatrix or GenotypeMatrix
n_bins : int
Number of frequency bins spanning [0, 1].
population : str or list, optional
missing_data : str
'include' - per-site n_valid for frequency
'exclude' - only sites with no missing data
Returns
-------
hist : ndarray, float64, shape (n_bins,)
Normalized counts (sum to 1).
bin_edges : ndarray, float64, shape (n_bins + 1,)
"""
from .genotype_matrix import GenotypeMatrix
if isinstance(matrix, GenotypeMatrix):
return _daf_histogram_diploid(matrix, n_bins, population)
if population is not None:
matrix = _get_population_matrix(matrix, population)
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
dac_i, n_valid_i = _dac_and_n(matrix.haplotypes)
dac = dac_i.astype(cp.float64)
n_valid = n_valid_i.astype(cp.float64)
if missing_data == 'exclude':
complete = n_valid_i == matrix.haplotypes.shape[0]
dafs = (dac / n_valid)[complete]
else:
usable = n_valid > 0
dafs = cp.where(usable, dac / n_valid, 0.0)
return _histogram_from_dafs(dafs, n_bins)
[docs]
def diplotype_frequency_spectrum(genotype_matrix,
population: Optional[Union[str, list]] = None):
"""Count distinct multi-locus genotype patterns (diplotypes).
Parameters
----------
genotype_matrix : GenotypeMatrix
population : str or list, optional
Returns
-------
freqs : ndarray, float64, sorted descending
Diplotype frequencies.
n_diplotypes : int
Number of distinct diplotypes.
"""
if population is not None:
pop_idx = genotype_matrix.sample_sets.get(population)
if pop_idx is None:
raise ValueError(f"Population {population} not found")
geno = genotype_matrix.genotypes[pop_idx, :]
else:
geno = genotype_matrix.genotypes
if isinstance(geno, cp.ndarray):
geno = geno.get()
geno = np.asarray(geno, dtype=np.int8)
n_ind = geno.shape[0]
# treat missing (-1) as wildcard for diplotype identity
labels = _cluster_haplotypes_with_missing(geno)
from collections import Counter
counts = Counter(labels)
freqs = np.array(sorted(counts.values(), reverse=True)) / n_ind
return freqs, len(counts)
def _histogram_from_dafs(dafs, n_bins):
"""Shared: compute normalized histogram from DAF CuPy array."""
bin_edges = cp.linspace(0, 1, n_bins + 1)
hist = cp.histogram(dafs, bins=bin_edges)[0].astype(cp.float64)
total = cp.sum(hist)
if total > 0:
hist = hist / total
return hist.get(), bin_edges.get()
def _daf_histogram_diploid(genotype_matrix, n_bins=20, population=None):
"""DAF histogram from diploid genotypes (internal)."""
if population is not None:
pop_idx = genotype_matrix.sample_sets.get(population)
if pop_idx is None:
raise ValueError(f"Population {population} not found")
geno = genotype_matrix.genotypes[pop_idx, :]
else:
geno = genotype_matrix.genotypes
if not isinstance(geno, cp.ndarray):
geno = cp.asarray(geno)
valid_mask = geno >= 0
geno_clean = cp.where(valid_mask, geno, 0).astype(cp.float64)
n_valid = cp.sum(valid_mask, axis=0).astype(cp.float64)
usable = n_valid > 0
dafs = cp.where(usable, cp.sum(geno_clean, axis=0) / (2.0 * n_valid), 0.0)
return _histogram_from_dafs(dafs, n_bins)
# backward compat alias
daf_histogram_diploid = _daf_histogram_diploid
# Summary statistics combinations commonly used
[docs]
def neutrality_tests(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> Dict[str, float]:
"""
Compute common neutrality test statistics.
Returns Tajima's D, Fay and Wu's H, and related values.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
The haplotype data
population : str or list, optional
Population name or list of sample indices
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
Returns
-------
dict
Dictionary with neutrality test results
"""
return {
'tajimas_d': tajimas_d(haplotype_matrix, population, missing_data),
'fay_wus_h': fay_wus_h(haplotype_matrix, population, missing_data),
'pi': pi(haplotype_matrix, population, span_normalize=False, missing_data=missing_data),
'theta_w': theta_w(haplotype_matrix, population, span_normalize=False, missing_data=missing_data),
'segregating_sites': segregating_sites(haplotype_matrix, population, missing_data)
}
[docs]
def heterozygosity_expected(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include'):
"""
Compute expected heterozygosity (gene diversity) per variant.
He = 1 - sum(p_i^2) for each variant, where p_i are allele frequencies.
For biallelic sites this simplifies to He = 2*p*(1-p).
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data.
population : str or list, optional
Population name or sample indices.
missing_data : str
'exclude' - NaN at sites with any missing data
'include' - per-site n_valid for frequency calculation
Returns
-------
ndarray, float64, shape (n_variants,)
Expected heterozygosity per variant.
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
dac_i, n_valid_i = _dac_and_n(matrix.haplotypes)
dac = dac_i.astype(cp.float64)
n_valid = n_valid_i.astype(cp.float64)
n = matrix.haplotypes.shape[0]
if missing_data == 'include':
p = cp.where(n_valid > 0, dac / n_valid, 0.0)
he = 2.0 * p * (1.0 - p)
he = cp.where(n_valid >= 2, he, cp.nan)
else:
p = dac / n
he = 2.0 * p * (1.0 - p)
if missing_data == 'exclude':
incomplete = n_valid_i < n
he[incomplete] = cp.nan
return he.get()
[docs]
def heterozygosity_observed(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
ploidy: int = 2,
missing_data: str = 'include'):
"""
Compute observed heterozygosity per variant.
Assumes consecutive haplotypes belong to the same individual
(standard for diploid VCF data). A site is heterozygous in an
individual if the two haplotypes differ.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data.
population : str or list, optional
Population name or sample indices.
ploidy : int
Ploidy level. Default 2 (diploid).
missing_data : str
'include' - skip missing individuals per site (default)
'exclude' - NaN at sites with any missing data
Returns
-------
ndarray, float64, shape (n_variants,)
Observed heterozygosity per variant.
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
hap = matrix.haplotypes # (n_haplotypes, n_variants)
n_hap = hap.shape[0]
if n_hap % ploidy != 0:
raise ValueError(
f"Number of haplotypes ({n_hap}) not divisible by ploidy ({ploidy})")
n_individuals = n_hap // ploidy
if ploidy == 2:
h1 = hap[0::2]
h2 = hap[1::2]
valid = (h1 >= 0) & (h2 >= 0)
het = (h1 != h2) & valid
n_valid = cp.sum(valid, axis=0).astype(cp.float64)
n_het = cp.sum(het, axis=0).astype(cp.float64)
ho = cp.where(n_valid > 0, n_het / n_valid, cp.nan)
else:
n_variants = hap.shape[1]
n_het = cp.zeros(n_variants, dtype=cp.float64)
n_valid_ind = cp.zeros(n_variants, dtype=cp.float64)
for ind in range(n_individuals):
ind_haps = hap[ind * ploidy:(ind + 1) * ploidy]
all_valid = cp.all(ind_haps >= 0, axis=0)
all_same = cp.all(ind_haps == ind_haps[0:1], axis=0)
n_valid_ind += all_valid.astype(cp.float64)
n_het += (all_valid & ~all_same).astype(cp.float64)
ho = cp.where(n_valid_ind > 0, n_het / n_valid_ind, cp.nan)
if missing_data == 'exclude':
has_missing = cp.any(hap < 0)
if has_missing:
missing_per_var = cp.sum(hap < 0, axis=0)
ho[missing_per_var > 0] = cp.nan
return ho.get()
[docs]
def inbreeding_coefficient(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
ploidy: int = 2,
missing_data: str = 'include'):
"""
Compute Wright's inbreeding coefficient F per variant.
F = 1 - Ho/He, where Ho is observed heterozygosity and He is
expected heterozygosity.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data.
population : str or list, optional
Population name or sample indices.
ploidy : int
Ploidy level for observed heterozygosity computation.
missing_data : str
Passed to heterozygosity_expected and heterozygosity_observed.
Returns
-------
ndarray, float64, shape (n_variants,)
Inbreeding coefficient per variant. NaN where He = 0.
"""
ho = cp.asarray(heterozygosity_observed(haplotype_matrix, population,
ploidy, missing_data=missing_data))
he = cp.asarray(heterozygosity_expected(haplotype_matrix, population,
missing_data=missing_data))
f = cp.where(he > 0, 1.0 - ho / he, cp.nan)
return f.get()
[docs]
def mu_var(haplotype_matrix: HaplotypeMatrix,
window_length: Optional[float] = None,
population: Optional[Union[str, list]] = None) -> float:
"""mu_VAR: SNP density statistic (RAiSD).
Number of SNPs per base pair. Elevated near sweeps due to
hitchhiking effects on local variant density.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
window_length : float, optional
Window length in bp. If None, uses chrom_end - chrom_start.
population : str or list, optional
Returns
-------
float
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
n_snps = matrix.num_variants
if window_length is None:
window_length = matrix.chrom_end - matrix.chrom_start
if window_length <= 0:
return 0.0
return float(n_snps / window_length)
[docs]
def mu_sfs(haplotype_matrix: HaplotypeMatrix,
population: Optional[Union[str, list]] = None,
missing_data: str = 'include') -> float:
"""mu_SFS: fraction of SNPs at SFS edges (RAiSD).
Counts singletons (DAC=1) and near-fixed variants (DAC=n-1),
divided by total segregating sites. Elevated near selective sweeps.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
population : str or list, optional
missing_data : str
'include' - per-site n_valid for edge classification
'exclude' - only sites with no missing data
Returns
-------
float
"""
if population is not None:
matrix = _get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
dac, n_valid = _dac_and_n(matrix.haplotypes)
if missing_data == 'exclude':
n = matrix.haplotypes.shape[0]
complete = n_valid == n
is_seg = complete & (dac > 0) & (dac < n)
is_edge = complete & ((dac == 1) | (dac == n - 1))
else:
usable = n_valid >= 2
is_seg = usable & (dac > 0) & (dac < n_valid)
is_edge = usable & ((dac == 1) | (dac == n_valid - 1))
n_seg = cp.sum(is_seg)
if int(n_seg.get()) == 0:
return 0.0
n_edge = cp.sum(is_edge)
return float((n_edge.astype(cp.float64) / n_seg.astype(cp.float64)).get())