1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4 core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5 error::{Result, StemError},
6 model::model_manager::ModelHandle,
7 types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14 execution_providers::ExecutionProviderDispatch,
15 session::{
16 builder::{GraphOptimizationLevel, SessionBuilder},
17 Session,
18 },
19 value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48 let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51 {
52 providers.push(CUDAExecutionProvider::default().build());
53 }
54
55 #[cfg(all(feature = "coreml", target_os = "macos"))]
56 {
57 if std::env::var("ENABLE_COREML").is_ok() {
60 if std::env::var("DEBUG_STEMS").is_ok() {
61 eprintln!("ℹ️ CoreML enabled via ENABLE_COREML environment variable");
62 }
63 providers.push(CoreMLExecutionProvider::default().build());
64 } else if std::env::var("DEBUG_STEMS").is_ok() {
65 eprintln!("ℹ️ CoreML disabled by default (set ENABLE_COREML=1 to enable)");
66 }
67 }
68
69 #[cfg(all(feature = "directml", target_os = "windows"))]
70 {
71 if std::env::var("ENABLE_DIRECTML").is_ok() {
73 if std::env::var("DEBUG_STEMS").is_ok() {
74 eprintln!("ℹ️ DirectML enabled via ENABLE_DIRECTML environment variable");
75 }
76 providers.push(DirectMLExecutionProvider::default().build());
77 } else if std::env::var("DEBUG_STEMS").is_ok() {
78 eprintln!("ℹ️ DirectML disabled by default (set ENABLE_DIRECTML=1 to enable)");
79 }
80 }
81
82 #[cfg(feature = "onednn")]
83 {
84 providers.push(OneDNNExecutionProvider::default().build());
86 }
87
88 providers
89}
90
91#[cfg(not(feature = "engine-mock"))]
92fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
93 Ok(SessionBuilder::new()?
94 .with_optimization_level(GraphOptimizationLevel::Level3)?
95 .with_intra_threads(num_threads)?
96 .with_inter_threads(num_threads)?
97 .with_parallel_execution(true)?
98 .commit_from_file(model_path)?)
99}
100
101#[cfg(not(feature = "engine-mock"))]
102fn commit_session_sequential_eps(
103 model_path: &std::path::Path,
104 num_threads: usize,
105 providers: Vec<ExecutionProviderDispatch>,
106) -> Result<Session> {
107 if providers.is_empty() {
108 if std::env::var("DEBUG_STEMS").is_ok() {
109 eprintln!("ℹ️ Using CPU ({} threads) - no GPU features enabled", num_threads);
110 }
111 return commit_cpu_session(model_path, num_threads);
112 }
113
114 if std::env::var("DEBUG_STEMS").is_ok() {
115 eprintln!(
116 "ℹ️ Trying execution providers sequentially ({} candidates) with CPU fallback",
117 providers.len()
118 );
119 }
120
121 for (idx, ep) in providers.into_iter().enumerate() {
122 let builder_res = SessionBuilder::new()?
123 .with_optimization_level(GraphOptimizationLevel::Level3)?
124 .with_intra_threads(num_threads)?
125 .with_inter_threads(num_threads)?
126 .with_execution_providers(vec![ep]);
127
128 let builder = match builder_res {
129 Ok(b) => b,
130 Err(e) => {
131 if std::env::var("DEBUG_STEMS").is_ok() {
132 eprintln!("⚠️ EP builder failed (attempt #{}): {}", idx + 1, e);
133 }
134 continue;
135 }
136 };
137
138 match builder.commit_from_file(model_path) {
139 Ok(sess) => {
140 if std::env::var("DEBUG_STEMS").is_ok() {
141 eprintln!("✅ Execution provider selected (attempt #{}).", idx + 1);
142 }
143 return Ok(sess);
144 }
145 Err(e) => {
146 if std::env::var("DEBUG_STEMS").is_ok() {
147 eprintln!("⚠️ EP commit failed (attempt #{}): {}", idx + 1, e);
148 }
149 continue;
150 }
151 }
152 }
153
154 if std::env::var("DEBUG_STEMS").is_ok() {
155 eprintln!("⚠️ All EPs failed; falling back to CPU ({} threads)", num_threads);
156 }
157 commit_cpu_session(model_path, num_threads)
158}
159
160
161#[cfg(not(feature = "engine-mock"))]
162pub fn preload(h: &ModelHandle) -> Result<()> {
163 ORT_INIT.get_or_try_init::<_, StemError>(|| {
164 ort::init().commit().map_err(StemError::from)?;
165 Ok(())
166 })?;
167
168 let num_threads = std::thread::available_parallelism()
169 .map(|n| n.get())
170 .unwrap_or(4);
171
172 if std::env::var("STEMMER_FORCE_CPU").is_ok() {
174 if std::env::var("DEBUG_STEMS").is_ok() {
175 eprintln!("ℹ️ STEMMER_FORCE_CPU is set: using CPU only");
176 }
177 let session = commit_cpu_session(h.local_path.as_path(), num_threads)?;
178 SESSION.set(Mutex::new(session)).ok();
179 MANIFEST.set(h.manifest.clone()).ok();
180 return Ok(());
181 }
182
183 let providers = get_execution_providers();
185
186 #[allow(unused_mut)]
188 let mut provider_names: Vec<&str> = Vec::new();
189 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
190 provider_names.push("CUDA");
191 #[cfg(all(feature = "coreml", target_os = "macos"))]
192 provider_names.push("CoreML");
193 #[cfg(all(feature = "directml", target_os = "windows"))]
194 provider_names.push("DirectML (opt-in)");
195 #[cfg(feature = "onednn")]
196 provider_names.push("oneDNN");
197
198 if std::env::var("DEBUG_STEMS").is_ok() {
199 eprintln!("ℹ️ Configured EP candidates: {:?}", provider_names);
200 }
201
202 let session = commit_session_sequential_eps(h.local_path.as_path(), num_threads, providers)?;
203
204 SESSION.set(Mutex::new(session)).ok();
205 MANIFEST.set(h.manifest.clone()).ok();
206 Ok(())
207}
208
209#[cfg(not(feature = "engine-mock"))]
210pub fn manifest() -> &'static ModelManifest {
211 MANIFEST
212 .get()
213 .expect("engine::preload() must be called once before using the engine")
214}
215
216#[cfg(not(feature = "engine-mock"))]
217pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
218 if left.len() != right.len() {
219 return Err(anyhow!("L/R length mismatch").into());
220 }
221 let t = left.len();
222 if t != DEMUCS_T {
223 return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
224 }
225
226 let mut planar = Vec::with_capacity(2 * t);
228 planar.extend_from_slice(left);
229 planar.extend_from_slice(right);
230 let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
231
232 let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
234 if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
235 return Err(anyhow!(
236 "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
237 f_bins,
238 frames,
239 DEMUCS_F,
240 DEMUCS_FRAMES
241 )
242 .into());
243 }
244 let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
245
246 let mut session = SESSION
247 .get()
248 .expect("engine::preload first")
249 .lock()
250 .expect("session poisoned");
251
252 let in_time = session
254 .inputs
255 .iter()
256 .find(|i| i.name == "input")
257 .map(|i| i.name.clone())
258 .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
259
260 let in_spec = session
261 .inputs
262 .iter()
263 .find(|i| i.name == "x")
264 .map(|i| i.name.clone())
265 .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
266
267 let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
269
270 let mut output_freq: Option<Value> = None;
274 let mut output_time: Option<Value> = None;
275
276 for (name, val) in outputs.into_iter() {
277 if name == "output" {
278 output_freq = Some(val);
279 } else if name == "add_67" {
280 output_time = Some(val);
281 }
282 }
283
284 let out_freq =
285 output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
286 let out_time =
287 output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
288
289 let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
291 let num_sources = shape_time[1] as usize;
292
293 let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
295
296 if std::env::var("DEBUG_STEMS").is_ok() {
298 let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
299 let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
300 eprintln!(
301 "Model output stats: time_max={:.6}, freq_max={:.6}",
302 time_max, freq_max
303 );
304 if time_max < 1e-10 && freq_max < 1e-10 {
305 eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
306 }
307 }
308
309 if shape_freq[0] != 1
311 || shape_freq[1] != num_sources as i64
312 || shape_freq[2] != 4
313 || shape_freq[3] != f_bins as i64
314 || shape_freq[4] != frames as i64
315 {
316 return Err(anyhow!(
317 "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
318 shape_freq,
319 num_sources,
320 f_bins,
321 frames
322 )
323 .into());
324 }
325
326 let source_specs: Vec<&[f32]> = (0..num_sources)
327 .map(|src| {
328 let src_freq_offset = src * 4 * f_bins * frames;
329 &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
330 })
331 .collect();
332
333 let istft_results =
334 istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
335
336 if std::env::var("DEBUG_STEMS").is_ok() {
338 for (src_idx, (left, right)) in istft_results.iter().enumerate() {
339 let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
340 let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
341 eprintln!(
342 "iSTFT result [source {}]: left_max={:.6}, right_max={:.6}",
343 src_idx, left_max, right_max
344 );
345 }
346 }
347
348 let mut result = Vec::with_capacity(num_sources * 2 * t);
349
350 for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
351 let src_time_offset = src * 2 * t;
353 let left_time = &data_time[src_time_offset..src_time_offset + t];
354 let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
355
356 for i in 0..t {
358 result.push(left_time[i] + left_freq[i]);
359 }
360 for i in 0..t {
361 result.push(right_time[i] + right_freq[i]);
362 }
363 }
364
365 let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
366 Ok(out)
367}
368
369#[cfg(feature = "engine-mock")]
370mod _engine_mock {
371 use super::*;
372 use once_cell::sync::OnceCell;
373 static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
374
375 pub fn preload(h: &ModelHandle) -> Result<()> {
376 MANIFEST.set(h.manifest.clone()).ok();
377 Ok(())
378 }
379
380 pub fn manifest() -> &'static ModelManifest {
381 MANIFEST.get().expect("preload first (mock)")
382 }
383
384 pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
385 let t = left.len().min(right.len());
386 let sources = 4usize;
387 let mut out = vec![0.0f32; sources * 2 * t];
388 for s in 0..sources {
389 for i in 0..t {
390 out[s * 2 * t + i] = left[i]; out[s * 2 * t + t + i] = right[i]; }
394 }
395 Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
396 }
397}
398
399#[cfg(feature = "engine-mock")]
400pub use _engine_mock::{manifest, preload, run_window_demucs};