sl_mpc_mate/coord/
buffered.rs1use std::{
5 ops::{Deref, DerefMut},
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use crate::coord::*;
11
12pub struct BufferedMsgRelay<R: Relay> {
13 relay: R,
14 in_buf: Vec<Vec<u8>>,
15}
16
17impl<R: Relay> BufferedMsgRelay<R> {
18 pub fn new(relay: R) -> Self {
20 Self {
21 relay,
22 in_buf: vec![],
23 }
24 }
25
26 pub fn with_capacity(relay: R, capacity: usize) -> Self {
27 Self {
28 relay,
29 in_buf: Vec::with_capacity(capacity),
30 }
31 }
32
33 pub async fn wait_for(
35 &mut self,
36 predicate: impl Fn(&MsgId) -> bool,
37 ) -> Option<Vec<u8>> {
38 if let Some(idx) = self.in_buf.iter().position(|msg| {
40 <&MsgHdr>::try_from(msg.as_slice())
41 .ok()
42 .filter(|hdr| predicate(hdr.id()))
43 .is_some()
44 }) {
45 return Some(self.in_buf.swap_remove(idx));
47 }
48
49 self.relay.flush().await.ok()?;
51
52 loop {
53 let msg = self.relay.next().await?;
54
55 if let Ok(hdr) = <&MsgHdr>::try_from(msg.as_slice()) {
56 if predicate(hdr.id()) {
57 return Some(msg);
59 } else {
60 self.in_buf.push(msg);
62 }
63 }
64 }
65 }
66
67 pub async fn recv(&mut self, id: &MsgId, ttl: u32) -> Option<Vec<u8>> {
69 self.relay.ask(id, ttl).await.ok()?;
70 self.wait_for(|msg| msg.eq(id)).await
71 }
72
73 pub fn buffered(&self) -> impl Iterator<Item = &[u8]> {
75 self.in_buf.iter().map(|m| m.as_ref())
76 }
77
78 pub fn buffered_mut(&mut self) -> impl Iterator<Item = &mut [u8]> {
80 self.in_buf.iter_mut().map(|m| m.as_mut())
81 }
82}
83
84impl<R: Relay> Stream for BufferedMsgRelay<R> {
85 type Item = R::Item;
86
87 fn poll_next(
88 self: Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 ) -> Poll<Option<Self::Item>> {
91 let this = self.get_mut();
92 if let Some(msg) = this.in_buf.pop() {
93 Poll::Ready(Some(msg))
94 } else {
95 this.relay.poll_next_unpin(cx)
96 }
97 }
98}
99
100impl<R: Relay> Sink<Vec<u8>> for BufferedMsgRelay<R> {
101 type Error = R::Error;
102
103 fn poll_ready(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 ) -> Poll<Result<(), Self::Error>> {
107 self.get_mut().relay.poll_ready_unpin(cx)
108 }
109
110 fn start_send(
111 self: Pin<&mut Self>,
112 item: Vec<u8>,
113 ) -> Result<(), Self::Error> {
114 self.get_mut().relay.start_send_unpin(item)
115 }
116
117 fn poll_flush(
118 self: Pin<&mut Self>,
119 cx: &mut Context<'_>,
120 ) -> Poll<Result<(), Self::Error>> {
121 self.get_mut().relay.poll_flush_unpin(cx)
122 }
123
124 fn poll_close(
125 self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 ) -> Poll<Result<(), Self::Error>> {
128 self.get_mut().relay.poll_close_unpin(cx)
129 }
130}
131
132impl<R: Relay> Relay for BufferedMsgRelay<R> {}
133
134impl<R: Relay> Deref for BufferedMsgRelay<R> {
135 type Target = R;
136
137 fn deref(&self) -> &Self::Target {
138 &self.relay
139 }
140}
141
142impl<R: Relay> DerefMut for BufferedMsgRelay<R> {
143 fn deref_mut(&mut self) -> &mut Self::Target {
144 &mut self.relay
145 }
146}