1#![cfg_attr(not(feature = "std"), no_std)]
43#![deny(unsafe_code)]
44#![warn(missing_docs)]
45#![cfg_attr(docsrs, feature(doc_cfg))]
46
47extern crate alloc;
48
49use alloc::string::{String, ToString};
50use alloc::vec::Vec;
51use core::fmt;
52use serde::{Serialize, de::DeserializeOwned};
53
54pub use turbomcp_core::error::McpError;
56
57#[derive(Debug, Clone)]
59pub struct CodecError {
60 pub message: String,
62 pub source: Option<String>,
64}
65
66impl fmt::Display for CodecError {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(f, "codec error: {}", self.message)
69 }
70}
71
72#[cfg(feature = "std")]
73impl std::error::Error for CodecError {}
74
75impl CodecError {
76 pub fn new(message: impl Into<String>) -> Self {
78 Self {
79 message: message.into(),
80 source: None,
81 }
82 }
83
84 pub fn with_source(message: impl Into<String>, source: impl Into<String>) -> Self {
86 Self {
87 message: message.into(),
88 source: Some(source.into()),
89 }
90 }
91
92 pub fn encode(message: impl Into<String>) -> Self {
94 Self::new(alloc::format!("encode: {}", message.into()))
95 }
96
97 pub fn decode(message: impl Into<String>) -> Self {
99 Self::new(alloc::format!("decode: {}", message.into()))
100 }
101}
102
103impl From<CodecError> for McpError {
104 fn from(err: CodecError) -> Self {
105 McpError::parse_error(err.message)
106 }
107}
108
109pub type CodecResult<T> = Result<T, CodecError>;
111
112pub trait Codec: Send + Sync {
139 fn encode<T: Serialize>(&self, value: &T) -> CodecResult<Vec<u8>>;
141
142 fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> CodecResult<T>;
144
145 fn content_type(&self) -> &'static str;
147
148 fn supports_streaming(&self) -> bool {
150 false
151 }
152
153 fn name(&self) -> &'static str;
155}
156
157#[derive(Debug, Clone, Default)]
162pub struct JsonCodec {
163 pub pretty: bool,
165}
166
167impl JsonCodec {
168 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn pretty() -> Self {
175 Self { pretty: true }
176 }
177}
178
179impl Codec for JsonCodec {
180 fn encode<T: Serialize>(&self, value: &T) -> CodecResult<Vec<u8>> {
181 if self.pretty {
182 serde_json::to_vec_pretty(value)
183 } else {
184 serde_json::to_vec(value)
185 }
186 .map_err(|e| CodecError::encode(e.to_string()))
187 }
188
189 fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> CodecResult<T> {
190 serde_json::from_slice(bytes).map_err(|e| CodecError::decode(e.to_string()))
191 }
192
193 fn content_type(&self) -> &'static str {
194 "application/json"
195 }
196
197 fn supports_streaming(&self) -> bool {
198 true
199 }
200
201 fn name(&self) -> &'static str {
202 "json"
203 }
204}
205
206#[cfg(feature = "simd")]
211#[cfg_attr(docsrs, doc(cfg(feature = "simd")))]
212#[derive(Debug, Clone, Default)]
213pub struct SimdJsonCodec;
214
215#[cfg(feature = "simd")]
216impl SimdJsonCodec {
217 pub fn new() -> Self {
219 Self
220 }
221}
222
223#[cfg(feature = "simd")]
224impl Codec for SimdJsonCodec {
225 fn encode<T: Serialize>(&self, value: &T) -> CodecResult<Vec<u8>> {
226 sonic_rs::to_vec(value).map_err(|e| CodecError::encode(e.to_string()))
227 }
228
229 fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> CodecResult<T> {
230 sonic_rs::from_slice(bytes).map_err(|e| CodecError::decode(e.to_string()))
231 }
232
233 fn content_type(&self) -> &'static str {
234 "application/json"
235 }
236
237 fn supports_streaming(&self) -> bool {
238 true
239 }
240
241 fn name(&self) -> &'static str {
242 "simd-json"
243 }
244}
245
246#[cfg(feature = "msgpack")]
283#[cfg_attr(docsrs, doc(cfg(feature = "msgpack")))]
284#[derive(Debug, Clone, Default)]
285pub struct MsgPackCodec;
286
287#[cfg(feature = "msgpack")]
288impl MsgPackCodec {
289 pub fn new() -> Self {
291 Self
292 }
293}
294
295#[cfg(feature = "msgpack")]
296impl Codec for MsgPackCodec {
297 fn encode<T: Serialize>(&self, value: &T) -> CodecResult<Vec<u8>> {
298 rmp_serde::to_vec_named(value).map_err(|e| CodecError::encode(e.to_string()))
300 }
301
302 fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> CodecResult<T> {
303 rmp_serde::from_slice(bytes).map_err(|e| CodecError::decode(e.to_string()))
304 }
305
306 fn content_type(&self) -> &'static str {
307 "application/msgpack"
308 }
309
310 fn supports_streaming(&self) -> bool {
311 false
312 }
313
314 fn name(&self) -> &'static str {
315 "msgpack"
316 }
317}
318
319const MAX_STREAMING_BUFFER_SIZE: usize = 1024 * 1024;
321
322#[derive(Debug)]
334pub struct StreamingJsonDecoder {
335 buffer: Vec<u8>,
336 max_buffer_size: usize,
337}
338
339impl Default for StreamingJsonDecoder {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345impl StreamingJsonDecoder {
346 pub fn new() -> Self {
348 Self {
349 buffer: Vec::new(),
350 max_buffer_size: MAX_STREAMING_BUFFER_SIZE,
351 }
352 }
353
354 pub fn with_capacity(capacity: usize) -> Self {
356 Self {
357 buffer: Vec::with_capacity(capacity),
358 max_buffer_size: MAX_STREAMING_BUFFER_SIZE,
359 }
360 }
361
362 pub fn with_max_size(max_size: usize) -> Self {
373 Self {
374 buffer: Vec::new(),
375 max_buffer_size: max_size.min(10 * 1024 * 1024), }
377 }
378
379 pub fn feed(&mut self, data: &[u8]) {
386 self.buffer.extend_from_slice(data);
387
388 if self.buffer.len() > self.max_buffer_size {
390 #[cfg(feature = "std")]
391 tracing::warn!(
392 buffer_size = self.buffer.len(),
393 max_size = self.max_buffer_size,
394 "Streaming buffer exceeded maximum size, clearing buffer"
395 );
396 self.buffer.clear();
397 }
398 }
399
400 pub fn try_decode<T: DeserializeOwned>(&mut self) -> CodecResult<Option<T>> {
405 if let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') {
407 let line = &self.buffer[..pos];
408
409 if line.is_empty() || line.iter().all(|b| b.is_ascii_whitespace()) {
411 self.buffer.drain(..=pos);
412 return Ok(None);
413 }
414
415 let result = serde_json::from_slice(line);
417
418 self.buffer.drain(..=pos);
420
421 match result {
422 Ok(value) => Ok(Some(value)),
423 Err(e) => Err(CodecError::decode(e.to_string())),
424 }
425 } else {
426 Ok(None)
427 }
428 }
429
430 pub fn clear(&mut self) {
432 self.buffer.clear();
433 }
434
435 pub fn is_empty(&self) -> bool {
437 self.buffer.is_empty()
438 }
439
440 pub fn len(&self) -> usize {
442 self.buffer.len()
443 }
444
445 pub fn max_buffer_size(&self) -> usize {
447 self.max_buffer_size
448 }
449}
450
451#[derive(Debug, Clone)]
456pub enum AnyCodec {
457 Json(JsonCodec),
459 #[cfg(feature = "simd")]
461 #[cfg_attr(docsrs, doc(cfg(feature = "simd")))]
462 SimdJson(SimdJsonCodec),
463 #[cfg(feature = "msgpack")]
465 #[cfg_attr(docsrs, doc(cfg(feature = "msgpack")))]
466 MsgPack(MsgPackCodec),
467}
468
469impl AnyCodec {
470 pub fn from_name(name: &str) -> Option<Self> {
477 match name {
478 "json" => Some(Self::Json(JsonCodec::new())),
479 #[cfg(feature = "simd")]
480 "simd" | "simd-json" => Some(Self::SimdJson(SimdJsonCodec::new())),
481 #[cfg(feature = "msgpack")]
482 "msgpack" => Some(Self::MsgPack(MsgPackCodec::new())),
483 _ => None,
484 }
485 }
486
487 pub fn available_names() -> &'static [&'static str] {
489 &[
490 "json",
491 #[cfg(feature = "simd")]
492 "simd-json",
493 #[cfg(feature = "msgpack")]
494 "msgpack",
495 ]
496 }
497
498 pub fn encode<T: Serialize>(&self, value: &T) -> CodecResult<Vec<u8>> {
500 match self {
501 Self::Json(c) => c.encode(value),
502 #[cfg(feature = "simd")]
503 Self::SimdJson(c) => c.encode(value),
504 #[cfg(feature = "msgpack")]
505 Self::MsgPack(c) => c.encode(value),
506 }
507 }
508
509 pub fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> CodecResult<T> {
511 match self {
512 Self::Json(c) => c.decode(bytes),
513 #[cfg(feature = "simd")]
514 Self::SimdJson(c) => c.decode(bytes),
515 #[cfg(feature = "msgpack")]
516 Self::MsgPack(c) => c.decode(bytes),
517 }
518 }
519
520 pub fn content_type(&self) -> &'static str {
522 match self {
523 Self::Json(c) => c.content_type(),
524 #[cfg(feature = "simd")]
525 Self::SimdJson(c) => c.content_type(),
526 #[cfg(feature = "msgpack")]
527 Self::MsgPack(c) => c.content_type(),
528 }
529 }
530
531 pub fn name(&self) -> &'static str {
533 match self {
534 Self::Json(c) => c.name(),
535 #[cfg(feature = "simd")]
536 Self::SimdJson(c) => c.name(),
537 #[cfg(feature = "msgpack")]
538 Self::MsgPack(c) => c.name(),
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use serde::{Deserialize, Serialize};
547
548 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
549 struct TestMessage {
550 id: u32,
551 method: String,
552 params: Option<serde_json::Value>,
553 }
554
555 #[test]
556 fn test_json_codec_roundtrip() {
557 let codec = JsonCodec::new();
558 let msg = TestMessage {
559 id: 42,
560 method: "test/method".into(),
561 params: Some(serde_json::json!({"key": "value"})),
562 };
563
564 let encoded = codec.encode(&msg).unwrap();
565 let decoded: TestMessage = codec.decode(&encoded).unwrap();
566
567 assert_eq!(msg, decoded);
568 }
569
570 #[test]
571 fn test_json_codec_pretty() {
572 let codec = JsonCodec::pretty();
573 let msg = TestMessage {
574 id: 1,
575 method: "test".into(),
576 params: None,
577 };
578
579 let encoded = codec.encode(&msg).unwrap();
580 let output = String::from_utf8(encoded).unwrap();
581
582 assert!(output.contains('\n'));
584 }
585
586 #[test]
587 fn test_codec_content_type() {
588 let json = JsonCodec::new();
589 assert_eq!(json.content_type(), "application/json");
590 assert_eq!(json.name(), "json");
591 }
592
593 #[test]
594 fn test_streaming_decoder() {
595 let mut decoder = StreamingJsonDecoder::new();
596
597 decoder.feed(br#"{"id":1,"method":"a","params":null}"#);
599 assert!(decoder.try_decode::<TestMessage>().unwrap().is_none());
600
601 decoder.feed(b"\n");
603 let msg: TestMessage = decoder.try_decode().unwrap().unwrap();
604 assert_eq!(msg.id, 1);
605 assert_eq!(msg.method, "a");
606 }
607
608 #[test]
609 fn test_streaming_decoder_multiple() {
610 let mut decoder = StreamingJsonDecoder::new();
611
612 decoder.feed(
614 br#"{"id":1,"method":"a","params":null}
615{"id":2,"method":"b","params":null}
616"#,
617 );
618
619 let msg1: TestMessage = decoder.try_decode().unwrap().unwrap();
620 assert_eq!(msg1.id, 1);
621
622 let msg2: TestMessage = decoder.try_decode().unwrap().unwrap();
623 assert_eq!(msg2.id, 2);
624
625 assert!(decoder.try_decode::<TestMessage>().unwrap().is_none());
627 }
628
629 #[test]
630 fn test_streaming_decoder_buffer_limit() {
631 let mut decoder = StreamingJsonDecoder::with_max_size(100);
632
633 let large_data = vec![b'x'; 150];
635 decoder.feed(&large_data);
636
637 assert!(
639 decoder.is_empty(),
640 "Buffer should be cleared after exceeding limit"
641 );
642 }
643
644 #[test]
645 fn test_streaming_decoder_max_size_cap() {
646 let decoder = StreamingJsonDecoder::with_max_size(100 * 1024 * 1024); assert_eq!(decoder.max_buffer_size(), 10 * 1024 * 1024);
651 }
652
653 #[test]
654 fn test_streaming_decoder_default_limit() {
655 let decoder = StreamingJsonDecoder::new();
656 assert_eq!(decoder.max_buffer_size(), 1024 * 1024); }
658
659 #[test]
660 fn test_any_codec() {
661 let codec = AnyCodec::from_name("json").unwrap();
662 assert_eq!(codec.name(), "json");
663
664 assert!(AnyCodec::from_name("unknown").is_none());
665 assert!(AnyCodec::available_names().contains(&"json"));
666 }
667
668 #[test]
669 fn test_codec_error() {
670 let codec = JsonCodec::new();
671 let result: CodecResult<TestMessage> = codec.decode(b"invalid json");
672 assert!(result.is_err());
673
674 let err = result.unwrap_err();
675 assert!(err.message.contains("decode"));
676 }
677
678 #[cfg(feature = "simd")]
679 #[test]
680 fn test_simd_codec_roundtrip() {
681 let codec = SimdJsonCodec::new();
682 let msg = TestMessage {
683 id: 99,
684 method: "simd/test".into(),
685 params: Some(serde_json::json!([1, 2, 3])),
686 };
687
688 let encoded = codec.encode(&msg).unwrap();
689 let decoded: TestMessage = codec.decode(&encoded).unwrap();
690
691 assert_eq!(msg, decoded);
692 }
693
694 #[cfg(feature = "msgpack")]
695 #[test]
696 fn test_msgpack_codec_roundtrip() {
697 let codec = MsgPackCodec::new();
698 let msg = TestMessage {
699 id: 77,
700 method: "msgpack/test".into(),
701 params: None,
702 };
703
704 let encoded = codec.encode(&msg).unwrap();
705 let decoded: TestMessage = codec.decode(&encoded).unwrap();
706
707 assert_eq!(msg, decoded);
708 assert_eq!(codec.content_type(), "application/msgpack");
709 }
710}