1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4 core::dsp::{istft_cac_stereo_parallel, stft_cac_stereo_centered},
5 error::{Result, StemError},
6 model::model_manager::ModelHandle,
7 types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14 execution_providers::ExecutionProviderDispatch,
15 session::{
16 builder::{GraphOptimizationLevel, SessionBuilder},
17 Session,
18 },
19 value::{Tensor, Value},
20};
21use std::sync::Mutex;
22
23#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
25use ort::execution_providers::CUDAExecutionProvider;
26#[cfg(all(feature = "coreml", target_os = "macos"))]
28use ort::execution_providers::CoreMLExecutionProvider;
29#[cfg(all(feature = "directml", target_os = "windows"))]
31use ort::execution_providers::DirectMLExecutionProvider;
32#[cfg(feature = "onednn")]
34use ort::execution_providers::OneDNNExecutionProvider;
35
36static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
37static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
38static ORT_INIT: OnceCell<()> = OnceCell::new();
39
40const DEMUCS_T: usize = 343_980;
41const DEMUCS_F: usize = 2048;
42const DEMUCS_FRAMES: usize = 336;
43const DEMUCS_NFFT: usize = 4096;
44const DEMUCS_HOP: usize = 1024;
45
46#[allow(unused_mut)]
47fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
48 let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
49
50 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
51 {
52 providers.push(
53 CUDAExecutionProvider::default()
54 .build()
55 );
56 }
57
58 #[cfg(all(feature = "coreml", target_os = "macos"))]
59 {
60 if std::env::var("ENABLE_COREML").is_ok() {
63 eprintln!("CoreML enabled via ENABLE_COREML environment variable");
64 providers.push(
65 CoreMLExecutionProvider::default()
66 .build()
67 );
68 } else {
69 eprintln!("CoreML disabled by default (set ENABLE_COREML=1 to enable)");
70 }
71 }
72
73 #[cfg(all(feature = "directml", target_os = "windows"))]
74 {
75 providers.push(
76 DirectMLExecutionProvider::default()
77 .build()
78 );
79 }
80
81 #[cfg(feature = "onednn")]
82 {
83 providers.push(
85 OneDNNExecutionProvider::default()
86 .build()
87 );
88 }
89
90 providers
91}
92
93#[cfg(not(feature = "engine-mock"))]
94pub fn preload(h: &ModelHandle) -> Result<()> {
95 ORT_INIT.get_or_try_init::<_, StemError>(|| {
96 ort::init().commit().map_err(StemError::from)?;
97 Ok(())
98 })?;
99
100 let num_threads = std::thread::available_parallelism()
101 .map(|n| n.get())
102 .unwrap_or(4);
103
104 let providers = get_execution_providers();
105
106 let session = if providers.is_empty() {
107 eprintln!("Using CPU ({} threads) - no GPU features enabled", num_threads);
108 SessionBuilder::new()?
109 .with_optimization_level(GraphOptimizationLevel::Level3)?
110 .with_intra_threads(num_threads)?
111 .with_inter_threads(num_threads)?
112 .with_parallel_execution(true)?
113 .commit_from_file(&h.local_path)?
114 } else {
115 #[allow(unused_mut)]
116 let mut provider_names: Vec<&str> = Vec::new();
117 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
118 provider_names.push("CUDA");
119 #[cfg(all(feature = "coreml", target_os = "macos"))]
120 provider_names.push("CoreML");
121 #[cfg(all(feature = "directml", target_os = "windows"))]
122 provider_names.push("DirectML");
123 #[cfg(feature = "onednn")]
124 provider_names.push("oneDNN");
125
126 eprintln!("Trying execution providers: {:?} (with CPU fallback)", provider_names);
127
128 match SessionBuilder::new()?
129 .with_optimization_level(GraphOptimizationLevel::Level3)?
130 .with_execution_providers(providers)
131 {
132 Ok(builder) => {
133 builder
134 .with_intra_threads(num_threads)?
135 .with_inter_threads(num_threads)?
136 .commit_from_file(&h.local_path)?
137 }
138 Err(e) => {
139 eprintln!("GPU providers failed ({}), using CPU ({} threads)", e, num_threads);
140 SessionBuilder::new()?
141 .with_optimization_level(GraphOptimizationLevel::Level3)?
142 .with_intra_threads(num_threads)?
143 .with_inter_threads(num_threads)?
144 .with_parallel_execution(true)?
145 .commit_from_file(&h.local_path)?
146 }
147 }
148 };
149
150 SESSION.set(Mutex::new(session)).ok();
151 MANIFEST.set(h.manifest.clone()).ok();
152 Ok(())
153}
154
155#[cfg(not(feature = "engine-mock"))]
156pub fn manifest() -> &'static ModelManifest {
157 MANIFEST
158 .get()
159 .expect("engine::preload() must be called once before using the engine")
160}
161
162#[cfg(not(feature = "engine-mock"))]
163pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
164 if left.len() != right.len() {
165 return Err(anyhow!("L/R length mismatch").into());
166 }
167 let t = left.len();
168 if t != DEMUCS_T {
169 return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
170 }
171
172 let mut planar = Vec::with_capacity(2 * t);
174 planar.extend_from_slice(left);
175 planar.extend_from_slice(right);
176 let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
177
178 let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
180 if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
181 return Err(anyhow!(
182 "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
183 f_bins,
184 frames,
185 DEMUCS_F,
186 DEMUCS_FRAMES
187 )
188 .into());
189 }
190 let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
191
192 let mut session = SESSION
193 .get()
194 .expect("engine::preload first")
195 .lock()
196 .expect("session poisoned");
197
198 let in_time = session
200 .inputs
201 .iter()
202 .find(|i| i.name == "input")
203 .map(|i| i.name.clone())
204 .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
205
206 let in_spec = session
207 .inputs
208 .iter()
209 .find(|i| i.name == "x")
210 .map(|i| i.name.clone())
211 .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
212
213 let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
215
216 let mut output_freq: Option<Value> = None;
220 let mut output_time: Option<Value> = None;
221
222 for (name, val) in outputs.into_iter() {
223 if name == "output" {
224 output_freq = Some(val);
225 } else if name == "add_67" {
226 output_time = Some(val);
227 }
228 }
229
230 let out_freq =
231 output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
232 let out_time =
233 output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
234
235 let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
237 let num_sources = shape_time[1] as usize;
238
239 let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
241
242 if std::env::var("DEBUG_STEMS").is_ok() {
244 let time_max = data_time.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
245 let freq_max = data_freq.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
246 eprintln!("Model output stats: time_max={:.6}, freq_max={:.6}", time_max, freq_max);
247 if time_max < 1e-10 && freq_max < 1e-10 {
248 eprintln!("WARNING: Model outputs are all zeros! This indicates a problem with the execution provider.");
249 }
250 }
251
252 if shape_freq[0] != 1
254 || shape_freq[1] != num_sources as i64
255 || shape_freq[2] != 4
256 || shape_freq[3] != f_bins as i64
257 || shape_freq[4] != frames as i64
258 {
259 return Err(anyhow!(
260 "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
261 shape_freq,
262 num_sources,
263 f_bins,
264 frames
265 )
266 .into());
267 }
268
269 let source_specs: Vec<&[f32]> = (0..num_sources)
270 .map(|src| {
271 let src_freq_offset = src * 4 * f_bins * frames;
272 &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames]
273 })
274 .collect();
275
276 let istft_results = istft_cac_stereo_parallel(&source_specs, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
277
278 if std::env::var("DEBUG_STEMS").is_ok() {
280 for (src_idx, (left, right)) in istft_results.iter().enumerate() {
281 let left_max = left.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
282 let right_max = right.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
283 eprintln!("iSTFT result [source {}]: left_max={:.6}, right_max={:.6}", src_idx, left_max, right_max);
284 }
285 }
286
287 let mut result = Vec::with_capacity(num_sources * 2 * t);
288
289 for (src, (left_freq, right_freq)) in istft_results.into_iter().enumerate() {
290 let src_time_offset = src * 2 * t;
292 let left_time = &data_time[src_time_offset..src_time_offset + t];
293 let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
294
295 for i in 0..t {
297 result.push(left_time[i] + left_freq[i]);
298 }
299 for i in 0..t {
300 result.push(right_time[i] + right_freq[i]);
301 }
302 }
303
304 let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
305 Ok(out)
306}
307
308#[cfg(feature = "engine-mock")]
309mod _engine_mock {
310 use super::*;
311 use once_cell::sync::OnceCell;
312 static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
313
314 pub fn preload(h: &ModelHandle) -> Result<()> {
315 MANIFEST.set(h.manifest.clone()).ok();
316 Ok(())
317 }
318
319 pub fn manifest() -> &'static ModelManifest {
320 MANIFEST.get().expect("preload first (mock)")
321 }
322
323 pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
324 let t = left.len().min(right.len());
325 let sources = 4usize;
326 let mut out = vec![0.0f32; sources * 2 * t];
327 for s in 0..sources {
328 for i in 0..t {
329 out[s * 2 * t + i] = left[i]; out[s * 2 * t + t + i] = right[i]; }
333 }
334 Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
335 }
336}
337
338#[cfg(feature = "engine-mock")]
339pub use _engine_mock::{manifest, preload, run_window_demucs};