Source code for pg_gpu.moments_ld

"""
Integration layer: GPU-accelerated LD statistics for moments inference.

Drop-in replacement for moments.LD.Parsing.compute_ld_statistics() using
pg_gpu's GPU kernels. Output format is identical to moments, so downstream
inference (bootstrap_data, optimize_log_lbfgsb, Godambe) works unchanged.

Supports 1-4 populations.

Usage:
    from pg_gpu.moments_ld import compute_ld_statistics

    ld_stats = compute_ld_statistics(
        "data.vcf", rec_map_file="map.txt", pop_file="pops.txt",
        pops=["pop0", "pop1"], r_bins=r_bins,
    )
    mv = moments.LD.Parsing.bootstrap_data({0: ld_stats})
"""

import numpy as np
import cupy as cp

from .haplotype_matrix import HaplotypeMatrix
from .ld_pipeline import (
    iter_pairs_within_distance as _iter_pairs_within_distance,
    compute_counts_for_pairs as _compute_counts_for_pairs,
    compute_genotype_counts_for_pairs as _compute_genotype_counts_for_pairs,
    compute_two_pop_statistics_batch as _compute_two_pop_statistics_batch,
    estimate_ld_chunk_size as _estimate_ld_chunk_size,
    ld_names as _ld_names,
    het_names as _het_names,
    generate_stat_specs as _generate_stat_specs,
)
from .genotype_kernels import compute_multi_pop_statistics_batch_geno
from .haplotype_kernels import compute_multi_pop_statistics_batch_hap
from .genotype_matrix import GenotypeMatrix
from . import ld_statistics


