stem_splitter_core/core/
engine.rs

1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4    core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5    error::{Result, StemError},
6    model::model_manager::ModelHandle,
7    types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14    execution_providers::ExecutionProviderDispatch,
15    session::{
16        builder::{GraphOptimizationLevel, SessionBuilder},
17        Session,
18    },
19    value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23// CUDA: Linux and Windows only
24#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26// CoreML: macOS only (Apple Silicon)
27#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29// DirectML: Windows only
30#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32// oneDNN: All platforms
33#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48    let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51    {
52        providers.push(CUDAExecutionProvider::default().build());
53    }
54
55    #[cfg(all(feature = "coreml", target_os = "macos"))]
56    {
57        // CoreML can sometimes produce silent/zero outputs on certain models
58        // Only enable if ENABLE_COREML env var is set
59        if std::env::var("ENABLE_COREML").is_ok() {
60            if std::env::var("DEBUG_STEMS").is_ok() {
61                eprintln!("ℹ️  CoreML enabled via ENABLE_COREML environment variable");
62            }
63            providers.push(CoreMLExecutionProvider::default().build());
64        } else if std::env::var("DEBUG_STEMS").is_ok() {
65            eprintln!("ℹ️  CoreML disabled by default (set ENABLE_COREML=1 to enable)");
66        }
67    }
68
69    #[cfg(all(feature = "directml", target_os = "windows"))]
70    {
71        // DirectML can fail on some models/drivers (init errors). Keep it opt-in.
72        if std::env::var("ENABLE_DIRECTML").is_ok() {
73            if std::env::var("DEBUG_STEMS").is_ok() {
74                eprintln!("ℹ️  DirectML enabled via ENABLE_DIRECTML environment variable");
75            }
76            providers.push(DirectMLExecutionProvider::default().build());
77        } else if std::env::var("DEBUG_STEMS").is_ok() {
78            eprintln!("ℹ️  DirectML disabled by default (set ENABLE_DIRECTML=1 to enable)");
79        }
80    }
81
82    #[cfg(feature = "onednn")]
83    {
84        // oneDNN can improve performance on Intel CPUs
85        providers.push(OneDNNExecutionProvider::default().build());
86    }
87
88    providers
89}
90
91#[cfg(not(feature = "engine-mock"))]
92fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
93    Ok(SessionBuilder::new()?
94        .with_optimization_level(GraphOptimizationLevel::Level3)?
95        .with_intra_threads(num_threads)?
96        .with_inter_threads(num_threads)?
97        .with_parallel_execution(true)?
98        .commit_from_file(model_path)?)
99}
100
101#[cfg(not(feature = "engine-mock"))]
102fn commit_session_sequential_eps(
103    model_path: &std::path::Path,
104    num_threads: usize,
105    providers: Vec<ExecutionProviderDispatch>,
106) -> Result<Session> {
107    if providers.is_empty() {
108        if std::env::var("DEBUG_STEMS").is_ok() {
109            eprintln!("ℹ️  Using CPU ({} threads) - no GPU features enabled", num_threads);
110        }
111        return commit_cpu_session(model_path, num_threads);
112    }
113
114    if std::env::var("DEBUG_STEMS").is_ok() {
115        eprintln!(
116            "ℹ️  Trying execution providers sequentially ({} candidates) with CPU fallback",
117            providers.len()
118        );
119    }
120
121    for (idx, ep) in providers.into_iter().enumerate() {
122        let builder_res = SessionBuilder::new()?
123            .with_optimization_level(GraphOptimizationLevel::Level3)?
124            .with_intra_threads(num_threads)?
125            .with_inter_threads(num_threads)?
126            .with_execution_providers(vec![ep]);
127
128        let builder = match builder_res {
129            Ok(b) => b,
130            Err(e) => {
131                if std::env::var("DEBUG_STEMS").is_ok() {
132                    eprintln!("⚠️  EP builder failed (attempt #{}): {}", idx + 1, e);
133                }
134                continue;
135            }
136        };
137
138        match builder.commit_from_file(model_path) {
139            Ok(sess) => {
140                if std::env::var("DEBUG_STEMS").is_ok() {
141                    eprintln!("✅ Execution provider selected (attempt #{}).", idx + 1);
142                }
143                return Ok(sess);
144            }
145            Err(e) => {
146                if std::env::var("DEBUG_STEMS").is_ok() {
147                    eprintln!("⚠️  EP commit failed (attempt #{}): {}", idx + 1, e);
148                }
149                continue;
150            }
151        }
152    }
153
154    if std::env::var("DEBUG_STEMS").is_ok() {
155        eprintln!("⚠️  All EPs failed; falling back to CPU ({} threads)", num_threads);
156    }
157    commit_cpu_session(model_path, num_threads)
158}
159
160
161#[cfg(not(feature = "engine-mock"))]
162pub fn preload(h: &ModelHandle) -> Result<()> {
163    ORT_INIT.get_or_try_init::<_, StemError>(|| {
164        ort::init().commit().map_err(StemError::from)?;
165        Ok(())
166    })?;
167
168    let num_threads = std::thread::available_parallelism()
169        .map(|n| n.get())
170        .unwrap_or(4);
171
172    // Debug / escape hatch: force CPU
173    if std::env::var("STEMMER_FORCE_CPU").is_ok() {
174        if std::env::var("DEBUG_STEMS").is_ok() {
175            eprintln!("ℹ️  STEMMER_FORCE_CPU is set: using CPU only");
176        }
177        let session = commit_cpu_session(h.local_path.as_path(), num_threads)?;
178        SESSION.set(Mutex::new(session)).ok();
179        MANIFEST.set(h.manifest.clone()).ok();
180        return Ok(());
181    }
182
183    // Build provider list (may be empty)
184    let providers = get_execution_providers();
185
186    // Optional: print provider list names (for logs)
187    #[allow(unused_mut)]
188    let mut provider_names: Vec<&str> = Vec::new();
189    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
190    provider_names.push("CUDA");
191    #[cfg(all(feature = "coreml", target_os = "macos"))]
192    provider_names.push("CoreML");
193    #[cfg(all(feature = "directml", target_os = "windows"))]
194    provider_names.push("DirectML (opt-in)");
195    #[cfg(feature = "onednn")]
196    provider_names.push("oneDNN");
197
198    if std::env::var("DEBUG_STEMS").is_ok() {
199        eprintln!("ℹ️  Configured EP candidates: {:?}", provider_names);
200    }
201
202    let session = commit_session_sequential_eps(h.local_path.as_path(), num_threads, providers)?;
203
204    SESSION.set(Mutex::new(session)).ok();
205    MANIFEST.set(h.manifest.clone()).ok();
206    Ok(())
207}
208
209#[cfg(not(feature = "engine-mock"))]
210pub fn manifest() -> &'static ModelManifest {
211    MANIFEST
212        .get()
213        .expect("engine::preload() must be called once before using the engine")
214}
215
216#[cfg(not(feature = "engine-mock"))]
217pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
218    if left.len() != right.len() {
219        return Err(anyhow!("L/R length mismatch").into());
220    }
221    let t = left.len();
222    if t != DEMUCS_T {
223        return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
224    }
225
226    // Build time branch [1,2,T], planar
227    let mut planar = Vec::with_capacity(2 * t);
228    planar.extend_from_slice(left);
229    planar.extend_from_slice(right);
230    let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
231
232    // Build spec branch [1,4,F,Frames] with center padding, Hann, 4096/1024
233    let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
234    if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
235        return Err(anyhow!(
236            "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
237            f_bins,
238            frames,
239            DEMUCS_F,
240            DEMUCS_FRAMES
241        )
242        .into());
243    }
244    let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
245
246    let mut session = SESSION
247        .get()
248        .expect("engine::preload first")
249        .lock()
250        .expect("session poisoned");
251
252    // Get input names
253    let in_time = session
254        .inputs
255        .iter()
256        .find(|i| i.name == "input")
257        .map(|i| i.name.clone())
258        .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
259
260    let in_spec = session
261        .inputs
262        .iter()
263        .find(|i| i.name == "x")
264        .map(|i| i.name.clone())
265        .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
266
267    // Run inference
268    let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
269
270    // Extract both outputs from the model
271    // "output": frequency domain [1, sources, 4, F, Frames]
272    // "add_67": time domain [1, sources, 2, T]
273    let mut output_freq: Option<Value> = None;
274    let mut output_time: Option<Value> = None;
275
276    for (name, val) in outputs.into_iter() {
277        if name == "output" {
278            output_freq = Some(val);
279        } else if name == "add_67" {
280            output_time = Some(val);
281        }
282    }
283
284    let out_freq =
285        output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
286    let out_time =
287        output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
288
289    // Extract time domain output [1, 4, 2, T] -> [4, 2, T]
290    let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
291    let num_sources = shape_time[1] as usize;
292
293    // Extract frequency domain output [1, sources, 4, F, Frames]
294    let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
295
296    // Debug: Check if model outputs are non-zero
297    if std::env::var("DEBUG_STEMS").is_ok() {
298        let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
299        let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
300        eprintln!(
301            "Model output stats: time_max={:.6}, freq_max={:.6}",
302            time_max, freq_max
303        );
304        if time_max < 1e-10 && freq_max < 1e-10 {
305            eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
306        }
307    }
308
309    // Validate shapes
310    if shape_freq[0] != 1
311        || shape_freq[1] != num_sources as i64
312        || shape_freq[2] != 4
313        || shape_freq[3] != f_bins as i64
314        || shape_freq[4] != frames as i64
315    {
316        return Err(anyhow!(
317            "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
318            shape_freq,
319            num_sources,
320            f_bins,
321            frames
322        )
323        .into());
324    }
325
326    let source_specs: Vec<&[f32]> = (0..num_sources)
327        .map(|src| {
328            let src_freq_offset = src * 4 * f_bins * frames;
329            &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
330        })
331        .collect();
332
333    let istft_results =
334        istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
335
336    // Debug: Check iSTFT results
337    if std::env::var("DEBUG_STEMS").is_ok() {
338        for (src_idx, (left, right)) in istft_results.iter().enumerate() {
339            let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
340            let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
341            eprintln!(
342                "iSTFT result [source {}]: left_max={:.6}, right_max={:.6}",
343                src_idx, left_max, right_max
344            );
345        }
346    }
347
348    let mut result = Vec::with_capacity(num_sources * 2 * t);
349
350    for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
351        // Extract time domain for this source [2, T]
352        let src_time_offset = src * 2 * t;
353        let left_time = &data_time[src_time_offset..src_time_offset + t];
354        let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
355
356        // Combine: output = time_domain + frequency_domain (after iSTFT)
357        for i in 0..t {
358            result.push(left_time[i] + left_freq[i]);
359        }
360        for i in 0..t {
361            result.push(right_time[i] + right_freq[i]);
362        }
363    }
364
365    let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
366    Ok(out)
367}
368
369#[cfg(feature = "engine-mock")]
370mod _engine_mock {
371    use super::*;
372    use once_cell::sync::OnceCell;
373    static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
374
375    pub fn preload(h: &ModelHandle) -> Result<()> {
376        MANIFEST.set(h.manifest.clone()).ok();
377        Ok(())
378    }
379
380    pub fn manifest() -> &'static ModelManifest {
381        MANIFEST.get().expect("preload first (mock)")
382    }
383
384    pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
385        let t = left.len().min(right.len());
386        let sources = 4usize;
387        let mut out = vec![0.0f32; sources * 2 * t];
388        for s in 0..sources {
389            for i in 0..t {
390                // “identity” stems: copy input
391                out[s * 2 * t + i] = left[i]; // L
392                out[s * 2 * t + t + i] = right[i]; // R
393            }
394        }
395        Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
396    }
397}
398
399#[cfg(feature = "engine-mock")]
400pub use _engine_mock::{manifest, preload, run_window_demucs};