stem_splitter_core/core/
audio.rs

1use std::{fs::File, path::Path};
2
3use anyhow::{Context, Result};
4use symphonia::core::{
5    audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
6    meta::MetadataOptions, probe::Hint,
7};
8use symphonia::default::{get_codecs, get_probe};
9
10use crate::types::AudioData;
11
12pub fn read_audio<P: AsRef<Path>>(path: P) -> Result<AudioData> {
13    let path: &Path = path.as_ref();
14
15    let file: File =
16        File::open(path).with_context(|| format!("Failed to open audio file: {:?}", path))?;
17
18    let mss: MediaSourceStream = MediaSourceStream::new(Box::new(file), Default::default());
19
20    let mut hint: Hint = Hint::new();
21
22    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
23        hint.with_extension(ext);
24    }
25
26    let probed = get_probe().format(
27        &hint,
28        mss,
29        &FormatOptions::default(),
30        &MetadataOptions::default(),
31    )?;
32
33    let mut format = probed.format;
34    let track = format.default_track().context("No default track found")?;
35
36    let mut decoder = get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
37
38    let mut samples: Vec<f32> = Vec::new();
39    let mut sample_rate: u32 = 0;
40    let mut channels: u16 = 0;
41
42    while let Ok(packet) = format.next_packet() {
43        let decoded = decoder.decode(&packet)?;
44        sample_rate = decoded.spec().rate;
45        channels = decoded.spec().channels.count() as u16;
46
47        let mut buffer = SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
48        buffer.copy_interleaved_ref(decoded);
49
50        samples.extend_from_slice(buffer.samples());
51    }
52
53    if std::env::var("DEBUG_STEMS").is_ok() {
54        eprintln!(
55            "🎧 Read audio: sample_rate={} Hz, channels={}, samples={} ({:.2} seconds)",
56            sample_rate,
57            channels,
58            samples.len(),
59            samples.len() as f64 / (sample_rate as f64 * channels as f64)
60        );
61    }
62
63    Ok(AudioData {
64        samples,
65        sample_rate,
66        channels,
67    })
68}
69
70pub fn write_audio(path: &str, audio: &AudioData) -> Result<()> {
71    let path_obj = std::path::Path::new(path);
72    if let Some(parent) = path_obj.parent() {
73        std::fs::create_dir_all(parent)?;
74    }
75
76    let spec = hound::WavSpec {
77        channels: audio.channels,
78        sample_rate: audio.sample_rate,
79        bits_per_sample: 16,
80        sample_format: hound::SampleFormat::Int,
81    };
82
83    let mut writer = hound::WavWriter::create(path, spec)?;
84    for sample in &audio.samples {
85        let s = (sample * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
86        writer.write_sample(s)?;
87    }
88
89    writer.finalize()?;
90    Ok(())
91}