Skip to main content

codec/
session.rs

1//! Session state machine for compact headers.
2
3use bitstream::{BitReader, BitWriter};
4use schema::schema_hash;
5use wire::{PacketFlags, PacketHeader, SectionTag, WirePacket, WireSection};
6
7use crate::error::{CodecError, CodecResult};
8use crate::limits::CodecLimits;
9use crate::snapshot::write_section;
10use crate::types::SnapshotTick;
11
12/// Compact header mode negotiated via session init.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum CompactHeaderMode {
15    /// Compact session header v1.
16    SessionV1 = 1,
17}
18
19impl CompactHeaderMode {
20    fn from_raw(raw: u8) -> Option<Self> {
21        match raw {
22            1 => Some(Self::SessionV1),
23            _ => None,
24        }
25    }
26}
27
28/// Session state for compact headers.
29#[derive(Debug, Clone)]
30pub struct SessionState {
31    pub schema_hash: u64,
32    pub session_id: Option<u64>,
33    pub last_tick: SnapshotTick,
34    pub compact_mode: CompactHeaderMode,
35}
36
37/// Encodes a session init packet.
38pub fn encode_session_init_packet(
39    schema: &schema::Schema,
40    tick: SnapshotTick,
41    session_id: Option<u64>,
42    compact_mode: CompactHeaderMode,
43    limits: &CodecLimits,
44    out: &mut [u8],
45) -> CodecResult<usize> {
46    let mut offset = wire::HEADER_SIZE;
47    let body_len = write_section(
48        SectionTag::SessionInit,
49        &mut out[offset..],
50        limits,
51        |writer| encode_session_init_body(session_id, compact_mode, writer),
52    )?;
53    offset += body_len;
54
55    let payload_len = offset - wire::HEADER_SIZE;
56    let header = PacketHeader {
57        version: wire::VERSION,
58        flags: PacketFlags::session_init(),
59        schema_hash: schema_hash(schema),
60        tick: tick.raw(),
61        baseline_tick: 0,
62        payload_len: payload_len as u32,
63    };
64    wire::encode_header(&header, &mut out[..wire::HEADER_SIZE]).map_err(|_| {
65        CodecError::OutputTooSmall {
66            needed: wire::HEADER_SIZE,
67            available: out.len(),
68        }
69    })?;
70
71    Ok(offset)
72}
73
74fn encode_session_init_body(
75    session_id: Option<u64>,
76    compact_mode: CompactHeaderMode,
77    writer: &mut BitWriter<'_>,
78) -> CodecResult<()> {
79    writer.align_to_byte()?;
80    writer.write_u64_aligned(session_id.unwrap_or(0))?;
81    writer.write_u8_aligned(compact_mode as u8)?;
82    writer.align_to_byte()?;
83    Ok(())
84}
85
86/// Decodes a session init packet into session state.
87pub fn decode_session_init_packet(
88    schema: &schema::Schema,
89    packet: &WirePacket<'_>,
90    limits: &CodecLimits,
91) -> CodecResult<SessionState> {
92    let header = packet.header;
93    if !header.flags.is_session_init() {
94        return Err(CodecError::SessionMissing);
95    }
96    if header.flags.is_full_snapshot() || header.flags.is_delta_snapshot() {
97        return Err(CodecError::SessionInitInvalid);
98    }
99    if header.baseline_tick != 0 {
100        return Err(CodecError::SessionInitInvalid);
101    }
102    let expected_hash = schema_hash(schema);
103    if header.schema_hash != expected_hash {
104        return Err(CodecError::SchemaMismatch {
105            expected: expected_hash,
106            found: header.schema_hash,
107        });
108    }
109
110    let mut init_section: Option<&WireSection<'_>> = None;
111    for section in &packet.sections {
112        match section.tag {
113            SectionTag::SessionInit => {
114                if init_section.is_some() {
115                    return Err(CodecError::SessionInitInvalid);
116                }
117                init_section = Some(section);
118            }
119            _ => {
120                return Err(CodecError::UnexpectedSection {
121                    section: section.tag,
122                });
123            }
124        }
125    }
126    let section = init_section.ok_or(CodecError::SessionInitInvalid)?;
127    let (session_id, compact_mode) = decode_session_init_body(section.body, limits)?;
128
129    Ok(SessionState {
130        schema_hash: header.schema_hash,
131        session_id,
132        last_tick: SnapshotTick::new(header.tick),
133        compact_mode,
134    })
135}
136
137fn decode_session_init_body(
138    body: &[u8],
139    limits: &CodecLimits,
140) -> CodecResult<(Option<u64>, CompactHeaderMode)> {
141    if body.len() > limits.max_section_bytes {
142        return Err(CodecError::LimitsExceeded {
143            kind: crate::error::LimitKind::SectionBytes,
144            limit: limits.max_section_bytes,
145            actual: body.len(),
146        });
147    }
148    let mut reader = BitReader::new(body);
149    reader.align_to_byte()?;
150    let session_id = reader.read_u64_aligned()?;
151    let mode = reader.read_u8_aligned()?;
152    reader.align_to_byte()?;
153    if reader.bits_remaining() != 0 {
154        return Err(CodecError::TrailingSectionData {
155            section: SectionTag::SessionInit,
156            remaining_bits: reader.bits_remaining(),
157        });
158    }
159    let compact_mode =
160        CompactHeaderMode::from_raw(mode).ok_or(CodecError::SessionUnsupportedMode { mode })?;
161    Ok((
162        if session_id == 0 {
163            None
164        } else {
165            Some(session_id)
166        },
167        compact_mode,
168    ))
169}
170
171/// Decodes a compact packet using session state.
172pub fn decode_session_packet<'a>(
173    schema: &schema::Schema,
174    session: &mut SessionState,
175    bytes: &'a [u8],
176    wire_limits: &wire::Limits,
177) -> CodecResult<WirePacket<'a>> {
178    if session.schema_hash != schema_hash(schema) {
179        return Err(CodecError::SchemaMismatch {
180            expected: schema_hash(schema),
181            found: session.schema_hash,
182        });
183    }
184    let header =
185        wire::decode_session_header(bytes, session.last_tick.raw()).map_err(CodecError::Wire)?;
186    if header.tick <= session.last_tick.raw() {
187        return Err(CodecError::SessionOutOfOrder {
188            previous: session.last_tick.raw(),
189            current: header.tick,
190        });
191    }
192
193    let payload_start = header.header_len;
194    let payload_end = payload_start + header.payload_len as usize;
195    if payload_end > bytes.len() {
196        return Err(CodecError::Wire(wire::DecodeError::PayloadLengthMismatch {
197            header_len: header.payload_len,
198            actual_len: bytes.len().saturating_sub(payload_start),
199        }));
200    }
201    let payload = &bytes[payload_start..payload_end];
202    let sections = wire::decode_sections(payload, wire_limits).map_err(CodecError::Wire)?;
203
204    session.last_tick = SnapshotTick::new(header.tick);
205    let flags = if header.flags.is_full_snapshot() {
206        PacketFlags::full_snapshot()
207    } else {
208        PacketFlags::delta_snapshot()
209    };
210    Ok(WirePacket {
211        header: PacketHeader {
212            version: wire::VERSION,
213            flags,
214            schema_hash: session.schema_hash,
215            tick: header.tick,
216            baseline_tick: header.baseline_tick,
217            payload_len: header.payload_len,
218        },
219        sections,
220    })
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::snapshot::{ComponentSnapshot, EntitySnapshot, FieldValue, Snapshot};
227    use crate::types::EntityId;
228    use schema::{ComponentDef, FieldCodec, FieldDef, FieldId, Schema};
229
230    fn schema_one_bool() -> Schema {
231        let component = ComponentDef::new(schema::ComponentId::new(1).unwrap())
232            .field(FieldDef::new(FieldId::new(1).unwrap(), FieldCodec::bool()));
233        Schema::new(vec![component]).unwrap()
234    }
235
236    #[test]
237    fn session_init_roundtrip() {
238        let schema = schema_one_bool();
239        let mut buf = [0u8; 128];
240        let bytes = encode_session_init_packet(
241            &schema,
242            SnapshotTick::new(5),
243            Some(42),
244            CompactHeaderMode::SessionV1,
245            &CodecLimits::for_testing(),
246            &mut buf,
247        )
248        .unwrap();
249        let packet = wire::decode_packet(&buf[..bytes], &wire::Limits::for_testing()).unwrap();
250        let session =
251            decode_session_init_packet(&schema, &packet, &CodecLimits::for_testing()).unwrap();
252        assert_eq!(session.session_id, Some(42));
253        assert_eq!(session.last_tick.raw(), 5);
254    }
255
256    #[test]
257    fn session_decode_compact_packet() {
258        let schema = schema_one_bool();
259        let baseline = Snapshot {
260            tick: SnapshotTick::new(10),
261            entities: vec![EntitySnapshot {
262                id: EntityId::new(1),
263                components: vec![ComponentSnapshot {
264                    id: schema::ComponentId::new(1).unwrap(),
265                    fields: vec![FieldValue::Bool(false)],
266                }],
267            }],
268        };
269        let current = Snapshot {
270            tick: SnapshotTick::new(11),
271            entities: vec![EntitySnapshot {
272                id: EntityId::new(1),
273                components: vec![ComponentSnapshot {
274                    id: schema::ComponentId::new(1).unwrap(),
275                    fields: vec![FieldValue::Bool(true)],
276                }],
277            }],
278        };
279        let mut session = SessionState {
280            schema_hash: schema_hash(&schema),
281            session_id: Some(1),
282            last_tick: baseline.tick,
283            compact_mode: CompactHeaderMode::SessionV1,
284        };
285        let mut buf = [0u8; 256];
286        let bytes = crate::delta::encode_delta_snapshot_for_client_session_with_scratch(
287            &schema,
288            current.tick,
289            baseline.tick,
290            &baseline,
291            &current,
292            &CodecLimits::for_testing(),
293            &mut crate::scratch::CodecScratch::default(),
294            &mut session.last_tick,
295            &mut buf,
296        )
297        .unwrap();
298        let packet = decode_session_packet(
299            &schema,
300            &mut session,
301            &buf[..bytes],
302            &wire::Limits::for_testing(),
303        )
304        .unwrap();
305        assert!(packet.header.flags.is_delta_snapshot());
306    }
307}