stem_splitter_core/core/
engine.rs

1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4    core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5    error::{Result, StemError},
6    model::model_manager::ModelHandle,
7    types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14    execution_providers::ExecutionProviderDispatch,
15    session::{
16        builder::{GraphOptimizationLevel, SessionBuilder},
17        Session,
18    },
19    value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23// CUDA: Linux and Windows only
24#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26// CoreML: macOS only (Apple Silicon)
27#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29// DirectML: Windows only
30#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32// oneDNN: All platforms
33#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48    let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51    {
52        providers.push(CUDAExecutionProvider::default().build());
53    }
54
55    #[cfg(all(feature = "coreml", target_os = "macos"))]
56    {
57        // CoreML can sometimes produce silent/zero outputs on certain models
58        // Only enable if ENABLE_COREML env var is set
59        if std::env::var("ENABLE_COREML").is_ok() {
60            eprintln!("CoreML enabled via ENABLE_COREML environment variable");
61            providers.push(CoreMLExecutionProvider::default().build());
62        } else {
63            eprintln!("CoreML disabled by default (set ENABLE_COREML=1 to enable)");
64        }
65    }
66
67    #[cfg(all(feature = "directml", target_os = "windows"))]
68    {
69        // DirectML can fail on some models/drivers (init errors). Keep it opt-in.
70        if std::env::var("ENABLE_DIRECTML").is_ok() {
71            eprintln!("DirectML enabled via ENABLE_DIRECTML environment variable");
72            providers.push(DirectMLExecutionProvider::default().build());
73        } else {
74            eprintln!("DirectML disabled by default (set ENABLE_DIRECTML=1 to enable)");
75        }
76    }
77
78    #[cfg(feature = "onednn")]
79    {
80        // oneDNN can improve performance on Intel CPUs
81        providers.push(OneDNNExecutionProvider::default().build());
82    }
83
84    providers
85}
86
87#[cfg(not(feature = "engine-mock"))]
88fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
89    Ok(SessionBuilder::new()?
90        .with_optimization_level(GraphOptimizationLevel::Level3)?
91        .with_intra_threads(num_threads)?
92        .with_inter_threads(num_threads)?
93        .with_parallel_execution(true)?
94        .commit_from_file(model_path)?)
95}
96
97#[cfg(not(feature = "engine-mock"))]
98fn commit_session_sequential_eps(
99    model_path: &std::path::Path,
100    num_threads: usize,
101    providers: Vec<ExecutionProviderDispatch>,
102) -> Result<Session> {
103    if providers.is_empty() {
104        eprintln!("Using CPU ({} threads) - no GPU features enabled", num_threads);
105        return commit_cpu_session(model_path, num_threads);
106    }
107
108    eprintln!(
109        "Trying execution providers sequentially ({} candidates) with CPU fallback",
110        providers.len()
111    );
112
113    for (idx, ep) in providers.into_iter().enumerate() {
114        let builder_res = SessionBuilder::new()?
115            .with_optimization_level(GraphOptimizationLevel::Level3)?
116            .with_intra_threads(num_threads)?
117            .with_inter_threads(num_threads)?
118            .with_execution_providers(vec![ep]);
119
120        let builder = match builder_res {
121            Ok(b) => b,
122            Err(e) => {
123                eprintln!("EP builder failed (attempt #{}): {}", idx + 1, e);
124                continue;
125            }
126        };
127
128        match builder.commit_from_file(model_path) {
129            Ok(sess) => {
130                eprintln!("Execution provider selected (attempt #{}).", idx + 1);
131                return Ok(sess);
132            }
133            Err(e) => {
134                eprintln!("EP commit failed (attempt #{}): {}", idx + 1, e);
135                continue;
136            }
137        }
138    }
139
140    eprintln!("All EPs failed; falling back to CPU ({} threads)", num_threads);
141    commit_cpu_session(model_path, num_threads)
142}
143
144
145#[cfg(not(feature = "engine-mock"))]
146pub fn preload(h: &ModelHandle) -> Result<()> {
147    ORT_INIT.get_or_try_init::<_, StemError>(|| {
148        ort::init().commit().map_err(StemError::from)?;
149        Ok(())
150    })?;
151
152    let num_threads = std::thread::available_parallelism()
153        .map(|n| n.get())
154        .unwrap_or(4);
155
156    // Debug / escape hatch: force CPU
157    if std::env::var("STEMMER_FORCE_CPU").is_ok() {
158        eprintln!("STEMMER_FORCE_CPU is set: using CPU only");
159        let session = commit_cpu_session(h.local_path.as_path(), num_threads)?;
160        SESSION.set(Mutex::new(session)).ok();
161        MANIFEST.set(h.manifest.clone()).ok();
162        return Ok(());
163    }
164
165    // Build provider list (may be empty)
166    let providers = get_execution_providers();
167
168    // Optional: print provider list names (for logs)
169    #[allow(unused_mut)]
170    let mut provider_names: Vec<&str> = Vec::new();
171    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
172    provider_names.push("CUDA");
173    #[cfg(all(feature = "coreml", target_os = "macos"))]
174    provider_names.push("CoreML");
175    #[cfg(all(feature = "directml", target_os = "windows"))]
176    provider_names.push("DirectML (opt-in)");
177    #[cfg(feature = "onednn")]
178    provider_names.push("oneDNN");
179
180    eprintln!("Configured EP candidates: {:?}", provider_names);
181
182    let session = commit_session_sequential_eps(h.local_path.as_path(), num_threads, providers)?;
183
184    SESSION.set(Mutex::new(session)).ok();
185    MANIFEST.set(h.manifest.clone()).ok();
186    Ok(())
187}
188
189#[cfg(not(feature = "engine-mock"))]
190pub fn manifest() -> &'static ModelManifest {
191    MANIFEST
192        .get()
193        .expect("engine::preload() must be called once before using the engine")
194}
195
196#[cfg(not(feature = "engine-mock"))]
197pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
198    if left.len() != right.len() {
199        return Err(anyhow!("L/R length mismatch").into());
200    }
201    let t = left.len();
202    if t != DEMUCS_T {
203        return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
204    }
205
206    // Build time branch [1,2,T], planar
207    let mut planar = Vec::with_capacity(2 * t);
208    planar.extend_from_slice(left);
209    planar.extend_from_slice(right);
210    let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
211
212    // Build spec branch [1,4,F,Frames] with center padding, Hann, 4096/1024
213    let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
214    if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
215        return Err(anyhow!(
216            "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
217            f_bins,
218            frames,
219            DEMUCS_F,
220            DEMUCS_FRAMES
221        )
222        .into());
223    }
224    let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
225
226    let mut session = SESSION
227        .get()
228        .expect("engine::preload first")
229        .lock()
230        .expect("session poisoned");
231
232    // Get input names
233    let in_time = session
234        .inputs
235        .iter()
236        .find(|i| i.name == "input")
237        .map(|i| i.name.clone())
238        .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
239
240    let in_spec = session
241        .inputs
242        .iter()
243        .find(|i| i.name == "x")
244        .map(|i| i.name.clone())
245        .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
246
247    // Run inference
248    let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
249
250    // Extract both outputs from the model
251    // "output": frequency domain [1, sources, 4, F, Frames]
252    // "add_67": time domain [1, sources, 2, T]
253    let mut output_freq: Option<Value> = None;
254    let mut output_time: Option<Value> = None;
255
256    for (name, val) in outputs.into_iter() {
257        if name == "output" {
258            output_freq = Some(val);
259        } else if name == "add_67" {
260            output_time = Some(val);
261        }
262    }
263
264    let out_freq =
265        output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
266    let out_time =
267        output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
268
269    // Extract time domain output [1, 4, 2, T] -> [4, 2, T]
270    let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
271    let num_sources = shape_time[1] as usize;
272
273    // Extract frequency domain output [1, sources, 4, F, Frames]
274    let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
275
276    // Debug: Check if model outputs are non-zero
277    if std::env::var("DEBUG_STEMS").is_ok() {
278        let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
279        let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
280        eprintln!(
281            "Model output stats: time_max={:.6}, freq_max={:.6}",
282            time_max, freq_max
283        );
284        if time_max < 1e-10 && freq_max < 1e-10 {
285            eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
286        }
287    }
288
289    // Validate shapes
290    if shape_freq[0] != 1
291        || shape_freq[1] != num_sources as i64
292        || shape_freq[2] != 4
293        || shape_freq[3] != f_bins as i64
294        || shape_freq[4] != frames as i64
295    {
296        return Err(anyhow!(
297            "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
298            shape_freq,
299            num_sources,
300            f_bins,
301            frames
302        )
303        .into());
304    }
305
306    let source_specs: Vec<&[f32]> = (0..num_sources)
307        .map(|src| {
308            let src_freq_offset = src * 4 * f_bins * frames;
309            &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
310        })
311        .collect();
312
313    let istft_results =
314        istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
315
316    // Debug: Check iSTFT results
317    if std::env::var("DEBUG_STEMS").is_ok() {
318        for (src_idx, (left, right)) in istft_results.iter().enumerate() {
319            let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
320            let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
321            eprintln!(
322                "iSTFT result [source {}]: left_max={:.6}, right_max={:.6}",
323                src_idx, left_max, right_max
324            );
325        }
326    }
327
328    let mut result = Vec::with_capacity(num_sources * 2 * t);
329
330    for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
331        // Extract time domain for this source [2, T]
332        let src_time_offset = src * 2 * t;
333        let left_time = &data_time[src_time_offset..src_time_offset + t];
334        let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
335
336        // Combine: output = time_domain + frequency_domain (after iSTFT)
337        for i in 0..t {
338            result.push(left_time[i] + left_freq[i]);
339        }
340        for i in 0..t {
341            result.push(right_time[i] + right_freq[i]);
342        }
343    }
344
345    let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
346    Ok(out)
347}
348
349#[cfg(feature = "engine-mock")]
350mod _engine_mock {
351    use super::*;
352    use once_cell::sync::OnceCell;
353    static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
354
355    pub fn preload(h: &ModelHandle) -> Result<()> {
356        MANIFEST.set(h.manifest.clone()).ok();
357        Ok(())
358    }
359
360    pub fn manifest() -> &'static ModelManifest {
361        MANIFEST.get().expect("preload first (mock)")
362    }
363
364    pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
365        let t = left.len().min(right.len());
366        let sources = 4usize;
367        let mut out = vec![0.0f32; sources * 2 * t];
368        for s in 0..sources {
369            for i in 0..t {
370                // “identity” stems: copy input
371                out[s * 2 * t + i] = left[i]; // L
372                out[s * 2 * t + t + i] = right[i]; // R
373            }
374        }
375        Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
376    }
377}
378
379#[cfg(feature = "engine-mock")]
380pub use _engine_mock::{manifest, preload, run_window_demucs};