Source code for pg_gpu.selection

"""
GPU-accelerated selection scan statistics.

This module provides efficient computation of haplotype-based selection scan
statistics including iHS, XP-EHH, nSL, XP-nSL, EHH decay, and Garud's H
statistics. All core computations are parallelized over haplotype pairs
using CuPy for GPU acceleration.
"""

import numpy as np
import cupy as cp
from typing import Union, Optional, Tuple
from .haplotype_matrix import HaplotypeMatrix
from ._utils import get_population_matrix as _get_population_matrix


# ---------------------------------------------------------------------------
# Public API: Standardization utilities
# ---------------------------------------------------------------------------

[docs] def standardize(score): """Centre scores to zero mean and scale to unit variance. Parameters ---------- score : array_like, float, shape (n_variants,) Raw scores (e.g., unstandardized IHS or nSL). Returns ------- ndarray, float, shape (n_variants,) Standardized scores. """ score = np.asarray(score, dtype='f8') return (score - np.nanmean(score)) / np.nanstd(score)
[docs] def standardize_by_allele_count(score, aac, bins=None, n_bins=None): """Standardize scores within allele frequency bins. Parameters ---------- score : array_like, float, shape (n_variants,) Raw scores (e.g., unstandardized IHS or nSL). aac : array_like, int, shape (n_variants,) Alternate allele counts for each variant. bins : array_like, int, optional Allele count bin edges. Overrides n_bins. n_bins : int, optional Number of bins to create. Default: max(aac) // 2. Returns ------- score_standardized : ndarray, float, shape (n_variants,) Standardized scores. bins : ndarray Bin edges used. """ from scipy.stats import binned_statistic score = np.asarray(score, dtype='f8') aac = np.asarray(aac) nonan = ~np.isnan(score) score_nonan = score[nonan] aac_nonan = aac[nonan] if bins is None: if n_bins is None: n_bins = max(1, int(np.max(aac) // 2)) bins = _make_similar_sized_bins(aac_nonan, n_bins) else: bins = np.asarray(bins) mean_score, _, _ = binned_statistic(aac_nonan, score_nonan, statistic=np.mean, bins=bins) std_score, _, _ = binned_statistic(aac_nonan, score_nonan, statistic=np.std, bins=bins) score_standardized = np.full_like(score, np.nan) for i in range(len(bins) - 1): x1 = bins[i] x2 = bins[i + 1] if i == 0: loc = aac < x2 elif i == len(bins) - 2: loc = aac >= x1 else: loc = (aac >= x1) & (aac < x2) m = mean_score[i] s = std_score[i] if s > 0: score_standardized[loc] = (score[loc] - m) / s return score_standardized, bins
# --------------------------------------------------------------------------- # Public API: Garud's H statistics # ---------------------------------------------------------------------------
[docs] def garud_h(matrix, population=None, missing_data='include'): """Compute Garud's H1, H12, H123, and H2/H1 statistics. Accepts either HaplotypeMatrix (uses haplotype frequencies) or GenotypeMatrix (uses diplotype frequencies) and dispatches automatically. Parameters ---------- matrix : HaplotypeMatrix or GenotypeMatrix Haplotype or diploid genotype data. population : str or list, optional Population name or list of sample indices. If None, uses all samples. missing_data : str 'include' - treat missing as wildcard (compatible with any allele). 'exclude' - filter to sites with no missing data. Returns ------- h1 : float Sum of squared haplotype/diplotype frequencies. h12 : float H12 statistic (top two combined). h123 : float H123 statistic (top three combined). h2_h1 : float H2/H1 ratio indicating sweep softness. """ from .genotype_matrix import GenotypeMatrix if isinstance(matrix, GenotypeMatrix): return _garud_h_diploid(matrix, population) haplotype_matrix = matrix if population is not None: haplotype_matrix = _get_population_matrix(haplotype_matrix, population) if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() hap = haplotype_matrix.haplotypes if missing_data == 'exclude': missing_per_var = cp.sum(hap < 0, axis=0) hap = hap[:, missing_per_var == 0] f = _distinct_haplotype_frequencies_missing(hap) return _garud_from_freqs(f)
[docs] def moving_garud_h(haplotype_matrix: HaplotypeMatrix, size: int, start: int = 0, stop: Optional[int] = None, step: Optional[int] = None, population: Optional[Union[str, list]] = None, missing_data: str = 'include'): """Compute Garud's H statistics in moving windows of variants. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data. size : int Window size (number of variants). start : int, optional Starting variant index. stop : int, optional Stopping variant index. step : int, optional Step between windows. Defaults to size (non-overlapping). population : str or list, optional Population name or list of sample indices. missing_data : str 'include' - treat missing as wildcard in pattern matching 'exclude' - filter to sites with no missing data Returns ------- h1 : ndarray, float h12 : ndarray, float h123 : ndarray, float h2_h1 : ndarray, float """ if population is not None: matrix = _get_population_matrix(haplotype_matrix, population) else: matrix = haplotype_matrix if matrix.device == 'CPU': matrix.transfer_to_gpu() hap = matrix.haplotypes if missing_data == 'exclude': missing_per_var = cp.sum(hap < 0, axis=0) hap = hap[:, missing_per_var == 0] n_variants = hap.shape[1] if stop is None: stop = n_variants if step is None: step = size has_missing = bool(cp.any(hap < 0).get()) if has_missing: # Fallback: per-window wildcard matching results = [] for w_start in range(start, stop - size + 1, step): w_end = w_start + size hap_window = hap[:, w_start:w_end] f = _distinct_haplotype_frequencies_missing(hap_window) results.append(_garud_from_freqs(f)) results = np.array(results, dtype='f8') return results[:, 0], results[:, 1], results[:, 2], results[:, 3] # GPU fast path: precompute cumulative weighted sums on GPU, # compute per-window hashes on GPU, transfer only the small hash # arrays to CPU for sorting/counting. # float64 for prefix-sum hashing to avoid accumulation error over long ranges n_hap = hap.shape[0] h_f64 = hap.astype(cp.float64) rng = cp.random.RandomState(seed=42) w1 = rng.standard_normal(n_variants, dtype=cp.float64) w2 = rng.standard_normal(n_variants, dtype=cp.float64) hw1 = h_f64 * w1[cp.newaxis, :] hw2 = h_f64 * w2[cp.newaxis, :] cs1 = cp.zeros((n_hap, n_variants + 1), dtype=cp.float64) cs2 = cp.zeros((n_hap, n_variants + 1), dtype=cp.float64) cp.cumsum(hw1, axis=1, out=cs1[:, 1:]) cp.cumsum(hw2, axis=1, out=cs2[:, 1:]) windows = list(range(start, stop - size + 1, step)) n_windows = len(windows) # Compute all window hashes on GPU at once: (n_windows, n_hap) w_starts = cp.array(windows, dtype=cp.int64) w_ends = w_starts + size # Gather prefix sums at window boundaries: (n_hap, n_windows) all_hash1 = cs1[:, w_ends] - cs1[:, w_starts] # (n_hap, n_windows) all_hash2 = cs2[:, w_ends] - cs2[:, w_starts] # Transfer to CPU: (n_hap, n_windows) * 2 * 8 bytes -- small ah1 = all_hash1.get().T # (n_windows, n_hap) ah2 = all_hash2.get().T results = np.empty((n_windows, 4), dtype='f8') for wi in range(n_windows): order = np.lexsort((ah2[wi], ah1[wi])) s1 = ah1[wi, order] s2 = ah2[wi, order] diff = (np.abs(s1[1:] - s1[:-1]) > 1e-6) | (np.abs(s2[1:] - s2[:-1]) > 1e-6) boundaries = np.concatenate([[True], diff]) boundary_idx = np.where(boundaries)[0] counts = np.diff(np.concatenate([boundary_idx, [n_hap]])) freqs = np.sort(counts)[::-1].astype(np.float64) / n_hap results[wi] = _garud_from_freqs(freqs) return results[:, 0], results[:, 1], results[:, 2], results[:, 3]
def _garud_from_freqs(f): """Compute H1/H12/H123/H2H1 from sorted frequency array.""" h1 = float(np.sum(f ** 2)) h12 = float(np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2)) h123 = float(np.sum(f[:3]) ** 2 + np.sum(f[3:] ** 2)) h2 = h1 - float(f[0] ** 2) h2_h1 = h2 / h1 if h1 > 0 else 0.0 return h1, h12, h123, h2_h1 def _garud_h_diploid(genotype_matrix, population=None): """Garud's H from diplotype frequencies (internal).""" from .diversity import diplotype_frequency_spectrum freqs, _ = diplotype_frequency_spectrum(genotype_matrix, population) return _garud_from_freqs(freqs) # backward compat alias garud_h_diploid = _garud_h_diploid # --------------------------------------------------------------------------- # Public API: nSL / XP-nSL # ---------------------------------------------------------------------------
[docs] def nsl(haplotype_matrix: HaplotypeMatrix, population: Optional[Union[str, list]] = None, missing_data: str = 'include'): """Compute the unstandardized nSL statistic for each variant. Compares the mean shared haplotype length around the reference (0) vs alternate (1) allele at each site. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data. population : str or list, optional Population name or list of sample indices. missing_data : str 'include' - missing extends shared suffix length (default) 'exclude' - filter to sites with no missing data before scan Returns ------- score : ndarray, float, shape (n_variants,) Unstandardized nSL scores: log(nSL1 / nSL0). """ if population is not None: matrix = _get_population_matrix(haplotype_matrix, population) else: matrix = haplotype_matrix if matrix.device == 'CPU': matrix.transfer_to_gpu() if missing_data == 'exclude': missing_per_var = matrix.count_missing(axis=0) valid = cp.where(missing_per_var == 0)[0] if len(valid) == 0: return np.full(matrix.num_variants, np.nan) matrix = matrix.get_subset(valid) # pg_gpu: (n_haplotypes, n_variants) -> allel convention: (n_variants, n_haplotypes) hap = matrix.haplotypes.T # now (n_variants, n_haplotypes) # forward scan nsl0_fwd, nsl1_fwd = _nsl01_scan_gpu(hap) # backward scan nsl0_rev, nsl1_rev = _nsl01_scan_gpu(hap[::-1]) nsl0_rev = nsl0_rev[::-1] nsl1_rev = nsl1_rev[::-1] nsl0 = nsl0_fwd + nsl0_rev nsl1 = nsl1_fwd + nsl1_rev score = cp.log(nsl1 / nsl0) return score.get()
[docs] def xpnsl(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], missing_data: str = 'include'): """Compute the unstandardized cross-population nSL statistic. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data containing both populations. pop1 : str or list First population name or sample indices. pop2 : str or list Second population name or sample indices. missing_data : str 'include' - missing extends shared suffix length (default) 'exclude' - filter to sites with no missing data before scan Returns ------- score : ndarray, float, shape (n_variants,) Unstandardized XP-nSL scores: log(nSL_pop1 / nSL_pop2). """ if missing_data == 'exclude': if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() missing_per_var = cp.sum(haplotype_matrix.haplotypes < 0, axis=0) valid = cp.where(missing_per_var == 0)[0] if len(valid) == 0: return np.full(haplotype_matrix.num_variants, np.nan) haplotype_matrix = haplotype_matrix.get_subset(valid) m1 = _get_population_matrix(haplotype_matrix, pop1) m2 = _get_population_matrix(haplotype_matrix, pop2) if m1.device == 'CPU': m1.transfer_to_gpu() if m2.device == 'CPU': m2.transfer_to_gpu() h1 = m1.haplotypes.T # (n_variants, n_haplotypes1) h2 = m2.haplotypes.T # forward nsl1_fwd = _nsl_scan_gpu(h1) nsl2_fwd = _nsl_scan_gpu(h2) # backward nsl1_rev = _nsl_scan_gpu(h1[::-1])[::-1] nsl2_rev = _nsl_scan_gpu(h2[::-1])[::-1] nsl1 = nsl1_fwd + nsl1_rev nsl2 = nsl2_fwd + nsl2_rev score = cp.log(nsl1 / nsl2) return score.get()
# --------------------------------------------------------------------------- # Public API: IHS / XP-EHH # ---------------------------------------------------------------------------
[docs] def ihs(haplotype_matrix: HaplotypeMatrix, pos=None, map_pos=None, min_ehh: float = 0.05, min_maf: float = 0.05, include_edges: bool = False, gap_scale: int = 20000, max_gap: int = 200000, is_accessible=None, population: Optional[Union[str, list]] = None, missing_data: str = 'include'): """Compute the unstandardized integrated haplotype score (iHS). Compares the integrated EHH for the reference vs alternate allele at each variant. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data. pos : array_like, optional Variant positions. If None, uses haplotype_matrix.positions. map_pos : array_like, optional Genetic map positions. If None, uses physical positions. min_ehh : float, optional Minimum EHH below which to truncate IHH computation. min_maf : float, optional Minimum minor allele frequency to compute score. include_edges : bool, optional If True, report scores even when EHH does not decay below min_ehh. gap_scale : int, optional Rescale gaps larger than this value. max_gap : int, optional Mark gaps larger than this as invalid. is_accessible : array_like, optional Genome accessibility mask. population : str or list, optional Population name or sample indices. missing_data : str 'include' - missing extends shared suffix length (default) 'exclude' - filter to sites with no missing data before scan Returns ------- score : ndarray, float, shape (n_variants,) Unstandardized iHS scores: log(IHH1 / IHH0). """ if population is not None: matrix = _get_population_matrix(haplotype_matrix, population) else: matrix = haplotype_matrix if matrix.device == 'CPU': matrix.transfer_to_gpu() if missing_data == 'exclude': missing_per_var = matrix.count_missing(axis=0) valid = cp.where(missing_per_var == 0)[0] if len(valid) == 0: return np.full(matrix.num_variants, np.nan) matrix = matrix.get_subset(valid) if pos is not None: pos = np.asarray(pos)[valid.get()] if pos is None: pos = matrix.positions if hasattr(pos, 'get'): pos = pos.get() pos = np.asarray(pos) hap = matrix.haplotypes.T # (n_variants, n_haplotypes) n_haplotypes = hap.shape[1] # Fall back to matrix's accessible mask if not provided explicitly if is_accessible is None and haplotype_matrix.has_accessible_mask: is_accessible = haplotype_matrix.accessible_mask.mask gaps = _compute_gaps(pos, map_pos, gap_scale, max_gap, is_accessible) gaps_gpu = cp.asarray(gaps) # Histogram kernel with warp-aggregated atomics, stride-based access, # and on-the-fly pair computation. No contiguous copies needed. scan_fn = _ihh01_scan_hist_gpu # forward scan ihh0_fwd, ihh1_fwd = scan_fn(hap, gaps_gpu, min_ehh, min_maf, include_edges) # backward scan (reversed view, no copy) ihh0_rev, ihh1_rev = scan_fn( hap[::-1], gaps_gpu[::-1], min_ehh, min_maf, include_edges) ihh0_rev = ihh0_rev[::-1] ihh1_rev = ihh1_rev[::-1] # Combine and score on GPU, single D2H transfer at end ihh0 = ihh0_fwd + ihh0_rev ihh1 = ihh1_fwd + ihh1_rev score = cp.log(ihh1 / ihh0) return score.get()
[docs] def xpehh(haplotype_matrix: HaplotypeMatrix, pop1: Union[str, list], pop2: Union[str, list], pos=None, map_pos=None, min_ehh: float = 0.05, include_edges: bool = False, gap_scale: int = 20000, max_gap: int = 200000, is_accessible=None, missing_data: str = 'include'): """Compute the unstandardized cross-population EHH (XP-EHH) statistic. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data containing both populations. pop1 : str or list First population name or sample indices. pop2 : str or list Second population name or sample indices. pos : array_like, optional Variant positions. If None, uses haplotype_matrix.positions. map_pos : array_like, optional Genetic map positions. min_ehh : float, optional Minimum EHH threshold. include_edges : bool, optional Report scores even when EHH does not decay below min_ehh. gap_scale : int, optional Rescale gaps larger than this value. max_gap : int, optional Mark gaps larger than this as invalid. is_accessible : array_like, optional Genome accessibility mask. missing_data : str 'include' - missing extends shared suffix length (default) 'exclude' - filter to sites with no missing data before scan Returns ------- score : ndarray, float, shape (n_variants,) Unstandardized XP-EHH scores: log(IHH_pop1 / IHH_pop2). """ if missing_data == 'exclude': if haplotype_matrix.device == 'CPU': haplotype_matrix.transfer_to_gpu() missing_per_var = cp.sum(haplotype_matrix.haplotypes < 0, axis=0) valid = cp.where(missing_per_var == 0)[0] if len(valid) == 0: return np.full(haplotype_matrix.num_variants, np.nan) haplotype_matrix = haplotype_matrix.get_subset(valid) if pos is not None: pos = np.asarray(pos)[valid.get()] m1 = _get_population_matrix(haplotype_matrix, pop1) m2 = _get_population_matrix(haplotype_matrix, pop2) if m1.device == 'CPU': m1.transfer_to_gpu() if m2.device == 'CPU': m2.transfer_to_gpu() if pos is None: pos = haplotype_matrix.positions if hasattr(pos, 'get'): pos = pos.get() pos = np.asarray(pos) h1 = m1.haplotypes.T h2 = m2.haplotypes.T # Fall back to matrix's accessible mask if not provided explicitly if is_accessible is None and haplotype_matrix.has_accessible_mask: is_accessible = haplotype_matrix.accessible_mask.mask gaps = _compute_gaps(pos, map_pos, gap_scale, max_gap, is_accessible) gaps_gpu = cp.asarray(gaps) # forward ihh1_fwd = _ihh_scan_gpu(h1, gaps_gpu, min_ehh, include_edges) ihh2_fwd = _ihh_scan_gpu(h2, gaps_gpu, min_ehh, include_edges) # backward ihh1_rev = _ihh_scan_gpu(h1[::-1], gaps_gpu[::-1], min_ehh, include_edges)[::-1] ihh2_rev = _ihh_scan_gpu(h2[::-1], gaps_gpu[::-1], min_ehh, include_edges)[::-1] ihh1 = ihh1_fwd + ihh1_rev ihh2 = ihh2_fwd + ihh2_rev score = np.log(ihh1 / ihh2) return score
# --------------------------------------------------------------------------- # Public API: EHH decay # ---------------------------------------------------------------------------
[docs] def ehh_decay(haplotype_matrix: HaplotypeMatrix, truncate: bool = False, population: Optional[Union[str, list]] = None, missing_data: str = 'include'): """Compute EHH decay from the first variant. Parameters ---------- haplotype_matrix : HaplotypeMatrix Haplotype data. truncate : bool, optional If True, exclude trailing zeros from the result. population : str or list, optional Population name or sample indices. missing_data : str 'include' - missing sites skipped in prefix comparison (default) 'exclude' - filter to sites with no missing data Returns ------- ehh : ndarray, float, shape (n_variants,) EHH values at each variant position from the first. """ if population is not None: matrix = _get_population_matrix(haplotype_matrix, population) else: matrix = haplotype_matrix if matrix.device == 'CPU': matrix.transfer_to_gpu() if missing_data == 'exclude': missing_per_var = matrix.count_missing(axis=0) valid = cp.where(missing_per_var == 0)[0] if len(valid) < matrix.num_variants: matrix = matrix.get_subset(valid) hap = matrix.haplotypes # (n_haplotypes, n_variants) n_hap = hap.shape[0] n_variants = hap.shape[1] n_pairs = (n_hap * (n_hap - 1)) // 2 # compute pairwise shared prefix lengths on GPU spl = _pairwise_shared_prefix_lengths_gpu(hap) # Bound the bincount range to the largest observed shared-prefix length. # Past max_spl, EHH is trivially zero; allocating + scanning + transferring # the (n_variants - max_spl) zero tail dominated runtime on whole-genome # scans (issue #94). max_spl reduction is a few microseconds. max_spl = int(spl.max()) minlength = 1 if truncate else max_spl + 1 b = cp.bincount(spl, minlength=minlength) c = cp.cumsum(b[::-1])[:-1] ehh_short = (c / n_pairs)[::-1].get() if truncate: return ehh_short out = np.zeros(n_variants, dtype=ehh_short.dtype) out[:ehh_short.shape[0]] = ehh_short return out
# --------------------------------------------------------------------------- # Private helpers: haplotype frequency computation (for Garud's H) # --------------------------------------------------------------------------- def _distinct_haplotype_frequencies(hap): """Compute distinct haplotype frequencies, sorted descending. Returns ------- freqs : ndarray, float64, sorted descending (CPU) """ from .diversity import _count_unique_haplotypes_gpu n_hap = hap.shape[0] _, counts = _count_unique_haplotypes_gpu(hap) freqs = cp.sort(counts)[::-1].astype(cp.float64) / n_hap return freqs.get() def _distinct_haplotype_frequencies_missing(hap): """Compute distinct haplotype frequencies treating -1 as wildcard. Uses GPU dot-product hashing when no missing data is present, otherwise falls back to CPU wildcard matching. Parameters ---------- hap : cupy.ndarray, shape (n_haplotypes, n_variants) Returns ------- freqs : ndarray, float64, sorted descending (CPU) """ if isinstance(hap, cp.ndarray): has_missing = bool(cp.any(hap < 0).get()) else: has_missing = bool(np.any(hap < 0)) if not has_missing: if not isinstance(hap, cp.ndarray): hap = cp.asarray(hap) return _distinct_haplotype_frequencies(hap) # Fallback: wildcard matching on CPU from .diversity import _cluster_haplotypes_with_missing from collections import Counter n_hap = hap.shape[0] hap_cpu = hap.get().astype(np.int8) if isinstance(hap, cp.ndarray) else hap labels = _cluster_haplotypes_with_missing(hap_cpu) counts = Counter(labels) freqs = np.array(sorted(counts.values(), reverse=True)) / n_hap return freqs # --------------------------------------------------------------------------- # CUDA RawKernels for SSL-based scans # --------------------------------------------------------------------------- _nsl01_kernel = cp.RawKernel(r''' extern "C" __global__ void nsl01_scan_kernel(const signed char* h, int n_variants, int n_haplotypes, long long stride_var, long long stride_hap, const int* pair_j, const int* pair_k, int n_pairs, double* ssl_sum_00, double* ssl_sum_11, int* count_00, int* count_11) { // Each thread handles one haplotype pair across all variants. // Warp-level reduction reduces atomic contention by 32x. int pid = blockDim.x * blockIdx.x + threadIdx.x; if (pid >= n_pairs) return; int j = pair_j[pid]; int k = pair_k[pid]; int ssl = 0; unsigned mask = 0xFFFFFFFF; int lane = threadIdx.x & 31; for (int i = 0; i < n_variants; i++) { signed char a1 = h[i * stride_var + j * stride_hap]; signed char a2 = h[i * stride_var + k * stride_hap]; double sv00 = 0.0, sv11 = 0.0; int c00 = 0, c11 = 0; if (a1 < 0 || a2 < 0) { ssl += 1; } else if (a1 == a2) { ssl += 1; if (a1 == 0) { sv00 = (double)ssl; c00 = 1; } else { sv11 = (double)ssl; c11 = 1; } } else { ssl = 0; } for (int off = 16; off > 0; off >>= 1) { sv00 += __shfl_down_sync(mask, sv00, off); sv11 += __shfl_down_sync(mask, sv11, off); c00 += __shfl_down_sync(mask, c00, off); c11 += __shfl_down_sync(mask, c11, off); } if (lane == 0) { if (sv00 != 0.0) atomicAdd(&ssl_sum_00[i], sv00); if (sv11 != 0.0) atomicAdd(&ssl_sum_11[i], sv11); if (c00 != 0) atomicAdd(&count_00[i], c00); if (c11 != 0) atomicAdd(&count_11[i], c11); } } } ''', 'nsl01_scan_kernel') _nsl_kernel = cp.RawKernel(r''' extern "C" __global__ void nsl_scan_kernel(const signed char* h, int n_variants, int n_haplotypes, long long stride_var, long long stride_hap, const int* pair_j, const int* pair_k, int n_pairs, double* ssl_sum) { int pid = blockDim.x * blockIdx.x + threadIdx.x; if (pid >= n_pairs) return; int j = pair_j[pid]; int k = pair_k[pid]; int ssl = 0; unsigned mask = 0xFFFFFFFF; int lane = threadIdx.x & 31; for (int i = 0; i < n_variants; i++) { signed char a1 = h[i * stride_var + j * stride_hap]; signed char a2 = h[i * stride_var + k * stride_hap]; if ((a1 != a2) && (a1 >= 0) && (a2 >= 0)) { ssl = 0; } else { ssl += 1; } double val = (double)ssl; for (int off = 16; off > 0; off >>= 1) val += __shfl_down_sync(mask, val, off); if (lane == 0) atomicAdd(&ssl_sum[i], val); } } ''', 'nsl_scan_kernel') _pair_index_cache = {} def _get_pair_indices(n_haplotypes): """Pre-compute pair indices as contiguous int32 arrays (cached).""" if n_haplotypes in _pair_index_cache: return _pair_index_cache[n_haplotypes] idx_j, idx_k = cp.triu_indices(n_haplotypes, k=1) result = idx_j.astype(cp.int32).copy(), idx_k.astype(cp.int32).copy() _pair_index_cache[n_haplotypes] = result return result # --------------------------------------------------------------------------- # Private helpers: SSL-based scans for nSL # --------------------------------------------------------------------------- def _nsl01_scan_gpu(h): """Forward scan computing mean SSL for ref (0) and alt (1) allele classes. Uses a stride-aware CUDA kernel with one thread per haplotype pair. Reads transposed/reversed views directly without copying. Parameters ---------- h : cupy.ndarray, shape (n_variants, n_haplotypes) Can be a non-contiguous view (e.g., from .T or [::-1]). Returns ------- nsl0 : cupy.ndarray, float64, shape (n_variants,) nsl1 : cupy.ndarray, float64, shape (n_variants,) """ n_variants, n_haplotypes = h.shape n_pairs = (n_haplotypes * (n_haplotypes - 1)) // 2 pair_j, pair_k = _get_pair_indices(n_haplotypes) ssl_sum_00 = cp.zeros(n_variants, dtype=cp.float64) ssl_sum_11 = cp.zeros(n_variants, dtype=cp.float64) count_00 = cp.zeros(n_variants, dtype=cp.int32) count_11 = cp.zeros(n_variants, dtype=cp.int32) # Ensure int8 dtype; keep as view (no contiguous copy needed) h_view = h if h.dtype == cp.int8 else h.astype(cp.int8, copy=False) s0, s1 = h_view.strides # bytes per element (int8 = 1 byte) block = 256 grid = (n_pairs + block - 1) // block _nsl01_kernel((grid,), (block,), (h_view, np.int32(n_variants), np.int32(n_haplotypes), np.int64(s0), np.int64(s1), pair_j, pair_k, np.int32(n_pairs), ssl_sum_00, ssl_sum_11, count_00, count_11)) nsl0 = cp.where(count_00 > 0, ssl_sum_00 / count_00, cp.nan) nsl1 = cp.where(count_11 > 0, ssl_sum_11 / count_11, cp.nan) return nsl0, nsl1 def _nsl_scan_gpu(h): """Forward scan computing mean SSL across all pairs (for cross-pop stats). Parameters ---------- h : cupy.ndarray, shape (n_variants, n_haplotypes) Returns ------- vnsl : cupy.ndarray, float64, shape (n_variants,) """ n_variants, n_haplotypes = h.shape n_pairs = (n_haplotypes * (n_haplotypes - 1)) // 2 pair_j, pair_k = _get_pair_indices(n_haplotypes) ssl_sum = cp.zeros(n_variants, dtype=cp.float64) h_view = h if h.dtype == cp.int8 else h.astype(cp.int8, copy=False) s0, s1 = h_view.strides block = 256 grid = (n_pairs + block - 1) // block _nsl_kernel((grid,), (block,), (h_view, np.int32(n_variants), np.int32(n_haplotypes), np.int64(s0), np.int64(s1), pair_j, pair_k, np.int32(n_pairs), ssl_sum)) vnsl = ssl_sum / n_pairs return vnsl # --------------------------------------------------------------------------- # Private helpers: gap computation for IHS/XP-EHH # --------------------------------------------------------------------------- def _compute_gaps(pos, map_pos=None, gap_scale=20000, max_gap=200000, is_accessible=None): """Compute gaps between variants for IHH integration. Parameters ---------- pos : ndarray, int, shape (n_variants,) Physical positions. map_pos : ndarray, float, optional Genetic map positions. If None, uses physical positions. gap_scale : int Rescale gaps larger than this. max_gap : int Mark gaps larger than this as -1 (invalid). is_accessible : ndarray, bool, optional Genome accessibility array. Returns ------- gaps : ndarray, float64, shape (n_variants - 1,) """ if hasattr(pos, 'get'): pos = pos.get() pos = np.asarray(pos) if map_pos is None: map_pos = pos if hasattr(map_pos, 'get'): map_pos = map_pos.get() map_pos = np.asarray(map_pos, dtype='f8') gaps = np.diff(map_pos) physical_gaps = np.diff(pos) if is_accessible is not None: is_accessible = np.asarray(is_accessible, dtype=bool) accessible_gaps = np.zeros(len(gaps), dtype='f8') for i in range(len(gaps)): n_access = np.count_nonzero(is_accessible[pos[i]-1:pos[i+1]-1]) accessible_gaps[i] = n_access scaling = accessible_gaps / physical_gaps gaps = gaps * scaling elif gap_scale is not None and gap_scale > 0: scaling = np.ones(gaps.shape, dtype='f8') loc_scale = physical_gaps > gap_scale scaling[loc_scale] = gap_scale / physical_gaps[loc_scale] gaps = gaps * scaling if max_gap is not None and max_gap > 0: gaps[physical_gaps > max_gap] = -1 return gaps # --------------------------------------------------------------------------- # Histogram-based IHH (one thread per pair, chunked) # # Kernels: # _ihh01_ssl_pervar_kernel_v2 -- per-variant allele-class histograms (iHS) # _ihh_ssl_kernel -- single histogram for all pairs (xpEHH) # --------------------------------------------------------------------------- _ihh01_ssl_pervar_kernel_v2 = cp.RawKernel(r''' extern "C" __global__ void ihh01_ssl_pervar_kernel_v2( const signed char* h, // haplotype matrix base pointer long long stride_var, // byte stride between variant rows long long stride_hap, // byte stride between haplotype columns int chunk_len, // number of variants in this chunk int n_haplotypes, int n_pairs, int* hist00, // (chunk_len, hist_size) for 0-0 pairs int* hist11, // (chunk_len, hist_size) for 1-1 pairs int* ssl_state, // (n_pairs,) persistent across chunks int hist_size, int chunk_start // offset into global variant index ) { int pid = blockDim.x * blockIdx.x + threadIdx.x; if (pid >= n_pairs) return; // On-the-fly pair index computation from linear pid // Inverse of: pid = j * n - j*(j+1)/2 + (k - j - 1) int n = n_haplotypes; long long np_total = (long long)n * (n - 1) / 2; int j = n - 2 - (int)floor(sqrt((double)(2 * (np_total - 1 - pid)) + 0.25) - 0.5); int k = pid - (int)((long long)j * n - (long long)j * (j + 1) / 2) + j + 1; int ssl = ssl_state[pid]; int lane = threadIdx.x & 31; for (int ci = 0; ci < chunk_len; ci++) { long long vi = chunk_start + ci; // global variant index signed char h1 = h[vi * stride_var + j * stride_hap]; signed char h2 = h[vi * stride_var + k * stride_hap]; if ((h1 != h2) && (h1 >= 0) && (h2 >= 0)) { ssl = 0; } else { ssl += 1; } int bucket = ssl < hist_size ? ssl : hist_size - 1; int is_class00 = (h1 == 0 && h2 == 0) ? 1 : 0; int is_class11 = (h1 == 1 && h2 == 1) ? 1 : 0; // Warp-aggregated histogram deposits using __match_any_sync unsigned match_bucket = __match_any_sync(0xFFFFFFFF, bucket); unsigned c00_ballot = __ballot_sync(0xFFFFFFFF, is_class00); unsigned c11_ballot = __ballot_sync(0xFFFFFFFF, is_class11); int c00_in_group = __popc(match_bucket & c00_ballot); int c11_in_group = __popc(match_bucket & c11_ballot); // Only the lowest-lane thread in each bucket group does atomics if (lane == (__ffs(match_bucket) - 1)) { if (c00_in_group > 0) atomicAdd(&hist00[ci * hist_size + bucket], c00_in_group); if (c11_in_group > 0) atomicAdd(&hist11[ci * hist_size + bucket], c11_in_group); } } ssl_state[pid] = ssl; } ''', 'ihh01_ssl_pervar_kernel_v2') _count_alleles_kernel = cp.RawKernel(r''' extern "C" __global__ void count_alleles_kernel( const signed char* h, // haplotype matrix base pointer long long stride_var, // byte stride between variant rows long long stride_hap, // byte stride between haplotype columns int chunk_len, int n_haplotypes, int chunk_start, int* n0, // (chunk_len,) output count of 0 alleles int* n1 // (chunk_len,) output count of 1 alleles ) { int ci = blockDim.x * blockIdx.x + threadIdx.x; if (ci >= chunk_len) return; long long vi = chunk_start + ci; int c0 = 0, c1 = 0; for (int j = 0; j < n_haplotypes; j++) { signed char v = h[vi * stride_var + j * stride_hap]; if (v == 0) c0++; else if (v == 1) c1++; } n0[ci] = c0; n1[ci] = c1; } ''', 'count_alleles_kernel') _ihh_ssl_kernel = cp.RawKernel(r''' extern "C" __global__ void ihh_ssl_kernel(const signed char* h, int n_variants, int n_haplotypes, const int* pair_j, const int* pair_k, int n_pairs, int* hist, int* ssl_state, int hist_size) { int pid = blockDim.x * blockIdx.x + threadIdx.x; if (pid >= n_pairs) return; int j = pair_j[pid]; int k = pair_k[pid]; int ssl = ssl_state[pid]; for (int i = 0; i < n_variants; i++) { signed char a1 = h[(long long)i * n_haplotypes + j]; signed char a2 = h[(long long)i * n_haplotypes + k]; if ((a1 != a2) && (a1 >= 0) && (a2 >= 0)) { ssl = 0; } else { ssl += 1; } int bucket = ssl < hist_size ? ssl : hist_size - 1; atomicAdd(&hist[i * hist_size + bucket], 1); } ssl_state[pid] = ssl; } ''', 'ihh_ssl_kernel') _ihh_integrate_kernel = cp.RawKernel(r''' extern "C" __global__ void ihh_integrate_kernel(const int* hist, // (chunk_len, hist_size) const double* gaps, // (total_n_variants - 1,) const int* n_pairs_arr, // (chunk_len,) int chunk_len, int chunk_start, // offset into global variant index int hist_size, double min_ehh, int include_edges, double* out_ihh) { // (chunk_len,) int ci = blockDim.x * blockIdx.x + threadIdx.x; if (ci >= chunk_len) return; int n_pairs = n_pairs_arr[ci]; int variant_idx = chunk_start + ci; // global variant index if (n_pairs == 0) { out_ihh[ci] = 0.0 / 0.0; // NaN return; } double inv_n = 1.0 / (double)n_pairs; // pairs with ssl=0 are no longer homozygous at this site int n_ident = n_pairs - hist[ci * hist_size + 0]; double ehh_prv = (double)n_ident * inv_n; if (ehh_prv <= min_ehh) { out_ihh[ci] = 0.0; return; } double ihh = 0.0; int max_steps = variant_idx < hist_size ? variant_idx : hist_size - 1; for (int i = 1; i <= max_steps; i++) { n_ident -= hist[ci * hist_size + i]; double ehh_cur = (double)n_ident * inv_n; int gap_idx = variant_idx - i; double gap = gaps[gap_idx]; if (gap < 0.0) { out_ihh[ci] = 0.0 / 0.0; // NaN return; } ihh += gap * (ehh_cur + ehh_prv) * 0.5; if (ehh_cur <= min_ehh) { out_ihh[ci] = ihh; return; } ehh_prv = ehh_cur; } // If we exhausted the histogram but variant_idx > hist_size, // the remaining bins were clamped -- all those pairs have already // been subtracted via the last bin. Continue with ehh_prv stable // (no more pairs to subtract) until we either hit min_ehh via a // gap or run out of variants. for (int i = max_steps + 1; i <= variant_idx; i++) { // n_ident unchanged (no more histogram bins) double ehh_cur = (double)n_ident * inv_n; int gap_idx = variant_idx - i; double gap = gaps[gap_idx]; if (gap < 0.0) { out_ihh[ci] = 0.0 / 0.0; // NaN return; } ihh += gap * (ehh_cur + ehh_prv) * 0.5; if (ehh_cur <= min_ehh) { out_ihh[ci] = ihh; return; } ehh_prv = ehh_cur; } if (include_edges) { out_ihh[ci] = ihh; } else { out_ihh[ci] = 0.0 / 0.0; // NaN } } ''', 'ihh_integrate_kernel') def _integrate_ihh_gpu(hist, gaps, n_pairs_arr, chunk_start, min_ehh, include_edges): """Integrate SSL histograms into IHH values on GPU. Parameters ---------- hist : cp.ndarray, int32, shape (chunk_len, hist_size) gaps : cp.ndarray, float64, shape (total_n_variants - 1,) n_pairs_arr : cp.ndarray, int32, shape (chunk_len,) Number of pairs per variant (allele-class-specific or total). chunk_start : int Global variant index offset for this chunk. min_ehh : float include_edges : bool Returns ------- ihh : cp.ndarray, float64, shape (chunk_len,) """ chunk_len, hist_size = hist.shape out_ihh = cp.empty(chunk_len, dtype=cp.float64) hist = cp.ascontiguousarray(hist) gaps = cp.ascontiguousarray(gaps) n_pairs_arr = cp.ascontiguousarray(n_pairs_arr) block = 256 grid = (chunk_len + block - 1) // block _ihh_integrate_kernel((grid,), (block,), (hist, gaps, n_pairs_arr, np.int32(chunk_len), np.int32(chunk_start), np.int32(hist_size), np.float64(min_ehh), np.int32(int(include_edges)), out_ihh)) return out_ihh # --------------------------------------------------------------------------- # Private helpers: SSL-based scans for IHS # --------------------------------------------------------------------------- def _ihh01_scan_hist_gpu(h, gaps, min_ehh=0.05, min_maf=0.05, include_edges=False, max_ssl_cap=None): """Forward scan computing IHH for ref/alt allele classes via histograms. Uses warp-aggregated atomics, on-the-fly pair index computation, and stride-based access to eliminate pair arrays and contiguous copies. All intermediate results stay on GPU. Parameters ---------- h : cupy.ndarray, shape (n_variants, n_haplotypes) Can be a non-contiguous view (e.g., from [::-1]). gaps : cupy.ndarray, float64, shape (n_variants - 1,) min_ehh, min_maf : float include_edges : bool max_ssl_cap : int or None Returns ------- ihh0 : cp.ndarray, float64, shape (n_variants,) ihh1 : cp.ndarray, float64, shape (n_variants,) """ n_variants, n_haplotypes = h.shape n_pairs = (n_haplotypes * (n_haplotypes - 1)) // 2 if max_ssl_cap is None: max_ssl_cap = min(n_variants, 10_000) hist_size = min(max_ssl_cap + 1, n_variants + 1) # Ensure int8 dtype; keep as view (no contiguous copy needed) h_view = h if h.dtype == cp.int8 else h.astype(cp.int8, copy=False) s0, s1 = h_view.strides # bytes per element (int8 = 1 byte) gaps_contig = cp.ascontiguousarray(gaps) # Chunk size: target ~2 GB for two histograms target_bytes = 2 * 1024**3 bytes_per_variant = hist_size * 4 * 2 # two int32 histograms chunk_size = max(1, target_bytes // bytes_per_variant) chunk_size = min(chunk_size, n_variants) block_ssl = 256 grid_ssl = (n_pairs + block_ssl - 1) // block_ssl block_ac = 256 grid_ac_fn = lambda c: (c + block_ac - 1) // block_ac # noqa: E731 ihh0_out = cp.empty(n_variants, dtype=cp.float64) ihh1_out = cp.empty(n_variants, dtype=cp.float64) ssl_state = cp.zeros(n_pairs, dtype=cp.int32) # Pre-allocate per-chunk buffers (reused across chunks; ~2 GB hists) n0_buf = cp.empty(chunk_size, dtype=cp.int32) n1_buf = cp.empty(chunk_size, dtype=cp.int32) hist00_buf = cp.empty((chunk_size, hist_size), dtype=cp.int32) hist11_buf = cp.empty((chunk_size, hist_size), dtype=cp.int32) for chunk_start in range(0, n_variants, chunk_size): chunk_end = min(chunk_start + chunk_size, n_variants) c_len = chunk_end - chunk_start hist00 = hist00_buf[:c_len] hist11 = hist11_buf[:c_len] hist00.fill(0) hist11.fill(0) _ihh01_ssl_pervar_kernel_v2((grid_ssl,), (block_ssl,), (h_view, np.int64(s0), np.int64(s1), np.int32(c_len), np.int32(n_haplotypes), np.int32(n_pairs), hist00, hist11, ssl_state, np.int32(hist_size), np.int32(chunk_start))) n0_chunk = n0_buf[:c_len] n1_chunk = n1_buf[:c_len] _count_alleles_kernel((grid_ac_fn(c_len),), (block_ac,), (h_view, np.int64(s0), np.int64(s1), np.int32(c_len), np.int32(n_haplotypes), np.int32(chunk_start), n0_chunk, n1_chunk)) n00 = (n0_chunk * (n0_chunk - 1) // 2).astype(cp.int32) n11 = (n1_chunk * (n1_chunk - 1) // 2).astype(cp.int32) ihh0_chunk = _integrate_ihh_gpu(hist00, gaps_contig, n00, chunk_start, min_ehh, include_edges) ihh1_chunk = _integrate_ihh_gpu(hist11, gaps_contig, n11, chunk_start, min_ehh, include_edges) total = (n0_chunk + n1_chunk).astype(cp.float64) minor = cp.minimum(n0_chunk, n1_chunk).astype(cp.float64) maf_fail = (total == 0) | (minor / cp.maximum(total, 1) < min_maf) ihh0_chunk[maf_fail] = cp.nan ihh1_chunk[maf_fail] = cp.nan ihh0_out[chunk_start:chunk_end] = ihh0_chunk ihh1_out[chunk_start:chunk_end] = ihh1_chunk return ihh0_out, ihh1_out def _ihh_scan_gpu(h, gaps, min_ehh=0.05, include_edges=False, max_ssl_cap=None): """Forward scan computing IHH across all pairs (for cross-pop stats). Uses fused kernel for <= 256 haplotypes (fast), falls back to chunked histogram approach for larger populations. Parameters ---------- h : cupy.ndarray, shape (n_variants, n_haplotypes) gaps : cupy.ndarray, float64, shape (n_variants - 1,) min_ehh : float include_edges : bool max_ssl_cap : int or None Maximum SSL tracked in histograms. If None, caps at 50K variants (shared haplotype segments rarely extend further in real data). Returns ------- vihh : np.ndarray, float64, shape (n_variants,) """ n_variants, n_haplotypes = h.shape n_pairs = (n_haplotypes * (n_haplotypes - 1)) // 2 # Histogram-based approach for all haplotype counts if max_ssl_cap is None: max_ssl_cap = min(n_variants, 10_000) hist_size = min(max_ssl_cap + 1, n_variants + 1) pair_j, pair_k = _get_pair_indices(n_haplotypes) # Determine chunk size: target ~2 GB per chunk for histogram target_bytes = 2 * 1024**3 bytes_per_variant = hist_size * 4 # one histogram, int32 chunk_size = max(1, target_bytes // bytes_per_variant) chunk_size = min(chunk_size, n_variants) h_contig = cp.ascontiguousarray(h.astype(cp.int8)) gaps_contig = cp.ascontiguousarray(gaps) block_ssl = 256 grid_ssl = (n_pairs + block_ssl - 1) // block_ssl vihh = np.empty(n_variants, dtype=np.float64) ssl_state = cp.zeros(n_pairs, dtype=cp.int32) hist_buf = cp.empty((chunk_size, hist_size), dtype=cp.int32) n_pairs_arr_buf = cp.full(chunk_size, n_pairs, dtype=cp.int32) for chunk_start in range(0, n_variants, chunk_size): chunk_end = min(chunk_start + chunk_size, n_variants) c_len = chunk_end - chunk_start hist = hist_buf[:c_len] hist.fill(0) h_chunk = h_contig[chunk_start:chunk_end] _ihh_ssl_kernel((grid_ssl,), (block_ssl,), (h_chunk, np.int32(c_len), np.int32(n_haplotypes), pair_j, pair_k, np.int32(n_pairs), hist, ssl_state, np.int32(hist_size))) ihh_chunk = _integrate_ihh_gpu(hist, gaps_contig, n_pairs_arr_buf[:c_len], chunk_start, min_ehh, include_edges) vihh[chunk_start:chunk_end] = ihh_chunk.get() return vihh # --------------------------------------------------------------------------- # Private helpers: pairwise shared prefix lengths (for EHH decay) # --------------------------------------------------------------------------- _spl_kernel = cp.RawKernel(r''' extern "C" __global__ void shared_prefix_length_kernel(const signed char* hap, int n_haplotypes, int n_variants, const int* pair_j, const int* pair_k, int n_pairs, int* spl) { int pid = blockDim.x * blockIdx.x + threadIdx.x; if (pid >= n_pairs) return; int j = pair_j[pid]; int k = pair_k[pid]; for (int v = 0; v < n_variants; v++) { signed char a = hap[(long long)j * n_variants + v]; signed char b = hap[(long long)k * n_variants + v]; // skip missing (-1): no information, prefix continues if (a < 0 || b < 0) continue; if (a != b) { spl[pid] = v; return; } } spl[pid] = n_variants; } ''', 'shared_prefix_length_kernel') def _pairwise_shared_prefix_lengths_gpu(hap): """Compute shared prefix length for all haplotype pairs. Parameters ---------- hap : cupy.ndarray, shape (n_haplotypes, n_variants) Returns ------- spl : cupy.ndarray, int32, shape (n_pairs,) """ n_hap, n_var = hap.shape n_pairs = (n_hap * (n_hap - 1)) // 2 pair_j, pair_k = _get_pair_indices(n_hap) spl = cp.empty(n_pairs, dtype=cp.int32) hap_contig = cp.ascontiguousarray(hap.astype(cp.int8)) block = 256 grid = (n_pairs + block - 1) // block _spl_kernel((grid,), (block,), (hap_contig, np.int32(n_hap), np.int32(n_var), pair_j, pair_k, np.int32(n_pairs), spl)) return spl # --------------------------------------------------------------------------- # Private helpers: binning utility # --------------------------------------------------------------------------- def _make_similar_sized_bins(x, n): """Create bins with approximately equal numbers of values. Parameters ---------- x : array_like Values to bin. n : int Target number of bins. Returns ------- bins : ndarray Bin edges. """ y = np.array(x).flatten() y.sort() bins = [y[0]] step = len(y) // n for i in range(step, len(y), step): v = y[i] if v > bins[-1]: bins.append(v) bins[-1] = y[-1] return np.array(bins)