Skip to main content

speech_prep/
buffer.rs

1//! Audio buffer types for ML processing
2//!
3//! Provides `AudioBuffer` for batch audio processing in ML pipelines,
4//! complementing streaming audio types.
5//!
6//! ## Architecture
7//!
8//! - **AudioChunk**: Streaming-focused, real-time transport
9//! - **AudioBuffer** (this module): ML-focused, batch processing
10//!
11//! ## Design Principles
12//!
13//! - Zero-copy conversions from AudioChunk where possible
14//! - Temporal type integration (AudioDuration)
15//! - ML framework compatibility (Candle, PyTorch)
16//! - Validation and normalization for ML inference
17
18use crate::error::{Error, Result};
19use crate::time::AudioDuration;
20use crate::types::AudioChunk;
21
22/// Sample rates approved for ingestion across the stack.
23const VALID_SAMPLE_RATES: [u32; 5] = [8000, 16000, 22050, 44100, 48000];
24
25/// Audio buffer for ML processing
26///
27/// Optimized for batch audio processing in ML pipelines. Provides
28/// utilities for normalization, resampling, and format conversion.
29///
30/// # Example
31///
32/// ```rust
33/// use speech_prep::buffer::AudioBuffer;
34///
35/// let samples = vec![0.1, 0.2, -0.1, -0.2];
36/// let buffer = AudioBuffer::from_samples(samples, 16000)?;
37///
38/// assert_eq!(buffer.sample_rate(), 16000);
39/// assert_eq!(buffer.len(), 4);
40/// # Ok::<(), speech_prep::error::Error>(())
41/// ```
42#[derive(Debug, Clone)]
43pub struct AudioBuffer {
44    /// Audio samples (f32, normalized to [-1.0, 1.0])
45    samples: Vec<f32>,
46    /// Sample rate in Hz
47    sample_rate: u32,
48    /// Optional metadata
49    metadata: Option<AudioMetadata>,
50}
51
52/// Audio metadata for tracking processing history
53#[derive(Debug, Clone, Default)]
54pub struct AudioMetadata {
55    /// Source identifier (file, stream, etc.)
56    pub source: Option<String>,
57    /// Original sample rate (before resampling)
58    pub original_sr: Option<u32>,
59    /// Duration in seconds
60    pub duration: Option<AudioDuration>,
61    /// Whether audio has been normalized
62    pub normalized: bool,
63    /// Processing operations applied
64    pub processing_chain: Vec<String>,
65}
66
67impl AudioBuffer {
68    /// Create `AudioBuffer` from f32 samples
69    ///
70    /// # Arguments
71    ///
72    /// * `samples` - Audio samples (will be validated)
73    /// * `sample_rate` - Sample rate in Hz
74    ///
75    /// # Errors
76    ///
77    /// Returns error if samples are empty or sample rate is invalid
78    ///
79    /// # Example
80    ///
81    /// ```rust
82    /// use speech_prep::buffer::AudioBuffer;
83    ///
84    /// let samples = vec![0.1, 0.2, -0.1];
85    /// let buffer = AudioBuffer::from_samples(samples, 16000)?;
86    /// assert_eq!(buffer.len(), 3);
87    /// # Ok::<(), speech_prep::error::Error>(())
88    /// ```
89    pub fn from_samples(samples: Vec<f32>, sample_rate: u32) -> Result<Self> {
90        Self::validate_sample_rate(sample_rate)?;
91
92        if samples.is_empty() {
93            return Err(Error::empty_input("audio samples are empty"));
94        }
95
96        Self::validate_sample_values(&samples)?;
97
98        Ok(Self {
99            samples,
100            sample_rate,
101            metadata: Some(AudioMetadata::default()),
102        })
103    }
104
105    /// Create `AudioBuffer` from `AudioChunk` (zero-copy data)
106    ///
107    /// Converts streaming `AudioChunk` to batch `AudioBuffer` for ML
108    /// processing.
109    ///
110    /// # Example
111    ///
112    /// ```rust
113    /// use speech_prep::buffer::AudioBuffer;
114    /// use speech_prep::types::AudioChunk;
115    ///
116    /// let chunk = AudioChunk::new(vec![0.1, 0.2], 0, 0.0, 16000);
117    /// let buffer = AudioBuffer::from_chunk(chunk)?;
118    /// assert_eq!(buffer.len(), 2);
119    /// # Ok::<(), speech_prep::error::Error>(())
120    /// ```
121    pub fn from_chunk(chunk: AudioChunk) -> Result<Self> {
122        let sample_rate = chunk.sample_rate;
123        let samples = chunk.data;
124
125        Self::from_samples(samples, sample_rate)
126    }
127
128    /// Get sample rate in Hz
129    #[must_use]
130    pub fn sample_rate(&self) -> u32 {
131        self.sample_rate
132    }
133
134    /// Get number of samples
135    #[must_use]
136    pub fn len(&self) -> usize {
137        self.samples.len()
138    }
139
140    /// Check if buffer is empty
141    #[must_use]
142    pub fn is_empty(&self) -> bool {
143        self.samples.is_empty()
144    }
145
146    /// Get duration as `AudioDuration`
147    ///
148    /// # Example
149    ///
150    /// ```rust
151    /// use speech_prep::buffer::AudioBuffer;
152    /// use speech_prep::time::AudioDuration;
153    ///
154    /// let buffer = AudioBuffer::from_samples(vec![0.0; 16000], 16000)?;
155    /// let duration = buffer.duration();
156    /// assert_eq!(duration.as_secs(), 1);
157    /// # Ok::<(), speech_prep::error::Error>(())
158    /// ```
159    #[must_use]
160    pub fn duration(&self) -> AudioDuration {
161        let duration_secs = self.samples.len() as f64 / f64::from(self.sample_rate);
162        let duration_nanos = (duration_secs * 1_000_000_000.0) as u64;
163        AudioDuration::from_nanos(duration_nanos)
164    }
165
166    /// Get immutable slice of samples
167    #[must_use]
168    pub fn samples(&self) -> &[f32] {
169        &self.samples
170    }
171
172    /// Get mutable slice of samples
173    pub fn samples_mut(&mut self) -> &mut [f32] {
174        &mut self.samples
175    }
176
177    /// Consume buffer and return samples
178    #[must_use]
179    pub fn into_samples(self) -> Vec<f32> {
180        self.samples
181    }
182
183    /// Normalize samples to [-1.0, 1.0] range
184    ///
185    /// Applies peak normalization to ensure samples are within valid range.
186    ///
187    /// # Example
188    ///
189    /// ```rust
190    /// use speech_prep::buffer::AudioBuffer;
191    ///
192    /// let mut buffer = AudioBuffer::from_samples(vec![2.0, -2.0, 1.0], 16000)?;
193    /// buffer.normalize();
194    ///
195    /// // Samples now scaled to [-1.0, 1.0]
196    /// assert!(buffer.samples().iter().all(|&s| s >= -1.0 && s <= 1.0));
197    /// # Ok::<(), speech_prep::error::Error>(())
198    /// ```
199    pub fn normalize(&mut self) {
200        let max_abs = self.samples.iter().map(|&s| s.abs()).fold(0.0f32, f32::max);
201
202        if max_abs > 0.0 {
203            let scale = 1.0 / max_abs;
204            for sample in &mut self.samples {
205                *sample *= scale;
206            }
207        }
208
209        if let Some(ref mut meta) = self.metadata {
210            meta.normalized = true;
211            meta.processing_chain.push("normalize".to_owned());
212        }
213    }
214
215    /// Validate sample values are within expected range
216    ///
217    /// Checks that all samples are finite and within reasonable bounds.
218    pub fn validate_samples(&self) -> Result<()> {
219        Self::validate_sample_values(&self.samples)
220    }
221
222    fn validate_sample_values(samples: &[f32]) -> Result<()> {
223        for &sample in samples {
224            if !sample.is_finite() {
225                return Err(Error::invalid_format("sample value is not finite"));
226            }
227        }
228
229        Ok(())
230    }
231
232    /// Get metadata
233    #[must_use]
234    pub fn metadata(&self) -> Option<&AudioMetadata> {
235        self.metadata.as_ref()
236    }
237
238    /// Set metadata
239    pub fn set_metadata(&mut self, metadata: AudioMetadata) {
240        self.metadata = Some(metadata);
241    }
242
243    fn validate_sample_rate(sample_rate: u32) -> Result<()> {
244        if VALID_SAMPLE_RATES.contains(&sample_rate) {
245            Ok(())
246        } else {
247            Err(Error::invalid_format(format!(
248                "unsupported sample rate: {sample_rate}"
249            )))
250        }
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_from_samples_valid() {
260        let samples = vec![0.1, 0.2, -0.1, -0.2];
261        let buffer = AudioBuffer::from_samples(samples.clone(), 16000);
262
263        assert!(buffer.is_ok());
264        let buffer = buffer.expect("buffer creation");
265        assert_eq!(buffer.len(), 4);
266        assert_eq!(buffer.sample_rate(), 16000);
267        assert_eq!(buffer.samples(), &samples[..]);
268    }
269
270    #[test]
271    fn test_from_samples_rejects_non_finite_values() {
272        let samples = vec![0.0, f32::NAN, 0.2];
273        let buffer = AudioBuffer::from_samples(samples, 16000);
274
275        assert!(buffer.is_err(), "NaN samples must be rejected");
276    }
277
278    #[test]
279    fn test_sample_rate_policy() {
280        assert!(AudioBuffer::from_samples(vec![0.0; 10], 16000).is_ok());
281        assert!(
282            AudioBuffer::from_samples(vec![0.0; 10], 32_000).is_err(),
283            "Unexpected sample rates must be rejected"
284        );
285    }
286
287    #[test]
288    fn test_from_samples_empty_fails() {
289        let samples: Vec<f32> = vec![];
290        let result = AudioBuffer::from_samples(samples, 16000);
291
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_from_samples_invalid_sample_rate() {
297        let samples = vec![0.1, 0.2];
298
299        assert!(AudioBuffer::from_samples(samples.clone(), 1000).is_err());
300
301        assert!(AudioBuffer::from_samples(samples, 100_000).is_err());
302    }
303
304    #[test]
305    fn test_from_chunk() {
306        let chunk = AudioChunk::new(vec![0.1, 0.2, -0.1], 0, 0.0, 16000);
307        let buffer = AudioBuffer::from_chunk(chunk);
308
309        assert!(buffer.is_ok());
310        let buffer = buffer.expect("buffer from chunk");
311        assert_eq!(buffer.len(), 3);
312        assert_eq!(buffer.sample_rate(), 16000);
313    }
314
315    #[test]
316    fn test_duration() {
317        let buffer = AudioBuffer::from_samples(vec![0.0; 16000], 16000).expect("buffer creation");
318        let duration = buffer.duration();
319
320        assert_eq!(duration.as_secs(), 1);
321        assert!(duration.as_millis() >= 1000 && duration.as_millis() <= 1001); // Allow for rounding
322    }
323
324    #[test]
325    fn test_normalize() {
326        let mut buffer =
327            AudioBuffer::from_samples(vec![2.0, -2.0, 1.0, -1.0], 16000).expect("buffer creation");
328
329        buffer.normalize();
330
331        let max_abs = buffer
332            .samples()
333            .iter()
334            .map(|&s| s.abs())
335            .fold(0.0f32, f32::max);
336        assert!((max_abs - 1.0).abs() < 1e-6);
337
338        assert!(buffer.metadata().expect("metadata").normalized);
339    }
340
341    #[test]
342    fn test_normalize_zero_samples() {
343        let mut buffer =
344            AudioBuffer::from_samples(vec![0.0, 0.0, 0.0], 16000).expect("buffer creation");
345
346        buffer.normalize();
347    }
348
349    #[test]
350    fn test_validate_samples_valid() {
351        let buffer =
352            AudioBuffer::from_samples(vec![0.1, -0.5, 0.9], 16000).expect("buffer creation");
353
354        assert!(buffer.validate_samples().is_ok());
355    }
356
357    #[test]
358    fn test_validate_samples_infinite() {
359        let mut buffer =
360            AudioBuffer::from_samples(vec![0.1, 0.2, 0.9], 16000).expect("buffer creation");
361        buffer.samples_mut()[1] = f32::INFINITY;
362        assert!(buffer.validate_samples().is_err());
363    }
364
365    #[test]
366    fn test_validate_samples_nan() {
367        let mut buffer =
368            AudioBuffer::from_samples(vec![0.1, 0.2, 0.9], 16000).expect("buffer creation");
369        buffer.samples_mut()[1] = f32::NAN;
370        assert!(buffer.validate_samples().is_err());
371    }
372
373    #[test]
374    fn test_into_samples() {
375        let samples = vec![0.1, 0.2, 0.3];
376        let buffer = AudioBuffer::from_samples(samples.clone(), 16000).expect("buffer creation");
377
378        let extracted = buffer.into_samples();
379        assert_eq!(extracted, samples);
380    }
381
382    #[test]
383    fn test_samples_mut() {
384        let mut buffer =
385            AudioBuffer::from_samples(vec![0.1, 0.2, 0.3], 16000).expect("buffer creation");
386
387        buffer.samples_mut()[0] = 0.5;
388        assert_eq!(buffer.samples()[0], 0.5);
389    }
390
391    #[test]
392    fn test_metadata_operations() {
393        let mut buffer = AudioBuffer::from_samples(vec![0.1, 0.2], 16000).expect("buffer creation");
394
395        let metadata = AudioMetadata {
396            source: Some("test.wav".to_owned()),
397            original_sr: Some(44100),
398            duration: Some(AudioDuration::from_millis(125)),
399            normalized: true,
400            processing_chain: vec!["resample".to_owned(), "normalize".to_owned()],
401        };
402
403        buffer.set_metadata(metadata);
404
405        let retrieved = buffer.metadata().expect("metadata");
406        assert_eq!(retrieved.source.as_deref(), Some("test.wav"));
407        assert_eq!(retrieved.original_sr, Some(44100));
408        assert!(retrieved.normalized);
409        assert_eq!(retrieved.processing_chain.len(), 2);
410    }
411}