1use flate2::Compression;
44use flate2::write::GzEncoder;
45use std::collections::HashMap;
46use std::fmt;
47use std::sync::Arc;
48
49use bytes::{Bytes, BytesMut};
50use serde::{Deserialize, Serialize};
51use uuid::Uuid;
52
53#[cfg(feature = "messagepack")]
54use msgpacker::Packable;
55
56use crate::types::{ContentType, ProtocolVersion, Timestamp};
57use crate::{McpError as Error, Result};
58
59#[cfg(feature = "messagepack")]
61#[derive(Debug, Clone)]
62pub enum JsonValue {
63 Null,
65 Bool(bool),
67 Number(f64),
69 String(String),
71 Array(Vec<JsonValue>),
73 Object(std::collections::HashMap<String, JsonValue>),
75}
76
77#[cfg(feature = "messagepack")]
78impl JsonValue {
79 pub fn from_serde_json(value: &serde_json::Value) -> Self {
81 match value {
82 serde_json::Value::Null => JsonValue::Null,
83 serde_json::Value::Bool(b) => JsonValue::Bool(*b),
84 serde_json::Value::Number(n) => {
85 if let Some(i) = n.as_i64() {
86 JsonValue::Number(i as f64)
87 } else if let Some(u) = n.as_u64() {
88 JsonValue::Number(u as f64)
89 } else if let Some(f) = n.as_f64() {
90 JsonValue::Number(f)
91 } else {
92 JsonValue::Null
93 }
94 }
95 serde_json::Value::String(s) => JsonValue::String(s.clone()),
96 serde_json::Value::Array(arr) => {
97 JsonValue::Array(arr.iter().map(Self::from_serde_json).collect())
98 }
99 serde_json::Value::Object(obj) => {
100 let mut map = std::collections::HashMap::new();
101 for (k, v) in obj {
102 map.insert(k.clone(), Self::from_serde_json(v));
103 }
104 JsonValue::Object(map)
105 }
106 }
107 }
108}
109
110#[cfg(feature = "messagepack")]
111impl msgpacker::Packable for JsonValue {
112 fn pack<T>(&self, buf: &mut T) -> usize
113 where
114 T: Extend<u8>,
115 {
116 match self {
117 JsonValue::Null => {
118 buf.extend([0xc0]);
120 1
121 }
122 JsonValue::Bool(b) => b.pack(buf),
123 JsonValue::Number(n) => n.pack(buf),
124 JsonValue::String(s) => s.pack(buf),
125 JsonValue::Array(arr) => {
126 let len = arr.len();
128 let mut bytes_written = 0;
129
130 if len <= 15 {
132 buf.extend([0x90 + len as u8]);
133 bytes_written += 1;
134 } else if len <= u16::MAX as usize {
135 buf.extend([0xdc]);
136 buf.extend((len as u16).to_be_bytes());
137 bytes_written += 3;
138 } else {
139 buf.extend([0xdd]);
140 buf.extend((len as u32).to_be_bytes());
141 bytes_written += 5;
142 }
143
144 for item in arr {
146 bytes_written += item.pack(buf);
147 }
148
149 bytes_written
150 }
151 JsonValue::Object(obj) => {
152 let len = obj.len();
154 let mut bytes_written = 0;
155
156 if len <= 15 {
158 buf.extend([0x80 + len as u8]);
159 bytes_written += 1;
160 } else if len <= u16::MAX as usize {
161 buf.extend([0xde]);
162 buf.extend((len as u16).to_be_bytes());
163 bytes_written += 3;
164 } else {
165 buf.extend([0xdf]);
166 buf.extend((len as u32).to_be_bytes());
167 bytes_written += 5;
168 }
169
170 for (k, v) in obj {
172 bytes_written += k.pack(buf);
173 bytes_written += v.pack(buf);
174 }
175
176 bytes_written
177 }
178 }
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
184#[serde(untagged)]
185pub enum MessageId {
186 String(String),
188 Number(i64),
190 Uuid(Uuid),
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct MessageMetadata {
197 pub created_at: Timestamp,
199
200 pub protocol_version: ProtocolVersion,
202
203 pub encoding: Option<String>,
205
206 pub content_type: ContentType,
208
209 pub size: usize,
211
212 pub correlation_id: Option<String>,
214
215 pub headers: HashMap<String, String>,
217}
218
219#[derive(Debug, Clone)]
221pub struct Message {
222 pub id: MessageId,
224
225 pub metadata: MessageMetadata,
227
228 pub payload: MessagePayload,
230}
231
232#[derive(Debug, Clone)]
234pub enum MessagePayload {
235 Json(JsonPayload),
237
238 Binary(BinaryPayload),
240
241 Text(String),
243
244 Empty,
246}
247
248#[derive(Debug, Clone)]
250pub struct JsonPayload {
251 pub raw: Bytes,
253
254 pub parsed: Option<Arc<serde_json::Value>>,
256
257 pub is_valid: bool,
259}
260
261#[derive(Debug, Clone)]
263pub struct BinaryPayload {
264 pub data: Bytes,
266
267 pub format: BinaryFormat,
269}
270
271#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
273#[serde(rename_all = "lowercase")]
274pub enum BinaryFormat {
275 MessagePack,
277
278 ProtoBuf,
280
281 Cbor,
283
284 Custom,
286}
287
288#[derive(Debug)]
290pub struct MessageSerializer {
291 default_format: SerializationFormat,
293
294 enable_compression: bool,
296
297 compression_threshold: usize,
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq)]
303pub enum SerializationFormat {
304 Json,
306
307 #[cfg(feature = "simd")]
309 SimdJson,
310
311 MessagePack,
313
314 Cbor,
316}
317
318impl Message {
319 pub fn json(id: MessageId, value: impl Serialize) -> Result<Self> {
325 let json_bytes = Self::serialize_json(&value)?;
326 let payload = MessagePayload::Json(JsonPayload {
327 raw: json_bytes.freeze(),
328 parsed: Some(Arc::new(serde_json::to_value(value)?)),
329 is_valid: true,
330 });
331
332 Ok(Self {
333 id,
334 metadata: MessageMetadata::new(ContentType::Json, payload.size()),
335 payload,
336 })
337 }
338
339 pub fn binary(id: MessageId, data: Bytes, format: BinaryFormat) -> Self {
341 let size = data.len();
342 let payload = MessagePayload::Binary(BinaryPayload { data, format });
343
344 Self {
345 id,
346 metadata: MessageMetadata::new(ContentType::Binary, size),
347 payload,
348 }
349 }
350
351 #[must_use]
353 pub fn text(id: MessageId, text: String) -> Self {
354 let size = text.len();
355 let payload = MessagePayload::Text(text);
356
357 Self {
358 id,
359 metadata: MessageMetadata::new(ContentType::Text, size),
360 payload,
361 }
362 }
363
364 #[must_use]
366 pub fn empty(id: MessageId) -> Self {
367 Self {
368 id,
369 metadata: MessageMetadata::new(ContentType::Json, 0),
370 payload: MessagePayload::Empty,
371 }
372 }
373
374 pub const fn size(&self) -> usize {
376 self.metadata.size
377 }
378
379 pub const fn is_empty(&self) -> bool {
381 matches!(self.payload, MessagePayload::Empty)
382 }
383
384 pub fn serialize(&self, format: SerializationFormat) -> Result<Bytes> {
390 match format {
391 SerializationFormat::Json => self.serialize_json_format(),
392 #[cfg(feature = "simd")]
393 SerializationFormat::SimdJson => self.serialize_simd_json(),
394 SerializationFormat::MessagePack => self.serialize_messagepack(),
395 SerializationFormat::Cbor => self.serialize_cbor(),
396 }
397 }
398
399 pub fn deserialize(bytes: Bytes) -> Result<Self> {
405 let format = Self::detect_format(&bytes);
407 Self::deserialize_with_format(bytes, format)
408 }
409
410 pub fn deserialize_with_format(bytes: Bytes, format: SerializationFormat) -> Result<Self> {
412 match format {
413 SerializationFormat::Json => Ok(Self::deserialize_json(bytes)),
414 #[cfg(feature = "simd")]
415 SerializationFormat::SimdJson => Ok(Self::deserialize_simd_json(bytes)),
416 SerializationFormat::MessagePack => Ok(Self::deserialize_messagepack(bytes)),
417 SerializationFormat::Cbor => Self::deserialize_cbor(bytes),
418 }
419 }
420
421 pub fn parse_json<T>(&self) -> Result<T>
423 where
424 T: for<'de> Deserialize<'de>,
425 {
426 match &self.payload {
427 MessagePayload::Json(json_payload) => json_payload.parsed.as_ref().map_or_else(
428 || {
429 #[cfg(feature = "simd")]
430 {
431 let mut json_bytes = json_payload.raw.to_vec();
432 simd_json::from_slice(&mut json_bytes).map_err(|e| {
433 Error::serialization(format!("SIMD JSON parsing failed: {e}"))
434 })
435 }
436 #[cfg(not(feature = "simd"))]
437 {
438 serde_json::from_slice(&json_payload.raw).map_err(|e| {
439 Error::serialization(format!("JSON parsing failed: {}", e))
440 })
441 }
442 },
443 |parsed| {
444 serde_json::from_value((**parsed).clone())
445 .map_err(|e| Error::serialization(format!("JSON parsing failed: {e}")))
446 },
447 ),
448 _ => Err(Error::invalid_params("Message payload is not JSON")),
449 }
450 }
451
452 fn serialize_json(value: &impl Serialize) -> Result<BytesMut> {
455 #[cfg(feature = "simd")]
456 {
457 sonic_rs::to_vec(value)
458 .map(|v| BytesMut::from(v.as_slice()))
459 .map_err(|e| Error::serialization(format!("SIMD JSON serialization failed: {e}")))
460 }
461 #[cfg(not(feature = "simd"))]
462 {
463 serde_json::to_vec(value)
464 .map(|v| BytesMut::from(v.as_slice()))
465 .map_err(|e| Error::serialization(format!("JSON serialization failed: {}", e)))
466 }
467 }
468
469 fn serialize_json_format(&self) -> Result<Bytes> {
470 match &self.payload {
471 MessagePayload::Json(json_payload) => Ok(json_payload.raw.clone()),
472 MessagePayload::Text(text) => Ok(Bytes::from(text.clone())),
473 MessagePayload::Empty => Ok(Bytes::from_static(b"{}")),
474 MessagePayload::Binary(_) => Err(Error::invalid_params(
475 "Cannot serialize non-JSON payload as JSON",
476 )),
477 }
478 }
479
480 #[cfg(feature = "simd")]
481 fn serialize_simd_json(&self) -> Result<Bytes> {
482 match &self.payload {
483 MessagePayload::Json(json_payload) => {
484 if json_payload.is_valid {
485 Ok(json_payload.raw.clone())
486 } else {
487 Err(Error::serialization("Invalid JSON payload"))
488 }
489 }
490 _ => Err(Error::invalid_params(
491 "Cannot serialize non-JSON payload with SIMD JSON",
492 )),
493 }
494 }
495
496 fn serialize_messagepack(&self) -> Result<Bytes> {
497 #[cfg(feature = "messagepack")]
498 {
499 match &self.payload {
500 MessagePayload::Binary(binary) if binary.format == BinaryFormat::MessagePack => {
501 Ok(binary.data.clone())
502 }
503 MessagePayload::Json(json_payload) => json_payload.parsed.as_ref().map_or_else(
504 || {
505 Err(Error::serialization(
506 "Cannot serialize unparsed JSON to MessagePack",
507 ))
508 },
509 |parsed| {
510 let packable_value = JsonValue::from_serde_json(parsed.as_ref());
512 let mut buffer = Vec::new();
513 packable_value.pack(&mut buffer);
514 Ok(Bytes::from(buffer))
515 },
516 ),
517 _ => Err(Error::invalid_params(
518 "Cannot serialize payload as MessagePack",
519 )),
520 }
521 }
522 #[cfg(not(feature = "messagepack"))]
523 {
524 let _ = self; Err(Error::invalid_params(
526 "MessagePack serialization not available",
527 ))
528 }
529 }
530
531 fn serialize_cbor(&self) -> Result<Bytes> {
532 match &self.payload {
533 MessagePayload::Binary(binary) if binary.format == BinaryFormat::Cbor => {
534 Ok(binary.data.clone())
535 }
536 MessagePayload::Json(json_payload) => {
537 if let Some(parsed) = &json_payload.parsed {
538 {
539 let mut buffer = Vec::new();
540 ciborium::into_writer(parsed.as_ref(), &mut buffer)
541 .map(|_| Bytes::from(buffer))
542 .map_err(|e| {
543 Error::serialization(format!("CBOR serialization failed: {e}"))
544 })
545 }
546 } else {
547 #[cfg(feature = "simd")]
549 {
550 let mut json_bytes = json_payload.raw.to_vec();
551 let value: serde_json::Value = simd_json::from_slice(&mut json_bytes)
552 .map_err(|e| {
553 Error::serialization(format!(
554 "SIMD JSON parsing failed before CBOR: {e}"
555 ))
556 })?;
557 {
558 let mut buffer = Vec::new();
559 ciborium::into_writer(&value, &mut buffer)
560 .map(|_| Bytes::from(buffer))
561 .map_err(|e| {
562 Error::serialization(format!("CBOR serialization failed: {e}"))
563 })
564 }
565 }
566 #[cfg(not(feature = "simd"))]
567 {
568 let value: serde_json::Value = serde_json::from_slice(&json_payload.raw)
569 .map_err(|e| {
570 Error::serialization(format!(
571 "JSON parsing failed before CBOR: {}",
572 e
573 ))
574 })?;
575 let mut buf = Vec::new();
576 ciborium::ser::into_writer(&value, &mut buf).map_err(|e| {
577 Error::serialization(format!("CBOR serialization failed: {}", e))
578 })?;
579 Ok(Bytes::from(buf))
580 }
581 }
582 }
583 _ => Err(Error::invalid_params("Cannot serialize payload as CBOR")),
584 }
585 }
586
587 fn deserialize_json(bytes: Bytes) -> Self {
588 let is_valid = serde_json::from_slice::<serde_json::Value>(&bytes).is_ok();
590
591 let payload = MessagePayload::Json(JsonPayload {
592 raw: bytes,
593 parsed: None, is_valid,
595 });
596
597 Self {
598 id: MessageId::Uuid(Uuid::new_v4()),
599 metadata: MessageMetadata::new(ContentType::Json, payload.size()),
600 payload,
601 }
602 }
603
604 #[cfg(feature = "simd")]
605 fn deserialize_simd_json(bytes: Bytes) -> Self {
606 let mut json_bytes = bytes.to_vec();
607 let is_valid = simd_json::from_slice::<serde_json::Value>(&mut json_bytes).is_ok();
608
609 let payload = MessagePayload::Json(JsonPayload {
610 raw: bytes,
611 parsed: None,
612 is_valid,
613 });
614
615 Self {
616 id: MessageId::Uuid(Uuid::new_v4()),
617 metadata: MessageMetadata::new(ContentType::Json, payload.size()),
618 payload,
619 }
620 }
621
622 fn deserialize_messagepack(bytes: Bytes) -> Self {
623 let payload = MessagePayload::Binary(BinaryPayload {
624 data: bytes,
625 format: BinaryFormat::MessagePack,
626 });
627
628 Self {
629 id: MessageId::Uuid(Uuid::new_v4()),
630 metadata: MessageMetadata::new(ContentType::Binary, payload.size()),
631 payload,
632 }
633 }
634
635 fn deserialize_cbor(bytes: Bytes) -> Result<Self> {
636 if let Ok(value) = ciborium::from_reader::<serde_json::Value, _>(&bytes[..]) {
638 let raw = serde_json::to_vec(&value)
639 .map(Bytes::from)
640 .map_err(|e| Error::serialization(format!("JSON re-encode failed: {e}")))?;
641 let payload = MessagePayload::Json(JsonPayload {
642 raw,
643 parsed: Some(Arc::new(value)),
644 is_valid: true,
645 });
646 return Ok(Self {
647 id: MessageId::Uuid(Uuid::new_v4()),
648 metadata: MessageMetadata::new(ContentType::Json, payload.size()),
649 payload,
650 });
651 }
652
653 let payload = MessagePayload::Binary(BinaryPayload {
655 data: bytes,
656 format: BinaryFormat::Cbor,
657 });
658 Ok(Self {
659 id: MessageId::Uuid(Uuid::new_v4()),
660 metadata: MessageMetadata::new(ContentType::Binary, payload.size()),
661 payload,
662 })
663 }
664
665 fn detect_format(bytes: &[u8]) -> SerializationFormat {
666 if bytes.is_empty() {
667 return SerializationFormat::Json;
668 }
669
670 if matches!(bytes[0], b'{' | b'[') {
672 #[cfg(feature = "simd")]
673 {
674 return SerializationFormat::SimdJson;
675 }
676 #[cfg(not(feature = "simd"))]
677 {
678 return SerializationFormat::Json;
679 }
680 }
681
682 if bytes.len() >= 2 && (bytes[0] == 0x82 || bytes[0] == 0x83) {
684 return SerializationFormat::MessagePack;
685 }
686
687 #[cfg(feature = "simd")]
689 {
690 SerializationFormat::SimdJson
691 }
692 #[cfg(not(feature = "simd"))]
693 {
694 SerializationFormat::Json
695 }
696 }
697}
698
699impl MessagePayload {
700 pub const fn size(&self) -> usize {
702 match self {
703 Self::Json(json) => json.raw.len(),
704 Self::Binary(binary) => binary.data.len(),
705 Self::Text(text) => text.len(),
706 Self::Empty => 0,
707 }
708 }
709}
710
711impl MessageMetadata {
712 #[must_use]
714 pub fn new(content_type: ContentType, size: usize) -> Self {
715 Self {
716 created_at: Timestamp::now(),
717 protocol_version: crate::PROTOCOL_VERSION.to_string(),
718 encoding: None,
719 content_type,
720 size,
721 correlation_id: None,
722 headers: HashMap::new(),
723 }
724 }
725
726 #[must_use]
728 pub fn with_header(mut self, key: String, value: String) -> Self {
729 self.headers.insert(key, value);
730 self
731 }
732
733 #[must_use]
735 pub fn with_correlation_id(mut self, correlation_id: String) -> Self {
736 self.correlation_id = Some(correlation_id);
737 self
738 }
739
740 #[must_use]
742 pub fn with_encoding(mut self, encoding: String) -> Self {
743 self.encoding = Some(encoding);
744 self
745 }
746}
747
748impl MessageSerializer {
749 #[must_use]
751 pub const fn new() -> Self {
752 Self {
753 default_format: SerializationFormat::Json,
754 enable_compression: false,
755 compression_threshold: 1024, }
757 }
758
759 #[must_use]
761 pub const fn with_format(mut self, format: SerializationFormat) -> Self {
762 self.default_format = format;
763 self
764 }
765
766 #[must_use]
768 pub const fn with_compression(mut self, enable: bool, threshold: usize) -> Self {
769 self.enable_compression = enable;
770 self.compression_threshold = threshold;
771 self
772 }
773
774 pub fn serialize(&self, message: &mut Message) -> Result<Bytes> {
776 let serialized = message.serialize(self.default_format)?;
777
778 if self.enable_compression && serialized.len() > self.compression_threshold {
780 message.metadata.encoding = Some("gzip".to_string()); Ok(self.compress(serialized))
782 } else {
783 Ok(serialized)
784 }
785 }
786
787 fn compress(&self, data: Bytes) -> Bytes {
790 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
791 if let Err(e) = std::io::Write::write_all(&mut encoder, &data) {
792 eprintln!("Failed to compress data: {}", e);
793 return data; }
795 match encoder.finish() {
796 Ok(compressed_data) => Bytes::from(compressed_data),
797 Err(e) => {
798 eprintln!("Failed to finish compression: {}", e);
799 data }
801 }
802 }
803}
804
805impl Default for MessageSerializer {
806 fn default() -> Self {
807 Self::new()
808 }
809}
810
811impl fmt::Display for MessageId {
812 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
813 match self {
814 Self::String(s) => write!(f, "{s}"),
815 Self::Number(n) => write!(f, "{n}"),
816 Self::Uuid(u) => write!(f, "{u}"),
817 }
818 }
819}
820
821impl From<String> for MessageId {
822 fn from(s: String) -> Self {
823 Self::String(s)
824 }
825}
826
827impl From<&str> for MessageId {
828 fn from(s: &str) -> Self {
829 Self::String(s.to_string())
830 }
831}
832
833impl From<i64> for MessageId {
834 fn from(n: i64) -> Self {
835 Self::Number(n)
836 }
837}
838
839impl From<Uuid> for MessageId {
840 fn from(u: Uuid) -> Self {
841 Self::Uuid(u)
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use serde_json::json;
849
850 #[test]
851 fn test_message_creation() {
852 let message = Message::json(MessageId::from("test"), json!({"key": "value"})).unwrap();
853 assert_eq!(message.id.to_string(), "test");
854 assert!(!message.is_empty());
855 }
856
857 #[test]
858 fn test_message_serialization() {
859 let message = Message::json(MessageId::from(1), json!({"test": true})).unwrap();
860 let serialized = message.serialize(SerializationFormat::Json).unwrap();
861 assert!(!serialized.is_empty());
862 }
863
864 #[derive(Deserialize, PartialEq, Debug)]
865 struct TestData {
866 number: i32,
867 }
868
869 #[test]
870 fn test_message_parsing() {
871 let message = Message::json(MessageId::from("test"), json!({"number": 42})).unwrap();
872
873 let parsed: TestData = message.parse_json().unwrap();
874 assert_eq!(parsed.number, 42);
875 }
876
877 #[test]
878 fn test_format_detection() {
879 let json_bytes = Bytes::from(r#"{"test": true}"#);
880 let format = Message::detect_format(&json_bytes);
881
882 #[cfg(feature = "simd")]
883 assert_eq!(format, SerializationFormat::SimdJson);
884 #[cfg(not(feature = "simd"))]
885 assert_eq!(format, SerializationFormat::Json);
886 }
887
888 #[test]
889 fn test_message_metadata() {
890 let metadata = MessageMetadata::new(ContentType::Json, 100)
891 .with_header("custom".to_string(), "value".to_string())
892 .with_correlation_id("corr-123".to_string());
893
894 assert_eq!(metadata.size, 100);
895 assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
896 assert_eq!(metadata.correlation_id, Some("corr-123".to_string()));
897 }
898
899 #[test]
900 fn test_message_serializer_compression() {
901 use flate2::read::GzDecoder;
902 use std::io::Read;
903
904 let serializer = MessageSerializer::new().with_compression(true, 10); let large_json = json!({
907 "data": "a".repeat(100), });
909 let mut message =
910 Message::json(MessageId::from("compressed_test"), large_json.clone()).unwrap();
911
912 let original_size = message.size();
913 assert!(
914 original_size > 10,
915 "Original message size should be greater than compression threshold"
916 );
917
918 let compressed_bytes = serializer.serialize(&mut message).unwrap();
919
920 assert_eq!(message.metadata.encoding, Some("gzip".to_string()));
922
923 assert!(
925 compressed_bytes.len() < original_size,
926 "Compressed size should be smaller than original"
927 );
928
929 let mut decoder = GzDecoder::new(&compressed_bytes[..]);
931 let mut decompressed_data = Vec::new();
932 decoder.read_to_end(&mut decompressed_data).unwrap();
933
934 let decompressed_message = Message::deserialize(Bytes::from(decompressed_data)).unwrap();
935 let parsed_json: serde_json::Value = decompressed_message.parse_json().unwrap();
936
937 assert_eq!(parsed_json, large_json);
938 }
939}