1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4 core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5 error::{Result, StemError},
6 model::model_manager::ModelHandle,
7 types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14 execution_providers::ExecutionProviderDispatch,
15 session::{
16 builder::{GraphOptimizationLevel, SessionBuilder},
17 Session,
18 },
19 value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48 let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51 {
52 providers.push(CUDAExecutionProvider::default().build());
53 }
54
55 #[cfg(all(feature = "coreml", target_os = "macos"))]
56 {
57 if std::env::var("ENABLE_COREML").is_ok() {
60 eprintln!("CoreML enabled via ENABLE_COREML environment variable");
61 providers.push(CoreMLExecutionProvider::default().build());
62 } else {
63 eprintln!("CoreML disabled by default (set ENABLE_COREML=1 to enable)");
64 }
65 }
66
67 #[cfg(all(feature = "directml", target_os = "windows"))]
68 {
69 if std::env::var("ENABLE_DIRECTML").is_ok() {
71 eprintln!("DirectML enabled via ENABLE_DIRECTML environment variable");
72 providers.push(DirectMLExecutionProvider::default().build());
73 } else {
74 eprintln!("DirectML disabled by default (set ENABLE_DIRECTML=1 to enable)");
75 }
76 }
77
78 #[cfg(feature = "onednn")]
79 {
80 providers.push(OneDNNExecutionProvider::default().build());
82 }
83
84 providers
85}
86
87#[cfg(not(feature = "engine-mock"))]
88fn commit_cpu_session(model_path: &std::path::Path, num_threads: usize) -> Result<Session> {
89 Ok(SessionBuilder::new()?
90 .with_optimization_level(GraphOptimizationLevel::Level3)?
91 .with_intra_threads(num_threads)?
92 .with_inter_threads(num_threads)?
93 .with_parallel_execution(true)?
94 .commit_from_file(model_path)?)
95}
96
97#[cfg(not(feature = "engine-mock"))]
98fn commit_session_sequential_eps(
99 model_path: &std::path::Path,
100 num_threads: usize,
101 providers: Vec<ExecutionProviderDispatch>,
102) -> Result<Session> {
103 if providers.is_empty() {
104 eprintln!("Using CPU ({} threads) - no GPU features enabled", num_threads);
105 return commit_cpu_session(model_path, num_threads);
106 }
107
108 eprintln!(
109 "Trying execution providers sequentially ({} candidates) with CPU fallback",
110 providers.len()
111 );
112
113 for (idx, ep) in providers.into_iter().enumerate() {
114 let builder_res = SessionBuilder::new()?
115 .with_optimization_level(GraphOptimizationLevel::Level3)?
116 .with_intra_threads(num_threads)?
117 .with_inter_threads(num_threads)?
118 .with_execution_providers(vec![ep]);
119
120 let builder = match builder_res {
121 Ok(b) => b,
122 Err(e) => {
123 eprintln!("EP builder failed (attempt #{}): {}", idx + 1, e);
124 continue;
125 }
126 };
127
128 match builder.commit_from_file(model_path) {
129 Ok(sess) => {
130 eprintln!("Execution provider selected (attempt #{}).", idx + 1);
131 return Ok(sess);
132 }
133 Err(e) => {
134 eprintln!("EP commit failed (attempt #{}): {}", idx + 1, e);
135 continue;
136 }
137 }
138 }
139
140 eprintln!("All EPs failed; falling back to CPU ({} threads)", num_threads);
141 commit_cpu_session(model_path, num_threads)
142}
143
144
145#[cfg(not(feature = "engine-mock"))]
146pub fn preload(h: &ModelHandle) -> Result<()> {
147 ORT_INIT.get_or_try_init::<_, StemError>(|| {
148 ort::init().commit().map_err(StemError::from)?;
149 Ok(())
150 })?;
151
152 let num_threads = std::thread::available_parallelism()
153 .map(|n| n.get())
154 .unwrap_or(4);
155
156 if std::env::var("STEMMER_FORCE_CPU").is_ok() {
158 eprintln!("STEMMER_FORCE_CPU is set: using CPU only");
159 let session = commit_cpu_session(h.local_path.as_path(), num_threads)?;
160 SESSION.set(Mutex::new(session)).ok();
161 MANIFEST.set(h.manifest.clone()).ok();
162 return Ok(());
163 }
164
165 let providers = get_execution_providers();
167
168 #[allow(unused_mut)]
170 let mut provider_names: Vec<&str> = Vec::new();
171 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
172 provider_names.push("CUDA");
173 #[cfg(all(feature = "coreml", target_os = "macos"))]
174 provider_names.push("CoreML");
175 #[cfg(all(feature = "directml", target_os = "windows"))]
176 provider_names.push("DirectML (opt-in)");
177 #[cfg(feature = "onednn")]
178 provider_names.push("oneDNN");
179
180 eprintln!("Configured EP candidates: {:?}", provider_names);
181
182 let session = commit_session_sequential_eps(h.local_path.as_path(), num_threads, providers)?;
183
184 SESSION.set(Mutex::new(session)).ok();
185 MANIFEST.set(h.manifest.clone()).ok();
186 Ok(())
187}
188
189#[cfg(not(feature = "engine-mock"))]
190pub fn manifest() -> &'static ModelManifest {
191 MANIFEST
192 .get()
193 .expect("engine::preload() must be called once before using the engine")
194}
195
196#[cfg(not(feature = "engine-mock"))]
197pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
198 if left.len() != right.len() {
199 return Err(anyhow!("L/R length mismatch").into());
200 }
201 let t = left.len();
202 if t != DEMUCS_T {
203 return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
204 }
205
206 let mut planar = Vec::with_capacity(2 * t);
208 planar.extend_from_slice(left);
209 planar.extend_from_slice(right);
210 let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
211
212 let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
214 if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
215 return Err(anyhow!(
216 "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
217 f_bins,
218 frames,
219 DEMUCS_F,
220 DEMUCS_FRAMES
221 )
222 .into());
223 }
224 let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
225
226 let mut session = SESSION
227 .get()
228 .expect("engine::preload first")
229 .lock()
230 .expect("session poisoned");
231
232 let in_time = session
234 .inputs
235 .iter()
236 .find(|i| i.name == "input")
237 .map(|i| i.name.clone())
238 .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
239
240 let in_spec = session
241 .inputs
242 .iter()
243 .find(|i| i.name == "x")
244 .map(|i| i.name.clone())
245 .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
246
247 let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
249
250 let mut output_freq: Option<Value> = None;
254 let mut output_time: Option<Value> = None;
255
256 for (name, val) in outputs.into_iter() {
257 if name == "output" {
258 output_freq = Some(val);
259 } else if name == "add_67" {
260 output_time = Some(val);
261 }
262 }
263
264 let out_freq =
265 output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
266 let out_time =
267 output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
268
269 let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
271 let num_sources = shape_time[1] as usize;
272
273 let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
275
276 if std::env::var("DEBUG_STEMS").is_ok() {
278 let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
279 let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
280 eprintln!(
281 "Model output stats: time_max={:.6}, freq_max={:.6}",
282 time_max, freq_max
283 );
284 if time_max < 1e-10 && freq_max < 1e-10 {
285 eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
286 }
287 }
288
289 if shape_freq[0] != 1
291 || shape_freq[1] != num_sources as i64
292 || shape_freq[2] != 4
293 || shape_freq[3] != f_bins as i64
294 || shape_freq[4] != frames as i64
295 {
296 return Err(anyhow!(
297 "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
298 shape_freq,
299 num_sources,
300 f_bins,
301 frames
302 )
303 .into());
304 }
305
306 let source_specs: Vec<&[f32]> = (0..num_sources)
307 .map(|src| {
308 let src_freq_offset = src * 4 * f_bins * frames;
309 &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
310 })
311 .collect();
312
313 let istft_results =
314 istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
315
316 if std::env::var("DEBUG_STEMS").is_ok() {
318 for (src_idx, (left, right)) in istft_results.iter().enumerate() {
319 let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
320 let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
321 eprintln!(
322 "iSTFT result [source {}]: left_max={:.6}, right_max={:.6}",
323 src_idx, left_max, right_max
324 );
325 }
326 }
327
328 let mut result = Vec::with_capacity(num_sources * 2 * t);
329
330 for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
331 let src_time_offset = src * 2 * t;
333 let left_time = &data_time[src_time_offset..src_time_offset + t];
334 let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
335
336 for i in 0..t {
338 result.push(left_time[i] + left_freq[i]);
339 }
340 for i in 0..t {
341 result.push(right_time[i] + right_freq[i]);
342 }
343 }
344
345 let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
346 Ok(out)
347}
348
349#[cfg(feature = "engine-mock")]
350mod _engine_mock {
351 use super::*;
352 use once_cell::sync::OnceCell;
353 static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
354
355 pub fn preload(h: &ModelHandle) -> Result<()> {
356 MANIFEST.set(h.manifest.clone()).ok();
357 Ok(())
358 }
359
360 pub fn manifest() -> &'static ModelManifest {
361 MANIFEST.get().expect("preload first (mock)")
362 }
363
364 pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
365 let t = left.len().min(right.len());
366 let sources = 4usize;
367 let mut out = vec![0.0f32; sources * 2 * t];
368 for s in 0..sources {
369 for i in 0..t {
370 out[s * 2 * t + i] = left[i]; out[s * 2 * t + t + i] = right[i]; }
374 }
375 Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
376 }
377}
378
379#[cfg(feature = "engine-mock")]
380pub use _engine_mock::{manifest, preload, run_window_demucs};