"""
Windowed analysis module for computing population genetics statistics across genomic windows.
This module provides efficient computation of statistics in sliding or non-overlapping windows
with intelligent memory management and GPU acceleration.
"""
import numpy as np
import cupy as cp
import pandas as pd
from typing import List, Dict, Union, Optional, Callable, Iterator, Tuple, Any
from dataclasses import dataclass
import warnings
from tqdm import tqdm
from .haplotype_matrix import HaplotypeMatrix
from . import ld_statistics
from . import divergence
from . import diversity
# Kwargs that the 'local_pca' / 'local_pca_jackknife' dispatch consumes but
# scalar-stat paths don't accept. Filtered out before the recursive call.
_LOCAL_PCA_ONLY_KWARGS = frozenset(
{'k', 'scaler', 'population', 'batch_size', 'window_type', 'regions',
'n_blocks', 'aggregate'})
def _compute_window_bases(haplotype_matrix, win_starts, win_stops,
is_accessible=None):
"""Compute per-window accessible base counts for per-base normalization.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Source matrix (checked for accessible_mask attribute).
win_starts, win_stops : array_like
Window boundary arrays (numpy, not cupy).
is_accessible : array_like, optional
Explicit accessibility mask (takes precedence over matrix attribute).
Returns
-------
numpy.ndarray, float64
Accessible base count per window.
"""
ws = np.asarray(win_starts, dtype=np.float64)
we = np.asarray(win_stops, dtype=np.float64)
if is_accessible is not None:
from .accessible import AccessibleMask
amask = AccessibleMask(np.asarray(is_accessible, dtype=bool))
return amask.count_accessible_windows(
ws.astype(np.int64), we.astype(np.int64))
if haplotype_matrix.has_accessible_mask:
return haplotype_matrix.accessible_mask.count_accessible_windows(
ws.astype(np.int64), we.astype(np.int64))
return we - ws
CANONICAL_WINDOW_PREFIX = (
'chrom', 'start', 'end', 'center', 'n_variants', 'window_id')
def _init_window_results(chrom, win_starts_bp, win_stops_bp, n_variants):
"""Canonical window prefix columns, shared by every dispatch path.
Inserts ``chrom, start, end, center, n_variants, window_id`` in that
order. Dict insertion order drives column order in the final DataFrame.
"""
start = np.asarray(win_starts_bp).astype(int)
end = np.asarray(win_stops_bp).astype(int)
n = len(start)
return {
'chrom': np.repeat(chrom if chrom is not None else 1, n),
'start': start,
'end': end,
'center': (start + end) // 2,
'n_variants': np.asarray(n_variants).astype(int),
'window_id': np.arange(n, dtype=int),
}
@dataclass
class WindowParams:
"""Parameters defining genomic windows."""
window_type: str # 'bp', 'snp', or 'regions'
window_size: int
step_size: int
regions: Optional[pd.DataFrame] = None
@dataclass
class WindowData:
"""Data for a single genomic window."""
chrom: Union[str, int]
start: int
end: int
center: int
matrix: HaplotypeMatrix
n_variants: int
window_id: int
class MemoryManager:
"""Manages GPU memory allocation and chunking strategy."""
def __init__(self, gpu_memory_limit: Union[str, int] = 'auto'):
self.gpu_memory_limit = self._parse_memory_limit(gpu_memory_limit)
self._gpu_memory_info = {}
def _parse_memory_limit(self, limit: Union[str, int]) -> int:
"""Parse memory limit string (e.g., '8GB') to bytes."""
if limit == 'auto':
# Use 80% of available GPU memory
mempool = cp.get_default_memory_pool()
return int(cp.cuda.Device().mem_info[1] * 0.8)
elif isinstance(limit, str):
# Parse strings like '8GB', '512MB'
units = {'KB': 1024, 'MB': 1024**2, 'GB': 1024**3}
for unit, multiplier in units.items():
if limit.upper().endswith(unit):
return int(float(limit[:-len(unit)]) * multiplier)
raise ValueError(f"Invalid memory limit format: {limit}")
else:
return int(limit)
def estimate_window_memory(self, n_variants: int, n_samples: int,
statistics: List[str]) -> int:
"""Estimate memory required for processing a window."""
# Base matrix memory
matrix_memory = n_variants * n_samples * 4 # float32
# Add overhead for statistics computation
overhead_multiplier = 1.5 + 0.2 * len(statistics)
# LD statistics need pairwise computations
if any('ld' in stat.lower() for stat in statistics):
overhead_multiplier += n_variants / 1000 # Scale with variant count
return int(matrix_memory * overhead_multiplier)
def determine_chunk_size(self, total_variants: int, n_samples: int,
window_params: WindowParams,
statistics: List[str]) -> int:
"""Determine optimal chunk size for processing."""
# Estimate memory per variant
window_memory = self.estimate_window_memory(
window_params.window_size, n_samples, statistics
)
# Calculate how many windows fit in memory
windows_per_chunk = max(1, self.gpu_memory_limit // window_memory)
# Account for overlapping windows
if window_params.step_size < window_params.window_size:
overlap_factor = window_params.window_size / window_params.step_size
chunk_variants = int(windows_per_chunk * window_params.step_size +
window_params.window_size)
else:
chunk_variants = windows_per_chunk * window_params.window_size
# Ensure reasonable chunk size
chunk_variants = min(chunk_variants, total_variants)
chunk_variants = max(chunk_variants, window_params.window_size * 2)
return chunk_variants
class StatisticsComputer:
"""Computes population genetics statistics for windows."""
# Built-in single population statistics
# Note: These are now created dynamically to use instance parameters
# Built-in two population statistics
# Note: These are now created dynamically to use instance parameters
# LD-based statistics
LD_STATS = {
'ld_decay': lambda w, **kwargs: _compute_mean_r2(w.matrix, **kwargs),
'mean_r2': lambda w, max_dist: _compute_mean_r2(w.matrix, max_dist),
}
def __init__(self, statistics: List[Union[str, Callable]],
populations: Optional[List[str]] = None,
custom_stat_kwargs: Optional[Dict] = None,
ld_bins: Optional[List[int]] = None,
missing_data: str = 'include',
span_normalize=True):
self.statistics = statistics
self.populations = populations or []
self.custom_stat_kwargs = custom_stat_kwargs or {}
self.ld_bins = ld_bins or [0, 1000, 5000, 10000, 50000]
self.missing_data = missing_data
self.span_normalize = span_normalize
# Categorize statistics
self._categorize_statistics()
def _categorize_statistics(self):
"""Categorize statistics by type for efficient computation."""
sn = self.span_normalize
self.SINGLE_POP_STATS = {
'pi': lambda w: diversity.pi(
w.matrix, span_normalize=sn,
missing_data=self.missing_data),
'theta_w': lambda w: diversity.theta_w(
w.matrix, span_normalize=sn,
missing_data=self.missing_data),
'tajimas_d': lambda w: diversity.tajimas_d(w.matrix, missing_data=self.missing_data),
'n_variants': lambda w: w.n_variants,
'n_singletons': lambda w: diversity.singleton_count(w.matrix, missing_data=self.missing_data),
'segregating_sites': lambda w: diversity.segregating_sites(w.matrix, missing_data=self.missing_data),
}
self.TWO_POP_STATS = {
'dxy': lambda w, p1, p2: divergence.dxy(
w.matrix, p1, p2,
missing_data=self.missing_data,
span_normalize=sn),
'fst': lambda w, p1, p2: divergence.fst(w.matrix, p1, p2, missing_data=self.missing_data),
'fst_hudson': lambda w, p1, p2: divergence.fst_hudson(w.matrix, p1, p2, missing_data=self.missing_data),
'fst_wc': lambda w, p1, p2: divergence.fst_weir_cockerham(w.matrix, p1, p2, missing_data=self.missing_data),
'da': lambda w, p1, p2: divergence.da(w.matrix, p1, p2,
missing_data=self.missing_data,
span_normalize=sn),
}
self.single_pop_stats = []
self.two_pop_stats = []
self.ld_stats = []
self.custom_stats = []
for stat in self.statistics:
if isinstance(stat, str):
if stat in self.SINGLE_POP_STATS:
self.single_pop_stats.append(stat)
elif stat in self.TWO_POP_STATS:
self.two_pop_stats.append(stat)
elif stat in self.LD_STATS:
self.ld_stats.append(stat)
else:
raise ValueError(f"Unknown statistic: {stat}")
else:
# Custom callable
self.custom_stats.append(stat)
def compute(self, window: WindowData) -> Dict[str, float]:
"""Compute all requested statistics for a window."""
results = {
'chrom': window.chrom,
'start': window.start,
'end': window.end,
'center': window.center,
'n_variants': window.n_variants,
'window_id': window.window_id,
}
# Skip if no variants
if window.n_variants == 0:
# Fill with NaN for all statistics
for stat in self.statistics:
if isinstance(stat, str):
results[stat] = np.nan
else:
results[stat.__name__] = np.nan
return results
# Single population statistics
for stat in self.single_pop_stats:
if self.populations:
# Compute for each population
for pop in self.populations:
pop_matrix = self._get_population_matrix(window.matrix, pop)
val = self.SINGLE_POP_STATS[stat](
WindowData(window.chrom, window.start, window.end,
window.center, pop_matrix, pop_matrix.num_variants,
window.window_id)
)
key = f"{stat}_{pop}"
self._store_result(results, key, val)
else:
val = self.SINGLE_POP_STATS[stat](window)
self._store_result(results, stat, val)
# Two population statistics
if len(self.populations) >= 2:
for stat in self.two_pop_stats:
for i, pop1 in enumerate(self.populations):
for pop2 in self.populations[i+1:]:
key = f"{stat}_{pop1}_{pop2}"
val = self.TWO_POP_STATS[stat](window, pop1, pop2)
self._store_result(results, key, val)
# LD statistics
for stat in self.ld_stats:
kwargs = self.custom_stat_kwargs.get(stat, {})
# Add default bins for ld_decay if not provided
if stat == 'ld_decay' and 'bins' not in kwargs:
kwargs['bins'] = self.ld_bins
results[stat] = self.LD_STATS[stat](window, **kwargs)
# Custom statistics
for stat in self.custom_stats:
kwargs = self.custom_stat_kwargs.get(stat.__name__, {})
results[stat.__name__] = stat(window, **kwargs)
return results
@staticmethod
def _store_result(results: Dict, key: str, val):
"""Store a scalar result into the results dict."""
results[key] = val
def _get_population_matrix(self, matrix: HaplotypeMatrix,
pop: str) -> HaplotypeMatrix:
"""Extract population-specific haplotype matrix."""
if pop not in matrix.sample_sets:
raise ValueError(f"Population {pop} not found in sample_sets")
pop_indices = matrix.sample_sets[pop]
pop_haplotypes = matrix.haplotypes[pop_indices, :]
return HaplotypeMatrix(
pop_haplotypes,
matrix.positions,
matrix.chrom_start,
matrix.chrom_end,
sample_sets={'all': list(range(len(pop_indices)))},
n_total_sites=matrix.n_total_sites,
)
class WindowIterator:
"""Iterates over genomic windows."""
def __init__(self, haplotype_matrix: HaplotypeMatrix,
window_params: WindowParams):
self._parent_mask = haplotype_matrix.accessible_mask
self.matrix = haplotype_matrix
self.params = window_params
self.positions = self.matrix.positions
# Get positions as numpy array for easier manipulation
if isinstance(self.positions, cp.ndarray):
self.positions_np = self.positions.get()
else:
self.positions_np = self.positions
def _attach_window_mask(self, window_matrix, start, end):
"""Set per-window n_total_sites from the parent's accessible mask."""
if self._parent_mask is not None:
window_matrix.n_total_sites = \
self._parent_mask.count_accessible(start, end)
def __iter__(self) -> Iterator[WindowData]:
"""Iterate over windows based on window type."""
if self.params.window_type == 'bp':
return self._iter_bp_windows()
elif self.params.window_type == 'snp':
return self._iter_snp_windows()
elif self.params.window_type == 'regions':
return self._iter_region_windows()
else:
raise ValueError(f"Unknown window type: {self.params.window_type}")
def _iter_bp_windows(self) -> Iterator[WindowData]:
"""Iterate over fixed base pair windows."""
chrom_start = int(self.positions_np[0])
chrom_end = int(self.positions_np[-1])
window_id = 0
start = chrom_start
while start < chrom_end:
end = start + self.params.window_size
center = (start + end) // 2
# Find variants in window
mask = (self.positions_np >= start) & (self.positions_np < end)
variant_indices = np.where(mask)[0]
if len(variant_indices) > 0:
# Extract window matrix
window_matrix = self.matrix.get_subset(variant_indices)
# Set correct chromosome coordinates for span normalization
window_matrix.chrom_start = start
window_matrix.chrom_end = end - 1 # end is exclusive in our window definition
self._attach_window_mask(window_matrix, start, end)
yield WindowData(
chrom=1, # TODO: Handle multiple chromosomes
start=start,
end=end,
center=center,
matrix=window_matrix,
n_variants=len(variant_indices),
window_id=window_id
)
window_id += 1
start += self.params.step_size
def _iter_snp_windows(self) -> Iterator[WindowData]:
"""Iterate over fixed SNP count windows."""
n_variants = len(self.positions_np)
window_id = 0
start_idx = 0
while start_idx + self.params.window_size <= n_variants:
end_idx = start_idx + self.params.window_size
# Get positions for this window
window_start = int(self.positions_np[start_idx])
window_end = int(self.positions_np[end_idx - 1])
center = (window_start + window_end) // 2
# Extract window matrix
variant_indices = np.arange(start_idx, end_idx)
window_matrix = self.matrix.get_subset(variant_indices)
# Set correct chromosome coordinates for span normalization
window_matrix.chrom_start = window_start
window_matrix.chrom_end = window_end
self._attach_window_mask(window_matrix, window_start, window_end + 1)
yield WindowData(
chrom=1, # TODO: Handle multiple chromosomes
start=window_start,
end=window_end,
center=center,
matrix=window_matrix,
n_variants=len(variant_indices),
window_id=window_id
)
window_id += 1
start_idx += self.params.step_size
def _iter_region_windows(self) -> Iterator[WindowData]:
"""Iterate over custom regions."""
if self.params.regions is None:
raise ValueError("Regions must be provided for region window type")
for window_id, region in self.params.regions.iterrows():
start = region['start']
end = region['end']
center = (start + end) // 2
# Find variants in region
mask = (self.positions_np >= start) & (self.positions_np < end)
variant_indices = np.where(mask)[0]
if len(variant_indices) > 0:
window_matrix = self.matrix.get_subset(variant_indices)
self._attach_window_mask(window_matrix, start, end)
yield WindowData(
chrom=region.get('chrom', 1),
start=start,
end=end,
center=center,
matrix=window_matrix,
n_variants=len(variant_indices),
window_id=window_id
)
def count_windows(self) -> int:
"""Count total number of windows."""
if self.params.window_type == 'bp':
chrom_start = int(self.positions_np[0])
chrom_end = int(self.positions_np[-1])
return max(1, (chrom_end - chrom_start - self.params.window_size) //
self.params.step_size + 1)
elif self.params.window_type == 'snp':
n_variants = len(self.positions_np)
if n_variants <= self.params.window_size:
return 1
else:
# Number of complete windows plus any partial window
return ((n_variants - self.params.window_size) // self.params.step_size) + 1
elif self.params.window_type == 'regions':
return len(self.params.regions)
class WindowedAnalyzer:
"""Main class for windowed analysis of genomic data."""
def __init__(self,
window_type: str = 'bp',
window_size: int = 50000,
step_size: Optional[int] = None,
statistics: List[Union[str, Callable]] = ['pi'],
populations: Optional[List[str]] = None,
regions: Optional[pd.DataFrame] = None,
ld_max_distance: int = 10000,
ld_bins: Optional[List[int]] = None,
gpu_memory_limit: Union[str, int] = 'auto',
chunk_size: Union[str, int] = 'auto',
n_jobs: int = 1,
progress_bar: bool = True,
custom_stat_kwargs: Optional[Dict] = None,
missing_data: str = 'include',
span_normalize=True):
"""
Initialize windowed analyzer.
Parameters
----------
window_type : str
Type of windows: 'bp' (base pairs), 'snp' (SNP count), or 'regions'
window_size : int
Size of windows (in bp or SNP count)
step_size : int, optional
Step between windows. If None, uses window_size (non-overlapping)
statistics : list
Statistics to compute. Can be strings or callable functions
populations : list, optional
Population names for population-specific statistics
regions : DataFrame, optional
Custom regions for window_type='regions'
ld_max_distance : int
Maximum distance for LD calculations
ld_bins : list, optional
Distance bins for LD decay
gpu_memory_limit : str or int
GPU memory limit ('auto', '8GB', or bytes)
chunk_size : str or int
Chunk size for processing ('auto' or number of variants)
n_jobs : int
Number of CPU threads
progress_bar : bool
Show progress bar
custom_stat_kwargs : dict
Keyword arguments for custom statistics
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
span_normalize : bool
``True`` (default): auto-detect best denominator per window.
``False``: return raw sums.
"""
self.window_params = WindowParams(
window_type=window_type,
window_size=window_size,
step_size=step_size or window_size,
regions=regions
)
self.statistics = statistics
self.populations = populations
self.ld_max_distance = ld_max_distance
self.ld_bins = ld_bins or [0, 1000, 5000, 10000, 50000]
self.gpu_memory_limit = gpu_memory_limit
self.chunk_size = chunk_size
self.n_jobs = n_jobs
self.progress_bar = progress_bar
self.missing_data = missing_data
self.span_normalize = span_normalize
# Initialize components
self.memory_manager = MemoryManager(gpu_memory_limit)
self.stats_computer = StatisticsComputer(
statistics, populations, custom_stat_kwargs, ld_bins=self.ld_bins,
missing_data=missing_data, span_normalize=span_normalize
)
def compute(self, haplotype_matrix: HaplotypeMatrix) -> pd.DataFrame:
"""
Compute windowed statistics for entire haplotype matrix.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Input haplotype data
Returns
-------
pd.DataFrame
Results with statistics for each window
"""
# Set population info if needed
if self.populations and not haplotype_matrix.sample_sets:
warnings.warn("Populations specified but haplotype_matrix has no sample_sets")
# Create window iterator
window_iter = WindowIterator(haplotype_matrix, self.window_params)
total_windows = window_iter.count_windows()
# Process windows
results = []
with tqdm(total=total_windows, disable=not self.progress_bar,
desc="Computing windows") as pbar:
for window in window_iter:
window_results = self.stats_computer.compute(window)
results.append(window_results)
pbar.update(1)
return pd.DataFrame(results)
def compute_region(self, haplotype_matrix: HaplotypeMatrix,
chrom: Union[str, int],
start: int,
end: int) -> pd.DataFrame:
"""
Compute statistics for a specific genomic region.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Input haplotype data
chrom : str or int
Chromosome identifier
start : int
Region start position
end : int
Region end position
Returns
-------
pd.DataFrame
Results for windows in the specified region
"""
# Extract region from matrix
region_matrix = haplotype_matrix.get_subset_from_range(start, end)
# Compute statistics
return self.compute(region_matrix)
def compute_streaming(self, haplotype_matrix: HaplotypeMatrix,
batch_size: int = 100) -> Iterator[pd.DataFrame]:
"""
Compute statistics in batches for memory efficiency.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Input haplotype data
batch_size : int
Number of windows per batch
Yields
------
pd.DataFrame
Batch of results
"""
window_iter = WindowIterator(haplotype_matrix, self.window_params)
total_windows = window_iter.count_windows()
batch_results = []
with tqdm(total=total_windows, disable=not self.progress_bar,
desc="Computing windows") as pbar:
for window in window_iter:
window_results = self.stats_computer.compute(window)
batch_results.append(window_results)
pbar.update(1)
if len(batch_results) >= batch_size:
yield pd.DataFrame(batch_results)
batch_results = []
# Yield remaining results
if batch_results:
yield pd.DataFrame(batch_results)
def _build_scatter_indices(pos_cpu, chrom_start, chrom_end,
window_size, step_size):
"""Build window indices for scatter-add over (possibly overlapping) windows.
When step_size < window_size, each variant falls inside up to
n_per_var = ceil(window_size / step_size) consecutive windows. We
replicate each variant's slot into those candidate windows, then mask
out candidates that are either out of range or whose window stop is
<= the variant position (which can happen at the low end when
window_size is not a multiple of step_size, or at chromosome edges).
Returns
-------
win_starts, win_stops : np.ndarray, float64, shape (n_windows,)
n_windows : int
n_per_var : int
k_safe : np.ndarray, int64, shape (n_variants, n_per_var)
Clipped candidate window indices per variant, safe for indexing
into arrays of length n_windows. Rows where `contains` is False
must be ignored.
contains : np.ndarray, bool, shape (n_variants, n_per_var)
True where the candidate window actually contains the variant.
win_idx_gpu : cp.ndarray, int64, shape (n_variants * n_per_var,)
mask_gpu : cp.ndarray, bool, shape (n_variants * n_per_var,)
The 2D arrays are raveled in row-major (C) order so that broadcasting
`values[:, None]` across n_per_var and raveling yields values aligned
with win_idx_gpu / mask_gpu.
"""
win_starts = np.arange(int(chrom_start), int(chrom_end), step_size,
dtype=np.float64)
win_stops = win_starts + window_size
n_windows = len(win_starts)
n_per_var = int(np.ceil(window_size / step_size))
k_hi = np.searchsorted(win_starts, pos_cpu, side='right') - 1
offsets = np.arange(n_per_var, dtype=np.int64)[None, :]
k_cand = k_hi[:, None] - offsets
in_range = (k_cand >= 0) & (k_cand < n_windows)
k_safe = np.clip(k_cand, 0, n_windows - 1)
contains = in_range & (pos_cpu[:, None] < win_stops[k_safe])
win_idx_gpu = cp.asarray(k_safe.ravel())
mask_gpu = cp.asarray(contains.ravel())
return (win_starts, win_stops, n_windows, n_per_var,
k_safe, contains, win_idx_gpu, mask_gpu)
def _windowed_thetas_scatter(haplotype_matrix, window_size, step_size,
statistics, populations, missing_data,
span_normalize, chrom=None):
"""Compute windowed theta estimators via scatter-add on GPU.
Uses dac_and_n fused kernel + direct vectorized arithmetic + scatter_add
for per-window accumulation. Handles variable sample sizes per site.
"""
from ._utils import get_population_matrix
from cupyx import scatter_add
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
pop = populations[0] if populations else None
if pop is not None:
matrix = get_population_matrix(haplotype_matrix, pop)
else:
matrix = haplotype_matrix
if missing_data == 'exclude':
matrix = matrix.exclude_missing_sites()
if matrix.num_variants == 0:
return pd.DataFrame()
from .diversity import _prepare_dac, _site_contribution
dac, n_valid, d, n_safe, seg, n_hap = _prepare_dac(matrix)
pos = matrix.positions
if isinstance(pos, cp.ndarray):
pos_cpu = pos.get()
else:
pos_cpu = np.asarray(pos)
chrom_start = (matrix.chrom_start if matrix.chrom_start is not None
else int(pos_cpu[0]))
chrom_end = (matrix.chrom_end if matrix.chrom_end is not None
else int(pos_cpu[-1]))
(win_starts, win_stops, n_windows, n_per_var,
k_safe, contains, win_idx_gpu, mask_gpu) = _build_scatter_indices(
pos_cpu, chrom_start, chrom_end, window_size, step_size)
def scatter_sum(values):
out = cp.zeros(n_windows, dtype=cp.float64)
if n_per_var == 1:
scatter_add(out, win_idx_gpu, values * mask_gpu)
else:
rep = cp.broadcast_to(values[:, None],
(values.shape[0], n_per_var)).ravel()
scatter_add(out, win_idx_gpu, rep * mask_gpu)
return out
# Compute requested per-variant contributions and scatter
n_variants_per_window = np.bincount(
k_safe[contains], minlength=n_windows)
results = _init_window_results(
chrom, win_starts, np.minimum(win_stops, chrom_end),
n_variants=n_variants_per_window)
# Span for normalization
if span_normalize is not False:
if haplotype_matrix.has_accessible_mask:
am = haplotype_matrix.accessible_mask
spans = am.count_accessible_windows(
results['start'], results['end'])
elif matrix.n_total_sites is not None:
# Proportional from n_total_sites
total_span = chrom_end - chrom_start
spans = (win_stops - win_starts) * matrix.n_total_sites / total_span
else:
spans = np.minimum(win_stops, chrom_end) - win_starts
# Windows with zero span (fully inaccessible) divide to NaN rather
# than 0 — the per-base rate is undefined, not zero. Clamp for
# safe division and carry the zero-span mask for post-processing.
zero_span = spans <= 0
spans = np.maximum(spans, 1.0)
else:
spans = np.ones(n_windows)
zero_span = np.zeros(n_windows, dtype=bool)
stats_set = set(statistics)
# Determine which theta estimators are needed as intermediates
needs = {
'pi': stats_set & {'pi', 'tajimas_d', 'fay_wu_h', 'normalized_fay_wu_h', 'zeng_dh'},
'theta_h': stats_set & {'theta_h', 'fay_wu_h', 'normalized_fay_wu_h', 'zeng_dh'},
'theta_l': stats_set & {'theta_l', 'zeng_e'},
'watterson': stats_set & {'theta_w', 'tajimas_d', 'zeng_e', 'zeng_dh'},
}
# Compute per-variant contributions via _site_contribution (single source of truth)
raw = {}
for est_name, dependents in needs.items():
if dependents:
raw[est_name] = scatter_sum(
_site_contribution(est_name, d, n_safe, seg, n_valid, n_hap, dac=dac))
if stats_set & {'segregating_sites', 'tajimas_d', 'normalized_fay_wu_h', 'zeng_e', 'zeng_dh'}:
seg_count = scatter_sum(seg.astype(cp.float64))
if 'singletons' in stats_set:
is_sing = seg & ((dac == 1) | (dac == n_valid - 1))
sing_count = scatter_sum(is_sing.astype(cp.float64))
if 'max_daf' in stats_set:
dafs = cp.where(seg, d / n_safe, 0.0).get()
rep_dafs = np.broadcast_to(dafs[:, None], contains.shape) * contains
max_daf_arr = np.zeros(n_windows)
np.maximum.at(max_daf_arr, k_safe.ravel(), rep_dafs.ravel())
# Build output — theta estimators as per-base rates
for est_name, out_name in [('pi', 'pi'), ('watterson', 'theta_w'),
('theta_h', 'theta_h'), ('theta_l', 'theta_l')]:
if out_name in stats_set and est_name in raw:
results[out_name] = raw[est_name].get() / spans
if 'segregating_sites' in stats_set:
results['segregating_sites'] = seg_count.get()
if 'singletons' in stats_set:
results['singletons'] = sing_count.get()
if 'max_daf' in stats_set:
results['max_daf'] = max_daf_arr
# Composite stats from raw theta sums
if 'fay_wu_h' in stats_set:
results['fay_wu_h'] = (raw['pi'] - raw['theta_h']).get() / spans
# Neutrality tests — unified Achaz (2009) variance framework
need_variance = stats_set & {'tajimas_d', 'normalized_fay_wu_h', 'zeng_e', 'zeng_dh'}
if need_variance:
from .diversity import _achaz_variance_coefficients
S = seg_count.get()
a1 = sum(1.0 / i for i in range(1, n_hap))
a2 = sum(1.0 / (i ** 2) for i in range(1, n_hap))
theta_est = S / a1
theta_sq_est = S * (S - 1) / (a1 ** 2 + a2)
def windowed_test(w1, w2, numerator_arr):
alpha, beta = _achaz_variance_coefficients(w1, w2, n_hap)
var = alpha * theta_est + beta * theta_sq_est
with np.errstate(invalid='ignore', divide='ignore'):
return np.where((var > 0) & (S >= 3),
numerator_arr / np.sqrt(var), np.nan)
if stats_set & {'tajimas_d', 'zeng_dh'}:
tajd = windowed_test('pi', 'watterson',
raw['pi'].get() - raw['watterson'].get())
if 'tajimas_d' in stats_set:
results['tajimas_d'] = tajd
if 'normalized_fay_wu_h' in stats_set:
results['normalized_fay_wu_h'] = windowed_test(
'pi', 'theta_h', raw['pi'].get() - raw['theta_h'].get())
if 'zeng_e' in stats_set:
results['zeng_e'] = windowed_test(
'theta_l', 'watterson',
raw['theta_l'].get() - raw['watterson'].get())
if 'zeng_dh' in stats_set:
H = (raw['pi'] - raw['theta_h']).get() / spans
results['zeng_dh'] = np.where(
(tajd < 0) & (H < 0), tajd * H, 0.0)
# Windows with no accessible bases get NaN for every per-base rate.
if zero_span.any():
for name in ('pi', 'theta_w', 'theta_h', 'theta_l',
'fay_wu_h', 'zeng_dh'):
if name in results:
results[name] = np.where(zero_span, np.nan, results[name])
return pd.DataFrame(results)
def _twopop_site_components(hap1, hap2):
"""Compute per-site two-population components on GPU.
Returns (mpd1, mpd2, between) where:
mpd1 = within-pop1 mean pairwise difference per site
mpd2 = within-pop2 mean pairwise difference per site
between = between-pop mean pairwise difference per site
All quantities use per-site valid counts (missing-data aware).
Reduces along the sample axis in chunks to avoid materializing
full (n_hap, n_var) float64 intermediates.
"""
n_var = hap1.shape[1]
# Compute per-site allele counts and valid counts via chunked reduction.
# Only allocates (n_hap, chunk) temporaries instead of (n_hap, n_var).
ac1 = cp.zeros(n_var, dtype=cp.float64)
n1 = cp.zeros(n_var, dtype=cp.float64)
ac2 = cp.zeros(n_var, dtype=cp.float64)
n2 = cp.zeros(n_var, dtype=cp.float64)
chunk = max(1, n_var // 20)
for s in range(0, n_var, chunk):
e = min(s + chunk, n_var)
h1c = hap1[:, s:e]
v1 = h1c >= 0
n1[s:e] = cp.sum(v1, axis=0)
ac1[s:e] = cp.sum(cp.where(v1, h1c, 0), axis=0)
del h1c, v1
h2c = hap2[:, s:e]
v2 = h2c >= 0
n2[s:e] = cp.sum(v2, axis=0)
ac2[s:e] = cp.sum(cp.where(v2, h2c, 0), axis=0)
del h2c, v2
# Within-pop mean pairwise differences
n1_pairs = n1 * (n1 - 1) / 2
n1_same = ((n1 - ac1) * (n1 - ac1 - 1) + ac1 * (ac1 - 1)) / 2
mpd1 = cp.where(n1_pairs > 0, (n1_pairs - n1_same) / n1_pairs, 0.0)
n2_pairs = n2 * (n2 - 1) / 2
n2_same = ((n2 - ac2) * (n2 - ac2 - 1) + ac2 * (ac2 - 1)) / 2
mpd2 = cp.where(n2_pairs > 0, (n2_pairs - n2_same) / n2_pairs, 0.0)
# Between-pop mean pairwise differences
n_between = n1 * n2
n_between_same = (n1 - ac1) * (n2 - ac2) + ac1 * ac2
between = cp.where(n_between > 0,
(n_between - n_between_same) / n_between, 0.0)
return mpd1, mpd2, between
def _windowed_twopop_scatter(haplotype_matrix, window_size, step_size,
statistics, populations, missing_data,
span_normalize, chrom=None):
"""Compute windowed two-population stats via scatter-add on GPU.
Same pattern as _windowed_thetas_scatter but for fst, dxy, da.
Uses per-site valid counts for correct missing data handling.
"""
from ._utils import get_population_matrix
from cupyx import scatter_add
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
pop1_name, pop2_name = populations[0], populations[1]
mat1 = get_population_matrix(haplotype_matrix, pop1_name)
mat2 = get_population_matrix(haplotype_matrix, pop2_name)
if missing_data == 'exclude':
haplotype_matrix = haplotype_matrix.exclude_missing_sites(
populations=[pop1_name, pop2_name])
if haplotype_matrix.num_variants == 0:
return pd.DataFrame()
mat1 = get_population_matrix(haplotype_matrix, pop1_name)
mat2 = get_population_matrix(haplotype_matrix, pop2_name)
hap1 = mat1.haplotypes
hap2 = mat2.haplotypes
pos = haplotype_matrix.positions
if isinstance(pos, cp.ndarray):
pos_cpu = pos.get()
else:
pos_cpu = np.asarray(pos)
chrom_start = (haplotype_matrix.chrom_start
if haplotype_matrix.chrom_start is not None
else int(pos_cpu[0]))
chrom_end = (haplotype_matrix.chrom_end
if haplotype_matrix.chrom_end is not None
else int(pos_cpu[-1]))
(win_starts, win_stops, n_windows, n_per_var,
k_safe, contains, win_idx_gpu, mask_gpu) = _build_scatter_indices(
pos_cpu, chrom_start, chrom_end, window_size, step_size)
def scatter_sum(values):
out = cp.zeros(n_windows, dtype=cp.float64)
if n_per_var == 1:
scatter_add(out, win_idx_gpu, values * mask_gpu)
else:
rep = cp.broadcast_to(values[:, None],
(values.shape[0], n_per_var)).ravel()
scatter_add(out, win_idx_gpu, rep * mask_gpu)
return out
n_variants_per_window = np.bincount(
k_safe[contains], minlength=n_windows)
results = _init_window_results(
chrom, win_starts, np.minimum(win_stops, chrom_end),
n_variants=n_variants_per_window)
# Span normalization
if span_normalize is not False:
if haplotype_matrix.has_accessible_mask:
am = haplotype_matrix.accessible_mask
spans = am.count_accessible_windows(
results['start'], results['end'])
elif haplotype_matrix.n_total_sites is not None:
total_span = chrom_end - chrom_start
spans = (win_stops - win_starts) * haplotype_matrix.n_total_sites / total_span
else:
spans = np.minimum(win_stops, chrom_end) - win_starts
# Windows with zero span (fully inaccessible) divide to NaN rather
# than 0 — the per-base rate is undefined, not zero. Clamp for
# safe division and carry the zero-span mask for post-processing.
zero_span = spans <= 0
spans = np.maximum(spans, 1.0)
else:
spans = np.ones(n_windows)
zero_span = np.zeros(n_windows, dtype=bool)
# Compute per-site components (single pass over the data)
mpd1, mpd2, between = _twopop_site_components(hap1, hap2)
stats_set = set(statistics)
# Scatter-add per-site components into windows (deduplicated)
need_between = stats_set & {'fst', 'fst_hudson', 'dxy', 'da'}
between_sum = scatter_sum(between) if need_between else None
if stats_set & {'fst', 'fst_hudson', 'da'}:
within = (mpd1 + mpd2) / 2.0
fst_num = scatter_sum(between - within)
if 'da' in stats_set:
pi1_sum = scatter_sum(mpd1)
pi2_sum = scatter_sum(mpd2)
# Build output — stay on GPU, single .get() per result
spans_gpu = cp.asarray(spans)
if stats_set & {'fst', 'fst_hudson'}:
fst_vals = cp.where(between_sum > 0, fst_num / between_sum, cp.nan).get()
if 'fst' in stats_set:
results['fst'] = fst_vals
if 'fst_hudson' in stats_set:
results['fst_hudson'] = fst_vals
if 'dxy' in stats_set:
results['dxy'] = (between_sum / spans_gpu).get()
if 'da' in stats_set:
results['da'] = ((between_sum - (pi1_sum + pi2_sum) / 2.0) / spans_gpu).get()
# Windows with no accessible bases get NaN for every per-base rate.
if zero_span.any():
for name in ('dxy', 'da'):
if name in results:
results[name] = np.where(zero_span, np.nan, results[name])
return pd.DataFrame(results)
# Convenience function for simple usage
def _stream_windowed_analysis(streaming_hm, *, window_size, step_size,
statistics, populations, missing_data,
span_normalize, accessible_bed, chrom,
**kwargs) -> pd.DataFrame:
"""Run windowed_analysis chunk-by-chunk over a StreamingHaplotypeMatrix.
The per-chunk DataFrames are concatenated row-wise. Each chunk's windows
are computed in isolation and the chunk boundaries are aligned so a
window never straddles two chunks; the StreamingHaplotypeMatrix
constructor picked an ``align_bp`` for this reason, so the only
contract we have to enforce here is ``window_size`` divides it.
Sliding windows (``step_size`` != ``window_size``) and the local-PCA
dispatch both need cross-chunk state and are not yet supported on the
streaming path; both raise rather than silently returning wrong
results.
"""
if step_size is None:
step_size = window_size
if step_size != window_size:
raise NotImplementedError(
"sliding windows (step_size != window_size) over a "
"StreamingHaplotypeMatrix would straddle chunk boundaries; "
"supply non-overlapping windows or materialize the region "
"eagerly first."
)
align_bp = streaming_hm.align_bp
if window_size > align_bp or align_bp % window_size != 0:
raise ValueError(
f"window_size={window_size} must divide the streaming matrix's "
f"chunk alignment ({align_bp}); pass a smaller window_size or "
f"re-open the store with a matching chunk_bp."
)
if any(s in ("local_pca", "local_pca_jackknife") for s in statistics):
raise NotImplementedError(
"local_pca requires a chromosome-wide reference and is not "
"available on the StreamingHaplotypeMatrix path; materialize "
"the region eagerly to run it."
)
# Parse the BED mask once up front and inject the resolved
# AccessibleMask into each chunk -- otherwise windowed_analysis would
# re-open and re-parse the BED for every chunk via
# haplotype_matrix.set_accessible_mask.
shared_mask = None
if accessible_bed is not None:
from .accessible import resolve_accessible_mask
shared_mask = resolve_accessible_mask(
accessible_bed, streaming_hm.chrom_start, streaming_hm.chrom_end,
chrom=chrom,
)
parts = []
for left, right, chunk_hm in streaming_hm.iter_gpu_chunks():
if shared_mask is not None:
chunk_hm.accessible_mask = shared_mask
df = windowed_analysis(
chunk_hm,
window_size=window_size, step_size=step_size,
statistics=statistics, populations=populations,
missing_data=missing_data, span_normalize=span_normalize,
accessible_bed=None, chrom=chrom,
**kwargs,
)
if df is not None and len(df):
parts.append(df)
if not parts:
# Empty source (e.g. mappable region contains no variants). Return
# an empty frame; the eager path would have raised, but for
# streaming this is a legitimate "no data in any chunk" outcome.
return pd.DataFrame()
out = pd.concat(parts, ignore_index=True)
# Per-chunk window_id columns restart at 0 inside each chunk; re-number
# so concatenated results carry a globally-unique window_id matching
# what the eager path produces.
if "window_id" in out.columns:
out["window_id"] = np.arange(len(out), dtype=out["window_id"].dtype)
return out
[docs]
def windowed_analysis(haplotype_matrix: HaplotypeMatrix,
window_size: int = 50000,
step_size: Optional[int] = None,
statistics: List[str] = ['pi'],
populations: Optional[List[str]] = None,
missing_data: str = 'include',
span_normalize=True,
accessible_bed: str = None,
chrom: str = None,
**kwargs) -> pd.DataFrame:
"""
Convenience function for windowed analysis.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Input haplotype data
window_size : int
Window size in base pairs
step_size : int, optional
Step size. If None, uses window_size
statistics : list
Statistics to compute
populations : list, optional
Population names
missing_data : str
'include' - Use all sites, calculate from available data per site
'exclude' - Only use sites with no missing data
span_normalize : bool
``True`` (default): auto-detect best denominator per window.
``False``: return raw sums.
accessible_bed : str, optional
Path to a BED file defining accessible/callable regions.
If provided and the matrix has no mask, loads the mask.
**kwargs
Additional arguments passed to WindowedAnalyzer
Returns
-------
pd.DataFrame
Windowed statistics results
"""
from .streaming_matrix import StreamingHaplotypeMatrix
if isinstance(haplotype_matrix, StreamingHaplotypeMatrix):
return _stream_windowed_analysis(
haplotype_matrix,
window_size=window_size, step_size=step_size,
statistics=statistics, populations=populations,
missing_data=missing_data, span_normalize=span_normalize,
accessible_bed=accessible_bed, chrom=chrom,
**kwargs,
)
if accessible_bed is not None and not haplotype_matrix.has_accessible_mask:
haplotype_matrix.set_accessible_mask(accessible_bed, chrom=chrom)
if step_size is None:
step_size = window_size
# Local PCA dispatch: vector-valued per-window output; cannot live in the
# scalar-stat DataFrame pipeline. Return a LocalPCAResult (with scalar
# stats merged into .windows if requested alongside).
# 'local_pca_jackknife' implies 'local_pca' -- the SE is meaningless
# without the base eigendecomposition.
if 'local_pca' in statistics or 'local_pca_jackknife' in statistics:
from . import decomposition
want_jackknife = 'local_pca_jackknife' in statistics
local_pca_kwargs = {
'k': kwargs.get('k', 2),
'scaler': kwargs.get('scaler', None),
'missing_data': missing_data,
'population': kwargs.get('population', None),
'batch_size': kwargs.get('batch_size', None),
'window_size': window_size,
'step_size': step_size,
'window_type': kwargs.get('window_type', 'bp'),
'regions': kwargs.get('regions', None),
}
if want_jackknife:
local_pca_kwargs['n_blocks'] = kwargs.get('n_blocks', 10)
local_pca_kwargs['aggregate'] = kwargs.get('aggregate', 'mean')
result = decomposition._local_pca_with_jackknife(
haplotype_matrix, **local_pca_kwargs)
else:
result = decomposition.local_pca(haplotype_matrix, **local_pca_kwargs)
other_stats = [s for s in statistics
if s not in ('local_pca', 'local_pca_jackknife')]
if other_stats:
scalar_df = windowed_analysis(
haplotype_matrix,
window_size=window_size,
step_size=step_size,
statistics=other_stats,
populations=populations,
missing_data=missing_data,
span_normalize=span_normalize,
accessible_bed=None, # already applied above
chrom=chrom,
**{k: v for k, v in kwargs.items()
if k not in _LOCAL_PCA_ONLY_KWARGS},
)
# result.windows already carries the canonical prefix; drop it
# from scalar_df and merge on window_id to avoid duplicates.
duplicates = [c for c in CANONICAL_WINDOW_PREFIX
if c != 'window_id' and c in scalar_df.columns]
result.windows = result.windows.merge(
scalar_df.drop(columns=duplicates),
on='window_id', how='left')
return result
# Scatter-add path: single-pop theta estimators via dac_and_n + scatter
scatter_single = {'pi', 'theta_w', 'tajimas_d', 'segregating_sites',
'theta_h', 'theta_l', 'fay_wu_h', 'singletons',
'normalized_fay_wu_h', 'zeng_e', 'zeng_dh', 'max_daf'}
scatter_twopop = {'fst', 'fst_hudson', 'dxy', 'da'}
requested = set(statistics)
if missing_data in ('include', 'exclude'):
# Pure single-pop request
if requested <= scatter_single and len(populations or []) <= 1:
result = _windowed_thetas_scatter(
haplotype_matrix, window_size, step_size,
statistics, populations, missing_data, span_normalize,
chrom=chrom)
if result is not None:
return result
# Pure two-pop request
if requested <= scatter_twopop and len(populations or []) == 2:
result = _windowed_twopop_scatter(
haplotype_matrix, window_size, step_size,
statistics, populations, missing_data, span_normalize,
chrom=chrom)
if result is not None:
return result
# Mixed single + two-pop request
single_stats = sorted(requested & scatter_single)
twopop_stats = sorted(requested & scatter_twopop)
if (single_stats and twopop_stats
and requested <= (scatter_single | scatter_twopop)
and len(populations or []) == 2):
df1 = _windowed_thetas_scatter(
haplotype_matrix, window_size, step_size,
single_stats, [populations[0]], missing_data, span_normalize,
chrom=chrom)
df2 = _windowed_twopop_scatter(
haplotype_matrix, window_size, step_size,
twopop_stats, populations, missing_data, span_normalize,
chrom=chrom)
if df1 is not None and df2 is not None:
# Both DataFrames share the canonical prefix; copy only the
# per-stat columns from df2 into df1.
for col in df2.columns:
if col not in CANONICAL_WINDOW_PREFIX:
df1[col] = df2[col].values
return df1
# Fused CUDA kernel path for more complex stat combinations.
fused_single = {'pi', 'theta_w', 'tajimas_d', 'segregating_sites',
'singletons', 'theta_h', 'fay_wu_h', 'max_daf'}
fused_two = {'fst', 'fst_hudson', 'fst_wc', 'dxy', 'da'}
fused_garud = {'garud_h1', 'garud_h12', 'garud_h123', 'garud_h2h1',
'haplotype_count'}
fused_selection = {'mean_nsl'}
fused_diploshic = {'snp_dist_mean', 'snp_dist_var', 'snp_dist_min',
'snp_dist_max', 'mu_var', 'mu_sfs', 'mu_ld',
'daf_hist', 'zns', 'omega',
'dist_var', 'dist_skew', 'dist_kurt'}
fused_all = (fused_single | fused_two | fused_garud | fused_selection
| fused_diploshic)
requested = set(statistics)
can_fuse = (missing_data == 'include'
and requested <= fused_all)
if can_fuse:
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
positions = haplotype_matrix.positions
if hasattr(positions, 'get'):
positions = positions.get()
positions = np.asarray(positions)
chrom_start = haplotype_matrix.chrom_start
chrom_end = haplotype_matrix.chrom_end
if chrom_start is None:
chrom_start = int(positions[0])
if chrom_end is None:
chrom_end = int(positions[-1])
chrom_start = int(chrom_start)
chrom_end = int(chrom_end)
# Build window start/stop arrays (supports overlapping windows)
win_starts = np.arange(chrom_start, chrom_end, step_size,
dtype=np.float64)
win_stops = win_starts + window_size
# Build equivalent bp_bins for _compute_window_ranges
bp_bins = np.concatenate([win_starts, [win_stops[-1]]])
pop1 = populations[0] if populations and len(populations) >= 1 else None
pop2 = populations[1] if populations and len(populations) >= 2 else None
population = pop1 # single-pop stats use the first population
# Choose chunked or single-shot fused based on memory
n_hap = haplotype_matrix.num_haplotypes
n_var = haplotype_matrix.num_variants
transpose_bytes = n_var * n_hap # int8
free_mem = cp.cuda.Device().mem_info[0]
use_chunked = transpose_bytes * 2 > free_mem * 0.7
fused_fn = (windowed_statistics_fused_chunked if use_chunked
else windowed_statistics_fused)
result_dict = fused_fn(
haplotype_matrix,
bp_bins=bp_bins,
statistics=tuple(statistics),
population=population,
pop1=pop1,
pop2=pop2,
per_base=(span_normalize is not False),
_win_starts=win_starts,
_win_stops=win_stops,
missing_data=missing_data,
chrom=chrom,
)
return pd.DataFrame(result_dict)
# Fallback: per-window Python loop
analyzer = WindowedAnalyzer(
window_type='bp',
window_size=window_size,
step_size=step_size,
statistics=statistics,
populations=populations,
missing_data=missing_data,
span_normalize=span_normalize,
**kwargs
)
return analyzer.compute(haplotype_matrix)
# Helper functions for built-in statistics
def _compute_mean_r2(matrix: HaplotypeMatrix, max_distance: int,
**kwargs) -> float:
"""Compute mean r² for variant pairs within a genomic distance."""
r2_matrix = matrix.pairwise_r2()
positions = matrix.positions
pos_i, pos_j = cp.meshgrid(positions, positions, indexing='ij')
distances = cp.abs(pos_j - pos_i)
mask = (distances > 0) & (distances <= max_distance)
if cp.any(mask):
return float(cp.mean(r2_matrix[mask]).get())
return np.nan
# ---------------------------------------------------------------------------
# Fused CUDA kernel: one block per window, all stats in one pass
# ---------------------------------------------------------------------------
# Haplotype data is transposed before kernel launch so variants are the
# leading dimension (column-major for haplotype access). This ensures
# coalesced memory reads when threads iterate over haplotypes.
_fused_windowed_kernel_v2 = cp.RawKernel(r'''
extern "C" __global__
void fused_windowed_stats_v2(const signed char* hap_t,
const long long* win_start,
const long long* win_stop,
int n_hap, int n_total_var, int n_windows,
double* out_mpd_sum,
double* out_seg_count,
double* out_sing_count,
double* out_var_count,
double* out_theta_h_sum,
double* out_max_daf) {
int wid = blockIdx.x;
if (wid >= n_windows) return;
int v_start = (int)win_start[wid];
int v_stop = (int)win_stop[wid];
int n_vars = v_stop - v_start;
if (n_vars <= 0) {
if (threadIdx.x == 0) {
out_mpd_sum[wid] = 0.0;
out_seg_count[wid] = 0.0;
out_sing_count[wid] = 0.0;
out_var_count[wid] = 0.0;
out_theta_h_sum[wid] = 0.0;
out_max_daf[wid] = 0.0;
}
return;
}
double dn = (double)n_hap;
double t_mpd = 0.0, t_seg = 0.0, t_sing = 0.0;
double t_count = 0.0, t_theta_h = 0.0, t_max_daf = 0.0;
for (int vi = threadIdx.x; vi < n_vars; vi += blockDim.x) {
int v = v_start + vi;
int dac = 0;
const signed char* row = hap_t + (long long)v * n_hap;
for (int h = 0; h < n_hap; h++) {
if (row[h] > 0) dac++;
}
double p = (double)dac / dn;
t_mpd += 2.0 * p * (1.0 - p) * dn / (dn - 1.0);
int is_seg = (dac > 0 && dac < n_hap) ? 1 : 0;
t_seg += is_seg;
t_sing += (dac == 1 || dac == n_hap - 1) ? 1.0 : 0.0;
t_count += 1.0;
if (is_seg) {
t_theta_h += 2.0 * (double)dac * (double)dac / (dn * (dn - 1.0));
}
if (p > t_max_daf) t_max_daf = p;
}
// Sum reduction for 5 accumulators
__shared__ double smem[6 * 256];
int tid = threadIdx.x;
smem[tid] = t_mpd;
smem[256 + tid] = t_seg;
smem[512 + tid] = t_sing;
smem[768 + tid] = t_count;
smem[1024 + tid] = t_theta_h;
smem[1280 + tid] = t_max_daf; // will be max-reduced
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
smem[256 + tid] += smem[256 + tid + s];
smem[512 + tid] += smem[512 + tid + s];
smem[768 + tid] += smem[768 + tid + s];
smem[1024 + tid] += smem[1024 + tid + s];
// max for max_daf
if (smem[1280 + tid + s] > smem[1280 + tid])
smem[1280 + tid] = smem[1280 + tid + s];
}
__syncthreads();
}
if (tid == 0) {
out_mpd_sum[wid] = smem[0];
out_seg_count[wid] = smem[256];
out_sing_count[wid] = smem[512];
out_var_count[wid] = smem[768];
out_theta_h_sum[wid] = smem[1024];
out_max_daf[wid] = smem[1280];
}
}
''', 'fused_windowed_stats_v2')
# Two-population fused kernel for FST and Dxy
_fused_windowed_twopop_kernel = cp.RawKernel(r'''
extern "C" __global__
void fused_windowed_twopop(const signed char* hap1_t,
const signed char* hap2_t,
const long long* win_start,
const long long* win_stop,
int n_hap1, int n_hap2,
int n_total_var, int n_windows,
double* out_fst_num,
double* out_fst_den,
double* out_dxy_sum,
double* out_pi1_sum,
double* out_pi2_sum,
double* out_wc_a_sum,
double* out_wc_ab_sum) {
int wid = blockIdx.x;
if (wid >= n_windows) return;
int v_start = (int)win_start[wid];
int v_stop = (int)win_stop[wid];
int n_vars = v_stop - v_start;
if (n_vars <= 0) {
if (threadIdx.x == 0) {
out_fst_num[wid] = 0.0; out_fst_den[wid] = 0.0;
out_dxy_sum[wid] = 0.0; out_pi1_sum[wid] = 0.0;
out_pi2_sum[wid] = 0.0; out_wc_a_sum[wid] = 0.0;
out_wc_ab_sum[wid] = 0.0;
}
return;
}
double t_fst_num = 0.0, t_fst_den = 0.0;
double t_dxy = 0.0, t_pi1 = 0.0, t_pi2 = 0.0;
double t_wc_a = 0.0, t_wc_ab = 0.0;
for (int vi = threadIdx.x; vi < n_vars; vi += blockDim.x) {
int v = v_start + vi;
// Count valid (non-missing) samples and alt alleles per site
int ac1_1 = 0, valid1 = 0;
const signed char* row1 = hap1_t + (long long)v * n_hap1;
for (int h = 0; h < n_hap1; h++) {
signed char a = row1[h];
if (a >= 0) { valid1++; if (a > 0) ac1_1++; }
}
int ac2_1 = 0, valid2 = 0;
const signed char* row2 = hap2_t + (long long)v * n_hap2;
for (int h = 0; h < n_hap2; h++) {
signed char a = row2[h];
if (a >= 0) { valid2++; if (a > 0) ac2_1++; }
}
if (valid1 == 0 || valid2 == 0) continue;
double dn1 = (double)valid1;
double dn2 = (double)valid2;
double ac1_0 = dn1 - ac1_1;
double ac2_0 = dn2 - ac2_1;
double d_ac1_1 = (double)ac1_1;
double d_ac2_1 = (double)ac2_1;
// Hudson: within-pop mean pairwise difference
double n1_pairs = dn1 * (dn1 - 1.0) / 2.0;
double n1_same = (ac1_0 * (ac1_0 - 1.0) + d_ac1_1 * (d_ac1_1 - 1.0)) / 2.0;
double mpd1 = (n1_pairs > 0) ? (n1_pairs - n1_same) / n1_pairs : 0.0;
double n2_pairs = dn2 * (dn2 - 1.0) / 2.0;
double n2_same = (ac2_0 * (ac2_0 - 1.0) + d_ac2_1 * (d_ac2_1 - 1.0)) / 2.0;
double mpd2 = (n2_pairs > 0) ? (n2_pairs - n2_same) / n2_pairs : 0.0;
double within = (mpd1 + mpd2) / 2.0;
// Between-pop mean pairwise difference
double n_between = dn1 * dn2;
double n_between_same = ac1_0 * ac2_0 + d_ac1_1 * d_ac2_1;
double between = (n_between > 0) ? (n_between - n_between_same) / n_between : 0.0;
t_fst_num += between - within;
t_fst_den += between;
t_dxy += between;
t_pi1 += mpd1;
t_pi2 += mpd2;
// Weir-Cockerham (haploid, h_bar=0, r=2, per-site sample sizes)
double n_total = dn1 + dn2;
double n_bar = n_total / 2.0;
double n_C = (n_total - (dn1*dn1 + dn2*dn2) / n_total);
double p1 = d_ac1_1 / dn1;
double p2 = d_ac2_1 / dn2;
double p_bar = (dn1 * p1 + dn2 * p2) / n_total;
double s2 = (dn1 * (p1 - p_bar) * (p1 - p_bar) +
dn2 * (p2 - p_bar) * (p2 - p_bar)) / n_bar;
double pq = p_bar * (1.0 - p_bar);
double a_val = 0.0, b_val = 0.0;
if (n_bar > 1.0 && n_C > 0.0) {
a_val = (n_bar / n_C) * (s2 - (1.0 / (n_bar - 1.0)) * (pq - s2 / 2.0));
b_val = (n_bar / (n_bar - 1.0)) * (pq - s2 / 2.0);
}
t_wc_a += a_val;
t_wc_ab += a_val + b_val;
}
// Block reduction (7 values)
__shared__ double smem[7 * 256];
int tid = threadIdx.x;
smem[tid] = t_fst_num;
smem[256 + tid] = t_fst_den;
smem[512 + tid] = t_dxy;
smem[768 + tid] = t_pi1;
smem[1024 + tid] = t_pi2;
smem[1280 + tid] = t_wc_a;
smem[1536 + tid] = t_wc_ab;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
smem[256 + tid] += smem[256 + tid + s];
smem[512 + tid] += smem[512 + tid + s];
smem[768 + tid] += smem[768 + tid + s];
smem[1024 + tid] += smem[1024 + tid + s];
smem[1280 + tid] += smem[1280 + tid + s];
smem[1536 + tid] += smem[1536 + tid + s];
}
__syncthreads();
}
if (tid == 0) {
out_fst_num[wid] = smem[0];
out_fst_den[wid] = smem[256];
out_dxy_sum[wid] = smem[512];
out_pi1_sum[wid] = smem[768];
out_pi2_sum[wid] = smem[1024];
out_wc_a_sum[wid] = smem[1280];
out_wc_ab_sum[wid]= smem[1536];
}
}
''', 'fused_windowed_twopop')
# Garud's H fused kernel: one block per window, sorts haplotype hashes
# in shared memory to count unique patterns and compute H statistics.
_fused_garud_h_kernel = cp.RawKernel(r'''
extern "C" __global__
void fused_garud_h(const double* hash1, // (n_windows, n_hap)
const double* hash2, // (n_windows, n_hap)
int n_hap, int n_windows, double tol,
double* out_h1, double* out_h12,
double* out_h123, double* out_h2h1,
double* out_n_distinct) {
int wid = blockIdx.x;
if (wid >= n_windows) return;
int tid = threadIdx.x;
// Load hashes into shared memory for sorting. The launcher picks
// blockDim.x = ceil(n_hap / 2) rounded up to a power of two so the
// odd-even sort has enough threads for every compare-and-swap, but
// that leaves blockDim.x < n_hap for many n_hap values -- a strided
// load is needed to cover all elements.
extern __shared__ double shm[];
double* s_h1 = shm; // n_hap doubles
double* s_h2 = shm + n_hap; // n_hap doubles
for (int i = tid; i < n_hap; i += blockDim.x) {
s_h1[i] = hash1[wid * n_hap + i];
s_h2[i] = hash2[wid * n_hap + i];
}
__syncthreads();
// Simple odd-even sort on hash1 (secondary on hash2 for ties)
// For n_hap <= 256 this is fast in shared memory
for (int phase = 0; phase < n_hap; phase++) {
int i = 2 * tid + (phase & 1);
if (i + 1 < n_hap) {
bool do_swap = false;
if (s_h1[i] > s_h1[i + 1]) {
do_swap = true;
} else if (s_h1[i] == s_h1[i + 1] && s_h2[i] > s_h2[i + 1]) {
do_swap = true;
}
if (do_swap) {
double tmp;
tmp = s_h1[i]; s_h1[i] = s_h1[i+1]; s_h1[i+1] = tmp;
tmp = s_h2[i]; s_h2[i] = s_h2[i+1]; s_h2[i+1] = tmp;
}
}
__syncthreads();
}
// Thread 0: count unique haplotypes, compute frequencies, derive H stats
if (tid == 0) {
// Count distinct haplotypes and collect top-3 frequencies
double inv_n = 1.0 / (double)n_hap;
// Walk sorted array, count runs
// We need: sum(f_i^2), and the top 3 frequencies
double sum_f2 = 0.0;
double top3[3] = {0.0, 0.0, 0.0};
int run_len = 1;
int n_distinct = 0;
for (int i = 1; i <= n_hap; i++) {
bool boundary = (i == n_hap);
if (!boundary) {
double d1 = s_h1[i] - s_h1[i-1];
double d2 = s_h2[i] - s_h2[i-1];
if (d1 < 0) d1 = -d1;
if (d2 < 0) d2 = -d2;
boundary = (d1 > tol) || (d2 > tol);
}
if (boundary) {
n_distinct++;
double f = (double)run_len * inv_n;
sum_f2 += f * f;
if (f > top3[0]) {
top3[2] = top3[1]; top3[1] = top3[0]; top3[0] = f;
} else if (f > top3[1]) {
top3[2] = top3[1]; top3[1] = f;
} else if (f > top3[2]) {
top3[2] = f;
}
run_len = 1;
} else {
run_len++;
}
}
double h1_val = sum_f2;
double h12_val = (top3[0] + top3[1]) * (top3[0] + top3[1])
+ (sum_f2 - top3[0]*top3[0] - top3[1]*top3[1]);
double h123_val = (top3[0] + top3[1] + top3[2]) * (top3[0] + top3[1] + top3[2])
+ (sum_f2 - top3[0]*top3[0] - top3[1]*top3[1] - top3[2]*top3[2]);
double h2_val = h1_val - top3[0] * top3[0];
double h2h1_val = (h1_val > 0.0) ? h2_val / h1_val : 0.0;
out_h1[wid] = h1_val;
out_h12[wid] = h12_val;
out_h123[wid] = h123_val;
out_h2h1[wid] = h2h1_val;
out_n_distinct[wid] = (double)n_distinct;
}
}
''', 'fused_garud_h')
def _compute_window_ranges(positions, bp_bins):
"""Map window edges to variant index ranges using searchsorted.
Returns (win_start, win_stop) as CuPy int64 arrays where
win_start[i]:win_stop[i] are the variant indices in window i.
"""
n_windows = len(bp_bins) - 1
win_start = cp.searchsorted(positions, bp_bins[:-1], side='left')
win_stop = cp.searchsorted(positions, bp_bins[1:], side='left')
return win_start.astype(cp.int64), win_stop.astype(cp.int64)
[docs]
def windowed_statistics_fused(haplotype_matrix: HaplotypeMatrix,
bp_bins,
statistics=('pi', 'theta_w', 'tajimas_d',
'segregating_sites', 'singletons'),
population=None,
pop1=None,
pop2=None,
per_base: bool = True,
is_accessible=None,
_win_starts=None,
_win_stops=None,
missing_data='include',
chrom=None):
"""GPU-native windowed statistics using fused CUDA kernels.
One kernel launch processes ALL windows in parallel. Each thread block
handles one window, with threads cooperatively reducing over variants.
Reads the haplotype matrix once and computes all statistics simultaneously.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data.
bp_bins : array_like
Window edges in base pairs. N+1 edges define N windows.
statistics : tuple of str
Statistics to compute. Single-pop: 'pi', 'theta_w', 'tajimas_d',
'segregating_sites', 'singletons'. Two-pop: 'fst', 'dxy'.
population : str or list, optional
Population for single-pop statistics.
pop1, pop2 : str or list, optional
Populations for two-pop statistics.
per_base : bool
Normalize by window size in base pairs.
is_accessible : array_like, optional
Accessibility mask for per-base normalization.
chrom : str or int, optional
Chromosome label emitted in the ``chrom`` output column.
Returns
-------
dict
Maps column names to numpy arrays of shape (n_windows,). The first
six columns are always ``chrom, start, end, center, n_variants,
window_id``, followed by one column per requested statistic.
"""
from ._utils import get_population_matrix
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if population is not None:
matrix = get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
hap_raw = matrix.haplotypes
n_hap = hap_raw.shape[0]
n_total_var = hap_raw.shape[1]
# transpose for coalesced kernel access: (n_total_var, n_hap)
hap = cp.ascontiguousarray(hap_raw.T.astype(cp.int8))
positions = matrix.positions
if not isinstance(positions, cp.ndarray):
positions = cp.asarray(positions)
# Support overlapping windows via explicit start/stop arrays
if _win_starts is not None and _win_stops is not None:
ws_gpu = cp.asarray(_win_starts, dtype=cp.float64)
we_gpu = cp.asarray(_win_stops, dtype=cp.float64)
n_windows = len(ws_gpu)
win_start = cp.searchsorted(positions, ws_gpu, side='left').astype(cp.int64)
win_stop = cp.searchsorted(positions, we_gpu, side='left').astype(cp.int64)
else:
bp_bins_gpu = cp.asarray(bp_bins, dtype=cp.float64)
n_windows = len(bp_bins_gpu) - 1
win_start, win_stop = _compute_window_ranges(positions, bp_bins_gpu)
ws_gpu = bp_bins_gpu[:-1]
we_gpu = bp_bins_gpu[1:]
ws_cpu = ws_gpu.get() if hasattr(ws_gpu, 'get') else ws_gpu
we_cpu = we_gpu.get() if hasattr(we_gpu, 'get') else we_gpu
# Raw variant-index span is the fallback n_variants; single-pop kernels
# overwrite below with a missing-data-aware count.
results = _init_window_results(
chrom, ws_cpu, we_cpu,
n_variants=(win_stop - win_start).get())
if per_base:
window_bases = _compute_window_bases(
haplotype_matrix, results['start'],
results['end'], is_accessible)
# single-pop stats via fused kernel
single_pop_stats = {'pi', 'theta_w', 'tajimas_d', 'segregating_sites',
'singletons', 'theta_h', 'fay_wu_h', 'max_daf'}
single_pop_requested = any(s in statistics for s in single_pop_stats)
if single_pop_requested:
out_mpd = cp.zeros(n_windows, dtype=cp.float64)
out_seg = cp.zeros(n_windows, dtype=cp.float64)
out_sing = cp.zeros(n_windows, dtype=cp.float64)
out_count = cp.zeros(n_windows, dtype=cp.float64)
out_theta_h = cp.zeros(n_windows, dtype=cp.float64)
out_max_daf = cp.zeros(n_windows, dtype=cp.float64)
block = 256
grid = n_windows
_fused_windowed_kernel_v2(
(grid,), (block,),
(hap, win_start, win_stop,
np.int32(n_hap), np.int32(n_total_var), np.int32(n_windows),
out_mpd, out_seg, out_sing, out_count, out_theta_h, out_max_daf))
mpd_sum = out_mpd.get()
seg_count = out_seg.get()
sing_count = out_sing.get()
var_count = out_count.get()
theta_h_sum = out_theta_h.get()
max_daf = out_max_daf.get()
results['n_variants'] = var_count.astype(int)
if 'pi' in statistics:
if per_base:
results['pi'] = np.where(window_bases > 0,
mpd_sum / window_bases, np.nan)
else:
results['pi'] = mpd_sum
if 'theta_w' in statistics:
a1 = np.sum(1.0 / np.arange(1, n_hap))
theta_abs = seg_count / a1
if per_base:
results['theta_w'] = np.where(window_bases > 0,
theta_abs / window_bases, np.nan)
else:
results['theta_w'] = theta_abs
if 'segregating_sites' in statistics:
results['segregating_sites'] = seg_count.astype(int)
if 'singletons' in statistics:
results['singletons'] = sing_count.astype(int)
if 'tajimas_d' in statistics:
n = n_hap
a1 = np.sum(1.0 / np.arange(1, n))
a2 = np.sum(1.0 / np.arange(1, n) ** 2)
b1 = (n + 1) / (3 * (n - 1))
b2 = 2 * (n ** 2 + n + 3) / (9 * n * (n - 1))
c1 = b1 - 1 / a1
c2 = b2 - (n + 2) / (a1 * n) + a2 / a1 ** 2
e1 = c1 / a1
e2 = c2 / (a1 ** 2 + a2)
S = seg_count
d_num = mpd_sum - S / a1
d_var = e1 * S + e2 * S * (S - 1)
d_std = np.sqrt(np.maximum(d_var, 0))
tajd = np.where(d_std > 0, d_num / d_std, np.nan)
tajd[S < 3] = np.nan
results['tajimas_d'] = tajd
if 'theta_h' in statistics:
if per_base:
results['theta_h'] = np.where(window_bases > 0,
theta_h_sum / window_bases, np.nan)
else:
results['theta_h'] = theta_h_sum
if 'fay_wu_h' in statistics:
# H = pi - theta_H (absolute, unnormalized)
results['fay_wu_h'] = mpd_sum - theta_h_sum
if 'max_daf' in statistics:
results['max_daf'] = max_daf
# two-pop stats via fused kernel
two_pop_stats = {'fst', 'fst_hudson', 'fst_wc', 'dxy', 'da'}
two_pop_requested = any(s in statistics for s in two_pop_stats)
if two_pop_requested:
if pop1 is None or pop2 is None:
raise ValueError("pop1 and pop2 required for fst/dxy/da")
# Use the original (unsubsetted) matrix for population lookup
m1 = get_population_matrix(haplotype_matrix, pop1)
m2 = get_population_matrix(haplotype_matrix, pop2)
if m1.device == 'CPU':
m1.transfer_to_gpu()
if m2.device == 'CPU':
m2.transfer_to_gpu()
n1 = m1.haplotypes.shape[0]
n2 = m2.haplotypes.shape[0]
hap1 = cp.ascontiguousarray(m1.haplotypes.T.astype(cp.int8))
hap2 = cp.ascontiguousarray(m2.haplotypes.T.astype(cp.int8))
out_fst_num = cp.zeros(n_windows, dtype=cp.float64)
out_fst_den = cp.zeros(n_windows, dtype=cp.float64)
out_dxy = cp.zeros(n_windows, dtype=cp.float64)
out_pi1 = cp.zeros(n_windows, dtype=cp.float64)
out_pi2 = cp.zeros(n_windows, dtype=cp.float64)
out_wc_a = cp.zeros(n_windows, dtype=cp.float64)
out_wc_ab = cp.zeros(n_windows, dtype=cp.float64)
block = 256
grid = n_windows
_fused_windowed_twopop_kernel(
(grid,), (block,),
(hap1, hap2, win_start, win_stop,
np.int32(n1), np.int32(n2),
np.int32(n_total_var), np.int32(n_windows),
out_fst_num, out_fst_den, out_dxy, out_pi1, out_pi2,
out_wc_a, out_wc_ab))
# Post-process on GPU, single .get() per result
if 'fst' in statistics or 'fst_hudson' in statistics:
hudson_fst = cp.where(out_fst_den > 0,
out_fst_num / out_fst_den, cp.nan).get()
if 'fst' in statistics:
results['fst'] = hudson_fst
if 'fst_hudson' in statistics:
results['fst_hudson'] = hudson_fst
if 'fst_wc' in statistics:
results['fst_wc'] = cp.where(out_wc_ab > 0,
out_wc_a / out_wc_ab, cp.nan).get()
if 'dxy' in statistics:
if per_base:
wb = cp.asarray(window_bases)
results['dxy'] = cp.where(wb > 0,
out_dxy / wb, cp.nan).get()
else:
results['dxy'] = out_dxy.get()
if 'da' in statistics:
da_sum = out_dxy - (out_pi1 + out_pi2) / 2.0
if per_base:
wb = cp.asarray(window_bases)
results['da'] = cp.where(wb > 0,
da_sum / wb, cp.nan).get()
else:
results['da'] = da_sum.get()
# Garud's H via fused kernel (SNP windows using prefix-sum hashing)
garud_stats = {'garud_h1', 'garud_h12', 'garud_h123', 'garud_h2h1',
'haplotype_count'}
garud_requested = any(s in statistics for s in garud_stats)
if garud_requested:
# `matrix` is already population-subsetted above (line 1722); pass
# population=None to avoid a double lookup that would fail because
# the subsetted matrix has sample_sets={'all': ...}, not the original
# population names.
_compute_fused_garud_h(matrix, None,
win_start, win_stop, n_windows, statistics,
results)
# Per-site stats binned into windows via scatter_add
bin_idx = cp.searchsorted(we_gpu, positions)
in_range = (bin_idx >= 0) & (bin_idx < n_windows)
# Shared DAC computation (used by daf_hist and mu_sfs)
dac_gpu = None
if any(s in statistics for s in ('daf_hist', 'mu_sfs')):
dac_gpu = cp.sum(cp.maximum(matrix.haplotypes, 0).astype(cp.int32), axis=0)
if 'mean_nsl' in statistics:
from . import selection as sel
# matrix is already population-subsetted (line 1533); don't re-subset
nsl_gpu = cp.asarray(sel.nsl(matrix))
valid = cp.isfinite(nsl_gpu) & in_range
results['mean_nsl'] = _windowed_mean(nsl_gpu, bin_idx, valid, n_windows)
# SNP distance stats per window
snp_dist_stats = {'snp_dist_mean', 'snp_dist_var', 'snp_dist_min',
'snp_dist_max', 'mu_var'}
if any(s in statistics for s in snp_dist_stats):
ws_np, we_np = win_start.get(), win_stop.get()
pos_cpu = positions.get() if hasattr(positions, 'get') else np.asarray(positions)
sd_mean = np.full(n_windows, np.nan)
sd_var = np.full(n_windows, np.nan)
sd_min = np.full(n_windows, np.nan)
sd_max = np.full(n_windows, np.nan)
mu_var_arr = np.full(n_windows, np.nan)
for wi in range(n_windows):
s, e = int(ws_np[wi]), int(we_np[wi])
if e - s < 2:
if e - s == 1:
if per_base:
mu_var_arr[wi] = (1.0 / window_bases[wi]
if window_bases[wi] > 0 else np.nan)
else:
mu_var_arr[wi] = 1.0
continue
win_pos = pos_cpu[s:e]
gaps = np.diff(win_pos).astype(np.float64)
sd_mean[wi] = np.mean(gaps)
sd_var[wi] = np.var(gaps)
sd_min[wi] = np.min(gaps)
sd_max[wi] = np.max(gaps)
if per_base:
mu_var_arr[wi] = (len(win_pos) / window_bases[wi]
if window_bases[wi] > 0 else np.nan)
else:
mu_var_arr[wi] = float(len(win_pos))
for stat, arr in [('snp_dist_mean', sd_mean), ('snp_dist_var', sd_var),
('snp_dist_min', sd_min), ('snp_dist_max', sd_max),
('mu_var', mu_var_arr)]:
if stat in statistics:
results[stat] = arr
# DAF histogram per window (GPU scatter)
if 'daf_hist' in statistics:
n_daf_bins = 20
daf = dac_gpu.astype(cp.float64) / n_hap
daf_bin = cp.minimum((daf * n_daf_bins).astype(cp.int32), n_daf_bins - 1)
# Composite index: window * n_daf_bins + daf_bin
composite = bin_idx * n_daf_bins + daf_bin
valid_daf = in_range
flat = _scatter_sum(cp.ones_like(composite[valid_daf], dtype=cp.float64),
composite[valid_daf], n_windows * n_daf_bins)
hist_matrix = flat.get().reshape(n_windows, n_daf_bins)
for b in range(n_daf_bins):
results[f'daf_bin_{b}'] = hist_matrix[:, b]
# muSFS: fraction of SNPs at SFS edges
if 'mu_sfs' in statistics:
is_edge = ((dac_gpu == 1) | (dac_gpu == n_hap - 1)).astype(cp.float64)
edge_sum = _scatter_sum(is_edge[in_range], bin_idx[in_range], n_windows)
total_count = _bin_counts(bin_idx[in_range], n_windows)
edge_cpu = edge_sum.get()
count_cpu = total_count.get()
mu_sfs = np.where(count_cpu > 0, edge_cpu / count_cpu, np.nan)
results['mu_sfs'] = mu_sfs
# Per-window pairwise stats (LD, distance moments)
ld_pairwise = {'zns', 'omega', 'mu_ld'}
dist_pairwise = {'dist_var', 'dist_skew', 'dist_kurt'}
perwin_stats = {s for s in (ld_pairwise | dist_pairwise) if s in statistics}
if perwin_stats:
from . import ld_statistics
from . import distance_stats
ws_np, we_np = win_start.get(), win_stop.get()
stat_arrays = {s: np.full(n_windows, np.nan) for s in perwin_stats}
need_dist = bool(perwin_stats & dist_pairwise)
need_winmat = ('omega' in stat_arrays or 'mu_ld' in stat_arrays
or need_dist)
# Precompute for fused ZnS path
if 'zns' in stat_arrays:
hap = matrix.haplotypes
hap_clean = cp.where(hap >= 0, hap, 0).astype(cp.float64)
valid_mask = (hap >= 0).astype(cp.float64)
for wi in range(n_windows):
s, e = int(ws_np[wi]), int(we_np[wi])
if e - s < 4:
continue
if 'zns' in stat_arrays:
# Per-window zns defaults to the unbiased sigma_d2
# estimator (Ragsdale & Gravel 2019), matching the new
# default of ld_statistics.zns(estimator='auto') on
# HaplotypeMatrix inputs.
stat_arrays['zns'][wi] = ld_statistics._zns_from_precomputed(
hap_clean, valid_mask, s, e,
use_projection=True)
if need_winmat:
win_mat = HaplotypeMatrix(matrix.haplotypes[:, s:e],
matrix.positions[s:e])
if 'omega' in stat_arrays:
stat_arrays['omega'][wi] = ld_statistics.omega(
win_mat, missing_data=missing_data)
if 'mu_ld' in stat_arrays:
stat_arrays['mu_ld'][wi] = ld_statistics.mu_ld(win_mat)
if need_dist:
v, sk, ku = distance_stats.dist_moments(win_mat)
if 'dist_var' in stat_arrays:
stat_arrays['dist_var'][wi] = v
if 'dist_skew' in stat_arrays:
stat_arrays['dist_skew'][wi] = sk
if 'dist_kurt' in stat_arrays:
stat_arrays['dist_kurt'][wi] = ku
results.update(stat_arrays)
return results
def windowed_statistics_fused_chunked(haplotype_matrix: HaplotypeMatrix,
bp_bins,
statistics=('pi', 'theta_w', 'tajimas_d'),
population=None,
pop1=None,
pop2=None,
per_base: bool = True,
is_accessible=None,
_win_starts=None,
_win_stops=None,
missing_data='include',
chrom=None):
"""Chunked fused windowed statistics for data too large for a single pass.
Same interface and results as windowed_statistics_fused(), but splits the
variant axis into memory-safe chunks. Each chunk is transposed and fed to
the existing fused CUDA kernels. Partial results are accumulated across
chunks (all kernel outputs are additive sums, except max_daf which uses
element-wise max).
"""
from ._utils import get_population_matrix
from ._memutil import estimate_fused_chunk_size, free_gpu_pool
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
if population is not None:
matrix = get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
hap_raw = matrix.haplotypes
n_hap = hap_raw.shape[0]
n_total_var = hap_raw.shape[1]
positions = matrix.positions
if not isinstance(positions, cp.ndarray):
positions = cp.asarray(positions)
# Window ranges (same setup as non-chunked)
if _win_starts is not None and _win_stops is not None:
ws_gpu = cp.asarray(_win_starts, dtype=cp.float64)
we_gpu = cp.asarray(_win_stops, dtype=cp.float64)
n_windows = len(ws_gpu)
win_start = cp.searchsorted(positions, ws_gpu, side='left').astype(cp.int64)
win_stop = cp.searchsorted(positions, we_gpu, side='left').astype(cp.int64)
else:
bp_bins_gpu = cp.asarray(bp_bins, dtype=cp.float64)
n_windows = len(bp_bins_gpu) - 1
win_start, win_stop = _compute_window_ranges(positions, bp_bins_gpu)
ws_gpu = bp_bins_gpu[:-1]
we_gpu = bp_bins_gpu[1:]
ws_cpu = ws_gpu.get() if hasattr(ws_gpu, 'get') else ws_gpu
we_cpu = we_gpu.get() if hasattr(we_gpu, 'get') else we_gpu
# Raw variant-index span is the fallback n_variants; single-pop kernels
# overwrite below with a missing-data-aware count.
results = _init_window_results(
chrom, ws_cpu, we_cpu,
n_variants=(win_stop - win_start).get())
if per_base:
window_bases = _compute_window_bases(
haplotype_matrix, results['start'],
results['end'], is_accessible)
# Determine chunk size
chunk_size = estimate_fused_chunk_size(n_hap)
# Ensure chunk is at least large enough for the largest window
max_win_variants = int((win_stop - win_start).max().get()) if n_windows > 0 else 0
chunk_size = max(chunk_size, max_win_variants + 1)
# ── Single-pop stats via chunked fused kernel ────────────────────────
single_pop_stats = {'pi', 'theta_w', 'tajimas_d', 'segregating_sites',
'singletons', 'theta_h', 'fay_wu_h', 'max_daf'}
single_pop_requested = any(s in statistics for s in single_pop_stats)
if single_pop_requested:
# Accumulators
acc_mpd = cp.zeros(n_windows, dtype=cp.float64)
acc_seg = cp.zeros(n_windows, dtype=cp.float64)
acc_sing = cp.zeros(n_windows, dtype=cp.float64)
acc_count = cp.zeros(n_windows, dtype=cp.float64)
acc_theta_h = cp.zeros(n_windows, dtype=cp.float64)
acc_max_daf = cp.zeros(n_windows, dtype=cp.float64)
for c_start in range(0, n_total_var, chunk_size):
c_end = min(c_start + chunk_size, n_total_var)
# Find windows overlapping this chunk
overlap = (win_start < c_end) & (win_stop > c_start)
w_idx = cp.where(overlap)[0]
if len(w_idx) == 0:
continue
# Clip window ranges to chunk boundaries
clipped_start = cp.maximum(win_start[w_idx], c_start) - c_start
clipped_stop = cp.minimum(win_stop[w_idx], c_end) - c_start
n_overlap = len(w_idx)
# Transpose chunk
hap_chunk_t = cp.ascontiguousarray(
hap_raw[:, c_start:c_end].T.astype(cp.int8))
n_chunk_var = c_end - c_start
# Per-chunk outputs
out_mpd = cp.zeros(n_overlap, dtype=cp.float64)
out_seg = cp.zeros(n_overlap, dtype=cp.float64)
out_sing = cp.zeros(n_overlap, dtype=cp.float64)
out_count = cp.zeros(n_overlap, dtype=cp.float64)
out_theta_h = cp.zeros(n_overlap, dtype=cp.float64)
out_max_daf = cp.zeros(n_overlap, dtype=cp.float64)
_fused_windowed_kernel_v2(
(int(n_overlap),), (256,),
(hap_chunk_t, clipped_start, clipped_stop,
np.int32(n_hap), np.int32(n_chunk_var), np.int32(n_overlap),
out_mpd, out_seg, out_sing, out_count, out_theta_h,
out_max_daf))
# Accumulate (all additive except max_daf)
cp.add.at(acc_mpd, w_idx, out_mpd)
cp.add.at(acc_seg, w_idx, out_seg)
cp.add.at(acc_sing, w_idx, out_sing)
cp.add.at(acc_count, w_idx, out_count)
cp.add.at(acc_theta_h, w_idx, out_theta_h)
acc_max_daf[w_idx] = cp.maximum(acc_max_daf[w_idx], out_max_daf)
del hap_chunk_t
free_gpu_pool()
# Post-processing (identical to non-chunked)
mpd_sum = acc_mpd.get()
seg_count = acc_seg.get()
sing_count = acc_sing.get()
var_count = acc_count.get()
theta_h_sum = acc_theta_h.get()
max_daf_arr = acc_max_daf.get()
results['n_variants'] = var_count.astype(int)
if 'pi' in statistics:
if per_base:
results['pi'] = np.where(window_bases > 0,
mpd_sum / window_bases, np.nan)
else:
results['pi'] = mpd_sum
if 'theta_w' in statistics:
a1 = np.sum(1.0 / np.arange(1, n_hap))
theta_abs = seg_count / a1
if per_base:
results['theta_w'] = np.where(window_bases > 0,
theta_abs / window_bases, np.nan)
else:
results['theta_w'] = theta_abs
if 'segregating_sites' in statistics:
results['segregating_sites'] = seg_count.astype(int)
if 'singletons' in statistics:
results['singletons'] = sing_count.astype(int)
if 'tajimas_d' in statistics:
n = n_hap
a1 = np.sum(1.0 / np.arange(1, n))
a2 = np.sum(1.0 / np.arange(1, n) ** 2)
b1 = (n + 1) / (3 * (n - 1))
b2 = 2 * (n ** 2 + n + 3) / (9 * n * (n - 1))
c1 = b1 - 1 / a1
c2 = b2 - (n + 2) / (a1 * n) + a2 / a1 ** 2
e1 = c1 / a1
e2 = c2 / (a1 ** 2 + a2)
S = seg_count
d_num = mpd_sum - S / a1
d_var = e1 * S + e2 * S * (S - 1)
d_std = np.sqrt(np.maximum(d_var, 0))
tajd = np.where(d_std > 0, d_num / d_std, np.nan)
tajd[S < 3] = np.nan
results['tajimas_d'] = tajd
if 'theta_h' in statistics:
if per_base:
results['theta_h'] = np.where(window_bases > 0,
theta_h_sum / window_bases, np.nan)
else:
results['theta_h'] = theta_h_sum
if 'fay_wu_h' in statistics:
results['fay_wu_h'] = mpd_sum - theta_h_sum
if 'max_daf' in statistics:
results['max_daf'] = max_daf_arr
# ── Two-pop stats via chunked fused kernel ───────────────────────────
two_pop_stats = {'fst', 'fst_hudson', 'fst_wc', 'dxy', 'da'}
two_pop_requested = any(s in statistics for s in two_pop_stats)
if two_pop_requested:
if pop1 is None or pop2 is None:
raise ValueError("pop1 and pop2 required for fst/dxy/da")
# Get population haplotype indices from the original (unsubsetted) matrix
pop1_idx = haplotype_matrix.sample_sets[pop1]
pop2_idx = haplotype_matrix.sample_sets[pop2]
n1 = len(pop1_idx)
n2 = len(pop2_idx)
# Chunk size based on the larger population
twopop_chunk = estimate_fused_chunk_size(max(n1, n2))
twopop_chunk = max(twopop_chunk, max_win_variants + 1)
acc_fst_num = cp.zeros(n_windows, dtype=cp.float64)
acc_fst_den = cp.zeros(n_windows, dtype=cp.float64)
acc_dxy = cp.zeros(n_windows, dtype=cp.float64)
acc_pi1 = cp.zeros(n_windows, dtype=cp.float64)
acc_pi2 = cp.zeros(n_windows, dtype=cp.float64)
acc_wc_a = cp.zeros(n_windows, dtype=cp.float64)
acc_wc_ab = cp.zeros(n_windows, dtype=cp.float64)
for c_start in range(0, n_total_var, twopop_chunk):
c_end = min(c_start + twopop_chunk, n_total_var)
overlap = (win_start < c_end) & (win_stop > c_start)
w_idx = cp.where(overlap)[0]
if len(w_idx) == 0:
continue
clipped_start = cp.maximum(win_start[w_idx], c_start) - c_start
clipped_stop = cp.minimum(win_stop[w_idx], c_end) - c_start
n_overlap = len(w_idx)
hap1_t = cp.ascontiguousarray(
hap_raw[pop1_idx, c_start:c_end].T.astype(cp.int8))
hap2_t = cp.ascontiguousarray(
hap_raw[pop2_idx, c_start:c_end].T.astype(cp.int8))
n_chunk_var = c_end - c_start
out_fst_num = cp.zeros(n_overlap, dtype=cp.float64)
out_fst_den = cp.zeros(n_overlap, dtype=cp.float64)
out_dxy = cp.zeros(n_overlap, dtype=cp.float64)
out_pi1 = cp.zeros(n_overlap, dtype=cp.float64)
out_pi2 = cp.zeros(n_overlap, dtype=cp.float64)
out_wc_a = cp.zeros(n_overlap, dtype=cp.float64)
out_wc_ab = cp.zeros(n_overlap, dtype=cp.float64)
_fused_windowed_twopop_kernel(
(int(n_overlap),), (256,),
(hap1_t, hap2_t, clipped_start, clipped_stop,
np.int32(n1), np.int32(n2),
np.int32(n_chunk_var), np.int32(n_overlap),
out_fst_num, out_fst_den, out_dxy, out_pi1, out_pi2,
out_wc_a, out_wc_ab))
cp.add.at(acc_fst_num, w_idx, out_fst_num)
cp.add.at(acc_fst_den, w_idx, out_fst_den)
cp.add.at(acc_dxy, w_idx, out_dxy)
cp.add.at(acc_pi1, w_idx, out_pi1)
cp.add.at(acc_pi2, w_idx, out_pi2)
cp.add.at(acc_wc_a, w_idx, out_wc_a)
cp.add.at(acc_wc_ab, w_idx, out_wc_ab)
del hap1_t, hap2_t
free_gpu_pool()
# Post-process on GPU, single .get() per result
if 'fst' in statistics or 'fst_hudson' in statistics:
hudson_fst = cp.where(acc_fst_den > 0,
acc_fst_num / acc_fst_den, cp.nan).get()
if 'fst' in statistics:
results['fst'] = hudson_fst
if 'fst_hudson' in statistics:
results['fst_hudson'] = hudson_fst
if 'fst_wc' in statistics:
results['fst_wc'] = cp.where(acc_wc_ab > 0,
acc_wc_a / acc_wc_ab, cp.nan).get()
if 'dxy' in statistics:
if per_base:
wb = cp.asarray(window_bases)
results['dxy'] = cp.where(wb > 0,
acc_dxy / wb, cp.nan).get()
else:
results['dxy'] = acc_dxy.get()
if 'da' in statistics:
da_sum = acc_dxy - (acc_pi1 + acc_pi2) / 2.0
if per_base:
wb = cp.asarray(window_bases)
results['da'] = cp.where(wb > 0,
da_sum / wb, cp.nan).get()
else:
results['da'] = da_sum.get()
# Delegate Garud H / scatter-add stats / per-window LD to the
# non-chunked function (they already handle large data or operate
# per-window).
garud_stats = {'garud_h1', 'garud_h12', 'garud_h123', 'garud_h2h1',
'haplotype_count'}
scatter_stats = {'mean_nsl', 'daf_hist', 'mu_sfs', 'snp_dist_mean',
'snp_dist_var', 'snp_dist_min', 'snp_dist_max',
'mu_var', 'zns', 'omega', 'mu_ld', 'dist_var',
'dist_skew', 'dist_kurt'}
remaining = set(statistics) & (garud_stats | scatter_stats)
if remaining:
# Call the non-chunked function for just these stats;
# they don't need the full transposed matrix.
extra = windowed_statistics_fused(
haplotype_matrix, bp_bins=bp_bins,
statistics=tuple(remaining),
population=population, pop1=pop1, pop2=pop2,
per_base=per_base, is_accessible=is_accessible,
_win_starts=_win_starts, _win_stops=_win_stops,
missing_data=missing_data, chrom=chrom)
for k, v in extra.items():
if k not in results:
results[k] = v
return results
def _position_weights(positions, salt):
"""Position-deterministic float64 weights for the Garud H hash basis.
Splitmix64-scrambles each absolute variant position to a weight in
``[-1, 1)``. The weight is a pure function of the position and the
salt, which is what lets the Garud H kernel produce the same per-
window hash whether the surrounding window was computed in one
eager pass or split across two streaming chunks; an n_variants-
seeded RNG would assign a different weight to the same variant
every time the matrix shape changed.
Two salts produce two independent weight columns w1, w2; together
they make collisions between distinct haplotype patterns vanishingly
rare without changing the Garud H computation downstream.
The output is uniform in [-1, 1) rather than standard-normal -- the
kernel only uses the weights to discriminate haplotype patterns by
sorted hash, not to inherit any specific distribution.
"""
s = cp.uint64(salt & 0xFFFFFFFFFFFFFFFF)
x = positions.astype(cp.uint64) + s
x = (x ^ (x >> cp.uint64(30))) * cp.uint64(0xBF58476D1CE4E5B9)
x = (x ^ (x >> cp.uint64(27))) * cp.uint64(0x94D049BB133111EB)
x = x ^ (x >> cp.uint64(31))
# Mantissa-pack the low 52 bits as a float64 in [1, 2); subtract 1 to
# get [0, 1); rescale to [-1, 1).
mant = (x & cp.uint64(0x000FFFFFFFFFFFFF)) | cp.uint64(0x3FF0000000000000)
u = mant.view(cp.float64) - 1.0
return 2.0 * u - 1.0
_GARUD_SALT1 = 0x9E3779B97F4A7C15 # golden-ratio constant
_GARUD_SALT2 = 0xC6BC279692B5C323 # phi^-2-based companion
def _windowed_mean(values, bin_idx, valid_mask, n_bins):
"""Compute mean of values per window bin, returning NaN for empty bins."""
val_sum = _scatter_sum(values[valid_mask], bin_idx[valid_mask], n_bins)
val_count = _bin_counts(bin_idx[valid_mask], n_bins)
sum_cpu = val_sum.get()
count_cpu = val_count.get()
return np.where(count_cpu > 0, sum_cpu / count_cpu, np.nan)
def _compute_fused_garud_h(haplotype_matrix, population,
win_start, win_stop, n_windows, statistics,
results):
"""Compute windowed Garud's H + fused GPU kernel.
Two assembly paths build the per-window haplotype hashes:
* Tile windows (``win_stop[k] == win_start[k+1]`` for all k) use
a per-window scatter-reduce via ``cp.add.reduceat``. Each
window's hash is summed from its own variants in index order,
so eager and streaming produce bit-identical results and the
kernel can use an exact-equality tolerance to bucket haplotypes.
* Sliding windows (overlap > 0) still use the prefix-sum trick,
which carries the well-known ULP-scale drift the kernel's
legacy ``tol=1e-3`` absorbs.
"""
from ._utils import get_population_matrix
if population is not None:
matrix = get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
hap = matrix.haplotypes # (n_hap, n_var)
n_hap, n_var = hap.shape
pos = matrix.positions
if not isinstance(pos, cp.ndarray):
pos = cp.asarray(pos)
# Tile detection. Both win_start and win_stop arrive as either
# numpy or cupy arrays; check on the host so the comparison is
# cheap and synchronous.
ws = win_start.get() if isinstance(win_start, cp.ndarray) else np.asarray(win_start)
we = win_stop.get() if isinstance(win_stop, cp.ndarray) else np.asarray(win_stop)
is_tile = (n_windows == 0) or bool(np.all(we[:-1] == ws[1:]))
if is_tile:
_garud_h_per_window_reduceat(hap, pos, win_start, win_stop,
n_hap, n_windows, statistics,
results)
return
# Sliding path: keep the existing prefix-sum implementation. Memory
# check sized for 4 (n_hap, span+1) float64 arrays (hw1, hw2,
# cs1, cs2); fall back to per-group chunking if a full single-pass
# buffer would not fit.
free_mem = cp.cuda.Device().mem_info[0]
prefix_budget = int(free_mem * 0.3)
cost_per_var = n_hap * 8 * 4
max_span = max(1, prefix_budget // cost_per_var)
if n_var <= max_span:
_garud_h_single_pass(hap, pos, n_hap, n_var, win_start, win_stop,
n_windows, statistics, results)
else:
_garud_h_chunked(hap, pos, n_hap, n_var, win_start, win_stop,
n_windows, max_span, statistics, results)
def _garud_h_per_window_reduceat(hap, pos, win_start, win_stop, n_hap,
n_windows, statistics, results):
"""Garud H assembling each window's hash by per-window scatter-reduce.
Cost is O(n_hap * n_var) -- the same as the prefix-sum approach --
but each window's hash is summed from its own variants in index
order, so two chunkings covering the same window produce
bit-identical results. With deterministic hashes the kernel can
bucket haplotypes by a much tighter tolerance (1e-12) than the
prefix-sum path's 1e-3, which means the count of distinct
haplotypes matches the hand-rolled reference exactly.
Requires non-overlapping tile windows so each variant is in at
most one window; the caller dispatches sliding windows to the
prefix-sum path.
"""
if n_windows == 0:
return
w1 = _position_weights(pos, _GARUD_SALT1)
w2 = _position_weights(pos, _GARUD_SALT2)
# Slice to the variant range covered by any window. Anything left
# of win_start[0] and right of win_stop[-1] is not in any window
# and should not contribute.
if isinstance(win_start, cp.ndarray):
ws = win_start.get()
we = win_stop.get()
else:
ws = np.asarray(win_start)
we = np.asarray(win_stop)
v_lo = int(ws[0])
v_hi = int(we[-1])
hap_slice = hap[:, v_lo:v_hi].astype(cp.float64)
hw1 = hap_slice * w1[v_lo:v_hi][cp.newaxis, :]
hw2 = hap_slice * w2[v_lo:v_hi][cp.newaxis, :]
# cp.add.reduceat: bin k sums positions ws[k]-v_lo .. ws[k+1]-v_lo
# (exclusive), and the final bin runs to the end of the sliced
# array, which equals we[-1]-v_lo for tile windows.
rel = (ws - v_lo).astype(np.int64)
h1 = cp.add.reduceat(hw1, rel, axis=1) # (n_hap, n_windows)
h2 = cp.add.reduceat(hw2, rel, axis=1)
all_h1 = cp.ascontiguousarray(h1.T) # kernel wants (n_windows, n_hap)
all_h2 = cp.ascontiguousarray(h2.T)
_launch_garud_kernel(all_h1, all_h2, n_hap, n_windows, statistics,
results, tol=1e-12)
def _garud_h_single_pass(hap, pos, n_hap, n_var, win_start, win_stop,
n_windows, statistics, results):
"""Garud H via full prefix-sum hashing (fits in memory)."""
h_f64 = hap.astype(cp.float64)
w1 = _position_weights(pos, _GARUD_SALT1)
w2 = _position_weights(pos, _GARUD_SALT2)
hw1 = h_f64 * w1[cp.newaxis, :]
hw2 = h_f64 * w2[cp.newaxis, :]
cs1 = cp.zeros((n_hap, n_var + 1), dtype=cp.float64)
cs2 = cp.zeros((n_hap, n_var + 1), dtype=cp.float64)
cp.cumsum(hw1, axis=1, out=cs1[:, 1:])
cp.cumsum(hw2, axis=1, out=cs2[:, 1:])
all_h1 = (cs1[:, win_stop] - cs1[:, win_start]).T
all_h2 = (cs2[:, win_stop] - cs2[:, win_start]).T
all_h1 = cp.ascontiguousarray(all_h1)
all_h2 = cp.ascontiguousarray(all_h2)
_launch_garud_kernel(all_h1, all_h2, n_hap, n_windows, statistics,
results, tol=1e-3)
def _garud_h_chunked(hap, pos, n_hap, n_var, win_start, win_stop,
n_windows, max_span, statistics, results):
"""Garud H processing windows in groups to limit memory."""
from ._memutil import free_gpu_pool
out_h1 = np.empty(n_windows, dtype=np.float64)
out_h12 = np.empty(n_windows, dtype=np.float64)
out_h123 = np.empty(n_windows, dtype=np.float64)
out_h2h1 = np.empty(n_windows, dtype=np.float64)
out_n_distinct = np.empty(n_windows, dtype=np.float64)
ws_cpu = win_start.get()
we_cpu = win_stop.get()
# Group windows by overlapping variant spans
processed = np.zeros(n_windows, dtype=bool)
wi = 0
while wi < n_windows:
# Find a group of consecutive windows that fit in max_span
group_var_start = int(ws_cpu[wi])
group_var_end = int(we_cpu[wi])
group_end = wi + 1
while group_end < n_windows:
candidate_end = int(we_cpu[group_end])
if candidate_end - group_var_start > max_span:
break
group_var_end = candidate_end
group_end += 1
span = group_var_end - group_var_start
n_group = group_end - wi
# Compute prefix sums over just this variant span. Weights are
# derived from absolute variant positions so the per-span basis
# matches the equivalent full-matrix span -- different chunkings
# produce identical Garud H values on the same window.
hap_span = hap[:, group_var_start:group_var_end].astype(cp.float64)
pos_span = pos[group_var_start:group_var_end]
w1 = _position_weights(pos_span, _GARUD_SALT1)
w2 = _position_weights(pos_span, _GARUD_SALT2)
hw1 = hap_span * w1[cp.newaxis, :]
hw2 = hap_span * w2[cp.newaxis, :]
cs1 = cp.zeros((n_hap, span + 1), dtype=cp.float64)
cs2 = cp.zeros((n_hap, span + 1), dtype=cp.float64)
cp.cumsum(hw1, axis=1, out=cs1[:, 1:])
cp.cumsum(hw2, axis=1, out=cs2[:, 1:])
# Local window indices relative to span start
local_ws = cp.asarray(ws_cpu[wi:group_end] - group_var_start)
local_we = cp.asarray(we_cpu[wi:group_end] - group_var_start)
all_h1 = (cs1[:, local_we] - cs1[:, local_ws]).T
all_h2 = (cs2[:, local_we] - cs2[:, local_ws]).T
all_h1 = cp.ascontiguousarray(all_h1)
all_h2 = cp.ascontiguousarray(all_h2)
# Launch kernel for this group
grp_results = {}
_launch_garud_kernel(all_h1, all_h2, n_hap, n_group,
statistics, grp_results, tol=1e-3)
# Store group results
for stat_name, out_arr in [('garud_h1', out_h1), ('garud_h12', out_h12),
('garud_h123', out_h123), ('garud_h2h1', out_h2h1),
('haplotype_count', out_n_distinct)]:
if stat_name in grp_results:
out_arr[wi:group_end] = grp_results[stat_name]
del hap_span, hw1, hw2, cs1, cs2, all_h1, all_h2
free_gpu_pool()
wi = group_end
if 'garud_h1' in statistics:
results['garud_h1'] = out_h1
if 'garud_h12' in statistics:
results['garud_h12'] = out_h12
if 'garud_h123' in statistics:
results['garud_h123'] = out_h123
if 'garud_h2h1' in statistics:
results['garud_h2h1'] = out_h2h1
if 'haplotype_count' in statistics:
results['haplotype_count'] = out_n_distinct.astype(int)
def _launch_garud_kernel(all_h1, all_h2, n_hap, n_windows, statistics,
results, *, tol):
"""Launch the Garud H GPU kernel and store results.
``tol`` is the float64 distance below which two sorted hashes are
treated as the same haplotype. The per-window scatter-reduce
hashing path passes a much tighter tolerance (1e-12) than the
prefix-sum path (1e-3) because its hashes are bit-identical for
equal haplotypes and far apart for distinct ones; the looser
legacy value still absorbs the ULP-scale drift the cumsum-then-
subtract trick introduces.
"""
out_h1 = cp.empty(n_windows, dtype=cp.float64)
out_h12 = cp.empty(n_windows, dtype=cp.float64)
out_h123 = cp.empty(n_windows, dtype=cp.float64)
out_h2h1 = cp.empty(n_windows, dtype=cp.float64)
out_n_distinct = cp.empty(n_windows, dtype=cp.float64)
block = max(1, (n_hap + 1) // 2)
block = 1 << (block - 1).bit_length()
block = min(block, 1024)
shm_size = 2 * n_hap * 8
_fused_garud_h_kernel(
(n_windows,), (block,),
(all_h1, all_h2, np.int32(n_hap), np.int32(n_windows),
np.float64(tol),
out_h1, out_h12, out_h123, out_h2h1, out_n_distinct),
shared_mem=shm_size)
if 'garud_h1' in statistics:
results['garud_h1'] = out_h1.get()
if 'garud_h12' in statistics:
results['garud_h12'] = out_h12.get()
if 'garud_h123' in statistics:
results['garud_h123'] = out_h123.get()
if 'garud_h2h1' in statistics:
results['garud_h2h1'] = out_h2h1.get()
if 'haplotype_count' in statistics:
results['haplotype_count'] = out_n_distinct.get().astype(int)
# ---------------------------------------------------------------------------
# GPU-native windowed statistics: compute once, bin everywhere
# ---------------------------------------------------------------------------
def _scatter_sum(values, bin_idx, n_bins):
"""Sum values into bins using scatter_add on GPU."""
result = cp.zeros(n_bins, dtype=cp.float64)
valid = (bin_idx >= 0) & (bin_idx < n_bins)
cp.add.at(result, bin_idx[valid], values[valid])
return result
def _bin_counts(bin_idx, n_bins):
"""Count variants per bin on GPU."""
valid = (bin_idx >= 0) & (bin_idx < n_bins)
return cp.bincount(bin_idx[valid], minlength=n_bins)
def _allele_sum_and_n(hap, has_missing=None):
"""Sum of alleles and valid count per variant, skipping missing (-1).
Parameters
----------
hap : cupy.ndarray
has_missing : bool, optional
If known, skip the check. If None, checks for negatives.
Returns
-------
dac : cupy.ndarray — derived allele count per variant
n_valid : cupy.ndarray or int — valid haplotypes per variant
"""
if has_missing is None:
has_missing = bool(int(cp.min(hap)) < 0)
if has_missing:
valid_mask = hap >= 0
dac = cp.sum(cp.where(valid_mask, hap, 0), axis=0)
n_valid = cp.sum(valid_mask, axis=0)
else:
dac = cp.sum(hap, axis=0)
n_valid = hap.shape[0]
return dac, n_valid
def _per_variant_mpd(hap, n_hap):
"""Mean pairwise difference per variant (GPU)."""
dac, n_valid = _allele_sum_and_n(hap)
dac = dac.astype(cp.float64)
if isinstance(n_valid, int):
n = cp.float64(n_valid)
else:
n = n_valid.astype(cp.float64)
usable = n > 1
p = cp.where(usable, dac / n, 0.0)
mpd = cp.zeros_like(dac)
mpd[usable] = 2.0 * p[usable] * (1.0 - p[usable]) * n[usable] / (n[usable] - 1)
return mpd
def _per_variant_is_seg(hap, n_hap_int):
"""Boolean: is variant segregating (GPU)."""
dac, n_valid = _allele_sum_and_n(hap)
return (dac > 0) & (dac < n_valid)
def _per_variant_is_singleton(hap, n_hap_int):
"""Boolean: is variant a singleton (GPU)."""
dac, n_valid = _allele_sum_and_n(hap)
return (dac == 1) | (dac == n_valid - 1)
def _per_variant_fst_hudson_components(hap1, hap2, n1, n2):
"""Per-variant Hudson FST numerator and denominator (GPU).
Returns (num, den) as CuPy arrays. Handles missing data (-1) by
using per-site valid counts.
"""
mpd1, mpd2, between = _twopop_site_components(hap1, hap2)
within = (mpd1 + mpd2) / 2.0
return between - within, between
def _per_variant_dxy(hap1, hap2, n1, n2):
"""Per-variant mean pairwise difference between populations (GPU)."""
_, _, between = _twopop_site_components(hap1, hap2)
return between
[docs]
def windowed_statistics(haplotype_matrix: HaplotypeMatrix,
bp_bins,
statistics=('pi', 'theta_w', 'tajimas_d'),
population=None,
pop1=None,
pop2=None,
per_base: bool = True,
is_accessible=None,
chrom=None):
"""GPU-native windowed statistics with no Python loop over windows.
Computes per-variant values once, then aggregates into windows using
GPU scatter_add operations. Dramatically faster than per-window
computation for large numbers of windows.
Parameters
----------
haplotype_matrix : HaplotypeMatrix
Haplotype data.
bp_bins : array_like
Window edges in base pairs. N+1 edges define N windows.
statistics : tuple of str
Statistics to compute. Supported:
Single-population: 'pi', 'theta_w', 'tajimas_d', 'segregating_sites',
'singletons', 'het_expected'
Two-population: 'fst', 'dxy'
population : str or list, optional
Population for single-pop statistics.
pop1, pop2 : str or list, optional
Populations for two-pop statistics (fst, dxy).
per_base : bool
If True, normalize by window size in base pairs.
is_accessible : array_like, optional
Boolean accessibility mask for per-base normalization.
chrom : str or int, optional
Chromosome label emitted in the ``chrom`` output column.
Returns
-------
dict
Maps column names to numpy arrays of shape (n_windows,). The first
six columns are ``chrom, start, end, center, n_variants, window_id``,
followed by one column per requested statistic.
"""
from ._utils import get_population_matrix
if haplotype_matrix.device == 'CPU':
haplotype_matrix.transfer_to_gpu()
# get population subset for single-pop stats
if population is not None:
matrix = get_population_matrix(haplotype_matrix, population)
else:
matrix = haplotype_matrix
if matrix.device == 'CPU':
matrix.transfer_to_gpu()
hap = matrix.haplotypes # (n_haplotypes, n_variants)
n_hap_int = hap.shape[0]
n_hap = cp.float64(n_hap_int)
positions = matrix.positions
if not isinstance(positions, cp.ndarray):
positions = cp.asarray(positions)
# assign variants to windows (GPU parallel binary search)
bp_bins = cp.asarray(bp_bins, dtype=cp.float64)
n_windows = len(bp_bins) - 1
bin_idx = cp.searchsorted(bp_bins, positions, side='right').astype(cp.int64) - 1
# clamp to valid range
bin_idx = cp.clip(bin_idx, 0, n_windows - 1)
# mark out-of-range variants
out_of_range = (positions < bp_bins[0]) | (positions >= bp_bins[-1])
bin_idx[out_of_range] = -1
variant_counts = _bin_counts(bin_idx, n_windows)
results = _init_window_results(
chrom, bp_bins[:-1].get(), bp_bins[1:].get(),
n_variants=variant_counts.get())
# window sizes for per-base normalization
if per_base:
window_bases = _compute_window_bases(
haplotype_matrix, results['start'],
results['end'], is_accessible)
window_bases = cp.asarray(window_bases, dtype=cp.float64)
# Phase 1: compute per-variant values and aggregate
# check for missing data once, share across all stats
_has_missing = bool(int(cp.min(hap)) < 0)
# precompute allele counts once (shared across all stats)
dac, n_valid = _allele_sum_and_n(hap, has_missing=_has_missing)
dac = dac.astype(cp.float64)
if isinstance(n_valid, int):
n_v = cp.float64(n_valid)
usable = cp.ones(dac.shape, dtype=bool)
else:
n_v = n_valid.astype(cp.float64)
usable = n_v > 1
p = cp.where(usable, dac / n_v, 0.0)
need_mpd = any(s in statistics for s in ('pi', 'tajimas_d'))
need_seg = any(s in statistics for s in
('theta_w', 'tajimas_d', 'segregating_sites'))
if need_mpd:
mpd = cp.zeros_like(dac)
mpd[usable] = 2.0 * p[usable] * (1.0 - p[usable]) * n_v[usable] / (n_v[usable] - 1) if not isinstance(n_valid, int) else 2.0 * p[usable] * (1.0 - p[usable]) * n_hap / (n_hap - 1)
else:
mpd = None
is_seg = (dac > 0) & (dac < n_v) if need_seg else None
if 'pi' in statistics:
pi_sum = _scatter_sum(mpd, bin_idx, n_windows)
if per_base:
results['pi'] = cp.where(window_bases > 0,
pi_sum / window_bases, cp.nan).get()
else:
results['pi'] = pi_sum.get()
if 'theta_w' in statistics:
seg_counts = _scatter_sum(is_seg.astype(cp.float64), bin_idx, n_windows)
n = n_hap_int
a1 = np.sum(1.0 / np.arange(1, n))
theta_abs = seg_counts / a1
if per_base:
results['theta_w'] = cp.where(window_bases > 0,
theta_abs / window_bases, cp.nan).get()
else:
results['theta_w'] = theta_abs.get()
if 'segregating_sites' in statistics:
seg_vals = is_seg if is_seg is not None else (dac > 0) & (dac < n_v)
seg_counts_out = _scatter_sum(seg_vals.astype(cp.float64), bin_idx,
n_windows)
results['segregating_sites'] = seg_counts_out.get().astype(int)
if 'singletons' in statistics:
is_sing = (dac == 1) | (dac == n_v - 1)
sing_counts = _scatter_sum(is_sing.astype(cp.float64), bin_idx,
n_windows)
results['singletons'] = sing_counts.get().astype(int)
if 'het_expected' in statistics:
he = 2.0 * p * (1.0 - p)
he_sum = _scatter_sum(he, bin_idx, n_windows)
results['het_expected'] = cp.where(
variant_counts > 0,
he_sum / variant_counts.astype(cp.float64), cp.nan).get()
if 'tajimas_d' in statistics:
# aggregate mpd and seg counts into windows, then apply formula
pi_sum_td = _scatter_sum(mpd, bin_idx, n_windows) if 'pi' not in statistics else _scatter_sum(mpd, bin_idx, n_windows)
seg_counts_td = _scatter_sum(is_seg.astype(cp.float64), bin_idx,
n_windows)
n = n_hap_int
a1 = np.sum(1.0 / np.arange(1, n))
a2 = np.sum(1.0 / np.arange(1, n) ** 2)
b1 = (n + 1) / (3 * (n - 1))
b2 = 2 * (n ** 2 + n + 3) / (9 * n * (n - 1))
c1 = b1 - 1 / a1
c2 = b2 - (n + 2) / (a1 * n) + a2 / a1 ** 2
e1 = c1 / a1
e2 = c2 / (a1 ** 2 + a2)
S = seg_counts_td.get()
pi_w = pi_sum_td.get()
d_num = pi_w - S / a1
d_var = e1 * S + e2 * S * (S - 1)
d_std = np.sqrt(np.maximum(d_var, 0))
tajd = np.where(d_std > 0, d_num / d_std, np.nan)
tajd[S < 3] = np.nan
results['tajimas_d'] = tajd
# two-population statistics
two_pop_stats = [s for s in statistics if s in ('fst', 'dxy')]
if two_pop_stats:
if pop1 is None or pop2 is None:
raise ValueError("pop1 and pop2 required for fst/dxy")
# Use the original (unsubsetted) matrix for population lookup
m1 = get_population_matrix(haplotype_matrix, pop1)
m2 = get_population_matrix(haplotype_matrix, pop2)
if m1.device == 'CPU':
m1.transfer_to_gpu()
if m2.device == 'CPU':
m2.transfer_to_gpu()
hap1 = m1.haplotypes
hap2 = m2.haplotypes
n1 = cp.float64(hap1.shape[0])
n2 = cp.float64(hap2.shape[0])
if 'fst' in statistics:
fst_num, fst_den = _per_variant_fst_hudson_components(
hap1, hap2, n1, n2)
num_sum = _scatter_sum(fst_num, bin_idx, n_windows)
den_sum = _scatter_sum(fst_den, bin_idx, n_windows)
results['fst'] = cp.where(den_sum > 0,
num_sum / den_sum, cp.nan).get()
if 'dxy' in statistics:
dxy_vals = _per_variant_dxy(hap1, hap2, n1, n2)
dxy_sum = _scatter_sum(dxy_vals, bin_idx, n_windows)
if per_base:
results['dxy'] = cp.where(window_bases > 0,
dxy_sum / window_bases, cp.nan).get()
else:
results['dxy'] = dxy_sum.get()
return results