Skip to main content

toddy_core/
codec.rs

1//! Wire codec for the stdin/stdout protocol.
2//!
3//! The renderer communicates with the host process over stdin (incoming
4//! messages) and stdout (outgoing events). Two wire formats are supported:
5//!
6//! - **JSON** -- newline-delimited JSON (JSONL). Each message is a UTF-8
7//!   JSON object terminated by `\n`. Human-readable, easy to debug.
8//!
9//! - **MsgPack** -- 4-byte big-endian length-prefixed MessagePack. Each
10//!   message is `[u32 BE length][msgpack payload]`. Compact, faster to
11//!   parse, supports native binary fields (e.g. pixel data).
12//!
13//! The codec is auto-detected from the first byte of stdin (`{` = JSON,
14//! anything else = MsgPack) and stored in a process-global [`OnceLock`]
15//! so that all emit paths (events, queries, screenshots) use the same
16//! format without threading the codec through every call site.
17
18use serde::Serialize;
19use serde::de::DeserializeOwned;
20use std::fmt;
21use std::io::{self, BufRead, Read};
22use std::sync::OnceLock;
23
24/// Maximum size for a single wire message (64 MiB). Applied to both JSON
25/// line reads and msgpack length-prefixed frames.
26pub const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
27
28/// Maximum nesting depth for `rmpv_to_json` conversion. Prevents stack
29/// overflow from deeply nested (or maliciously crafted) msgpack payloads.
30const MAX_RMPV_DEPTH: usize = 128;
31
32/// Process-global wire codec, set once at startup via [`Codec::set_global`].
33static WIRE_CODEC: OnceLock<Codec> = OnceLock::new();
34
35/// Wire codec for the stdin/stdout protocol.
36///
37/// See the [module documentation](self) for format details.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Codec {
40    /// Newline-delimited JSON (JSONL).
41    Json,
42    /// Length-prefixed MessagePack.
43    MsgPack,
44}
45
46impl fmt::Display for Codec {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            Codec::Json => f.write_str("json"),
50            Codec::MsgPack => f.write_str("msgpack"),
51        }
52    }
53}
54
55impl Codec {
56    /// Encode a value to wire bytes ready to write to stdout.
57    ///
58    /// - JSON: `serde_json` serialization + trailing `\n`.
59    /// - MsgPack: 4-byte BE u32 length prefix + `rmp_serde` named serialization.
60    ///
61    /// Allocates a new Vec per call. For hot paths (e.g. rapid event
62    /// emission), consider pre-allocating and reusing a buffer.
63    pub fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, String> {
64        match self {
65            Codec::Json => {
66                let mut bytes =
67                    serde_json::to_vec(value).map_err(|e| format!("json encode: {e}"))?;
68                bytes.push(b'\n');
69                Ok(bytes)
70            }
71            Codec::MsgPack => {
72                let payload =
73                    rmp_serde::to_vec_named(value).map_err(|e| format!("msgpack encode: {e}"))?;
74                let len = u32::try_from(payload.len()).map_err(|_| {
75                    format!(
76                        "payload exceeds 4 GiB frame limit ({} bytes)",
77                        payload.len()
78                    )
79                })?;
80                let mut bytes = Vec::with_capacity(4 + payload.len());
81                bytes.extend_from_slice(&len.to_be_bytes());
82                bytes.extend_from_slice(&payload);
83                Ok(bytes)
84            }
85        }
86    }
87
88    /// Encode a JSON map with an optional binary field to wire bytes.
89    ///
90    /// For MsgPack: binary fields are encoded as native msgpack binary
91    /// (`rmpv::Value::Binary`), avoiding the ~33% size overhead of
92    /// base64. The map is built via `rmpv::Value::Map` to preserve
93    /// the binary type.
94    ///
95    /// For JSON: binary fields are base64-encoded as strings.
96    ///
97    /// Use this instead of [`encode`](Self::encode) when the message
98    /// contains raw byte data (e.g. pixel buffers) that should use
99    /// native binary encoding over msgpack.
100    pub fn encode_binary_message(
101        &self,
102        mut map: serde_json::Map<String, serde_json::Value>,
103        binary_field: Option<(&str, &[u8])>,
104    ) -> Result<Vec<u8>, String> {
105        match self {
106            Codec::Json => {
107                if let Some((key, bytes)) = binary_field
108                    && !bytes.is_empty()
109                {
110                    use base64::Engine;
111                    let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
112                    map.insert(key.to_string(), serde_json::Value::String(b64));
113                }
114                let val = serde_json::Value::Object(map);
115                let mut bytes =
116                    serde_json::to_vec(&val).map_err(|e| format!("json encode: {e}"))?;
117                bytes.push(b'\n');
118                Ok(bytes)
119            }
120            Codec::MsgPack => {
121                use rmpv::Value as V;
122
123                let mut entries: Vec<(V, V)> = map
124                    .into_iter()
125                    .map(|(k, v)| (V::String(k.into()), json_to_rmpv(v)))
126                    .collect();
127
128                if let Some((key, bytes)) = binary_field
129                    && !bytes.is_empty()
130                {
131                    entries.push((V::String(key.into()), V::Binary(bytes.to_vec())));
132                }
133
134                let msg = V::Map(entries);
135                let mut payload = Vec::new();
136                rmpv::encode::write_value(&mut payload, &msg)
137                    .map_err(|e| format!("msgpack encode: {e}"))?;
138                let len = u32::try_from(payload.len()).map_err(|_| {
139                    format!(
140                        "payload exceeds 4 GiB frame limit ({} bytes)",
141                        payload.len()
142                    )
143                })?;
144                let mut bytes = Vec::with_capacity(4 + payload.len());
145                bytes.extend_from_slice(&len.to_be_bytes());
146                bytes.extend_from_slice(&payload);
147                Ok(bytes)
148            }
149        }
150    }
151
152    /// Decode a raw payload (framing already stripped) into a typed value.
153    ///
154    /// For JSON, `bytes` is the UTF-8 JSON text (without the trailing newline).
155    /// For MsgPack, `bytes` is the raw msgpack payload (without the length prefix).
156    ///
157    /// MsgPack decoding routes through `rmpv::Value` as an intermediate. This
158    /// preserves binary data (msgpack's bin type) as JSON arrays of byte values,
159    /// which the `deserialize_binary_field` custom deserializer in protocol.rs
160    /// can reconstruct into `Vec<u8>`. The `serde_json::Value` intermediate is
161    /// still needed for tag dispatch (`#[serde(tag = "type")]`) which rmp-serde
162    /// doesn't handle reliably for externally-produced msgpack.
163    pub fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T, String> {
164        match self {
165            Codec::Json => serde_json::from_slice(bytes).map_err(|e| format!("json decode: {e}")),
166            Codec::MsgPack => {
167                // Pre-check nesting depth before rmpv deserialization.
168                // rmpv::read_value recurses without a depth limit, so a
169                // pathologically nested payload can cause stack overflow
170                // before our depth-limited rmpv_to_json conversion runs.
171                check_msgpack_depth(bytes, MAX_RMPV_DEPTH)
172                    .map_err(|e| format!("msgpack depth check: {e}"))?;
173                let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &bytes[..])
174                    .map_err(|e| format!("msgpack decode (rmpv): {e}"))?;
175                let json_val = rmpv_to_json(rmpv_val);
176                serde_json::from_value(json_val)
177                    .map_err(|e| format!("msgpack decode (tag dispatch): {e}"))
178            }
179        }
180    }
181
182    /// Read one framed message from a buffered reader, returning the raw payload.
183    ///
184    /// - JSON: reads until `\n`, returns the line bytes (without the newline).
185    /// - MsgPack: reads a 4-byte BE u32 length, then reads that many bytes.
186    ///
187    /// Returns `Ok(None)` on EOF (clean shutdown).
188    pub fn read_message<R: BufRead>(&self, reader: &mut R) -> io::Result<Option<Vec<u8>>> {
189        match self {
190            Codec::Json => loop {
191                let mut line = String::new();
192                // Wrap in Take to bound allocation BEFORE the full line is
193                // buffered. Without this, a sender could transmit an arbitrarily
194                // long line without a newline, causing unbounded memory growth.
195                let limit = (MAX_MESSAGE_SIZE + 1) as u64;
196                let n = (&mut *reader).take(limit).read_line(&mut line)?;
197                if n == 0 {
198                    return Ok(None);
199                }
200                if line.len() > MAX_MESSAGE_SIZE {
201                    return Err(io::Error::new(
202                        io::ErrorKind::InvalidData,
203                        format!(
204                            "JSON message exceeds {} byte limit ({} bytes)",
205                            MAX_MESSAGE_SIZE,
206                            line.len()
207                        ),
208                    ));
209                }
210                let trimmed = line.trim();
211                if trimmed.is_empty() {
212                    continue;
213                }
214                return Ok(Some(trimmed.as_bytes().to_vec()));
215            },
216            Codec::MsgPack => {
217                let mut len_buf = [0u8; 4];
218                match reader.read_exact(&mut len_buf) {
219                    Ok(()) => {}
220                    Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
221                    Err(e) => return Err(e),
222                }
223                let len = u32::from_be_bytes(len_buf) as usize;
224                if len == 0 {
225                    return Err(io::Error::new(
226                        io::ErrorKind::InvalidData,
227                        "empty frame received",
228                    ));
229                }
230                if len > MAX_MESSAGE_SIZE {
231                    return Err(io::Error::new(
232                        io::ErrorKind::InvalidData,
233                        format!(
234                            "msgpack frame exceeds {} byte limit ({} bytes)",
235                            MAX_MESSAGE_SIZE, len
236                        ),
237                    ));
238                }
239                let mut payload = vec![0u8; len];
240                reader.read_exact(&mut payload)?;
241                Ok(Some(payload))
242            }
243        }
244    }
245
246    /// Detect codec from the first byte of input.
247    ///
248    /// `{` (0x7B) indicates JSON. Anything else indicates MsgPack (the first
249    /// byte of a 4-byte length prefix).
250    pub fn detect_from_first_byte(byte: u8) -> Codec {
251        if byte == b'{' {
252            Codec::Json
253        } else {
254            Codec::MsgPack
255        }
256    }
257
258    /// Store the negotiated codec in the global slot. Panics if called twice.
259    pub fn set_global(codec: Codec) {
260        WIRE_CODEC
261            .set(codec)
262            .expect("WIRE_CODEC already initialized");
263    }
264
265    /// Get the global wire codec. Returns MsgPack if not yet initialized.
266    pub fn get_global() -> &'static Codec {
267        WIRE_CODEC.get().unwrap_or(&Codec::MsgPack)
268    }
269}
270
271// ---------------------------------------------------------------------------
272// Msgpack nesting depth pre-check
273// ---------------------------------------------------------------------------
274
275/// Iteratively scan raw msgpack bytes and reject payloads that would cause
276/// problems for `rmpv::read_value`:
277///
278/// - **Nesting depth** exceeding `max_depth` (prevents stack overflow from
279///   rmpv's recursive parser).
280/// - **Declared element counts** exceeding the remaining bytes (prevents
281///   rmpv from pre-allocating `Vec::with_capacity(billions)` when the
282///   declared count is larger than the payload can possibly contain).
283fn check_msgpack_depth(bytes: &[u8], max_depth: usize) -> Result<(), String> {
284    let len = bytes.len();
285    let mut pos: usize = 0;
286    let mut depth: usize = 0;
287    // Stack tracks how many child elements remain at each nesting level.
288    let mut remaining: Vec<usize> = Vec::new();
289
290    while pos < len {
291        let b = bytes[pos];
292        pos += 1;
293
294        // Classify the format marker: (data_bytes_to_skip, child_element_count).
295        // For containers (array/map), child_count > 0 and we push a new depth level.
296        // For scalars, child_count == 0 and we consume one element from the parent.
297        let (skip, children) = match b {
298            // positive fixint
299            0x00..=0x7f => (0, 0),
300            // fixmap: N key-value pairs = 2N child elements
301            0x80..=0x8f => (0, ((b & 0x0f) as usize) * 2),
302            // fixarray
303            0x90..=0x9f => (0, (b & 0x0f) as usize),
304            // fixstr
305            0xa0..=0xbf => ((b & 0x1f) as usize, 0),
306            // nil, (unused), false, true
307            0xc0..=0xc3 => (0, 0),
308            // bin8
309            0xc4 => {
310                if pos >= len {
311                    break;
312                }
313                (1 + bytes[pos] as usize, 0)
314            }
315            // bin16
316            0xc5 => {
317                if pos + 1 >= len {
318                    break;
319                }
320                let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
321                (2 + n, 0)
322            }
323            // bin32
324            0xc6 => {
325                if pos + 3 >= len {
326                    break;
327                }
328                let n = u32::from_be_bytes([
329                    bytes[pos],
330                    bytes[pos + 1],
331                    bytes[pos + 2],
332                    bytes[pos + 3],
333                ]) as usize;
334                (4 + n, 0)
335            }
336            // ext8
337            0xc7 => {
338                if pos >= len {
339                    break;
340                }
341                (2 + bytes[pos] as usize, 0)
342            }
343            // ext16
344            0xc8 => {
345                if pos + 1 >= len {
346                    break;
347                }
348                let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
349                (3 + n, 0)
350            }
351            // ext32
352            0xc9 => {
353                if pos + 3 >= len {
354                    break;
355                }
356                let n = u32::from_be_bytes([
357                    bytes[pos],
358                    bytes[pos + 1],
359                    bytes[pos + 2],
360                    bytes[pos + 3],
361                ]) as usize;
362                (5 + n, 0)
363            }
364            // float32
365            0xca => (4, 0),
366            // float64
367            0xcb => (8, 0),
368            // uint8, int8
369            0xcc | 0xd0 => (1, 0),
370            // uint16, int16
371            0xcd | 0xd1 => (2, 0),
372            // uint32, int32
373            0xce | 0xd2 => (4, 0),
374            // uint64, int64
375            0xcf | 0xd3 => (8, 0),
376            // fixext 1, 2, 4, 8, 16 (type byte + data)
377            0xd4 => (2, 0),
378            0xd5 => (3, 0),
379            0xd6 => (5, 0),
380            0xd7 => (9, 0),
381            0xd8 => (17, 0),
382            // str8
383            0xd9 => {
384                if pos >= len {
385                    break;
386                }
387                (1 + bytes[pos] as usize, 0)
388            }
389            // str16
390            0xda => {
391                if pos + 1 >= len {
392                    break;
393                }
394                let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
395                (2 + n, 0)
396            }
397            // str32
398            0xdb => {
399                if pos + 3 >= len {
400                    break;
401                }
402                let n = u32::from_be_bytes([
403                    bytes[pos],
404                    bytes[pos + 1],
405                    bytes[pos + 2],
406                    bytes[pos + 3],
407                ]) as usize;
408                (4 + n, 0)
409            }
410            // array16
411            0xdc => {
412                if pos + 1 >= len {
413                    break;
414                }
415                let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
416                pos += 2;
417                (0, n)
418            }
419            // array32
420            0xdd => {
421                if pos + 3 >= len {
422                    break;
423                }
424                let n = u32::from_be_bytes([
425                    bytes[pos],
426                    bytes[pos + 1],
427                    bytes[pos + 2],
428                    bytes[pos + 3],
429                ]) as usize;
430                pos += 4;
431                (0, n)
432            }
433            // map16
434            0xde => {
435                if pos + 1 >= len {
436                    break;
437                }
438                let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
439                pos += 2;
440                (0, n * 2)
441            }
442            // map32
443            0xdf => {
444                if pos + 3 >= len {
445                    break;
446                }
447                let n = u32::from_be_bytes([
448                    bytes[pos],
449                    bytes[pos + 1],
450                    bytes[pos + 2],
451                    bytes[pos + 3],
452                ]) as usize;
453                pos += 4;
454                (0, n * 2)
455            }
456            // negative fixint
457            0xe0..=0xff => (0, 0),
458        };
459
460        pos += skip;
461
462        if children > 0 {
463            // Each child element needs at least 1 byte. Reject declared
464            // counts that exceed the remaining data to prevent rmpv from
465            // pre-allocating huge Vecs based on a forged count field.
466            let remaining_bytes = len.saturating_sub(pos);
467            if children > remaining_bytes {
468                return Err(format!(
469                    "msgpack container declares {children} elements but only {remaining_bytes} bytes remain"
470                ));
471            }
472
473            depth += 1;
474            if depth > max_depth {
475                return Err(format!("msgpack nesting depth exceeds limit ({max_depth})"));
476            }
477            remaining.push(children);
478        } else {
479            // Leaf value consumed: pop completed containers.
480            while let Some(count) = remaining.last_mut() {
481                *count -= 1;
482                if *count == 0 {
483                    remaining.pop();
484                    depth -= 1;
485                } else {
486                    break;
487                }
488            }
489        }
490    }
491
492    Ok(())
493}
494
495// ---------------------------------------------------------------------------
496// rmpv::Value -> serde_json::Value conversion
497// ---------------------------------------------------------------------------
498
499/// Convert an rmpv::Value to serde_json::Value, preserving binary data as
500/// JSON arrays of byte values (u8). This is the key difference from the old
501/// rmp_serde -> serde_json::Value path, which silently dropped binary data
502/// (serde_json::Value has no binary type).
503///
504/// The `deserialize_binary_field` custom deserializer in protocol.rs knows
505/// how to reconstruct `Vec<u8>` from these byte arrays.
506///
507/// Recursion depth is capped at `MAX_RMPV_DEPTH` to prevent stack overflow
508/// from deeply nested or malicious payloads.
509fn rmpv_to_json(val: rmpv::Value) -> serde_json::Value {
510    rmpv_to_json_inner(val, 0)
511}
512
513fn rmpv_to_json_inner(val: rmpv::Value, depth: usize) -> serde_json::Value {
514    if depth > MAX_RMPV_DEPTH {
515        log::error!("rmpv_to_json: recursion depth exceeded {MAX_RMPV_DEPTH}, replaced with null");
516        return serde_json::Value::Null;
517    }
518
519    match val {
520        rmpv::Value::Nil => serde_json::Value::Null,
521        rmpv::Value::Boolean(b) => serde_json::Value::Bool(b),
522        rmpv::Value::Integer(n) => {
523            if let Some(i) = n.as_i64() {
524                serde_json::Value::Number(i.into())
525            } else if let Some(u) = n.as_u64() {
526                serde_json::Value::Number(u.into())
527            } else {
528                // Fallback: shouldn't happen for msgpack integers
529                serde_json::Value::Null
530            }
531        }
532        rmpv::Value::F32(f) => serde_json::Number::from_f64(f as f64)
533            .map(serde_json::Value::Number)
534            .unwrap_or_else(|| {
535                log::warn!("rmpv_to_json: non-finite f32 ({f}) replaced with 0.0");
536                serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
537            }),
538        rmpv::Value::F64(f) => serde_json::Number::from_f64(f)
539            .map(serde_json::Value::Number)
540            .unwrap_or_else(|| {
541                log::warn!("rmpv_to_json: non-finite f64 ({f}) replaced with 0.0");
542                serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
543            }),
544        rmpv::Value::String(s) => {
545            // rmpv::Utf8String -- may or may not be valid UTF-8.
546            // Use lossy conversion so invalid bytes become U+FFFD instead of
547            // silently mapping to null (which breaks tag dispatch on "type").
548            serde_json::Value::String(String::from_utf8_lossy(s.as_bytes()).into_owned())
549        }
550        rmpv::Value::Binary(bytes) => {
551            // Preserve raw bytes as a JSON array of u8 values.
552            // The deserialize_binary_field custom deserializer reconstructs Vec<u8>.
553            serde_json::Value::Array(
554                bytes
555                    .into_iter()
556                    .map(|b| serde_json::Value::Number(b.into()))
557                    .collect(),
558            )
559        }
560        rmpv::Value::Array(arr) => serde_json::Value::Array(
561            arr.into_iter()
562                .map(|v| rmpv_to_json_inner(v, depth + 1))
563                .collect(),
564        ),
565        rmpv::Value::Map(entries) => {
566            let mut map = serde_json::Map::new();
567            for (k, v) in entries {
568                // Map keys: try to use string representation
569                let key = match k {
570                    rmpv::Value::String(s) => s.into_str().unwrap_or_default().to_string(),
571                    rmpv::Value::Integer(n) => n.to_string(),
572                    other => format!("{other}"),
573                };
574                map.insert(key, rmpv_to_json_inner(v, depth + 1));
575            }
576            serde_json::Value::Object(map)
577        }
578        rmpv::Value::Ext(type_id, _bytes) => {
579            log::warn!(
580                "rmpv_to_json: msgpack ext type {type_id} not supported, replaced with null"
581            );
582            serde_json::Value::Null
583        }
584    }
585}
586
587/// Convert a serde_json::Value to rmpv::Value for msgpack encoding.
588/// Used by `encode_binary_message` to build rmpv maps from JSON maps.
589fn json_to_rmpv(val: serde_json::Value) -> rmpv::Value {
590    match val {
591        serde_json::Value::Null => rmpv::Value::Nil,
592        serde_json::Value::Bool(b) => rmpv::Value::Boolean(b),
593        serde_json::Value::Number(n) => {
594            if let Some(i) = n.as_i64() {
595                rmpv::Value::Integer(i.into())
596            } else if let Some(u) = n.as_u64() {
597                rmpv::Value::Integer(u.into())
598            } else if let Some(f) = n.as_f64() {
599                rmpv::Value::F64(f)
600            } else {
601                rmpv::Value::Nil
602            }
603        }
604        serde_json::Value::String(s) => rmpv::Value::String(s.into()),
605        serde_json::Value::Array(arr) => {
606            rmpv::Value::Array(arr.into_iter().map(json_to_rmpv).collect())
607        }
608        serde_json::Value::Object(map) => rmpv::Value::Map(
609            map.into_iter()
610                .map(|(k, v)| (rmpv::Value::String(k.into()), json_to_rmpv(v)))
611                .collect(),
612        ),
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619    use serde::{Deserialize, Serialize};
620    use serde_json::json;
621
622    #[derive(Debug, Serialize, Deserialize, PartialEq)]
623    struct Simple {
624        name: String,
625        count: u32,
626    }
627
628    #[derive(Debug, Serialize, Deserialize, PartialEq)]
629    #[serde(tag = "type", rename_all = "snake_case")]
630    enum Tagged {
631        Alpha { value: String },
632        Beta { x: f64, y: f64 },
633    }
634
635    #[derive(Debug, Serialize, Deserialize, PartialEq)]
636    struct WithFlatten {
637        op: String,
638        #[serde(flatten)]
639        rest: serde_json::Value,
640    }
641
642    // -- JSON roundtrips --
643
644    #[test]
645    fn json_roundtrip_simple() {
646        let original = Simple {
647            name: "test".into(),
648            count: 42,
649        };
650        let bytes = Codec::Json.encode(&original).unwrap();
651        assert!(bytes.ends_with(b"\n"));
652        let decoded: Simple = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
653        assert_eq!(decoded, original);
654    }
655
656    #[test]
657    fn json_roundtrip_tagged_enum() {
658        let original = Tagged::Beta { x: 1.5, y: 2.5 };
659        let bytes = Codec::Json.encode(&original).unwrap();
660        let decoded: Tagged = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
661        assert_eq!(decoded, original);
662    }
663
664    // -- MsgPack roundtrips --
665
666    #[test]
667    fn msgpack_roundtrip_simple() {
668        let original = Simple {
669            name: "test".into(),
670            count: 42,
671        };
672        let bytes = Codec::MsgPack.encode(&original).unwrap();
673        // First 4 bytes are length prefix
674        let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
675        assert_eq!(len, bytes.len() - 4);
676        let decoded: Simple = Codec::MsgPack.decode(&bytes[4..]).unwrap();
677        assert_eq!(decoded, original);
678    }
679
680    #[test]
681    fn msgpack_roundtrip_tagged_enum() {
682        let original = Tagged::Alpha {
683            value: "hello".into(),
684        };
685        let bytes = Codec::MsgPack.encode(&original).unwrap();
686        let payload = &bytes[4..];
687        let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
688        assert_eq!(decoded, original);
689    }
690
691    #[test]
692    fn msgpack_roundtrip_tagged_enum_beta() {
693        let original = Tagged::Beta {
694            x: std::f64::consts::PI,
695            y: -1.0,
696        };
697        let bytes = Codec::MsgPack.encode(&original).unwrap();
698        let payload = &bytes[4..];
699        let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
700        assert_eq!(decoded, original);
701    }
702
703    #[test]
704    fn msgpack_flatten_deserialize() {
705        // Flatten on deserialize: encode a map with extra keys, decode into
706        // a struct with #[serde(flatten)] rest: Value.
707        let input = json!({"op": "props", "path": [0, 1], "props": {"label": "hi"}});
708        let bytes = rmp_serde::to_vec_named(&input).unwrap();
709        let decoded: WithFlatten = rmp_serde::from_slice(&bytes).unwrap();
710        assert_eq!(decoded.op, "props");
711        assert_eq!(decoded.rest["path"], json!([0, 1]));
712        assert_eq!(decoded.rest["props"]["label"], "hi");
713    }
714
715    // -- read_message --
716
717    #[test]
718    fn json_read_message_skips_blank_lines() {
719        // Blank lines between messages must be skipped, not treated as EOF.
720        let data = b"\n\n{\"name\":\"a\",\"count\":1}\n\n{\"name\":\"b\",\"count\":2}\n\n";
721        let mut reader = io::BufReader::new(&data[..]);
722
723        let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
724        let s1: Simple = Codec::Json.decode(&msg1).unwrap();
725        assert_eq!(s1.name, "a");
726
727        let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
728        let s2: Simple = Codec::Json.decode(&msg2).unwrap();
729        assert_eq!(s2.name, "b");
730
731        // Trailing blank lines followed by real EOF should return None.
732        assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
733    }
734
735    #[test]
736    fn json_read_message() {
737        let data = b"{\"name\":\"a\",\"count\":1}\n{\"name\":\"b\",\"count\":2}\n";
738        let mut reader = io::BufReader::new(&data[..]);
739
740        let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
741        let s1: Simple = Codec::Json.decode(&msg1).unwrap();
742        assert_eq!(s1.name, "a");
743
744        let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
745        let s2: Simple = Codec::Json.decode(&msg2).unwrap();
746        assert_eq!(s2.name, "b");
747
748        assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
749    }
750
751    #[test]
752    fn msgpack_read_message() {
753        // Build two length-prefixed msgpack messages
754        let s1 = Simple {
755            name: "x".into(),
756            count: 10,
757        };
758        let s2 = Simple {
759            name: "y".into(),
760            count: 20,
761        };
762        let p1 = rmp_serde::to_vec_named(&s1).unwrap();
763        let p2 = rmp_serde::to_vec_named(&s2).unwrap();
764
765        let mut data = Vec::new();
766        data.extend_from_slice(&(p1.len() as u32).to_be_bytes());
767        data.extend_from_slice(&p1);
768        data.extend_from_slice(&(p2.len() as u32).to_be_bytes());
769        data.extend_from_slice(&p2);
770
771        let mut reader = io::BufReader::new(&data[..]);
772
773        let msg1 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
774        let d1: Simple = Codec::MsgPack.decode(&msg1).unwrap();
775        assert_eq!(d1, s1);
776
777        let msg2 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
778        let d2: Simple = Codec::MsgPack.decode(&msg2).unwrap();
779        assert_eq!(d2, s2);
780
781        assert!(Codec::MsgPack.read_message(&mut reader).unwrap().is_none());
782    }
783
784    // -- read_message size limit tests --
785
786    #[test]
787    fn json_read_message_rejects_oversized_line() {
788        // A line longer than MAX_MESSAGE_SIZE must be rejected.
789        // We can't allocate 64 MiB in a test, so use a smaller custom
790        // read_message-like flow. Instead, verify the Take wrapper works
791        // by constructing a line just over the limit.
792        //
793        // Since MAX_MESSAGE_SIZE is 64 MiB (too big for a unit test),
794        // we test the logic indirectly: a line of exactly MAX_MESSAGE_SIZE+1
795        // bytes (no newline) should be rejected. We use a small stand-in
796        // to verify the mechanics.
797        let small_limit = 100;
798        // Construct a line with no newline, longer than small_limit.
799        let long_line: Vec<u8> = vec![b'x'; small_limit + 10];
800        let mut reader = io::BufReader::new(&long_line[..]);
801
802        // Read using Take with the small limit -- simulates what
803        // read_message does, just with a smaller limit.
804        let mut line = String::new();
805        let limit = (small_limit + 1) as u64;
806        let _n = (&mut reader).take(limit).read_line(&mut line).unwrap();
807        // The Take capped the read, so line.len() <= small_limit + 1.
808        assert!(line.len() <= small_limit + 1);
809        // Without the Take, line.len() would be small_limit + 10.
810    }
811
812    #[test]
813    fn msgpack_read_message_rejects_oversized_frame() {
814        // Build a frame with length prefix claiming MAX_MESSAGE_SIZE + 1 bytes.
815        let len = (MAX_MESSAGE_SIZE + 1) as u32;
816        let mut data = Vec::new();
817        data.extend_from_slice(&len.to_be_bytes());
818        // Don't need the actual payload -- the size check fires first.
819        data.extend_from_slice(&[0u8; 64]); // just enough to not EOF
820
821        let mut reader = io::BufReader::new(&data[..]);
822        let result = Codec::MsgPack.read_message(&mut reader);
823        assert!(result.is_err());
824        let err = result.unwrap_err();
825        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
826        assert!(err.to_string().contains("byte limit"));
827    }
828
829    #[test]
830    fn msgpack_read_message_rejects_zero_length_frame() {
831        let mut data = Vec::new();
832        data.extend_from_slice(&0u32.to_be_bytes());
833
834        let mut reader = io::BufReader::new(&data[..]);
835        let result = Codec::MsgPack.read_message(&mut reader);
836        assert!(result.is_err());
837        assert!(result.unwrap_err().to_string().contains("empty frame"));
838    }
839
840    // -- Cross-format: simulate external msgpack (e.g. Msgpax) --
841    //
842    // rmp-serde's own serializer produces bytes that its deserializer can
843    // roundtrip, but external msgpack producers encode maps differently.
844    // These tests build raw msgpack via serde_json::Value -> rmp_serde
845    // (which is format-agnostic, not tagged-enum-aware) to simulate what
846    // an external producer like Msgpax sends. The Codec::decode workaround
847    // (msgpack -> rmpv::Value -> serde_json::Value -> T) must handle these.
848
849    #[test]
850    fn msgpack_external_tagged_enum_alpha() {
851        // Simulate Msgpax encoding {"type": "alpha", "value": "hello"}
852        let external = json!({"type": "alpha", "value": "hello"});
853        let bytes = rmp_serde::to_vec_named(&external).unwrap();
854        let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
855        assert_eq!(
856            decoded,
857            Tagged::Alpha {
858                value: "hello".into()
859            }
860        );
861    }
862
863    #[test]
864    fn msgpack_external_tagged_enum_beta() {
865        let external = json!({"type": "beta", "x": 1.5, "y": -2.0});
866        let bytes = rmp_serde::to_vec_named(&external).unwrap();
867        let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
868        assert_eq!(decoded, Tagged::Beta { x: 1.5, y: -2.0 });
869    }
870
871    #[test]
872    fn msgpack_external_incoming_settings() {
873        // This is exactly what a host sends: a plain map with "type":"settings".
874        use crate::protocol::IncomingMessage;
875        let external = json!({"type": "settings", "settings": {"antialiasing": false}});
876        let bytes = rmp_serde::to_vec_named(&external).unwrap();
877        let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
878        assert!(matches!(decoded, IncomingMessage::Settings { .. }));
879    }
880
881    #[test]
882    fn msgpack_external_incoming_snapshot() {
883        use crate::protocol::IncomingMessage;
884        let external = json!({"type": "snapshot", "tree": {"id": "root", "type": "column", "props": {}, "children": []}});
885        let bytes = rmp_serde::to_vec_named(&external).unwrap();
886        let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
887        assert!(matches!(decoded, IncomingMessage::Snapshot { .. }));
888    }
889
890    // -- Binary data preservation through rmpv path --
891
892    #[test]
893    fn msgpack_image_op_with_native_binary() {
894        // Simulate what an external producer sends when using native binary fields.
895        // Build raw msgpack with a binary field using rmpv directly.
896        use rmpv::Value as RmpvValue;
897
898        let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255, 0, 255, 0, 255]; // 2 RGBA pixels
899        let msg = RmpvValue::Map(vec![
900            (
901                RmpvValue::String("type".into()),
902                RmpvValue::String("image_op".into()),
903            ),
904            (
905                RmpvValue::String("op".into()),
906                RmpvValue::String("create_image".into()),
907            ),
908            (
909                RmpvValue::String("handle".into()),
910                RmpvValue::String("test_img".into()),
911            ),
912            (
913                RmpvValue::String("pixels".into()),
914                RmpvValue::Binary(pixel_bytes.clone()),
915            ),
916            (
917                RmpvValue::String("width".into()),
918                RmpvValue::Integer(1.into()),
919            ),
920            (
921                RmpvValue::String("height".into()),
922                RmpvValue::Integer(2.into()),
923            ),
924        ]);
925
926        let mut buf = Vec::new();
927        rmpv::encode::write_value(&mut buf, &msg).unwrap();
928
929        let decoded: crate::protocol::IncomingMessage = Codec::MsgPack.decode(&buf).unwrap();
930        match decoded {
931            crate::protocol::IncomingMessage::ImageOp {
932                op,
933                handle,
934                pixels,
935                width,
936                height,
937                data,
938            } => {
939                assert_eq!(op, "create_image");
940                assert_eq!(handle, "test_img");
941                assert_eq!(pixels, Some(pixel_bytes));
942                assert_eq!(width, Some(1));
943                assert_eq!(height, Some(2));
944                assert!(data.is_none());
945            }
946            other => panic!("expected ImageOp, got {other:?}"),
947        }
948    }
949
950    #[test]
951    fn msgpack_image_op_with_base64_string() {
952        // JSON mode: binary data arrives as base64-encoded string.
953        use crate::protocol::IncomingMessage;
954        use base64::Engine as _;
955
956        let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255];
957        let b64 = base64::engine::general_purpose::STANDARD.encode(&pixel_bytes);
958
959        let json_msg = json!({
960            "type": "image_op",
961            "op": "create_image",
962            "handle": "test_img",
963            "pixels": b64,
964            "width": 1,
965            "height": 1
966        });
967        let json_str = serde_json::to_string(&json_msg).unwrap();
968
969        let decoded: IncomingMessage = Codec::Json.decode(json_str.as_bytes()).unwrap();
970        match decoded {
971            IncomingMessage::ImageOp { pixels, .. } => {
972                assert_eq!(pixels, Some(pixel_bytes));
973            }
974            other => panic!("expected ImageOp, got {other:?}"),
975        }
976    }
977
978    // -- rmpv_to_json unit tests --
979
980    #[test]
981    fn rmpv_to_json_preserves_binary_as_array() {
982        let binary = rmpv::Value::Binary(vec![1, 2, 3]);
983        let result = rmpv_to_json(binary);
984        assert_eq!(result, json!([1, 2, 3]));
985    }
986
987    #[test]
988    fn rmpv_to_json_handles_nested_map() {
989        let val = rmpv::Value::Map(vec![
990            (
991                rmpv::Value::String("key".into()),
992                rmpv::Value::String("val".into()),
993            ),
994            (
995                rmpv::Value::String("num".into()),
996                rmpv::Value::Integer(42.into()),
997            ),
998        ]);
999        let result = rmpv_to_json(val);
1000        assert_eq!(result, json!({"key": "val", "num": 42}));
1001    }
1002
1003    // -- detect --
1004
1005    #[test]
1006    fn detect_json_from_brace() {
1007        assert_eq!(Codec::detect_from_first_byte(b'{'), Codec::Json);
1008    }
1009
1010    #[test]
1011    fn detect_msgpack_from_zero() {
1012        assert_eq!(Codec::detect_from_first_byte(0x00), Codec::MsgPack);
1013    }
1014
1015    #[test]
1016    fn detect_msgpack_from_fixmap() {
1017        assert_eq!(Codec::detect_from_first_byte(0x85), Codec::MsgPack);
1018    }
1019
1020    #[test]
1021    fn display_format() {
1022        assert_eq!(Codec::Json.to_string(), "json");
1023        assert_eq!(Codec::MsgPack.to_string(), "msgpack");
1024    }
1025
1026    // -- Additional rmpv_to_json coverage --
1027
1028    #[test]
1029    fn rmpv_to_json_deeply_nested_maps() {
1030        // Nested map: {"outer": {"inner": {"deep": 42}}}
1031        let val = rmpv::Value::Map(vec![(
1032            rmpv::Value::String("outer".into()),
1033            rmpv::Value::Map(vec![(
1034                rmpv::Value::String("inner".into()),
1035                rmpv::Value::Map(vec![(
1036                    rmpv::Value::String("deep".into()),
1037                    rmpv::Value::Integer(42.into()),
1038                )]),
1039            )]),
1040        )]);
1041        let result = rmpv_to_json(val);
1042        assert_eq!(result, json!({"outer": {"inner": {"deep": 42}}}));
1043    }
1044
1045    #[test]
1046    fn rmpv_to_json_binary_in_nested_map() {
1047        // Binary data nested inside a map should be preserved as byte arrays.
1048        let val = rmpv::Value::Map(vec![
1049            (
1050                rmpv::Value::String("name".into()),
1051                rmpv::Value::String("img".into()),
1052            ),
1053            (
1054                rmpv::Value::String("pixels".into()),
1055                rmpv::Value::Binary(vec![255, 128, 0, 255]),
1056            ),
1057        ]);
1058        let result = rmpv_to_json(val);
1059        assert_eq!(result["name"], json!("img"));
1060        assert_eq!(result["pixels"], json!([255, 128, 0, 255]));
1061    }
1062
1063    #[test]
1064    fn msgpack_roundtrip_with_binary_field() {
1065        // Encode a message containing binary data via msgpack, decode it,
1066        // and verify the binary field comes through as a byte array.
1067        use rmpv::Value as RmpvValue;
1068
1069        let raw_bytes: Vec<u8> = vec![0xDE, 0xAD, 0xBE, 0xEF];
1070        let msg = RmpvValue::Map(vec![
1071            (
1072                RmpvValue::String("type".into()),
1073                RmpvValue::String("alpha".into()),
1074            ),
1075            (
1076                RmpvValue::String("value".into()),
1077                RmpvValue::String("hello".into()),
1078            ),
1079            (
1080                RmpvValue::String("payload".into()),
1081                RmpvValue::Binary(raw_bytes.clone()),
1082            ),
1083        ]);
1084
1085        // Encode to raw msgpack bytes.
1086        let mut buf = Vec::new();
1087        rmpv::encode::write_value(&mut buf, &msg).unwrap();
1088
1089        // The rmpv_to_json path preserves binary as an array of u8.
1090        let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &buf[..]).unwrap();
1091        let json_val = rmpv_to_json(rmpv_val);
1092
1093        // The tagged enum fields decode fine.
1094        assert_eq!(json_val["type"], "alpha");
1095        assert_eq!(json_val["value"], "hello");
1096
1097        // Binary preserved as array of byte values.
1098        let payload = json_val["payload"].as_array().unwrap();
1099        let bytes: Vec<u8> = payload.iter().map(|v| v.as_u64().unwrap() as u8).collect();
1100        assert_eq!(bytes, raw_bytes);
1101    }
1102
1103    #[test]
1104    fn rmpv_to_json_handles_nil_and_bool() {
1105        assert_eq!(rmpv_to_json(rmpv::Value::Nil), json!(null));
1106        assert_eq!(rmpv_to_json(rmpv::Value::Boolean(true)), json!(true));
1107        assert_eq!(rmpv_to_json(rmpv::Value::Boolean(false)), json!(false));
1108    }
1109
1110    // -- check_msgpack_depth --
1111
1112    #[test]
1113    fn msgpack_depth_check_accepts_flat_map() {
1114        let val = json!({"a": 1, "b": "hello", "c": true});
1115        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1116        assert!(check_msgpack_depth(&bytes, 128).is_ok());
1117    }
1118
1119    #[test]
1120    fn msgpack_depth_check_accepts_nested_within_limit() {
1121        // 3 levels: {"outer": {"middle": {"inner": 42}}}
1122        let val = json!({"outer": {"middle": {"inner": 42}}});
1123        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1124        assert!(check_msgpack_depth(&bytes, 3).is_ok());
1125    }
1126
1127    #[test]
1128    fn msgpack_depth_check_rejects_beyond_limit() {
1129        // 3 nested maps exceeds a limit of 2
1130        let val = json!({"a": {"b": {"c": 1}}});
1131        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1132        assert!(check_msgpack_depth(&bytes, 2).is_err());
1133    }
1134
1135    #[test]
1136    fn msgpack_depth_check_accepts_flat_array() {
1137        let val = json!([1, 2, 3, 4, 5]);
1138        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1139        assert!(check_msgpack_depth(&bytes, 1).is_ok());
1140    }
1141
1142    #[test]
1143    fn msgpack_depth_check_nested_arrays() {
1144        let val = json!([[[42]]]);
1145        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1146        assert!(check_msgpack_depth(&bytes, 3).is_ok());
1147        assert!(check_msgpack_depth(&bytes, 2).is_err());
1148    }
1149
1150    #[test]
1151    fn msgpack_depth_check_mixed_containers() {
1152        let val = json!({"list": [{"nested": true}]});
1153        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1154        // depth: map(1) -> array(2) -> map(3) = 3 levels
1155        assert!(check_msgpack_depth(&bytes, 3).is_ok());
1156        assert!(check_msgpack_depth(&bytes, 2).is_err());
1157    }
1158
1159    #[test]
1160    fn msgpack_depth_check_empty_containers() {
1161        let val = json!({"empty_map": {}, "empty_arr": []});
1162        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1163        assert!(check_msgpack_depth(&bytes, 2).is_ok());
1164    }
1165
1166    #[test]
1167    fn msgpack_depth_check_sibling_arrays_dont_add_depth() {
1168        // [[1,2], [3,4]] has depth 2 (outer array -> inner array), not 3
1169        let val = json!([[1, 2], [3, 4]]);
1170        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1171        assert!(check_msgpack_depth(&bytes, 2).is_ok());
1172    }
1173
1174    #[test]
1175    fn msgpack_depth_check_binary_data() {
1176        use rmpv::Value as V;
1177        let val = V::Map(vec![(
1178            V::String("data".into()),
1179            V::Binary(vec![0xDE, 0xAD]),
1180        )]);
1181        let mut bytes = Vec::new();
1182        rmpv::encode::write_value(&mut bytes, &val).unwrap();
1183        assert!(check_msgpack_depth(&bytes, 1).is_ok());
1184    }
1185
1186    #[test]
1187    fn msgpack_depth_check_deeply_nested_rejects() {
1188        // Build a deeply nested msgpack: {a: {a: {a: ... {a: 1} ...}}}
1189        use rmpv::Value as V;
1190        let depth = 200;
1191        let mut val = V::Integer(1.into());
1192        for _ in 0..depth {
1193            val = V::Map(vec![(V::String("a".into()), val)]);
1194        }
1195        let mut bytes = Vec::new();
1196        rmpv::encode::write_value(&mut bytes, &val).unwrap();
1197
1198        assert!(check_msgpack_depth(&bytes, 128).is_err());
1199        assert!(check_msgpack_depth(&bytes, 200).is_ok());
1200    }
1201
1202    #[test]
1203    fn msgpack_decode_rejects_deeply_nested() {
1204        // Verify the full decode path rejects deeply nested payloads.
1205        use rmpv::Value as V;
1206        let mut val = V::Integer(1.into());
1207        for _ in 0..200 {
1208            val = V::Map(vec![(V::String("a".into()), val)]);
1209        }
1210        let mut bytes = Vec::new();
1211        rmpv::encode::write_value(&mut bytes, &val).unwrap();
1212
1213        let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
1214        assert!(result.is_err());
1215        assert!(result.unwrap_err().contains("depth"));
1216    }
1217
1218    #[test]
1219    fn msgpack_depth_check_truncated_payload_does_not_panic() {
1220        // Truncated payloads must not panic. They may return Ok (for
1221        // scalars or truncated length fields) or Err (for containers
1222        // whose declared count exceeds remaining bytes).
1223        let val = json!({"a": {"b": [1, 2, 3]}});
1224        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1225        for cut in [1, 3, 5, bytes.len() / 2] {
1226            let _ = check_msgpack_depth(&bytes[..cut], 128);
1227        }
1228        // Truncated containers: declared children > 0 remaining bytes
1229        assert!(check_msgpack_depth(&[0x81], 128).is_err()); // fixmap(1): 2 children, 0 bytes
1230        assert!(check_msgpack_depth(&[0x91], 128).is_err()); // fixarray(1): 1 child, 0 bytes
1231        // Truncated length fields: loop breaks before parsing children
1232        assert!(check_msgpack_depth(&[0xdc], 128).is_ok()); // array16, no length bytes
1233        assert!(check_msgpack_depth(&[0xde, 0x00], 128).is_ok()); // map16, partial length
1234    }
1235
1236    #[test]
1237    fn msgpack_depth_check_empty_input() {
1238        assert!(check_msgpack_depth(&[], 128).is_ok());
1239    }
1240
1241    #[test]
1242    fn msgpack_depth_check_scalars_only() {
1243        // Pure scalar value (no containers) should always pass.
1244        let val = json!(42);
1245        let bytes = rmp_serde::to_vec_named(&val).unwrap();
1246        assert!(check_msgpack_depth(&bytes, 0).is_ok());
1247    }
1248
1249    #[test]
1250    fn msgpack_depth_check_rejects_forged_element_count() {
1251        // map32 declaring 2^32-1 entries but only a few bytes of actual
1252        // data. Without the element count check, rmpv::read_value would
1253        // try Vec::with_capacity(4 billion) and OOM.
1254        let mut bytes = vec![0xdf]; // map32 marker
1255        bytes.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes()); // 4 billion entries
1256        bytes.extend_from_slice(&[0xa1, b'k', 0x01]); // one tiny key-value pair
1257
1258        let result = check_msgpack_depth(&bytes, 128);
1259        assert!(result.is_err());
1260        assert!(result.unwrap_err().contains("elements"));
1261    }
1262
1263    #[test]
1264    fn msgpack_decode_rejects_forged_element_count() {
1265        // Verify the full decode path rejects forged counts.
1266        let mut bytes = vec![0xdd]; // array32 marker
1267        bytes.extend_from_slice(&0x7FFF_FFFFu32.to_be_bytes()); // 2 billion entries
1268        bytes.push(0x01); // one element
1269
1270        let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
1271        assert!(result.is_err());
1272        assert!(result.unwrap_err().contains("elements"));
1273    }
1274
1275    // -- json_to_rmpv ---------------------------------------------------------
1276
1277    #[test]
1278    fn json_to_rmpv_scalars() {
1279        assert_eq!(json_to_rmpv(json!(null)), rmpv::Value::Nil);
1280        assert_eq!(json_to_rmpv(json!(true)), rmpv::Value::Boolean(true));
1281        assert_eq!(json_to_rmpv(json!(42)), rmpv::Value::Integer(42.into()));
1282        assert_eq!(json_to_rmpv(json!(2.5)), rmpv::Value::F64(2.5));
1283        assert_eq!(
1284            json_to_rmpv(json!("hello")),
1285            rmpv::Value::String("hello".into())
1286        );
1287    }
1288
1289    #[test]
1290    fn json_to_rmpv_nested() {
1291        let val = json!({"key": [1, "two", null]});
1292        let rmpv = json_to_rmpv(val);
1293        match rmpv {
1294            rmpv::Value::Map(entries) => {
1295                assert_eq!(entries.len(), 1);
1296                let (k, v) = &entries[0];
1297                assert_eq!(k, &rmpv::Value::String("key".into()));
1298                match v {
1299                    rmpv::Value::Array(arr) => {
1300                        assert_eq!(arr.len(), 3);
1301                        assert_eq!(arr[0], rmpv::Value::Integer(1.into()));
1302                        assert_eq!(arr[2], rmpv::Value::Nil);
1303                    }
1304                    other => panic!("expected array, got {other:?}"),
1305                }
1306            }
1307            other => panic!("expected map, got {other:?}"),
1308        }
1309    }
1310
1311    // -- encode_binary_message ------------------------------------------------
1312
1313    #[test]
1314    fn encode_binary_message_json_without_binary() {
1315        let mut map = serde_json::Map::new();
1316        map.insert("type".to_string(), json!("test"));
1317        map.insert("id".to_string(), json!("t1"));
1318
1319        let bytes = Codec::Json.encode_binary_message(map, None).unwrap();
1320        let s = std::str::from_utf8(&bytes).unwrap();
1321        assert!(s.ends_with('\n'));
1322        let parsed: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
1323        assert_eq!(parsed["type"], "test");
1324        assert_eq!(parsed["id"], "t1");
1325        assert!(parsed.get("rgba").is_none());
1326    }
1327
1328    #[test]
1329    fn encode_binary_message_json_with_binary() {
1330        use base64::Engine as _;
1331
1332        let mut map = serde_json::Map::new();
1333        map.insert("type".to_string(), json!("screenshot"));
1334        let pixel_data = vec![255u8, 0, 128, 64];
1335
1336        let bytes = Codec::Json
1337            .encode_binary_message(map, Some(("rgba", &pixel_data)))
1338            .unwrap();
1339        let parsed: serde_json::Value = serde_json::from_slice(&bytes[..bytes.len() - 1]).unwrap();
1340        let b64 = parsed["rgba"].as_str().unwrap();
1341        let decoded = base64::engine::general_purpose::STANDARD
1342            .decode(b64)
1343            .unwrap();
1344        assert_eq!(decoded, pixel_data);
1345    }
1346
1347    #[test]
1348    fn encode_binary_message_msgpack_with_binary() {
1349        let mut map = serde_json::Map::new();
1350        map.insert("type".to_string(), json!("screenshot"));
1351        map.insert("id".to_string(), json!("s1"));
1352        let pixel_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
1353
1354        let bytes = Codec::MsgPack
1355            .encode_binary_message(map, Some(("rgba", &pixel_data)))
1356            .unwrap();
1357
1358        // Strip 4-byte length prefix
1359        let payload = &bytes[4..];
1360        let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &payload[..]).unwrap();
1361
1362        // Find the rgba field -- should be native Binary, not a string
1363        match rmpv_val {
1364            rmpv::Value::Map(entries) => {
1365                let rgba_entry = entries
1366                    .iter()
1367                    .find(|(k, _)| k == &rmpv::Value::String("rgba".into()));
1368                match rgba_entry {
1369                    Some((_, rmpv::Value::Binary(data))) => {
1370                        assert_eq!(data, &pixel_data);
1371                    }
1372                    other => panic!("expected Binary rgba field, got {other:?}"),
1373                }
1374            }
1375            other => panic!("expected Map, got {other:?}"),
1376        }
1377    }
1378
1379    #[test]
1380    fn encode_binary_message_msgpack_roundtrip_non_binary_fields() {
1381        let mut map = serde_json::Map::new();
1382        map.insert("type".to_string(), json!("test"));
1383        map.insert("count".to_string(), json!(42));
1384        map.insert("nested".to_string(), json!({"a": [1, 2]}));
1385
1386        let bytes = Codec::MsgPack.encode_binary_message(map, None).unwrap();
1387        let decoded: serde_json::Value = Codec::MsgPack.decode(&bytes[4..]).unwrap();
1388        assert_eq!(decoded["type"], "test");
1389        assert_eq!(decoded["count"], 42);
1390        assert_eq!(decoded["nested"]["a"][0], 1);
1391    }
1392}