stem_splitter_core/core/
splitter.rs1use crate::{
2 core::{
3 audio::{read_audio, write_audio},
4 dsp::to_planar_stereo,
5 engine,
6 },
7 error::Result,
8 io::progress::{emit_split_progress, SplitProgress},
9 model::model_manager::ensure_model,
10 types::{AudioData, SplitOptions, SplitResult},
11};
12
13use std::{
14 collections::HashMap,
15 fs,
16 path::{Path, PathBuf},
17};
18use tempfile::tempdir;
19
20pub fn split_file(input_path: &str, opts: SplitOptions) -> Result<SplitResult> {
21 emit_split_progress(SplitProgress::Stage("resolve_model"));
22 let handle = ensure_model(&opts.model_name, opts.manifest_url_override.as_deref())?;
23
24 emit_split_progress(SplitProgress::Stage("engine_preload"));
25 engine::preload(&handle)?;
26
27 let mf = engine::manifest();
28
29 if mf.sample_rate != 44100 {
30 return Err(anyhow::anyhow!("Currently expecting 44.1k model").into());
31 }
32
33 emit_split_progress(SplitProgress::Stage("read_audio"));
34 let audio = read_audio(input_path)?;
35 let stereo = to_planar_stereo(&audio.samples, audio.channels);
36 let n = stereo.len();
37
38 if n == 0 {
39 return Err(anyhow::anyhow!("Empty audio").into());
40 }
41
42 let win = mf.window;
43 let hop = mf.hop;
44
45 if !(win > 0 && hop > 0 && hop <= win) {
46 return Err(anyhow::anyhow!("Bad win/hop in manifest").into());
47 }
48
49 if std::env::var("DEBUG_STEMS").is_ok() {
50 eprintln!("Window settings: win={}, hop={}, overlap={}", win, hop, win - hop);
51 }
52
53 let stems_names = mf.stems.clone();
54 let mut stems_count = stems_names.len().max(1);
55
56 let tmp = tempdir()?;
57 let tmp_dir = tmp.path().to_path_buf();
58
59 let mut left_raw = vec![0f32; win];
60 let mut right_raw = vec![0f32; win];
61
62 let mut acc: Vec<Vec<[f32; 2]>> = Vec::new();
64
65 let mut pos = 0usize;
66 let mut first_chunk = true;
67
68 emit_split_progress(SplitProgress::Stage("infer"));
69 while pos < n {
70 for i in 0..win {
72 let idx = pos + i;
73 if idx < n {
74 left_raw[i] = stereo[idx][0];
75 right_raw[i] = stereo[idx][1];
76 } else {
77 left_raw[i] = 0.0;
78 right_raw[i] = 0.0;
79 }
80 }
81
82 let out = engine::run_window_demucs(&left_raw, &right_raw)?;
84 let (s_count, _, t_out) = (out.shape()[0], out.shape()[1], out.shape()[2]);
85
86 if first_chunk {
87 stems_count = s_count;
88 acc = vec![vec![[0f32; 2]; n]; stems_count];
89 first_chunk = false;
90 }
91
92 let copy_len = hop.min(t_out).min(n - pos);
95 for st in 0..stems_count {
96 for i in 0..copy_len {
97 acc[st][pos + i][0] = out[(st, 0, i)];
98 acc[st][pos + i][1] = out[(st, 1, i)];
99 }
100 }
101
102 if pos + hop >= n {
103 break;
104 }
105 pos += hop;
106 }
107
108 let names = if stems_names.is_empty() {
109 vec![
110 "vocals".into(),
111 "drums".into(),
112 "bass".into(),
113 "other".into(),
114 ]
115 } else {
116 stems_names
117 };
118
119 let mut name_idx: HashMap<String, usize> = HashMap::new();
120 for (i, name) in names.iter().enumerate() {
121 name_idx.insert(name.to_lowercase(), i);
122 }
123
124 fs::create_dir_all(&opts.output_dir)?;
125
126 emit_split_progress(SplitProgress::Stage("write_stems"));
127
128 if std::env::var("DEBUG_STEMS").is_ok() {
129 for st in 0..stems_count {
130 let max_val = acc[st].iter()
131 .map(|s| s[0].abs().max(s[1].abs()))
132 .fold(0.0f32, f32::max);
133 eprintln!("Accumulator [stem {}]: max_value={:.6}, samples={}", st, max_val, acc[st].len());
134 }
135 }
136
137 let stem_to_wav = |st: usize, base: &str| -> Result<String> {
138 let mut inter = Vec::with_capacity(n * 2);
139
140 for sample in &acc[st][..n] {
141 inter.push(sample[0]);
142 inter.push(sample[1]);
143 }
144
145 emit_split_progress(SplitProgress::Writing {
146 stem: base.to_string(),
147 done: n,
148 total: n,
149 percent: 100.0,
150 });
151
152 let data = AudioData {
153 samples: inter,
154 sample_rate: mf.sample_rate,
155 channels: 2,
156 };
157
158 let p = tmp_dir.join(format!("{base}.wav"));
159 write_audio(p.to_str().unwrap(), &data)?;
160
161 Ok(p.to_string_lossy().into())
162 };
163
164 let get_idx = |key: &str, fallback: usize| -> usize {
165 name_idx
166 .get(key)
167 .copied()
168 .unwrap_or(fallback.min(stems_count.saturating_sub(1)))
169 };
170
171 let v_path = stem_to_wav(get_idx("vocals", 0), "vocals")?;
172 let d_path = stem_to_wav(get_idx("drums", 1), "drums")?;
173 let b_path = stem_to_wav(get_idx("bass", 2), "bass")?;
174 let o_path = stem_to_wav(get_idx("other", 3), "other")?;
175
176 emit_split_progress(SplitProgress::Stage("finalize"));
177
178 let file_stem = Path::new(input_path)
179 .file_stem()
180 .and_then(|s| s.to_str())
181 .unwrap_or("output");
182 let base = PathBuf::from(&opts.output_dir).join(file_stem);
183
184 let vocals_out = copy_to(&v_path, &format!("{}_vocals.wav", base.to_string_lossy()))?;
185 let drums_out = copy_to(&d_path, &format!("{}_drums.wav", base.to_string_lossy()))?;
186 let bass_out = copy_to(&b_path, &format!("{}_bass.wav", base.to_string_lossy()))?;
187 let other_out = copy_to(&o_path, &format!("{}_other.wav", base.to_string_lossy()))?;
188
189 emit_split_progress(SplitProgress::Finished);
190
191 Ok(SplitResult {
192 vocals_path: vocals_out,
193 drums_path: drums_out,
194 bass_path: bass_out,
195 other_path: other_out,
196 })
197}
198
199fn copy_to(src: &str, dst: &str) -> Result<String> {
200 fs::copy(src, dst)?;
201 Ok(dst.to_string())
202}