Skip to main content

zerodds_websocket_bridge/
codec.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 ZeroDDS Contributors
3
4//! WebSocket Wire-Codec — RFC 6455 §5.2 + §5.3.
5
6use alloc::vec::Vec;
7use core::fmt;
8
9use crate::frame::{Frame, Opcode};
10use crate::masking::apply_mask;
11
12/// Codec-Fehler.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum CodecError {
15    /// Header zu kurz.
16    HeaderTooShort,
17    /// Spec §5.2 — payload length 126 mit Wert <= 125 oder 127 mit
18    /// Wert <= 65535. "minimal number of bytes MUST be used".
19    NonMinimalLength,
20    /// Spec §5.2 — payload length 127 mit gesetztem MSB.
21    PayloadLengthMsbSet,
22    /// Frame-Body reicht nicht in die verfuegbaren Bytes.
23    PayloadTruncated,
24    /// Masking-Key reicht nicht in die verfuegbaren Bytes.
25    MaskingKeyTruncated,
26    /// Spec §5.5 — Control Frame mit payload > 125 Bytes ist illegal.
27    ControlFrameTooLong,
28    /// Spec §5.5 — Control Frame muss FIN=1 (kein Fragmentieren).
29    FragmentedControlFrame,
30}
31
32impl fmt::Display for CodecError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::HeaderTooShort => f.write_str("header too short"),
36            Self::NonMinimalLength => f.write_str("non-minimal payload length encoding"),
37            Self::PayloadLengthMsbSet => f.write_str("64-bit payload length MSB set"),
38            Self::PayloadTruncated => f.write_str("payload truncated"),
39            Self::MaskingKeyTruncated => f.write_str("masking key truncated"),
40            Self::ControlFrameTooLong => f.write_str("control frame payload > 125 bytes"),
41            Self::FragmentedControlFrame => f.write_str("control frame with FIN=0"),
42        }
43    }
44}
45
46#[cfg(feature = "std")]
47impl std::error::Error for CodecError {}
48
49/// Encodiert einen [`Frame`] zum WebSocket-Wire-Byte-Slice.
50///
51/// Wenn `frame.masking_key` gesetzt ist, wird das Payload waehrend
52/// des Encode XOR-maskiert (Spec §5.3).
53///
54/// # Errors
55/// * [`CodecError::ControlFrameTooLong`] wenn Control-Frame mit
56///   Payload > 125 Bytes (Spec §5.5).
57/// * [`CodecError::FragmentedControlFrame`] wenn Control-Frame mit
58///   FIN=0 (Spec §5.5).
59pub fn encode(frame: &Frame) -> Result<Vec<u8>, CodecError> {
60    if frame.opcode.is_control() {
61        if !frame.fin {
62            return Err(CodecError::FragmentedControlFrame);
63        }
64        if frame.payload.len() > 125 {
65            return Err(CodecError::ControlFrameTooLong);
66        }
67    }
68    let mut out = Vec::with_capacity(2 + 8 + 4 + frame.payload.len());
69
70    // Byte 0: FIN | RSV1 | RSV2 | RSV3 | Opcode (4 bits).
71    let mut byte0 = frame.opcode.to_bits() & 0x0F;
72    if frame.fin {
73        byte0 |= 0x80;
74    }
75    if frame.rsv1 {
76        byte0 |= 0x40;
77    }
78    if frame.rsv2 {
79        byte0 |= 0x20;
80    }
81    if frame.rsv3 {
82        byte0 |= 0x10;
83    }
84    out.push(byte0);
85
86    // Byte 1: MASK | Payload-Length (7 bits).
87    let payload_len = frame.payload.len();
88    let masked = frame.masking_key.is_some();
89    let (len7, ext_len) = encode_payload_length(payload_len);
90    let byte1 = (if masked { 0x80 } else { 0x00 }) | (len7 & 0x7F);
91    out.push(byte1);
92    out.extend_from_slice(&ext_len);
93
94    // Masking-Key (4 bytes wenn MASK=1).
95    if let Some(key) = frame.masking_key {
96        out.extend_from_slice(&key);
97        // Payload-Daten XOR-maskiert ausgeben.
98        let mut masked_payload = frame.payload.clone();
99        apply_mask(&mut masked_payload, key);
100        out.extend_from_slice(&masked_payload);
101    } else {
102        out.extend_from_slice(&frame.payload);
103    }
104
105    Ok(out)
106}
107
108/// Spec §5.2 — Payload-Length-Encoding mit "minimal number of bytes".
109/// Liefert `(7-bit-Wert, Extended-Bytes)`.
110fn encode_payload_length(len: usize) -> (u8, Vec<u8>) {
111    if len <= 125 {
112        #[allow(clippy::cast_possible_truncation)]
113        (len as u8, Vec::new())
114    } else if len <= 0xFFFF {
115        #[allow(clippy::cast_possible_truncation)]
116        (126, (len as u16).to_be_bytes().to_vec())
117    } else {
118        // Spec §5.2 — 64-bit length with MSB=0.
119        let bytes = (len as u64).to_be_bytes();
120        (127, bytes.to_vec())
121    }
122}
123
124/// Decodiert einen [`Frame`] aus dem WebSocket-Wire-Byte-Slice.
125/// Wenn das MASK-Bit gesetzt ist, wird das Payload waehrend des Decode
126/// automatisch demaskiert.
127///
128/// Liefert `(frame, consumed_bytes)`.
129///
130/// # Errors
131/// Siehe [`CodecError`].
132pub fn decode(bytes: &[u8]) -> Result<(Frame, usize), CodecError> {
133    if bytes.len() < 2 {
134        return Err(CodecError::HeaderTooShort);
135    }
136    let byte0 = bytes[0];
137    let fin = (byte0 & 0x80) != 0;
138    let rsv1 = (byte0 & 0x40) != 0;
139    let rsv2 = (byte0 & 0x20) != 0;
140    let rsv3 = (byte0 & 0x10) != 0;
141    let opcode = Opcode::from_bits(byte0 & 0x0F);
142
143    let byte1 = bytes[1];
144    let masked = (byte1 & 0x80) != 0;
145    let len7 = byte1 & 0x7F;
146
147    let mut cursor = 2usize;
148    let payload_len = match len7.cmp(&126) {
149        core::cmp::Ordering::Less => usize::from(len7),
150        core::cmp::Ordering::Equal => {
151            if bytes.len() < cursor + 2 {
152                return Err(CodecError::HeaderTooShort);
153            }
154            let v = u16::from_be_bytes([bytes[cursor], bytes[cursor + 1]]);
155            cursor += 2;
156            if v <= 125 {
157                return Err(CodecError::NonMinimalLength);
158            }
159            usize::from(v)
160        }
161        core::cmp::Ordering::Greater => {
162            // len7 == 127.
163            if bytes.len() < cursor + 8 {
164                return Err(CodecError::HeaderTooShort);
165            }
166            let mut buf = [0u8; 8];
167            buf.copy_from_slice(&bytes[cursor..cursor + 8]);
168            let v = u64::from_be_bytes(buf);
169            cursor += 8;
170            if (v & 0x8000_0000_0000_0000) != 0 {
171                return Err(CodecError::PayloadLengthMsbSet);
172            }
173            if v <= 0xFFFF {
174                return Err(CodecError::NonMinimalLength);
175            }
176            usize::try_from(v).map_err(|_| CodecError::PayloadTruncated)?
177        }
178    };
179
180    if opcode.is_control() {
181        if !fin {
182            return Err(CodecError::FragmentedControlFrame);
183        }
184        if payload_len > 125 {
185            return Err(CodecError::ControlFrameTooLong);
186        }
187    }
188
189    let masking_key = if masked {
190        if bytes.len() < cursor + 4 {
191            return Err(CodecError::MaskingKeyTruncated);
192        }
193        let key = [
194            bytes[cursor],
195            bytes[cursor + 1],
196            bytes[cursor + 2],
197            bytes[cursor + 3],
198        ];
199        cursor += 4;
200        Some(key)
201    } else {
202        None
203    };
204
205    if bytes.len() < cursor + payload_len {
206        return Err(CodecError::PayloadTruncated);
207    }
208    let mut payload = bytes[cursor..cursor + payload_len].to_vec();
209    cursor += payload_len;
210
211    if let Some(key) = masking_key {
212        apply_mask(&mut payload, key);
213    }
214
215    Ok((
216        Frame {
217            fin,
218            rsv1,
219            rsv2,
220            rsv3,
221            opcode,
222            masking_key,
223            payload,
224        },
225        cursor,
226    ))
227}
228
229#[cfg(test)]
230#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn smallest_text_frame_encodes_to_2_byte_header_plus_payload() {
236        // RFC 6455 §5.2 — payload-len <= 125 ⇒ 2 byte header.
237        let bytes = encode(&Frame::text("hi")).expect("encode");
238        assert_eq!(bytes.len(), 4);
239        // FIN=1, opcode=1.
240        assert_eq!(bytes[0], 0x81);
241        // MASK=0, len=2.
242        assert_eq!(bytes[1], 0x02);
243        assert_eq!(&bytes[2..], b"hi");
244    }
245
246    #[test]
247    fn medium_payload_uses_extended_16_bit_length() {
248        // Spec §5.2 — len 126..=65535 ⇒ marker 126 + 2 byte BE.
249        let payload = alloc::vec![0xAA; 200];
250        let f = Frame::binary(payload.clone());
251        let bytes = encode(&f).expect("encode");
252        assert_eq!(bytes[0], 0x82);
253        assert_eq!(bytes[1] & 0x7F, 126);
254        assert_eq!(&bytes[2..4], &200u16.to_be_bytes());
255        assert_eq!(&bytes[4..], &payload[..]);
256    }
257
258    #[test]
259    fn large_payload_uses_extended_64_bit_length() {
260        // Spec §5.2 — len > 65535 ⇒ marker 127 + 8 byte BE mit MSB=0.
261        let payload = alloc::vec![0xBB; 70_000];
262        let f = Frame::binary(payload.clone());
263        let bytes = encode(&f).expect("encode");
264        assert_eq!(bytes[1] & 0x7F, 127);
265        let mut len_buf = [0u8; 8];
266        len_buf.copy_from_slice(&bytes[2..10]);
267        assert_eq!(u64::from_be_bytes(len_buf), 70_000);
268        // MSB=0.
269        assert_eq!(bytes[2] & 0x80, 0);
270    }
271
272    #[test]
273    fn round_trip_unmasked_text() {
274        let f = Frame::text("hello world");
275        let bytes = encode(&f).expect("encode");
276        let (parsed, consumed) = decode(&bytes).expect("decode");
277        assert_eq!(parsed, f);
278        assert_eq!(consumed, bytes.len());
279    }
280
281    #[test]
282    fn round_trip_masked_payload_unmasked_on_decode() {
283        // Spec §5.3 — payload XOR'd with key on encode, XOR'd back on
284        // decode.
285        let f = Frame::text("masked!").with_mask([0x12, 0x34, 0x56, 0x78]);
286        let bytes = encode(&f).expect("encode");
287        // Wire-Bytes ungleich Plaintext.
288        assert_ne!(&bytes[6..], b"masked!");
289        let (parsed, _) = decode(&bytes).expect("decode");
290        assert_eq!(parsed.payload, b"masked!");
291        assert_eq!(parsed.masking_key, Some([0x12, 0x34, 0x56, 0x78]));
292    }
293
294    #[test]
295    fn round_trip_medium_and_large_payloads() {
296        for size in [126, 200, 65535, 65536, 100_000] {
297            let f = Frame::binary(alloc::vec![0xAB; size]);
298            let bytes = encode(&f).expect("encode");
299            let (parsed, _) = decode(&bytes).expect("decode");
300            assert_eq!(parsed.payload.len(), size);
301        }
302    }
303
304    #[test]
305    fn ping_frame_round_trip() {
306        let f = Frame::ping(alloc::vec![1, 2, 3]);
307        let bytes = encode(&f).expect("encode");
308        let (parsed, _) = decode(&bytes).expect("decode");
309        assert_eq!(parsed.opcode, Opcode::Ping);
310        assert_eq!(parsed.payload, alloc::vec![1, 2, 3]);
311    }
312
313    #[test]
314    fn close_frame_carries_status_code() {
315        let f = Frame::close(1000, "");
316        let bytes = encode(&f).expect("encode");
317        let (parsed, _) = decode(&bytes).expect("decode");
318        assert_eq!(parsed.opcode, Opcode::Close);
319        assert_eq!(&parsed.payload[..2], &1000u16.to_be_bytes());
320    }
321
322    #[test]
323    fn header_too_short_decode_fails() {
324        assert_eq!(decode(&[]), Err(CodecError::HeaderTooShort));
325        assert_eq!(decode(&[0x81]), Err(CodecError::HeaderTooShort));
326    }
327
328    #[test]
329    fn extended_16_bit_length_truncated_fails() {
330        // Marker 126 ohne 2 Length-Bytes.
331        assert_eq!(decode(&[0x81, 0x7E]), Err(CodecError::HeaderTooShort));
332    }
333
334    #[test]
335    fn extended_64_bit_length_msb_set_rejected() {
336        // Spec §5.2 — MSB MUST be 0.
337        let bytes = [0x82u8, 0x7F, 0x80, 0, 0, 0, 0, 0, 0, 0];
338        assert_eq!(decode(&bytes), Err(CodecError::PayloadLengthMsbSet));
339    }
340
341    #[test]
342    fn non_minimal_16_bit_length_rejected() {
343        // Spec §5.2 — "minimal number of bytes MUST be used".
344        // Marker 126 mit Wert 100 ist non-minimal.
345        let bytes = [0x82u8, 0x7E, 0, 100, 0xAA, 0xBB];
346        assert_eq!(decode(&bytes), Err(CodecError::NonMinimalLength));
347    }
348
349    #[test]
350    fn non_minimal_64_bit_length_rejected() {
351        // Marker 127 mit Wert 65000 ist non-minimal (waere 16-bit).
352        let mut bytes = alloc::vec![0x82u8, 0x7F];
353        bytes.extend_from_slice(&65000u64.to_be_bytes());
354        assert_eq!(decode(&bytes), Err(CodecError::NonMinimalLength));
355    }
356
357    #[test]
358    fn control_frame_with_long_payload_rejected_on_encode() {
359        // Spec §5.5 — control frame payload <= 125.
360        let f = Frame::ping(alloc::vec![0; 200]);
361        assert_eq!(encode(&f), Err(CodecError::ControlFrameTooLong));
362    }
363
364    #[test]
365    fn fragmented_control_frame_rejected_on_encode() {
366        // Spec §5.5 — control frames MUST NOT be fragmented (FIN=1).
367        let mut f = Frame::ping(alloc::vec![1, 2]);
368        f.fin = false;
369        assert_eq!(encode(&f), Err(CodecError::FragmentedControlFrame));
370    }
371
372    #[test]
373    fn masked_frame_without_key_bytes_decode_fails() {
374        // MASK=1 aber nur 2 Header-Bytes vorhanden.
375        let bytes = [0x81u8, 0x80];
376        assert_eq!(decode(&bytes), Err(CodecError::MaskingKeyTruncated));
377    }
378
379    #[test]
380    fn payload_truncation_decode_fails() {
381        // FIN+Text+len=10, aber nur 2 Payload-Bytes.
382        let bytes = [0x81u8, 0x0A, 0xAA, 0xBB];
383        assert_eq!(decode(&bytes), Err(CodecError::PayloadTruncated));
384    }
385
386    #[test]
387    fn rsv_bits_propagate_to_decoded_frame() {
388        // Spec §5.2 — RSV1-3 sind kein Codec-Validation-Topic
389        // (Extension-Negotiation), aber muessen 1:1 durchgereicht
390        // werden.
391        let mut f = Frame::binary(alloc::vec![1]);
392        f.rsv1 = true;
393        f.rsv3 = true;
394        let bytes = encode(&f).expect("encode");
395        let (parsed, _) = decode(&bytes).expect("decode");
396        assert!(parsed.rsv1);
397        assert!(!parsed.rsv2);
398        assert!(parsed.rsv3);
399    }
400
401    #[test]
402    fn fin_zero_text_frame_round_trip() {
403        // Spec §5.4 — Fragmentierung: FIN=0 + Text-Opcode → Caller
404        // sendet Continuation-Frames mit FIN=1 fuer letzten.
405        let mut f = Frame::text("part-1");
406        f.fin = false;
407        let bytes = encode(&f).expect("encode");
408        let (parsed, _) = decode(&bytes).expect("decode");
409        assert!(!parsed.fin);
410        assert_eq!(parsed.opcode, Opcode::Text);
411    }
412}