Skip to main content

snapcast_client/decoder/
mod.rs

1//! Audio decoder trait and implementations.
2
3#[cfg(feature = "f32lz4")]
4pub mod f32lz4;
5pub mod flac;
6pub mod opus;
7pub mod vorbis;
8
9use anyhow::{Result, bail};
10use snapcast_proto::SampleFormat;
11use snapcast_proto::message::codec_header::CodecHeader;
12
13/// Audio decoder trait — matches the C++ `Decoder` interface.
14pub trait Decoder: Send {
15    /// Initialize the decoder from a codec header. Returns the sample format.
16    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat>;
17
18    /// Decode audio data in-place. Returns true if successful.
19    /// For PCM this is a no-op (passthrough).
20    fn decode(&mut self, data: &mut Vec<u8>) -> Result<bool>;
21}
22
23/// PCM decoder — passthrough. Parses the RIFF/WAV header to extract sample format.
24#[derive(Default)]
25pub struct PcmDecoder;
26
27impl PcmDecoder {
28    /// Create a new PCM passthrough decoder.
29    pub fn new() -> Self {
30        Self
31    }
32}
33
34/// Parse a RIFF/WAV header and return the sample format.
35fn parse_riff_header(payload: &[u8]) -> Result<SampleFormat> {
36    if payload.len() < 44 {
37        bail!("PCM header too small ({} bytes)", payload.len());
38    }
39
40    // RIFF header: "RIFF" + size + "WAVE"
41    if &payload[0..4] != b"RIFF" || &payload[8..12] != b"WAVE" {
42        bail!("not a RIFF/WAVE header");
43    }
44
45    let mut pos = 12;
46    let mut sample_rate: u32 = 0;
47    let mut bits_per_sample: u16 = 0;
48    let mut num_channels: u16 = 0;
49
50    // Walk chunks until we find "fmt " and "data"
51    while pos + 8 <= payload.len() {
52        let chunk_id = &payload[pos..pos + 4];
53        let chunk_size = u32::from_le_bytes(payload[pos + 4..pos + 8].try_into().unwrap()) as usize;
54        pos += 8;
55
56        if chunk_id == b"fmt " {
57            if pos + 16 > payload.len() {
58                bail!("fmt chunk too small");
59            }
60            // audio_format: u16 at +0 (skip)
61            num_channels = u16::from_le_bytes(payload[pos + 2..pos + 4].try_into().unwrap());
62            sample_rate = u32::from_le_bytes(payload[pos + 4..pos + 8].try_into().unwrap());
63            // byte_rate: u32 at +8 (skip)
64            // block_align: u16 at +12 (skip)
65            bits_per_sample = u16::from_le_bytes(payload[pos + 14..pos + 16].try_into().unwrap());
66            pos += chunk_size;
67        } else if chunk_id == b"data" {
68            break;
69        } else {
70            pos += chunk_size;
71        }
72    }
73
74    if sample_rate == 0 {
75        bail!("sample format not found in RIFF header");
76    }
77
78    Ok(SampleFormat::new(
79        sample_rate,
80        bits_per_sample,
81        num_channels,
82    ))
83}
84
85impl Decoder for PcmDecoder {
86    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat> {
87        tracing::trace!(
88            codec = "pcm",
89            payload_len = header.payload.len(),
90            "set_header"
91        );
92        parse_riff_header(&header.payload)
93    }
94
95    fn decode(&mut self, _data: &mut Vec<u8>) -> Result<bool> {
96        // PCM is passthrough — no decoding needed
97        Ok(true)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    /// Minimal valid 44-byte WAV header for 48000:16:2
106    fn wav_header_48000_16_2() -> Vec<u8> {
107        let mut h = Vec::new();
108        h.extend_from_slice(b"RIFF");
109        h.extend_from_slice(&0u32.to_le_bytes()); // file size (don't care)
110        h.extend_from_slice(b"WAVE");
111        // fmt chunk
112        h.extend_from_slice(b"fmt ");
113        h.extend_from_slice(&16u32.to_le_bytes()); // chunk size
114        h.extend_from_slice(&1u16.to_le_bytes()); // PCM format
115        h.extend_from_slice(&2u16.to_le_bytes()); // channels
116        h.extend_from_slice(&48000u32.to_le_bytes()); // sample rate
117        h.extend_from_slice(&192000u32.to_le_bytes()); // byte rate
118        h.extend_from_slice(&4u16.to_le_bytes()); // block align
119        h.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
120        // data chunk
121        h.extend_from_slice(b"data");
122        h.extend_from_slice(&0u32.to_le_bytes()); // data size
123        h
124    }
125
126    #[test]
127    fn parse_wav_header() {
128        let header = CodecHeader {
129            codec: "pcm".into(),
130            payload: wav_header_48000_16_2(),
131        };
132        let mut dec = PcmDecoder::new();
133        let sf = dec.set_header(&header).unwrap();
134        assert_eq!(sf.rate(), 48000);
135        assert_eq!(sf.bits(), 16);
136        assert_eq!(sf.channels(), 2);
137    }
138
139    #[test]
140    fn parse_wav_header_44100_24_2() {
141        let mut h = Vec::new();
142        h.extend_from_slice(b"RIFF");
143        h.extend_from_slice(&0u32.to_le_bytes());
144        h.extend_from_slice(b"WAVE");
145        h.extend_from_slice(b"fmt ");
146        h.extend_from_slice(&16u32.to_le_bytes());
147        h.extend_from_slice(&1u16.to_le_bytes());
148        h.extend_from_slice(&2u16.to_le_bytes());
149        h.extend_from_slice(&44100u32.to_le_bytes());
150        h.extend_from_slice(&264600u32.to_le_bytes());
151        h.extend_from_slice(&6u16.to_le_bytes());
152        h.extend_from_slice(&24u16.to_le_bytes());
153        h.extend_from_slice(b"data");
154        h.extend_from_slice(&0u32.to_le_bytes());
155
156        let header = CodecHeader {
157            codec: "pcm".into(),
158            payload: h,
159        };
160        let mut dec = PcmDecoder::new();
161        let sf = dec.set_header(&header).unwrap();
162        assert_eq!(sf.rate(), 44100);
163        assert_eq!(sf.bits(), 24);
164        assert_eq!(sf.channels(), 2);
165    }
166
167    #[test]
168    fn too_small_header_fails() {
169        let header = CodecHeader {
170            codec: "pcm".into(),
171            payload: vec![0; 10],
172        };
173        let mut dec = PcmDecoder::new();
174        assert!(dec.set_header(&header).is_err());
175    }
176
177    #[test]
178    fn not_riff_fails() {
179        let mut h = vec![0u8; 44];
180        h[0..4].copy_from_slice(b"NOPE");
181        let header = CodecHeader {
182            codec: "pcm".into(),
183            payload: h,
184        };
185        let mut dec = PcmDecoder::new();
186        assert!(dec.set_header(&header).is_err());
187    }
188
189    #[test]
190    fn decode_is_passthrough() {
191        let mut dec = PcmDecoder::new();
192        let mut data = vec![0xAA, 0xBB, 0xCC];
193        assert!(dec.decode(&mut data).unwrap());
194        assert_eq!(data, vec![0xAA, 0xBB, 0xCC]);
195    }
196}