Skip to main content

snapcast_client/decoder/
vorbis.rs

1//! Vorbis decoder using symphonia.
2//!
3//! The snapserver sends Ogg/Vorbis stream headers as the CodecHeader,
4//! then Ogg pages containing Vorbis audio data as WireChunk payloads.
5
6use anyhow::{Result, bail};
7use snapcast_proto::SampleFormat;
8use snapcast_proto::message::codec_header::CodecHeader;
9use symphonia::core::audio::SampleBuffer;
10use symphonia::core::codecs::{CODEC_TYPE_VORBIS, CodecParameters, DecoderOptions};
11use symphonia::core::formats::Packet;
12
13use crate::decoder::Decoder;
14use crate::stream::SampleEncoding;
15
16/// Parse the Vorbis identification header from an Ogg bitstream.
17///
18/// The CodecHeader payload is a complete Ogg bitstream containing the 3 Vorbis
19/// header packets. We parse the first Ogg page to find the Vorbis identification
20/// header and extract sample rate and channels.
21///
22/// Default bit depth is 16 (Vorbis is float internally; the C++ code defaults to 16).
23fn parse_vorbis_header(payload: &[u8]) -> Result<(SampleFormat, Vec<u8>)> {
24    // Find OggS capture pattern
25    if payload.len() < 28 || &payload[0..4] != b"OggS" {
26        bail!("not an Ogg bitstream");
27    }
28
29    // Ogg page header: 27 bytes fixed + segment table
30    let num_segments = payload[26] as usize;
31    let header_size = 27 + num_segments;
32    if payload.len() < header_size {
33        bail!("Ogg page header truncated");
34    }
35
36    // First packet starts after the page header
37    let packet_start = header_size;
38    let remaining = &payload[packet_start..];
39
40    // Vorbis identification header: type(1) + "vorbis"(6) + version(4) + channels(1) + rate(4)
41    if remaining.len() < 16 {
42        bail!("Vorbis identification header too small");
43    }
44    if remaining[0] != 1 || &remaining[1..7] != b"vorbis" {
45        bail!("not a Vorbis identification header");
46    }
47
48    let channels = remaining[11] as u16;
49    let sample_rate = u32::from_le_bytes(remaining[12..16].try_into().unwrap());
50
51    if sample_rate == 0 || channels == 0 {
52        bail!("invalid Vorbis header: rate={sample_rate}, channels={channels}");
53    }
54
55    // Vorbis is float internally; output as f32 to preserve full precision.
56    let sf = SampleFormat::new(sample_rate, 32, channels);
57
58    Ok((sf, payload.to_vec()))
59}
60
61/// Vorbis audio decoder using symphonia.
62pub struct VorbisDecoder {
63    decoder: Box<dyn symphonia::core::codecs::Decoder>,
64    sample_format: SampleFormat,
65    packet_id: u64,
66}
67
68impl VorbisDecoder {
69    fn from_header(header: &CodecHeader) -> Result<Self> {
70        let (sf, extra_data) = parse_vorbis_header(&header.payload)?;
71        let mut params = CodecParameters::new();
72        params
73            .for_codec(CODEC_TYPE_VORBIS)
74            .with_sample_rate(sf.rate())
75            .with_channels(
76                symphonia::core::audio::Channels::from_bits(((1u64 << sf.channels()) - 1) as u32)
77                    .unwrap_or(symphonia::core::audio::Channels::FRONT_LEFT),
78            )
79            .with_extra_data(extra_data.into_boxed_slice());
80        let decoder = symphonia::default::get_codecs()
81            .make(&params, &DecoderOptions::default())
82            .map_err(|e| anyhow::anyhow!("failed to create Vorbis decoder: {e}"))?;
83        Ok(Self {
84            decoder,
85            sample_format: sf,
86            packet_id: 0,
87        })
88    }
89}
90
91impl Decoder for VorbisDecoder {
92    fn set_header(&mut self, header: &CodecHeader) -> Result<SampleFormat> {
93        *self = Self::from_header(header)?;
94        Ok(self.sample_format)
95    }
96
97    fn decode(&mut self, data: &mut Vec<u8>) -> Result<bool> {
98        if data.is_empty() {
99            return Ok(false);
100        }
101
102        tracing::trace!(
103            codec = "vorbis",
104            input_bytes = data.len(),
105            packet_id = self.packet_id,
106            "decode"
107        );
108
109        let packet = Packet::new_from_slice(0, self.packet_id, 0, data);
110        self.packet_id += 1;
111
112        let decoded = match self.decoder.decode(&packet) {
113            Ok(buf) => buf,
114            Err(e) => {
115                tracing::warn!(codec = "vorbis", error = %e, "decode failed");
116                return Ok(false);
117            }
118        };
119
120        let spec = *decoded.spec();
121        let frames = decoded.frames() as u64;
122
123        let mut sample_buf = SampleBuffer::<f32>::new(frames, spec);
124        sample_buf.copy_interleaved_ref(decoded);
125
126        let mut out = Vec::with_capacity(sample_buf.samples().len() * 4);
127        for &s in sample_buf.samples() {
128            out.extend_from_slice(&s.to_le_bytes());
129        }
130
131        *data = out;
132        Ok(true)
133    }
134
135    fn output_encoding(&self) -> SampleEncoding {
136        SampleEncoding::Float32
137    }
138}
139
140/// Create a VorbisDecoder from a CodecHeader.
141pub fn create(header: &CodecHeader) -> Result<VorbisDecoder> {
142    VorbisDecoder::from_header(header)
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    /// Build a minimal Ogg page containing a Vorbis identification header.
150    /// 44100 Hz, 2 channels.
151    fn ogg_vorbis_header_44100_2() -> Vec<u8> {
152        let mut page = Vec::new();
153
154        // -- Vorbis identification header packet --
155        let mut vorbis_id = Vec::new();
156        vorbis_id.push(1u8); // packet type = identification
157        vorbis_id.extend_from_slice(b"vorbis");
158        vorbis_id.extend_from_slice(&0u32.to_le_bytes()); // version
159        vorbis_id.push(2); // channels
160        vorbis_id.extend_from_slice(&44100u32.to_le_bytes()); // sample rate
161        vorbis_id.extend_from_slice(&0i32.to_le_bytes()); // bitrate max
162        vorbis_id.extend_from_slice(&128000i32.to_le_bytes()); // bitrate nominal
163        vorbis_id.extend_from_slice(&0i32.to_le_bytes()); // bitrate min
164        vorbis_id.push(0xb8); // blocksize 0=256, 1=2048 → (8<<4)|11 = 0xb8
165        vorbis_id.push(1); // framing bit
166
167        let packet_len = vorbis_id.len();
168
169        // -- Ogg page header --
170        page.extend_from_slice(b"OggS"); // capture pattern
171        page.push(0); // version
172        page.push(0x02); // header type: beginning of stream
173        page.extend_from_slice(&0u64.to_le_bytes()); // granule position
174        page.extend_from_slice(&1u32.to_le_bytes()); // serial number
175        page.extend_from_slice(&0u32.to_le_bytes()); // page sequence
176        page.extend_from_slice(&0u32.to_le_bytes()); // CRC (skip for test)
177        page.push(1); // 1 segment
178        page.push(packet_len as u8); // segment size
179
180        // -- Packet data --
181        page.extend_from_slice(&vorbis_id);
182
183        page
184    }
185
186    #[test]
187    fn parse_header_44100_2() {
188        let payload = ogg_vorbis_header_44100_2();
189        let (sf, _) = parse_vorbis_header(&payload).unwrap();
190        assert_eq!(sf.rate(), 44100);
191        assert_eq!(sf.channels(), 2);
192        assert_eq!(sf.bits(), 32);
193    }
194
195    #[test]
196    fn parse_header_48000_6() {
197        let mut page = Vec::new();
198        let mut vorbis_id = Vec::new();
199        vorbis_id.push(1u8);
200        vorbis_id.extend_from_slice(b"vorbis");
201        vorbis_id.extend_from_slice(&0u32.to_le_bytes());
202        vorbis_id.push(6); // 6 channels
203        vorbis_id.extend_from_slice(&48000u32.to_le_bytes());
204        // pad to 16 bytes minimum
205        vorbis_id.resize(30, 0);
206
207        let packet_len = vorbis_id.len();
208        page.extend_from_slice(b"OggS");
209        page.push(0);
210        page.push(0x02);
211        page.extend_from_slice(&0u64.to_le_bytes());
212        page.extend_from_slice(&1u32.to_le_bytes());
213        page.extend_from_slice(&0u32.to_le_bytes());
214        page.extend_from_slice(&0u32.to_le_bytes());
215        page.push(1);
216        page.push(packet_len as u8);
217        page.extend_from_slice(&vorbis_id);
218
219        let (sf, _) = parse_vorbis_header(&page).unwrap();
220        assert_eq!(sf.rate(), 48000);
221        assert_eq!(sf.channels(), 6);
222    }
223
224    #[test]
225    fn not_ogg_fails() {
226        assert!(parse_vorbis_header(b"NOPE_not_ogg_data_at_all!!!!!").is_err());
227    }
228
229    #[test]
230    fn not_vorbis_fails() {
231        let mut page = Vec::new();
232        page.extend_from_slice(b"OggS");
233        page.push(0);
234        page.push(0);
235        page.extend_from_slice(&[0; 20]); // rest of ogg header
236        page.push(1); // 1 segment
237        page.push(16); // segment size
238        page.extend_from_slice(&[0; 16]); // not a vorbis packet
239        assert!(parse_vorbis_header(&page).is_err());
240    }
241}