1use serde::Serialize;
19use serde::de::DeserializeOwned;
20use std::fmt;
21use std::io::{self, BufRead, Read};
22use std::sync::OnceLock;
23
24pub const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
27
28const MAX_RMPV_DEPTH: usize = 128;
31
32static WIRE_CODEC: OnceLock<Codec> = OnceLock::new();
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Codec {
40 Json,
42 MsgPack,
44}
45
46impl fmt::Display for Codec {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 Codec::Json => f.write_str("json"),
50 Codec::MsgPack => f.write_str("msgpack"),
51 }
52 }
53}
54
55impl Codec {
56 pub fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, String> {
64 match self {
65 Codec::Json => {
66 let mut bytes =
67 serde_json::to_vec(value).map_err(|e| format!("json encode: {e}"))?;
68 bytes.push(b'\n');
69 Ok(bytes)
70 }
71 Codec::MsgPack => {
72 let payload =
73 rmp_serde::to_vec_named(value).map_err(|e| format!("msgpack encode: {e}"))?;
74 let len = u32::try_from(payload.len()).map_err(|_| {
75 format!(
76 "payload exceeds 4 GiB frame limit ({} bytes)",
77 payload.len()
78 )
79 })?;
80 let mut bytes = Vec::with_capacity(4 + payload.len());
81 bytes.extend_from_slice(&len.to_be_bytes());
82 bytes.extend_from_slice(&payload);
83 Ok(bytes)
84 }
85 }
86 }
87
88 pub fn encode_binary_message(
101 &self,
102 mut map: serde_json::Map<String, serde_json::Value>,
103 binary_field: Option<(&str, &[u8])>,
104 ) -> Result<Vec<u8>, String> {
105 match self {
106 Codec::Json => {
107 if let Some((key, bytes)) = binary_field
108 && !bytes.is_empty()
109 {
110 use base64::Engine;
111 let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
112 map.insert(key.to_string(), serde_json::Value::String(b64));
113 }
114 let val = serde_json::Value::Object(map);
115 let mut bytes =
116 serde_json::to_vec(&val).map_err(|e| format!("json encode: {e}"))?;
117 bytes.push(b'\n');
118 Ok(bytes)
119 }
120 Codec::MsgPack => {
121 use rmpv::Value as V;
122
123 let mut entries: Vec<(V, V)> = map
124 .into_iter()
125 .map(|(k, v)| (V::String(k.into()), json_to_rmpv(v)))
126 .collect();
127
128 if let Some((key, bytes)) = binary_field
129 && !bytes.is_empty()
130 {
131 entries.push((V::String(key.into()), V::Binary(bytes.to_vec())));
132 }
133
134 let msg = V::Map(entries);
135 let mut payload = Vec::new();
136 rmpv::encode::write_value(&mut payload, &msg)
137 .map_err(|e| format!("msgpack encode: {e}"))?;
138 let len = u32::try_from(payload.len()).map_err(|_| {
139 format!(
140 "payload exceeds 4 GiB frame limit ({} bytes)",
141 payload.len()
142 )
143 })?;
144 let mut bytes = Vec::with_capacity(4 + payload.len());
145 bytes.extend_from_slice(&len.to_be_bytes());
146 bytes.extend_from_slice(&payload);
147 Ok(bytes)
148 }
149 }
150 }
151
152 pub fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T, String> {
164 match self {
165 Codec::Json => serde_json::from_slice(bytes).map_err(|e| format!("json decode: {e}")),
166 Codec::MsgPack => {
167 check_msgpack_depth(bytes, MAX_RMPV_DEPTH)
172 .map_err(|e| format!("msgpack depth check: {e}"))?;
173 let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &bytes[..])
174 .map_err(|e| format!("msgpack decode (rmpv): {e}"))?;
175 let json_val = rmpv_to_json(rmpv_val);
176 serde_json::from_value(json_val)
177 .map_err(|e| format!("msgpack decode (tag dispatch): {e}"))
178 }
179 }
180 }
181
182 pub fn read_message<R: BufRead>(&self, reader: &mut R) -> io::Result<Option<Vec<u8>>> {
189 match self {
190 Codec::Json => loop {
191 let mut line = String::new();
192 let limit = (MAX_MESSAGE_SIZE + 1) as u64;
196 let n = (&mut *reader).take(limit).read_line(&mut line)?;
197 if n == 0 {
198 return Ok(None);
199 }
200 if line.len() > MAX_MESSAGE_SIZE {
201 return Err(io::Error::new(
202 io::ErrorKind::InvalidData,
203 format!(
204 "JSON message exceeds {} byte limit ({} bytes)",
205 MAX_MESSAGE_SIZE,
206 line.len()
207 ),
208 ));
209 }
210 let trimmed = line.trim();
211 if trimmed.is_empty() {
212 continue;
213 }
214 return Ok(Some(trimmed.as_bytes().to_vec()));
215 },
216 Codec::MsgPack => {
217 let mut len_buf = [0u8; 4];
218 match reader.read_exact(&mut len_buf) {
219 Ok(()) => {}
220 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
221 Err(e) => return Err(e),
222 }
223 let len = u32::from_be_bytes(len_buf) as usize;
224 if len == 0 {
225 return Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 "empty frame received",
228 ));
229 }
230 if len > MAX_MESSAGE_SIZE {
231 return Err(io::Error::new(
232 io::ErrorKind::InvalidData,
233 format!(
234 "msgpack frame exceeds {} byte limit ({} bytes)",
235 MAX_MESSAGE_SIZE, len
236 ),
237 ));
238 }
239 let mut payload = vec![0u8; len];
240 reader.read_exact(&mut payload)?;
241 Ok(Some(payload))
242 }
243 }
244 }
245
246 pub fn detect_from_first_byte(byte: u8) -> Codec {
251 if byte == b'{' {
252 Codec::Json
253 } else {
254 Codec::MsgPack
255 }
256 }
257
258 pub fn set_global(codec: Codec) {
260 WIRE_CODEC
261 .set(codec)
262 .expect("WIRE_CODEC already initialized");
263 }
264
265 pub fn get_global() -> &'static Codec {
267 WIRE_CODEC.get().unwrap_or(&Codec::MsgPack)
268 }
269}
270
271fn check_msgpack_depth(bytes: &[u8], max_depth: usize) -> Result<(), String> {
284 let len = bytes.len();
285 let mut pos: usize = 0;
286 let mut depth: usize = 0;
287 let mut remaining: Vec<usize> = Vec::new();
289
290 while pos < len {
291 let b = bytes[pos];
292 pos += 1;
293
294 let (skip, children) = match b {
298 0x00..=0x7f => (0, 0),
300 0x80..=0x8f => (0, ((b & 0x0f) as usize) * 2),
302 0x90..=0x9f => (0, (b & 0x0f) as usize),
304 0xa0..=0xbf => ((b & 0x1f) as usize, 0),
306 0xc0..=0xc3 => (0, 0),
308 0xc4 => {
310 if pos >= len {
311 break;
312 }
313 (1 + bytes[pos] as usize, 0)
314 }
315 0xc5 => {
317 if pos + 1 >= len {
318 break;
319 }
320 let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
321 (2 + n, 0)
322 }
323 0xc6 => {
325 if pos + 3 >= len {
326 break;
327 }
328 let n = u32::from_be_bytes([
329 bytes[pos],
330 bytes[pos + 1],
331 bytes[pos + 2],
332 bytes[pos + 3],
333 ]) as usize;
334 (4 + n, 0)
335 }
336 0xc7 => {
338 if pos >= len {
339 break;
340 }
341 (2 + bytes[pos] as usize, 0)
342 }
343 0xc8 => {
345 if pos + 1 >= len {
346 break;
347 }
348 let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
349 (3 + n, 0)
350 }
351 0xc9 => {
353 if pos + 3 >= len {
354 break;
355 }
356 let n = u32::from_be_bytes([
357 bytes[pos],
358 bytes[pos + 1],
359 bytes[pos + 2],
360 bytes[pos + 3],
361 ]) as usize;
362 (5 + n, 0)
363 }
364 0xca => (4, 0),
366 0xcb => (8, 0),
368 0xcc | 0xd0 => (1, 0),
370 0xcd | 0xd1 => (2, 0),
372 0xce | 0xd2 => (4, 0),
374 0xcf | 0xd3 => (8, 0),
376 0xd4 => (2, 0),
378 0xd5 => (3, 0),
379 0xd6 => (5, 0),
380 0xd7 => (9, 0),
381 0xd8 => (17, 0),
382 0xd9 => {
384 if pos >= len {
385 break;
386 }
387 (1 + bytes[pos] as usize, 0)
388 }
389 0xda => {
391 if pos + 1 >= len {
392 break;
393 }
394 let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
395 (2 + n, 0)
396 }
397 0xdb => {
399 if pos + 3 >= len {
400 break;
401 }
402 let n = u32::from_be_bytes([
403 bytes[pos],
404 bytes[pos + 1],
405 bytes[pos + 2],
406 bytes[pos + 3],
407 ]) as usize;
408 (4 + n, 0)
409 }
410 0xdc => {
412 if pos + 1 >= len {
413 break;
414 }
415 let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
416 pos += 2;
417 (0, n)
418 }
419 0xdd => {
421 if pos + 3 >= len {
422 break;
423 }
424 let n = u32::from_be_bytes([
425 bytes[pos],
426 bytes[pos + 1],
427 bytes[pos + 2],
428 bytes[pos + 3],
429 ]) as usize;
430 pos += 4;
431 (0, n)
432 }
433 0xde => {
435 if pos + 1 >= len {
436 break;
437 }
438 let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
439 pos += 2;
440 (0, n * 2)
441 }
442 0xdf => {
444 if pos + 3 >= len {
445 break;
446 }
447 let n = u32::from_be_bytes([
448 bytes[pos],
449 bytes[pos + 1],
450 bytes[pos + 2],
451 bytes[pos + 3],
452 ]) as usize;
453 pos += 4;
454 (0, n * 2)
455 }
456 0xe0..=0xff => (0, 0),
458 };
459
460 pos += skip;
461
462 if children > 0 {
463 let remaining_bytes = len.saturating_sub(pos);
467 if children > remaining_bytes {
468 return Err(format!(
469 "msgpack container declares {children} elements but only {remaining_bytes} bytes remain"
470 ));
471 }
472
473 depth += 1;
474 if depth > max_depth {
475 return Err(format!("msgpack nesting depth exceeds limit ({max_depth})"));
476 }
477 remaining.push(children);
478 } else {
479 while let Some(count) = remaining.last_mut() {
481 *count -= 1;
482 if *count == 0 {
483 remaining.pop();
484 depth -= 1;
485 } else {
486 break;
487 }
488 }
489 }
490 }
491
492 Ok(())
493}
494
495fn rmpv_to_json(val: rmpv::Value) -> serde_json::Value {
510 rmpv_to_json_inner(val, 0)
511}
512
513fn rmpv_to_json_inner(val: rmpv::Value, depth: usize) -> serde_json::Value {
514 if depth > MAX_RMPV_DEPTH {
515 log::error!("rmpv_to_json: recursion depth exceeded {MAX_RMPV_DEPTH}, replaced with null");
516 return serde_json::Value::Null;
517 }
518
519 match val {
520 rmpv::Value::Nil => serde_json::Value::Null,
521 rmpv::Value::Boolean(b) => serde_json::Value::Bool(b),
522 rmpv::Value::Integer(n) => {
523 if let Some(i) = n.as_i64() {
524 serde_json::Value::Number(i.into())
525 } else if let Some(u) = n.as_u64() {
526 serde_json::Value::Number(u.into())
527 } else {
528 serde_json::Value::Null
530 }
531 }
532 rmpv::Value::F32(f) => serde_json::Number::from_f64(f as f64)
533 .map(serde_json::Value::Number)
534 .unwrap_or_else(|| {
535 log::warn!("rmpv_to_json: non-finite f32 ({f}) replaced with 0.0");
536 serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
537 }),
538 rmpv::Value::F64(f) => serde_json::Number::from_f64(f)
539 .map(serde_json::Value::Number)
540 .unwrap_or_else(|| {
541 log::warn!("rmpv_to_json: non-finite f64 ({f}) replaced with 0.0");
542 serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
543 }),
544 rmpv::Value::String(s) => {
545 serde_json::Value::String(String::from_utf8_lossy(s.as_bytes()).into_owned())
549 }
550 rmpv::Value::Binary(bytes) => {
551 serde_json::Value::Array(
554 bytes
555 .into_iter()
556 .map(|b| serde_json::Value::Number(b.into()))
557 .collect(),
558 )
559 }
560 rmpv::Value::Array(arr) => serde_json::Value::Array(
561 arr.into_iter()
562 .map(|v| rmpv_to_json_inner(v, depth + 1))
563 .collect(),
564 ),
565 rmpv::Value::Map(entries) => {
566 let mut map = serde_json::Map::new();
567 for (k, v) in entries {
568 let key = match k {
570 rmpv::Value::String(s) => s.into_str().unwrap_or_default().to_string(),
571 rmpv::Value::Integer(n) => n.to_string(),
572 other => format!("{other}"),
573 };
574 map.insert(key, rmpv_to_json_inner(v, depth + 1));
575 }
576 serde_json::Value::Object(map)
577 }
578 rmpv::Value::Ext(type_id, _bytes) => {
579 log::warn!(
580 "rmpv_to_json: msgpack ext type {type_id} not supported, replaced with null"
581 );
582 serde_json::Value::Null
583 }
584 }
585}
586
587fn json_to_rmpv(val: serde_json::Value) -> rmpv::Value {
590 match val {
591 serde_json::Value::Null => rmpv::Value::Nil,
592 serde_json::Value::Bool(b) => rmpv::Value::Boolean(b),
593 serde_json::Value::Number(n) => {
594 if let Some(i) = n.as_i64() {
595 rmpv::Value::Integer(i.into())
596 } else if let Some(u) = n.as_u64() {
597 rmpv::Value::Integer(u.into())
598 } else if let Some(f) = n.as_f64() {
599 rmpv::Value::F64(f)
600 } else {
601 rmpv::Value::Nil
602 }
603 }
604 serde_json::Value::String(s) => rmpv::Value::String(s.into()),
605 serde_json::Value::Array(arr) => {
606 rmpv::Value::Array(arr.into_iter().map(json_to_rmpv).collect())
607 }
608 serde_json::Value::Object(map) => rmpv::Value::Map(
609 map.into_iter()
610 .map(|(k, v)| (rmpv::Value::String(k.into()), json_to_rmpv(v)))
611 .collect(),
612 ),
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619 use serde::{Deserialize, Serialize};
620 use serde_json::json;
621
622 #[derive(Debug, Serialize, Deserialize, PartialEq)]
623 struct Simple {
624 name: String,
625 count: u32,
626 }
627
628 #[derive(Debug, Serialize, Deserialize, PartialEq)]
629 #[serde(tag = "type", rename_all = "snake_case")]
630 enum Tagged {
631 Alpha { value: String },
632 Beta { x: f64, y: f64 },
633 }
634
635 #[derive(Debug, Serialize, Deserialize, PartialEq)]
636 struct WithFlatten {
637 op: String,
638 #[serde(flatten)]
639 rest: serde_json::Value,
640 }
641
642 #[test]
645 fn json_roundtrip_simple() {
646 let original = Simple {
647 name: "test".into(),
648 count: 42,
649 };
650 let bytes = Codec::Json.encode(&original).unwrap();
651 assert!(bytes.ends_with(b"\n"));
652 let decoded: Simple = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
653 assert_eq!(decoded, original);
654 }
655
656 #[test]
657 fn json_roundtrip_tagged_enum() {
658 let original = Tagged::Beta { x: 1.5, y: 2.5 };
659 let bytes = Codec::Json.encode(&original).unwrap();
660 let decoded: Tagged = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
661 assert_eq!(decoded, original);
662 }
663
664 #[test]
667 fn msgpack_roundtrip_simple() {
668 let original = Simple {
669 name: "test".into(),
670 count: 42,
671 };
672 let bytes = Codec::MsgPack.encode(&original).unwrap();
673 let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
675 assert_eq!(len, bytes.len() - 4);
676 let decoded: Simple = Codec::MsgPack.decode(&bytes[4..]).unwrap();
677 assert_eq!(decoded, original);
678 }
679
680 #[test]
681 fn msgpack_roundtrip_tagged_enum() {
682 let original = Tagged::Alpha {
683 value: "hello".into(),
684 };
685 let bytes = Codec::MsgPack.encode(&original).unwrap();
686 let payload = &bytes[4..];
687 let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
688 assert_eq!(decoded, original);
689 }
690
691 #[test]
692 fn msgpack_roundtrip_tagged_enum_beta() {
693 let original = Tagged::Beta {
694 x: std::f64::consts::PI,
695 y: -1.0,
696 };
697 let bytes = Codec::MsgPack.encode(&original).unwrap();
698 let payload = &bytes[4..];
699 let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
700 assert_eq!(decoded, original);
701 }
702
703 #[test]
704 fn msgpack_flatten_deserialize() {
705 let input = json!({"op": "props", "path": [0, 1], "props": {"label": "hi"}});
708 let bytes = rmp_serde::to_vec_named(&input).unwrap();
709 let decoded: WithFlatten = rmp_serde::from_slice(&bytes).unwrap();
710 assert_eq!(decoded.op, "props");
711 assert_eq!(decoded.rest["path"], json!([0, 1]));
712 assert_eq!(decoded.rest["props"]["label"], "hi");
713 }
714
715 #[test]
718 fn json_read_message_skips_blank_lines() {
719 let data = b"\n\n{\"name\":\"a\",\"count\":1}\n\n{\"name\":\"b\",\"count\":2}\n\n";
721 let mut reader = io::BufReader::new(&data[..]);
722
723 let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
724 let s1: Simple = Codec::Json.decode(&msg1).unwrap();
725 assert_eq!(s1.name, "a");
726
727 let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
728 let s2: Simple = Codec::Json.decode(&msg2).unwrap();
729 assert_eq!(s2.name, "b");
730
731 assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
733 }
734
735 #[test]
736 fn json_read_message() {
737 let data = b"{\"name\":\"a\",\"count\":1}\n{\"name\":\"b\",\"count\":2}\n";
738 let mut reader = io::BufReader::new(&data[..]);
739
740 let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
741 let s1: Simple = Codec::Json.decode(&msg1).unwrap();
742 assert_eq!(s1.name, "a");
743
744 let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
745 let s2: Simple = Codec::Json.decode(&msg2).unwrap();
746 assert_eq!(s2.name, "b");
747
748 assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
749 }
750
751 #[test]
752 fn msgpack_read_message() {
753 let s1 = Simple {
755 name: "x".into(),
756 count: 10,
757 };
758 let s2 = Simple {
759 name: "y".into(),
760 count: 20,
761 };
762 let p1 = rmp_serde::to_vec_named(&s1).unwrap();
763 let p2 = rmp_serde::to_vec_named(&s2).unwrap();
764
765 let mut data = Vec::new();
766 data.extend_from_slice(&(p1.len() as u32).to_be_bytes());
767 data.extend_from_slice(&p1);
768 data.extend_from_slice(&(p2.len() as u32).to_be_bytes());
769 data.extend_from_slice(&p2);
770
771 let mut reader = io::BufReader::new(&data[..]);
772
773 let msg1 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
774 let d1: Simple = Codec::MsgPack.decode(&msg1).unwrap();
775 assert_eq!(d1, s1);
776
777 let msg2 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
778 let d2: Simple = Codec::MsgPack.decode(&msg2).unwrap();
779 assert_eq!(d2, s2);
780
781 assert!(Codec::MsgPack.read_message(&mut reader).unwrap().is_none());
782 }
783
784 #[test]
787 fn json_read_message_rejects_oversized_line() {
788 let small_limit = 100;
798 let long_line: Vec<u8> = vec![b'x'; small_limit + 10];
800 let mut reader = io::BufReader::new(&long_line[..]);
801
802 let mut line = String::new();
805 let limit = (small_limit + 1) as u64;
806 let _n = (&mut reader).take(limit).read_line(&mut line).unwrap();
807 assert!(line.len() <= small_limit + 1);
809 }
811
812 #[test]
813 fn msgpack_read_message_rejects_oversized_frame() {
814 let len = (MAX_MESSAGE_SIZE + 1) as u32;
816 let mut data = Vec::new();
817 data.extend_from_slice(&len.to_be_bytes());
818 data.extend_from_slice(&[0u8; 64]); let mut reader = io::BufReader::new(&data[..]);
822 let result = Codec::MsgPack.read_message(&mut reader);
823 assert!(result.is_err());
824 let err = result.unwrap_err();
825 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
826 assert!(err.to_string().contains("byte limit"));
827 }
828
829 #[test]
830 fn msgpack_read_message_rejects_zero_length_frame() {
831 let mut data = Vec::new();
832 data.extend_from_slice(&0u32.to_be_bytes());
833
834 let mut reader = io::BufReader::new(&data[..]);
835 let result = Codec::MsgPack.read_message(&mut reader);
836 assert!(result.is_err());
837 assert!(result.unwrap_err().to_string().contains("empty frame"));
838 }
839
840 #[test]
850 fn msgpack_external_tagged_enum_alpha() {
851 let external = json!({"type": "alpha", "value": "hello"});
853 let bytes = rmp_serde::to_vec_named(&external).unwrap();
854 let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
855 assert_eq!(
856 decoded,
857 Tagged::Alpha {
858 value: "hello".into()
859 }
860 );
861 }
862
863 #[test]
864 fn msgpack_external_tagged_enum_beta() {
865 let external = json!({"type": "beta", "x": 1.5, "y": -2.0});
866 let bytes = rmp_serde::to_vec_named(&external).unwrap();
867 let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
868 assert_eq!(decoded, Tagged::Beta { x: 1.5, y: -2.0 });
869 }
870
871 #[test]
872 fn msgpack_external_incoming_settings() {
873 use crate::protocol::IncomingMessage;
875 let external = json!({"type": "settings", "settings": {"antialiasing": false}});
876 let bytes = rmp_serde::to_vec_named(&external).unwrap();
877 let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
878 assert!(matches!(decoded, IncomingMessage::Settings { .. }));
879 }
880
881 #[test]
882 fn msgpack_external_incoming_snapshot() {
883 use crate::protocol::IncomingMessage;
884 let external = json!({"type": "snapshot", "tree": {"id": "root", "type": "column", "props": {}, "children": []}});
885 let bytes = rmp_serde::to_vec_named(&external).unwrap();
886 let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
887 assert!(matches!(decoded, IncomingMessage::Snapshot { .. }));
888 }
889
890 #[test]
893 fn msgpack_image_op_with_native_binary() {
894 use rmpv::Value as RmpvValue;
897
898 let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255, 0, 255, 0, 255]; let msg = RmpvValue::Map(vec![
900 (
901 RmpvValue::String("type".into()),
902 RmpvValue::String("image_op".into()),
903 ),
904 (
905 RmpvValue::String("op".into()),
906 RmpvValue::String("create_image".into()),
907 ),
908 (
909 RmpvValue::String("handle".into()),
910 RmpvValue::String("test_img".into()),
911 ),
912 (
913 RmpvValue::String("pixels".into()),
914 RmpvValue::Binary(pixel_bytes.clone()),
915 ),
916 (
917 RmpvValue::String("width".into()),
918 RmpvValue::Integer(1.into()),
919 ),
920 (
921 RmpvValue::String("height".into()),
922 RmpvValue::Integer(2.into()),
923 ),
924 ]);
925
926 let mut buf = Vec::new();
927 rmpv::encode::write_value(&mut buf, &msg).unwrap();
928
929 let decoded: crate::protocol::IncomingMessage = Codec::MsgPack.decode(&buf).unwrap();
930 match decoded {
931 crate::protocol::IncomingMessage::ImageOp {
932 op,
933 handle,
934 pixels,
935 width,
936 height,
937 data,
938 } => {
939 assert_eq!(op, "create_image");
940 assert_eq!(handle, "test_img");
941 assert_eq!(pixels, Some(pixel_bytes));
942 assert_eq!(width, Some(1));
943 assert_eq!(height, Some(2));
944 assert!(data.is_none());
945 }
946 other => panic!("expected ImageOp, got {other:?}"),
947 }
948 }
949
950 #[test]
951 fn msgpack_image_op_with_base64_string() {
952 use crate::protocol::IncomingMessage;
954 use base64::Engine as _;
955
956 let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255];
957 let b64 = base64::engine::general_purpose::STANDARD.encode(&pixel_bytes);
958
959 let json_msg = json!({
960 "type": "image_op",
961 "op": "create_image",
962 "handle": "test_img",
963 "pixels": b64,
964 "width": 1,
965 "height": 1
966 });
967 let json_str = serde_json::to_string(&json_msg).unwrap();
968
969 let decoded: IncomingMessage = Codec::Json.decode(json_str.as_bytes()).unwrap();
970 match decoded {
971 IncomingMessage::ImageOp { pixels, .. } => {
972 assert_eq!(pixels, Some(pixel_bytes));
973 }
974 other => panic!("expected ImageOp, got {other:?}"),
975 }
976 }
977
978 #[test]
981 fn rmpv_to_json_preserves_binary_as_array() {
982 let binary = rmpv::Value::Binary(vec![1, 2, 3]);
983 let result = rmpv_to_json(binary);
984 assert_eq!(result, json!([1, 2, 3]));
985 }
986
987 #[test]
988 fn rmpv_to_json_handles_nested_map() {
989 let val = rmpv::Value::Map(vec![
990 (
991 rmpv::Value::String("key".into()),
992 rmpv::Value::String("val".into()),
993 ),
994 (
995 rmpv::Value::String("num".into()),
996 rmpv::Value::Integer(42.into()),
997 ),
998 ]);
999 let result = rmpv_to_json(val);
1000 assert_eq!(result, json!({"key": "val", "num": 42}));
1001 }
1002
1003 #[test]
1006 fn detect_json_from_brace() {
1007 assert_eq!(Codec::detect_from_first_byte(b'{'), Codec::Json);
1008 }
1009
1010 #[test]
1011 fn detect_msgpack_from_zero() {
1012 assert_eq!(Codec::detect_from_first_byte(0x00), Codec::MsgPack);
1013 }
1014
1015 #[test]
1016 fn detect_msgpack_from_fixmap() {
1017 assert_eq!(Codec::detect_from_first_byte(0x85), Codec::MsgPack);
1018 }
1019
1020 #[test]
1021 fn display_format() {
1022 assert_eq!(Codec::Json.to_string(), "json");
1023 assert_eq!(Codec::MsgPack.to_string(), "msgpack");
1024 }
1025
1026 #[test]
1029 fn rmpv_to_json_deeply_nested_maps() {
1030 let val = rmpv::Value::Map(vec![(
1032 rmpv::Value::String("outer".into()),
1033 rmpv::Value::Map(vec![(
1034 rmpv::Value::String("inner".into()),
1035 rmpv::Value::Map(vec![(
1036 rmpv::Value::String("deep".into()),
1037 rmpv::Value::Integer(42.into()),
1038 )]),
1039 )]),
1040 )]);
1041 let result = rmpv_to_json(val);
1042 assert_eq!(result, json!({"outer": {"inner": {"deep": 42}}}));
1043 }
1044
1045 #[test]
1046 fn rmpv_to_json_binary_in_nested_map() {
1047 let val = rmpv::Value::Map(vec![
1049 (
1050 rmpv::Value::String("name".into()),
1051 rmpv::Value::String("img".into()),
1052 ),
1053 (
1054 rmpv::Value::String("pixels".into()),
1055 rmpv::Value::Binary(vec![255, 128, 0, 255]),
1056 ),
1057 ]);
1058 let result = rmpv_to_json(val);
1059 assert_eq!(result["name"], json!("img"));
1060 assert_eq!(result["pixels"], json!([255, 128, 0, 255]));
1061 }
1062
1063 #[test]
1064 fn msgpack_roundtrip_with_binary_field() {
1065 use rmpv::Value as RmpvValue;
1068
1069 let raw_bytes: Vec<u8> = vec![0xDE, 0xAD, 0xBE, 0xEF];
1070 let msg = RmpvValue::Map(vec![
1071 (
1072 RmpvValue::String("type".into()),
1073 RmpvValue::String("alpha".into()),
1074 ),
1075 (
1076 RmpvValue::String("value".into()),
1077 RmpvValue::String("hello".into()),
1078 ),
1079 (
1080 RmpvValue::String("payload".into()),
1081 RmpvValue::Binary(raw_bytes.clone()),
1082 ),
1083 ]);
1084
1085 let mut buf = Vec::new();
1087 rmpv::encode::write_value(&mut buf, &msg).unwrap();
1088
1089 let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &buf[..]).unwrap();
1091 let json_val = rmpv_to_json(rmpv_val);
1092
1093 assert_eq!(json_val["type"], "alpha");
1095 assert_eq!(json_val["value"], "hello");
1096
1097 let payload = json_val["payload"].as_array().unwrap();
1099 let bytes: Vec<u8> = payload.iter().map(|v| v.as_u64().unwrap() as u8).collect();
1100 assert_eq!(bytes, raw_bytes);
1101 }
1102
1103 #[test]
1104 fn rmpv_to_json_handles_nil_and_bool() {
1105 assert_eq!(rmpv_to_json(rmpv::Value::Nil), json!(null));
1106 assert_eq!(rmpv_to_json(rmpv::Value::Boolean(true)), json!(true));
1107 assert_eq!(rmpv_to_json(rmpv::Value::Boolean(false)), json!(false));
1108 }
1109
1110 #[test]
1113 fn msgpack_depth_check_accepts_flat_map() {
1114 let val = json!({"a": 1, "b": "hello", "c": true});
1115 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1116 assert!(check_msgpack_depth(&bytes, 128).is_ok());
1117 }
1118
1119 #[test]
1120 fn msgpack_depth_check_accepts_nested_within_limit() {
1121 let val = json!({"outer": {"middle": {"inner": 42}}});
1123 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1124 assert!(check_msgpack_depth(&bytes, 3).is_ok());
1125 }
1126
1127 #[test]
1128 fn msgpack_depth_check_rejects_beyond_limit() {
1129 let val = json!({"a": {"b": {"c": 1}}});
1131 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1132 assert!(check_msgpack_depth(&bytes, 2).is_err());
1133 }
1134
1135 #[test]
1136 fn msgpack_depth_check_accepts_flat_array() {
1137 let val = json!([1, 2, 3, 4, 5]);
1138 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1139 assert!(check_msgpack_depth(&bytes, 1).is_ok());
1140 }
1141
1142 #[test]
1143 fn msgpack_depth_check_nested_arrays() {
1144 let val = json!([[[42]]]);
1145 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1146 assert!(check_msgpack_depth(&bytes, 3).is_ok());
1147 assert!(check_msgpack_depth(&bytes, 2).is_err());
1148 }
1149
1150 #[test]
1151 fn msgpack_depth_check_mixed_containers() {
1152 let val = json!({"list": [{"nested": true}]});
1153 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1154 assert!(check_msgpack_depth(&bytes, 3).is_ok());
1156 assert!(check_msgpack_depth(&bytes, 2).is_err());
1157 }
1158
1159 #[test]
1160 fn msgpack_depth_check_empty_containers() {
1161 let val = json!({"empty_map": {}, "empty_arr": []});
1162 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1163 assert!(check_msgpack_depth(&bytes, 2).is_ok());
1164 }
1165
1166 #[test]
1167 fn msgpack_depth_check_sibling_arrays_dont_add_depth() {
1168 let val = json!([[1, 2], [3, 4]]);
1170 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1171 assert!(check_msgpack_depth(&bytes, 2).is_ok());
1172 }
1173
1174 #[test]
1175 fn msgpack_depth_check_binary_data() {
1176 use rmpv::Value as V;
1177 let val = V::Map(vec![(
1178 V::String("data".into()),
1179 V::Binary(vec![0xDE, 0xAD]),
1180 )]);
1181 let mut bytes = Vec::new();
1182 rmpv::encode::write_value(&mut bytes, &val).unwrap();
1183 assert!(check_msgpack_depth(&bytes, 1).is_ok());
1184 }
1185
1186 #[test]
1187 fn msgpack_depth_check_deeply_nested_rejects() {
1188 use rmpv::Value as V;
1190 let depth = 200;
1191 let mut val = V::Integer(1.into());
1192 for _ in 0..depth {
1193 val = V::Map(vec![(V::String("a".into()), val)]);
1194 }
1195 let mut bytes = Vec::new();
1196 rmpv::encode::write_value(&mut bytes, &val).unwrap();
1197
1198 assert!(check_msgpack_depth(&bytes, 128).is_err());
1199 assert!(check_msgpack_depth(&bytes, 200).is_ok());
1200 }
1201
1202 #[test]
1203 fn msgpack_decode_rejects_deeply_nested() {
1204 use rmpv::Value as V;
1206 let mut val = V::Integer(1.into());
1207 for _ in 0..200 {
1208 val = V::Map(vec![(V::String("a".into()), val)]);
1209 }
1210 let mut bytes = Vec::new();
1211 rmpv::encode::write_value(&mut bytes, &val).unwrap();
1212
1213 let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
1214 assert!(result.is_err());
1215 assert!(result.unwrap_err().contains("depth"));
1216 }
1217
1218 #[test]
1219 fn msgpack_depth_check_truncated_payload_does_not_panic() {
1220 let val = json!({"a": {"b": [1, 2, 3]}});
1224 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1225 for cut in [1, 3, 5, bytes.len() / 2] {
1226 let _ = check_msgpack_depth(&bytes[..cut], 128);
1227 }
1228 assert!(check_msgpack_depth(&[0x81], 128).is_err()); assert!(check_msgpack_depth(&[0x91], 128).is_err()); assert!(check_msgpack_depth(&[0xdc], 128).is_ok()); assert!(check_msgpack_depth(&[0xde, 0x00], 128).is_ok()); }
1235
1236 #[test]
1237 fn msgpack_depth_check_empty_input() {
1238 assert!(check_msgpack_depth(&[], 128).is_ok());
1239 }
1240
1241 #[test]
1242 fn msgpack_depth_check_scalars_only() {
1243 let val = json!(42);
1245 let bytes = rmp_serde::to_vec_named(&val).unwrap();
1246 assert!(check_msgpack_depth(&bytes, 0).is_ok());
1247 }
1248
1249 #[test]
1250 fn msgpack_depth_check_rejects_forged_element_count() {
1251 let mut bytes = vec![0xdf]; bytes.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes()); bytes.extend_from_slice(&[0xa1, b'k', 0x01]); let result = check_msgpack_depth(&bytes, 128);
1259 assert!(result.is_err());
1260 assert!(result.unwrap_err().contains("elements"));
1261 }
1262
1263 #[test]
1264 fn msgpack_decode_rejects_forged_element_count() {
1265 let mut bytes = vec![0xdd]; bytes.extend_from_slice(&0x7FFF_FFFFu32.to_be_bytes()); bytes.push(0x01); let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
1271 assert!(result.is_err());
1272 assert!(result.unwrap_err().contains("elements"));
1273 }
1274
1275 #[test]
1278 fn json_to_rmpv_scalars() {
1279 assert_eq!(json_to_rmpv(json!(null)), rmpv::Value::Nil);
1280 assert_eq!(json_to_rmpv(json!(true)), rmpv::Value::Boolean(true));
1281 assert_eq!(json_to_rmpv(json!(42)), rmpv::Value::Integer(42.into()));
1282 assert_eq!(json_to_rmpv(json!(2.5)), rmpv::Value::F64(2.5));
1283 assert_eq!(
1284 json_to_rmpv(json!("hello")),
1285 rmpv::Value::String("hello".into())
1286 );
1287 }
1288
1289 #[test]
1290 fn json_to_rmpv_nested() {
1291 let val = json!({"key": [1, "two", null]});
1292 let rmpv = json_to_rmpv(val);
1293 match rmpv {
1294 rmpv::Value::Map(entries) => {
1295 assert_eq!(entries.len(), 1);
1296 let (k, v) = &entries[0];
1297 assert_eq!(k, &rmpv::Value::String("key".into()));
1298 match v {
1299 rmpv::Value::Array(arr) => {
1300 assert_eq!(arr.len(), 3);
1301 assert_eq!(arr[0], rmpv::Value::Integer(1.into()));
1302 assert_eq!(arr[2], rmpv::Value::Nil);
1303 }
1304 other => panic!("expected array, got {other:?}"),
1305 }
1306 }
1307 other => panic!("expected map, got {other:?}"),
1308 }
1309 }
1310
1311 #[test]
1314 fn encode_binary_message_json_without_binary() {
1315 let mut map = serde_json::Map::new();
1316 map.insert("type".to_string(), json!("test"));
1317 map.insert("id".to_string(), json!("t1"));
1318
1319 let bytes = Codec::Json.encode_binary_message(map, None).unwrap();
1320 let s = std::str::from_utf8(&bytes).unwrap();
1321 assert!(s.ends_with('\n'));
1322 let parsed: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
1323 assert_eq!(parsed["type"], "test");
1324 assert_eq!(parsed["id"], "t1");
1325 assert!(parsed.get("rgba").is_none());
1326 }
1327
1328 #[test]
1329 fn encode_binary_message_json_with_binary() {
1330 use base64::Engine as _;
1331
1332 let mut map = serde_json::Map::new();
1333 map.insert("type".to_string(), json!("screenshot"));
1334 let pixel_data = vec![255u8, 0, 128, 64];
1335
1336 let bytes = Codec::Json
1337 .encode_binary_message(map, Some(("rgba", &pixel_data)))
1338 .unwrap();
1339 let parsed: serde_json::Value = serde_json::from_slice(&bytes[..bytes.len() - 1]).unwrap();
1340 let b64 = parsed["rgba"].as_str().unwrap();
1341 let decoded = base64::engine::general_purpose::STANDARD
1342 .decode(b64)
1343 .unwrap();
1344 assert_eq!(decoded, pixel_data);
1345 }
1346
1347 #[test]
1348 fn encode_binary_message_msgpack_with_binary() {
1349 let mut map = serde_json::Map::new();
1350 map.insert("type".to_string(), json!("screenshot"));
1351 map.insert("id".to_string(), json!("s1"));
1352 let pixel_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
1353
1354 let bytes = Codec::MsgPack
1355 .encode_binary_message(map, Some(("rgba", &pixel_data)))
1356 .unwrap();
1357
1358 let payload = &bytes[4..];
1360 let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &payload[..]).unwrap();
1361
1362 match rmpv_val {
1364 rmpv::Value::Map(entries) => {
1365 let rgba_entry = entries
1366 .iter()
1367 .find(|(k, _)| k == &rmpv::Value::String("rgba".into()));
1368 match rgba_entry {
1369 Some((_, rmpv::Value::Binary(data))) => {
1370 assert_eq!(data, &pixel_data);
1371 }
1372 other => panic!("expected Binary rgba field, got {other:?}"),
1373 }
1374 }
1375 other => panic!("expected Map, got {other:?}"),
1376 }
1377 }
1378
1379 #[test]
1380 fn encode_binary_message_msgpack_roundtrip_non_binary_fields() {
1381 let mut map = serde_json::Map::new();
1382 map.insert("type".to_string(), json!("test"));
1383 map.insert("count".to_string(), json!(42));
1384 map.insert("nested".to_string(), json!({"a": [1, 2]}));
1385
1386 let bytes = Codec::MsgPack.encode_binary_message(map, None).unwrap();
1387 let decoded: serde_json::Value = Codec::MsgPack.decode(&bytes[4..]).unwrap();
1388 assert_eq!(decoded["type"], "test");
1389 assert_eq!(decoded["count"], 42);
1390 assert_eq!(decoded["nested"]["a"][0], 1);
1391 }
1392}