stem_splitter_core/core/
engine.rs

1#![cfg_attr(feature = "engine-mock", allow(dead_code, unused_imports))]
2
3use crate::{
4    core::dsp::{istft_cac_stereo, stft_cac_stereo_centered},
5    error::{Result, StemError},
6    model::model_manager::ModelHandle,
7    types::ModelManifest,
8};
9
10use anyhow::anyhow;
11use ndarray::Array3;
12use once_cell::sync::OnceCell;
13use ort::{
14    session::{
15        builder::{GraphOptimizationLevel, SessionBuilder},
16        Session,
17    },
18    value::{Tensor, Value},
19};
20use std::sync::Mutex;
21
22static SESSION: OnceCell<Mutex<Session>> = OnceCell::new();
23static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
24static ORT_INIT: OnceCell<()> = OnceCell::new();
25
26const DEMUCS_T: usize = 343_980;
27const DEMUCS_F: usize = 2048;
28const DEMUCS_FRAMES: usize = 336;
29const DEMUCS_NFFT: usize = 4096;
30const DEMUCS_HOP: usize = 1024;
31
32#[cfg(not(feature = "engine-mock"))]
33pub fn preload(h: &ModelHandle) -> Result<()> {
34    ORT_INIT.get_or_try_init::<_, StemError>(|| {
35        ort::init().commit().map_err(StemError::from)?;
36        Ok(())
37    })?;
38
39    // Use more threads for better performance
40    let num_threads = std::thread::available_parallelism()
41        .map(|n| n.get())
42        .unwrap_or(4);
43
44    let session = SessionBuilder::new()?
45        .with_optimization_level(GraphOptimizationLevel::Level3)?
46        .with_intra_threads(num_threads)?
47        .with_inter_threads(num_threads)?
48        .with_parallel_execution(true)?
49        .commit_from_file(&h.local_path)?;
50
51    SESSION.set(Mutex::new(session)).ok();
52    MANIFEST.set(h.manifest.clone()).ok();
53    Ok(())
54}
55
56#[cfg(not(feature = "engine-mock"))]
57pub fn manifest() -> &'static ModelManifest {
58    MANIFEST
59        .get()
60        .expect("engine::preload() must be called once before using the engine")
61}
62
63#[cfg(not(feature = "engine-mock"))]
64pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
65    if left.len() != right.len() {
66        return Err(anyhow!("L/R length mismatch").into());
67    }
68    let t = left.len();
69    if t != DEMUCS_T {
70        return Err(anyhow!("Bad window length {} (expected {})", t, DEMUCS_T).into());
71    }
72
73    // Build time branch [1,2,T], planar
74    let mut planar = Vec::with_capacity(2 * t);
75    planar.extend_from_slice(left);
76    planar.extend_from_slice(right);
77    let time_value: Value = Tensor::from_array((vec![1, 2, t], planar))?.into_dyn();
78
79    // Build spec branch [1,4,F,Frames] with center padding, Hann, 4096/1024
80    let (spec_cac, f_bins, frames) = stft_cac_stereo_centered(left, right, DEMUCS_NFFT, DEMUCS_HOP);
81    if f_bins != DEMUCS_F || frames != DEMUCS_FRAMES {
82        return Err(anyhow!(
83            "Spec dims mismatch: got F={},Frames={}, expected F={},Frames={}",
84            f_bins,
85            frames,
86            DEMUCS_F,
87            DEMUCS_FRAMES
88        )
89        .into());
90    }
91    let spec_value: Value = Tensor::from_array((vec![1, 4, f_bins, frames], spec_cac))?.into_dyn();
92
93    let mut session = SESSION
94        .get()
95        .expect("engine::preload first")
96        .lock()
97        .expect("session poisoned");
98
99    // Get input names
100    let in_time = session
101        .inputs
102        .iter()
103        .find(|i| i.name == "input")
104        .map(|i| i.name.clone())
105        .ok_or_else(|| anyhow!("Model missing input 'input'"))?;
106
107    let in_spec = session
108        .inputs
109        .iter()
110        .find(|i| i.name == "x")
111        .map(|i| i.name.clone())
112        .ok_or_else(|| anyhow!("Model missing input 'x'"))?;
113
114    // Run inference
115    let outputs = session.run(vec![(in_time, time_value), (in_spec, spec_value)])?;
116
117    // Extract both outputs from the model
118    // "output": frequency domain [1, sources, 4, F, Frames]
119    // "add_67": time domain [1, sources, 2, T]
120    let mut output_freq: Option<Value> = None;
121    let mut output_time: Option<Value> = None;
122
123    for (name, val) in outputs.into_iter() {
124        if name == "output" {
125            output_freq = Some(val);
126        } else if name == "add_67" {
127            output_time = Some(val);
128        }
129    }
130
131    let out_freq =
132        output_freq.ok_or_else(|| anyhow!("Model did not return 'output' (freq domain)"))?;
133    let out_time =
134        output_time.ok_or_else(|| anyhow!("Model did not return 'add_67' (time domain)"))?;
135
136    // Extract time domain output [1, 4, 2, T] -> [4, 2, T]
137    let (shape_time, data_time) = out_time.try_extract_tensor::<f32>()?;
138    let num_sources = shape_time[1] as usize;
139
140    // Extract frequency domain output [1, sources, 4, F, Frames]
141    let (shape_freq, data_freq) = out_freq.try_extract_tensor::<f32>()?;
142
143    // Validate shapes
144    if shape_freq[0] != 1
145        || shape_freq[1] != num_sources as i64
146        || shape_freq[2] != 4
147        || shape_freq[3] != f_bins as i64
148        || shape_freq[4] != frames as i64
149    {
150        return Err(anyhow!(
151            "Unexpected freq output shape: {:?}, expected [1, {}, 4, {}, {}]",
152            shape_freq,
153            num_sources,
154            f_bins,
155            frames
156        )
157        .into());
158    }
159
160    // Combine frequency and time domain outputs
161    // According to demucs.onnx: final = time_domain + istft(frequency_domain)
162    let mut result = Vec::with_capacity(num_sources * 2 * t);
163
164    for src in 0..num_sources {
165        // Extract frequency domain for this source [4, F, Frames]
166        let src_freq_offset = src * 4 * f_bins * frames;
167        let src_freq_data = &data_freq[src_freq_offset..src_freq_offset + 4 * f_bins * frames];
168
169        // Apply iSTFT to convert frequency domain to time domain
170        let (left_freq, right_freq) =
171            istft_cac_stereo(src_freq_data, f_bins, frames, DEMUCS_NFFT, DEMUCS_HOP, t);
172
173        // Extract time domain for this source [2, T]
174        let src_time_offset = src * 2 * t;
175        let left_time = &data_time[src_time_offset..src_time_offset + t];
176        let right_time = &data_time[src_time_offset + t..src_time_offset + 2 * t];
177
178        // Combine: output = time_domain + frequency_domain (after iSTFT)
179        for i in 0..t {
180            result.push(left_time[i] + left_freq[i]);
181        }
182        for i in 0..t {
183            result.push(right_time[i] + right_freq[i]);
184        }
185    }
186
187    let out = ndarray::Array3::from_shape_vec((num_sources, 2, t), result)?;
188    Ok(out)
189}
190
191#[cfg(feature = "engine-mock")]
192mod _engine_mock {
193    use super::*;
194    use once_cell::sync::OnceCell;
195    static MANIFEST: OnceCell<ModelManifest> = OnceCell::new();
196
197    pub fn preload(h: &ModelHandle) -> Result<()> {
198        MANIFEST.set(h.manifest.clone()).ok();
199        Ok(())
200    }
201
202    pub fn manifest() -> &'static ModelManifest {
203        MANIFEST.get().expect("preload first (mock)")
204    }
205
206    pub fn run_window_demucs(left: &[f32], right: &[f32]) -> Result<Array3<f32>> {
207        let t = left.len().min(right.len());
208        let sources = 4usize;
209        let mut out = vec![0.0f32; sources * 2 * t];
210        for s in 0..sources {
211            for i in 0..t {
212                // “identity” stems: copy input
213                out[s * 2 * t + i] = left[i]; // L
214                out[s * 2 * t + t + i] = right[i]; // R
215            }
216        }
217        Ok(ndarray::Array3::from_shape_vec((sources, 2, t), out)?)
218    }
219}
220
221#[cfg(feature = "engine-mock")]
222pub use _engine_mock::{manifest, preload, run_window_demucs};