Skip to main content

rpc_runtime_codec_msgpack/
lib.rs

1use std::io::Cursor;
2
3use rmpv::{Integer, Value, decode, encode};
4use rpc_runtime_core::{
5    CapabilityFlags, Envelope, Goodbye, Hello, HelloAck, InstanceId, MessageKind, MethodId,
6    Notification, NotificationId, Options, Request, RequestId, ResponseError, ResponseOk, Role,
7    ServiceGuid,
8};
9use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
10use thiserror::Error;
11
12pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct CodecLimits {
16    pub max_message_size: usize,
17}
18
19impl Default for CodecLimits {
20    fn default() -> Self {
21        Self {
22            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
23        }
24    }
25}
26
27#[derive(Debug, Error)]
28#[error("{error}")]
29pub struct CodecError {
30    pub error: RuntimeError,
31}
32
33impl CodecError {
34    fn protocol(code: RuntimeErrorCode, message: impl Into<String>) -> Self {
35        Self {
36            error: RuntimeError::protocol(code, message),
37        }
38    }
39
40    pub fn into_runtime_error(self) -> RuntimeError {
41        self.error
42    }
43}
44
45pub fn encode_envelope(envelope: &Envelope) -> Result<Vec<u8>, CodecError> {
46    let value = envelope_to_value(envelope);
47    let mut bytes = Vec::new();
48    encode::write_value(&mut bytes, &value).map_err(|err| {
49        CodecError::protocol(
50            RuntimeErrorCode::PayloadEncodeFailed,
51            format!("failed to encode MessagePack envelope: {err}"),
52        )
53    })?;
54    Ok(bytes)
55}
56
57pub fn decode_envelope(bytes: &[u8], limits: CodecLimits) -> Result<Envelope, CodecError> {
58    if bytes.len() > limits.max_message_size {
59        return Err(CodecError::protocol(
60            RuntimeErrorCode::InvalidEnvelope,
61            format!(
62                "frame size {} exceeds limit {}",
63                bytes.len(),
64                limits.max_message_size
65            ),
66        ));
67    }
68
69    let mut cursor = Cursor::new(bytes);
70    let value = decode::read_value(&mut cursor).map_err(|err| {
71        CodecError::protocol(
72            RuntimeErrorCode::PayloadDecodeFailed,
73            format!("failed to decode MessagePack envelope: {err}"),
74        )
75    })?;
76    if cursor.position() != bytes.len() as u64 {
77        return Err(CodecError::protocol(
78            RuntimeErrorCode::InvalidEnvelope,
79            "trailing bytes after top-level envelope",
80        ));
81    }
82    value_to_envelope(value)
83}
84
85pub fn encode_service_guid(guid: ServiceGuid) -> Value {
86    Value::Binary(guid.get().as_bytes().to_vec())
87}
88
89pub fn decode_service_guid(value: &Value) -> Result<ServiceGuid, CodecError> {
90    match value {
91        Value::Binary(bytes) if bytes.len() == 16 => {
92            let uuid = uuid::Uuid::from_slice(bytes).map_err(|err| {
93                CodecError::protocol(
94                    RuntimeErrorCode::InvalidEnvelope,
95                    format!("invalid service GUID bytes: {err}"),
96                )
97            })?;
98            Ok(ServiceGuid::new(uuid))
99        }
100        Value::Binary(bytes) => Err(CodecError::protocol(
101            RuntimeErrorCode::InvalidEnvelope,
102            format!("service GUID must be exactly 16 bytes, got {}", bytes.len()),
103        )),
104        _ => Err(CodecError::protocol(
105            RuntimeErrorCode::InvalidEnvelope,
106            "service GUID must be MessagePack bin(16)",
107        )),
108    }
109}
110
111fn envelope_to_value(envelope: &Envelope) -> Value {
112    match envelope {
113        Envelope::Hello(message) => Value::Array(vec![
114            u8_value(MessageKind::Hello.as_u8()),
115            u64_value(message.protocol_version as u64),
116            u8_value(message.role.as_u8()),
117            u64_value(message.capability_bits.bits()),
118            u64_value(message.max_message_size),
119            options_to_value(&message.options),
120        ]),
121        Envelope::HelloAck(message) => Value::Array(vec![
122            u8_value(MessageKind::HelloAck.as_u8()),
123            u64_value(message.protocol_version as u64),
124            u64_value(message.accepted_capability_bits.bits()),
125            u64_value(message.max_message_size),
126            options_to_value(&message.options),
127        ]),
128        Envelope::Request(message) => Value::Array(vec![
129            u8_value(MessageKind::Request.as_u8()),
130            u64_value(message.request_id.get()),
131            u64_value(message.instance_id.get()),
132            u64_value(message.method_id.get() as u64),
133            message.payload.clone(),
134        ]),
135        Envelope::ResponseOk(message) => Value::Array(vec![
136            u8_value(MessageKind::ResponseOk.as_u8()),
137            u64_value(message.request_id.get()),
138            message.payload.clone(),
139        ]),
140        Envelope::ResponseError(message) => Value::Array(vec![
141            u8_value(MessageKind::ResponseError.as_u8()),
142            u64_value(message.request_id.get()),
143            i64_value(message.error_code as i64),
144            u8_value(message.error_kind),
145            string_option_to_value(message.error_message.as_deref()),
146            message.error_details.clone(),
147        ]),
148        Envelope::Notification(message) => Value::Array(vec![
149            u8_value(MessageKind::Notification.as_u8()),
150            u64_value(message.instance_id.map_or(0, InstanceId::get)),
151            u64_value(message.notification_id.get() as u64),
152            message.payload.clone(),
153        ]),
154        Envelope::Goodbye(message) => Value::Array(vec![
155            u8_value(MessageKind::Goodbye.as_u8()),
156            u64_value(message.reason_code as u64),
157            string_option_to_value(message.message.as_deref()),
158        ]),
159    }
160}
161
162fn value_to_envelope(value: Value) -> Result<Envelope, CodecError> {
163    let fields = match value {
164        Value::Array(fields) => fields,
165        _ => {
166            return Err(CodecError::protocol(
167                RuntimeErrorCode::InvalidEnvelope,
168                "top-level envelope must be a MessagePack array",
169            ));
170        }
171    };
172    let kind = required_u8(
173        fields.first(),
174        "message_kind",
175        RuntimeErrorCode::UnknownMessageKind,
176    )?;
177    match kind {
178        1 => decode_hello(fields),
179        2 => decode_hello_ack(fields),
180        3 => decode_request(fields),
181        4 => decode_response_ok(fields),
182        5 => decode_response_error(fields),
183        6 => decode_notification(fields),
184        7 => Err(CodecError::protocol(
185            RuntimeErrorCode::RequestCancelUnsupported,
186            "CANCEL is reserved and unsupported in v1",
187        )),
188        8 => decode_goodbye(fields),
189        other => Err(CodecError::protocol(
190            RuntimeErrorCode::UnknownMessageKind,
191            format!("unknown message kind `{other}`"),
192        )),
193    }
194}
195
196fn decode_hello(fields: Vec<Value>) -> Result<Envelope, CodecError> {
197    exact_len(&fields, 6)?;
198    Ok(Envelope::Hello(Hello {
199        protocol_version: required_u32(fields.get(1), "protocol_version")?,
200        role: role(required_u8(
201            fields.get(2),
202            "role",
203            RuntimeErrorCode::InvalidEnvelope,
204        )?)?,
205        capability_bits: CapabilityFlags::from_bits_retain(required_u64(
206            fields.get(3),
207            "capability_bits",
208            RuntimeErrorCode::InvalidEnvelope,
209        )?),
210        max_message_size: required_u64(
211            fields.get(4),
212            "max_message_size",
213            RuntimeErrorCode::InvalidEnvelope,
214        )?,
215        options: required_options(fields.get(5), "options")?,
216    }))
217}
218
219fn decode_hello_ack(fields: Vec<Value>) -> Result<Envelope, CodecError> {
220    exact_len(&fields, 5)?;
221    Ok(Envelope::HelloAck(HelloAck {
222        protocol_version: required_u32(fields.get(1), "protocol_version")?,
223        accepted_capability_bits: CapabilityFlags::from_bits_retain(required_u64(
224            fields.get(2),
225            "accepted_capability_bits",
226            RuntimeErrorCode::InvalidEnvelope,
227        )?),
228        max_message_size: required_u64(
229            fields.get(3),
230            "max_message_size",
231            RuntimeErrorCode::InvalidEnvelope,
232        )?,
233        options: required_options(fields.get(4), "options")?,
234    }))
235}
236
237fn decode_request(fields: Vec<Value>) -> Result<Envelope, CodecError> {
238    exact_len(&fields, 5)?;
239    let instance_id = required_instance_id(fields.get(2))?;
240    Ok(Envelope::Request(Request {
241        request_id: RequestId::new(required_u64(
242            fields.get(1),
243            "request_id",
244            RuntimeErrorCode::InvalidRequestId,
245        )?),
246        instance_id,
247        method_id: MethodId::new(required_u32(fields.get(3), "method_id")?),
248        payload: fields[4].clone(),
249    }))
250}
251
252fn decode_response_ok(fields: Vec<Value>) -> Result<Envelope, CodecError> {
253    exact_len(&fields, 3)?;
254    Ok(Envelope::ResponseOk(ResponseOk {
255        request_id: RequestId::new(required_u64(
256            fields.get(1),
257            "request_id",
258            RuntimeErrorCode::InvalidRequestId,
259        )?),
260        payload: fields[2].clone(),
261    }))
262}
263
264fn decode_response_error(fields: Vec<Value>) -> Result<Envelope, CodecError> {
265    exact_len(&fields, 6)?;
266    Ok(Envelope::ResponseError(ResponseError {
267        request_id: RequestId::new(required_u64(
268            fields.get(1),
269            "request_id",
270            RuntimeErrorCode::InvalidRequestId,
271        )?),
272        error_code: required_i32(fields.get(2), "error_code")?,
273        error_kind: required_u8(
274            fields.get(3),
275            "error_kind",
276            RuntimeErrorCode::InvalidEnvelope,
277        )?,
278        error_message: optional_string(fields.get(4), "error_message")?,
279        error_details: fields[5].clone(),
280    }))
281}
282
283fn decode_notification(fields: Vec<Value>) -> Result<Envelope, CodecError> {
284    exact_len(&fields, 4)?;
285    let raw_instance_id = required_u64(
286        fields.get(1),
287        "instance_id",
288        RuntimeErrorCode::InvalidInstanceId,
289    )?;
290    Ok(Envelope::Notification(Notification {
291        instance_id: if raw_instance_id == 0 {
292            None
293        } else {
294            Some(InstanceId::new(raw_instance_id).ok_or_else(|| {
295                CodecError::protocol(
296                    RuntimeErrorCode::InvalidInstanceId,
297                    "instance_id must be non-zero",
298                )
299            })?)
300        },
301        notification_id: NotificationId::new(required_u32(fields.get(2), "notification_id")?),
302        payload: fields[3].clone(),
303    }))
304}
305
306fn decode_goodbye(fields: Vec<Value>) -> Result<Envelope, CodecError> {
307    exact_len(&fields, 3)?;
308    Ok(Envelope::Goodbye(Goodbye {
309        reason_code: required_u32(fields.get(1), "reason_code")?,
310        message: optional_string(fields.get(2), "message")?,
311    }))
312}
313
314fn exact_len(fields: &[Value], expected: usize) -> Result<(), CodecError> {
315    if fields.len() == expected {
316        return Ok(());
317    }
318    Err(CodecError::protocol(
319        RuntimeErrorCode::InvalidEnvelope,
320        format!(
321            "invalid envelope field count: expected {expected}, got {}",
322            fields.len()
323        ),
324    ))
325}
326
327fn role(value: u8) -> Result<Role, CodecError> {
328    match value {
329        1 => Ok(Role::Client),
330        2 => Ok(Role::Server),
331        3 => Ok(Role::Peer),
332        other => Err(CodecError::protocol(
333            RuntimeErrorCode::InvalidEnvelope,
334            format!("invalid role `{other}`"),
335        )),
336    }
337}
338
339fn required_instance_id(value: Option<&Value>) -> Result<InstanceId, CodecError> {
340    let value = required_u64(value, "instance_id", RuntimeErrorCode::InvalidInstanceId)?;
341    InstanceId::new(value).ok_or_else(|| {
342        CodecError::protocol(
343            RuntimeErrorCode::InvalidInstanceId,
344            "request instance_id must be non-zero",
345        )
346    })
347}
348
349fn required_u64(
350    value: Option<&Value>,
351    field: &str,
352    code: RuntimeErrorCode,
353) -> Result<u64, CodecError> {
354    value.and_then(Value::as_u64).ok_or_else(|| {
355        CodecError::protocol(code, format!("field `{field}` must be an unsigned integer"))
356    })
357}
358
359fn required_u32(value: Option<&Value>, field: &str) -> Result<u32, CodecError> {
360    let value = required_u64(value, field, RuntimeErrorCode::InvalidEnvelope)?;
361    u32::try_from(value).map_err(|_| {
362        CodecError::protocol(
363            RuntimeErrorCode::InvalidEnvelope,
364            format!("field `{field}` exceeds u32 range"),
365        )
366    })
367}
368
369fn required_u8(
370    value: Option<&Value>,
371    field: &str,
372    code: RuntimeErrorCode,
373) -> Result<u8, CodecError> {
374    let value = required_u64(value, field, code)?;
375    u8::try_from(value)
376        .map_err(|_| CodecError::protocol(code, format!("field `{field}` exceeds u8 range")))
377}
378
379fn required_i32(value: Option<&Value>, field: &str) -> Result<i32, CodecError> {
380    let value = value.and_then(Value::as_i64).ok_or_else(|| {
381        CodecError::protocol(
382            RuntimeErrorCode::InvalidEnvelope,
383            format!("field `{field}` must be a signed integer"),
384        )
385    })?;
386    i32::try_from(value).map_err(|_| {
387        CodecError::protocol(
388            RuntimeErrorCode::InvalidEnvelope,
389            format!("field `{field}` exceeds i32 range"),
390        )
391    })
392}
393
394fn optional_string(value: Option<&Value>, field: &str) -> Result<Option<String>, CodecError> {
395    match value {
396        Some(Value::Nil) => Ok(None),
397        Some(Value::String(value)) => value
398            .as_str()
399            .map(|value| Some(value.to_string()))
400            .ok_or_else(|| {
401                CodecError::protocol(
402                    RuntimeErrorCode::InvalidEnvelope,
403                    format!("field `{field}` must contain valid UTF-8"),
404                )
405            }),
406        _ => Err(CodecError::protocol(
407            RuntimeErrorCode::InvalidEnvelope,
408            format!("field `{field}` must be string or nil"),
409        )),
410    }
411}
412
413fn required_options(value: Option<&Value>, field: &str) -> Result<Options, CodecError> {
414    match value {
415        Some(Value::Nil) => Ok(Vec::new()),
416        Some(Value::Map(entries)) => entries
417            .iter()
418            .map(|(key, value)| {
419                let key = key.as_str().ok_or_else(|| {
420                    CodecError::protocol(
421                        RuntimeErrorCode::InvalidEnvelope,
422                        format!("field `{field}` option key must be string"),
423                    )
424                })?;
425                if !is_option_scalar(value) {
426                    return Err(CodecError::protocol(
427                        RuntimeErrorCode::InvalidEnvelope,
428                        format!("field `{field}` option value must be scalar"),
429                    ));
430                }
431                Ok((key.to_string(), value.clone()))
432            })
433            .collect(),
434        _ => Err(CodecError::protocol(
435            RuntimeErrorCode::InvalidEnvelope,
436            format!("field `{field}` must be map or nil"),
437        )),
438    }
439}
440
441fn is_option_scalar(value: &Value) -> bool {
442    matches!(
443        value,
444        Value::Nil | Value::Boolean(_) | Value::Integer(_) | Value::String(_) | Value::Binary(_)
445    )
446}
447
448fn options_to_value(options: &Options) -> Value {
449    if options.is_empty() {
450        return Value::Nil;
451    }
452    Value::Map(
453        options
454            .iter()
455            .map(|(key, value)| (Value::from(key.as_str()), value.clone()))
456            .collect(),
457    )
458}
459
460fn string_option_to_value(value: Option<&str>) -> Value {
461    value.map_or(Value::Nil, Value::from)
462}
463
464fn u8_value(value: u8) -> Value {
465    u64_value(value as u64)
466}
467
468fn u64_value(value: u64) -> Value {
469    Value::Integer(Integer::from(value))
470}
471
472fn i64_value(value: i64) -> Value {
473    Value::Integer(Integer::from(value))
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use rpc_runtime_core::{CapabilityFlags, RUNTIME_PROTOCOL_VERSION};
480    use uuid::Uuid;
481
482    fn roundtrip(envelope: Envelope) {
483        let bytes = encode_envelope(&envelope).expect("encode envelope");
484        let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode envelope");
485        assert_eq!(decoded, envelope);
486    }
487
488    #[test]
489    fn hello_roundtrips() {
490        roundtrip(Envelope::Hello(Hello {
491            protocol_version: RUNTIME_PROTOCOL_VERSION,
492            role: Role::Client,
493            capability_bits: CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
494                | CapabilityFlags::GOODBYE,
495            max_message_size: 4096,
496            options: vec![("implementation".to_string(), Value::from("test"))],
497        }));
498    }
499
500    #[test]
501    fn hello_ack_roundtrips() {
502        roundtrip(Envelope::HelloAck(HelloAck {
503            protocol_version: RUNTIME_PROTOCOL_VERSION,
504            accepted_capability_bits: CapabilityFlags::SERVICE_ACTIVATION,
505            max_message_size: 8192,
506            options: Vec::new(),
507        }));
508    }
509
510    #[test]
511    fn request_roundtrips() {
512        roundtrip(Envelope::Request(Request {
513            request_id: RequestId::new(10),
514            instance_id: InstanceId::new(22).expect("non-zero instance id"),
515            method_id: MethodId::new(3),
516            payload: Value::Array(vec![Value::from("card")]),
517        }));
518    }
519
520    #[test]
521    fn response_ok_roundtrips() {
522        roundtrip(Envelope::ResponseOk(ResponseOk {
523            request_id: RequestId::new(10),
524            payload: Value::Nil,
525        }));
526    }
527
528    #[test]
529    fn response_error_roundtrips() {
530        roundtrip(Envelope::ResponseError(ResponseError {
531            request_id: RequestId::new(10),
532            error_code: 1007,
533            error_kind: 3,
534            error_message: Some("method missing".to_string()),
535            error_details: Value::Nil,
536        }));
537    }
538
539    #[test]
540    fn notification_roundtrips_with_global_instance() {
541        roundtrip(Envelope::Notification(Notification {
542            instance_id: None,
543            notification_id: NotificationId::new(4),
544            payload: Value::from(true),
545        }));
546    }
547
548    #[test]
549    fn goodbye_roundtrips() {
550        roundtrip(Envelope::Goodbye(Goodbye {
551            reason_code: 1,
552            message: Some("shutdown".to_string()),
553        }));
554    }
555
556    #[test]
557    fn spec_request_shape_decodes() {
558        let value = Value::Array(vec![
559            Value::from(3),
560            Value::from(42),
561            Value::from(7),
562            Value::from(2),
563            Value::Nil,
564        ]);
565        let mut bytes = Vec::new();
566        encode::write_value(&mut bytes, &value).expect("encode raw shape");
567
568        let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode request");
569        assert_eq!(
570            decoded,
571            Envelope::Request(Request {
572                request_id: RequestId::new(42),
573                instance_id: InstanceId::new(7).expect("non-zero instance id"),
574                method_id: MethodId::new(2),
575                payload: Value::Nil,
576            })
577        );
578    }
579
580    #[test]
581    fn non_array_fails() {
582        let bytes = encode_raw(&Value::from("not-an-array"));
583        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
584        assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
585    }
586
587    #[test]
588    fn unknown_kind_fails() {
589        let bytes = encode_raw(&Value::Array(vec![Value::from(99)]));
590        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
591        assert_eq!(err.error.code, RuntimeErrorCode::UnknownMessageKind);
592    }
593
594    #[test]
595    fn wrong_field_count_fails() {
596        let bytes = encode_raw(&Value::Array(vec![
597            Value::from(4),
598            Value::from(1),
599            Value::Nil,
600            Value::Nil,
601        ]));
602        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
603        assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
604    }
605
606    #[test]
607    fn request_instance_zero_fails() {
608        let bytes = encode_raw(&Value::Array(vec![
609            Value::from(3),
610            Value::from(1),
611            Value::from(0),
612            Value::from(2),
613            Value::Nil,
614        ]));
615        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
616        assert_eq!(err.error.code, RuntimeErrorCode::InvalidInstanceId);
617    }
618
619    #[test]
620    fn scalar_type_mismatch_fails() {
621        let bytes = encode_raw(&Value::Array(vec![
622            Value::from(3),
623            Value::from("bad-request-id"),
624            Value::from(1),
625            Value::from(2),
626            Value::Nil,
627        ]));
628        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
629        assert_eq!(err.error.code, RuntimeErrorCode::InvalidRequestId);
630    }
631
632    #[test]
633    fn oversized_frame_fails() {
634        let bytes = encode_envelope(&Envelope::ResponseOk(ResponseOk {
635            request_id: RequestId::new(1),
636            payload: Value::Nil,
637        }))
638        .expect("encode");
639        let err = decode_envelope(
640            &bytes,
641            CodecLimits {
642                max_message_size: bytes.len() - 1,
643            },
644        )
645        .expect_err("must fail");
646        assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
647    }
648
649    #[test]
650    fn service_guid_uses_raw_16_byte_binary_encoding() {
651        let uuid = Uuid::parse_str("d7d0c9a0-2eb4-4d2d-a3f9-7a4b2875b7e1").expect("uuid");
652        let guid = ServiceGuid::new(uuid);
653
654        let encoded = encode_service_guid(guid);
655        assert_eq!(encoded, Value::Binary(uuid.as_bytes().to_vec()));
656
657        let decoded = decode_service_guid(&encoded).expect("decode guid");
658        assert_eq!(decoded, guid);
659    }
660
661    #[test]
662    fn service_guid_rejects_wrong_binary_length() {
663        let err = decode_service_guid(&Value::Binary(vec![1, 2, 3])).expect_err("must fail");
664        assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
665    }
666
667    #[test]
668    fn cancel_fails_as_unsupported() {
669        let bytes = encode_raw(&Value::Array(vec![Value::from(7)]));
670        let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
671        assert_eq!(err.error.code, RuntimeErrorCode::RequestCancelUnsupported);
672    }
673
674    #[test]
675    fn performance_smoke_for_request_codec_path() {
676        let envelope = Envelope::Request(Request {
677            request_id: RequestId::new(1),
678            instance_id: InstanceId::new(1).expect("non-zero instance id"),
679            method_id: MethodId::new(1),
680            payload: Value::Nil,
681        });
682        for _ in 0..10_000 {
683            let bytes = encode_envelope(&envelope).expect("encode");
684            let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode");
685            assert!(matches!(decoded, Envelope::Request(_)));
686        }
687    }
688
689    fn encode_raw(value: &Value) -> Vec<u8> {
690        let mut bytes = Vec::new();
691        encode::write_value(&mut bytes, value).expect("encode raw value");
692        bytes
693    }
694}