stem_splitter_core/core/
audio.rs

1use std::{fs::File, path::Path};
2
3use anyhow::{Context, Result};
4use symphonia::core::{
5    audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
6    meta::MetadataOptions, probe::Hint,
7};
8use symphonia::default::{get_codecs, get_probe};
9
10use crate::types::AudioData;
11
12pub fn read_audio<P: AsRef<Path>>(path: P) -> Result<AudioData> {
13    let path: &Path = path.as_ref();
14
15    let file: File =
16        File::open(path).with_context(|| format!("Failed to open audio file: {:?}", path))?;
17
18    let mss: MediaSourceStream = MediaSourceStream::new(Box::new(file), Default::default());
19
20    let mut hint: Hint = Hint::new();
21
22    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
23        hint.with_extension(ext);
24    }
25
26    let probed = get_probe().format(
27        &hint,
28        mss,
29        &FormatOptions::default(),
30        &MetadataOptions::default(),
31    )?;
32
33    let mut format = probed.format;
34    let track = format.default_track().context("No default track found")?;
35
36    let mut decoder = get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
37
38    let mut samples: Vec<f32> = Vec::new();
39    let mut sample_rate: u32 = 0;
40    let mut channels: u16 = 0;
41
42    while let Ok(packet) = format.next_packet() {
43        let decoded = decoder.decode(&packet)?;
44        sample_rate = decoded.spec().rate;
45        channels = decoded.spec().channels.count() as u16;
46
47        let mut buffer = SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
48        buffer.copy_interleaved_ref(decoded);
49
50        samples.extend_from_slice(buffer.samples());
51    }
52
53    println!(
54        "🎧 Read audio: sample_rate={}, channels={}, samples={}",
55        sample_rate,
56        channels,
57        samples.len()
58    );
59
60    Ok(AudioData {
61        samples,
62        sample_rate,
63        channels,
64    })
65}
66
67pub fn write_audio(path: &str, audio: &AudioData) -> Result<()> {
68    let path_obj = std::path::Path::new(path);
69    if let Some(parent) = path_obj.parent() {
70        std::fs::create_dir_all(parent)?;
71    }
72
73    let spec = hound::WavSpec {
74        channels: audio.channels,
75        sample_rate: audio.sample_rate,
76        bits_per_sample: 16,
77        sample_format: hound::SampleFormat::Int,
78    };
79
80    let mut writer = hound::WavWriter::create(path, spec)?;
81    for sample in &audio.samples {
82        let s = (sample * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
83        writer.write_sample(s)?;
84    }
85
86    writer.finalize()?;
87    Ok(())
88}