1use xxhash_rust::xxh3::xxh3_64;
19
20use crate::error::WireError;
21use crate::message::{
22 AuthKind, BatchPayload, ChosenMode, ClientHello, HandshakeStatus, Kind, Message, ServerHello,
23 StreamErrorCode,
24};
25use crate::{MAGIC, MAX_MESSAGE_BYTES, VERSION};
26
27pub const MIN_FRAME_BYTES: usize = 1 + 4 + 8 ;
29
30pub fn encode(msg: &Message) -> Vec<u8> {
33 let mut out = Vec::new();
34 encode_into(msg, &mut out);
35 out
36}
37
38pub fn encode_into(msg: &Message, out: &mut Vec<u8>) {
41 let kind = msg.kind();
42 let start = out.len();
43 out.push(kind as u8);
46 out.extend_from_slice(&[0u8; 4]);
47 let payload_start = out.len();
48 write_payload(msg, out);
49 let payload_len = out.len() - payload_start;
50 debug_assert!(payload_len <= u32::MAX as usize);
51 out[start + 1..start + 5].copy_from_slice(&(payload_len as u32).to_le_bytes());
52 if kind.has_checksum() {
53 let h = xxh3_64(&out[start..payload_start + payload_len]);
54 out.extend_from_slice(&h.to_le_bytes());
55 }
56}
57
58pub fn encode_handshake(msg: &Message) -> Vec<u8> {
62 debug_assert!(matches!(
63 msg,
64 Message::ClientHello(_) | Message::ServerHello(_)
65 ));
66 encode(msg)
67}
68
69fn write_payload(msg: &Message, out: &mut Vec<u8>) {
70 match msg {
71 Message::ClientHello(h) => {
72 out.extend_from_slice(&MAGIC);
73 out.push(VERSION);
74 out.push(h.capability_flags);
75 out.push(h.auth_kind as u8);
76 let auth_len: u16 = u16::try_from(h.auth.len()).unwrap_or(u16::MAX);
77 out.extend_from_slice(&auth_len.to_le_bytes());
78 out.extend_from_slice(&h.auth[..auth_len as usize]);
79 let os_len: u32 = u32::try_from(h.open_stream.len()).unwrap_or(u32::MAX);
80 out.extend_from_slice(&os_len.to_le_bytes());
81 out.extend_from_slice(&h.open_stream[..os_len as usize]);
82 }
83 Message::ServerHello(s) => {
84 out.push(s.status as u8);
85 out.push(match s.chosen_mode {
86 Some(m) => m as u8,
87 None => 0,
88 });
89 out.extend_from_slice(&s.initial_credit.to_le_bytes());
90 out.push(s.server_version);
91 out.extend_from_slice(&s.max_message_bytes.to_le_bytes());
92 let so_len: u32 = u32::try_from(s.stream_opened.len()).unwrap_or(u32::MAX);
93 out.extend_from_slice(&so_len.to_le_bytes());
94 out.extend_from_slice(&s.stream_opened[..so_len as usize]);
95 }
96 Message::RawFrame {
97 frame_id,
98 perm_seed,
99 zstd_bytes,
100 } => {
101 out.extend_from_slice(&frame_id.to_le_bytes());
102 out.extend_from_slice(perm_seed);
103 out.extend_from_slice(zstd_bytes);
104 }
105 Message::ZstdBatch {
106 batch_id,
107 epoch,
108 n_records,
109 zstd_bytes,
110 } => {
111 out.extend_from_slice(&batch_id.to_le_bytes());
112 out.extend_from_slice(&epoch.to_le_bytes());
113 out.extend_from_slice(&n_records.to_le_bytes());
114 out.extend_from_slice(zstd_bytes);
115 }
116 Message::PlainBatch(b) => {
117 out.extend_from_slice(&b.batch_id.to_le_bytes());
118 out.extend_from_slice(&b.epoch.to_le_bytes());
119 let n: u32 = u32::try_from(b.records.len()).unwrap_or(u32::MAX);
120 out.extend_from_slice(&n.to_le_bytes());
121 for rec in &b.records {
122 let len: u32 = u32::try_from(rec.len()).unwrap_or(u32::MAX);
123 out.extend_from_slice(&len.to_le_bytes());
124 out.extend_from_slice(&rec[..len as usize]);
125 }
126 }
127 Message::EpochBoundary {
128 completed_epoch,
129 records_in_epoch,
130 } => {
131 out.extend_from_slice(&completed_epoch.to_le_bytes());
132 out.extend_from_slice(&records_in_epoch.to_le_bytes());
133 }
134 Message::StreamError {
135 code,
136 fatal,
137 detail,
138 } => {
139 out.push(*code as u8);
140 out.push(if *fatal { 1 } else { 0 });
141 out.extend_from_slice(detail);
142 }
143 Message::StreamClosed {
144 total_records,
145 epochs_completed,
146 } => {
147 out.extend_from_slice(&total_records.to_le_bytes());
148 out.extend_from_slice(&epochs_completed.to_le_bytes());
149 }
150 Message::Heartbeat { now_unix_nanos } | Message::Pong { now_unix_nanos } => {
151 out.extend_from_slice(&now_unix_nanos.to_le_bytes());
152 }
153 Message::AddCredit { add_bytes } => {
154 out.extend_from_slice(&add_bytes.to_le_bytes());
155 }
156 Message::Cancel { reason } => {
157 out.extend_from_slice(reason);
158 }
159 }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum HandshakeRole {
168 ExpectClientHello,
171 ExpectServerHello,
174 Either,
176}
177
178#[derive(Debug, Clone, Copy)]
180pub struct DecodeOptions {
181 pub verify_xxh3: bool,
184 pub max_payload: u32,
187 pub role: HandshakeRole,
189}
190
191impl Default for DecodeOptions {
192 fn default() -> Self {
193 Self {
194 verify_xxh3: true,
195 max_payload: MAX_MESSAGE_BYTES,
196 role: HandshakeRole::Either,
197 }
198 }
199}
200
201#[derive(Debug, Default)]
204pub struct Decoder {
205 buf: Vec<u8>,
206 opts: DecodeOptions,
207}
208
209impl Decoder {
210 pub fn new(opts: DecodeOptions) -> Self {
211 Self {
212 buf: Vec::with_capacity(8 * 1024),
213 opts,
214 }
215 }
216
217 pub fn feed(&mut self, bytes: &[u8]) {
220 self.buf.extend_from_slice(bytes);
221 }
222
223 pub fn try_next(&mut self) -> Result<Option<Message>, WireError> {
228 if self.buf.is_empty() {
229 return Ok(None);
230 }
231 let kind_byte = self.buf[0];
232 let kind = Kind::from_u8(kind_byte).ok_or(WireError::UnknownKind { got: kind_byte })?;
233 if self.buf.len() < 5 {
234 return Ok(None);
235 }
236 let len_bytes: [u8; 4] = self.buf[1..5].try_into().unwrap_or([0u8; 4]);
238 let payload_len = u32::from_le_bytes(len_bytes);
239 if payload_len > self.opts.max_payload {
240 return Err(WireError::PayloadTooLarge {
241 got: payload_len,
242 max: self.opts.max_payload,
243 });
244 }
245 let need = 5 + payload_len as usize + if kind.has_checksum() { 8 } else { 0 };
246 if self.buf.len() < need {
247 return Ok(None);
248 }
249 if kind.has_checksum() {
251 let checked = 5 + payload_len as usize;
252 let xxh_bytes: [u8; 8] = self.buf[checked..checked + 8]
254 .try_into()
255 .unwrap_or([0u8; 8]);
256 let wire_h = u64::from_le_bytes(xxh_bytes);
257 if self.opts.verify_xxh3 {
258 let computed = xxh3_64(&self.buf[..checked]);
259 if computed != wire_h {
260 return Err(WireError::Xxh3Mismatch {
261 kind: kind_byte,
262 wire: wire_h,
263 computed,
264 });
265 }
266 }
267 }
268 let payload = &self.buf[5..5 + payload_len as usize];
269 let msg = parse_payload(kind, payload, self.opts.role)?;
270 self.buf.drain(..need);
271 Ok(Some(msg))
272 }
273
274 pub fn buffered_bytes(&self) -> usize {
276 self.buf.len()
277 }
278}
279
280fn parse_payload(kind: Kind, payload: &[u8], role: HandshakeRole) -> Result<Message, WireError> {
281 match kind {
282 Kind::Handshake => parse_handshake(payload, role),
283 Kind::RawFrame => parse_raw_frame(payload),
284 Kind::ZstdBatch => parse_zstd_batch(payload),
285 Kind::PlainBatch => parse_plain_batch(payload),
286 Kind::EpochBoundary => parse_epoch_boundary(payload),
287 Kind::StreamError => parse_stream_error(payload),
288 Kind::StreamClosed => parse_stream_closed(payload),
289 Kind::Heartbeat => parse_u64_kind(payload, kind, |ns| Message::Heartbeat {
290 now_unix_nanos: ns,
291 }),
292 Kind::AddCredit => parse_u64_kind(payload, kind, |n| Message::AddCredit { add_bytes: n }),
293 Kind::Cancel => Ok(Message::Cancel {
294 reason: payload.to_vec(),
295 }),
296 Kind::Pong => parse_u64_kind(payload, kind, |ns| Message::Pong { now_unix_nanos: ns }),
297 }
298}
299
300fn parse_handshake(p: &[u8], role: HandshakeRole) -> Result<Message, WireError> {
301 match role {
302 HandshakeRole::ExpectClientHello => parse_client_hello(p),
303 HandshakeRole::ExpectServerHello => parse_server_hello(p),
304 HandshakeRole::Either => {
305 if p.len() >= 8 && p[..8] == MAGIC {
306 parse_client_hello(p)
307 } else {
308 parse_server_hello(p)
309 }
310 }
311 }
312}
313
314fn parse_client_hello(p: &[u8]) -> Result<Message, WireError> {
315 let min = 8 + 1 + 1 + 1 + 2 + 4;
317 if p.len() < min {
318 return Err(WireError::TruncatedPayload {
319 kind: 0,
320 expected: min,
321 got: p.len(),
322 });
323 }
324 if p[..8] != MAGIC {
325 return Err(WireError::BadMagic {
326 got: p[..8].try_into().unwrap_or([0u8; 8]),
327 });
328 }
329 let version = p[8];
330 if version != VERSION {
331 return Err(WireError::BadVersion { got: version });
332 }
333 let caps = p[9];
334 let auth_kind_raw = p[10];
335 let auth_kind = AuthKind::from_u8(auth_kind_raw).ok_or(WireError::Malformed {
336 kind: 0,
337 detail: format!("bad auth_kind={auth_kind_raw}"),
338 })?;
339 let auth_len = u16::from_le_bytes(p[11..13].try_into().unwrap_or([0; 2])) as usize;
340 let auth_end = 13 + auth_len;
341 if p.len() < auth_end + 4 {
342 return Err(WireError::TruncatedPayload {
343 kind: 0,
344 expected: auth_end + 4,
345 got: p.len(),
346 });
347 }
348 let auth = p[13..auth_end].to_vec();
349 let os_len =
350 u32::from_le_bytes(p[auth_end..auth_end + 4].try_into().unwrap_or([0; 4])) as usize;
351 let os_start = auth_end + 4;
352 if p.len() < os_start + os_len {
353 return Err(WireError::TruncatedPayload {
354 kind: 0,
355 expected: os_start + os_len,
356 got: p.len(),
357 });
358 }
359 let open_stream = p[os_start..os_start + os_len].to_vec();
360 Ok(Message::ClientHello(ClientHello {
361 capability_flags: caps,
362 auth_kind,
363 auth,
364 open_stream,
365 }))
366}
367
368fn parse_server_hello(p: &[u8]) -> Result<Message, WireError> {
369 let min = 1 + 1 + 8 + 1 + 4 + 4;
371 if p.len() < min {
372 return Err(WireError::TruncatedPayload {
373 kind: 0,
374 expected: min,
375 got: p.len(),
376 });
377 }
378 let status = HandshakeStatus::from_u8(p[0]);
379 let chosen_mode = if p[1] == 0 {
380 None
381 } else {
382 ChosenMode::from_u8(p[1])
383 };
384 let initial_credit = u64::from_le_bytes(p[2..10].try_into().unwrap_or([0; 8]));
385 let server_version = p[10];
386 if server_version != VERSION {
387 return Err(WireError::BadVersion {
388 got: server_version,
389 });
390 }
391 let max_msg = u32::from_le_bytes(p[11..15].try_into().unwrap_or([0; 4]));
392 let so_len = u32::from_le_bytes(p[15..19].try_into().unwrap_or([0; 4])) as usize;
393 if p.len() < 19 + so_len {
394 return Err(WireError::TruncatedPayload {
395 kind: 0,
396 expected: 19 + so_len,
397 got: p.len(),
398 });
399 }
400 let stream_opened = p[19..19 + so_len].to_vec();
401 Ok(Message::ServerHello(ServerHello {
402 status,
403 chosen_mode,
404 initial_credit,
405 server_version,
406 max_message_bytes: max_msg,
407 stream_opened,
408 }))
409}
410
411fn parse_raw_frame(p: &[u8]) -> Result<Message, WireError> {
412 const HDR: usize = 4 + 32;
414 if p.len() < HDR {
415 return Err(WireError::TruncatedPayload {
416 kind: Kind::RawFrame as u8,
417 expected: HDR,
418 got: p.len(),
419 });
420 }
421 let frame_id = u32::from_le_bytes(p[0..4].try_into().unwrap_or([0; 4]));
422 let mut perm_seed = [0u8; 32];
423 perm_seed.copy_from_slice(&p[4..36]);
424 let zstd_bytes = p[36..].to_vec();
425 Ok(Message::RawFrame {
426 frame_id,
427 perm_seed,
428 zstd_bytes,
429 })
430}
431
432fn parse_zstd_batch(p: &[u8]) -> Result<Message, WireError> {
433 if p.len() < 16 {
434 return Err(WireError::TruncatedPayload {
435 kind: Kind::ZstdBatch as u8,
436 expected: 16,
437 got: p.len(),
438 });
439 }
440 let batch_id = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
441 let epoch = u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4]));
442 let n_records = u32::from_le_bytes(p[12..16].try_into().unwrap_or([0; 4]));
443 let zstd_bytes = p[16..].to_vec();
444 Ok(Message::ZstdBatch {
445 batch_id,
446 epoch,
447 n_records,
448 zstd_bytes,
449 })
450}
451
452fn parse_plain_batch(p: &[u8]) -> Result<Message, WireError> {
453 let min = 8 + 4 + 4;
454 if p.len() < min {
455 return Err(WireError::TruncatedPayload {
456 kind: Kind::PlainBatch as u8,
457 expected: min,
458 got: p.len(),
459 });
460 }
461 let batch_id = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
462 let epoch = u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4]));
463 let n = u32::from_le_bytes(p[12..16].try_into().unwrap_or([0; 4])) as usize;
464
465 let mut records = Vec::with_capacity(n);
466 let mut cursor = 16usize;
467 for _ in 0..n {
468 if p.len() < cursor + 4 {
469 return Err(WireError::TruncatedPayload {
470 kind: Kind::PlainBatch as u8,
471 expected: cursor + 4,
472 got: p.len(),
473 });
474 }
475 let len = u32::from_le_bytes(p[cursor..cursor + 4].try_into().unwrap_or([0; 4])) as usize;
476 cursor += 4;
477 if p.len() < cursor + len {
478 return Err(WireError::TruncatedPayload {
479 kind: Kind::PlainBatch as u8,
480 expected: cursor + len,
481 got: p.len(),
482 });
483 }
484 records.push(p[cursor..cursor + len].to_vec());
485 cursor += len;
486 }
487 Ok(Message::PlainBatch(BatchPayload {
488 batch_id,
489 epoch,
490 records,
491 }))
492}
493
494fn parse_epoch_boundary(p: &[u8]) -> Result<Message, WireError> {
495 if p.len() != 12 {
496 return Err(WireError::TruncatedPayload {
497 kind: Kind::EpochBoundary as u8,
498 expected: 12,
499 got: p.len(),
500 });
501 }
502 Ok(Message::EpochBoundary {
503 completed_epoch: u32::from_le_bytes(p[0..4].try_into().unwrap_or([0; 4])),
504 records_in_epoch: u64::from_le_bytes(p[4..12].try_into().unwrap_or([0; 8])),
505 })
506}
507
508fn parse_stream_error(p: &[u8]) -> Result<Message, WireError> {
509 if p.len() < 2 {
510 return Err(WireError::TruncatedPayload {
511 kind: Kind::StreamError as u8,
512 expected: 2,
513 got: p.len(),
514 });
515 }
516 Ok(Message::StreamError {
517 code: StreamErrorCode::from_u8(p[0]),
518 fatal: p[1] != 0,
519 detail: p[2..].to_vec(),
520 })
521}
522
523fn parse_stream_closed(p: &[u8]) -> Result<Message, WireError> {
524 if p.len() != 12 {
525 return Err(WireError::TruncatedPayload {
526 kind: Kind::StreamClosed as u8,
527 expected: 12,
528 got: p.len(),
529 });
530 }
531 Ok(Message::StreamClosed {
532 total_records: u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8])),
533 epochs_completed: u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4])),
534 })
535}
536
537fn parse_u64_kind(
538 p: &[u8],
539 kind: Kind,
540 build: impl FnOnce(u64) -> Message,
541) -> Result<Message, WireError> {
542 if p.len() != 8 {
543 return Err(WireError::TruncatedPayload {
544 kind: kind as u8,
545 expected: 8,
546 got: p.len(),
547 });
548 }
549 let v = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
550 Ok(build(v))
551}