Skip to main content

ryo_app/
codec.rs

1//! tarpc transport codec for RYO RPC communication.
2//!
3//! # Why MessagePackNamed?
4//!
5//! RYO uses MessagePack for RPC serialization via tarpc. There are two serialization modes:
6//!
7//! | Mode | Serialization | `skip_serializing_if` |
8//! |------|--------------|----------------------|
9//! | Array-based (default) | `[value1, value2, ...]` | **Incompatible** |
10//! | Named (map-based) | `{"field1": value1, ...}` | Compatible |
11//!
12//! Many response types use `#[serde(skip_serializing_if = "...")]` to reduce payload size.
13//! This requires **named serialization** where fields are identified by name, not position.
14//!
15//! Using array-based serialization with `skip_serializing_if` causes deserialization failures:
16//! ```text
17//! invalid type: boolean `false`, expected a sequence
18//! ```
19//!
20//! # Usage
21//!
22//! Always use the helper functions to create transports:
23//!
24//! ```ignore
25//! use ryo_app::codec::create_client_transport;
26//! use tokio::net::UnixStream;
27//!
28//! let stream = UnixStream::connect(socket_path).await?;
29//! let transport = create_client_transport(stream);
30//! let client = RyoServiceClient::new(config, transport).spawn();
31//! ```
32//!
33//! # Important
34//!
35//! **DO NOT** use `tokio_serde::formats::MessagePack::default()` directly.
36//! It uses array-based serialization which is incompatible with `skip_serializing_if`.
37
38use serde::{de::DeserializeOwned, Serialize};
39use std::io;
40use std::marker::PhantomData;
41use std::pin::Pin;
42use tokio_util::bytes::{Bytes, BytesMut};
43
44/// MessagePack codec with named (map-based) serialization.
45///
46/// This codec uses `rmp_serde::to_vec_named` for serialization, which produces
47/// map-based output compatible with `skip_serializing_if` attributes.
48///
49/// # Example
50///
51/// ```ignore
52/// let transport = tarpc::serde_transport::new(
53///     tokio_util::codec::LengthDelimitedCodec::builder().new_framed(stream),
54///     MessagePackNamed::default(),
55/// );
56/// ```
57#[derive(Debug)]
58pub struct MessagePackNamed<Item, SinkItem> {
59    _item: PhantomData<fn() -> Item>,
60    _sink_item: PhantomData<fn(SinkItem)>,
61}
62
63impl<Item, SinkItem> Default for MessagePackNamed<Item, SinkItem> {
64    fn default() -> Self {
65        Self {
66            _item: PhantomData,
67            _sink_item: PhantomData,
68        }
69    }
70}
71
72impl<Item, SinkItem> Clone for MessagePackNamed<Item, SinkItem> {
73    fn clone(&self) -> Self {
74        Self::default()
75    }
76}
77
78impl<Item, SinkItem> tokio_serde::Deserializer<Item> for MessagePackNamed<Item, SinkItem>
79where
80    Item: DeserializeOwned,
81{
82    type Error = io::Error;
83
84    fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
85        rmp_serde::from_slice(src).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
86    }
87}
88
89impl<Item, SinkItem> tokio_serde::Serializer<SinkItem> for MessagePackNamed<Item, SinkItem>
90where
91    SinkItem: Serialize,
92{
93    type Error = io::Error;
94
95    fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
96        rmp_serde::to_vec_named(item)
97            .map(Into::into)
98            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
99    }
100}
101
102// ============================================================================
103// Transport Factory Functions
104// ============================================================================
105
106use tokio::io::{AsyncRead, AsyncWrite};
107use tokio_util::codec::LengthDelimitedCodec;
108
109/// Create a framed transport with LengthDelimitedCodec.
110///
111/// Helper to reduce boilerplate when creating transports.
112fn framed<T: AsyncRead + AsyncWrite>(
113    stream: T,
114) -> tokio_util::codec::Framed<T, LengthDelimitedCodec> {
115    LengthDelimitedCodec::builder().new_framed(stream)
116}
117
118/// Create a tarpc transport with the correct codec for client-side use.
119///
120/// This is the recommended way to create a client transport. It ensures:
121/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
122/// - Proper framing with `LengthDelimitedCodec`
123///
124/// # Example
125///
126/// ```ignore
127/// use ryo_app::codec::create_client_transport;
128/// use tokio::net::UnixStream;
129///
130/// let stream = UnixStream::connect(socket_path).await?;
131/// let transport = create_client_transport(stream);
132/// let client = RyoServiceClient::new(config, transport).spawn();
133/// ```
134pub fn create_client_transport<T: AsyncRead + AsyncWrite + Unpin>(
135    stream: T,
136) -> impl futures::Stream<
137    Item = Result<tarpc::Response<crate::service::RyoServiceResponse>, std::io::Error>,
138> + futures::Sink<
139    tarpc::ClientMessage<crate::service::RyoServiceRequest>,
140    Error = std::io::Error,
141> {
142    tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
143}
144
145/// Create a tarpc transport with the correct codec for server-side use.
146///
147/// This is the recommended way to create a server transport. It ensures:
148/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
149/// - Proper framing with `LengthDelimitedCodec`
150///
151/// # Example
152///
153/// ```ignore
154/// use ryo_app::codec::create_server_transport;
155/// use tokio::net::UnixListener;
156///
157/// let (stream, _) = listener.accept().await?;
158/// let transport = create_server_transport(stream);
159/// let channel = tarpc::server::BaseChannel::with_defaults(transport);
160/// ```
161pub fn create_server_transport<T: AsyncRead + AsyncWrite + Unpin>(
162    stream: T,
163) -> impl futures::Stream<
164    Item = Result<tarpc::ClientMessage<crate::service::RyoServiceRequest>, std::io::Error>,
165> + futures::Sink<tarpc::Response<crate::service::RyoServiceResponse>, Error = std::io::Error> {
166    tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
167}
168
169#[cfg(test)]
170mod tests {
171
172    #[test]
173    fn test_messagepack_named_roundtrip() {
174        use serde::{Deserialize, Serialize};
175
176        #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
177        struct TestStruct {
178            name: String,
179            #[serde(default, skip_serializing_if = "Vec::is_empty")]
180            items: Vec<String>,
181            #[serde(default)]
182            count: usize,
183        }
184
185        let original = TestStruct {
186            name: "test".to_string(),
187            items: vec![], // Will be skipped during serialization
188            count: 42,
189        };
190
191        // Serialize with named codec
192        let encoded = rmp_serde::to_vec_named(&original).unwrap();
193
194        // Deserialize
195        let decoded: TestStruct = rmp_serde::from_slice(&encoded).unwrap();
196
197        assert_eq!(original, decoded);
198    }
199
200    #[test]
201    fn test_skip_serializing_if_with_named() {
202        use serde::{Deserialize, Serialize};
203
204        #[derive(Debug, Serialize, Deserialize)]
205        struct Response {
206            #[serde(default, skip_serializing_if = "Vec::is_empty")]
207            patterns: Vec<String>,
208            #[serde(default)]
209            applied: bool,
210            #[serde(default)]
211            files_modified: usize,
212        }
213
214        // Simulate SuggestGenerateResponse with list=true returning patterns
215        let response = Response {
216            patterns: vec!["pattern1".to_string()],
217            applied: false,
218            files_modified: 0,
219        };
220
221        let encoded = rmp_serde::to_vec_named(&response).unwrap();
222        let _decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
223
224        // Simulate empty response (patterns skipped)
225        let empty_response = Response {
226            patterns: vec![],
227            applied: false,
228            files_modified: 0,
229        };
230
231        let encoded = rmp_serde::to_vec_named(&empty_response).unwrap();
232        let decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
233        assert!(decoded.patterns.is_empty());
234    }
235}