Skip to main content

salmon_infer/
lib.rs

1//! Collapsed EM / VBEM abundance estimation over equivalence classes.
2//!
3//! Ports salmon's `CollapsedEMOptimizer` (`src/inference/CollapsedEMOptimizer.cpp`):
4//! given a finalized set of equivalence classes (each a transcript label, a
5//! count, and per-transcript `combined_weights`), iteratively estimate the
6//! expected number of fragments originating from each transcript.
7//!
8//! The update rules match the C++ exactly:
9//! - **EM**: `alphaOut[t] += count * (alphaIn[t] * w_t) / sum_j(alphaIn[j] * w_j)`,
10//!   with single-transcript classes assigned their full count.
11//! - **VBEM**: replaces `alphaIn[t]` with `expTheta[t] = exp(digamma(alphaIn[t] +
12//!   prior_t) - digamma(sum_j(alphaIn[j] + prior_j)))`.
13//!
14//! Parallelization with rayon and SQUAREM acceleration are deferred; plain
15//! iteration converges to the same fixpoint.
16
17use salmon_eqclass::CollapsedEqClasses;
18
19mod online;
20mod packed;
21pub mod uncertainty;
22
23pub use online::OnlineInference;
24pub use packed::PackedEqClasses;
25pub use uncertainty::{ambiguity_counts, bootstrap, gibbs_sample, GibbsOptions};
26
27/// Optimizer configuration. Defaults mirror salmon's command-line defaults.
28#[derive(Debug, Clone)]
29pub struct EmOptions {
30    pub max_iter: u32,
31    pub min_iter: u32,
32    /// relative-difference convergence tolerance
33    pub rel_diff_tol: f64,
34    /// only transcripts with `alpha` above this participate in the convergence check
35    pub alpha_check_cutoff: f64,
36    /// abundances below this are truncated to zero on output
37    pub min_alpha: f64,
38    /// use Variational Bayes EM instead of plain EM
39    pub use_vbem: bool,
40    /// per-transcript Dirichlet prior (VBEM only)
41    pub vb_prior: f64,
42    /// interpret `vb_prior` as a per-nucleotide prior (`vb_prior * effLen`)
43    /// instead of a flat per-transcript prior (salmon's `--perNucleotidePrior`;
44    /// salmon's default is the per-transcript interpretation).
45    pub per_nucleotide_prior: bool,
46}
47
48impl Default for EmOptions {
49    fn default() -> Self {
50        Self {
51            max_iter: 10_000,
52            min_iter: 50,
53            rel_diff_tol: 0.01,
54            alpha_check_cutoff: 1e-2,
55            min_alpha: 1e-8,
56            use_vbem: false,
57            vb_prior: 1e-2,
58            per_nucleotide_prior: false,
59        }
60    }
61}
62
63/// Result of an optimization run.
64#[derive(Debug, Clone)]
65pub struct EmResult {
66    /// estimated fragment counts per transcript (indexed by transcript id)
67    pub alphas: Vec<f64>,
68    /// iterations actually run
69    pub iters: u32,
70    /// whether the relative-difference criterion was met before `max_iter`
71    pub converged: bool,
72    /// total count in equivalence classes whose every transcript was truncated
73    /// below `min_alpha` in the final redistribution step — mass that could not be
74    /// reassigned to any surviving transcript (reported, not rescaled away). 0 in
75    /// the normal case.
76    pub dropped_mass: f64,
77}
78
79/// Relative-difference convergence check, matching salmon: the max over
80/// transcripts (with `alpha_in` above the cutoff) of
81/// `|alpha_out - alpha_in| / alpha_out`.
82fn max_rel_diff(alpha_in: &[f64], alpha_out: &[f64], cutoff: f64) -> f64 {
83    let mut max_d = f64::NEG_INFINITY;
84    for i in 0..alpha_in.len() {
85        if alpha_in[i] > cutoff && alpha_out[i] > 0.0 {
86            let d = (alpha_out[i] - alpha_in[i]).abs() / alpha_out[i];
87            if d > max_d {
88                max_d = d;
89            }
90        }
91    }
92    max_d
93}
94
95/// Run the optimizer to convergence (parallel EM/VBEM over the packed layout).
96///
97/// `eq` must already have `combined_weights` populated (call
98/// [`CollapsedEqClasses::update_eff_lengths`](salmon_eqclass::CollapsedEqClasses::update_eff_lengths)).
99/// `num_txps` is the total transcript count (output length). Abundances are
100/// initialized uniformly over the total fragment count. Internally builds a
101/// flat CSR [`PackedEqClasses`] and uses rayon-parallel M-steps.
102pub fn optimize(
103    eq: &CollapsedEqClasses,
104    num_txps: usize,
105    opts: &EmOptions,
106    eff_lens: Option<&[f64]>,
107) -> EmResult {
108    let packed = PackedEqClasses::from_collapsed(eq, num_txps);
109    optimize_packed_with_init(&packed, opts, true, None, eff_lens)
110}
111
112/// As [`optimize`], but warm-starts the abundances from `init_alphas` (per
113/// transcript id) when its length matches `num_txps` — used to seed the EM with
114/// salmon's count-blended initialization (online estimates blended with uniform),
115/// which reduces the iteration count to convergence.
116pub fn optimize_with_init(
117    eq: &CollapsedEqClasses,
118    num_txps: usize,
119    opts: &EmOptions,
120    init_alphas: Option<&[f64]>,
121    eff_lens: Option<&[f64]>,
122) -> EmResult {
123    let packed = PackedEqClasses::from_collapsed(eq, num_txps);
124    optimize_packed_with_init(&packed, opts, true, init_alphas, eff_lens)
125}
126
127/// Core convergence loop over a [`PackedEqClasses`]. `parallel` selects the
128/// rayon M-step (for the single main run) vs. the sequential one (used by
129/// bootstrap, which parallelizes across replicates instead). The per-class
130/// `counts` are the packed structure's own (bootstrap passes resampled counts
131/// through [`run_em_counts`]).
132pub fn optimize_packed(p: &PackedEqClasses, opts: &EmOptions, parallel: bool) -> EmResult {
133    optimize_packed_with_init(p, opts, parallel, None, None)
134}
135
136/// As [`optimize_packed`], but seeds the abundances from `init_alphas` (a warm
137/// start, e.g. salmon's online-estimate-blended-with-uniform initialization)
138/// when supplied; otherwise starts uniform.
139pub fn optimize_packed_with_init(
140    p: &PackedEqClasses,
141    opts: &EmOptions,
142    parallel: bool,
143    init_alphas: Option<&[f64]>,
144    eff_lens: Option<&[f64]>,
145) -> EmResult {
146    let (alphas, iters, converged) = run_em_counts(
147        p,
148        &p.counts,
149        opts,
150        parallel,
151        opts.min_iter,
152        init_alphas,
153        eff_lens,
154    );
155    // Truncate negligible abundances (matches salmon's cutoff), but rather than
156    // zero-and-rescale, redistribute the truncated mass to eq-class co-members via
157    // one masked final M-step (see `redistribute_truncated`). `min_alpha <= 0`
158    // (e.g. the bias warm-up) keeps the continuous vector untouched.
159    let (alphas, dropped_mass) =
160        finalize_truncate_redistribute(p, &p.counts, alphas, opts, eff_lens);
161    EmResult {
162        alphas,
163        iters,
164        converged,
165        dropped_mass,
166    }
167}
168
169/// VBEM Dirichlet prior per transcript: flat `vb_prior`, or `vb_prior·max(1,effLen)`
170/// under `--perNucleotidePrior`.
171pub(crate) fn prior_alphas_vec(
172    opts: &EmOptions,
173    eff_lens: Option<&[f64]>,
174    num_txps: usize,
175) -> Vec<f64> {
176    match (opts.per_nucleotide_prior, eff_lens) {
177        (true, Some(el)) if el.len() == num_txps => {
178            el.iter().map(|&l| opts.vb_prior * l.max(1.0)).collect()
179        }
180        _ => vec![opts.vb_prior; num_txps],
181    }
182}
183
184/// Apply the post-convergence min-alpha truncation as a mass-preserving
185/// redistribution (not a rescale). Returns the finalized alphas and the mass that
186/// could not be redistributed (fully-truncated classes). A non-positive
187/// `min_alpha` is a no-op (used by the bias warm-up, whose alphas continue into
188/// the warm-started EM).
189pub(crate) fn finalize_truncate_redistribute(
190    p: &PackedEqClasses,
191    counts: &[u64],
192    alphas: Vec<f64>,
193    opts: &EmOptions,
194    eff_lens: Option<&[f64]>,
195) -> (Vec<f64>, f64) {
196    if opts.min_alpha <= 0.0 {
197        return (alphas, 0.0);
198    }
199    let prior = prior_alphas_vec(opts, eff_lens, p.num_txps);
200    packed::redistribute_truncated(p, counts, &alphas, &prior, opts.min_alpha, opts.use_vbem)
201}
202
203/// Run EM/VBEM to convergence on `p` with explicit per-class `counts`, returning
204/// `(alphas, iters, converged)` *without* the final min-alpha truncation (so
205/// bootstrap can apply its own scaling first). `min_iter` is the minimum number
206/// of iterations before the convergence check engages.
207pub(crate) fn run_em_counts(
208    p: &PackedEqClasses,
209    counts: &[u64],
210    opts: &EmOptions,
211    parallel: bool,
212    min_iter: u32,
213    init_alphas: Option<&[f64]>,
214    eff_lens: Option<&[f64]>,
215) -> (Vec<f64>, u32, bool) {
216    let num_txps = p.num_txps;
217    let total: u64 = counts.iter().sum();
218    let init = if num_txps > 0 {
219        total as f64 / num_txps as f64
220    } else {
221        0.0
222    };
223    // Warm start from a supplied initialization (e.g. the online-phase abundance
224    // estimates blended with uniform, matching salmon's count-blended init) when
225    // its length matches; otherwise start uniform over the total fragment count.
226    let mut alphas = match init_alphas {
227        Some(a) if a.len() == num_txps => a.to_vec(),
228        _ => vec![init; num_txps],
229    };
230    let mut alphas_prime = vec![0.0f64; num_txps];
231    // VBEM prior: flat per-transcript `vb_prior` (salmon's default), or — under
232    // `--perNucleotidePrior` — `vb_prior * effLen` per transcript.
233    let prior_alphas = prior_alphas_vec(opts, eff_lens, num_txps);
234    let mut exp_theta = vec![0.0f64; num_txps];
235    let mut scratch: Vec<f64> = Vec::with_capacity(64);
236    // Per-shard dense accumulators reused across all parallel M-steps (allocated
237    // once here, not per-task per-iteration). Each shard processes a contiguous
238    // slice of the classes with plain adds, then they are summed into `alpha_out`
239    // — avoiding the cross-thread CAS contention of a single shared atomic array.
240    // Capped at 64 shards: beyond that, the per-iteration zero/reduce overhead
241    // outweighs the extra accumulation parallelism.
242    let mut shards: Vec<Vec<f64>> = if parallel {
243        let nshards = rayon::current_num_threads().clamp(1, 64);
244        vec![vec![0.0f64; num_txps]; nshards]
245    } else {
246        Vec::new()
247    };
248
249    let mut converged = false;
250    let mut it = 0u32;
251    while it < opts.max_iter {
252        match (opts.use_vbem, parallel) {
253            (false, true) => {
254                packed::em_step_par(p, counts, &alphas, &mut alphas_prime, &mut shards)
255            }
256            (false, false) => {
257                packed::em_step_seq(p, counts, &alphas, &mut alphas_prime, &mut scratch)
258            }
259            (true, true) => packed::vbem_step_par(
260                p,
261                counts,
262                &prior_alphas,
263                &alphas,
264                &mut alphas_prime,
265                &mut exp_theta,
266                &mut shards,
267            ),
268            (true, false) => packed::vbem_step_seq(
269                p,
270                counts,
271                &prior_alphas,
272                &alphas,
273                &mut alphas_prime,
274                &mut exp_theta,
275                &mut scratch,
276            ),
277        }
278        it += 1;
279        if it >= min_iter {
280            let d = max_rel_diff(&alphas, &alphas_prime, opts.alpha_check_cutoff);
281            std::mem::swap(&mut alphas, &mut alphas_prime);
282            if d.is_finite() && d < opts.rel_diff_tol {
283                converged = true;
284                break;
285            }
286        } else {
287            std::mem::swap(&mut alphas, &mut alphas_prime);
288        }
289    }
290    (alphas, it, converged)
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use salmon_eqclass::{EquivalenceClassBuilder, TranscriptGroup};
297
298    /// Build a collapsed eq-class set and set unit effective lengths.
299    fn build(classes: &[(Vec<u32>, u64)], num_txps: usize) -> CollapsedEqClasses {
300        let b = EquivalenceClassBuilder::new();
301        for (txps, count) in classes {
302            let w = vec![1.0; txps.len()];
303            b.add_group(TranscriptGroup::new(txps.clone()), w, *count);
304        }
305        let mut eq = b.finish();
306        eq.update_eff_lengths(&vec![1.0; num_txps]);
307        eq
308    }
309
310    #[test]
311    fn unique_classes_recover_exact_counts() {
312        // Two transcripts, only unique evidence: EM must return those counts.
313        let eq = build(&[(vec![0], 30), (vec![1], 70)], 2);
314        let res = optimize(&eq, 2, &EmOptions::default(), None);
315        assert!((res.alphas[0] - 30.0).abs() < 1e-6);
316        assert!((res.alphas[1] - 70.0).abs() < 1e-6);
317    }
318
319    #[test]
320    fn shared_class_splits_by_unique_evidence() {
321        // 10 unique to t0, 90 unique to t1, 100 shared between them.
322        // The EM fixpoint allocates the shared class proportionally to the
323        // current abundances; with equal eff lengths the stable split tracks
324        // the unique ratio, so totals converge to 0.1*200=20 and 0.9*200=180.
325        let eq = build(&[(vec![0], 10), (vec![1], 90), (vec![0, 1], 100)], 2);
326        let res = optimize(&eq, 2, &EmOptions::default(), None);
327        let total = res.alphas[0] + res.alphas[1];
328        assert!((total - 200.0).abs() < 1e-6, "total={total}");
329        assert!((res.alphas[0] - 20.0).abs() < 1e-2, "a0={}", res.alphas[0]);
330        assert!((res.alphas[1] - 180.0).abs() < 1e-2, "a1={}", res.alphas[1]);
331    }
332
333    #[test]
334    fn conserves_total_count() {
335        let eq = build(&[(vec![0, 1, 2], 50), (vec![1, 2], 30), (vec![2], 20)], 3);
336        let res = optimize(&eq, 3, &EmOptions::default(), None);
337        let total: f64 = res.alphas.iter().sum();
338        assert!((total - 100.0).abs() < 1e-6, "total={total}");
339    }
340
341    #[test]
342    fn vbem_runs_and_conserves_approximately() {
343        let eq = build(&[(vec![0], 30), (vec![1], 70), (vec![0, 1], 100)], 2);
344        let opts = EmOptions {
345            use_vbem: true,
346            ..Default::default()
347        };
348        let res = optimize(&eq, 2, &opts, None);
349        let total: f64 = res.alphas.iter().sum();
350        // VBEM with a tiny prior stays very close to the EM total.
351        assert!((total - 200.0).abs() < 1.0, "total={total}");
352        assert!(res.alphas[1] > res.alphas[0]);
353    }
354
355    #[test]
356    fn effective_length_shifts_allocation() {
357        // One shared class, equal weights, but t0 is 3x longer -> more of the
358        // shared mass should go to the shorter t1.
359        let b = EquivalenceClassBuilder::new();
360        b.add_group(TranscriptGroup::new(vec![0, 1]), vec![1.0, 1.0], 100);
361        let mut eq = b.finish();
362        eq.update_eff_lengths(&[300.0, 100.0]);
363        let res = optimize(&eq, 2, &EmOptions::default(), None);
364        assert!(res.alphas[1] > res.alphas[0], "{:?}", res.alphas);
365    }
366}