Skip to main content

wire/
session.rs

1//! Compact session header encoding for client replication.
2
3use crate::error::{DecodeError, EncodeError, SectionFramingError, WireResult};
4
5/// Maximum encoded size of a session header in bytes.
6pub const SESSION_MAX_HEADER_SIZE: usize = 1 + 5 + 5 + 5;
7
8/// Flags for session headers (compact, 1 byte).
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
10pub struct SessionFlags(u8);
11
12impl SessionFlags {
13    /// Flag indicating a full snapshot packet.
14    pub const FULL_SNAPSHOT: u8 = 1 << 0;
15    /// Flag indicating a delta snapshot packet.
16    pub const DELTA_SNAPSHOT: u8 = 1 << 1;
17    /// Reserved bits mask (must be zero).
18    const RESERVED_MASK: u8 = !0b11;
19
20    /// Creates flags from a raw value.
21    #[must_use]
22    pub const fn from_raw(raw: u8) -> Self {
23        Self(raw)
24    }
25
26    /// Returns the raw flag bits.
27    #[must_use]
28    pub const fn raw(self) -> u8 {
29        self.0
30    }
31
32    /// Returns `true` if this is a full snapshot.
33    #[must_use]
34    pub const fn is_full_snapshot(self) -> bool {
35        self.0 & Self::FULL_SNAPSHOT != 0
36    }
37
38    /// Returns `true` if this is a delta snapshot.
39    #[must_use]
40    pub const fn is_delta_snapshot(self) -> bool {
41        self.0 & Self::DELTA_SNAPSHOT != 0
42    }
43
44    /// Returns `true` if flags are valid (exactly one snapshot bit, no reserved).
45    #[must_use]
46    pub const fn is_valid(self) -> bool {
47        let has_full = self.is_full_snapshot();
48        let has_delta = self.is_delta_snapshot();
49        let has_reserved = self.0 & Self::RESERVED_MASK != 0;
50        (has_full ^ has_delta) && !has_reserved
51    }
52
53    /// Creates flags for a full snapshot.
54    #[must_use]
55    pub const fn full_snapshot() -> Self {
56        Self(Self::FULL_SNAPSHOT)
57    }
58
59    /// Creates flags for a delta snapshot.
60    #[must_use]
61    pub const fn delta_snapshot() -> Self {
62        Self(Self::DELTA_SNAPSHOT)
63    }
64}
65
66/// Decoded session header (compact format).
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct SessionHeader {
69    pub flags: SessionFlags,
70    pub tick: u32,
71    pub baseline_tick: u32,
72    pub payload_len: u32,
73    pub header_len: usize,
74}
75
76/// Encodes a compact session header into the provided buffer.
77pub fn encode_session_header(
78    out: &mut [u8],
79    flags: SessionFlags,
80    tick_delta: u32,
81    baseline_delta: u32,
82    payload_len: u32,
83) -> Result<usize, EncodeError> {
84    if out.len() < SESSION_MAX_HEADER_SIZE {
85        return Err(EncodeError::BufferTooSmall {
86            needed: SESSION_MAX_HEADER_SIZE,
87            available: out.len(),
88        });
89    }
90    if !flags.is_valid() {
91        return Err(EncodeError::LengthOverflow { length: 0 });
92    }
93
94    let mut offset = 0;
95    out[offset] = flags.raw();
96    offset += 1;
97    offset += write_varu32(tick_delta, &mut out[offset..]);
98    offset += write_varu32(baseline_delta, &mut out[offset..]);
99    offset += write_varu32(payload_len, &mut out[offset..]);
100    Ok(offset)
101}
102
103/// Decodes a compact session header from the provided buffer.
104pub fn decode_session_header(buf: &[u8], last_tick: u32) -> WireResult<SessionHeader> {
105    if buf.is_empty() {
106        return Err(DecodeError::PacketTooSmall {
107            actual: buf.len(),
108            required: 1,
109        });
110    }
111
112    let flags = SessionFlags::from_raw(buf[0]);
113    if !flags.is_valid() {
114        return Err(DecodeError::InvalidFlags {
115            flags: flags.raw() as u16,
116        });
117    }
118
119    let mut offset = 1;
120    let (tick_delta, new_offset) = read_varu32(buf, offset)?;
121    offset = new_offset;
122    if tick_delta == 0 {
123        return Err(DecodeError::InvalidFlags {
124            flags: flags.raw() as u16,
125        });
126    }
127    let tick = last_tick
128        .checked_add(tick_delta)
129        .ok_or(DecodeError::InvalidFlags {
130            flags: flags.raw() as u16,
131        })?;
132
133    let (baseline_delta, new_offset) = read_varu32(buf, offset)?;
134    offset = new_offset;
135    let baseline_tick =
136        tick.checked_sub(baseline_delta)
137            .ok_or(DecodeError::InvalidBaselineTick {
138                baseline_tick: baseline_delta,
139                flags: flags.raw() as u16,
140            })?;
141    if flags.is_full_snapshot() && baseline_delta != 0 {
142        return Err(DecodeError::InvalidBaselineTick {
143            baseline_tick,
144            flags: flags.raw() as u16,
145        });
146    }
147    if flags.is_delta_snapshot() && baseline_tick == 0 {
148        return Err(DecodeError::InvalidBaselineTick {
149            baseline_tick,
150            flags: flags.raw() as u16,
151        });
152    }
153
154    let (payload_len, new_offset) = read_varu32(buf, offset)?;
155    offset = new_offset;
156
157    Ok(SessionHeader {
158        flags,
159        tick,
160        baseline_tick,
161        payload_len,
162        header_len: offset,
163    })
164}
165
166fn read_varu32(buf: &[u8], mut offset: usize) -> Result<(u32, usize), DecodeError> {
167    let mut value: u64 = 0;
168    let mut shift = 0;
169    let mut count = 0;
170    loop {
171        if offset >= buf.len() {
172            return Err(DecodeError::SectionFraming(
173                SectionFramingError::Truncated {
174                    needed: 1,
175                    available: buf.len().saturating_sub(offset),
176                },
177            ));
178        }
179        let byte = buf[offset];
180        offset += 1;
181        count += 1;
182        value |= ((byte & 0x7F) as u64) << shift;
183        if (byte & 0x80) == 0 {
184            break;
185        }
186        shift += 7;
187        if count >= 5 {
188            return Err(DecodeError::SectionFraming(
189                SectionFramingError::InvalidVarint,
190            ));
191        }
192    }
193    Ok((value as u32, offset))
194}
195
196fn write_varu32(mut value: u32, out: &mut [u8]) -> usize {
197    let mut offset = 0;
198    loop {
199        let mut byte = (value & 0x7F) as u8;
200        value >>= 7;
201        if value != 0 {
202            byte |= 0x80;
203        }
204        out[offset] = byte;
205        offset += 1;
206        if value == 0 {
207            break;
208        }
209    }
210    offset
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn session_header_roundtrip_delta() {
219        let mut buf = [0u8; SESSION_MAX_HEADER_SIZE];
220        let len =
221            encode_session_header(&mut buf, SessionFlags::delta_snapshot(), 2, 1, 123).unwrap();
222        let decoded = decode_session_header(&buf[..len], 10).unwrap();
223        assert_eq!(decoded.tick, 12);
224        assert_eq!(decoded.baseline_tick, 11);
225        assert_eq!(decoded.payload_len, 123);
226    }
227
228    #[test]
229    fn session_header_rejects_zero_tick_delta() {
230        let mut buf = [0u8; SESSION_MAX_HEADER_SIZE];
231        let len =
232            encode_session_header(&mut buf, SessionFlags::delta_snapshot(), 0, 1, 10).unwrap();
233        let err = decode_session_header(&buf[..len], 1).unwrap_err();
234        assert!(matches!(err, DecodeError::InvalidFlags { .. }));
235    }
236}