Skip to main content

speech_prep/
buffer.rs

1//! Audio buffer types for batch audio processing.
2//!
3//! Provides `AudioBuffer` for owned sample buffers, complementing streaming
4//! audio types.
5//!
6//! ## Architecture
7//!
8//! - **AudioChunk**: Streaming-oriented transport
9//! - **AudioBuffer** (this module): Owned samples for batch-style APIs
10//!
11//! ## Design Principles
12//!
13//! - Zero-copy conversions from AudioChunk where possible
14//! - Temporal type integration (AudioDuration)
15//! - Validation and normalization helpers
16
17use crate::error::{Error, Result};
18use crate::time::AudioDuration;
19use crate::types::AudioChunk;
20
21/// Sample rates approved for ingestion across the stack.
22const VALID_SAMPLE_RATES: [u32; 5] = [8000, 16000, 22050, 44100, 48000];
23
24/// Audio buffer for batch processing.
25///
26/// Provides utilities for normalization, resampling, and format conversion.
27///
28/// # Example
29///
30/// ```rust
31/// use speech_prep::buffer::AudioBuffer;
32///
33/// let samples = vec![0.1, 0.2, -0.1, -0.2];
34/// let buffer = AudioBuffer::from_samples(samples, 16000)?;
35///
36/// assert_eq!(buffer.sample_rate(), 16000);
37/// assert_eq!(buffer.len(), 4);
38/// # Ok::<(), speech_prep::error::Error>(())
39/// ```
40#[derive(Debug, Clone)]
41pub struct AudioBuffer {
42    /// Audio samples (f32, normalized to [-1.0, 1.0])
43    samples: Vec<f32>,
44    /// Sample rate in Hz
45    sample_rate: u32,
46    /// Optional metadata
47    metadata: Option<AudioBufferMetadata>,
48}
49
50/// Audio buffer metadata for tracking processing history.
51#[derive(Debug, Clone, Default)]
52pub struct AudioBufferMetadata {
53    /// Source identifier (file, stream, etc.)
54    pub source: Option<String>,
55    /// Original sample rate (before resampling)
56    pub original_sr: Option<u32>,
57    /// Duration in seconds
58    pub duration: Option<AudioDuration>,
59    /// Whether audio has been normalized
60    pub normalized: bool,
61    /// Processing operations applied
62    pub processing_chain: Vec<String>,
63}
64
65impl AudioBuffer {
66    /// Create `AudioBuffer` from f32 samples
67    ///
68    /// # Arguments
69    ///
70    /// * `samples` - Audio samples (will be validated)
71    /// * `sample_rate` - Sample rate in Hz
72    ///
73    /// # Errors
74    ///
75    /// Returns error if samples are empty or sample rate is invalid
76    ///
77    /// # Example
78    ///
79    /// ```rust
80    /// use speech_prep::buffer::AudioBuffer;
81    ///
82    /// let samples = vec![0.1, 0.2, -0.1];
83    /// let buffer = AudioBuffer::from_samples(samples, 16000)?;
84    /// assert_eq!(buffer.len(), 3);
85    /// # Ok::<(), speech_prep::error::Error>(())
86    /// ```
87    pub fn from_samples(samples: Vec<f32>, sample_rate: u32) -> Result<Self> {
88        Self::validate_sample_rate(sample_rate)?;
89
90        if samples.is_empty() {
91            return Err(Error::empty_input("audio samples are empty"));
92        }
93
94        Self::validate_sample_values(&samples)?;
95
96        Ok(Self {
97            samples,
98            sample_rate,
99            metadata: Some(AudioBufferMetadata::default()),
100        })
101    }
102
103    /// Create `AudioBuffer` from `AudioChunk` (zero-copy data)
104    ///
105    /// Converts a streaming `AudioChunk` into an owned `AudioBuffer`.
106    ///
107    /// # Example
108    ///
109    /// ```rust
110    /// use speech_prep::buffer::AudioBuffer;
111    /// use speech_prep::types::AudioChunk;
112    ///
113    /// let chunk = AudioChunk::new(vec![0.1, 0.2], 0, 0.0, 16000);
114    /// let buffer = AudioBuffer::from_chunk(chunk)?;
115    /// assert_eq!(buffer.len(), 2);
116    /// # Ok::<(), speech_prep::error::Error>(())
117    /// ```
118    pub fn from_chunk(chunk: AudioChunk) -> Result<Self> {
119        let sample_rate = chunk.sample_rate;
120        let samples = chunk.data;
121
122        Self::from_samples(samples, sample_rate)
123    }
124
125    /// Get sample rate in Hz
126    #[must_use]
127    pub fn sample_rate(&self) -> u32 {
128        self.sample_rate
129    }
130
131    /// Get number of samples
132    #[must_use]
133    pub fn len(&self) -> usize {
134        self.samples.len()
135    }
136
137    /// Check if buffer is empty
138    #[must_use]
139    pub fn is_empty(&self) -> bool {
140        self.samples.is_empty()
141    }
142
143    /// Get duration as `AudioDuration`
144    ///
145    /// # Example
146    ///
147    /// ```rust
148    /// use speech_prep::buffer::AudioBuffer;
149    /// use speech_prep::time::AudioDuration;
150    ///
151    /// let buffer = AudioBuffer::from_samples(vec![0.0; 16000], 16000)?;
152    /// let duration = buffer.duration();
153    /// assert_eq!(duration.as_secs(), 1);
154    /// # Ok::<(), speech_prep::error::Error>(())
155    /// ```
156    #[must_use]
157    pub fn duration(&self) -> AudioDuration {
158        let duration_secs = self.samples.len() as f64 / f64::from(self.sample_rate);
159        let duration_nanos = (duration_secs * 1_000_000_000.0) as u64;
160        AudioDuration::from_nanos(duration_nanos)
161    }
162
163    /// Get immutable slice of samples
164    #[must_use]
165    pub fn samples(&self) -> &[f32] {
166        &self.samples
167    }
168
169    /// Get mutable slice of samples
170    pub fn samples_mut(&mut self) -> &mut [f32] {
171        &mut self.samples
172    }
173
174    /// Consume buffer and return samples
175    #[must_use]
176    pub fn into_samples(self) -> Vec<f32> {
177        self.samples
178    }
179
180    /// Normalize samples to [-1.0, 1.0] range
181    ///
182    /// Applies peak normalization to ensure samples are within valid range.
183    ///
184    /// # Example
185    ///
186    /// ```rust
187    /// use speech_prep::buffer::AudioBuffer;
188    ///
189    /// let mut buffer = AudioBuffer::from_samples(vec![2.0, -2.0, 1.0], 16000)?;
190    /// buffer.normalize();
191    ///
192    /// // Samples now scaled to [-1.0, 1.0]
193    /// assert!(buffer.samples().iter().all(|&s| s >= -1.0 && s <= 1.0));
194    /// # Ok::<(), speech_prep::error::Error>(())
195    /// ```
196    pub fn normalize(&mut self) {
197        let max_abs = self.samples.iter().map(|&s| s.abs()).fold(0.0f32, f32::max);
198
199        if max_abs > 0.0 {
200            let scale = 1.0 / max_abs;
201            for sample in &mut self.samples {
202                *sample *= scale;
203            }
204        }
205
206        if let Some(ref mut meta) = self.metadata {
207            meta.normalized = true;
208            meta.processing_chain.push("normalize".to_owned());
209        }
210    }
211
212    /// Validate sample values are within expected range
213    ///
214    /// Checks that all samples are finite and within reasonable bounds.
215    pub fn validate_samples(&self) -> Result<()> {
216        Self::validate_sample_values(&self.samples)
217    }
218
219    fn validate_sample_values(samples: &[f32]) -> Result<()> {
220        for &sample in samples {
221            if !sample.is_finite() {
222                return Err(Error::invalid_format("sample value is not finite"));
223            }
224        }
225
226        Ok(())
227    }
228
229    /// Get metadata
230    #[must_use]
231    pub fn metadata(&self) -> Option<&AudioBufferMetadata> {
232        self.metadata.as_ref()
233    }
234
235    /// Set metadata
236    pub fn set_metadata(&mut self, metadata: AudioBufferMetadata) {
237        self.metadata = Some(metadata);
238    }
239
240    fn validate_sample_rate(sample_rate: u32) -> Result<()> {
241        if VALID_SAMPLE_RATES.contains(&sample_rate) {
242            Ok(())
243        } else {
244            Err(Error::invalid_format(format!(
245                "unsupported sample rate: {sample_rate}"
246            )))
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_from_samples_valid() {
257        let samples = vec![0.1, 0.2, -0.1, -0.2];
258        let buffer = AudioBuffer::from_samples(samples.clone(), 16000);
259
260        assert!(buffer.is_ok());
261        let buffer = buffer.expect("buffer creation");
262        assert_eq!(buffer.len(), 4);
263        assert_eq!(buffer.sample_rate(), 16000);
264        assert_eq!(buffer.samples(), &samples[..]);
265    }
266
267    #[test]
268    fn test_from_samples_rejects_non_finite_values() {
269        let samples = vec![0.0, f32::NAN, 0.2];
270        let buffer = AudioBuffer::from_samples(samples, 16000);
271
272        assert!(buffer.is_err(), "NaN samples must be rejected");
273    }
274
275    #[test]
276    fn test_sample_rate_policy() {
277        assert!(AudioBuffer::from_samples(vec![0.0; 10], 16000).is_ok());
278        assert!(
279            AudioBuffer::from_samples(vec![0.0; 10], 32_000).is_err(),
280            "Unexpected sample rates must be rejected"
281        );
282    }
283
284    #[test]
285    fn test_from_samples_empty_fails() {
286        let samples: Vec<f32> = vec![];
287        let result = AudioBuffer::from_samples(samples, 16000);
288
289        assert!(result.is_err());
290    }
291
292    #[test]
293    fn test_from_samples_invalid_sample_rate() {
294        let samples = vec![0.1, 0.2];
295
296        assert!(AudioBuffer::from_samples(samples.clone(), 1000).is_err());
297
298        assert!(AudioBuffer::from_samples(samples, 100_000).is_err());
299    }
300
301    #[test]
302    fn test_from_chunk() {
303        let chunk = AudioChunk::new(vec![0.1, 0.2, -0.1], 0, 0.0, 16000);
304        let buffer = AudioBuffer::from_chunk(chunk);
305
306        assert!(buffer.is_ok());
307        let buffer = buffer.expect("buffer from chunk");
308        assert_eq!(buffer.len(), 3);
309        assert_eq!(buffer.sample_rate(), 16000);
310    }
311
312    #[test]
313    fn test_duration() {
314        let buffer = AudioBuffer::from_samples(vec![0.0; 16000], 16000).expect("buffer creation");
315        let duration = buffer.duration();
316
317        assert_eq!(duration.as_secs(), 1);
318        assert!(duration.as_millis() >= 1000 && duration.as_millis() <= 1001); // Allow for rounding
319    }
320
321    #[test]
322    fn test_normalize() {
323        let mut buffer =
324            AudioBuffer::from_samples(vec![2.0, -2.0, 1.0, -1.0], 16000).expect("buffer creation");
325
326        buffer.normalize();
327
328        let max_abs = buffer
329            .samples()
330            .iter()
331            .map(|&s| s.abs())
332            .fold(0.0f32, f32::max);
333        assert!((max_abs - 1.0).abs() < 1e-6);
334
335        assert!(buffer.metadata().expect("metadata").normalized);
336    }
337
338    #[test]
339    fn test_normalize_zero_samples() {
340        let mut buffer =
341            AudioBuffer::from_samples(vec![0.0, 0.0, 0.0], 16000).expect("buffer creation");
342
343        buffer.normalize();
344    }
345
346    #[test]
347    fn test_validate_samples_valid() {
348        let buffer =
349            AudioBuffer::from_samples(vec![0.1, -0.5, 0.9], 16000).expect("buffer creation");
350
351        assert!(buffer.validate_samples().is_ok());
352    }
353
354    #[test]
355    fn test_validate_samples_infinite() {
356        let mut buffer =
357            AudioBuffer::from_samples(vec![0.1, 0.2, 0.9], 16000).expect("buffer creation");
358        buffer.samples_mut()[1] = f32::INFINITY;
359        assert!(buffer.validate_samples().is_err());
360    }
361
362    #[test]
363    fn test_validate_samples_nan() {
364        let mut buffer =
365            AudioBuffer::from_samples(vec![0.1, 0.2, 0.9], 16000).expect("buffer creation");
366        buffer.samples_mut()[1] = f32::NAN;
367        assert!(buffer.validate_samples().is_err());
368    }
369
370    #[test]
371    fn test_into_samples() {
372        let samples = vec![0.1, 0.2, 0.3];
373        let buffer = AudioBuffer::from_samples(samples.clone(), 16000).expect("buffer creation");
374
375        let extracted = buffer.into_samples();
376        assert_eq!(extracted, samples);
377    }
378
379    #[test]
380    fn test_samples_mut() {
381        let mut buffer =
382            AudioBuffer::from_samples(vec![0.1, 0.2, 0.3], 16000).expect("buffer creation");
383
384        buffer.samples_mut()[0] = 0.5;
385        assert_eq!(buffer.samples()[0], 0.5);
386    }
387
388    #[test]
389    fn test_metadata_operations() {
390        let mut buffer = AudioBuffer::from_samples(vec![0.1, 0.2], 16000).expect("buffer creation");
391
392        let metadata = AudioBufferMetadata {
393            source: Some("test.wav".to_owned()),
394            original_sr: Some(44100),
395            duration: Some(AudioDuration::from_millis(125)),
396            normalized: true,
397            processing_chain: vec!["resample".to_owned(), "normalize".to_owned()],
398        };
399
400        buffer.set_metadata(metadata);
401
402        let retrieved = buffer.metadata().expect("metadata");
403        assert_eq!(retrieved.source.as_deref(), Some("test.wav"));
404        assert_eq!(retrieved.original_sr, Some(44100));
405        assert!(retrieved.normalized);
406        assert_eq!(retrieved.processing_chain.len(), 2);
407    }
408}