1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4 core::{
5 dsp::{
6 istft_cac_stereo_sources_add_into, stft_cac_stereo_centered_into, IstftBatchWorkspace,
7 },
8 ep,
9 },
10 error::{Result, StemError},
11 io::ep_cache,
12 model::model_manager::ModelHandle,
13 types::ModelManifest,
14};
15
16use anyhow::anyhow;
17use ndarray::Array3;
18use once_cell::sync::OnceCell;
19use ort::{
20 session::{
21 builder::{GraphOptimizationLevel, SessionBuilder},
22 Session,
23 },
24 value::{TensorRef, Value},
25};
26use std::{
27 path::PathBuf,
28 sync::{
29 atomic::{AtomicBool, Ordering},
30 Mutex,
31 },
32 time::Instant,
33};
34
35static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
36static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
37static ORT_INIT: OnceCell<()> = OnceCell::new();
38#[cfg(not(feature = "engine-mock"))]
39static ENGINE_IO_SPEC: OnceCell<EngineIoSpec> = OnceCell::new();
40#[cfg(not(feature = "engine-mock"))]
41static ENGINE_PERF: OnceCell<EnginePerfConfig> = OnceCell::new();
42#[cfg(not(feature = "engine-mock"))]
43static INPUT_SCRATCH: OnceCell<Mutex<InferenceScratch>> = OnceCell::new();
44#[cfg(not(feature = "engine-mock"))]
45static ISTFT_SCRATCH: OnceCell<Mutex<IstftBatchWorkspace>> = OnceCell::new();
46#[cfg(not(feature = "engine-mock"))]
47static PRELOAD_PROBE_INPUT: OnceCell<(Vec<f32>, Vec<f32>)> = OnceCell::new();
48#[cfg(not(feature = "engine-mock"))]
49static ENGINE_CONTEXT: OnceCell<EngineContext> = OnceCell::new();
50#[cfg(not(feature = "engine-mock"))]
51static RUNTIME_EP_FALLBACK_USED: AtomicBool = AtomicBool::new(false);
52
53const DEMUCS_T: usize = 343_980;
54const DEMUCS_F: usize = 2048;
55const DEMUCS_FRAMES: usize = 336;
56const DEMUCS_NFFT: usize = 4096;
57const DEMUCS_HOP: usize = 1024;
58
59#[cfg(not(feature = "engine-mock"))]
60struct EngineContext {
61 model_path: PathBuf,
62 num_threads: usize,
63 selected_kind: ep::EpKind,
64}
65
66#[cfg(not(feature = "engine-mock"))]
67struct DemucsRawOutput {
68 num_sources: usize,
69 data_time: Vec<f32>,
70 data_freq: Vec<f32>,
71 time_max: f32,
72 freq_max: f32,
73}
74
75#[cfg(not(feature = "engine-mock"))]
76#[derive(Clone, Copy)]
77struct OrtThreading {
78 intra_threads: usize,
79 inter_threads: usize,
80 parallel_execution: bool,
81}
82
83#[cfg(not(feature = "engine-mock"))]
84#[derive(Clone, Copy)]
85struct EngineIoSpec {
86 use_positional_inputs: bool,
87}
88
89#[cfg(not(feature = "engine-mock"))]
90#[derive(Clone, Copy)]
91struct EnginePerfConfig {
92 enabled: bool,
93}
94
95#[cfg(not(feature = "engine-mock"))]
96#[derive(Default)]
97struct WindowPerf {
98 prep_ns: u128,
99 stft_ns: u128,
100 lock_wait_ns: u128,
101 run_ns: u128,
102 extract_ns: u128,
103 decode_ns: u128,
104 istft_ns: u128,
105 mix_ns: u128,
106 total_ns: u128,
107}
108
109#[cfg(not(feature = "engine-mock"))]
110#[derive(Default)]
111struct InferenceScratch {
112 time_branch: Vec<f32>,
113 spec_branch: Vec<f32>,
114}
115
116#[cfg(not(feature = "engine-mock"))]
117impl InferenceScratch {
118 fn with_demucs_capacity() -> Self {
119 Self {
120 time_branch: Vec::with_capacity(2 * DEMUCS_T),
121 spec_branch: Vec::with_capacity(4 * DEMUCS_F * DEMUCS_FRAMES),
122 }
123 }
124
125 fn fill_time_branch(&mut self, left: &[f32], right: &[f32]) {
126 self.time_branch.clear();
127 self.time_branch.extend_from_slice(left);
128 self.time_branch.extend_from_slice(right);
129 }
130}
131
132#[cfg(not(feature = "engine-mock"))]
133fn parse_env_usize(name: &str) -> Option<usize> {
134 let raw = std::env::var(name).ok()?;
135 let parsed = raw.parse::<usize>().ok()?;
136 if parsed == 0 {
137 None
138 } else {
139 Some(parsed)
140 }
141}
142
143#[cfg(not(feature = "engine-mock"))]
144fn parse_env_bool(name: &str) -> Option<bool> {
145 let raw = std::env::var(name).ok()?;
146 match raw.trim().to_ascii_lowercase().as_str() {
147 "1" | "true" | "yes" | "on" => Some(true),
148 "0" | "false" | "no" | "off" => Some(false),
149 _ => None,
150 }
151}
152
153#[cfg(not(feature = "engine-mock"))]
154fn apply_thread_overrides(mut cfg: OrtThreading) -> OrtThreading {
155 if let Some(intra) = parse_env_usize("STEMMER_ORT_INTRA_THREADS") {
156 cfg.intra_threads = intra;
157 }
158 if let Some(inter) = parse_env_usize("STEMMER_ORT_INTER_THREADS") {
159 cfg.inter_threads = inter;
160 }
161 if let Some(parallel) = parse_env_bool("STEMMER_ORT_PARALLEL") {
162 cfg.parallel_execution = parallel;
163 }
164 cfg
165}
166
167#[cfg(not(feature = "engine-mock"))]
168fn perf_config() -> &'static EnginePerfConfig {
169 ENGINE_PERF.get_or_init(|| EnginePerfConfig {
170 enabled: std::env::var("STEMMER_PERF").is_ok(),
171 })
172}
173
174#[cfg(not(feature = "engine-mock"))]
175fn input_scratch() -> &'static Mutex<InferenceScratch> {
176 INPUT_SCRATCH.get_or_init(|| Mutex::new(InferenceScratch::with_demucs_capacity()))
177}
178
179#[cfg(not(feature = "engine-mock"))]
180fn istft_scratch() -> &'static Mutex<IstftBatchWorkspace> {
181 ISTFT_SCRATCH.get_or_init(|| Mutex::new(IstftBatchWorkspace::default()))
182}
183
184#[cfg(not(feature = "engine-mock"))]
185fn io_spec() -> &'static EngineIoSpec {
186 ENGINE_IO_SPEC
187 .get()
188 .expect("engine::preload() must initialize input binding")
189}
190
191#[cfg(not(feature = "engine-mock"))]
192fn use_positional_inputs(input_names: &[&str]) -> bool {
193 matches!(input_names, ["input", "x"])
194}
195
196#[cfg(not(feature = "engine-mock"))]
197fn inspect_engine_io(session: &Session) -> Result<EngineIoSpec> {
198 let input_names: Vec<&str> = session.inputs().iter().map(|input| input.name()).collect();
199 let output_names: Vec<&str> = session
200 .outputs()
201 .iter()
202 .map(|output| output.name())
203 .collect();
204
205 if !output_names.contains(&"output") {
206 return Err(anyhow!("Model missing output 'output' (freq domain)").into());
207 }
208 if !output_names.contains(&"add_67") {
209 return Err(anyhow!("Model missing output 'add_67' (time domain)").into());
210 }
211
212 Ok(EngineIoSpec {
213 use_positional_inputs: use_positional_inputs(&input_names),
214 })
215}
216
217#[cfg(not(feature = "engine-mock"))]
218fn format_ms(ns: u128) -> f64 {
219 ns as f64 / 1_000_000.0
220}
221
222#[cfg(not(feature = "engine-mock"))]
223fn log_window_perf(perf: &WindowPerf) {
224 eprintln!(
225 "⏱️ window total={:.2}ms prep={:.2}ms stft={:.2}ms lock={:.2}ms run={:.2}ms extract={:.2}ms decode={:.2}ms istft={:.2}ms mix={:.2}ms",
226 format_ms(perf.total_ns),
227 format_ms(perf.prep_ns),
228 format_ms(perf.stft_ns),
229 format_ms(perf.lock_wait_ns),
230 format_ms(perf.run_ns),
231 format_ms(perf.extract_ns),
232 format_ms(perf.decode_ns),
233 format_ms(perf.istft_ns),
234 format_ms(perf.mix_ns),
235 );
236}
237
238#[cfg(not(feature = "engine-mock"))]
239fn cpu_threading(num_threads: usize) -> OrtThreading {
240 let base = OrtThreading {
241 intra_threads: num_threads.max(1),
242 inter_threads: 1,
243 parallel_execution: false,
244 };
245 apply_thread_overrides(base)
246}
247
248#[cfg(not(feature = "engine-mock"))]
249fn ep_threading(kind: ep::EpKind, num_threads: usize) -> OrtThreading {
250 let base = match kind {
251 ep::EpKind::Cuda | ep::EpKind::CoreML | ep::EpKind::DirectML => OrtThreading {
252 intra_threads: num_threads.clamp(1, 4),
253 inter_threads: 1,
254 parallel_execution: false,
255 },
256 ep::EpKind::OneDNN | ep::EpKind::Cpu => OrtThreading {
257 intra_threads: num_threads.max(1),
258 inter_threads: 1,
259 parallel_execution: false,
260 },
261 ep::EpKind::Xnnpack => OrtThreading {
262 intra_threads: 1,
263 inter_threads: 1,
264 parallel_execution: false,
265 },
266 };
267 apply_thread_overrides(base)
268}
269
270#[cfg(not(feature = "engine-mock"))]
271fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
272 let threading = cpu_threading(num_threads);
273
274 if std::env::var("DEBUG_STEMS").is_ok() {
275 eprintln!(
276 "ℹ️ ORT CPU threading: intra={}, inter={}, parallel={}",
277 threading.intra_threads, threading.inter_threads, threading.parallel_execution
278 );
279 }
280
281 Ok(SessionBuilder::new()?
282 .with_optimization_level(GraphOptimizationLevel::Level3)?
283 .with_intra_threads(threading.intra_threads)?
284 .with_inter_threads(threading.inter_threads)?
285 .with_parallel_execution(threading.parallel_execution)?
286 .commit_from_file(model_path)?)
287}
288
289#[cfg(not(feature = "engine-mock"))]
290fn commit_ep_session(
291 model_path: &std::path::Path,
292 num_threads: usize,
293 kind: ep::EpKind,
294 provider: ort::execution_providers::ExecutionProviderDispatch,
295) -> Result<Session> {
296 let threading = ep_threading(kind, num_threads);
297
298 if std::env::var("DEBUG_STEMS").is_ok() {
299 eprintln!(
300 "ℹ️ ORT EP threading: intra={}, inter={}, parallel={}",
301 threading.intra_threads, threading.inter_threads, threading.parallel_execution
302 );
303 }
304
305 let mut builder = SessionBuilder::new()?
306 .with_optimization_level(GraphOptimizationLevel::Level3)?
307 .with_intra_threads(threading.intra_threads)?
308 .with_inter_threads(threading.inter_threads)?
309 .with_parallel_execution(threading.parallel_execution)?
310 .with_execution_providers(vec![provider])?;
311
312 if matches!(kind, ep::EpKind::Xnnpack) {
313 builder = builder
314 .with_intra_op_spinning(false)?
315 .with_inter_op_spinning(false)?;
316 }
317
318 Ok(builder.commit_from_file(model_path)?)
319}
320
321#[cfg(not(feature = "engine-mock"))]
322fn run_demucs_raw_from_inputs(
323 session: &mut Session,
324 io_spec: &EngineIoSpec,
325 t: usize,
326 f_bins: usize,
327 frames: usize,
328 time_branch: &[f32],
329 spec_branch: &[f32],
330 perf_enabled: bool,
331 perf: &mut WindowPerf,
332) -> Result<(Value, Value)> {
333 let time_value = TensorRef::from_array_view(([1usize, 2, t], time_branch))?;
334 let spec_value = TensorRef::from_array_view(([1usize, 4, f_bins, frames], spec_branch))?;
335
336 let run_start = perf_enabled.then(Instant::now);
337 let mut outputs = if io_spec.use_positional_inputs {
338 session.run(ort::inputs![time_value, spec_value])?
339 } else {
340 session.run(ort::inputs!["input" => time_value, "x" => spec_value])?
341 };
342 if let Some(start) = run_start {
343 perf.run_ns += start.elapsed().as_nanos();
344 }
345
346 let extract_start = perf_enabled.then(Instant::now);
347 let out_freq = outputs
348 .remove("output")
349 .ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
350 let out_time = outputs
351 .remove("add_67")
352 .ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
353 if let Some(start) = extract_start {
354 perf.extract_ns += start.elapsed().as_nanos();
355 }
356
357 Ok((out_time, out_freq))
358}
359
360#[cfg(not(feature = "engine-mock"))]
361fn decode_demucs_outputs(
362 out_time: Value,
363 out_freq: Value,
364 t: usize,
365 f_bins: usize,
366 frames: usize,
367 perf_enabled: bool,
368 perf: &mut WindowPerf,
369) -> Result<DemucsRawOutput> {
370 let decode_start = perf_enabled.then(Instant::now);
371
372 let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
373 if shape_time.len() != 4
374 || shape_time[0] != 1
375 || shape_time[2] != 2
376 || shape_time[3] != t as i64
377 {
378 return Err(anyhow!(
379 "Unexpected time output shape: {:?}, expected [1, sources, 2, {}]",
380 shape_time,
381 t
382 )
383 .into());
384 }
385 let num_sources = shape_time[1] as usize;
386
387 let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
388 if shape_freq.len() != 5
389 || shape_freq[0] != 1
390 || shape_freq[1] != num_sources as i64
391 || shape_freq[2] != 4
392 || shape_freq[3] != f_bins as i64
393 || shape_freq[4] != frames as i64
394 {
395 return Err(anyhow!(
396 "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
397 shape_freq,
398 num_sources,
399 f_bins,
400 frames
401 )
402 .into());
403 }
404
405 let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
406 let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
407
408 let raw = DemucsRawOutput {
409 num_sources,
410 data_time: data_time.to_vec(),
411 data_freq: data_freq.to_vec(),
412 time_max,
413 freq_max,
414 };
415
416 if let Some(start) = decode_start {
417 perf.decode_ns += start.elapsed().as_nanos();
418 }
419
420 Ok(raw)
421}
422
423#[cfg(not(feature = "engine-mock"))]
424fn prepare_demucs_inputs(
425 left: &[f32],
426 right: &[f32],
427 scratch: &mut InferenceScratch,
428 perf_enabled: bool,
429 perf: &mut WindowPerf,
430) -> Result<(usize, usize, usize)> {
431 if left.len() != right.len() {
432 return Err(anyhow!("L/R length mismatch").into());
433 }
434 let t = left.len();
435 if t != DEMUCS_T {
436 return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
437 }
438
439 let prep_start = perf_enabled.then(Instant::now);
440 scratch.fill_time_branch(left, right);
441
442 let stft_start = perf_enabled.then(Instant::now);
443 let (f_bins, frames) = stft_cac_stereo_centered_into(
444 left,
445 right,
446 DEMUCS_NFFT,
447 DEMUCS_HOP,
448 &mut scratch.spec_branch,
449 );
450 if let Some(start) = stft_start {
451 perf.stft_ns += start.elapsed().as_nanos();
452 }
453 if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
454 return Err(anyhow!(
455 "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
456 f_bins,
457 frames,
458 DEMUCS_F,
459 DEMUCS_FRAMES
460 )
461 .into());
462 }
463
464 if let Some(start) = prep_start {
465 perf.prep_ns += start.elapsed().as_nanos();
466 }
467
468 Ok((t, f_bins, frames))
469}
470
471#[cfg(not(feature = "engine-mock"))]
472fn run_demucs_raw_with_session(
473 session: &mut Session,
474 io_spec: &EngineIoSpec,
475 scratch: &mut InferenceScratch,
476 left: &[f32],
477 right: &[f32],
478 perf_enabled: bool,
479 perf: &mut WindowPerf,
480) -> Result<DemucsRawOutput> {
481 let (t, f_bins, frames) = prepare_demucs_inputs(left, right, scratch, perf_enabled, perf)?;
482 let (out_time, out_freq) = run_demucs_raw_from_inputs(
483 session,
484 io_spec,
485 t,
486 f_bins,
487 frames,
488 &scratch.time_branch,
489 &scratch.spec_branch,
490 perf_enabled,
491 perf,
492 )?;
493 decode_demucs_outputs(out_time, out_freq, t, f_bins, frames, perf_enabled, perf)
494}
495
496#[cfg(not(feature = "engine-mock"))]
497pub fn preload(h: &ModelHandle) -> Result<()> {
498 ORT_INIT.get_or_try_init::<_, StemError>(|| {
499 let _ = ort::init().commit();
500 Ok(())
501 })?;
502
503 let num_threads = std::thread::available_parallelism()
504 .map(|n| n.get())
505 .unwrap_or(4);
506
507 let selected = ep::create_best_session(
508 h.local_path.as_path(),
509 num_threads,
510 commit_cpu_session,
511 commit_ep_session,
512 probe_session_health,
513 )?;
514
515 let selected_io = inspect_engine_io(&selected.session)?;
516 ENGINE_IO_SPEC.set(selected_io).ok();
517 ENGINE_PERF.set(*perf_config()).ok();
518 INPUT_SCRATCH
519 .set(Mutex::new(InferenceScratch::with_demucs_capacity()))
520 .ok();
521 ISTFT_SCRATCH
522 .set(Mutex::new(IstftBatchWorkspace::default()))
523 .ok();
524
525 if std::env::var("DEBUG_STEMS").is_ok() {
526 eprintln!(
527 "ℹ️ Engine input binding: {}",
528 if selected_io.use_positional_inputs {
529 "positional"
530 } else {
531 "named"
532 }
533 );
534 }
535
536 ENGINE_CONTEXT
537 .set(EngineContext {
538 model_path: h.local_path.clone(),
539 num_threads,
540 selected_kind: selected.kind,
541 })
542 .ok();
543 RUNTIME_EP_FALLBACK_USED.store(false, Ordering::Relaxed);
544
545 SESSION.set(Mutex::new(selected.session)).ok();
546 MANIFEST.set(h.manifest.clone()).ok();
547 Ok(())
548}
549
550#[cfg(not(feature = "engine-mock"))]
551pub fn manifest() -> &'static ModelManifest {
552 MANIFEST
553 .get()
554 .expect("engine::preload() must be called once before using the engine")
555}
556
557#[cfg(not(feature = "engine-mock"))]
558const NEAR_SILENT_ERROR_PREFIX: &str = "near-silent execution output";
559
560#[cfg(not(feature = "engine-mock"))]
561enum RuntimeFallbackDecision {
562 RetryOnCpu,
563 ForcedProviderError,
564 PropagateOriginal,
565}
566
567#[cfg(not(feature = "engine-mock"))]
568fn output_is_near_silent(time_max: f32, freq_max: f32) -> bool {
569 time_max < 1e-6 && freq_max < 1e-3
570}
571
572#[cfg(not(feature = "engine-mock"))]
573fn input_is_near_silent(left: &[f32], right: &[f32]) -> bool {
574 let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
575 let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
576 left_max.max(right_max) < 1e-4
577}
578
579#[cfg(not(feature = "engine-mock"))]
580fn build_preload_probe_input() -> (Vec<f32>, Vec<f32>) {
581 use std::f32::consts::TAU;
582
583 let sample_rate = 44_100.0f32;
584 let mut left = Vec::with_capacity(DEMUCS_T);
585 let mut right = Vec::with_capacity(DEMUCS_T);
586
587 for i in 0..DEMUCS_T {
588 let t = i as f32 / sample_rate;
589 left.push(0.22 * (TAU * 220.0 * t).sin() + 0.11 * (TAU * 660.0 * t).sin());
590 right.push(0.20 * (TAU * 330.0 * t).sin() + 0.09 * (TAU * 880.0 * t).cos());
591 }
592
593 (left, right)
594}
595
596#[cfg(not(feature = "engine-mock"))]
597fn preload_probe_input() -> &'static (Vec<f32>, Vec<f32>) {
598 PRELOAD_PROBE_INPUT.get_or_init(build_preload_probe_input)
599}
600
601#[cfg(not(feature = "engine-mock"))]
602fn ensure_output_is_not_near_silent(
603 left: &[f32],
604 right: &[f32],
605 raw: &DemucsRawOutput,
606) -> Result<()> {
607 if !input_is_near_silent(left, right) && output_is_near_silent(raw.time_max, raw.freq_max) {
608 return Err(anyhow!(
609 "{} (time_max={:.3e}, freq_max={:.3e})",
610 NEAR_SILENT_ERROR_PREFIX,
611 raw.time_max,
612 raw.freq_max
613 )
614 .into());
615 }
616
617 Ok(())
618}
619
620#[cfg(not(feature = "engine-mock"))]
621fn probe_session_health(session: &mut Session) -> Result<()> {
622 let (left, right) = preload_probe_input();
623 let io_spec = inspect_engine_io(session)?;
624 let mut scratch = InferenceScratch::with_demucs_capacity();
625 let mut perf = WindowPerf::default();
626 let raw = run_demucs_raw_with_session(
627 session,
628 &io_spec,
629 &mut scratch,
630 left,
631 right,
632 false,
633 &mut perf,
634 )?;
635 ensure_output_is_not_near_silent(left, right, &raw)
636}
637
638#[cfg(not(feature = "engine-mock"))]
639fn is_forced_non_cpu_ep() -> bool {
640 let Ok(value) = std::env::var("STEMMER_EP_FORCE") else {
641 return false;
642 };
643
644 let v = value.trim().to_ascii_lowercase();
645 !v.is_empty() && v != "cpu"
646}
647
648#[cfg(not(feature = "engine-mock"))]
649fn near_silent_error(message: &str) -> bool {
650 message.contains(NEAR_SILENT_ERROR_PREFIX)
651}
652
653#[cfg(not(feature = "engine-mock"))]
654fn runtime_fallback_decision(
655 error_text: &str,
656 forced_non_cpu_ep: bool,
657 fallback_already_used: bool,
658) -> RuntimeFallbackDecision {
659 if !near_silent_error(error_text) {
660 return RuntimeFallbackDecision::PropagateOriginal;
661 }
662 if forced_non_cpu_ep {
663 return RuntimeFallbackDecision::ForcedProviderError;
664 }
665 if fallback_already_used {
666 return RuntimeFallbackDecision::PropagateOriginal;
667 }
668 RuntimeFallbackDecision::RetryOnCpu
669}
670
671#[cfg(not(feature = "engine-mock"))]
672pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
673 if left.len() != right.len() {
674 return Err(anyhow!("L/R length mismatch").into());
675 }
676 if left.len() != DEMUCS_T {
677 return Err(anyhow!("Bad window length {} (expected {})", left.len(), DEMUCS_T).into());
678 }
679
680 let debug_enabled = std::env::var("DEBUG_STEMS").is_ok();
681 let perf_enabled = perf_config().enabled;
682
683 match run_window_demucs_once(left, right, debug_enabled, perf_enabled) {
684 Ok(out) => Ok(out),
685 Err(e) => {
686 let error_text = e.to_string();
687 let forced_non_cpu_ep = is_forced_non_cpu_ep();
688 let fallback_already_used = RUNTIME_EP_FALLBACK_USED.load(Ordering::SeqCst);
689
690 match runtime_fallback_decision(&error_text, forced_non_cpu_ep, fallback_already_used) {
691 RuntimeFallbackDecision::ForcedProviderError => {
692 if debug_enabled {
693 eprintln!(
694 "⚠️ Runtime EP output was near-silent and STEMMER_EP_FORCE is set; refusing CPU fallback"
695 );
696 }
697 return Err(anyhow!(
698 "Forced execution provider produced near-silent runtime output; refusing CPU fallback"
699 )
700 .into());
701 }
702 RuntimeFallbackDecision::PropagateOriginal => {
703 if near_silent_error(&error_text) && debug_enabled {
704 eprintln!(
705 "⚠️ Runtime EP output remained near-silent after fallback; propagating original error"
706 );
707 }
708 return Err(e);
709 }
710 RuntimeFallbackDecision::RetryOnCpu => {}
711 }
712
713 RUNTIME_EP_FALLBACK_USED.store(true, Ordering::SeqCst);
714
715 let ctx = ENGINE_CONTEXT
716 .get()
717 .ok_or_else(|| anyhow!("engine context missing for runtime fallback"))?;
718
719 if ctx.selected_kind != ep::EpKind::Cpu {
720 if let Err(cache_err) = ep_cache::mark_unhealthy(
721 ctx.selected_kind.env_name(),
722 &ctx.model_path,
723 &error_text,
724 ) {
725 if debug_enabled {
726 eprintln!(
727 "⚠️ Failed to persist unhealthy EP cache entry: {}",
728 cache_err
729 );
730 }
731 } else if debug_enabled {
732 eprintln!(
733 "ℹ️ Marked {} as unhealthy for this model (cached for 7 days)",
734 ctx.selected_kind.label()
735 );
736 }
737 }
738
739 if debug_enabled {
740 eprintln!(
741 "⚠️ Runtime EP output was near-silent; switching to CPU and retrying this chunk"
742 );
743 }
744
745 let cpu_session = commit_cpu_session(&ctx.model_path, ctx.num_threads)?;
746 let mut session = SESSION
747 .get()
748 .expect("engine::preload first")
749 .lock()
750 .expect("session poisoned");
751 *session = cpu_session;
752 drop(session);
753
754 match run_window_demucs_once(left, right, debug_enabled, perf_enabled) {
755 Ok(out) => {
756 if debug_enabled {
757 eprintln!("✅ Runtime fallback succeeded: CPU is now active");
758 }
759 Ok(out)
760 }
761 Err(retry_error) => {
762 if debug_enabled {
763 eprintln!("❌ Runtime fallback to CPU failed: {}", retry_error);
764 }
765 Err(retry_error)
766 }
767 }
768 }
769 }
770}
771
772#[cfg(not(feature = "engine-mock"))]
773fn run_window_demucs_once(
774 left: &[f32],
775 right: &[f32],
776 debug_enabled: bool,
777 perf_enabled: bool,
778) -> Result<Array3<f32>> {
779 let total_start = perf_enabled.then(Instant::now);
780 let mut perf = WindowPerf::default();
781
782 let raw = {
783 let mut scratch = input_scratch().lock().expect("input scratch poisoned");
784 let (t, f_bins, frames) =
785 prepare_demucs_inputs(left, right, &mut scratch, perf_enabled, &mut perf)?;
786
787 let lock_start = perf_enabled.then(Instant::now);
788 let mut session = SESSION
789 .get()
790 .expect("engine::preload first")
791 .lock()
792 .expect("session poisoned");
793 if let Some(start) = lock_start {
794 perf.lock_wait_ns += start.elapsed().as_nanos();
795 }
796
797 let (out_time, out_freq) = run_demucs_raw_from_inputs(
798 &mut session,
799 io_spec(),
800 t,
801 f_bins,
802 frames,
803 &scratch.time_branch,
804 &scratch.spec_branch,
805 perf_enabled,
806 &mut perf,
807 )?;
808 drop(session);
809 drop(scratch);
810
811 decode_demucs_outputs(
812 out_time,
813 out_freq,
814 t,
815 f_bins,
816 frames,
817 perf_enabled,
818 &mut perf,
819 )?
820 };
821
822 let out = postprocess_demucs_output(raw, left, right, debug_enabled, perf_enabled, &mut perf)?;
823
824 if let Some(start) = total_start {
825 perf.total_ns = start.elapsed().as_nanos();
826 log_window_perf(&perf);
827 }
828
829 Ok(out)
830}
831
832#[cfg(not(feature = "engine-mock"))]
833fn postprocess_demucs_output(
834 mut raw: DemucsRawOutput,
835 left: &[f32],
836 right: &[f32],
837 debug_enabled: bool,
838 perf_enabled: bool,
839 perf: &mut WindowPerf,
840) -> Result<Array3<f32>> {
841 let t = left.len();
842 let num_sources = raw.num_sources;
843
844 if debug_enabled {
845 eprintln!(
846 "Model output stats: time_max={:.6}, freq_max={:.6}",
847 raw.time_max, raw.freq_max
848 );
849 }
850
851 ensure_output_is_not_near_silent(left, right, &raw)?;
852
853 let source_specs: Vec<&[f32]> = (0..num_sources)
854 .map(|src| {
855 let src_freq_offset = src * 4 * DEMUCS_F * DEMUCS_FRAMES;
856 &raw.data_freq[src_freq_offset..src_freq_offset + 4 * DEMUCS_F * DEMUCS_FRAMES]
857 })
858 .collect();
859
860 let istft_start = perf_enabled.then(Instant::now);
861 {
862 let mut istft_ws = istft_scratch().lock().expect("iSTFT scratch poisoned");
863 istft_cac_stereo_sources_add_into(
864 &source_specs,
865 DEMUCS_F,
866 DEMUCS_FRAMES,
867 DEMUCS_NFFT,
868 DEMUCS_HOP,
869 t,
870 &mut istft_ws,
871 &mut raw.data_time,
872 );
873 }
874 if let Some(start) = istft_start {
875 perf.istft_ns += start.elapsed().as_nanos();
876 }
877
878 if debug_enabled {
879 for src_idx in 0..num_sources {
880 let src_time_offset = src_idx * 2 * t;
881 let left_mix = &raw.data_time[src_time_offset..src_time_offset + t];
882 let right_mix = &raw.data_time[src_time_offset + t..src_time_offset + 2 * t];
883 let left_max = left_mix.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
884 let right_max = right_mix.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
885 eprintln!(
886 "Combined output [source {}]: left_max={:.6}, right_max={:.6}",
887 src_idx, left_max, right_max
888 );
889 }
890 }
891
892 let mix_start = perf_enabled.then(Instant::now);
893
894 if let Some(start) = mix_start {
895 perf.mix_ns += start.elapsed().as_nanos();
896 }
897
898 Ok(ndarray::Array3::from_shape_vec(
899 (num_sources, 2, t),
900 raw.data_time,
901 )?)
902}
903
904#[cfg(not(feature = "engine-mock"))]
905#[cfg(test)]
906mod runtime_policy_tests {
907 use super::*;
908
909 #[test]
910 fn fallback_retries_on_cpu_when_near_silent_and_not_forced() {
911 let decision = runtime_fallback_decision(
912 "near-silent execution output (time_max=0, freq_max=0)",
913 false,
914 false,
915 );
916 assert!(matches!(decision, RuntimeFallbackDecision::RetryOnCpu));
917 }
918
919 #[test]
920 fn fallback_refuses_when_forced_provider() {
921 let decision = runtime_fallback_decision(
922 "near-silent execution output (time_max=0, freq_max=0)",
923 true,
924 false,
925 );
926 assert!(matches!(
927 decision,
928 RuntimeFallbackDecision::ForcedProviderError
929 ));
930 }
931
932 #[test]
933 fn fallback_does_not_retry_twice() {
934 let decision = runtime_fallback_decision(
935 "near-silent execution output (time_max=0, freq_max=0)",
936 false,
937 true,
938 );
939 assert!(matches!(
940 decision,
941 RuntimeFallbackDecision::PropagateOriginal
942 ));
943 }
944
945 #[test]
946 fn fallback_ignores_non_silent_errors() {
947 let decision = runtime_fallback_decision("Model missing input 'x'", false, false);
948 assert!(matches!(
949 decision,
950 RuntimeFallbackDecision::PropagateOriginal
951 ));
952 }
953
954 #[test]
955 fn near_silent_threshold_checks() {
956 assert!(output_is_near_silent(1e-7, 1e-4));
957 assert!(!output_is_near_silent(1e-4, 1e-4));
958 assert!(!output_is_near_silent(1e-7, 1e-2));
959 }
960
961 #[test]
962 fn input_silence_threshold_checks() {
963 let quiet = vec![0.0f32; 16];
964 let loud = vec![5e-4f32; 16];
965 assert!(input_is_near_silent(&quiet, &quiet));
966 assert!(!input_is_near_silent(&loud, &quiet));
967 }
968
969 #[test]
970 fn preload_probe_input_is_loud_enough_for_health_checks() {
971 let (left, right) = build_preload_probe_input();
972 assert_eq!(left.len(), DEMUCS_T);
973 assert_eq!(right.len(), DEMUCS_T);
974 assert!(!input_is_near_silent(&left, &right));
975 }
976
977 #[test]
978 fn positional_binding_detection_requires_expected_input_order() {
979 assert!(use_positional_inputs(&["input", "x"]));
980 assert!(!use_positional_inputs(&["x", "input"]));
981 assert!(!use_positional_inputs(&["input"]));
982 }
983}
984
985#[cfg(feature = "engine-mock")]
986mod _engine_mock {
987 use super::*;
988 use once_cell::sync::OnceCell;
989 static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
990
991 pub fn preload(h: &ModelHandle) -> Result<()> {
992 MANIFEST.set(h.manifest.clone()).ok();
993 Ok(())
994 }
995
996 pub fn manifest() -> &'static ModelManifest {
997 MANIFEST.get().expect("preload first (mock)")
998 }
999
1000 pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
1001 let t = left.len().min(right.len());
1002 let sources = 4usize;
1003 let mut out = vec![0.0f32; sources * 2 * t];
1004 for s in 0..sources {
1005 for i in 0..t {
1006 out[s * 2 * t + i] = left[i]; out[s * 2 * t + t + i] = right[i]; }
1010 }
1011 Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
1012 }
1013}
1014
1015#[cfg(feature = "engine-mock")]
1016pub use _engine_mock::{manifest, preload, run_window_demucs};