Source code for pg_gpu.ld_statistics

"""
GPU-accelerated linkage disequilibrium statistics.

This module provides an API for computing LD statistics
on GPUs with automatic missing data handling.
"""

import numpy as np
import cupy as cp
from typing import Optional, Union, Tuple, List, Dict


[docs] def dd(counts: cp.ndarray, populations: Optional[Union[Tuple[int, int], int]] = None, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute D² statistic for any population configuration. Parameters ---------- counts : cp.ndarray Haplotype counts array: - Single population: shape (N, 4) - Two populations: shape (N, 8) - Multi-population: shape (N, 4*P) populations : tuple of int, optional Population indices. None for single population, (i, j) for between populations i and j n_valid : cp.ndarray, optional Valid sample counts per population. Shape depends on configuration: - Single pop: shape (N,) - Two pops: shape (N, 2) or tuple of (N,) arrays Returns ------- cp.ndarray D² values for each locus """ # Handle different input formats if populations is None: # Single population case if counts.shape[1] == 4: return _dd_single(counts, n_valid) else: # Default to first population if counts has multiple return _dd_single(counts[:, :4], n_valid[:, 0] if n_valid is not None and n_valid.ndim == 2 else n_valid) # Two population case pop1, pop2 = populations if pop1 == pop2: # Within population start_idx = pop1 * 4 pop_counts = counts[:, start_idx:start_idx + 4] pop_n_valid = None if n_valid is not None: if n_valid.ndim == 2: pop_n_valid = n_valid[:, pop1] elif isinstance(n_valid, tuple): pop_n_valid = n_valid[pop1] else: pop_n_valid = n_valid return _dd_single(pop_counts, pop_n_valid) else: # Between populations return _dd_between(counts, pop1, pop2, n_valid)
[docs] def dz(counts: cp.ndarray, populations: Optional[Tuple[int, int, int]] = None, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute Dz statistic for any population configuration. Parameters ---------- counts : cp.ndarray Haplotype counts array populations : tuple of int, optional Three population indices (i, j, k) for Dz(i,j,k). None defaults to single population (0, 0, 0) n_valid : cp.ndarray, optional Valid sample counts per population Returns ------- cp.ndarray Dz values for each locus """ if populations is None: # Single population case if counts.shape[1] == 4: return _dz_single(counts, n_valid) else: # Default to first population populations = (0, 0, 0) return _dz_multi(counts, populations, n_valid)
[docs] def pi2(counts: cp.ndarray, populations: Optional[Tuple[int, int, int, int]] = None, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute π₂ statistic for any population configuration. Parameters ---------- counts : cp.ndarray Haplotype counts array populations : tuple of int, optional Four population indices (i, j, k, l) for π₂(i,j,k,l). None defaults to single population (0, 0, 0, 0) n_valid : cp.ndarray, optional Valid sample counts per population Returns ------- cp.ndarray π₂ values for each locus """ if populations is None: # Single population case if counts.shape[1] == 4: return _pi2_single(counts, n_valid) else: # Default to first population populations = (0, 0, 0, 0) return _pi2_multi(counts, populations, n_valid)
[docs] def dd_within(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute D² within a single population. Convenience function equivalent to dd(counts, populations=None) """ return _dd_single(counts, n_valid)
[docs] def dd_between(counts: cp.ndarray, pop1_idx: int, pop2_idx: int, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute D² between two populations. Convenience function equivalent to dd(counts, populations=(pop1_idx, pop2_idx)) """ return _dd_between(counts, pop1_idx, pop2_idx, n_valid)
def _hap_count_inputs(counts, n_valid): """Unpack a (N,4) counts array into the 5 contiguous float64 arrays the haplotype r/r_squared/d_prime kernels expect.""" c11 = cp.ascontiguousarray(counts[:, 0].astype(cp.float64)) c10 = cp.ascontiguousarray(counts[:, 1].astype(cp.float64)) c01 = cp.ascontiguousarray(counts[:, 2].astype(cp.float64)) c00 = cp.ascontiguousarray(counts[:, 3].astype(cp.float64)) if n_valid is None: n = c11 + c10 + c01 + c00 else: n = cp.ascontiguousarray(n_valid.astype(cp.float64)) return c11, c10, c01, c00, n
[docs] def r(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute Pearson correlation coefficient r between variant pairs from haplotype counts. Parameters ---------- counts : cp.ndarray, shape (N, 4) Haplotype counts [n11, n10, n01, n00] for each variant pair. n_valid : cp.ndarray, optional Valid sample counts per pair. Shape (N,). Returns ------- cp.ndarray, float64, shape (N,) Pearson r values. NaN where computation is undefined (monomorphic at either locus). """ from .haplotype_kernels import _R_KERN, _launch c11, c10, c01, c00, n = _hap_count_inputs(counts, n_valid) N = c11.shape[0] out = cp.empty(N, dtype=cp.float64) _launch(_R_KERN, (c11, c10, c01, c00, n, out, N), N) return out
[docs] def r_squared(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """ Compute r-squared (squared Pearson correlation) between variant pairs from haplotype counts. Parameters ---------- counts : cp.ndarray, shape (N, 4) Haplotype counts [n11, n10, n01, n00] for each variant pair. n_valid : cp.ndarray, optional Valid sample counts per pair. Shape (N,). Returns ------- cp.ndarray, float64, shape (N,) r-squared values. NaN where computation is undefined. """ from .haplotype_kernels import _R_SQUARED_KERN, _launch c11, c10, c01, c00, n = _hap_count_inputs(counts, n_valid) N = c11.shape[0] out = cp.empty(N, dtype=cp.float64) _launch(_R_SQUARED_KERN, (c11, c10, c01, c00, n, out, N), N) return out
def d_prime(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute Lewontin's D' (normalized linkage disequilibrium). D' = D / D_max, where D_max depends on the sign of D: If D >= 0: D_max = min(p_A * q_B, q_A * p_B) If D < 0: D_max = min(p_A * p_B, q_A * q_B) Parameters ---------- counts : cp.ndarray, shape (N, 4) Haplotype counts [n11, n10, n01, n00] for each variant pair. n_valid : cp.ndarray, optional Valid sample counts per pair. Shape (N,). Returns ------- cp.ndarray, float64, shape (N,) D' values in [-1, 1]. NaN where computation is undefined (monomorphic at either locus or D_max is zero). """ from .haplotype_kernels import _D_PRIME_KERN, _launch c11, c10, c01, c00, n = _hap_count_inputs(counts, n_valid) N = c11.shape[0] out = cp.empty(N, dtype=cp.float64) _launch(_D_PRIME_KERN, (c11, c10, c01, c00, n, out, N), N) return out def _prepare_segregating(mat, missing_data='include'): """Filter to segregating sites and return cleaned arrays. Returns (hap_clean, valid_mask, m) or (None, None, 0) if < 2 sites. """ if hasattr(mat, 'device') and mat.device == 'CPU': mat.transfer_to_gpu() if missing_data == 'exclude': hap = mat.haplotypes missing_per_var = cp.sum(hap < 0, axis=0) valid = cp.where(missing_per_var == 0)[0] mat = mat.get_subset(valid) hap = mat.haplotypes dac = cp.sum(cp.maximum(hap, 0).astype(cp.int32), axis=0) n_valid_per_site = cp.sum((hap >= 0).astype(cp.int32), axis=0) seg = (dac > 0) & (dac < n_valid_per_site) seg_idx = cp.where(seg)[0] if len(seg_idx) < mat.num_variants: mat = mat.get_subset(seg_idx) hap = mat.haplotypes m = hap.shape[1] if m < 2: return None, None, 0 valid_mask = (hap >= 0).astype(cp.float64) hap_clean = cp.where(hap >= 0, hap, 0).astype(cp.float64) return hap_clean, valid_mask, m def _tile_counts(hi, vi, hj, vj): """Compute 4-way haplotype counts for all pairs in a tile. Returns c1, c2, c3, c4 as (B_i, B_j) matrices where: c1 = n_AB (derived at both) c2 = n_Ab (derived at i, ancestral at j) c3 = n_aB (ancestral at i, derived at j) c4 = n_ab (ancestral at both) n = c1+c2+c3+c4 (valid at both sites) """ c1 = hi.T @ hj # derived at both s12 = hi.T @ vj # derived at i, valid at j (= c1 + c2) s13 = vi.T @ hj # valid at i, derived at j (= c1 + c3) n = vi.T @ vj # valid at both c2 = s12 - c1 c3 = s13 - c1 c4 = n - c1 - c2 - c3 return c1, c2, c3, c4, n def _tile_r2_naive(hi, vi, hj, vj, pi, pqi, pj, pqj): """Compute naive r² for a tile (frequency-based, biased).""" joint_n = vi.T @ vj joint_11 = hi.T @ hj p_AB = cp.where(joint_n > 0, joint_11 / joint_n, 0.0) D = p_AB - cp.outer(pi, pj) denom = cp.outer(pqi, pqj) return cp.where(denom > 0, (D ** 2) / denom, 0.0) def _tile_sigma_d2(hi, vi, hj, vj): """Compute unbiased D²/π² (sigma_d^2) for a tile. Uses multinomial projection estimators (Ragsdale & Gravel 2019): D² = [c1(c1-1)c4(c4-1) + c2(c2-1)c3(c3-1) - 2*c1*c2*c3*c4] / [n(n-1)(n-2)(n-3)] π² = [(c1+c2)(c1+c3)(c2+c4)(c3+c4) - c1*c4*(-1+c1+3c2+3c3+c4) - c2*c3*(-1+3c1+c2+c3+3c4)] / [n(n-1)(n-2)(n-3)] Returns sigma_d2 tile and valid mask (n >= 4). """ c1, c2, c3, c4, n = _tile_counts(hi, vi, hj, vj) # Unbiased D² numerator dd_num = (c1 * (c1 - 1) * c4 * (c4 - 1) + c2 * (c2 - 1) * c3 * (c3 - 1) - 2 * c1 * c2 * c3 * c4) # Unbiased π² numerator s12 = c1 + c2 s13 = c1 + c3 s24 = c2 + c4 s34 = c3 + c4 pi2_num = (s12 * s13 * s24 * s34 - c1 * c4 * (-1 + c1 + 3 * c2 + 3 * c3 + c4) - c2 * c3 * (-1 + 3 * c1 + c2 + c3 + 3 * c4)) valid = n >= 4 sigma_d2 = cp.where(valid & (pi2_num != 0), dd_num / pi2_num, 0.0) return sigma_d2, valid def _resolve_ld_estimator(estimator: str, is_hap_matrix: bool) -> str: """Resolve an LD estimator string, including the ``'auto'`` policy. ``'auto'`` resolves to: - ``'sigma_d2'`` for a ``HaplotypeMatrix`` (unbiased Ragsdale & Gravel 2019 estimator -- the recommended path on phased data). - ``'rogers_huff'`` for a ``GenotypeMatrix`` (the natural diploid-dosage estimator). - ``'r2'`` otherwise (pre-computed r² arrays, etc.). Explicit ``'r2'``, ``'sigma_d2'``, and ``'rogers_huff'`` pass through unchanged. """ if estimator == 'auto': return 'sigma_d2' if is_hap_matrix else 'rogers_huff' if estimator not in ('r2', 'sigma_d2', 'rogers_huff'): raise ValueError( f"Unknown estimator: {estimator!r} " f"(expected one of 'auto', 'r2', 'sigma_d2', 'rogers_huff')") return estimator def _dosage_from_matrix(matrix) -> "cp.ndarray": """Return a ``(n_samples, n_variants)`` float64 dosage array. For a ``HaplotypeMatrix`` (n_haplotypes, n_variants) of 0/1, adjacent haplotypes are paired into 0/1/2 dosages (sample 0 = haplotypes 0,1; sample 1 = haplotypes 2,3; ...). For a ``GenotypeMatrix`` (n_samples, n_variants), the genotypes are used directly. Raises ``ValueError`` if missing values (-1) are present, matching the convention of ``scikit-allel.rogers_huff_r``. """ from .haplotype_matrix import HaplotypeMatrix from .genotype_matrix import GenotypeMatrix if isinstance(matrix, HaplotypeMatrix): if matrix.device == 'CPU': matrix.transfer_to_gpu() hap = matrix.haplotypes if (hap < 0).any(): raise ValueError( "rogers_huff_r: input HaplotypeMatrix contains missing " "values (-1). Rogers-Huff r expects strict 0/1/2 dosage " "input; drop or impute missing sites first.") n_hap = hap.shape[0] if n_hap % 2 != 0: raise ValueError( f"rogers_huff_r: HaplotypeMatrix has an odd number of " f"haplotypes ({n_hap}); cannot pair into diploids.") return (hap[0::2, :] + hap[1::2, :]).astype(cp.float64) if isinstance(matrix, GenotypeMatrix): if matrix.device == 'CPU': matrix.transfer_to_gpu() g = matrix.genotypes if (g < 0).any(): raise ValueError( "rogers_huff_r: input GenotypeMatrix contains missing " "values (-1). Rogers-Huff r expects strict 0/1/2 dosage " "input; drop or impute missing sites first.") return g.astype(cp.float64) raise TypeError( f"rogers_huff_r: expected HaplotypeMatrix or GenotypeMatrix; " f"got {type(matrix).__name__}") def _tile_rogers_huff_r(g_i: "cp.ndarray", g_j: "cp.ndarray", mu_i: "cp.ndarray", mu_j: "cp.ndarray", ssd_i: "cp.ndarray", ssd_j: "cp.ndarray", n_samples: int) -> "cp.ndarray": """Per-tile signed Rogers-Huff r block from dosage tiles. Parameters ---------- g_i, g_j : (n_samples, B_i), (n_samples, B_j) float64 Dosage tiles (uncentered). mu_i, mu_j : (B_i,), (B_j,) float64 Per-column means (precomputed for the full matrix). ssd_i, ssd_j : (B_i,), (B_j,) float64 Per-column sums of squared deviations from the column mean (precomputed). Equivalent to ``n_samples * variance``. n_samples : int Number of samples (rows of the dosage matrix). Returns ------- r : (B_i, B_j) float64 Signed Rogers-Huff r per pair. NaN where either column is constant (ssd == 0); matches ``allel.rogers_huff_r``. Notes ----- Uses the rank-1 expansion ``(g_i - mu_i)^T (g_j - mu_j) = g_i^T g_j - n * mu_i mu_j^T`` so the centered cross-product is one matmul plus an outer product, no per-tile centering of the input. """ cov = g_i.T @ g_j - n_samples * cp.outer(mu_i, mu_j) denom = cp.sqrt(cp.outer(ssd_i, ssd_j)) return cp.where(denom > 0, cov / denom, cp.nan) def _rogers_huff_pairwise_r(matrix, tile_size: Optional[int] = None ) -> "cp.ndarray": """Full ``(n_variants, n_variants)`` Rogers-Huff r matrix. Computed tile-by-tile so peak memory is ``O(B^2)`` rather than ``O(n^2)``. The diagonal is set to NaN. Sub-diagonal entries are filled by symmetry. Parameters ---------- matrix : HaplotypeMatrix or GenotypeMatrix tile_size : int, optional Block size B. Defaults to ``min(n_variants, 1024)`` which keeps each tile <= 8 MB at float64 for typical sample sizes. Returns ------- r : (n_variants, n_variants) float64 cupy.ndarray Symmetric Rogers-Huff r matrix on GPU. NaN on the diagonal and for variant pairs where either column is monomorphic. """ g = _dosage_from_matrix(matrix) n_samples, n_var = g.shape if tile_size is None: tile_size = min(n_var, 1024) mu = g.mean(axis=0) ssd = ((g - mu) ** 2).sum(axis=0) out = cp.empty((n_var, n_var), dtype=cp.float64) for i0 in range(0, n_var, tile_size): i1 = min(i0 + tile_size, n_var) for j0 in range(i0, n_var, tile_size): j1 = min(j0 + tile_size, n_var) tile = _tile_rogers_huff_r( g[:, i0:i1], g[:, j0:j1], mu[i0:i1], mu[j0:j1], ssd[i0:i1], ssd[j0:j1], n_samples) out[i0:i1, j0:j1] = tile if i0 != j0: out[j0:j1, i0:i1] = tile.T cp.fill_diagonal(out, cp.nan) return out def rogers_huff_r(matrix, tile_size: Optional[int] = None) -> "cp.ndarray": """Pairwise Rogers-Huff (2008) r for all variant pairs. Returns the upper-triangle pairwise r values in condensed form, matching the layout of :func:`scikit-allel.rogers_huff_r`: pairs are ordered ``(0,1), (0,2), ..., (0,n-1), (1,2), ..., (n-2,n-1)``. Parameters ---------- matrix : HaplotypeMatrix or GenotypeMatrix Diploid input. ``HaplotypeMatrix`` rows are paired into 0/1/2 dosages; ``GenotypeMatrix`` genotypes are used directly. Both must be free of -1 missing sentinels (raise otherwise). tile_size : int, optional GPU tile size. Defaults to ``min(n_variants, 1024)``. Returns ------- r : cupy.ndarray, shape ``(n_variants * (n_variants - 1) // 2,)`` Signed Rogers-Huff r per pair. NaN where either variant is monomorphic. See Also -------- rogers_huff_r_squared : convenience wrapper returning ``r ** 2``. """ r_full = _rogers_huff_pairwise_r(matrix, tile_size=tile_size) n = r_full.shape[0] iu = cp.triu_indices(n, k=1) return r_full[iu] def rogers_huff_r_squared(matrix, tile_size: Optional[int] = None ) -> "cp.ndarray": """Pairwise Rogers-Huff r² for all variant pairs. Convenience wrapper around :func:`rogers_huff_r` returning the squared values. """ return rogers_huff_r(matrix, tile_size=tile_size) ** 2 def _zns_tiled(mat, missing_data='include', tile_size=512): """Compute ZnS without materializing the full r² matrix. Uses tile-based accumulation: computes r² for B×B blocks and sums per tile, keeping memory at O(B²) instead of O(m²). When missing_data='project' (set internally via estimator='sigma_d2'), uses unbiased multinomial projection estimators (Ragsdale & Gravel 2019) computing σ_D² = D²/π² per pair instead of naive r². """ hap_clean, valid_mask, m = _prepare_segregating(mat, missing_data) if m < 2: return 0.0 use_projection = (missing_data == 'project') # internal: mapped from estimator='sigma_d2' B = tile_size total = 0.0 n_pairs = 0 if not use_projection: n_valid = cp.sum(valid_mask, axis=0).astype(cp.float64) p = cp.where(n_valid > 0, cp.sum(hap_clean, axis=0) / n_valid, 0.0) pq = p * (1 - p) for i0 in range(0, m, B): i1 = min(i0 + B, m) hi = hap_clean[:, i0:i1] vi = valid_mask[:, i0:i1] for j0 in range(i0, m, B): j1 = min(j0 + B, m) hj = hap_clean[:, j0:j1] vj = valid_mask[:, j0:j1] if use_projection: tile, valid = _tile_sigma_d2(hi, vi, hj, vj) if i0 == j0: cp.fill_diagonal(tile, 0.0) cp.fill_diagonal(valid, False) total += float(cp.sum(tile).get()) n_pairs += int(cp.sum(valid).get()) else: total += 2.0 * float(cp.sum(tile).get()) n_pairs += 2 * int(cp.sum(valid).get()) else: r2_tile = _tile_r2_naive( hi, vi, hj, vj, p[i0:i1], pq[i0:i1], p[j0:j1], pq[j0:j1]) if i0 == j0: cp.fill_diagonal(r2_tile, 0.0) total += float(cp.sum(r2_tile).get()) else: total += 2.0 * float(cp.sum(r2_tile).get()) if use_projection: return total / n_pairs if n_pairs > 0 else 0.0 return total / (m * (m - 1)) def _zns_from_precomputed(hap_clean, valid_mask, col_start, col_end, tile_size=512, use_projection=False): """Compute ZnS for a column range using precomputed arrays. This avoids creating a HaplotypeMatrix and recomputing valid_mask/hap_clean for each window in the windowed_analysis loop. Parameters ---------- hap_clean : cupy.ndarray, shape (n_hap, n_variants) Haplotype data with missing values set to 0. valid_mask : cupy.ndarray, shape (n_hap, n_variants) 1 where data is valid, 0 where missing. col_start, col_end : int Column range [col_start, col_end) to compute ZnS over. tile_size : int Tile size for accumulation. use_projection : bool If True, use unbiased multinomial projection estimators. Returns ------- float ZnS value, or 0.0 if fewer than 2 segregating sites. """ hc = hap_clean[:, col_start:col_end] vm = valid_mask[:, col_start:col_end] # Filter to segregating sites n_valid = cp.sum(vm, axis=0).astype(cp.float64) dac = cp.sum(hc, axis=0) seg = (dac > 0) & (dac < n_valid) seg_idx = cp.where(seg)[0] m = len(seg_idx) if m < 2: return 0.0 hc = hc[:, seg_idx] vm = vm[:, seg_idx] if not use_projection: n_valid = n_valid[seg_idx] p = cp.where(n_valid > 0, cp.sum(hc, axis=0) / n_valid, 0.0) pq = p * (1 - p) B = tile_size total = 0.0 n_pairs = 0 for i0 in range(0, m, B): i1 = min(i0 + B, m) hi = hc[:, i0:i1] vi = vm[:, i0:i1] for j0 in range(i0, m, B): j1 = min(j0 + B, m) hj = hc[:, j0:j1] vj = vm[:, j0:j1] if use_projection: tile, valid = _tile_sigma_d2(hi, vi, hj, vj) if i0 == j0: cp.fill_diagonal(tile, 0.0) cp.fill_diagonal(valid, False) total += float(cp.sum(tile).get()) n_pairs += int(cp.sum(valid).get()) else: total += 2.0 * float(cp.sum(tile).get()) n_pairs += 2 * int(cp.sum(valid).get()) else: r2_tile = _tile_r2_naive( hi, vi, hj, vj, p[i0:i1], pq[i0:i1], p[j0:j1], pq[j0:j1]) if i0 == j0: cp.fill_diagonal(r2_tile, 0.0) total += float(cp.sum(r2_tile).get()) else: total += 2.0 * float(cp.sum(r2_tile).get()) if use_projection: return total / n_pairs if n_pairs > 0 else 0.0 return total / (m * (m - 1))
[docs] def zns(r2_matrix_or_matrix, missing_data='include', estimator='auto'): """Kelly's ZnS: mean pairwise r-squared across all SNP pairs. Parameters ---------- r2_matrix_or_matrix : ndarray, HaplotypeMatrix, or GenotypeMatrix Square r-squared matrix, or a matrix object (dispatches to haploid or diploid r-squared computation automatically). When a HaplotypeMatrix is passed, uses tiled computation to avoid materializing the full m×m r² matrix. missing_data : str ``'include'`` (default) uses per-site valid data for frequency computation. ``'exclude'`` filters to sites with no missing data. estimator : str ``'auto'`` (default) uses the unbiased ``sigma_d2`` estimator when the input is a ``HaplotypeMatrix``, and falls back to naive ``r2`` for pre-computed r² arrays or ``GenotypeMatrix`` inputs (where ``sigma_d2`` is not available). ``'r2'`` always computes naive r-squared. ``'sigma_d2'`` always uses the unbiased multinomial projection estimators (Ragsdale & Gravel 2019), computing mean :math:`\\sigma_D^2 = D^2/\\pi_2` per pair with falling-factorial corrections. Requires ``HaplotypeMatrix`` input. Returns ------- float Mean r-squared (or mean sigma_D^2 when sigma_d2 is selected). """ from .haplotype_matrix import HaplotypeMatrix is_hm = isinstance(r2_matrix_or_matrix, HaplotypeMatrix) estimator = _resolve_ld_estimator(estimator, is_hm) # Map estimator to internal missing_data for backward compat with _zns_tiled _md = 'project' if estimator == 'sigma_d2' else missing_data # Streaming path for HaplotypeMatrix: O(B²) memory instead of O(m²) if is_hm: return _zns_tiled(r2_matrix_or_matrix, _md) if estimator == 'sigma_d2': raise ValueError( "estimator='sigma_d2' requires a HaplotypeMatrix, " "not a pre-computed r² array") r2_matrix = _resolve_r2_matrix(r2_matrix_or_matrix, missing_data) m = r2_matrix.shape[0] if m < 2: return 0.0 total = cp.sum(r2_matrix) - cp.trace(r2_matrix) return float((total / (m * (m - 1))).get())
def _build_sigma_d2_matrix(mat, missing_data='include'): """Build full m×m σ_D² matrix using unbiased estimators. Used by omega() when estimator='sigma_d2'. """ hap_clean, valid_mask, m = _prepare_segregating(mat, missing_data) if m < 2: return cp.zeros((0, 0), dtype=cp.float64) c1, c2, c3, c4, n = _tile_counts(hap_clean, valid_mask, hap_clean, valid_mask) dd_num = (c1 * (c1 - 1) * c4 * (c4 - 1) + c2 * (c2 - 1) * c3 * (c3 - 1) - 2 * c1 * c2 * c3 * c4) s12, s13, s24, s34 = c1 + c2, c1 + c3, c2 + c4, c3 + c4 pi2_num = (s12 * s13 * s24 * s34 - c1 * c4 * (-1 + c1 + 3 * c2 + 3 * c3 + c4) - c2 * c3 * (-1 + 3 * c1 + c2 + c3 + 3 * c4)) valid = (n >= 4) & (pi2_num != 0) result = cp.where(valid, dd_num / pi2_num, 0.0) cp.fill_diagonal(result, 0.0) return result
[docs] def omega(r2_matrix_or_matrix, missing_data='include', estimator='auto'): """Kim and Nielsen's Omega: max ratio of within-partition to cross-partition mean LD. For each possible SNP partition point l, splits variants into [0:l) and [l:m), computes mean r-squared within each block and between blocks. Returns max(mean_within / mean_cross). Uses GPU prefix sums on the upper triangle to evaluate all partition points without a Python loop. Matches diploSHIC's convention of using upper-triangle pairs only. Parameters ---------- r2_matrix_or_matrix : ndarray, HaplotypeMatrix, or GenotypeMatrix Square r-squared matrix, or a matrix object (dispatches to haploid or diploid r-squared computation automatically). missing_data : str ``'include'`` (default) uses per-site valid data for frequency computation. ``'exclude'`` filters to sites with no missing data. estimator : str ``'auto'`` (default) uses the unbiased ``sigma_d2`` estimator when the input is a ``HaplotypeMatrix``, and falls back to naive ``r2`` for pre-computed r² arrays or ``GenotypeMatrix`` inputs (where ``sigma_d2`` is not available). ``'r2'`` always computes naive r-squared. ``'sigma_d2'`` always uses unbiased :math:`\\sigma_D^2 = D^2/\\pi_2` (Ragsdale & Gravel 2019). Requires ``HaplotypeMatrix`` input. Returns ------- float Maximum omega value. Returns 0 if fewer than 5 SNPs. """ from .haplotype_matrix import HaplotypeMatrix is_hm = isinstance(r2_matrix_or_matrix, HaplotypeMatrix) estimator = _resolve_ld_estimator(estimator, is_hm) if estimator == 'sigma_d2': if not is_hm: raise ValueError( "estimator='sigma_d2' requires a HaplotypeMatrix") r2_matrix = _build_sigma_d2_matrix(r2_matrix_or_matrix, missing_data=missing_data) else: r2_matrix = _resolve_r2_matrix(r2_matrix_or_matrix, missing_data) m = r2_matrix.shape[0] if m < 5: return 0.0 # work with upper triangle only (i < j), matching diploSHIC r2 = cp.triu(r2_matrix, k=1) # 2D prefix sums on upper triangle S = cp.cumsum(cp.cumsum(r2, axis=0), axis=1) def block_sum(r_start, r_end, c_start, c_end): """Sum of S[r_start:r_end, c_start:c_end] via inclusion-exclusion.""" val = S[r_end - 1, c_end - 1] if r_start > 0: val -= S[r_start - 1, c_end - 1] if c_start > 0: val -= S[r_end - 1, c_start - 1] if r_start > 0 and c_start > 0: val += S[r_start - 1, c_start - 1] return val # partition points l = 3..m-2 (matching diploSHIC) l_vals = cp.arange(3, m - 1) # left block: upper triangle pairs (i,j) with i < j < l # = sum of r2[0:l, 0:l] upper triangle = block_sum(0, l, 0, l) left_sum = S[l_vals - 1, l_vals - 1] # total upper triangle sum total_upper = S[m - 1, m - 1] # cross block: pairs (i,j) with i < l and j >= l # = block_sum(0, l, l, m) cross_sum = S[l_vals - 1, m - 1] - left_sum # right block: pairs (i,j) with i >= l and j > i (upper triangle of right block) right_sum = total_upper - left_sum - cross_sum # pair counts (upper triangle only) n_left = l_vals * (l_vals - 1) // 2 n_right = (m - l_vals) * (m - l_vals - 1) // 2 n_cross = l_vals * (m - l_vals) n_within = n_left + n_right within_sum = left_sum + right_sum valid = (n_within > 0) & (n_cross > 0) & (cross_sum > 0) mean_within = cp.where(n_within > 0, within_sum / n_within.astype(cp.float64), 0.0) mean_cross = cp.where(n_cross > 0, cross_sum / n_cross.astype(cp.float64), 1.0) omega_vals = cp.where(valid, mean_within / mean_cross, 0.0) return float(cp.max(omega_vals).get())
[docs] def mu_ld(haplotype_matrix, missing_data='include'): """mu_LD: haplotype pattern exclusivity between left/right halves (RAiSD). Splits variants at midpoint and measures how exclusively haplotype patterns associate across halves. Elevated at sweep boundaries where LD structure changes abruptly. Parameters ---------- haplotype_matrix : HaplotypeMatrix missing_data : str 'include' - treat missing as wildcard in pattern matching 'exclude' - filter to sites with no missing data Returns ------- float """ if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() hap = haplotype_matrix.haplotypes if missing_data == 'exclude': missing_per_var = cp.sum(hap < 0, axis=0) hap = hap[:, missing_per_var == 0] n_hap, n_var = hap.shape if n_var < 2: return 0.0 mid = n_var // 2 left = hap[:, :mid].get().astype(np.int8) right = hap[:, mid:].get().astype(np.int8) from .diversity import _cluster_haplotypes_with_missing left_labels = _cluster_haplotypes_with_missing(left) right_labels = _cluster_haplotypes_with_missing(right) # for each distinct left pattern, count how many distinct right patterns it pairs with left_to_right = {} right_to_left = {} for i in range(n_hap): ll, rl = left_labels[i], right_labels[i] left_to_right.setdefault(ll, set()).add(rl) right_to_left.setdefault(rl, set()).add(ll) n_left = len(left_to_right) n_right = len(right_to_left) if n_left == 0 or n_right == 0: return 0.0 n_excl_left = sum(1 for v in left_to_right.values() if len(v) == 1) n_excl_right = sum(1 for v in right_to_left.values() if len(v) == 1) return float((n_excl_left / n_left + n_excl_right / n_right) / 2.0)
def _resolve_r2_matrix(r2_matrix_or_matrix, missing_data='include'): """Convert a matrix object to an r2 matrix, or pass through raw arrays. Filters to segregating sites only (excludes monomorphic variants) to match diploSHIC/allel convention for ZnS/Omega. """ from .haplotype_matrix import HaplotypeMatrix from .genotype_matrix import GenotypeMatrix if isinstance(r2_matrix_or_matrix, (GenotypeMatrix, HaplotypeMatrix)): mat = r2_matrix_or_matrix if hasattr(mat, 'device') and mat.device == 'CPU': mat.transfer_to_gpu() # Filter missing data sites if missing_data == 'exclude': hap = mat.haplotypes if isinstance(mat, HaplotypeMatrix) else mat.genotypes missing_per_var = cp.sum(hap < 0, axis=0) valid = cp.where(missing_per_var == 0)[0] if isinstance(mat, HaplotypeMatrix): mat = mat.get_subset(valid) else: geno = mat.genotypes[:, valid] pos = mat.positions[valid] from .genotype_matrix import GenotypeMatrix as GM mat = GM(geno, pos) # Haploid: filter monomorphic sites before r^2 computation. # diploSHIC marks monomorphic pairs as -1 and skips them in ZnS/Omega. # We match this by excluding monomorphic sites entirely. if isinstance(mat, HaplotypeMatrix): hap = mat.haplotypes dac = cp.sum(cp.maximum(hap, 0).astype(cp.int32), axis=0) n_valid = cp.sum((hap >= 0).astype(cp.int32), axis=0) seg = (dac > 0) & (dac < n_valid) seg_idx = cp.where(seg)[0] if len(seg_idx) < mat.num_variants: mat = mat.get_subset(seg_idx) return mat.pairwise_r2().astype(cp.float64) else: return _r2_matrix_diploid(mat) else: if not isinstance(r2_matrix_or_matrix, cp.ndarray): return cp.asarray(r2_matrix_or_matrix, dtype=cp.float64) return r2_matrix_or_matrix def _r2_matrix_diploid(genotype_matrix): """Compute r-squared matrix from diploid genotypes (0/1/2) on GPU. Uses genotype correlation: treats 0/1/2 as continuous dosage values, computes Pearson correlation, then squares. Parameters ---------- genotype_matrix : GenotypeMatrix or cupy.ndarray If GenotypeMatrix, uses .genotypes. If array, shape (n_individuals, n_variants). Returns ------- r2 : cupy.ndarray, float64, shape (n_variants, n_variants) """ from .genotype_matrix import GenotypeMatrix if isinstance(genotype_matrix, GenotypeMatrix): if genotype_matrix.device == 'CPU': genotype_matrix.transfer_to_gpu() geno = genotype_matrix.genotypes else: geno = genotype_matrix if not isinstance(geno, cp.ndarray): geno = cp.asarray(geno) # mask missing data: compute per-site mean from valid data only valid_mask = (geno >= 0).astype(cp.float64) geno_clean = cp.where(geno >= 0, geno, 0).astype(cp.float64) n_valid = cp.sum(valid_mask, axis=0).astype(cp.float64) mean = cp.where(n_valid > 0, cp.sum(geno_clean, axis=0) / n_valid, 0.0) # center, zeroing out missing entries gn = (geno_clean - mean[None, :]) * valid_mask # variance per variant (using valid counts) var = cp.sum(gn ** 2, axis=0) # correlation via matrix multiply cov = gn.T @ gn # (n_var, n_var) # normalize: r_ij = cov_ij / sqrt(var_i * var_j) denom = cp.sqrt(cp.outer(var, var)) r2 = cp.where(denom > 0, (cov / denom) ** 2, 0.0) cp.fill_diagonal(r2, 0.0) return r2 # Keep old names as aliases for backward compat r2_matrix_diploid = _r2_matrix_diploid zns_diploid = zns omega_diploid = omega
[docs] def compute_ld_statistics(counts: cp.ndarray, statistics: List[str] = ['dd', 'dz', 'pi2'], populations: Optional[Dict[str, Union[Tuple, None]]] = None, n_valid: Optional[cp.ndarray] = None) -> Dict[str, cp.ndarray]: """ Compute multiple LD statistics in one pass. Parameters ---------- counts : cp.ndarray Haplotype counts array statistics : list of str Statistics to compute ('dd', 'dz', 'pi2') populations : dict, optional Population configurations for each statistic. E.g., {'dd': (0, 1), 'dz': (0, 0, 1), 'pi2': (0, 0, 1, 1)} n_valid : cp.ndarray, optional Valid sample counts per population Returns ------- dict Dictionary mapping statistic names to computed values """ if populations is None: populations = {} results = {} for stat in statistics: if stat == 'dd': pop_config = populations.get('dd', None) results['dd'] = dd(counts, pop_config, n_valid) elif stat == 'dz': pop_config = populations.get('dz', None) results['dz'] = dz(counts, pop_config, n_valid) elif stat == 'pi2': pop_config = populations.get('pi2', None) results['pi2'] = pi2(counts, pop_config, n_valid) elif stat == 'r': results['r'] = r(counts, n_valid) elif stat == 'r_squared': results['r_squared'] = r_squared(counts, n_valid) else: raise ValueError(f"Unknown statistic: {stat}") return results
# Internal implementation functions def _get_pop_data(counts, n_valid, pop_idx): """Extract counts and valid sample size for one population. Parameters ---------- counts : cp.ndarray, shape (N, 4*P) Concatenated haplotype counts for P populations. n_valid : tuple of cp.ndarray, cp.ndarray with ndim==2, or None Per-population valid sample counts. pop_idx : int Population index (0-based). Returns ------- c1, c2, c3, c4, n : cp.ndarray Haplotype counts and total valid samples for this population. """ start = pop_idx * 4 pop_counts = counts[:, start:start+4] if n_valid is not None: if isinstance(n_valid, tuple): if pop_idx < len(n_valid) and n_valid[pop_idx] is not None: pop_n = n_valid[pop_idx] else: pop_n = cp.sum(pop_counts, axis=1) elif hasattr(n_valid, 'ndim') and n_valid.ndim == 2: pop_n = n_valid[:, pop_idx] else: pop_n = n_valid else: pop_n = cp.sum(pop_counts, axis=1) return pop_counts[:, 0], pop_counts[:, 1], pop_counts[:, 2], pop_counts[:, 3], pop_n def _dd_single(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute D² for single population.""" c1, c2, c3, c4 = counts[:, 0], counts[:, 1], counts[:, 2], counts[:, 3] n = n_valid if n_valid is not None else cp.sum(counts, axis=1) numer = c1 * (c1 - 1) * c4 * (c4 - 1) + c2 * (c2 - 1) * c3 * (c3 - 1) - 2 * c1 * c2 * c3 * c4 denom = n * (n - 1) * (n - 2) * (n - 3) valid_mask = n >= 4 result = cp.zeros_like(n, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _dd_between(counts: cp.ndarray, pop1_idx: int, pop2_idx: int, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute D² between two populations.""" # Extract counts for each population start1 = pop1_idx * 4 start2 = pop2_idx * 4 c11, c12, c13, c14 = counts[:, start1], counts[:, start1+1], counts[:, start1+2], counts[:, start1+3] c21, c22, c23, c24 = counts[:, start2], counts[:, start2+1], counts[:, start2+2], counts[:, start2+3] # Get valid sample sizes if n_valid is not None: if isinstance(n_valid, tuple): n1 = n_valid[0] if n_valid[0] is not None else cp.sum(counts[:, start1:start1+4], axis=1) n2 = n_valid[1] if n_valid[1] is not None else cp.sum(counts[:, start2:start2+4], axis=1) elif hasattr(n_valid, 'ndim') and n_valid.ndim == 2: n1 = n_valid[:, pop1_idx] n2 = n_valid[:, pop2_idx] else: # Assume n_valid is for between-population pairs n1 = n_valid n2 = n_valid else: n1 = cp.sum(counts[:, start1:start1+4], axis=1) n2 = cp.sum(counts[:, start2:start2+4], axis=1) D1 = c12 * c13 - c11 * c14 D2 = c22 * c23 - c21 * c24 numer = D1 * D2 denom = n1 * (n1 - 1) * n2 * (n2 - 1) valid_mask = (n1 >= 2) & (n2 >= 2) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _dz_single(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute Dz for single population.""" c1, c2, c3, c4 = counts[:, 0], counts[:, 1], counts[:, 2], counts[:, 3] n = n_valid if n_valid is not None else cp.sum(counts, axis=1) diff = c1 * c4 - c2 * c3 sum_34_12 = (c3 + c4) - (c1 + c2) sum_24_13 = (c2 + c4) - (c1 + c3) sum_23_14 = (c2 + c3) - (c1 + c4) numer = diff * sum_34_12 * sum_24_13 + diff * sum_23_14 + 2 * (c2 * c3 + c1 * c4) denom = n * (n - 1) * (n - 2) * (n - 3) valid_mask = n >= 4 result = cp.zeros_like(n, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _dz_multi(counts: cp.ndarray, populations: Tuple[int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute Dz for multiple populations.""" pop1, pop2, pop3 = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) if pop1 == pop2 == pop3: # Single population if n_valid is not None and isinstance(n_valid, tuple): # Handle tuple case pop_n_valid = n_valid[pop1] if pop1 < len(n_valid) and n_valid[pop1] is not None else None elif n_valid is not None and hasattr(n_valid, 'ndim') and n_valid.ndim == 2: pop_n_valid = n_valid[:, pop1] else: pop_n_valid = n_valid return _dz_single(counts[:, pop1*4:(pop1+1)*4], pop_n_valid) elif pop1 == pop2: # Dz(i,i,j) c11, c12, c13, c14, n1 = get_pop_data(pop1) c21, c22, c23, c24, n2 = get_pop_data(pop3) numer = ( (-c11 - c12 + c13 + c14) * (-(c12 * c13) + c11 * c14) * (-c21 + c22 - c23 + c24) ) denom = n2 * n1 * (n1 - 1) * (n1 - 2) valid_mask = (n1 >= 3) & (n2 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] elif pop1 == pop3: # Dz(i,j,i) c11, c12, c13, c14, n1 = get_pop_data(pop1) c21, c22, c23, c24, n2 = get_pop_data(pop2) numer = ( (-c11 + c12 - c13 + c14) * (-(c12 * c13) + c11 * c14) * (-c21 - c22 + c23 + c24) ) denom = n2 * n1 * (n1 - 1) * (n1 - 2) valid_mask = (n1 >= 3) & (n2 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] elif pop2 == pop3: # Dz(i,j,j) c11, c12, c13, c14, n1 = get_pop_data(pop1) c21, c22, c23, c24, n2 = get_pop_data(pop2) numer = (-(c12 * c13) + c11 * c14) * (-c21 + c22 + c23 - c24) + ( -(c12 * c13) + c11 * c14 ) * (-c21 + c22 - c23 + c24) * (-c21 - c22 + c23 + c24) denom = n1 * (n1 - 1) * n2 * (n2 - 1) valid_mask = (n1 >= 2) & (n2 >= 2) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] else: # Dz(i,j,k) all different populations c11, c12, c13, c14, n1 = get_pop_data(pop1) c21, c22, c23, c24, n2 = get_pop_data(pop2) c31, c32, c33, c34, n3 = get_pop_data(pop3) numer = -( (c12 * c13 - c11 * c14) * (c21 + c22 - c23 - c24) * (c31 - c32 + c33 - c34) ) denom = n1 * (n1 - 1) * n2 * n3 valid_mask = (n1 >= 2) & (n2 >= 1) & (n3 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_single(counts: cp.ndarray, n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute π₂ for single population.""" c1, c2, c3, c4 = counts[:, 0], counts[:, 1], counts[:, 2], counts[:, 3] n = n_valid if n_valid is not None else cp.sum(counts, axis=1) s12 = c1 + c2 s13 = c1 + c3 s24 = c2 + c4 s34 = c3 + c4 term_a = s12 * s13 * s24 * s34 term_b = c1 * c4 * (-1 + c1 + 3 * c2 + 3 * c3 + c4) term_c = c2 * c3 * (-1 + 3 * c1 + c2 + c3 + 3 * c4) numer = term_a - term_b - term_c denom = n * (n - 1) * (n - 2) * (n - 3) valid_mask = n >= 4 result = cp.zeros_like(n, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_multi(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Compute π₂ for multiple populations.""" i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) # Count how many times each population index appears pop_list = [i, j, k, l] pop_counts = {} for p in pop_list: pop_counts[p] = pop_counts.get(p, 0) + 1 n_unique = len(pop_counts) max_count = max(pop_counts.values()) if n_unique == 1: # All same population return _pi2_single(counts[:, i*4:(i+1)*4], n_valid[:, i] if n_valid is not None and n_valid.ndim == 2 else n_valid) elif max_count == 3: # Three same, one different -- normalize to (single, triple, triple, triple) triple_pop = [p for p, c in pop_counts.items() if c == 3][0] single_pop = [p for p, c in pop_counts.items() if c == 1][0] result = _pi2_iiij(counts, (single_pop, triple_pop, triple_pop, triple_pop), n_valid) elif i == j and k == l: # pi2(i,i,k,k) -- two pairs c11, c12, c13, c14, n1 = get_pop_data(i) c21, c22, c23, c24, n2 = get_pop_data(k) numer1 = (c11 + c12) * (c13 + c14) * (c21 + c23) * (c22 + c24) numer2 = (c21 + c22) * (c23 + c24) * (c11 + c13) * (c12 + c14) denom = n1 * (n1 - 1) * n2 * (n2 - 1) valid_mask = (n1 >= 2) & (n2 >= 2) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = 0.5 * (numer1[valid_mask] + numer2[valid_mask]) / denom[valid_mask] elif i == j and k != l: # pi2(i,i,k,l) type -- handles both 2 and 3 distinct populations result = _pi2_iikl(counts, populations, n_valid) elif i != j and k == l: # pi2(i,j,k,k) type result = _pi2_ijkk(counts, populations, n_valid) elif (i == k and j == l) or (i == l and j == k): # pi2(i,j,i,j) or pi2(i,j,j,i) type c11, c12, c13, c14, n1 = get_pop_data(i) c21, c22, c23, c24, n2 = get_pop_data(j) numer = ( ((c12 + c14) * (c13 + c14) * (c21 + c22) * (c21 + c23)) / 4.0 + ((c11 + c13) * (c13 + c14) * (c21 + c22) * (c22 + c24)) / 4.0 + ((c11 + c12) * (c12 + c14) * (c21 + c23) * (c23 + c24)) / 4.0 + ((c11 + c12) * (c11 + c13) * (c22 + c24) * (c23 + c24)) / 4.0 + ( -(c12 * c13 * c21) + c14 * c21 - c12 * c14 * c21 - c13 * c14 * c21 - c14 ** 2 * c21 - c14 * c21 ** 2 + c13 * c22 - c11 * c13 * c22 - c13 ** 2 * c22 - c11 * c14 * c22 - c13 * c14 * c22 - c13 * c21 * c22 - c14 * c21 * c22 - c13 * c22 ** 2 + c12 * c23 - c11 * c12 * c23 - c12 ** 2 * c23 - c11 * c14 * c23 - c12 * c14 * c23 - c12 * c21 * c23 - c14 * c21 * c23 - c11 * c22 * c23 - c14 * c22 * c23 - c12 * c23 ** 2 + c11 * c24 - c11 ** 2 * c24 - c11 * c12 * c24 - c11 * c13 * c24 - c12 * c13 * c24 - c12 * c21 * c24 - c13 * c21 * c24 - c11 * c22 * c24 - c13 * c22 * c24 - c11 * c23 * c24 - c12 * c23 * c24 - c11 * c24 ** 2 ) / 4.0 ) denom = n1 * (n1 - 1) * n2 * (n2 - 1) valid_mask = (n1 >= 2) & (n2 >= 2) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] else: if n_unique == 4: result = _pi2_all_different(counts, populations, n_valid) elif n_unique == 3: result = _pi2_shared_pop(counts, populations, n_valid) else: result = cp.zeros_like(get_pop_data(0)[4], dtype=cp.float64) return result def _pi2_iiij(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Helper for pi2(i,j,j,j) configurations.""" i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) # For pi2(i,j,j,j) where j==k==l and i!=j c11, c12, c13, c14, n1 = get_pop_data(j) # The population that appears 3 times c21, c22, c23, c24, n2 = get_pop_data(i) # The population that appears once # From moments _pi2_iiij formula numer = ( -((c11 + c12) * c14 * (c21 + c23)) - (c12 * (c13 + c14) * (c21 + c23)) + ((c11 + c12) * (c12 + c14) * (c13 + c14) * (c21 + c23)) + ((c11 + c12) * (c13 + c14) * (-2 * c22 - 2 * c24)) + ((c11 + c12) * c14 * (c22 + c24)) + (c12 * (c13 + c14) * (c22 + c24)) + ((c11 + c12) * (c11 + c13) * (c13 + c14) * (c22 + c24)) ) / 2.0 denom = n2 * n1 * (n1 - 1) * (n1 - 2) valid_mask = (n1 >= 3) & (n2 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_iikl(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Helper for pi2(i,i,k,l) configurations.""" i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) # Get all unique populations involved unique_pops = list(set([i, k, l])) if len(unique_pops) == 2: # Cases like (0,0,0,1) -- delegate to _pi2_iiij pop_minor = k if k != i else l result = _pi2_iiij(counts, (pop_minor, i, i, i), n_valid) else: # 3 distinct populations: pi2(i,i,j,k) where i,j,k all different # cs1 = counts[i], cs2 = counts[k], cs3 = counts[l] c11, c12, c13, c14, n1 = get_pop_data(i) c21, c22, c23, c24, n2 = get_pop_data(k) c31, c32, c33, c34, n3 = get_pop_data(l) numer = ( (c11 + c12) * (c13 + c14) * (c22 * (c31 + c33) + c24 * (c31 + c33) + (c21 + c23) * (c32 + c34)) ) / 2.0 denom = n1 * (n1 - 1) * n2 * n3 valid_mask = (n1 >= 2) & (n2 >= 1) & (n3 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_ijkk(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Helper for pi2(i,j,k,k) configurations.""" i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) # From moments: pi2(i,j,k,k) where pop3 == pop4 c11, c12, c13, c14, n1 = get_pop_data(k) # pop3/pop4 (k) c21, c22, c23, c24, n2 = get_pop_data(i) # pop1 (i) # Special case: if j == k, cs3 is the same as cs1 if j == k: c31, c32, c33, c34, n3 = c11, c12, c13, c14, n1 else: c31, c32, c33, c34, n3 = get_pop_data(j) # pop2 (j) # From moments formula numer = ( (c11 + c13) * (c12 + c14) * (c23 * (c31 + c32) + c24 * (c31 + c32) + (c21 + c22) * (c33 + c34)) ) / 2.0 denom = n1 * (n1 - 1) * n2 * n3 valid_mask = (n1 >= 2) & (n2 >= 1) & (n3 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_shared_pop(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Helper for pi2 with one population shared between pairs. Handles pi2(i,j,i,k), pi2(i,j,k,i), pi2(i,j,j,k), pi2(i,j,k,j) where exactly one population appears in both the first and second pair. """ i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) # Map to canonical form: cs1 = shared pop, cs2 = other from first pair, # cs3 = other from second pair if i == k: # pi2(i,j;i,l) shared, other1, other2 = i, j, l elif i == l: # pi2(i,j;k,i) shared, other1, other2 = i, j, k elif j == k: # pi2(i,j;j,l) shared, other1, other2 = j, i, l elif j == l: # pi2(i,j;k,j) shared, other1, other2 = j, i, k else: n1 = get_pop_data(0)[4] return cp.zeros_like(n1, dtype=cp.float64) c11, c12, c13, c14, n1 = get_pop_data(shared) c21, c22, c23, c24, n2 = get_pop_data(other1) c31, c32, c33, c34, n3 = get_pop_data(other2) numer = ( c14 ** 2 * (c21 + c22) * (c31 + c33) + c12 ** 2 * (c23 + c24) * (c31 + c33) + (-1 + c11 + c13) * (c13 * (c21 + c22) + c11 * (c23 + c24)) * (c32 + c34) + c14 * ( c11 * (c23 + c24) * (c31 + c33) + c21 * ( (-1 + c13) * c31 + c13 * c32 - c33 + c13 * c33 + c13 * c34 + c11 * (c32 + c34) ) + c22 * ( (-1 + c13) * c31 + c13 * c32 - c33 + c13 * c33 + c13 * c34 + c11 * (c32 + c34) ) ) + c12 * ( c14 * (c21 + c22 + c23 + c24) * (c31 + c33) + c13 * (c21 * (c31 + c33) + c22 * (c31 + c33) + (c23 + c24) * (c32 + c34)) + (c23 + c24) * ((-1 + c11) * c31 - c33 + c11 * (c32 + c33 + c34)) ) ) / 4.0 denom = n1 * (n1 - 1) * n2 * n3 valid_mask = (n1 >= 2) & (n2 >= 1) & (n3 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result def _pi2_all_different(counts: cp.ndarray, populations: Tuple[int, int, int, int], n_valid: Optional[cp.ndarray] = None) -> cp.ndarray: """Helper for pi2(i,j,k,l) where all 4 populations are different.""" i, j, k, l = populations def get_pop_data(pop_idx): return _get_pop_data(counts, n_valid, pop_idx) c11, c12, c13, c14, n1 = get_pop_data(i) c21, c22, c23, c24, n2 = get_pop_data(j) c31, c32, c33, c34, n3 = get_pop_data(k) c41, c42, c43, c44, n4 = get_pop_data(l) numer = ( ((c13 + c14) * (c21 + c22) * (c32 + c34) * (c41 + c43)) / 4.0 + ((c11 + c12) * (c23 + c24) * (c32 + c34) * (c41 + c43)) / 4.0 + ((c13 + c14) * (c21 + c22) * (c31 + c33) * (c42 + c44)) / 4.0 + ((c11 + c12) * (c23 + c24) * (c31 + c33) * (c42 + c44)) / 4.0 ) denom = n1 * n2 * n3 * n4 valid_mask = (n1 >= 1) & (n2 >= 1) & (n3 >= 1) & (n4 >= 1) result = cp.zeros_like(n1, dtype=cp.float64) result[valid_mask] = numer[valid_mask] / denom[valid_mask] return result # Backward compatibility layer def DD(counts, n_valid=None): """Deprecated: Use dd() instead.""" import warnings warnings.warn( "DD() is deprecated. Use ld_statistics.dd() instead.", DeprecationWarning, stacklevel=2 ) return dd_within(counts, n_valid) def DD_two_pops(counts, pop1_idx, pop2_idx, n_valid1=None, n_valid2=None): """Deprecated: Use dd() with populations parameter instead.""" import warnings warnings.warn( "DD_two_pops() is deprecated. Use ld_statistics.dd(counts, populations=(pop1_idx, pop2_idx)) instead.", DeprecationWarning, stacklevel=2 ) # Reconstruct the expected format if n_valid1 is not None and n_valid2 is not None: n_valid = (n_valid1, n_valid2) else: n_valid = None return dd(counts, populations=(pop1_idx, pop2_idx), n_valid=n_valid) def Dz_two_pops(counts, pop_indices, n_valid1=None, n_valid2=None): """Deprecated: Use dz() with populations parameter instead.""" import warnings warnings.warn( "Dz_two_pops() is deprecated. Use ld_statistics.dz(counts, populations=pop_indices) instead.", DeprecationWarning, stacklevel=2 ) if n_valid1 is not None and n_valid2 is not None: n_valid = (n_valid1, n_valid2) else: n_valid = None return dz(counts, populations=pop_indices, n_valid=n_valid) def pi2_two_pops(counts, pop_indices, n_valid1=None, n_valid2=None): """Deprecated: Use pi2() with populations parameter instead.""" import warnings warnings.warn( "pi2_two_pops() is deprecated. Use ld_statistics.pi2(counts, populations=pop_indices) instead.", DeprecationWarning, stacklevel=2 ) if n_valid1 is not None and n_valid2 is not None: n_valid = (n_valid1, n_valid2) else: n_valid = None return pi2(counts, populations=pop_indices, n_valid=n_valid)