zlink_core/connection/chain/
mod.rs

1//! Chain method calls.
2
3mod reply_stream;
4#[doc(hidden)]
5pub use reply_stream::ReplyStream;
6
7use crate::{connection::Socket, reply, Call, Connection, Result};
8use core::fmt::Debug;
9use futures_util::stream::Stream;
10use serde::{Deserialize, Serialize};
11
12/// A chain of method calls that will be sent together.
13///
14/// Each call must have the same method, reply, and error types for homogeneity. Use
15/// [`Connection::chain_call`] to create a new chain, extend it with [`Chain::append`] and send the
16/// the entire chain using [`Chain::send`].
17///
18/// With `std` feature enabled, this supports unlimited calls. Otherwise it is limited by how many
19/// calls can fit in our fixed-sized buffer.
20///
21/// Oneway calls (where `Call::oneway() == Some(true)`) do not expect replies and are handled
22/// automatically by the chain.
23#[derive(Debug)]
24pub struct Chain<'c, S: Socket, ReplyParams, ReplyError> {
25    pub(super) connection: &'c mut Connection<S>,
26    pub(super) call_count: usize,
27    pub(super) reply_count: usize,
28    _phantom: core::marker::PhantomData<(ReplyParams, ReplyError)>,
29}
30
31impl<'c, S, ReplyParams, ReplyError> Chain<'c, S, ReplyParams, ReplyError>
32where
33    S: Socket,
34    ReplyParams: Deserialize<'c> + Debug,
35    ReplyError: Deserialize<'c> + Debug,
36{
37    /// Create a new chain with the first call.
38    pub(super) fn new<Method>(
39        connection: &'c mut Connection<S>,
40        call: &Call<Method>,
41    ) -> Result<Self>
42    where
43        Method: Serialize + Debug,
44    {
45        connection.write.enqueue_call(call)?;
46        let reply_count = if call.oneway() { 0 } else { 1 };
47        Ok(Chain {
48            connection,
49            call_count: 1,
50            reply_count,
51            _phantom: core::marker::PhantomData,
52        })
53    }
54
55    /// Append another method call to the chain.
56    ///
57    /// The call will be enqueued but not sent until [`Chain::send`] is called. Note that one way
58    /// calls (where `Call::oneway() == Some(true)`) do not receive replies.
59    ///
60    /// Calls with `more == Some(true)` will stream multiple replies until a reply with
61    /// `continues != Some(true)` is received.
62    pub fn append<Method>(mut self, call: &Call<Method>) -> Result<Self>
63    where
64        Method: Serialize + Debug,
65    {
66        self.connection.write.enqueue_call(call)?;
67        if !call.oneway() {
68            self.reply_count += 1;
69        };
70        self.call_count += 1;
71        Ok(self)
72    }
73
74    /// Send all enqueued calls and return a replies stream.
75    ///
76    /// This will flush all enqueued calls in a single write operation and then return a stream
77    /// that allows reading the replies.
78    pub async fn send(
79        self,
80    ) -> Result<impl Stream<Item = Result<reply::Result<ReplyParams, ReplyError>>> + 'c>
81    where
82        ReplyParams: 'c,
83        ReplyError: 'c,
84    {
85        // Flush all enqueued calls.
86        self.connection.write.flush().await?;
87
88        Ok(ReplyStream::new(
89            self.connection.read_mut(),
90            |conn| conn.receive_reply::<ReplyParams, ReplyError>(),
91            self.reply_count,
92        ))
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::Call;
100    use futures_util::pin_mut;
101    use serde::{Deserialize, Serialize};
102
103    #[derive(Debug, Serialize, Deserialize)]
104    struct GetUser {
105        id: u32,
106    }
107
108    #[derive(Debug, Serialize, Deserialize)]
109    struct User {
110        id: u32,
111    }
112
113    #[derive(Debug, Serialize, Deserialize)]
114    struct ApiError {
115        code: i32,
116    }
117
118    // Use consolidated mock socket from test_utils.
119    use crate::test_utils::mock_socket::MockSocket;
120
121    #[tokio::test]
122    async fn homogeneous_calls() -> crate::Result<()> {
123        let responses = [r#"{"parameters":{"id":1}}"#, r#"{"parameters":{"id":2}}"#];
124        let socket = MockSocket::new(&responses);
125        let mut conn = Connection::new(socket);
126
127        let call1 = Call::new(GetUser { id: 1 });
128        let call2 = Call::new(GetUser { id: 2 });
129
130        let replies = conn
131            .chain_call::<GetUser, User, ApiError>(&call1)?
132            .append(&call2)?
133            .send()
134            .await?;
135
136        use futures_util::stream::StreamExt;
137        pin_mut!(replies);
138
139        let user1 = replies.next().await.unwrap()?.unwrap();
140        assert_eq!(user1.parameters().unwrap().id, 1);
141
142        let user2 = replies.next().await.unwrap()?.unwrap();
143        assert_eq!(user2.parameters().unwrap().id, 2);
144
145        // No more replies should be available.
146        let no_reply = replies.next().await;
147        assert!(no_reply.is_none());
148        Ok(())
149    }
150
151    #[tokio::test]
152    async fn oneway_calls_no_reply() -> crate::Result<()> {
153        // Only the first call expects a reply; the second is oneway.
154        let responses = [r#"{"parameters":{"id":1}}"#];
155        let socket = MockSocket::new(&responses);
156        let mut conn = Connection::new(socket);
157
158        let get_user = Call::new(GetUser { id: 1 });
159        let oneway_call = Call::new(GetUser { id: 2 }).set_oneway(true);
160
161        let replies = conn
162            .chain_call::<GetUser, User, ApiError>(&get_user)?
163            .append(&oneway_call)?
164            .send()
165            .await?;
166
167        use futures_util::stream::StreamExt;
168        pin_mut!(replies);
169
170        let user = replies.next().await.unwrap()?.unwrap();
171        assert_eq!(user.parameters().unwrap().id, 1);
172
173        // No more replies should be available.
174        let no_reply = replies.next().await;
175        assert!(no_reply.is_none());
176        Ok(())
177    }
178
179    #[tokio::test]
180    async fn more_calls_with_streaming() -> crate::Result<()> {
181        let responses = [
182            r#"{"parameters":{"id":1},"continues":true}"#,
183            r#"{"parameters":{"id":2},"continues":true}"#,
184            r#"{"parameters":{"id":3},"continues":false}"#,
185            r#"{"parameters":{"id":4}}"#,
186        ];
187        let socket = MockSocket::new(&responses);
188        let mut conn = Connection::new(socket);
189
190        let more_call = Call::new(GetUser { id: 1 }).set_more(true);
191        let regular_call = Call::new(GetUser { id: 2 });
192
193        let replies = conn
194            .chain_call::<GetUser, User, ApiError>(&more_call)?
195            .append(&regular_call)?
196            .send()
197            .await?;
198
199        use futures_util::stream::StreamExt;
200        pin_mut!(replies);
201
202        // First call - streaming replies
203        let user1 = replies.next().await.unwrap()?.unwrap();
204        assert_eq!(user1.parameters().unwrap().id, 1);
205        assert_eq!(user1.continues(), Some(true));
206
207        let user2 = replies.next().await.unwrap()?.unwrap();
208        assert_eq!(user2.parameters().unwrap().id, 2);
209        assert_eq!(user2.continues(), Some(true));
210
211        let user3 = replies.next().await.unwrap()?.unwrap();
212        assert_eq!(user3.parameters().unwrap().id, 3);
213        assert_eq!(user3.continues(), Some(false));
214
215        // Second call - single reply
216        let user4 = replies.next().await.unwrap()?.unwrap();
217        assert_eq!(user4.parameters().unwrap().id, 4);
218        assert_eq!(user4.continues(), None);
219
220        // No more replies should be available.
221        let no_reply = replies.next().await;
222        assert!(no_reply.is_none());
223        Ok(())
224    }
225
226    #[tokio::test]
227    async fn stream_interface_works() -> crate::Result<()> {
228        use futures_util::stream::StreamExt;
229
230        let responses = [
231            r#"{"parameters":{"id":1}}"#,
232            r#"{"parameters":{"id":2}}"#,
233            r#"{"parameters":{"id":3}}"#,
234        ];
235        let socket = MockSocket::new(&responses);
236        let mut conn = Connection::new(socket);
237
238        let call1 = Call::new(GetUser { id: 1 });
239        let call2 = Call::new(GetUser { id: 2 });
240        let call3 = Call::new(GetUser { id: 3 });
241
242        let replies = conn
243            .chain_call::<GetUser, User, ApiError>(&call1)?
244            .append(&call2)?
245            .append(&call3)?
246            .send()
247            .await?;
248
249        // Use Stream's collect method to gather all results
250        pin_mut!(replies);
251        let results: mayheap::Vec<_, 16> = replies.collect().await;
252        assert_eq!(results.len(), 3);
253
254        // Verify all results are successful
255        for (i, result) in results.into_iter().enumerate() {
256            let user = result?.unwrap();
257            assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
258        }
259
260        Ok(())
261    }
262
263    #[cfg(feature = "std")]
264    #[tokio::test]
265    async fn heterogeneous_calls() -> crate::Result<()> {
266        // Types for heterogeneous calls test
267        #[derive(Debug, Serialize, Deserialize)]
268        #[serde(tag = "method")]
269        enum HeterogeneousMethods {
270            GetUser { id: u32 },
271            GetPost { post_id: u32 },
272            DeleteUser { user_id: u32 },
273        }
274
275        #[derive(Debug, Serialize, Deserialize)]
276        #[serde(untagged)]
277        enum HeterogeneousResponses {
278            Post(Post),
279            User(User),
280            DeleteResult(DeleteResult),
281        }
282
283        #[derive(Debug, Serialize, Deserialize)]
284        struct DeleteResult {
285            success: bool,
286        }
287
288        #[derive(Debug, Serialize, Deserialize)]
289        struct Post {
290            id: u32,
291            title: mayheap::String<32>,
292        }
293
294        #[derive(Debug, Serialize, Deserialize)]
295        #[serde(untagged)]
296        enum HeterogeneousErrors {
297            UserError(ApiError),
298            PostError(PostError),
299            DeleteError(DeleteError),
300        }
301
302        #[derive(Debug, Serialize, Deserialize)]
303        struct DeleteError {
304            reason: mayheap::String<64>,
305        }
306
307        #[derive(Debug, Serialize, Deserialize)]
308        struct PostError {
309            message: mayheap::String<64>,
310        }
311
312        let responses = [
313            r#"{"parameters":{"id":1}}"#,
314            r#"{"parameters":{"id":123,"title":"Test Post"}}"#,
315            r#"{"parameters":{"success":true}}"#,
316        ];
317        let socket = MockSocket::new(&responses);
318        let mut conn = Connection::new(socket);
319
320        let get_user_call = Call::new(HeterogeneousMethods::GetUser { id: 1 });
321        let get_post_call = Call::new(HeterogeneousMethods::GetPost { post_id: 123 });
322        let delete_user_call = Call::new(HeterogeneousMethods::DeleteUser { user_id: 456 });
323
324        let replies = conn
325            .chain_call::<HeterogeneousMethods, HeterogeneousResponses, HeterogeneousErrors>(
326                &get_user_call,
327            )?
328            .append(&get_post_call)?
329            .append(&delete_user_call)?
330            .send()
331            .await?;
332
333        use futures_util::stream::StreamExt;
334        pin_mut!(replies);
335
336        // First response: User
337        let user_response = replies.next().await.unwrap()?.unwrap();
338        if let HeterogeneousResponses::User(user) = user_response.parameters().unwrap() {
339            assert_eq!(user.id, 1);
340        } else {
341            panic!("Expected User response");
342        }
343
344        // Second response: Post
345        let post_response = replies.next().await.unwrap()?.unwrap();
346        if let HeterogeneousResponses::Post(post) = post_response.parameters().unwrap() {
347            assert_eq!(post.id, 123);
348            assert_eq!(post.title, "Test Post");
349        } else {
350            panic!("Expected Post response");
351        }
352
353        // Third response: DeleteResult
354        let delete_response = replies.next().await.unwrap()?.unwrap();
355        if let HeterogeneousResponses::DeleteResult(result) = delete_response.parameters().unwrap()
356        {
357            assert!(result.success);
358        } else {
359            panic!("Expected DeleteResult response");
360        }
361
362        // No more replies should be available.
363        let no_reply = replies.next().await;
364        assert!(no_reply.is_none());
365        Ok(())
366    }
367}