1use std::io::Cursor;
2
3use rmpv::{Integer, Value, decode, encode};
4use rpc_runtime_core::{
5 CapabilityFlags, Envelope, Goodbye, Hello, HelloAck, InstanceId, MessageKind, MethodId,
6 Notification, NotificationId, Options, Request, RequestId, ResponseError, ResponseOk, Role,
7 ServiceGuid,
8};
9use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
10use thiserror::Error;
11
12pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct CodecLimits {
16 pub max_message_size: usize,
17}
18
19impl Default for CodecLimits {
20 fn default() -> Self {
21 Self {
22 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
23 }
24 }
25}
26
27#[derive(Debug, Error)]
28#[error("{error}")]
29pub struct CodecError {
30 pub error: RuntimeError,
31}
32
33impl CodecError {
34 fn protocol(code: RuntimeErrorCode, message: impl Into<String>) -> Self {
35 Self {
36 error: RuntimeError::protocol(code, message),
37 }
38 }
39
40 pub fn into_runtime_error(self) -> RuntimeError {
41 self.error
42 }
43}
44
45pub fn encode_envelope(envelope: &Envelope) -> Result<Vec<u8>, CodecError> {
46 let value = envelope_to_value(envelope);
47 let mut bytes = Vec::new();
48 encode::write_value(&mut bytes, &value).map_err(|err| {
49 CodecError::protocol(
50 RuntimeErrorCode::PayloadEncodeFailed,
51 format!("failed to encode MessagePack envelope: {err}"),
52 )
53 })?;
54 Ok(bytes)
55}
56
57pub fn decode_envelope(bytes: &[u8], limits: CodecLimits) -> Result<Envelope, CodecError> {
58 if bytes.len() > limits.max_message_size {
59 return Err(CodecError::protocol(
60 RuntimeErrorCode::InvalidEnvelope,
61 format!(
62 "frame size {} exceeds limit {}",
63 bytes.len(),
64 limits.max_message_size
65 ),
66 ));
67 }
68
69 let mut cursor = Cursor::new(bytes);
70 let value = decode::read_value(&mut cursor).map_err(|err| {
71 CodecError::protocol(
72 RuntimeErrorCode::PayloadDecodeFailed,
73 format!("failed to decode MessagePack envelope: {err}"),
74 )
75 })?;
76 if cursor.position() != bytes.len() as u64 {
77 return Err(CodecError::protocol(
78 RuntimeErrorCode::InvalidEnvelope,
79 "trailing bytes after top-level envelope",
80 ));
81 }
82 value_to_envelope(value)
83}
84
85pub fn encode_service_guid(guid: ServiceGuid) -> Value {
86 Value::Binary(guid.get().as_bytes().to_vec())
87}
88
89pub fn decode_service_guid(value: &Value) -> Result<ServiceGuid, CodecError> {
90 match value {
91 Value::Binary(bytes) if bytes.len() == 16 => {
92 let uuid = uuid::Uuid::from_slice(bytes).map_err(|err| {
93 CodecError::protocol(
94 RuntimeErrorCode::InvalidEnvelope,
95 format!("invalid service GUID bytes: {err}"),
96 )
97 })?;
98 Ok(ServiceGuid::new(uuid))
99 }
100 Value::Binary(bytes) => Err(CodecError::protocol(
101 RuntimeErrorCode::InvalidEnvelope,
102 format!("service GUID must be exactly 16 bytes, got {}", bytes.len()),
103 )),
104 _ => Err(CodecError::protocol(
105 RuntimeErrorCode::InvalidEnvelope,
106 "service GUID must be MessagePack bin(16)",
107 )),
108 }
109}
110
111fn envelope_to_value(envelope: &Envelope) -> Value {
112 match envelope {
113 Envelope::Hello(message) => Value::Array(vec![
114 u8_value(MessageKind::Hello.as_u8()),
115 u64_value(message.protocol_version as u64),
116 u8_value(message.role.as_u8()),
117 u64_value(message.capability_bits.bits()),
118 u64_value(message.max_message_size),
119 options_to_value(&message.options),
120 ]),
121 Envelope::HelloAck(message) => Value::Array(vec![
122 u8_value(MessageKind::HelloAck.as_u8()),
123 u64_value(message.protocol_version as u64),
124 u64_value(message.accepted_capability_bits.bits()),
125 u64_value(message.max_message_size),
126 options_to_value(&message.options),
127 ]),
128 Envelope::Request(message) => Value::Array(vec![
129 u8_value(MessageKind::Request.as_u8()),
130 u64_value(message.request_id.get()),
131 u64_value(message.instance_id.get()),
132 u64_value(message.method_id.get() as u64),
133 message.payload.clone(),
134 ]),
135 Envelope::ResponseOk(message) => Value::Array(vec![
136 u8_value(MessageKind::ResponseOk.as_u8()),
137 u64_value(message.request_id.get()),
138 message.payload.clone(),
139 ]),
140 Envelope::ResponseError(message) => Value::Array(vec![
141 u8_value(MessageKind::ResponseError.as_u8()),
142 u64_value(message.request_id.get()),
143 i64_value(message.error_code as i64),
144 u8_value(message.error_kind),
145 string_option_to_value(message.error_message.as_deref()),
146 message.error_details.clone(),
147 ]),
148 Envelope::Notification(message) => Value::Array(vec![
149 u8_value(MessageKind::Notification.as_u8()),
150 u64_value(message.instance_id.map_or(0, InstanceId::get)),
151 u64_value(message.notification_id.get() as u64),
152 message.payload.clone(),
153 ]),
154 Envelope::Goodbye(message) => Value::Array(vec![
155 u8_value(MessageKind::Goodbye.as_u8()),
156 u64_value(message.reason_code as u64),
157 string_option_to_value(message.message.as_deref()),
158 ]),
159 }
160}
161
162fn value_to_envelope(value: Value) -> Result<Envelope, CodecError> {
163 let fields = match value {
164 Value::Array(fields) => fields,
165 _ => {
166 return Err(CodecError::protocol(
167 RuntimeErrorCode::InvalidEnvelope,
168 "top-level envelope must be a MessagePack array",
169 ));
170 }
171 };
172 let kind = required_u8(
173 fields.first(),
174 "message_kind",
175 RuntimeErrorCode::UnknownMessageKind,
176 )?;
177 match kind {
178 1 => decode_hello(fields),
179 2 => decode_hello_ack(fields),
180 3 => decode_request(fields),
181 4 => decode_response_ok(fields),
182 5 => decode_response_error(fields),
183 6 => decode_notification(fields),
184 7 => Err(CodecError::protocol(
185 RuntimeErrorCode::RequestCancelUnsupported,
186 "CANCEL is reserved and unsupported in v1",
187 )),
188 8 => decode_goodbye(fields),
189 other => Err(CodecError::protocol(
190 RuntimeErrorCode::UnknownMessageKind,
191 format!("unknown message kind `{other}`"),
192 )),
193 }
194}
195
196fn decode_hello(fields: Vec<Value>) -> Result<Envelope, CodecError> {
197 exact_len(&fields, 6)?;
198 Ok(Envelope::Hello(Hello {
199 protocol_version: required_u32(fields.get(1), "protocol_version")?,
200 role: role(required_u8(
201 fields.get(2),
202 "role",
203 RuntimeErrorCode::InvalidEnvelope,
204 )?)?,
205 capability_bits: CapabilityFlags::from_bits_retain(required_u64(
206 fields.get(3),
207 "capability_bits",
208 RuntimeErrorCode::InvalidEnvelope,
209 )?),
210 max_message_size: required_u64(
211 fields.get(4),
212 "max_message_size",
213 RuntimeErrorCode::InvalidEnvelope,
214 )?,
215 options: required_options(fields.get(5), "options")?,
216 }))
217}
218
219fn decode_hello_ack(fields: Vec<Value>) -> Result<Envelope, CodecError> {
220 exact_len(&fields, 5)?;
221 Ok(Envelope::HelloAck(HelloAck {
222 protocol_version: required_u32(fields.get(1), "protocol_version")?,
223 accepted_capability_bits: CapabilityFlags::from_bits_retain(required_u64(
224 fields.get(2),
225 "accepted_capability_bits",
226 RuntimeErrorCode::InvalidEnvelope,
227 )?),
228 max_message_size: required_u64(
229 fields.get(3),
230 "max_message_size",
231 RuntimeErrorCode::InvalidEnvelope,
232 )?,
233 options: required_options(fields.get(4), "options")?,
234 }))
235}
236
237fn decode_request(fields: Vec<Value>) -> Result<Envelope, CodecError> {
238 exact_len(&fields, 5)?;
239 let instance_id = required_instance_id(fields.get(2))?;
240 Ok(Envelope::Request(Request {
241 request_id: RequestId::new(required_u64(
242 fields.get(1),
243 "request_id",
244 RuntimeErrorCode::InvalidRequestId,
245 )?),
246 instance_id,
247 method_id: MethodId::new(required_u32(fields.get(3), "method_id")?),
248 payload: fields[4].clone(),
249 }))
250}
251
252fn decode_response_ok(fields: Vec<Value>) -> Result<Envelope, CodecError> {
253 exact_len(&fields, 3)?;
254 Ok(Envelope::ResponseOk(ResponseOk {
255 request_id: RequestId::new(required_u64(
256 fields.get(1),
257 "request_id",
258 RuntimeErrorCode::InvalidRequestId,
259 )?),
260 payload: fields[2].clone(),
261 }))
262}
263
264fn decode_response_error(fields: Vec<Value>) -> Result<Envelope, CodecError> {
265 exact_len(&fields, 6)?;
266 Ok(Envelope::ResponseError(ResponseError {
267 request_id: RequestId::new(required_u64(
268 fields.get(1),
269 "request_id",
270 RuntimeErrorCode::InvalidRequestId,
271 )?),
272 error_code: required_i32(fields.get(2), "error_code")?,
273 error_kind: required_u8(
274 fields.get(3),
275 "error_kind",
276 RuntimeErrorCode::InvalidEnvelope,
277 )?,
278 error_message: optional_string(fields.get(4), "error_message")?,
279 error_details: fields[5].clone(),
280 }))
281}
282
283fn decode_notification(fields: Vec<Value>) -> Result<Envelope, CodecError> {
284 exact_len(&fields, 4)?;
285 let raw_instance_id = required_u64(
286 fields.get(1),
287 "instance_id",
288 RuntimeErrorCode::InvalidInstanceId,
289 )?;
290 Ok(Envelope::Notification(Notification {
291 instance_id: if raw_instance_id == 0 {
292 None
293 } else {
294 Some(InstanceId::new(raw_instance_id).ok_or_else(|| {
295 CodecError::protocol(
296 RuntimeErrorCode::InvalidInstanceId,
297 "instance_id must be non-zero",
298 )
299 })?)
300 },
301 notification_id: NotificationId::new(required_u32(fields.get(2), "notification_id")?),
302 payload: fields[3].clone(),
303 }))
304}
305
306fn decode_goodbye(fields: Vec<Value>) -> Result<Envelope, CodecError> {
307 exact_len(&fields, 3)?;
308 Ok(Envelope::Goodbye(Goodbye {
309 reason_code: required_u32(fields.get(1), "reason_code")?,
310 message: optional_string(fields.get(2), "message")?,
311 }))
312}
313
314fn exact_len(fields: &[Value], expected: usize) -> Result<(), CodecError> {
315 if fields.len() == expected {
316 return Ok(());
317 }
318 Err(CodecError::protocol(
319 RuntimeErrorCode::InvalidEnvelope,
320 format!(
321 "invalid envelope field count: expected {expected}, got {}",
322 fields.len()
323 ),
324 ))
325}
326
327fn role(value: u8) -> Result<Role, CodecError> {
328 match value {
329 1 => Ok(Role::Client),
330 2 => Ok(Role::Server),
331 3 => Ok(Role::Peer),
332 other => Err(CodecError::protocol(
333 RuntimeErrorCode::InvalidEnvelope,
334 format!("invalid role `{other}`"),
335 )),
336 }
337}
338
339fn required_instance_id(value: Option<&Value>) -> Result<InstanceId, CodecError> {
340 let value = required_u64(value, "instance_id", RuntimeErrorCode::InvalidInstanceId)?;
341 InstanceId::new(value).ok_or_else(|| {
342 CodecError::protocol(
343 RuntimeErrorCode::InvalidInstanceId,
344 "request instance_id must be non-zero",
345 )
346 })
347}
348
349fn required_u64(
350 value: Option<&Value>,
351 field: &str,
352 code: RuntimeErrorCode,
353) -> Result<u64, CodecError> {
354 value.and_then(Value::as_u64).ok_or_else(|| {
355 CodecError::protocol(code, format!("field `{field}` must be an unsigned integer"))
356 })
357}
358
359fn required_u32(value: Option<&Value>, field: &str) -> Result<u32, CodecError> {
360 let value = required_u64(value, field, RuntimeErrorCode::InvalidEnvelope)?;
361 u32::try_from(value).map_err(|_| {
362 CodecError::protocol(
363 RuntimeErrorCode::InvalidEnvelope,
364 format!("field `{field}` exceeds u32 range"),
365 )
366 })
367}
368
369fn required_u8(
370 value: Option<&Value>,
371 field: &str,
372 code: RuntimeErrorCode,
373) -> Result<u8, CodecError> {
374 let value = required_u64(value, field, code)?;
375 u8::try_from(value)
376 .map_err(|_| CodecError::protocol(code, format!("field `{field}` exceeds u8 range")))
377}
378
379fn required_i32(value: Option<&Value>, field: &str) -> Result<i32, CodecError> {
380 let value = value.and_then(Value::as_i64).ok_or_else(|| {
381 CodecError::protocol(
382 RuntimeErrorCode::InvalidEnvelope,
383 format!("field `{field}` must be a signed integer"),
384 )
385 })?;
386 i32::try_from(value).map_err(|_| {
387 CodecError::protocol(
388 RuntimeErrorCode::InvalidEnvelope,
389 format!("field `{field}` exceeds i32 range"),
390 )
391 })
392}
393
394fn optional_string(value: Option<&Value>, field: &str) -> Result<Option<String>, CodecError> {
395 match value {
396 Some(Value::Nil) => Ok(None),
397 Some(Value::String(value)) => value
398 .as_str()
399 .map(|value| Some(value.to_string()))
400 .ok_or_else(|| {
401 CodecError::protocol(
402 RuntimeErrorCode::InvalidEnvelope,
403 format!("field `{field}` must contain valid UTF-8"),
404 )
405 }),
406 _ => Err(CodecError::protocol(
407 RuntimeErrorCode::InvalidEnvelope,
408 format!("field `{field}` must be string or nil"),
409 )),
410 }
411}
412
413fn required_options(value: Option<&Value>, field: &str) -> Result<Options, CodecError> {
414 match value {
415 Some(Value::Nil) => Ok(Vec::new()),
416 Some(Value::Map(entries)) => entries
417 .iter()
418 .map(|(key, value)| {
419 let key = key.as_str().ok_or_else(|| {
420 CodecError::protocol(
421 RuntimeErrorCode::InvalidEnvelope,
422 format!("field `{field}` option key must be string"),
423 )
424 })?;
425 if !is_option_scalar(value) {
426 return Err(CodecError::protocol(
427 RuntimeErrorCode::InvalidEnvelope,
428 format!("field `{field}` option value must be scalar"),
429 ));
430 }
431 Ok((key.to_string(), value.clone()))
432 })
433 .collect(),
434 _ => Err(CodecError::protocol(
435 RuntimeErrorCode::InvalidEnvelope,
436 format!("field `{field}` must be map or nil"),
437 )),
438 }
439}
440
441fn is_option_scalar(value: &Value) -> bool {
442 matches!(
443 value,
444 Value::Nil | Value::Boolean(_) | Value::Integer(_) | Value::String(_) | Value::Binary(_)
445 )
446}
447
448fn options_to_value(options: &Options) -> Value {
449 if options.is_empty() {
450 return Value::Nil;
451 }
452 Value::Map(
453 options
454 .iter()
455 .map(|(key, value)| (Value::from(key.as_str()), value.clone()))
456 .collect(),
457 )
458}
459
460fn string_option_to_value(value: Option<&str>) -> Value {
461 value.map_or(Value::Nil, Value::from)
462}
463
464fn u8_value(value: u8) -> Value {
465 u64_value(value as u64)
466}
467
468fn u64_value(value: u64) -> Value {
469 Value::Integer(Integer::from(value))
470}
471
472fn i64_value(value: i64) -> Value {
473 Value::Integer(Integer::from(value))
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use rpc_runtime_core::{CapabilityFlags, RUNTIME_PROTOCOL_VERSION};
480 use uuid::Uuid;
481
482 fn roundtrip(envelope: Envelope) {
483 let bytes = encode_envelope(&envelope).expect("encode envelope");
484 let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode envelope");
485 assert_eq!(decoded, envelope);
486 }
487
488 #[test]
489 fn hello_roundtrips() {
490 roundtrip(Envelope::Hello(Hello {
491 protocol_version: RUNTIME_PROTOCOL_VERSION,
492 role: Role::Client,
493 capability_bits: CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
494 | CapabilityFlags::GOODBYE,
495 max_message_size: 4096,
496 options: vec![("implementation".to_string(), Value::from("test"))],
497 }));
498 }
499
500 #[test]
501 fn hello_ack_roundtrips() {
502 roundtrip(Envelope::HelloAck(HelloAck {
503 protocol_version: RUNTIME_PROTOCOL_VERSION,
504 accepted_capability_bits: CapabilityFlags::SERVICE_ACTIVATION,
505 max_message_size: 8192,
506 options: Vec::new(),
507 }));
508 }
509
510 #[test]
511 fn request_roundtrips() {
512 roundtrip(Envelope::Request(Request {
513 request_id: RequestId::new(10),
514 instance_id: InstanceId::new(22).expect("non-zero instance id"),
515 method_id: MethodId::new(3),
516 payload: Value::Array(vec![Value::from("card")]),
517 }));
518 }
519
520 #[test]
521 fn response_ok_roundtrips() {
522 roundtrip(Envelope::ResponseOk(ResponseOk {
523 request_id: RequestId::new(10),
524 payload: Value::Nil,
525 }));
526 }
527
528 #[test]
529 fn response_error_roundtrips() {
530 roundtrip(Envelope::ResponseError(ResponseError {
531 request_id: RequestId::new(10),
532 error_code: 1007,
533 error_kind: 3,
534 error_message: Some("method missing".to_string()),
535 error_details: Value::Nil,
536 }));
537 }
538
539 #[test]
540 fn notification_roundtrips_with_global_instance() {
541 roundtrip(Envelope::Notification(Notification {
542 instance_id: None,
543 notification_id: NotificationId::new(4),
544 payload: Value::from(true),
545 }));
546 }
547
548 #[test]
549 fn goodbye_roundtrips() {
550 roundtrip(Envelope::Goodbye(Goodbye {
551 reason_code: 1,
552 message: Some("shutdown".to_string()),
553 }));
554 }
555
556 #[test]
557 fn spec_request_shape_decodes() {
558 let value = Value::Array(vec![
559 Value::from(3),
560 Value::from(42),
561 Value::from(7),
562 Value::from(2),
563 Value::Nil,
564 ]);
565 let mut bytes = Vec::new();
566 encode::write_value(&mut bytes, &value).expect("encode raw shape");
567
568 let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode request");
569 assert_eq!(
570 decoded,
571 Envelope::Request(Request {
572 request_id: RequestId::new(42),
573 instance_id: InstanceId::new(7).expect("non-zero instance id"),
574 method_id: MethodId::new(2),
575 payload: Value::Nil,
576 })
577 );
578 }
579
580 #[test]
581 fn non_array_fails() {
582 let bytes = encode_raw(&Value::from("not-an-array"));
583 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
584 assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
585 }
586
587 #[test]
588 fn unknown_kind_fails() {
589 let bytes = encode_raw(&Value::Array(vec![Value::from(99)]));
590 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
591 assert_eq!(err.error.code, RuntimeErrorCode::UnknownMessageKind);
592 }
593
594 #[test]
595 fn wrong_field_count_fails() {
596 let bytes = encode_raw(&Value::Array(vec![
597 Value::from(4),
598 Value::from(1),
599 Value::Nil,
600 Value::Nil,
601 ]));
602 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
603 assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
604 }
605
606 #[test]
607 fn request_instance_zero_fails() {
608 let bytes = encode_raw(&Value::Array(vec![
609 Value::from(3),
610 Value::from(1),
611 Value::from(0),
612 Value::from(2),
613 Value::Nil,
614 ]));
615 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
616 assert_eq!(err.error.code, RuntimeErrorCode::InvalidInstanceId);
617 }
618
619 #[test]
620 fn scalar_type_mismatch_fails() {
621 let bytes = encode_raw(&Value::Array(vec![
622 Value::from(3),
623 Value::from("bad-request-id"),
624 Value::from(1),
625 Value::from(2),
626 Value::Nil,
627 ]));
628 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
629 assert_eq!(err.error.code, RuntimeErrorCode::InvalidRequestId);
630 }
631
632 #[test]
633 fn oversized_frame_fails() {
634 let bytes = encode_envelope(&Envelope::ResponseOk(ResponseOk {
635 request_id: RequestId::new(1),
636 payload: Value::Nil,
637 }))
638 .expect("encode");
639 let err = decode_envelope(
640 &bytes,
641 CodecLimits {
642 max_message_size: bytes.len() - 1,
643 },
644 )
645 .expect_err("must fail");
646 assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
647 }
648
649 #[test]
650 fn service_guid_uses_raw_16_byte_binary_encoding() {
651 let uuid = Uuid::parse_str("d7d0c9a0-2eb4-4d2d-a3f9-7a4b2875b7e1").expect("uuid");
652 let guid = ServiceGuid::new(uuid);
653
654 let encoded = encode_service_guid(guid);
655 assert_eq!(encoded, Value::Binary(uuid.as_bytes().to_vec()));
656
657 let decoded = decode_service_guid(&encoded).expect("decode guid");
658 assert_eq!(decoded, guid);
659 }
660
661 #[test]
662 fn service_guid_rejects_wrong_binary_length() {
663 let err = decode_service_guid(&Value::Binary(vec![1, 2, 3])).expect_err("must fail");
664 assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
665 }
666
667 #[test]
668 fn cancel_fails_as_unsupported() {
669 let bytes = encode_raw(&Value::Array(vec![Value::from(7)]));
670 let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
671 assert_eq!(err.error.code, RuntimeErrorCode::RequestCancelUnsupported);
672 }
673
674 #[test]
675 fn performance_smoke_for_request_codec_path() {
676 let envelope = Envelope::Request(Request {
677 request_id: RequestId::new(1),
678 instance_id: InstanceId::new(1).expect("non-zero instance id"),
679 method_id: MethodId::new(1),
680 payload: Value::Nil,
681 });
682 for _ in 0..10_000 {
683 let bytes = encode_envelope(&envelope).expect("encode");
684 let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode");
685 assert!(matches!(decoded, Envelope::Request(_)));
686 }
687 }
688
689 fn encode_raw(value: &Value) -> Vec<u8> {
690 let mut bytes = Vec::new();
691 encode::write_value(&mut bytes, value).expect("encode raw value");
692 bytes
693 }
694}