[docs] def compute_ld_statistics( vcf_file=None, rec_map_file=None, pop_file=None, pops=None, r_bins=None, bp_bins=None, use_genotypes=False, report=True, ac_filter=True, haplotype_matrix=None, genotype_matrix=None, accessible_bed=None, pop_assignment=None, ): """GPU-accelerated drop-in replacement for moments.LD.Parsing.compute_ld_statistics. Accepts the same arguments as the moments version so existing pipelines can switch by changing only the import:: # moments (CPU): import moments.LD ld_stats = moments.LD.Parsing.compute_ld_statistics( vcf_file="data.vcf.gz", rec_map_file="rec_map.txt", pop_file="pops.txt", pops=["popA", "popB"], r_bins=[0, 1e-6, 2e-6, 5e-6], ) # pg_gpu (GPU, same call signature): from pg_gpu.moments_ld import compute_ld_statistics ld_stats = compute_ld_statistics( vcf_file="data.vcf.gz", rec_map_file="rec_map.txt", pop_file="pops.txt", pops=["popA", "popB"], r_bins=[0, 1e-6, 2e-6, 5e-6], ) The returned dict has the same structure (keys 'bins', 'sums', 'stats', 'pops') and can be passed directly to moments inference functions. Parameters ---------- vcf_file : str, optional Path to VCF file. Not needed if haplotype_matrix/genotype_matrix provided. rec_map_file : str, optional Recombination map (tab-delimited: pos, Map(cM)). Required with r_bins. pop_file, pop_assignment : str, dict, numpy.ndarray, list, or False, optional Sample-to-population assignment in any form ``normalize_pop_input`` accepts (path / sample-to-pop dict / labels array / zarr key name / ``False`` to disable). The two kwargs are aliases; ``pop_file`` is the moments-compatible spelling, ``pop_assignment`` is the spelling the rest of pg_gpu uses. Passing both at once is an error. pops : list of str Population names (1-4). Defaults to ['pop0', 'pop1']. r_bins : array-like, optional Recombination rate bin edges (Morgans). bp_bins : array-like, optional Base-pair distance bin edges (alternative to r_bins). use_genotypes : bool If True, use diploid genotype counts (9-way) instead of haplotype counts (4-way). Requires unphased diploid data. report : bool Print progress. ac_filter : bool Apply biallelic filter. haplotype_matrix : HaplotypeMatrix, optional Pre-loaded HaplotypeMatrix (skips VCF loading and GPU transfer). genotype_matrix : GenotypeMatrix, optional Pre-loaded GenotypeMatrix (skips VCF loading and GPU transfer). accessible_bed : str, optional Path to a BED file defining accessible/callable regions. Variants at inaccessible positions are removed before computing statistics. Returns ------- dict with keys 'bins', 'sums', 'stats', 'pops' (moments format). """ if pop_file is not None and pop_assignment is not None: raise TypeError( "pass only one of pop_file or pop_assignment; they are " "aliases for the same argument." ) if pop_assignment is not None: pop_file = pop_assignment if pops is None: pops = ['pop0', 'pop1'] num_pops = len(pops) if num_pops < 1 or num_pops > 4: raise ValueError("1-4 populations supported") if r_bins is None and bp_bins is None: raise ValueError("Either r_bins or bp_bins must be provided") if use_genotypes: # Genotype (diploid) path if genotype_matrix is not None: gm = genotype_matrix if gm.device != 'GPU': gm.transfer_to_gpu() elif haplotype_matrix is not None: gm = GenotypeMatrix.from_haplotype_matrix(haplotype_matrix) if gm.device != 'GPU': gm.transfer_to_gpu() else: if vcf_file is None: raise ValueError("vcf_file or genotype_matrix required") if pop_file is None: raise ValueError("pop_file is required when loading from VCF") if report: print(f"Loading {vcf_file} (genotypes) ...") gm = GenotypeMatrix.from_vcf(vcf_file) gm.load_pop_file(pop_file, pops=pops) if ac_filter: gm = gm.apply_biallelic_filter() if accessible_bed is not None and not gm.has_accessible_mask: gm.set_accessible_mask(accessible_bed) gm.transfer_to_gpu() mat = gm if report: print(f" {gm.num_individuals} individuals, {gm.num_variants:,} variants") else: # Haplotype (phased) path if haplotype_matrix is not None: hm = haplotype_matrix if not isinstance(hm.haplotypes, cp.ndarray): hm.transfer_to_gpu() else: if vcf_file is None: raise ValueError("vcf_file or haplotype_matrix is required") if pop_file is None: raise ValueError("pop_file is required when loading from VCF") if report: print(f"Loading {vcf_file} ...") hm = HaplotypeMatrix.from_vcf(vcf_file) hm.load_pop_file(pop_file, pops=pops) if ac_filter: hm = hm.apply_biallelic_filter() if accessible_bed is not None and not hm.has_accessible_mask: hm.set_accessible_mask(accessible_bed) hm.transfer_to_gpu() mat = hm if report: print(f" {hm.num_haplotypes} hap, {hm.num_variants:,} variants") # Determine bins and distance metric for pair binning if r_bins is not None: if rec_map_file is None: raise ValueError("rec_map_file required with r_bins") bins = np.asarray(r_bins, dtype=np.float64) pos_cpu = mat.positions.get() if hasattr(mat.positions, 'get') else np.asarray(mat.positions) gen_dists = _interpolate_genetic_distances(pos_cpu, rec_map_file) gen_dists_gpu = cp.asarray(gen_dists) max_bp_dist = _max_bp_for_r_dist(pos_cpu, gen_dists, float(bins[-1])) else: bins = np.asarray(bp_bins, dtype=np.float64) gen_dists_gpu = None max_bp_dist = float(bins[-1]) n_bins = len(bins) - 1 if report: print(f" Computing LD ({n_bins} bins, {num_pops} pops) ...") ld_stat_names = _ld_names(num_pops) het_stat_names = _het_names(num_pops) if use_genotypes: ld_sums = _compute_ld_sums(mat, pops, bins, gen_dists_gpu, max_bp_dist, use_genotypes=True) het = _compute_heterozygosity(mat, pops, use_genotypes=True) else: ld_sums = _compute_ld_sums(mat, pops, bins, gen_dists_gpu, max_bp_dist) het = _compute_heterozygosity(mat, pops) if report: print(" Done.") bin_tuples = [(float(bins[i]), float(bins[i + 1])) for i in range(n_bins)] sums_list = [ld_sums[i] for i in range(n_bins)] sums_list.append(np.array([het[h] for h in het_stat_names])) return {'bins': bin_tuples, 'sums': sums_list, 'stats': (ld_stat_names, het_stat_names), 'pops': pops}
# --------------------------------------------------------------------------- # Internals # --------------------------------------------------------------------------- def _interpolate_genetic_distances(positions, rec_map_file): """Interpolate per-variant genetic map positions (Morgans) from map file.""" map_pos, map_vals = [], [] with open(rec_map_file) as f: for line in f: parts = line.strip().split() if len(parts) >= 2: try: map_pos.append(float(parts[0])) map_vals.append(float(parts[1]) / 100.0) # cM -> Morgans except ValueError: continue return np.interp(positions, np.array(map_pos), np.array(map_vals)) def _max_bp_for_r_dist(positions, gen_dists, max_r): """Conservative max physical distance for a given recombination distance. Uses minimum local recombination rate with safety margin. Caps at chromosome length to avoid generating excessive pairs. """ chrom_len = float(positions[-1] - positions[0]) if len(positions) < 2: return chrom_len bp_diffs = np.diff(positions).astype(np.float64) r_diffs = np.diff(gen_dists) valid = bp_diffs > 0 if not np.any(valid): return chrom_len rates = r_diffs[valid] / bp_diffs[valid] min_rate = np.min(rates[rates > 0]) if np.any(rates > 0) else max_r / chrom_len result = max_r / min_rate * 1.1 return min(result, chrom_len) def _compute_ld_sums(mat, pops, bins, gen_dists_gpu, max_bp_dist, use_genotypes=False): """Compute LD statistic sums per bin on GPU for N populations. Handles both haplotype (4-way) and genotype (9-way) count modes. """ num_pops = len(pops) if use_genotypes: # Filter to variants biallelic across the union of specified populations # (matches moments' behavior in _count_types_sparse lines 545-547) # Sum per-pop to avoid materializing the full union genotype matrix geno = mat.genotypes xp = cp if isinstance(geno, cp.ndarray) else np alt_sum = xp.zeros(mat.num_variants, dtype=xp.int64) n_valid_filter = xp.zeros(mat.num_variants, dtype=xp.int64) seen = set() for pop in pops: for idx in mat.sample_sets[pop]: if idx in seen: continue seen.add(idx) row = geno[idx, :] v = row >= 0 alt_sum += xp.where(v, row, 0).astype(xp.int64) n_valid_filter += v.astype(xp.int64) max_alt = 2 * n_valid_filter keep = (alt_sum > 0) & (alt_sum < max_alt) & (n_valid_filter >= 2) keep_idx = xp.where(keep)[0] pos = mat.positions[keep_idx] data_matrix = mat.genotypes[:, keep_idx] count_fn = _compute_genotype_counts_for_pairs stat_fn = compute_multi_pop_statistics_batch_geno else: pos = mat.positions data_matrix = mat.haplotypes count_fn = _compute_counts_for_pairs stat_fn = None # handled by 2-pop fast path or multi-pop if not isinstance(pos, cp.ndarray): pos = cp.array(pos) n_bins = len(bins) - 1 bins_gpu = cp.asarray(bins) pop_indices = [mat.sample_sets[p] for p in pops] max_samp = max(len(pi) for pi in pop_indices) chunk_size = _estimate_ld_chunk_size(max_samp, num_pops=num_pops) ld_stat_names = _ld_names(num_pops) n_ld = len(ld_stat_names) # Genetic-distance lookup: filter once outside the loop so fancy-indexing # on the keep-mask doesn't repeat per chunk. if gen_dists_gpu is not None: gen_dists_lookup = gen_dists_gpu[keep_idx] if use_genotypes else gen_dists_gpu else: gen_dists_lookup = None bin_sums = cp.zeros((n_bins, n_ld), dtype=cp.float64) stat_specs = _generate_stat_specs(num_pops) if (num_pops != 2 or use_genotypes) else None for ci, cj in _iter_pairs_within_distance(pos, max_bp_dist, chunk_size): if gen_dists_lookup is not None: distances = cp.abs(gen_dists_lookup[cj] - gen_dists_lookup[ci]) else: distances = pos[cj] - pos[ci] cb = cp.digitize(distances, bins_gpu) - 1 del distances counts_list = [] n_valid_list = [] for pidx in pop_indices: c, nv = count_fn(data_matrix, ci, cj, pidx) counts_list.append(c) n_valid_list.append(nv) if not use_genotypes and num_pops == 2: stats = _compute_two_pop_statistics_batch( counts_list[0], counts_list[1], n_valid_list[0], n_valid_list[1], ld_statistics) elif use_genotypes: stats = compute_multi_pop_statistics_batch_geno( counts_list, n_valid_list, None, stat_specs) else: stats = compute_multi_pop_statistics_batch_hap( counts_list, n_valid_list, ld_statistics, stat_specs) valid = (cb >= 0) & (cb < n_bins) vb = cb[valid] vs = stats[valid] flat_idx = vb[:, None] * n_ld + cp.arange(n_ld)[None, :] cp.add.at(bin_sums.ravel(), flat_idx.ravel(), vs.ravel()) del counts_list, n_valid_list, stats, cb return bin_sums.get() def _compute_heterozygosity(mat, pops, use_genotypes=False): """Compute H_i_j statistics on GPU for N populations (moments convention). Works with both haplotype and genotype data by converting to allele counts with the haploid sample size convention. """ num_pops = len(pops) alt_counts = [] ref_counts = [] hap_sizes = [] for pop in pops: pidx = mat.sample_sets[pop] if use_genotypes: if isinstance(pidx, list): pidx = cp.array(pidx, dtype=cp.int32) pop_data = mat.genotypes[pidx, :] valid = pop_data >= 0 alt = cp.sum(cp.where(valid, pop_data, 0).astype(cp.int32), axis=0).astype(cp.float64) n_hap = 2.0 * cp.sum(valid, axis=0).astype(cp.float64) else: alt = cp.sum(cp.maximum(mat.haplotypes[pidx, :], 0).astype(cp.int32), axis=0).astype(cp.float64) n_hap = cp.float64(len(pidx)) * cp.ones_like(alt) alt_counts.append(alt) ref_counts.append(n_hap - alt) hap_sizes.append(n_hap) result = {} for ii in range(num_pops): for jj in range(ii, num_pops): if ii == jj: val = float(cp.sum( 2.0 * ref_counts[ii] * alt_counts[ii] / (hap_sizes[ii] * (hap_sizes[ii] - 1)) ).get()) else: val = float(cp.sum( (ref_counts[ii] * alt_counts[jj] + alt_counts[ii] * ref_counts[jj]) / (hap_sizes[ii] * hap_sizes[jj]) ).get()) result[f"H_{ii}_{jj}"] = val return result