stem_splitter_core/core/
splitter.rs

1use crate::{
2    core::{
3        audio::{read_audio, write_audio},
4        dsp::to_planar_stereo,
5        engine,
6    },
7    error::Result,
8    io::progress::{emit_split_progress, SplitProgress},
9    model::model_manager::ensure_model,
10    types::{AudioData, SplitOptions, SplitResult},
11};
12
13use std::{
14    collections::HashMap,
15    fs,
16    path::{Path, PathBuf},
17};
18use tempfile::tempdir;
19
20pub fn split_file(input_path: &str, opts: SplitOptions) -> Result<SplitResult> {
21    emit_split_progress(SplitProgress::Stage("resolve_model"));
22    let handle = ensure_model(&opts.model_name, opts.manifest_url_override.as_deref())?;
23
24    emit_split_progress(SplitProgress::Stage("engine_preload"));
25    engine::preload(&handle)?;
26
27    let mf = engine::manifest();
28
29    if mf.sample_rate != 44100 {
30        return Err(anyhow::anyhow!("Currently expecting 44.1k model").into());
31    }
32
33    emit_split_progress(SplitProgress::Stage("read_audio"));
34    let audio = read_audio(input_path)?;
35    let stereo = to_planar_stereo(&audio.samples, audio.channels);
36    let n = stereo.len();
37
38    if n == 0 {
39        return Err(anyhow::anyhow!("Empty audio").into());
40    }
41
42    let win = mf.window;
43    let hop = mf.hop;
44
45    if !(win > 0 && hop > 0 && hop <= win) {
46        return Err(anyhow::anyhow!("Bad win/hop in manifest").into());
47    }
48
49    if std::env::var("DEBUG_STEMS").is_ok() {
50        eprintln!("Window settings: win={}, hop={}, overlap={}", win, hop, win - hop);
51    }
52
53    let stems_names = mf.stems.clone();
54    let mut stems_count = stems_names.len().max(1);
55
56    let tmp = tempdir()?;
57    let tmp_dir = tmp.path().to_path_buf();
58
59    let mut left_raw = vec![0f32; win];
60    let mut right_raw = vec![0f32; win];
61
62    // Accumulator for each stem - no windowing needed since model outputs are already processed
63    let mut acc: Vec<Vec<[f32; 2]>> = Vec::new();
64
65    let mut pos = 0usize;
66    let mut first_chunk = true;
67
68    emit_split_progress(SplitProgress::Stage("infer"));
69    while pos < n {
70        // Extract audio chunk
71        for i in 0..win {
72            let idx = pos + i;
73            if idx < n {
74                left_raw[i] = stereo[idx][0];
75                right_raw[i] = stereo[idx][1];
76            } else {
77                left_raw[i] = 0.0;
78                right_raw[i] = 0.0;
79            }
80        }
81
82        // Run inference - model already handles windowing internally
83        let out = engine::run_window_demucs(&left_raw, &right_raw)?;
84        let (s_count, _, t_out) = (out.shape()[0], out.shape()[1], out.shape()[2]);
85
86        if first_chunk {
87            stems_count = s_count;
88            acc = vec![vec![[0f32; 2]; n]; stems_count];
89            first_chunk = false;
90        }
91
92        // Copy only the non-overlapping part (first 'hop' samples of each window)
93        // This avoids overwriting data from previous windows
94        let copy_len = hop.min(t_out).min(n - pos);
95        for st in 0..stems_count {
96            for i in 0..copy_len {
97                acc[st][pos + i][0] = out[(st, 0, i)];
98                acc[st][pos + i][1] = out[(st, 1, i)];
99            }
100        }
101
102        if pos + hop >= n {
103            break;
104        }
105        pos += hop;
106    }
107
108    let names = if stems_names.is_empty() {
109        vec![
110            "vocals".into(),
111            "drums".into(),
112            "bass".into(),
113            "other".into(),
114        ]
115    } else {
116        stems_names
117    };
118
119    let mut name_idx: HashMap<String, usize> = HashMap::new();
120    for (i, name) in names.iter().enumerate() {
121        name_idx.insert(name.to_lowercase(), i);
122    }
123
124    fs::create_dir_all(&opts.output_dir)?;
125
126    emit_split_progress(SplitProgress::Stage("write_stems"));
127    
128    if std::env::var("DEBUG_STEMS").is_ok() {
129        for st in 0..stems_count {
130            let max_val = acc[st].iter()
131                .map(|s| s[0].abs().max(s[1].abs()))
132                .fold(0.0f32, f32::max);
133            eprintln!("Accumulator [stem {}]: max_value={:.6}, samples={}", st, max_val, acc[st].len());
134        }
135    }
136    
137    let stem_to_wav = |st: usize, base: &str| -> Result<String> {
138        let mut inter = Vec::with_capacity(n * 2);
139
140        for sample in &acc[st][..n] {
141            inter.push(sample[0]);
142            inter.push(sample[1]);
143        }
144
145        emit_split_progress(SplitProgress::Writing {
146            stem: base.to_string(),
147            done: n,
148            total: n,
149            percent: 100.0,
150        });
151
152        let data = AudioData {
153            samples: inter,
154            sample_rate: mf.sample_rate,
155            channels: 2,
156        };
157
158        let p = tmp_dir.join(format!("{base}.wav"));
159        write_audio(p.to_str().unwrap(), &data)?;
160
161        Ok(p.to_string_lossy().into())
162    };
163
164    let get_idx = |key: &str, fallback: usize| -> usize {
165        name_idx
166            .get(key)
167            .copied()
168            .unwrap_or(fallback.min(stems_count.saturating_sub(1)))
169    };
170
171    let v_path = stem_to_wav(get_idx("vocals", 0), "vocals")?;
172    let d_path = stem_to_wav(get_idx("drums", 1), "drums")?;
173    let b_path = stem_to_wav(get_idx("bass", 2), "bass")?;
174    let o_path = stem_to_wav(get_idx("other", 3), "other")?;
175
176    emit_split_progress(SplitProgress::Stage("finalize"));
177
178    let file_stem = Path::new(input_path)
179        .file_stem()
180        .and_then(|s| s.to_str())
181        .unwrap_or("output");
182    let base = PathBuf::from(&opts.output_dir).join(file_stem);
183
184    let vocals_out = copy_to(&v_path, &format!("{}_vocals.wav", base.to_string_lossy()))?;
185    let drums_out = copy_to(&d_path, &format!("{}_drums.wav", base.to_string_lossy()))?;
186    let bass_out = copy_to(&b_path, &format!("{}_bass.wav", base.to_string_lossy()))?;
187    let other_out = copy_to(&o_path, &format!("{}_other.wav", base.to_string_lossy()))?;
188
189    emit_split_progress(SplitProgress::Finished);
190
191    Ok(SplitResult {
192        vocals_path: vocals_out,
193        drums_path: drums_out,
194        bass_path: bass_out,
195        other_path: other_out,
196    })
197}
198
199fn copy_to(src: &str, dst: &str) -> Result<String> {
200    fs::copy(src, dst)?;
201    Ok(dst.to_string())
202}