Skip to main content

snapcast_client/decoder/
mod.rs

1//! Audio decoder trait and implementations.
2
3#[cfg(feature = "f32lz4")]
4pub mod f32lz4;
5pub mod flac;
6pub mod opus;
7pub mod vorbis;
8
9use anyhow::{Result, bail};
10use snapcast_proto::SampleFormat;
11use snapcast_proto::message::codec_header::CodecHeader;
12
13use crate::stream::SampleEncoding;
14
15/// Audio decoder trait — matches the C++ `Decoder` interface.
16pub trait Decoder: Send {
17    /// Initialize the decoder from a codec header. Returns the sample format.
18    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat>;
19
20    /// Decode audio data in-place. Returns true if successful.
21    /// For PCM this is a no-op (passthrough).
22    fn decode(&mut self, data: &mut Vec<u8>) -> Result<bool>;
23
24    /// Encoding produced by [`decode`](Self::decode).
25    fn output_encoding(&self) -> SampleEncoding {
26        SampleEncoding::PcmInt
27    }
28}
29
30/// PCM decoder — passthrough. Parses the RIFF/WAV header to extract sample format.
31#[derive(Default)]
32pub struct PcmDecoder;
33
34impl PcmDecoder {
35    /// Create a new PCM passthrough decoder.
36    pub fn new() -> Self {
37        Self
38    }
39}
40
41/// Parse a RIFF/WAV header and return the sample format.
42fn parse_riff_header(payload: &[u8]) -> Result<SampleFormat> {
43    if payload.len() < 44 {
44        bail!("PCM header too small ({} bytes)", payload.len());
45    }
46
47    // RIFF header: "RIFF" + size + "WAVE"
48    if &payload[0..4] != b"RIFF" || &payload[8..12] != b"WAVE" {
49        bail!("not a RIFF/WAVE header");
50    }
51
52    let mut pos = 12;
53    let mut sample_rate: u32 = 0;
54    let mut bits_per_sample: u16 = 0;
55    let mut num_channels: u16 = 0;
56
57    // Walk chunks until we find "fmt " and "data"
58    while pos + 8 <= payload.len() {
59        let chunk_id = &payload[pos..pos + 4];
60        let chunk_size = u32::from_le_bytes(payload[pos + 4..pos + 8].try_into().unwrap()) as usize;
61        pos += 8;
62
63        if chunk_id == b"fmt " {
64            if pos + 16 > payload.len() {
65                bail!("fmt chunk too small");
66            }
67            // audio_format: u16 at +0 (skip)
68            num_channels = u16::from_le_bytes(payload[pos + 2..pos + 4].try_into().unwrap());
69            sample_rate = u32::from_le_bytes(payload[pos + 4..pos + 8].try_into().unwrap());
70            // byte_rate: u32 at +8 (skip)
71            // block_align: u16 at +12 (skip)
72            bits_per_sample = u16::from_le_bytes(payload[pos + 14..pos + 16].try_into().unwrap());
73            pos += chunk_size;
74        } else if chunk_id == b"data" {
75            break;
76        } else {
77            pos += chunk_size;
78        }
79    }
80
81    if sample_rate == 0 {
82        bail!("sample format not found in RIFF header");
83    }
84
85    Ok(SampleFormat::new(
86        sample_rate,
87        bits_per_sample,
88        num_channels,
89    ))
90}
91
92impl Decoder for PcmDecoder {
93    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat> {
94        tracing::trace!(
95            codec = "pcm",
96            payload_len = header.payload.len(),
97            "set_header"
98        );
99        parse_riff_header(&header.payload)
100    }
101
102    fn decode(&mut self, _data: &mut Vec<u8>) -> Result<bool> {
103        // PCM is passthrough — no decoding needed
104        Ok(true)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    /// Minimal valid 44-byte WAV header for 48000:16:2
113    fn wav_header_48000_16_2() -> Vec<u8> {
114        let mut h = Vec::new();
115        h.extend_from_slice(b"RIFF");
116        h.extend_from_slice(&0u32.to_le_bytes()); // file size (don't care)
117        h.extend_from_slice(b"WAVE");
118        // fmt chunk
119        h.extend_from_slice(b"fmt ");
120        h.extend_from_slice(&16u32.to_le_bytes()); // chunk size
121        h.extend_from_slice(&1u16.to_le_bytes()); // PCM format
122        h.extend_from_slice(&2u16.to_le_bytes()); // channels
123        h.extend_from_slice(&48000u32.to_le_bytes()); // sample rate
124        h.extend_from_slice(&192000u32.to_le_bytes()); // byte rate
125        h.extend_from_slice(&4u16.to_le_bytes()); // block align
126        h.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
127        // data chunk
128        h.extend_from_slice(b"data");
129        h.extend_from_slice(&0u32.to_le_bytes()); // data size
130        h
131    }
132
133    #[test]
134    fn parse_wav_header() {
135        let header = CodecHeader {
136            codec: "pcm".into(),
137            payload: wav_header_48000_16_2(),
138        };
139        let mut dec = PcmDecoder::new();
140        let sf = dec.set_header(&header).unwrap();
141        assert_eq!(sf.rate(), 48000);
142        assert_eq!(sf.bits(), 16);
143        assert_eq!(sf.channels(), 2);
144    }
145
146    #[test]
147    fn parse_wav_header_44100_24_2() {
148        let mut h = Vec::new();
149        h.extend_from_slice(b"RIFF");
150        h.extend_from_slice(&0u32.to_le_bytes());
151        h.extend_from_slice(b"WAVE");
152        h.extend_from_slice(b"fmt ");
153        h.extend_from_slice(&16u32.to_le_bytes());
154        h.extend_from_slice(&1u16.to_le_bytes());
155        h.extend_from_slice(&2u16.to_le_bytes());
156        h.extend_from_slice(&44100u32.to_le_bytes());
157        h.extend_from_slice(&264600u32.to_le_bytes());
158        h.extend_from_slice(&6u16.to_le_bytes());
159        h.extend_from_slice(&24u16.to_le_bytes());
160        h.extend_from_slice(b"data");
161        h.extend_from_slice(&0u32.to_le_bytes());
162
163        let header = CodecHeader {
164            codec: "pcm".into(),
165            payload: h,
166        };
167        let mut dec = PcmDecoder::new();
168        let sf = dec.set_header(&header).unwrap();
169        assert_eq!(sf.rate(), 44100);
170        assert_eq!(sf.bits(), 24);
171        assert_eq!(sf.channels(), 2);
172    }
173
174    #[test]
175    fn too_small_header_fails() {
176        let header = CodecHeader {
177            codec: "pcm".into(),
178            payload: vec![0; 10],
179        };
180        let mut dec = PcmDecoder::new();
181        assert!(dec.set_header(&header).is_err());
182    }
183
184    #[test]
185    fn not_riff_fails() {
186        let mut h = vec![0u8; 44];
187        h[0..4].copy_from_slice(b"NOPE");
188        let header = CodecHeader {
189            codec: "pcm".into(),
190            payload: h,
191        };
192        let mut dec = PcmDecoder::new();
193        assert!(dec.set_header(&header).is_err());
194    }
195
196    #[test]
197    fn decode_is_passthrough() {
198        let mut dec = PcmDecoder::new();
199        let mut data = vec![0xAA, 0xBB, 0xCC];
200        assert!(dec.decode(&mut data).unwrap());
201        assert_eq!(data, vec![0xAA, 0xBB, 0xCC]);
202    }
203}