Skip to main content

snapcast_client/decoder/
vorbis.rs

1//! Vorbis decoder using symphonia.
2//!
3//! The snapserver sends Ogg/Vorbis stream headers as the CodecHeader,
4//! then Ogg pages containing Vorbis audio data as WireChunk payloads.
5
6use anyhow::{Result, bail};
7use snapcast_proto::SampleFormat;
8use snapcast_proto::message::codec_header::CodecHeader;
9use symphonia::core::audio::SampleBuffer;
10use symphonia::core::codecs::{CODEC_TYPE_VORBIS, CodecParameters, DecoderOptions};
11use symphonia::core::formats::Packet;
12
13use crate::decoder::Decoder;
14
15/// Parse the Vorbis identification header from an Ogg bitstream.
16///
17/// The CodecHeader payload is a complete Ogg bitstream containing the 3 Vorbis
18/// header packets. We parse the first Ogg page to find the Vorbis identification
19/// header and extract sample rate and channels.
20///
21/// Default bit depth is 16 (Vorbis is float internally; the C++ code defaults to 16).
22fn parse_vorbis_header(payload: &[u8]) -> Result<(SampleFormat, Vec<u8>)> {
23    // Find OggS capture pattern
24    if payload.len() < 28 || &payload[0..4] != b"OggS" {
25        bail!("not an Ogg bitstream");
26    }
27
28    // Ogg page header: 27 bytes fixed + segment table
29    let num_segments = payload[26] as usize;
30    let header_size = 27 + num_segments;
31    if payload.len() < header_size {
32        bail!("Ogg page header truncated");
33    }
34
35    // First packet starts after the page header
36    let packet_start = header_size;
37    let remaining = &payload[packet_start..];
38
39    // Vorbis identification header: type(1) + "vorbis"(6) + version(4) + channels(1) + rate(4)
40    if remaining.len() < 16 {
41        bail!("Vorbis identification header too small");
42    }
43    if remaining[0] != 1 || &remaining[1..7] != b"vorbis" {
44        bail!("not a Vorbis identification header");
45    }
46
47    let channels = remaining[11] as u16;
48    let sample_rate = u32::from_le_bytes(remaining[12..16].try_into().unwrap());
49
50    if sample_rate == 0 || channels == 0 {
51        bail!("invalid Vorbis header: rate={sample_rate}, channels={channels}");
52    }
53
54    // Default to 16-bit (Vorbis is float internally, C++ defaults to 16)
55    let sf = SampleFormat::new(sample_rate, 16, channels);
56
57    Ok((sf, payload.to_vec()))
58}
59
60/// Vorbis audio decoder using symphonia.
61pub struct VorbisDecoder {
62    decoder: Box<dyn symphonia::core::codecs::Decoder>,
63    sample_format: SampleFormat,
64    packet_id: u64,
65}
66
67impl VorbisDecoder {
68    fn from_header(header: &CodecHeader) -> Result<Self> {
69        let (sf, extra_data) = parse_vorbis_header(&header.payload)?;
70        let mut params = CodecParameters::new();
71        params
72            .for_codec(CODEC_TYPE_VORBIS)
73            .with_sample_rate(sf.rate())
74            .with_channels(
75                symphonia::core::audio::Channels::from_bits(((1u64 << sf.channels()) - 1) as u32)
76                    .unwrap_or(symphonia::core::audio::Channels::FRONT_LEFT),
77            )
78            .with_extra_data(extra_data.into_boxed_slice());
79        let decoder = symphonia::default::get_codecs()
80            .make(&params, &DecoderOptions::default())
81            .map_err(|e| anyhow::anyhow!("failed to create Vorbis decoder: {e}"))?;
82        Ok(Self {
83            decoder,
84            sample_format: sf,
85            packet_id: 0,
86        })
87    }
88}
89
90impl Decoder for VorbisDecoder {
91    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat> {
92        *self = Self::from_header(header)?;
93        Ok(self.sample_format)
94    }
95
96    fn decode(&mut self, data: &mut Vec<u8>) -> Result<bool> {
97        if data.is_empty() {
98            return Ok(false);
99        }
100
101        tracing::trace!(
102            codec = "vorbis",
103            input_bytes = data.len(),
104            packet_id = self.packet_id,
105            "decode"
106        );
107
108        let packet = Packet::new_from_slice(0, self.packet_id, 0, data);
109        self.packet_id += 1;
110
111        let decoded = match self.decoder.decode(&packet) {
112            Ok(buf) => buf,
113            Err(e) => {
114                tracing::warn!(codec = "vorbis", error = %e, "decode failed");
115                return Ok(false);
116            }
117        };
118
119        let spec = *decoded.spec();
120        let frames = decoded.frames() as u64;
121
122        let mut sample_buf = SampleBuffer::<i16>::new(frames, spec);
123        sample_buf.copy_interleaved_ref(decoded);
124
125        let mut out = Vec::with_capacity(sample_buf.samples().len() * 2);
126        for &s in sample_buf.samples() {
127            out.extend_from_slice(&s.to_le_bytes());
128        }
129
130        *data = out;
131        Ok(true)
132    }
133}
134
135/// Create a VorbisDecoder from a CodecHeader.
136pub fn create(header: &CodecHeader) -> Result<VorbisDecoder> {
137    VorbisDecoder::from_header(header)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    /// Build a minimal Ogg page containing a Vorbis identification header.
145    /// 44100 Hz, 2 channels.
146    fn ogg_vorbis_header_44100_2() -> Vec<u8> {
147        let mut page = Vec::new();
148
149        // -- Vorbis identification header packet --
150        let mut vorbis_id = Vec::new();
151        vorbis_id.push(1u8); // packet type = identification
152        vorbis_id.extend_from_slice(b"vorbis");
153        vorbis_id.extend_from_slice(&0u32.to_le_bytes()); // version
154        vorbis_id.push(2); // channels
155        vorbis_id.extend_from_slice(&44100u32.to_le_bytes()); // sample rate
156        vorbis_id.extend_from_slice(&0i32.to_le_bytes()); // bitrate max
157        vorbis_id.extend_from_slice(&128000i32.to_le_bytes()); // bitrate nominal
158        vorbis_id.extend_from_slice(&0i32.to_le_bytes()); // bitrate min
159        vorbis_id.push(0xb8); // blocksize 0=256, 1=2048 → (8<<4)|11 = 0xb8
160        vorbis_id.push(1); // framing bit
161
162        let packet_len = vorbis_id.len();
163
164        // -- Ogg page header --
165        page.extend_from_slice(b"OggS"); // capture pattern
166        page.push(0); // version
167        page.push(0x02); // header type: beginning of stream
168        page.extend_from_slice(&0u64.to_le_bytes()); // granule position
169        page.extend_from_slice(&1u32.to_le_bytes()); // serial number
170        page.extend_from_slice(&0u32.to_le_bytes()); // page sequence
171        page.extend_from_slice(&0u32.to_le_bytes()); // CRC (skip for test)
172        page.push(1); // 1 segment
173        page.push(packet_len as u8); // segment size
174
175        // -- Packet data --
176        page.extend_from_slice(&vorbis_id);
177
178        page
179    }
180
181    #[test]
182    fn parse_header_44100_2() {
183        let payload = ogg_vorbis_header_44100_2();
184        let (sf, _) = parse_vorbis_header(&payload).unwrap();
185        assert_eq!(sf.rate(), 44100);
186        assert_eq!(sf.channels(), 2);
187        assert_eq!(sf.bits(), 16); // default
188    }
189
190    #[test]
191    fn parse_header_48000_6() {
192        let mut page = Vec::new();
193        let mut vorbis_id = Vec::new();
194        vorbis_id.push(1u8);
195        vorbis_id.extend_from_slice(b"vorbis");
196        vorbis_id.extend_from_slice(&0u32.to_le_bytes());
197        vorbis_id.push(6); // 6 channels
198        vorbis_id.extend_from_slice(&48000u32.to_le_bytes());
199        // pad to 16 bytes minimum
200        vorbis_id.resize(30, 0);
201
202        let packet_len = vorbis_id.len();
203        page.extend_from_slice(b"OggS");
204        page.push(0);
205        page.push(0x02);
206        page.extend_from_slice(&0u64.to_le_bytes());
207        page.extend_from_slice(&1u32.to_le_bytes());
208        page.extend_from_slice(&0u32.to_le_bytes());
209        page.extend_from_slice(&0u32.to_le_bytes());
210        page.push(1);
211        page.push(packet_len as u8);
212        page.extend_from_slice(&vorbis_id);
213
214        let (sf, _) = parse_vorbis_header(&page).unwrap();
215        assert_eq!(sf.rate(), 48000);
216        assert_eq!(sf.channels(), 6);
217    }
218
219    #[test]
220    fn not_ogg_fails() {
221        assert!(parse_vorbis_header(b"NOPE_not_ogg_data_at_all!!!!!").is_err());
222    }
223
224    #[test]
225    fn not_vorbis_fails() {
226        let mut page = Vec::new();
227        page.extend_from_slice(b"OggS");
228        page.push(0);
229        page.push(0);
230        page.extend_from_slice(&[0; 20]); // rest of ogg header
231        page.push(1); // 1 segment
232        page.push(16); // segment size
233        page.extend_from_slice(&[0; 16]); // not a vorbis packet
234        assert!(parse_vorbis_header(&page).is_err());
235    }
236}