stem_splitter_core/core/
engine.rs1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4 core::dsp::{istft_cac_stereo, stft_cac_stereo_centered},
5 error::{Result, StemError},
6 model::model_manager::ModelHandle,
7 types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14 session::{
15 builder::{GraphOptimizationLevel, SessionBuilder},
16 Session,
17 },
18 value::{Tensor, Value},
19};
20use std::sync::Mutex;
21
22static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
23static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
24static ORT_INIT: OnceCell<()> = OnceCell::new();
25
26const DEMUCS_T: usize = 343_980;
27const DEMUCS_F: usize = 2048;
28const DEMUCS_FRAMES: usize = 336;
29const DEMUCS_NFFT: usize = 4096;
30const DEMUCS_HOP: usize = 1024;
31
32#[cfg(not(feature = "engine-mock"))]
33pub fn preload(h: &ModelHandle) -> Result<()> {
34 ORT_INIT.get_or_try_init::<_, StemError>(|| {
35 ort::init().commit().map_err(StemError::from)?;
36 Ok(())
37 })?;
38
39 let num_threads = std::thread::available_parallelism()
41 .map(|n| n.get())
42 .unwrap_or(4);
43
44 let session = SessionBuilder::new()?
45 .with_optimization_level(GraphOptimizationLevel::Level3)?
46 .with_intra_threads(num_threads)?
47 .with_inter_threads(num_threads)?
48 .with_parallel_execution(true)?
49 .commit_from_file(&h.local_path)?;
50
51 SESSION.set(Mutex::new(session)).ok();
52 MANIFEST.set(h.manifest.clone()).ok();
53 Ok(())
54}
55
56#[cfg(not(feature = "engine-mock"))]
57pub fn manifest() -> &'static ModelManifest {
58 MANIFEST
59 .get()
60 .expect("engine::preload() must be called once before using the engine")
61}
62
63#[cfg(not(feature = "engine-mock"))]
64pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
65 if left.len() != right.len() {
66 return Err(anyhow!("L/R length mismatch").into());
67 }
68 let t = left.len();
69 if t != DEMUCS_T {
70 return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
71 }
72
73 let mut planar = Vec::with_capacity(2 * t);
75 planar.extend_from_slice(left);
76 planar.extend_from_slice(right);
77 let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
78
79 let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
81 if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
82 return Err(anyhow!(
83 "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
84 f_bins,
85 frames,
86 DEMUCS_F,
87 DEMUCS_FRAMES
88 )
89 .into());
90 }
91 let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
92
93 let mut session = SESSION
94 .get()
95 .expect("engine::preload first")
96 .lock()
97 .expect("session poisoned");
98
99 let in_time = session
101 .inputs
102 .iter()
103 .find(|i| i.name == "input")
104 .map(|i| i.name.clone())
105 .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
106
107 let in_spec = session
108 .inputs
109 .iter()
110 .find(|i| i.name == "x")
111 .map(|i| i.name.clone())
112 .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
113
114 let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
116
117 let mut output_freq: Option<Value> = None;
121 let mut output_time: Option<Value> = None;
122
123 for (name, val) in outputs.into_iter() {
124 if name == "output" {
125 output_freq = Some(val);
126 } else if name == "add_67" {
127 output_time = Some(val);
128 }
129 }
130
131 let out_freq =
132 output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
133 let out_time =
134 output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
135
136 let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
138 let num_sources = shape_time[1] as usize;
139
140 let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
142
143 if shape_freq[0] != 1
145 || shape_freq[1] != num_sources as i64
146 || shape_freq[2] != 4
147 || shape_freq[3] != f_bins as i64
148 || shape_freq[4] != frames as i64
149 {
150 return Err(anyhow!(
151 "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
152 shape_freq,
153 num_sources,
154 f_bins,
155 frames
156 )
157 .into());
158 }
159
160 let mut result = Vec::with_capacity(num_sources * 2 * t);
163
164 for src in 0..num_sources {
165 let src_freq_offset = src * 4 * f_bins * frames;
167 let src_freq_data = &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames];
168
169 let (left_freq, right_freq) =
171 istft_cac_stereo(src_freq_data, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
172
173 let src_time_offset = src * 2 * t;
175 let left_time = &data_time[src_time_offset..src_time_offset + t];
176 let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
177
178 for i in 0..t {
180 result.push(left_time[i] + left_freq[i]);
181 }
182 for i in 0..t {
183 result.push(right_time[i] + right_freq[i]);
184 }
185 }
186
187 let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
188 Ok(out)
189}
190
191#[cfg(feature = "engine-mock")]
192mod _engine_mock {
193 use super::*;
194 use once_cell::sync::OnceCell;
195 static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
196
197 pub fn preload(h: &ModelHandle) -> Result<()> {
198 MANIFEST.set(h.manifest.clone()).ok();
199 Ok(())
200 }
201
202 pub fn manifest() -> &'static ModelManifest {
203 MANIFEST.get().expect("preload first (mock)")
204 }
205
206 pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
207 let t = left.len().min(right.len());
208 let sources = 4usize;
209 let mut out = vec![0.0f32; sources * 2 * t];
210 for s in 0..sources {
211 for i in 0..t {
212 out[s * 2 * t + i] = left[i]; out[s * 2 * t + t + i] = right[i]; }
216 }
217 Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
218 }
219}
220
221#[cfg(feature = "engine-mock")]
222pub use _engine_mock::{manifest, preload, run_window_demucs};