1use crate::{
2 core::{
3 audio::{create_wav_writer, read_audio, sample_to_i16, WavWriter},
4 engine,
5 },
6 error::Result,
7 io::progress::{emit_split_progress, SplitProgress},
8 model::model_manager::ensure_model,
9 types::{SplitOptions, SplitResult},
10};
11
12use std::{
13 collections::HashMap,
14 path::{Path, PathBuf},
15};
16
17struct StemOutput {
18 stem_idx: usize,
19 stem_name: String,
20 writer: WavWriter,
21}
22
23fn audio_frame_count(samples: &[f32], channels: u16) -> usize {
24 let channels = usize::from(channels.max(1));
25 samples.len() / channels
26}
27
28fn fill_stereo_window(
29 samples: &[f32],
30 channels: u16,
31 start_frame: usize,
32 left_raw: &mut [f32],
33 right_raw: &mut [f32],
34) {
35 let channels = usize::from(channels.max(1));
36
37 for i in 0..left_raw.len() {
38 let frame = start_frame + i;
39 let base = frame * channels;
40 if base >= samples.len() {
41 left_raw[i] = 0.0;
42 right_raw[i] = 0.0;
43 continue;
44 }
45
46 let left = samples[base];
47 let right = if channels == 1 {
48 left
49 } else {
50 samples.get(base + 1).copied().unwrap_or(left)
51 };
52
53 left_raw[i] = left;
54 right_raw[i] = right;
55 }
56}
57
58fn build_output_paths(input_path: &str, output_dir: &str) -> (String, String, String, String) {
59 let file_stem = Path::new(input_path)
60 .file_stem()
61 .and_then(|s| s.to_str())
62 .unwrap_or("output");
63 let base = PathBuf::from(output_dir).join(file_stem);
64
65 (
66 format!("{}_vocals.wav", base.to_string_lossy()),
67 format!("{}_drums.wav", base.to_string_lossy()),
68 format!("{}_bass.wav", base.to_string_lossy()),
69 format!("{}_other.wav", base.to_string_lossy()),
70 )
71}
72
73fn build_stem_outputs(
74 names: &[String],
75 stems_count: usize,
76 sample_rate: u32,
77 vocals_out: String,
78 drums_out: String,
79 bass_out: String,
80 other_out: String,
81) -> Result<Vec<StemOutput>> {
82 let mut name_idx: HashMap<String, usize> = HashMap::new();
83 for (i, name) in names.iter().enumerate() {
84 name_idx.insert(name.to_lowercase(), i);
85 }
86
87 let get_idx = |key: &str, fallback: usize| -> usize {
88 name_idx
89 .get(key)
90 .copied()
91 .unwrap_or(fallback.min(stems_count.saturating_sub(1)))
92 };
93
94 Ok(vec![
95 StemOutput {
96 stem_idx: get_idx("vocals", 0),
97 stem_name: "vocals".to_string(),
98 writer: create_wav_writer(&vocals_out, sample_rate, 2)?,
99 },
100 StemOutput {
101 stem_idx: get_idx("drums", 1),
102 stem_name: "drums".to_string(),
103 writer: create_wav_writer(&drums_out, sample_rate, 2)?,
104 },
105 StemOutput {
106 stem_idx: get_idx("bass", 2),
107 stem_name: "bass".to_string(),
108 writer: create_wav_writer(&bass_out, sample_rate, 2)?,
109 },
110 StemOutput {
111 stem_idx: get_idx("other", 3),
112 stem_name: "other".to_string(),
113 writer: create_wav_writer(&other_out, sample_rate, 2)?,
114 },
115 ])
116}
117
118pub fn split_file(input_path: &str, opts: SplitOptions) -> Result<SplitResult> {
119 emit_split_progress(SplitProgress::Stage("resolve_model"));
120 let handle = ensure_model(&opts.model_name, opts.manifest_url_override.as_deref())?;
121
122 emit_split_progress(SplitProgress::Stage("engine_preload"));
123 engine::preload(&handle)?;
124
125 let mf = engine::manifest();
126
127 if mf.sample_rate != 44100 {
128 return Err(anyhow::anyhow!("Currently expecting 44.1k model").into());
129 }
130
131 emit_split_progress(SplitProgress::Stage("read_audio"));
132 let audio = read_audio(input_path)?;
133 let n = audio_frame_count(&audio.samples, audio.channels);
134
135 if n == 0 {
136 return Err(anyhow::anyhow!("Empty audio").into());
137 }
138
139 let win = mf.window;
140 let hop = mf.hop;
141
142 if !(win > 0 && hop > 0 && hop <= win) {
143 return Err(anyhow::anyhow!("Bad win/hop in manifest").into());
144 }
145
146 if std::env::var("DEBUG_STEMS").is_ok() {
147 eprintln!(
148 "Window settings: win={}, hop={}, overlap={}",
149 win,
150 hop,
151 win - hop
152 );
153 }
154
155 let names = if mf.stems.is_empty() {
156 vec![
157 "vocals".into(),
158 "drums".into(),
159 "bass".into(),
160 "other".into(),
161 ]
162 } else {
163 mf.stems.clone()
164 };
165
166 let (vocals_out, drums_out, bass_out, other_out) =
167 build_output_paths(input_path, &opts.output_dir);
168
169 let mut left_raw = vec![0f32; win];
170 let mut right_raw = vec![0f32; win];
171 let mut stem_outputs: Vec<StemOutput> = Vec::new();
172
173 let mut pos = 0usize;
174 let mut chunk_done = 0usize;
175 let total_chunks = if n <= hop { 1 } else { (n - 1) / hop + 1 };
176 let mut first_chunk = true;
177
178 emit_split_progress(SplitProgress::Stage("infer"));
179 while pos < n {
180 fill_stereo_window(
181 &audio.samples,
182 audio.channels,
183 pos,
184 &mut left_raw,
185 &mut right_raw,
186 );
187
188 let out = engine::run_window_demucs(&left_raw, &right_raw)?;
189 let (stems_count, _, t_out) = (out.shape()[0], out.shape()[1], out.shape()[2]);
190
191 if first_chunk {
192 stem_outputs = build_stem_outputs(
193 &names,
194 stems_count,
195 mf.sample_rate,
196 vocals_out.clone(),
197 drums_out.clone(),
198 bass_out.clone(),
199 other_out.clone(),
200 )?;
201 first_chunk = false;
202 }
203
204 let copy_len = hop.min(t_out).min(n - pos);
205 for stem_output in &mut stem_outputs {
206 for i in 0..copy_len {
207 stem_output
208 .writer
209 .write_sample(sample_to_i16(out[(stem_output.stem_idx, 0, i)]))
210 .map_err(anyhow::Error::from)?;
211 stem_output
212 .writer
213 .write_sample(sample_to_i16(out[(stem_output.stem_idx, 1, i)]))
214 .map_err(anyhow::Error::from)?;
215 }
216 }
217
218 chunk_done += 1;
219 emit_split_progress(SplitProgress::Chunks {
220 done: chunk_done,
221 total: total_chunks,
222 percent: chunk_done as f32 / total_chunks as f32 * 100.0,
223 });
224
225 if pos + hop >= n {
226 break;
227 }
228 pos += hop;
229 }
230
231 emit_split_progress(SplitProgress::Stage("write_stems"));
232 for (idx, stem_output) in stem_outputs.into_iter().enumerate() {
233 emit_split_progress(SplitProgress::Writing {
234 stem: stem_output.stem_name,
235 done: idx + 1,
236 total: 4,
237 percent: (idx + 1) as f32 / 4.0 * 100.0,
238 });
239 stem_output.writer.finalize().map_err(anyhow::Error::from)?;
240 }
241
242 emit_split_progress(SplitProgress::Stage("finalize"));
243 emit_split_progress(SplitProgress::Finished);
244
245 Ok(SplitResult {
246 vocals_path: vocals_out,
247 drums_path: drums_out,
248 bass_path: bass_out,
249 other_path: other_out,
250 })
251}