1use salmon_eqclass::CollapsedEqClasses;
18
19mod online;
20mod packed;
21pub mod uncertainty;
22
23pub use online::OnlineInference;
24pub use packed::PackedEqClasses;
25pub use uncertainty::{ambiguity_counts, bootstrap, gibbs_sample, GibbsOptions};
26
27#[derive(Debug, Clone)]
29pub struct EmOptions {
30 pub max_iter: u32,
31 pub min_iter: u32,
32 pub rel_diff_tol: f64,
34 pub alpha_check_cutoff: f64,
36 pub min_alpha: f64,
38 pub use_vbem: bool,
40 pub vb_prior: f64,
42 pub per_nucleotide_prior: bool,
46}
47
48impl Default for EmOptions {
49 fn default() -> Self {
50 Self {
51 max_iter: 10_000,
52 min_iter: 50,
53 rel_diff_tol: 0.01,
54 alpha_check_cutoff: 1e-2,
55 min_alpha: 1e-8,
56 use_vbem: false,
57 vb_prior: 1e-2,
58 per_nucleotide_prior: false,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct EmResult {
66 pub alphas: Vec<f64>,
68 pub iters: u32,
70 pub converged: bool,
72 pub dropped_mass: f64,
77}
78
79fn max_rel_diff(alpha_in: &[f64], alpha_out: &[f64], cutoff: f64) -> f64 {
83 let mut max_d = f64::NEG_INFINITY;
84 for i in 0..alpha_in.len() {
85 if alpha_in[i] > cutoff && alpha_out[i] > 0.0 {
86 let d = (alpha_out[i] - alpha_in[i]).abs() / alpha_out[i];
87 if d > max_d {
88 max_d = d;
89 }
90 }
91 }
92 max_d
93}
94
95pub fn optimize(
103 eq: &CollapsedEqClasses,
104 num_txps: usize,
105 opts: &EmOptions,
106 eff_lens: Option<&[f64]>,
107) -> EmResult {
108 let packed = PackedEqClasses::from_collapsed(eq, num_txps);
109 optimize_packed_with_init(&packed, opts, true, None, eff_lens)
110}
111
112pub fn optimize_with_init(
117 eq: &CollapsedEqClasses,
118 num_txps: usize,
119 opts: &EmOptions,
120 init_alphas: Option<&[f64]>,
121 eff_lens: Option<&[f64]>,
122) -> EmResult {
123 let packed = PackedEqClasses::from_collapsed(eq, num_txps);
124 optimize_packed_with_init(&packed, opts, true, init_alphas, eff_lens)
125}
126
127pub fn optimize_packed(p: &PackedEqClasses, opts: &EmOptions, parallel: bool) -> EmResult {
133 optimize_packed_with_init(p, opts, parallel, None, None)
134}
135
136pub fn optimize_packed_with_init(
140 p: &PackedEqClasses,
141 opts: &EmOptions,
142 parallel: bool,
143 init_alphas: Option<&[f64]>,
144 eff_lens: Option<&[f64]>,
145) -> EmResult {
146 let (alphas, iters, converged) = run_em_counts(
147 p,
148 &p.counts,
149 opts,
150 parallel,
151 opts.min_iter,
152 init_alphas,
153 eff_lens,
154 );
155 let (alphas, dropped_mass) =
160 finalize_truncate_redistribute(p, &p.counts, alphas, opts, eff_lens);
161 EmResult {
162 alphas,
163 iters,
164 converged,
165 dropped_mass,
166 }
167}
168
169pub(crate) fn prior_alphas_vec(
172 opts: &EmOptions,
173 eff_lens: Option<&[f64]>,
174 num_txps: usize,
175) -> Vec<f64> {
176 match (opts.per_nucleotide_prior, eff_lens) {
177 (true, Some(el)) if el.len() == num_txps => {
178 el.iter().map(|&l| opts.vb_prior * l.max(1.0)).collect()
179 }
180 _ => vec![opts.vb_prior; num_txps],
181 }
182}
183
184pub(crate) fn finalize_truncate_redistribute(
190 p: &PackedEqClasses,
191 counts: &[u64],
192 alphas: Vec<f64>,
193 opts: &EmOptions,
194 eff_lens: Option<&[f64]>,
195) -> (Vec<f64>, f64) {
196 if opts.min_alpha <= 0.0 {
197 return (alphas, 0.0);
198 }
199 let prior = prior_alphas_vec(opts, eff_lens, p.num_txps);
200 packed::redistribute_truncated(p, counts, &alphas, &prior, opts.min_alpha, opts.use_vbem)
201}
202
203pub(crate) fn run_em_counts(
208 p: &PackedEqClasses,
209 counts: &[u64],
210 opts: &EmOptions,
211 parallel: bool,
212 min_iter: u32,
213 init_alphas: Option<&[f64]>,
214 eff_lens: Option<&[f64]>,
215) -> (Vec<f64>, u32, bool) {
216 let num_txps = p.num_txps;
217 let total: u64 = counts.iter().sum();
218 let init = if num_txps > 0 {
219 total as f64 / num_txps as f64
220 } else {
221 0.0
222 };
223 let mut alphas = match init_alphas {
227 Some(a) if a.len() == num_txps => a.to_vec(),
228 _ => vec![init; num_txps],
229 };
230 let mut alphas_prime = vec![0.0f64; num_txps];
231 let prior_alphas = prior_alphas_vec(opts, eff_lens, num_txps);
234 let mut exp_theta = vec![0.0f64; num_txps];
235 let mut scratch: Vec<f64> = Vec::with_capacity(64);
236 let mut shards: Vec<Vec<f64>> = if parallel {
243 let nshards = rayon::current_num_threads().clamp(1, 64);
244 vec![vec![0.0f64; num_txps]; nshards]
245 } else {
246 Vec::new()
247 };
248
249 let mut converged = false;
250 let mut it = 0u32;
251 while it < opts.max_iter {
252 match (opts.use_vbem, parallel) {
253 (false, true) => {
254 packed::em_step_par(p, counts, &alphas, &mut alphas_prime, &mut shards)
255 }
256 (false, false) => {
257 packed::em_step_seq(p, counts, &alphas, &mut alphas_prime, &mut scratch)
258 }
259 (true, true) => packed::vbem_step_par(
260 p,
261 counts,
262 &prior_alphas,
263 &alphas,
264 &mut alphas_prime,
265 &mut exp_theta,
266 &mut shards,
267 ),
268 (true, false) => packed::vbem_step_seq(
269 p,
270 counts,
271 &prior_alphas,
272 &alphas,
273 &mut alphas_prime,
274 &mut exp_theta,
275 &mut scratch,
276 ),
277 }
278 it += 1;
279 if it >= min_iter {
280 let d = max_rel_diff(&alphas, &alphas_prime, opts.alpha_check_cutoff);
281 std::mem::swap(&mut alphas, &mut alphas_prime);
282 if d.is_finite() && d < opts.rel_diff_tol {
283 converged = true;
284 break;
285 }
286 } else {
287 std::mem::swap(&mut alphas, &mut alphas_prime);
288 }
289 }
290 (alphas, it, converged)
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use salmon_eqclass::{EquivalenceClassBuilder, TranscriptGroup};
297
298 fn build(classes: &[(Vec<u32>, u64)], num_txps: usize) -> CollapsedEqClasses {
300 let b = EquivalenceClassBuilder::new();
301 for (txps, count) in classes {
302 let w = vec![1.0; txps.len()];
303 b.add_group(TranscriptGroup::new(txps.clone()), w, *count);
304 }
305 let mut eq = b.finish();
306 eq.update_eff_lengths(&vec![1.0; num_txps]);
307 eq
308 }
309
310 #[test]
311 fn unique_classes_recover_exact_counts() {
312 let eq = build(&[(vec![0], 30), (vec![1], 70)], 2);
314 let res = optimize(&eq, 2, &EmOptions::default(), None);
315 assert!((res.alphas[0] - 30.0).abs() < 1e-6);
316 assert!((res.alphas[1] - 70.0).abs() < 1e-6);
317 }
318
319 #[test]
320 fn shared_class_splits_by_unique_evidence() {
321 let eq = build(&[(vec![0], 10), (vec![1], 90), (vec![0, 1], 100)], 2);
326 let res = optimize(&eq, 2, &EmOptions::default(), None);
327 let total = res.alphas[0] + res.alphas[1];
328 assert!((total - 200.0).abs() < 1e-6, "total={total}");
329 assert!((res.alphas[0] - 20.0).abs() < 1e-2, "a0={}", res.alphas[0]);
330 assert!((res.alphas[1] - 180.0).abs() < 1e-2, "a1={}", res.alphas[1]);
331 }
332
333 #[test]
334 fn conserves_total_count() {
335 let eq = build(&[(vec![0, 1, 2], 50), (vec![1, 2], 30), (vec![2], 20)], 3);
336 let res = optimize(&eq, 3, &EmOptions::default(), None);
337 let total: f64 = res.alphas.iter().sum();
338 assert!((total - 100.0).abs() < 1e-6, "total={total}");
339 }
340
341 #[test]
342 fn vbem_runs_and_conserves_approximately() {
343 let eq = build(&[(vec![0], 30), (vec![1], 70), (vec![0, 1], 100)], 2);
344 let opts = EmOptions {
345 use_vbem: true,
346 ..Default::default()
347 };
348 let res = optimize(&eq, 2, &opts, None);
349 let total: f64 = res.alphas.iter().sum();
350 assert!((total - 200.0).abs() < 1.0, "total={total}");
352 assert!(res.alphas[1] > res.alphas[0]);
353 }
354
355 #[test]
356 fn effective_length_shifts_allocation() {
357 let b = EquivalenceClassBuilder::new();
360 b.add_group(TranscriptGroup::new(vec![0, 1]), vec![1.0, 1.0], 100);
361 let mut eq = b.finish();
362 eq.update_eff_lengths(&[300.0, 100.0]);
363 let res = optimize(&eq, 2, &EmOptions::default(), None);
364 assert!(res.alphas[1] > res.alphas[0], "{:?}", res.alphas);
365 }
366}