stem_splitter_core/core/
engine.rs

1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4    core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5    error::{Result, StemError},
6    model::model_manager::ModelHandle,
7    types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14    execution_providers::ExecutionProviderDispatch,
15    session::{
16        builder::{GraphOptimizationLevel, SessionBuilder},
17        Session,
18    },
19    value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23// CUDA: Linux and Windows only
24#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26// CoreML: macOS only (Apple Silicon)
27#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29// DirectML: Windows only
30#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32// oneDNN: All platforms
33#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48    let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51    {
52        providers.push(
53            CUDAExecutionProvider::default()
54                .build()
55        );
56    }
57
58    #[cfg(all(feature = "coreml", target_os = "macos"))]
59    {
60        // CoreML can sometimes produce silent/zero outputs on certain models
61        // Only enable if ENABLE_COREML env var is set
62        if std::env::var("ENABLE_COREML").is_ok() {
63            eprintln!("CoreML enabled via ENABLE_COREML environment variable");
64            providers.push(
65                CoreMLExecutionProvider::default()
66                    .build()
67            );
68        } else {
69            eprintln!("CoreML disabled by default (set ENABLE_COREML=1 to enable)");
70        }
71    }
72
73    #[cfg(all(feature = "directml", target_os = "windows"))]
74    {
75        providers.push(
76            DirectMLExecutionProvider::default()
77                .build()
78        );
79    }
80
81    #[cfg(feature = "onednn")]
82    {
83        // oneDNN can improve performance on Intel CPUs
84        providers.push(
85            OneDNNExecutionProvider::default()
86                .build()
87        );
88    }
89
90    providers
91}
92
93#[cfg(not(feature = "engine-mock"))]
94pub fn preload(h: &ModelHandle) -> Result<()> {
95    ORT_INIT.get_or_try_init::<_, StemError>(|| {
96        ort::init().commit().map_err(StemError::from)?;
97        Ok(())
98    })?;
99
100    let num_threads = std::thread::available_parallelism()
101        .map(|n| n.get())
102        .unwrap_or(4);
103
104    let providers = get_execution_providers();
105    
106    let session = if providers.is_empty() {
107        eprintln!("Using CPU ({} threads) - no GPU features enabled", num_threads);
108        SessionBuilder::new()?
109            .with_optimization_level(GraphOptimizationLevel::Level3)?
110            .with_intra_threads(num_threads)?
111            .with_inter_threads(num_threads)?
112            .with_parallel_execution(true)?
113            .commit_from_file(&h.local_path)?
114    } else {
115        #[allow(unused_mut)]
116        let mut provider_names: Vec<&str> = Vec::new();
117        #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
118        provider_names.push("CUDA");
119        #[cfg(all(feature = "coreml", target_os = "macos"))]
120        provider_names.push("CoreML");
121        #[cfg(all(feature = "directml", target_os = "windows"))]
122        provider_names.push("DirectML");
123        #[cfg(feature = "onednn")]
124        provider_names.push("oneDNN");
125        
126        eprintln!("Trying execution providers: {:?} (with CPU fallback)", provider_names);
127        
128        match SessionBuilder::new()?
129            .with_optimization_level(GraphOptimizationLevel::Level3)?
130            .with_execution_providers(providers)
131        {
132            Ok(builder) => {
133                builder
134                    .with_intra_threads(num_threads)?
135                    .with_inter_threads(num_threads)?
136                    .commit_from_file(&h.local_path)?
137            }
138            Err(e) => {
139                eprintln!("GPU providers failed ({}), using CPU ({} threads)", e, num_threads);
140                SessionBuilder::new()?
141                    .with_optimization_level(GraphOptimizationLevel::Level3)?
142                    .with_intra_threads(num_threads)?
143                    .with_inter_threads(num_threads)?
144                    .with_parallel_execution(true)?
145                    .commit_from_file(&h.local_path)?
146            }
147        }
148    };
149
150    SESSION.set(Mutex::new(session)).ok();
151    MANIFEST.set(h.manifest.clone()).ok();
152    Ok(())
153}
154
155#[cfg(not(feature = "engine-mock"))]
156pub fn manifest() -> &'static ModelManifest {
157    MANIFEST
158        .get()
159        .expect("engine::preload() must be called once before using the engine")
160}
161
162#[cfg(not(feature = "engine-mock"))]
163pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
164    if left.len() != right.len() {
165        return Err(anyhow!("L/R length mismatch").into());
166    }
167    let t = left.len();
168    if t != DEMUCS_T {
169        return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
170    }
171
172    // Build time branch [1,2,T], planar
173    let mut planar = Vec::with_capacity(2 * t);
174    planar.extend_from_slice(left);
175    planar.extend_from_slice(right);
176    let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
177
178    // Build spec branch [1,4,F,Frames] with center padding, Hann, 4096/1024
179    let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
180    if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
181        return Err(anyhow!(
182            "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
183            f_bins,
184            frames,
185            DEMUCS_F,
186            DEMUCS_FRAMES
187        )
188        .into());
189    }
190    let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
191
192    let mut session = SESSION
193        .get()
194        .expect("engine::preload first")
195        .lock()
196        .expect("session poisoned");
197
198    // Get input names
199    let in_time = session
200        .inputs
201        .iter()
202        .find(|i| i.name == "input")
203        .map(|i| i.name.clone())
204        .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
205
206    let in_spec = session
207        .inputs
208        .iter()
209        .find(|i| i.name == "x")
210        .map(|i| i.name.clone())
211        .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
212
213    // Run inference
214    let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
215
216    // Extract both outputs from the model
217    // "output": frequency domain [1, sources, 4, F, Frames]
218    // "add_67": time domain [1, sources, 2, T]
219    let mut output_freq: Option<Value> = None;
220    let mut output_time: Option<Value> = None;
221
222    for (name, val) in outputs.into_iter() {
223        if name == "output" {
224            output_freq = Some(val);
225        } else if name == "add_67" {
226            output_time = Some(val);
227        }
228    }
229
230    let out_freq =
231        output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
232    let out_time =
233        output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
234
235    // Extract time domain output [1, 4, 2, T] -> [4, 2, T]
236    let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
237    let num_sources = shape_time[1] as usize;
238
239    // Extract frequency domain output [1, sources, 4, F, Frames]
240    let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
241
242    // Debug: Check if model outputs are non-zero
243    if std::env::var("DEBUG_STEMS").is_ok() {
244        let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
245        let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
246        eprintln!("Model output stats: time_max={:.6}, freq_max={:.6}", time_max, freq_max);
247        if time_max < 1e-10 && freq_max < 1e-10 {
248            eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
249        }
250    }
251
252    // Validate shapes
253    if shape_freq[0] != 1
254        || shape_freq[1] != num_sources as i64
255        || shape_freq[2] != 4
256        || shape_freq[3] != f_bins as i64
257        || shape_freq[4] != frames as i64
258    {
259        return Err(anyhow!(
260            "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
261            shape_freq,
262            num_sources,
263            f_bins,
264            frames
265        )
266        .into());
267    }
268
269    let source_specs: Vec<&[f32]> = (0..num_sources)
270        .map(|src| {
271            let src_freq_offset = src * 4 * f_bins * frames;
272            &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
273        })
274        .collect();
275
276    let istft_results = istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
277
278    // Debug: Check iSTFT results
279    if std::env::var("DEBUG_STEMS").is_ok() {
280        for (src_idx, (left, right)) in istft_results.iter().enumerate() {
281            let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
282            let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
283            eprintln!("iSTFT result [source {}]: left_max={:.6}, right_max={:.6}", src_idx, left_max, right_max);
284        }
285    }
286
287    let mut result = Vec::with_capacity(num_sources * 2 * t);
288
289    for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
290        // Extract time domain for this source [2, T]
291        let src_time_offset = src * 2 * t;
292        let left_time = &data_time[src_time_offset..src_time_offset + t];
293        let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
294
295        // Combine: output = time_domain + frequency_domain (after iSTFT)
296        for i in 0..t {
297            result.push(left_time[i] + left_freq[i]);
298        }
299        for i in 0..t {
300            result.push(right_time[i] + right_freq[i]);
301        }
302    }
303
304    let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
305    Ok(out)
306}
307
308#[cfg(feature = "engine-mock")]
309mod _engine_mock {
310    use super::*;
311    use once_cell::sync::OnceCell;
312    static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
313
314    pub fn preload(h: &ModelHandle) -> Result<()> {
315        MANIFEST.set(h.manifest.clone()).ok();
316        Ok(())
317    }
318
319    pub fn manifest() -> &'static ModelManifest {
320        MANIFEST.get().expect("preload first (mock)")
321    }
322
323    pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
324        let t = left.len().min(right.len());
325        let sources = 4usize;
326        let mut out = vec![0.0f32; sources * 2 * t];
327        for s in 0..sources {
328            for i in 0..t {
329                // “identity” stems: copy input
330                out[s * 2 * t + i] = left[i]; // L
331                out[s * 2 * t + t + i] = right[i]; // R
332            }
333        }
334        Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
335    }
336}
337
338#[cfg(feature = "engine-mock")]
339pub use _engine_mock::{manifest, preload, run_window_demucs};