Skip to main content

stem_splitter_core/core/
engine.rs

1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4    core::{
5        dsp::{
6            istft_cac_stereo_sources_add_into, stft_cac_stereo_centered_into, IstftBatchWorkspace,
7        },
8        ep,
9    },
10    error::{Result, StemError},
11    io::ep_cache,
12    model::model_manager::ModelHandle,
13    types::ModelManifest,
14};
15
16use anyhow::anyhow;
17use ndarray::Array3;
18use once_cell::sync::OnceCell;
19use ort::{
20    session::{
21        builder::{GraphOptimizationLevel, SessionBuilder},
22        Session,
23    },
24    value::{TensorRef, Value},
25};
26use std::{
27    path::PathBuf,
28    sync::{
29        atomic::{AtomicBool, Ordering},
30        Mutex,
31    },
32    time::Instant,
33};
34
35static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
36static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
37static ORT_INIT: OnceCell<()> = OnceCell::new();
38#[cfg(not(feature = "engine-mock"))]
39static ENGINE_IO_SPEC: OnceCell<EngineIoSpec> = OnceCell::new();
40#[cfg(not(feature = "engine-mock"))]
41static ENGINE_PERF: OnceCell<EnginePerfConfig> = OnceCell::new();
42#[cfg(not(feature = "engine-mock"))]
43static INPUT_SCRATCH: OnceCell<Mutex<InferenceScratch>> = OnceCell::new();
44#[cfg(not(feature = "engine-mock"))]
45static ISTFT_SCRATCH: OnceCell<Mutex<IstftBatchWorkspace>> = OnceCell::new();
46#[cfg(not(feature = "engine-mock"))]
47static PRELOAD_PROBE_INPUT: OnceCell<(Vec<f32>, Vec<f32>)> = OnceCell::new();
48#[cfg(not(feature = "engine-mock"))]
49static ENGINE_CONTEXT: OnceCell<EngineContext> = OnceCell::new();
50#[cfg(not(feature = "engine-mock"))]
51static RUNTIME_EP_FALLBACK_USED: AtomicBool = AtomicBool::new(false);
52
53const DEMUCS_T: usize = 343_980;
54const DEMUCS_F: usize = 2048;
55const DEMUCS_FRAMES: usize = 336;
56const DEMUCS_NFFT: usize = 4096;
57const DEMUCS_HOP: usize = 1024;
58
59#[cfg(not(feature = "engine-mock"))]
60struct EngineContext {
61    model_path: PathBuf,
62    num_threads: usize,
63    selected_kind: ep::EpKind,
64}
65
66#[cfg(not(feature = "engine-mock"))]
67struct DemucsRawOutput {
68    num_sources: usize,
69    data_time: Vec<f32>,
70    data_freq: Vec<f32>,
71    time_max: f32,
72    freq_max: f32,
73}
74
75#[cfg(not(feature = "engine-mock"))]
76#[derive(Clone, Copy)]
77struct OrtThreading {
78    intra_threads: usize,
79    inter_threads: usize,
80    parallel_execution: bool,
81}
82
83#[cfg(not(feature = "engine-mock"))]
84#[derive(Clone, Copy)]
85struct EngineIoSpec {
86    use_positional_inputs: bool,
87}
88
89#[cfg(not(feature = "engine-mock"))]
90#[derive(Clone, Copy)]
91struct EnginePerfConfig {
92    enabled: bool,
93}
94
95#[cfg(not(feature = "engine-mock"))]
96#[derive(Default)]
97struct WindowPerf {
98    prep_ns: u128,
99    stft_ns: u128,
100    lock_wait_ns: u128,
101    run_ns: u128,
102    extract_ns: u128,
103    decode_ns: u128,
104    istft_ns: u128,
105    mix_ns: u128,
106    total_ns: u128,
107}
108
109#[cfg(not(feature = "engine-mock"))]
110#[derive(Default)]
111struct InferenceScratch {
112    time_branch: Vec<f32>,
113    spec_branch: Vec<f32>,
114}
115
116#[cfg(not(feature = "engine-mock"))]
117impl InferenceScratch {
118    fn with_demucs_capacity() -> Self {
119        Self {
120            time_branch: Vec::with_capacity(2 * DEMUCS_T),
121            spec_branch: Vec::with_capacity(4 * DEMUCS_F * DEMUCS_FRAMES),
122        }
123    }
124
125    fn fill_time_branch(&mut self, left: &[f32], right: &[f32]) {
126        self.time_branch.clear();
127        self.time_branch.extend_from_slice(left);
128        self.time_branch.extend_from_slice(right);
129    }
130}
131
132#[cfg(not(feature = "engine-mock"))]
133fn parse_env_usize(name: &str) -> Option<usize> {
134    let raw = std::env::var(name).ok()?;
135    let parsed = raw.parse::<usize>().ok()?;
136    if parsed == 0 {
137        None
138    } else {
139        Some(parsed)
140    }
141}
142
143#[cfg(not(feature = "engine-mock"))]
144fn parse_env_bool(name: &str) -> Option<bool> {
145    let raw = std::env::var(name).ok()?;
146    match raw.trim().to_ascii_lowercase().as_str() {
147        "1" | "true" | "yes" | "on" => Some(true),
148        "0" | "false" | "no" | "off" => Some(false),
149        _ => None,
150    }
151}
152
153#[cfg(not(feature = "engine-mock"))]
154fn apply_thread_overrides(mut cfg: OrtThreading) -> OrtThreading {
155    if let Some(intra) = parse_env_usize("STEMMER_ORT_INTRA_THREADS") {
156        cfg.intra_threads = intra;
157    }
158    if let Some(inter) = parse_env_usize("STEMMER_ORT_INTER_THREADS") {
159        cfg.inter_threads = inter;
160    }
161    if let Some(parallel) = parse_env_bool("STEMMER_ORT_PARALLEL") {
162        cfg.parallel_execution = parallel;
163    }
164    cfg
165}
166
167#[cfg(not(feature = "engine-mock"))]
168fn perf_config() -> &'static EnginePerfConfig {
169    ENGINE_PERF.get_or_init(|| EnginePerfConfig {
170        enabled: std::env::var("STEMMER_PERF").is_ok(),
171    })
172}
173
174#[cfg(not(feature = "engine-mock"))]
175fn input_scratch() -> &'static Mutex<InferenceScratch> {
176    INPUT_SCRATCH.get_or_init(|| Mutex::new(InferenceScratch::with_demucs_capacity()))
177}
178
179#[cfg(not(feature = "engine-mock"))]
180fn istft_scratch() -> &'static Mutex<IstftBatchWorkspace> {
181    ISTFT_SCRATCH.get_or_init(|| Mutex::new(IstftBatchWorkspace::default()))
182}
183
184#[cfg(not(feature = "engine-mock"))]
185fn io_spec() -> &'static EngineIoSpec {
186    ENGINE_IO_SPEC
187        .get()
188        .expect("engine::preload() must initialize input binding")
189}
190
191#[cfg(not(feature = "engine-mock"))]
192fn use_positional_inputs(input_names: &[&str]) -> bool {
193    matches!(input_names, ["input", "x"])
194}
195
196#[cfg(not(feature = "engine-mock"))]
197fn inspect_engine_io(session: &Session) -> Result<EngineIoSpec> {
198    let input_names: Vec<&str> = session.inputs().iter().map(|input| input.name()).collect();
199    let output_names: Vec<&str> = session
200        .outputs()
201        .iter()
202        .map(|output| output.name())
203        .collect();
204
205    if !output_names.contains(&"output") {
206        return Err(anyhow!("Model missing output 'output' (freq domain)").into());
207    }
208    if !output_names.contains(&"add_67") {
209        return Err(anyhow!("Model missing output 'add_67' (time domain)").into());
210    }
211
212    Ok(EngineIoSpec {
213        use_positional_inputs: use_positional_inputs(&input_names),
214    })
215}
216
217#[cfg(not(feature = "engine-mock"))]
218fn format_ms(ns: u128) -> f64 {
219    ns as f64 / 1_000_000.0
220}
221
222#[cfg(not(feature = "engine-mock"))]
223fn log_window_perf(perf: &WindowPerf) {
224    eprintln!(
225        "⏱️  window total={:.2}ms prep={:.2}ms stft={:.2}ms lock={:.2}ms run={:.2}ms extract={:.2}ms decode={:.2}ms istft={:.2}ms mix={:.2}ms",
226        format_ms(perf.total_ns),
227        format_ms(perf.prep_ns),
228        format_ms(perf.stft_ns),
229        format_ms(perf.lock_wait_ns),
230        format_ms(perf.run_ns),
231        format_ms(perf.extract_ns),
232        format_ms(perf.decode_ns),
233        format_ms(perf.istft_ns),
234        format_ms(perf.mix_ns),
235    );
236}
237
238#[cfg(not(feature = "engine-mock"))]
239fn cpu_threading(num_threads: usize) -> OrtThreading {
240    let base = OrtThreading {
241        intra_threads: num_threads.max(1),
242        inter_threads: 1,
243        parallel_execution: false,
244    };
245    apply_thread_overrides(base)
246}
247
248#[cfg(not(feature = "engine-mock"))]
249fn ep_threading(kind: ep::EpKind, num_threads: usize) -> OrtThreading {
250    let base = match kind {
251        ep::EpKind::Cuda | ep::EpKind::CoreML | ep::EpKind::DirectML => OrtThreading {
252            intra_threads: num_threads.clamp(1, 4),
253            inter_threads: 1,
254            parallel_execution: false,
255        },
256        ep::EpKind::OneDNN | ep::EpKind::Cpu => OrtThreading {
257            intra_threads: num_threads.max(1),
258            inter_threads: 1,
259            parallel_execution: false,
260        },
261        ep::EpKind::Xnnpack => OrtThreading {
262            intra_threads: 1,
263            inter_threads: 1,
264            parallel_execution: false,
265        },
266    };
267    apply_thread_overrides(base)
268}
269
270#[cfg(not(feature = "engine-mock"))]
271fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
272    let threading = cpu_threading(num_threads);
273
274    if std::env::var("DEBUG_STEMS").is_ok() {
275        eprintln!(
276            "ℹ️  ORT CPU threading: intra={}, inter={}, parallel={}",
277            threading.intra_threads, threading.inter_threads, threading.parallel_execution
278        );
279    }
280
281    Ok(SessionBuilder::new()?
282        .with_optimization_level(GraphOptimizationLevel::Level3)?
283        .with_intra_threads(threading.intra_threads)?
284        .with_inter_threads(threading.inter_threads)?
285        .with_parallel_execution(threading.parallel_execution)?
286        .commit_from_file(model_path)?)
287}
288
289#[cfg(not(feature = "engine-mock"))]
290fn commit_ep_session(
291    model_path: &std::path::Path,
292    num_threads: usize,
293    kind: ep::EpKind,
294    provider: ort::execution_providers::ExecutionProviderDispatch,
295) -> Result<Session> {
296    let threading = ep_threading(kind, num_threads);
297
298    if std::env::var("DEBUG_STEMS").is_ok() {
299        eprintln!(
300            "ℹ️  ORT EP threading: intra={}, inter={}, parallel={}",
301            threading.intra_threads, threading.inter_threads, threading.parallel_execution
302        );
303    }
304
305    let mut builder = SessionBuilder::new()?
306        .with_optimization_level(GraphOptimizationLevel::Level3)?
307        .with_intra_threads(threading.intra_threads)?
308        .with_inter_threads(threading.inter_threads)?
309        .with_parallel_execution(threading.parallel_execution)?
310        .with_execution_providers(vec![provider])?;
311
312    if matches!(kind, ep::EpKind::Xnnpack) {
313        builder = builder
314            .with_intra_op_spinning(false)?
315            .with_inter_op_spinning(false)?;
316    }
317
318    Ok(builder.commit_from_file(model_path)?)
319}
320
321#[cfg(not(feature = "engine-mock"))]
322fn run_demucs_raw_from_inputs(
323    session: &mut Session,
324    io_spec: &EngineIoSpec,
325    t: usize,
326    f_bins: usize,
327    frames: usize,
328    time_branch: &[f32],
329    spec_branch: &[f32],
330    perf_enabled: bool,
331    perf: &mut WindowPerf,
332) -> Result<(Value, Value)> {
333    let time_value = TensorRef::from_array_view(([1usize, 2, t], time_branch))?;
334    let spec_value = TensorRef::from_array_view(([1usize, 4, f_bins, frames], spec_branch))?;
335
336    let run_start = perf_enabled.then(Instant::now);
337    let mut outputs = if io_spec.use_positional_inputs {
338        session.run(ort::inputs![time_value, spec_value])?
339    } else {
340        session.run(ort::inputs!["input" => time_value, "x" => spec_value])?
341    };
342    if let Some(start) = run_start {
343        perf.run_ns += start.elapsed().as_nanos();
344    }
345
346    let extract_start = perf_enabled.then(Instant::now);
347    let out_freq = outputs
348        .remove("output")
349        .ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
350    let out_time = outputs
351        .remove("add_67")
352        .ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
353    if let Some(start) = extract_start {
354        perf.extract_ns += start.elapsed().as_nanos();
355    }
356
357    Ok((out_time, out_freq))
358}
359
360#[cfg(not(feature = "engine-mock"))]
361fn decode_demucs_outputs(
362    out_time: Value,
363    out_freq: Value,
364    t: usize,
365    f_bins: usize,
366    frames: usize,
367    perf_enabled: bool,
368    perf: &mut WindowPerf,
369) -> Result<DemucsRawOutput> {
370    let decode_start = perf_enabled.then(Instant::now);
371
372    let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
373    if shape_time.len() != 4
374        || shape_time[0] != 1
375        || shape_time[2] != 2
376        || shape_time[3] != t as i64
377    {
378        return Err(anyhow!(
379            "Unexpected time output shape: {:?}, expected [1, sources, 2, {}]",
380            shape_time,
381            t
382        )
383        .into());
384    }
385    let num_sources = shape_time[1] as usize;
386
387    let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
388    if shape_freq.len() != 5
389        || shape_freq[0] != 1
390        || shape_freq[1] != num_sources as i64
391        || shape_freq[2] != 4
392        || shape_freq[3] != f_bins as i64
393        || shape_freq[4] != frames as i64
394    {
395        return Err(anyhow!(
396            "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
397            shape_freq,
398            num_sources,
399            f_bins,
400            frames
401        )
402        .into());
403    }
404
405    let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
406    let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
407
408    let raw = DemucsRawOutput {
409        num_sources,
410        data_time: data_time.to_vec(),
411        data_freq: data_freq.to_vec(),
412        time_max,
413        freq_max,
414    };
415
416    if let Some(start) = decode_start {
417        perf.decode_ns += start.elapsed().as_nanos();
418    }
419
420    Ok(raw)
421}
422
423#[cfg(not(feature = "engine-mock"))]
424fn prepare_demucs_inputs(
425    left: &[f32],
426    right: &[f32],
427    scratch: &mut InferenceScratch,
428    perf_enabled: bool,
429    perf: &mut WindowPerf,
430) -> Result<(usize, usize, usize)> {
431    if left.len() != right.len() {
432        return Err(anyhow!("L/R length mismatch").into());
433    }
434    let t = left.len();
435    if t != DEMUCS_T {
436        return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
437    }
438
439    let prep_start = perf_enabled.then(Instant::now);
440    scratch.fill_time_branch(left, right);
441
442    let stft_start = perf_enabled.then(Instant::now);
443    let (f_bins, frames) = stft_cac_stereo_centered_into(
444        left,
445        right,
446        DEMUCS_NFFT,
447        DEMUCS_HOP,
448        &mut scratch.spec_branch,
449    );
450    if let Some(start) = stft_start {
451        perf.stft_ns += start.elapsed().as_nanos();
452    }
453    if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
454        return Err(anyhow!(
455            "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
456            f_bins,
457            frames,
458            DEMUCS_F,
459            DEMUCS_FRAMES
460        )
461        .into());
462    }
463
464    if let Some(start) = prep_start {
465        perf.prep_ns += start.elapsed().as_nanos();
466    }
467
468    Ok((t, f_bins, frames))
469}
470
471#[cfg(not(feature = "engine-mock"))]
472fn run_demucs_raw_with_session(
473    session: &mut Session,
474    io_spec: &EngineIoSpec,
475    scratch: &mut InferenceScratch,
476    left: &[f32],
477    right: &[f32],
478    perf_enabled: bool,
479    perf: &mut WindowPerf,
480) -> Result<DemucsRawOutput> {
481    let (t, f_bins, frames) = prepare_demucs_inputs(left, right, scratch, perf_enabled, perf)?;
482    let (out_time, out_freq) = run_demucs_raw_from_inputs(
483        session,
484        io_spec,
485        t,
486        f_bins,
487        frames,
488        &scratch.time_branch,
489        &scratch.spec_branch,
490        perf_enabled,
491        perf,
492    )?;
493    decode_demucs_outputs(out_time, out_freq, t, f_bins, frames, perf_enabled, perf)
494}
495
496#[cfg(not(feature = "engine-mock"))]
497pub fn preload(h: &ModelHandle) -> Result<()> {
498    ORT_INIT.get_or_try_init::<_, StemError>(|| {
499        let _ = ort::init().commit();
500        Ok(())
501    })?;
502
503    let num_threads = std::thread::available_parallelism()
504        .map(|n| n.get())
505        .unwrap_or(4);
506
507    let selected = ep::create_best_session(
508        h.local_path.as_path(),
509        num_threads,
510        commit_cpu_session,
511        commit_ep_session,
512        probe_session_health,
513    )?;
514
515    let selected_io = inspect_engine_io(&selected.session)?;
516    ENGINE_IO_SPEC.set(selected_io).ok();
517    ENGINE_PERF.set(*perf_config()).ok();
518    INPUT_SCRATCH
519        .set(Mutex::new(InferenceScratch::with_demucs_capacity()))
520        .ok();
521    ISTFT_SCRATCH
522        .set(Mutex::new(IstftBatchWorkspace::default()))
523        .ok();
524
525    if std::env::var("DEBUG_STEMS").is_ok() {
526        eprintln!(
527            "ℹ️  Engine input binding: {}",
528            if selected_io.use_positional_inputs {
529                "positional"
530            } else {
531                "named"
532            }
533        );
534    }
535
536    ENGINE_CONTEXT
537        .set(EngineContext {
538            model_path: h.local_path.clone(),
539            num_threads,
540            selected_kind: selected.kind,
541        })
542        .ok();
543    RUNTIME_EP_FALLBACK_USED.store(false, Ordering::Relaxed);
544
545    SESSION.set(Mutex::new(selected.session)).ok();
546    MANIFEST.set(h.manifest.clone()).ok();
547    Ok(())
548}
549
550#[cfg(not(feature = "engine-mock"))]
551pub fn manifest() -> &'static ModelManifest {
552    MANIFEST
553        .get()
554        .expect("engine::preload() must be called once before using the engine")
555}
556
557#[cfg(not(feature = "engine-mock"))]
558const NEAR_SILENT_ERROR_PREFIX: &str = "near-silent execution output";
559
560#[cfg(not(feature = "engine-mock"))]
561enum RuntimeFallbackDecision {
562    RetryOnCpu,
563    ForcedProviderError,
564    PropagateOriginal,
565}
566
567#[cfg(not(feature = "engine-mock"))]
568fn output_is_near_silent(time_max: f32, freq_max: f32) -> bool {
569    time_max < 1e-6 && freq_max < 1e-3
570}
571
572#[cfg(not(feature = "engine-mock"))]
573fn input_is_near_silent(left: &[f32], right: &[f32]) -> bool {
574    let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
575    let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
576    left_max.max(right_max) < 1e-4
577}
578
579#[cfg(not(feature = "engine-mock"))]
580fn build_preload_probe_input() -> (Vec<f32>, Vec<f32>) {
581    use std::f32::consts::TAU;
582
583    let sample_rate = 44_100.0f32;
584    let mut left = Vec::with_capacity(DEMUCS_T);
585    let mut right = Vec::with_capacity(DEMUCS_T);
586
587    for i in 0..DEMUCS_T {
588        let t = i as f32 / sample_rate;
589        left.push(0.22 * (TAU * 220.0 * t).sin() + 0.11 * (TAU * 660.0 * t).sin());
590        right.push(0.20 * (TAU * 330.0 * t).sin() + 0.09 * (TAU * 880.0 * t).cos());
591    }
592
593    (left, right)
594}
595
596#[cfg(not(feature = "engine-mock"))]
597fn preload_probe_input() -> &'static (Vec<f32>, Vec<f32>) {
598    PRELOAD_PROBE_INPUT.get_or_init(build_preload_probe_input)
599}
600
601#[cfg(not(feature = "engine-mock"))]
602fn ensure_output_is_not_near_silent(
603    left: &[f32],
604    right: &[f32],
605    raw: &DemucsRawOutput,
606) -> Result<()> {
607    if !input_is_near_silent(left, right) && output_is_near_silent(raw.time_max, raw.freq_max) {
608        return Err(anyhow!(
609            "{} (time_max={:.3e}, freq_max={:.3e})",
610            NEAR_SILENT_ERROR_PREFIX,
611            raw.time_max,
612            raw.freq_max
613        )
614        .into());
615    }
616
617    Ok(())
618}
619
620#[cfg(not(feature = "engine-mock"))]
621fn probe_session_health(session: &mut Session) -> Result<()> {
622    let (left, right) = preload_probe_input();
623    let io_spec = inspect_engine_io(session)?;
624    let mut scratch = InferenceScratch::with_demucs_capacity();
625    let mut perf = WindowPerf::default();
626    let raw = run_demucs_raw_with_session(
627        session,
628        &io_spec,
629        &mut scratch,
630        left,
631        right,
632        false,
633        &mut perf,
634    )?;
635    ensure_output_is_not_near_silent(left, right, &raw)
636}
637
638#[cfg(not(feature = "engine-mock"))]
639fn is_forced_non_cpu_ep() -> bool {
640    let Ok(value) = std::env::var("STEMMER_EP_FORCE") else {
641        return false;
642    };
643
644    let v = value.trim().to_ascii_lowercase();
645    !v.is_empty() && v != "cpu"
646}
647
648#[cfg(not(feature = "engine-mock"))]
649fn near_silent_error(message: &str) -> bool {
650    message.contains(NEAR_SILENT_ERROR_PREFIX)
651}
652
653#[cfg(not(feature = "engine-mock"))]
654fn runtime_fallback_decision(
655    error_text: &str,
656    forced_non_cpu_ep: bool,
657    fallback_already_used: bool,
658) -> RuntimeFallbackDecision {
659    if !near_silent_error(error_text) {
660        return RuntimeFallbackDecision::PropagateOriginal;
661    }
662    if forced_non_cpu_ep {
663        return RuntimeFallbackDecision::ForcedProviderError;
664    }
665    if fallback_already_used {
666        return RuntimeFallbackDecision::PropagateOriginal;
667    }
668    RuntimeFallbackDecision::RetryOnCpu
669}
670
671#[cfg(not(feature = "engine-mock"))]
672pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
673    if left.len() != right.len() {
674        return Err(anyhow!("L/R length mismatch").into());
675    }
676    if left.len() != DEMUCS_T {
677        return Err(anyhow!("Bad window length {} (expected {})", left.len(), DEMUCS_T).into());
678    }
679
680    let debug_enabled = std::env::var("DEBUG_STEMS").is_ok();
681    let perf_enabled = perf_config().enabled;
682
683    match run_window_demucs_once(left, right, debug_enabled, perf_enabled) {
684        Ok(out) => Ok(out),
685        Err(e) => {
686            let error_text = e.to_string();
687            let forced_non_cpu_ep = is_forced_non_cpu_ep();
688            let fallback_already_used = RUNTIME_EP_FALLBACK_USED.load(Ordering::SeqCst);
689
690            match runtime_fallback_decision(&error_text, forced_non_cpu_ep, fallback_already_used) {
691                RuntimeFallbackDecision::ForcedProviderError => {
692                    if debug_enabled {
693                        eprintln!(
694                            "⚠️  Runtime EP output was near-silent and STEMMER_EP_FORCE is set; refusing CPU fallback"
695                        );
696                    }
697                    return Err(anyhow!(
698                        "Forced execution provider produced near-silent runtime output; refusing CPU fallback"
699                    )
700                    .into());
701                }
702                RuntimeFallbackDecision::PropagateOriginal => {
703                    if near_silent_error(&error_text) && debug_enabled {
704                        eprintln!(
705                            "⚠️  Runtime EP output remained near-silent after fallback; propagating original error"
706                        );
707                    }
708                    return Err(e);
709                }
710                RuntimeFallbackDecision::RetryOnCpu => {}
711            }
712
713            RUNTIME_EP_FALLBACK_USED.store(true, Ordering::SeqCst);
714
715            let ctx = ENGINE_CONTEXT
716                .get()
717                .ok_or_else(|| anyhow!("engine context missing for runtime fallback"))?;
718
719            if ctx.selected_kind != ep::EpKind::Cpu {
720                if let Err(cache_err) = ep_cache::mark_unhealthy(
721                    ctx.selected_kind.env_name(),
722                    &ctx.model_path,
723                    &error_text,
724                ) {
725                    if debug_enabled {
726                        eprintln!(
727                            "⚠️  Failed to persist unhealthy EP cache entry: {}",
728                            cache_err
729                        );
730                    }
731                } else if debug_enabled {
732                    eprintln!(
733                        "ℹ️  Marked {} as unhealthy for this model (cached for 7 days)",
734                        ctx.selected_kind.label()
735                    );
736                }
737            }
738
739            if debug_enabled {
740                eprintln!(
741                    "⚠️  Runtime EP output was near-silent; switching to CPU and retrying this chunk"
742                );
743            }
744
745            let cpu_session = commit_cpu_session(&ctx.model_path, ctx.num_threads)?;
746            let mut session = SESSION
747                .get()
748                .expect("engine::preload first")
749                .lock()
750                .expect("session poisoned");
751            *session = cpu_session;
752            drop(session);
753
754            match run_window_demucs_once(left, right, debug_enabled, perf_enabled) {
755                Ok(out) => {
756                    if debug_enabled {
757                        eprintln!("✅ Runtime fallback succeeded: CPU is now active");
758                    }
759                    Ok(out)
760                }
761                Err(retry_error) => {
762                    if debug_enabled {
763                        eprintln!("❌ Runtime fallback to CPU failed: {}", retry_error);
764                    }
765                    Err(retry_error)
766                }
767            }
768        }
769    }
770}
771
772#[cfg(not(feature = "engine-mock"))]
773fn run_window_demucs_once(
774    left: &[f32],
775    right: &[f32],
776    debug_enabled: bool,
777    perf_enabled: bool,
778) -> Result<Array3<f32>> {
779    let total_start = perf_enabled.then(Instant::now);
780    let mut perf = WindowPerf::default();
781
782    let raw = {
783        let mut scratch = input_scratch().lock().expect("input scratch poisoned");
784        let (t, f_bins, frames) =
785            prepare_demucs_inputs(left, right, &mut scratch, perf_enabled, &mut perf)?;
786
787        let lock_start = perf_enabled.then(Instant::now);
788        let mut session = SESSION
789            .get()
790            .expect("engine::preload first")
791            .lock()
792            .expect("session poisoned");
793        if let Some(start) = lock_start {
794            perf.lock_wait_ns += start.elapsed().as_nanos();
795        }
796
797        let (out_time, out_freq) = run_demucs_raw_from_inputs(
798            &mut session,
799            io_spec(),
800            t,
801            f_bins,
802            frames,
803            &scratch.time_branch,
804            &scratch.spec_branch,
805            perf_enabled,
806            &mut perf,
807        )?;
808        drop(session);
809        drop(scratch);
810
811        decode_demucs_outputs(
812            out_time,
813            out_freq,
814            t,
815            f_bins,
816            frames,
817            perf_enabled,
818            &mut perf,
819        )?
820    };
821
822    let out = postprocess_demucs_output(raw, left, right, debug_enabled, perf_enabled, &mut perf)?;
823
824    if let Some(start) = total_start {
825        perf.total_ns = start.elapsed().as_nanos();
826        log_window_perf(&perf);
827    }
828
829    Ok(out)
830}
831
832#[cfg(not(feature = "engine-mock"))]
833fn postprocess_demucs_output(
834    mut raw: DemucsRawOutput,
835    left: &[f32],
836    right: &[f32],
837    debug_enabled: bool,
838    perf_enabled: bool,
839    perf: &mut WindowPerf,
840) -> Result<Array3<f32>> {
841    let t = left.len();
842    let num_sources = raw.num_sources;
843
844    if debug_enabled {
845        eprintln!(
846            "Model output stats: time_max={:.6}, freq_max={:.6}",
847            raw.time_max, raw.freq_max
848        );
849    }
850
851    ensure_output_is_not_near_silent(left, right, &raw)?;
852
853    let source_specs: Vec<&[f32]> = (0..num_sources)
854        .map(|src| {
855            let src_freq_offset = src * 4 * DEMUCS_F * DEMUCS_FRAMES;
856            &raw.data_freq[src_freq_offset..src_freq_offset + 4 * DEMUCS_F * DEMUCS_FRAMES]
857        })
858        .collect();
859
860    let istft_start = perf_enabled.then(Instant::now);
861    {
862        let mut istft_ws = istft_scratch().lock().expect("iSTFT scratch poisoned");
863        istft_cac_stereo_sources_add_into(
864            &source_specs,
865            DEMUCS_F,
866            DEMUCS_FRAMES,
867            DEMUCS_NFFT,
868            DEMUCS_HOP,
869            t,
870            &mut istft_ws,
871            &mut raw.data_time,
872        );
873    }
874    if let Some(start) = istft_start {
875        perf.istft_ns += start.elapsed().as_nanos();
876    }
877
878    if debug_enabled {
879        for src_idx in 0..num_sources {
880            let src_time_offset = src_idx * 2 * t;
881            let left_mix = &raw.data_time[src_time_offset..src_time_offset + t];
882            let right_mix = &raw.data_time[src_time_offset + t..src_time_offset + 2 * t];
883            let left_max = left_mix.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
884            let right_max = right_mix.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
885            eprintln!(
886                "Combined output [source {}]: left_max={:.6}, right_max={:.6}",
887                src_idx, left_max, right_max
888            );
889        }
890    }
891
892    let mix_start = perf_enabled.then(Instant::now);
893
894    if let Some(start) = mix_start {
895        perf.mix_ns += start.elapsed().as_nanos();
896    }
897
898    Ok(ndarray::Array3::from_shape_vec(
899        (num_sources, 2, t),
900        raw.data_time,
901    )?)
902}
903
904#[cfg(not(feature = "engine-mock"))]
905#[cfg(test)]
906mod runtime_policy_tests {
907    use super::*;
908
909    #[test]
910    fn fallback_retries_on_cpu_when_near_silent_and_not_forced() {
911        let decision = runtime_fallback_decision(
912            "near-silent execution output (time_max=0, freq_max=0)",
913            false,
914            false,
915        );
916        assert!(matches!(decision, RuntimeFallbackDecision::RetryOnCpu));
917    }
918
919    #[test]
920    fn fallback_refuses_when_forced_provider() {
921        let decision = runtime_fallback_decision(
922            "near-silent execution output (time_max=0, freq_max=0)",
923            true,
924            false,
925        );
926        assert!(matches!(
927            decision,
928            RuntimeFallbackDecision::ForcedProviderError
929        ));
930    }
931
932    #[test]
933    fn fallback_does_not_retry_twice() {
934        let decision = runtime_fallback_decision(
935            "near-silent execution output (time_max=0, freq_max=0)",
936            false,
937            true,
938        );
939        assert!(matches!(
940            decision,
941            RuntimeFallbackDecision::PropagateOriginal
942        ));
943    }
944
945    #[test]
946    fn fallback_ignores_non_silent_errors() {
947        let decision = runtime_fallback_decision("Model missing input 'x'", false, false);
948        assert!(matches!(
949            decision,
950            RuntimeFallbackDecision::PropagateOriginal
951        ));
952    }
953
954    #[test]
955    fn near_silent_threshold_checks() {
956        assert!(output_is_near_silent(1e-7, 1e-4));
957        assert!(!output_is_near_silent(1e-4, 1e-4));
958        assert!(!output_is_near_silent(1e-7, 1e-2));
959    }
960
961    #[test]
962    fn input_silence_threshold_checks() {
963        let quiet = vec![0.0f32; 16];
964        let loud = vec![5e-4f32; 16];
965        assert!(input_is_near_silent(&quiet, &quiet));
966        assert!(!input_is_near_silent(&loud, &quiet));
967    }
968
969    #[test]
970    fn preload_probe_input_is_loud_enough_for_health_checks() {
971        let (left, right) = build_preload_probe_input();
972        assert_eq!(left.len(), DEMUCS_T);
973        assert_eq!(right.len(), DEMUCS_T);
974        assert!(!input_is_near_silent(&left, &right));
975    }
976
977    #[test]
978    fn positional_binding_detection_requires_expected_input_order() {
979        assert!(use_positional_inputs(&["input", "x"]));
980        assert!(!use_positional_inputs(&["x", "input"]));
981        assert!(!use_positional_inputs(&["input"]));
982    }
983}
984
985#[cfg(feature = "engine-mock")]
986mod _engine_mock {
987    use super::*;
988    use once_cell::sync::OnceCell;
989    static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
990
991    pub fn preload(h: &ModelHandle) -> Result<()> {
992        MANIFEST.set(h.manifest.clone()).ok();
993        Ok(())
994    }
995
996    pub fn manifest() -> &'static ModelManifest {
997        MANIFEST.get().expect("preload first (mock)")
998    }
999
1000    pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
1001        let t = left.len().min(right.len());
1002        let sources = 4usize;
1003        let mut out = vec![0.0f32; sources * 2 * t];
1004        for s in 0..sources {
1005            for i in 0..t {
1006                // “identity” stems: copy input
1007                out[s * 2 * t + i] = left[i]; // L
1008                out[s * 2 * t + t + i] = right[i]; // R
1009            }
1010        }
1011        Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
1012    }
1013}
1014
1015#[cfg(feature = "engine-mock")]
1016pub use _engine_mock::{manifest, preload, run_window_demucs};