Skip to main content

stem_splitter_core/core/
splitter.rs

1use crate::{
2    core::{
3        audio::{create_wav_writer, read_audio, sample_to_i16, WavWriter},
4        engine,
5    },
6    error::Result,
7    io::progress::{emit_split_progress, SplitProgress},
8    model::model_manager::ensure_model,
9    types::{SplitOptions, SplitResult},
10};
11
12use std::{
13    collections::HashMap,
14    path::{Path, PathBuf},
15};
16
17struct StemOutput {
18    stem_idx: usize,
19    stem_name: String,
20    writer: WavWriter,
21}
22
23fn audio_frame_count(samples: &[f32], channels: u16) -> usize {
24    let channels = usize::from(channels.max(1));
25    samples.len() / channels
26}
27
28fn fill_stereo_window(
29    samples: &[f32],
30    channels: u16,
31    start_frame: usize,
32    left_raw: &mut [f32],
33    right_raw: &mut [f32],
34) {
35    let channels = usize::from(channels.max(1));
36
37    for i in 0..left_raw.len() {
38        let frame = start_frame + i;
39        let base = frame * channels;
40        if base >= samples.len() {
41            left_raw[i] = 0.0;
42            right_raw[i] = 0.0;
43            continue;
44        }
45
46        let left = samples[base];
47        let right = if channels == 1 {
48            left
49        } else {
50            samples.get(base + 1).copied().unwrap_or(left)
51        };
52
53        left_raw[i] = left;
54        right_raw[i] = right;
55    }
56}
57
58fn build_output_paths(input_path: &str, output_dir: &str) -> (String, String, String, String) {
59    let file_stem = Path::new(input_path)
60        .file_stem()
61        .and_then(|s| s.to_str())
62        .unwrap_or("output");
63    let base = PathBuf::from(output_dir).join(file_stem);
64
65    (
66        format!("{}_vocals.wav", base.to_string_lossy()),
67        format!("{}_drums.wav", base.to_string_lossy()),
68        format!("{}_bass.wav", base.to_string_lossy()),
69        format!("{}_other.wav", base.to_string_lossy()),
70    )
71}
72
73fn build_stem_outputs(
74    names: &[String],
75    stems_count: usize,
76    sample_rate: u32,
77    vocals_out: String,
78    drums_out: String,
79    bass_out: String,
80    other_out: String,
81) -> Result<Vec<StemOutput>> {
82    let mut name_idx: HashMap<String, usize> = HashMap::new();
83    for (i, name) in names.iter().enumerate() {
84        name_idx.insert(name.to_lowercase(), i);
85    }
86
87    let get_idx = |key: &str, fallback: usize| -> usize {
88        name_idx
89            .get(key)
90            .copied()
91            .unwrap_or(fallback.min(stems_count.saturating_sub(1)))
92    };
93
94    Ok(vec![
95        StemOutput {
96            stem_idx: get_idx("vocals", 0),
97            stem_name: "vocals".to_string(),
98            writer: create_wav_writer(&vocals_out, sample_rate, 2)?,
99        },
100        StemOutput {
101            stem_idx: get_idx("drums", 1),
102            stem_name: "drums".to_string(),
103            writer: create_wav_writer(&drums_out, sample_rate, 2)?,
104        },
105        StemOutput {
106            stem_idx: get_idx("bass", 2),
107            stem_name: "bass".to_string(),
108            writer: create_wav_writer(&bass_out, sample_rate, 2)?,
109        },
110        StemOutput {
111            stem_idx: get_idx("other", 3),
112            stem_name: "other".to_string(),
113            writer: create_wav_writer(&other_out, sample_rate, 2)?,
114        },
115    ])
116}
117
118pub fn split_file(input_path: &str, opts: SplitOptions) -> Result<SplitResult> {
119    emit_split_progress(SplitProgress::Stage("resolve_model"));
120    let handle = ensure_model(&opts.model_name, opts.manifest_url_override.as_deref())?;
121
122    emit_split_progress(SplitProgress::Stage("engine_preload"));
123    engine::preload(&handle)?;
124
125    let mf = engine::manifest();
126
127    if mf.sample_rate != 44100 {
128        return Err(anyhow::anyhow!("Currently expecting 44.1k model").into());
129    }
130
131    emit_split_progress(SplitProgress::Stage("read_audio"));
132    let audio = read_audio(input_path)?;
133    let n = audio_frame_count(&audio.samples, audio.channels);
134
135    if n == 0 {
136        return Err(anyhow::anyhow!("Empty audio").into());
137    }
138
139    let win = mf.window;
140    let hop = mf.hop;
141
142    if !(win > 0 && hop > 0 && hop <= win) {
143        return Err(anyhow::anyhow!("Bad win/hop in manifest").into());
144    }
145
146    if std::env::var("DEBUG_STEMS").is_ok() {
147        eprintln!(
148            "Window settings: win={}, hop={}, overlap={}",
149            win,
150            hop,
151            win - hop
152        );
153    }
154
155    let names = if mf.stems.is_empty() {
156        vec![
157            "vocals".into(),
158            "drums".into(),
159            "bass".into(),
160            "other".into(),
161        ]
162    } else {
163        mf.stems.clone()
164    };
165
166    let (vocals_out, drums_out, bass_out, other_out) =
167        build_output_paths(input_path, &opts.output_dir);
168
169    let mut left_raw = vec![0f32; win];
170    let mut right_raw = vec![0f32; win];
171    let mut stem_outputs: Vec<StemOutput> = Vec::new();
172
173    let mut pos = 0usize;
174    let mut chunk_done = 0usize;
175    let total_chunks = if n <= hop { 1 } else { (n - 1) / hop + 1 };
176    let mut first_chunk = true;
177
178    emit_split_progress(SplitProgress::Stage("infer"));
179    while pos < n {
180        fill_stereo_window(
181            &audio.samples,
182            audio.channels,
183            pos,
184            &mut left_raw,
185            &mut right_raw,
186        );
187
188        let out = engine::run_window_demucs(&left_raw, &right_raw)?;
189        let (stems_count, _, t_out) = (out.shape()[0], out.shape()[1], out.shape()[2]);
190
191        if first_chunk {
192            stem_outputs = build_stem_outputs(
193                &names,
194                stems_count,
195                mf.sample_rate,
196                vocals_out.clone(),
197                drums_out.clone(),
198                bass_out.clone(),
199                other_out.clone(),
200            )?;
201            first_chunk = false;
202        }
203
204        let copy_len = hop.min(t_out).min(n - pos);
205        for stem_output in &mut stem_outputs {
206            for i in 0..copy_len {
207                stem_output
208                    .writer
209                    .write_sample(sample_to_i16(out[(stem_output.stem_idx, 0, i)]))
210                    .map_err(anyhow::Error::from)?;
211                stem_output
212                    .writer
213                    .write_sample(sample_to_i16(out[(stem_output.stem_idx, 1, i)]))
214                    .map_err(anyhow::Error::from)?;
215            }
216        }
217
218        chunk_done += 1;
219        emit_split_progress(SplitProgress::Chunks {
220            done: chunk_done,
221            total: total_chunks,
222            percent: chunk_done as f32 / total_chunks as f32 * 100.0,
223        });
224
225        if pos + hop >= n {
226            break;
227        }
228        pos += hop;
229    }
230
231    emit_split_progress(SplitProgress::Stage("write_stems"));
232    for (idx, stem_output) in stem_outputs.into_iter().enumerate() {
233        emit_split_progress(SplitProgress::Writing {
234            stem: stem_output.stem_name,
235            done: idx + 1,
236            total: 4,
237            percent: (idx + 1) as f32 / 4.0 * 100.0,
238        });
239        stem_output.writer.finalize().map_err(anyhow::Error::from)?;
240    }
241
242    emit_split_progress(SplitProgress::Stage("finalize"));
243    emit_split_progress(SplitProgress::Finished);
244
245    Ok(SplitResult {
246        vocals_path: vocals_out,
247        drums_path: drums_out,
248        bass_path: bass_out,
249        other_path: other_out,
250    })
251}