stem_splitter_core/core/
audio.rs1use std::{fs::File, path::Path};
2
3use anyhow::{Context, Result};
4use symphonia::core::{
5 audio::SampleBuffer, codecs::DecoderOptions, formats::FormatOptions, io::MediaSourceStream,
6 meta::MetadataOptions, probe::Hint,
7};
8use symphonia::default::{get_codecs, get_probe};
9
10use crate::types::AudioData;
11
12pub fn read_audio<P: AsRef<Path>>(path: P) -> Result<AudioData> {
13 let path: &Path = path.as_ref();
14
15 let file: File =
16 File::open(path).with_context(|| format!("Failed to open audio file: {:?}", path))?;
17
18 let mss: MediaSourceStream = MediaSourceStream::new(Box::new(file), Default::default());
19
20 let mut hint: Hint = Hint::new();
21
22 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
23 hint.with_extension(ext);
24 }
25
26 let probed = get_probe().format(
27 &hint,
28 mss,
29 &FormatOptions::default(),
30 &MetadataOptions::default(),
31 )?;
32
33 let mut format = probed.format;
34 let track = format.default_track().context("No default track found")?;
35
36 let mut decoder = get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
37
38 let mut samples: Vec<f32> = Vec::new();
39 let mut sample_rate: u32 = 0;
40 let mut channels: u16 = 0;
41
42 while let Ok(packet) = format.next_packet() {
43 let decoded = decoder.decode(&packet)?;
44 sample_rate = decoded.spec().rate;
45 channels = decoded.spec().channels.count() as u16;
46
47 let mut buffer = SampleBuffer::<f32>::new(decoded.capacity() as u64, *decoded.spec());
48 buffer.copy_interleaved_ref(decoded);
49
50 samples.extend_from_slice(buffer.samples());
51 }
52
53 if std::env::var("DEBUG_STEMS").is_ok() {
54 eprintln!(
55 "🎧 Read audio: sample_rate={} Hz, channels={}, samples={} ({:.2} seconds)",
56 sample_rate,
57 channels,
58 samples.len(),
59 samples.len() as f64 / (sample_rate as f64 * channels as f64)
60 );
61 }
62
63 Ok(AudioData {
64 samples,
65 sample_rate,
66 channels,
67 })
68}
69
70pub fn write_audio(path: &str, audio: &AudioData) -> Result<()> {
71 let path_obj = std::path::Path::new(path);
72 if let Some(parent) = path_obj.parent() {
73 std::fs::create_dir_all(parent)?;
74 }
75
76 let spec = hound::WavSpec {
77 channels: audio.channels,
78 sample_rate: audio.sample_rate,
79 bits_per_sample: 16,
80 sample_format: hound::SampleFormat::Int,
81 };
82
83 let mut writer = hound::WavWriter::create(path, spec)?;
84 for sample in &audio.samples {
85 let s = (sample * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
86 writer.write_sample(s)?;
87 }
88
89 writer.finalize()?;
90 Ok(())
91}