1mod reply_stream;
4pub use reply_stream::ReplyStream;
5
6use crate::{connection::Socket, Call, Connection, Result};
7use core::fmt::Debug;
8use serde::{de::DeserializeOwned, Serialize};
9
10#[derive(Debug)]
21pub struct Chain<'c, S: Socket> {
22 pub(super) connection: &'c mut Connection<S>,
23 pub(super) call_count: usize,
24 pub(super) reply_count: usize,
25}
26
27impl<'c, S> Chain<'c, S>
28where
29 S: Socket,
30{
31 pub(super) fn new<Method>(
33 connection: &'c mut Connection<S>,
34 call: &Call<Method>,
35 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
36 ) -> Result<Self>
37 where
38 Method: Serialize + Debug,
39 {
40 #[cfg(feature = "std")]
41 connection.write.enqueue_call(call, fds)?;
42 #[cfg(not(feature = "std"))]
43 connection.write.enqueue_call(call)?;
44
45 let reply_count = if call.oneway() { 0 } else { 1 };
46 Ok(Chain {
47 connection,
48 call_count: 1,
49 reply_count,
50 })
51 }
52
53 pub fn append<Method>(
63 mut self,
64 call: &Call<Method>,
65 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
66 ) -> Result<Self>
67 where
68 Method: Serialize + Debug,
69 {
70 #[cfg(feature = "std")]
71 self.connection.write.enqueue_call(call, fds)?;
72 #[cfg(not(feature = "std"))]
73 self.connection.write.enqueue_call(call)?;
74
75 if !call.oneway() {
76 self.reply_count += 1;
77 };
78 self.call_count += 1;
79 Ok(self)
80 }
81
82 pub async fn send<ReplyParams, ReplyError>(
89 self,
90 ) -> Result<ReplyStream<'c, ReplyParams, ReplyError>>
91 where
92 ReplyParams: DeserializeOwned + Debug,
93 ReplyError: DeserializeOwned + Debug,
94 {
95 self.connection.write.flush().await?;
97
98 Ok(ReplyStream::new(
99 self.connection.read_mut(),
100 self.reply_count,
101 ))
102 }
103
104 pub async fn send_ignore_replies(self) -> Result<()> {
111 use futures_util::StreamExt;
112 use serde::de::IgnoredAny;
113
114 let replies = self.send::<IgnoredAny, IgnoredAny>().await?;
115 futures_util::pin_mut!(replies);
116
117 while let Some(result) = replies.next().await {
118 let _ = result?;
119 }
120
121 Ok(())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::Call;
129 use futures_util::pin_mut;
130 use serde::{Deserialize, Serialize};
131
132 #[derive(Debug, Serialize, Deserialize)]
133 struct GetUser {
134 id: u32,
135 }
136
137 #[derive(Debug, Serialize, Deserialize)]
138 struct User {
139 id: u32,
140 }
141
142 #[derive(Debug, Serialize, Deserialize)]
143 struct ApiError {
144 code: i32,
145 }
146
147 use crate::test_utils::mock_socket::MockSocket;
149
150 #[tokio::test]
151 async fn homogeneous_calls() -> crate::Result<()> {
152 let responses = [r#"{"parameters":{"id":1}}"#, r#"{"parameters":{"id":2}}"#];
153 let socket = MockSocket::with_responses(&responses);
154 let mut conn = Connection::new(socket);
155
156 let call1 = Call::new(GetUser { id: 1 });
157 let call2 = Call::new(GetUser { id: 2 });
158
159 #[cfg(feature = "std")]
160 let replies = conn
161 .chain_call::<GetUser>(&call1, vec![])?
162 .append(&call2, vec![])?
163 .send::<User, ApiError>()
164 .await?;
165 #[cfg(not(feature = "std"))]
166 let replies = conn
167 .chain_call::<GetUser>(&call1)?
168 .append(&call2)?
169 .send::<User, ApiError>()
170 .await?;
171
172 use futures_util::stream::StreamExt;
173 pin_mut!(replies);
174
175 #[cfg(feature = "std")]
176 {
177 let (user1, _fds) = replies.next().await.unwrap()?;
178 let user1 = user1.unwrap();
179 assert_eq!(user1.parameters().unwrap().id, 1);
180
181 let (user2, _fds) = replies.next().await.unwrap()?;
182 let user2 = user2.unwrap();
183 assert_eq!(user2.parameters().unwrap().id, 2);
184 }
185 #[cfg(not(feature = "std"))]
186 {
187 let user1 = replies.next().await.unwrap()?.unwrap();
188 assert_eq!(user1.parameters().unwrap().id, 1);
189
190 let user2 = replies.next().await.unwrap()?.unwrap();
191 assert_eq!(user2.parameters().unwrap().id, 2);
192 }
193
194 let no_reply = replies.next().await;
196 assert!(no_reply.is_none());
197 Ok(())
198 }
199
200 #[tokio::test]
201 async fn oneway_calls_no_reply() -> crate::Result<()> {
202 let responses = [r#"{"parameters":{"id":1}}"#];
204 let socket = MockSocket::with_responses(&responses);
205 let mut conn = Connection::new(socket);
206
207 let get_user = Call::new(GetUser { id: 1 });
208 let oneway_call = Call::new(GetUser { id: 2 }).set_oneway(true);
209
210 #[cfg(feature = "std")]
211 let replies = conn
212 .chain_call::<GetUser>(&get_user, vec![])?
213 .append(&oneway_call, vec![])?
214 .send::<User, ApiError>()
215 .await?;
216 #[cfg(not(feature = "std"))]
217 let replies = conn
218 .chain_call::<GetUser>(&get_user)?
219 .append(&oneway_call)?
220 .send::<User, ApiError>()
221 .await?;
222
223 use futures_util::stream::StreamExt;
224 pin_mut!(replies);
225
226 #[cfg(feature = "std")]
227 {
228 let (user, _fds) = replies.next().await.unwrap()?;
229 let user = user.unwrap();
230 assert_eq!(user.parameters().unwrap().id, 1);
231 }
232 #[cfg(not(feature = "std"))]
233 {
234 let user = replies.next().await.unwrap()?.unwrap();
235 assert_eq!(user.parameters().unwrap().id, 1);
236 }
237
238 let no_reply = replies.next().await;
240 assert!(no_reply.is_none());
241 Ok(())
242 }
243
244 #[tokio::test]
245 async fn more_calls_with_streaming() -> crate::Result<()> {
246 let responses = [
247 r#"{"parameters":{"id":1},"continues":true}"#,
248 r#"{"parameters":{"id":2},"continues":true}"#,
249 r#"{"parameters":{"id":3},"continues":false}"#,
250 r#"{"parameters":{"id":4}}"#,
251 ];
252 let socket = MockSocket::with_responses(&responses);
253 let mut conn = Connection::new(socket);
254
255 let more_call = Call::new(GetUser { id: 1 }).set_more(true);
256 let regular_call = Call::new(GetUser { id: 2 });
257
258 #[cfg(feature = "std")]
259 let replies = conn
260 .chain_call::<GetUser>(&more_call, vec![])?
261 .append(®ular_call, vec![])?
262 .send::<User, ApiError>()
263 .await?;
264 #[cfg(not(feature = "std"))]
265 let replies = conn
266 .chain_call::<GetUser>(&more_call)?
267 .append(®ular_call)?
268 .send::<User, ApiError>()
269 .await?;
270
271 use futures_util::stream::StreamExt;
272 pin_mut!(replies);
273
274 #[cfg(feature = "std")]
276 {
277 let (user1, _fds) = replies.next().await.unwrap()?;
278 let user1 = user1.unwrap();
279 assert_eq!(user1.parameters().unwrap().id, 1);
280 assert_eq!(user1.continues(), Some(true));
281
282 let (user2, _fds) = replies.next().await.unwrap()?;
283 let user2 = user2.unwrap();
284 assert_eq!(user2.parameters().unwrap().id, 2);
285 assert_eq!(user2.continues(), Some(true));
286
287 let (user3, _fds) = replies.next().await.unwrap()?;
288 let user3 = user3.unwrap();
289 assert_eq!(user3.parameters().unwrap().id, 3);
290 assert_eq!(user3.continues(), Some(false));
291
292 let (user4, _fds) = replies.next().await.unwrap()?;
294 let user4 = user4.unwrap();
295 assert_eq!(user4.parameters().unwrap().id, 4);
296 assert_eq!(user4.continues(), None);
297 }
298 #[cfg(not(feature = "std"))]
299 {
300 let user1 = replies.next().await.unwrap()?.unwrap();
301 assert_eq!(user1.parameters().unwrap().id, 1);
302 assert_eq!(user1.continues(), Some(true));
303
304 let user2 = replies.next().await.unwrap()?.unwrap();
305 assert_eq!(user2.parameters().unwrap().id, 2);
306 assert_eq!(user2.continues(), Some(true));
307
308 let user3 = replies.next().await.unwrap()?.unwrap();
309 assert_eq!(user3.parameters().unwrap().id, 3);
310 assert_eq!(user3.continues(), Some(false));
311
312 let user4 = replies.next().await.unwrap()?.unwrap();
314 assert_eq!(user4.parameters().unwrap().id, 4);
315 assert_eq!(user4.continues(), None);
316 }
317
318 let no_reply = replies.next().await;
320 assert!(no_reply.is_none());
321 Ok(())
322 }
323
324 #[tokio::test]
325 async fn stream_interface_works() -> crate::Result<()> {
326 use futures_util::stream::StreamExt;
327
328 let responses = [
329 r#"{"parameters":{"id":1}}"#,
330 r#"{"parameters":{"id":2}}"#,
331 r#"{"parameters":{"id":3}}"#,
332 ];
333 let socket = MockSocket::with_responses(&responses);
334 let mut conn = Connection::new(socket);
335
336 let call1 = Call::new(GetUser { id: 1 });
337 let call2 = Call::new(GetUser { id: 2 });
338 let call3 = Call::new(GetUser { id: 3 });
339
340 #[cfg(feature = "std")]
341 let replies = conn
342 .chain_call::<GetUser>(&call1, vec![])?
343 .append(&call2, vec![])?
344 .append(&call3, vec![])?
345 .send::<User, ApiError>()
346 .await?;
347 #[cfg(not(feature = "std"))]
348 let replies = conn
349 .chain_call::<GetUser>(&call1)?
350 .append(&call2)?
351 .append(&call3)?
352 .send::<User, ApiError>()
353 .await?;
354
355 pin_mut!(replies);
357 let results: Vec<_> = replies.collect().await;
358 assert_eq!(results.len(), 3);
359
360 #[cfg(feature = "std")]
362 for (i, result) in results.into_iter().enumerate() {
363 let (reply, _fds) = result?;
364 let user = reply.unwrap();
365 assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
366 }
367 #[cfg(not(feature = "std"))]
368 for (i, result) in results.into_iter().enumerate() {
369 let user = result?.unwrap();
370 assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
371 }
372
373 Ok(())
374 }
375
376 #[tokio::test]
377 async fn heterogeneous_calls() -> crate::Result<()> {
378 #[derive(Debug, Serialize, Deserialize)]
380 #[serde(tag = "method")]
381 enum HeterogeneousMethods {
382 GetUser { id: u32 },
383 GetPost { post_id: u32 },
384 DeleteUser { user_id: u32 },
385 }
386
387 #[derive(Debug, Serialize, Deserialize)]
388 #[serde(untagged)]
389 enum HeterogeneousResponses {
390 Post(Post),
391 User(User),
392 DeleteResult(DeleteResult),
393 }
394
395 #[derive(Debug, Serialize, Deserialize)]
396 struct DeleteResult {
397 success: bool,
398 }
399
400 #[derive(Debug, Serialize, Deserialize)]
401 struct Post {
402 id: u32,
403 title: String,
404 }
405
406 #[derive(Debug, Serialize, Deserialize)]
407 #[serde(untagged)]
408 enum HeterogeneousErrors {
409 UserError(ApiError),
410 PostError(PostError),
411 DeleteError(DeleteError),
412 }
413
414 #[derive(Debug, Serialize, Deserialize)]
415 struct DeleteError {
416 reason: String,
417 }
418
419 #[derive(Debug, Serialize, Deserialize)]
420 struct PostError {
421 message: String,
422 }
423
424 let responses = [
425 r#"{"parameters":{"id":1}}"#,
426 r#"{"parameters":{"id":123,"title":"Test Post"}}"#,
427 r#"{"parameters":{"success":true}}"#,
428 ];
429 let socket = MockSocket::with_responses(&responses);
430 let mut conn = Connection::new(socket);
431
432 let get_user_call = Call::new(HeterogeneousMethods::GetUser { id: 1 });
433 let get_post_call = Call::new(HeterogeneousMethods::GetPost { post_id: 123 });
434 let delete_user_call = Call::new(HeterogeneousMethods::DeleteUser { user_id: 456 });
435
436 #[cfg(feature = "std")]
437 let replies = conn
438 .chain_call::<HeterogeneousMethods>(&get_user_call, vec![])?
439 .append(&get_post_call, vec![])?
440 .append(&delete_user_call, vec![])?
441 .send::<HeterogeneousResponses, HeterogeneousErrors>()
442 .await?;
443 #[cfg(not(feature = "std"))]
444 let replies = conn
445 .chain_call::<HeterogeneousMethods>(&get_user_call)?
446 .append(&get_post_call)?
447 .append(&delete_user_call)?
448 .send::<HeterogeneousResponses, HeterogeneousErrors>()
449 .await?;
450
451 use futures_util::stream::StreamExt;
452 pin_mut!(replies);
453
454 #[cfg(feature = "std")]
455 {
456 let (user_response, _fds) = replies.next().await.unwrap()?;
458 let user_response = user_response.unwrap();
459 if let HeterogeneousResponses::User(user) = user_response.parameters().unwrap() {
460 assert_eq!(user.id, 1);
461 } else {
462 panic!("Expected User response");
463 }
464
465 let (post_response, _fds) = replies.next().await.unwrap()?;
467 let post_response = post_response.unwrap();
468 if let HeterogeneousResponses::Post(post) = post_response.parameters().unwrap() {
469 assert_eq!(post.id, 123);
470 assert_eq!(post.title, "Test Post");
471 } else {
472 panic!("Expected Post response");
473 }
474
475 let (delete_response, _fds) = replies.next().await.unwrap()?;
477 let delete_response = delete_response.unwrap();
478 if let HeterogeneousResponses::DeleteResult(result) =
479 delete_response.parameters().unwrap()
480 {
481 assert!(result.success);
482 } else {
483 panic!("Expected DeleteResult response");
484 }
485 }
486 #[cfg(not(feature = "std"))]
487 {
488 let user_response = replies.next().await.unwrap()?.unwrap();
490 if let HeterogeneousResponses::User(user) = user_response.parameters().unwrap() {
491 assert_eq!(user.id, 1);
492 } else {
493 panic!("Expected User response");
494 }
495
496 let post_response = replies.next().await.unwrap()?.unwrap();
498 if let HeterogeneousResponses::Post(post) = post_response.parameters().unwrap() {
499 assert_eq!(post.id, 123);
500 assert_eq!(post.title, "Test Post");
501 } else {
502 panic!("Expected Post response");
503 }
504
505 let delete_response = replies.next().await.unwrap()?.unwrap();
507 if let HeterogeneousResponses::DeleteResult(result) =
508 delete_response.parameters().unwrap()
509 {
510 assert!(result.success);
511 } else {
512 panic!("Expected DeleteResult response");
513 }
514 }
515
516 let no_reply = replies.next().await;
518 assert!(no_reply.is_none());
519 Ok(())
520 }
521
522 #[tokio::test]
523 async fn chain_from_iter() -> crate::Result<()> {
524 use futures_util::stream::StreamExt;
525
526 let responses = [
527 r#"{"parameters":{"id":1}}"#,
528 r#"{"parameters":{"id":2}}"#,
529 r#"{"parameters":{"id":3}}"#,
530 ];
531 let socket = MockSocket::with_responses(&responses);
532 let mut conn = Connection::new(socket);
533
534 let replies = conn
535 .chain_from_iter::<GetUser, _, _>((1..=3).map(|id| GetUser { id }))?
536 .send::<User, ApiError>()
537 .await?;
538
539 pin_mut!(replies);
540 let results: Vec<_> = replies.collect().await;
541 assert_eq!(results.len(), 3);
542
543 #[cfg(feature = "std")]
544 for (i, result) in results.into_iter().enumerate() {
545 let (reply, _fds) = result?;
546 let user = reply.unwrap();
547 assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
548 }
549 #[cfg(not(feature = "std"))]
550 for (i, result) in results.into_iter().enumerate() {
551 let user = result?.unwrap();
552 assert_eq!(user.parameters().unwrap().id, (i + 1) as u32);
553 }
554
555 Ok(())
556 }
557
558 #[tokio::test]
559 async fn chain_from_iter_with_calls() -> crate::Result<()> {
560 use futures_util::stream::StreamExt;
561
562 let responses = [r#"{"parameters":{"id":1}}"#, r#"{"parameters":{"id":2}}"#];
563 let socket = MockSocket::with_responses(&responses);
564 let mut conn = Connection::new(socket);
565
566 let calls = vec![Call::new(GetUser { id: 1 }), Call::new(GetUser { id: 2 })];
567
568 let replies = conn
569 .chain_from_iter::<GetUser, _, _>(calls)?
570 .send::<User, ApiError>()
571 .await?;
572
573 pin_mut!(replies);
574 let results: Vec<_> = replies.collect().await;
575 assert_eq!(results.len(), 2);
576
577 Ok(())
578 }
579
580 #[tokio::test]
581 async fn chain_from_empty_iter_fails() -> crate::Result<()> {
582 let socket = MockSocket::with_responses(&[]);
583 let mut conn = Connection::new(socket);
584
585 let methods: Vec<GetUser> = vec![];
586
587 let result = conn.chain_from_iter::<GetUser, _, _>(methods);
588
589 assert!(matches!(result, Err(crate::Error::EmptyChain)));
590 Ok(())
591 }
592
593 #[cfg(feature = "std")]
594 #[tokio::test]
595 async fn chain_from_iter_with_fds() -> crate::Result<()> {
596 use crate::{
597 connection::socket::{ReadHalf, WriteHalf},
598 test_utils::mock_socket::MockWriteHalf,
599 };
600 use futures_util::stream::StreamExt;
601 use rustix::{fd::AsFd, io::write};
602 use std::os::unix::net::UnixStream;
603
604 let (send1_r, send1_w) = UnixStream::pair().unwrap();
606 let (send2_r, send2_w) = UnixStream::pair().unwrap();
607 write(send1_w.as_fd(), b"send1").unwrap();
608 write(send2_w.as_fd(), b"send2").unwrap();
609
610 let responses = [r#"{"parameters":{"id":1}}"#, r#"{"parameters":{"id":2}}"#];
611 let socket = MockSocket::new(&responses, vec![]);
612 let (read_half, write_half) = socket.split();
613
614 #[derive(Debug)]
616 struct TrackingSocket<R, W> {
617 read: R,
618 write: W,
619 }
620
621 impl<R: ReadHalf, W: WriteHalf> crate::connection::Socket for TrackingSocket<R, W> {
622 type ReadHalf = R;
623 type WriteHalf = W;
624
625 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
626 (self.read, self.write)
627 }
628 }
629
630 #[derive(Debug)]
631 struct TrackingWriteHalf {
632 mock: MockWriteHalf,
633 }
634
635 impl WriteHalf for TrackingWriteHalf {
636 async fn write(&mut self, buf: &[u8], fds: &[impl AsFd]) -> crate::Result<()> {
637 self.mock.write(buf, fds).await
638 }
639 }
640
641 let tracking_write = TrackingWriteHalf { mock: write_half };
642 let mut conn = Connection::new(TrackingSocket {
643 read: read_half,
644 write: tracking_write,
645 });
646
647 let calls_with_fds: Vec<(GetUser, Vec<std::os::fd::OwnedFd>)> = vec![
648 (GetUser { id: 1 }, vec![send1_r.into()]),
649 (GetUser { id: 2 }, vec![send2_r.into()]),
650 ];
651
652 let replies = conn
653 .chain_from_iter_with_fds::<GetUser, _, _>(calls_with_fds)?
654 .send::<User, ApiError>()
655 .await?;
656
657 let reply_results: Vec<_> = {
659 pin_mut!(replies);
660 replies.collect().await
661 };
662
663 let fds_written = conn.write_mut().socket.mock.fds_written();
665 assert_eq!(fds_written.len(), 2, "Should have written FDs twice");
666 assert_eq!(fds_written[0].len(), 1, "First call should send 1 FD");
667 assert_eq!(fds_written[1].len(), 1, "Second call should send 1 FD");
668
669 let mut buf = [0u8; 5];
671 rustix::io::read(fds_written[0][0].as_fd(), &mut buf).unwrap();
672 assert_eq!(&buf, b"send1");
673 rustix::io::read(fds_written[1][0].as_fd(), &mut buf).unwrap();
674 assert_eq!(&buf, b"send2");
675
676 assert_eq!(reply_results.len(), 2);
678 let (reply1, _) = reply_results[0].as_ref().unwrap();
679 assert_eq!(reply1.as_ref().unwrap().parameters().unwrap().id, 1);
680 let (reply2, _) = reply_results[1].as_ref().unwrap();
681 assert_eq!(reply2.as_ref().unwrap().parameters().unwrap().id, 2);
682
683 Ok(())
684 }
685
686 #[tokio::test]
687 async fn ignore_replies() -> crate::Result<()> {
688 let responses = [r#"{"parameters":{"id":1}}"#, r#"{"parameters":{"id":2}}"#];
689 let socket = MockSocket::with_responses(&responses);
690 let mut conn = Connection::new(socket);
691
692 let call1 = Call::new(GetUser { id: 1 });
693 let call2 = Call::new(GetUser { id: 2 });
694
695 #[cfg(feature = "std")]
696 conn.chain_call::<GetUser>(&call1, vec![])?
697 .append(&call2, vec![])?
698 .send_ignore_replies()
699 .await?;
700 #[cfg(not(feature = "std"))]
701 conn.chain_call::<GetUser>(&call1)?
702 .append(&call2)?
703 .send_ignore_replies()
704 .await?;
705
706 Ok(())
707 }
708}