sl_mpc_mate/coord/
buffered.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4use std::{
5    ops::{Deref, DerefMut},
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use crate::coord::*;
11
12pub struct BufferedMsgRelay<R: Relay> {
13    relay: R,
14    in_buf: Vec<Vec<u8>>,
15}
16
17impl<R: Relay> BufferedMsgRelay<R> {
18    /// Construct a BufferedMsgRelay by wrapping up a Relay object
19    pub fn new(relay: R) -> Self {
20        Self {
21            relay,
22            in_buf: vec![],
23        }
24    }
25
26    pub fn with_capacity(relay: R, capacity: usize) -> Self {
27        Self {
28            relay,
29            in_buf: Vec::with_capacity(capacity),
30        }
31    }
32
33    /// Wait for particular messages based on predicate.
34    pub async fn wait_for(
35        &mut self,
36        predicate: impl Fn(&MsgId) -> bool,
37    ) -> Option<Vec<u8>> {
38        // first, look into the input buffer
39        if let Some(idx) = self.in_buf.iter().position(|msg| {
40            <&MsgHdr>::try_from(msg.as_slice())
41                .ok()
42                .filter(|hdr| predicate(hdr.id()))
43                .is_some()
44        }) {
45            // there is a buffered message matching the predicate.
46            return Some(self.in_buf.swap_remove(idx));
47        }
48
49        // flush output message messages.
50        self.relay.flush().await.ok()?;
51
52        loop {
53            let msg = self.relay.next().await?;
54
55            if let Ok(hdr) = <&MsgHdr>::try_from(msg.as_slice()) {
56                if predicate(hdr.id()) {
57                    // good, return it
58                    return Some(msg);
59                } else {
60                    // push into the buffer and try again
61                    self.in_buf.push(msg);
62                }
63            }
64        }
65    }
66
67    /// Function to receive message based on certain ID
68    pub async fn recv(&mut self, id: &MsgId, ttl: u32) -> Option<Vec<u8>> {
69        self.relay.ask(id, ttl).await.ok()?;
70        self.wait_for(|msg| msg.eq(id)).await
71    }
72
73    /// Return all buffered messages
74    pub fn buffered(&self) -> impl Iterator<Item = &[u8]> {
75        self.in_buf.iter().map(|m| m.as_ref())
76    }
77
78    /// Return all buffered messages and allow change
79    pub fn buffered_mut(&mut self) -> impl Iterator<Item = &mut [u8]> {
80        self.in_buf.iter_mut().map(|m| m.as_mut())
81    }
82}
83
84impl<R: Relay> Stream for BufferedMsgRelay<R> {
85    type Item = R::Item;
86
87    fn poll_next(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90    ) -> Poll<Option<Self::Item>> {
91        let this = self.get_mut();
92        if let Some(msg) = this.in_buf.pop() {
93            Poll::Ready(Some(msg))
94        } else {
95            this.relay.poll_next_unpin(cx)
96        }
97    }
98}
99
100impl<R: Relay> Sink<Vec<u8>> for BufferedMsgRelay<R> {
101    type Error = R::Error;
102
103    fn poll_ready(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106    ) -> Poll<Result<(), Self::Error>> {
107        self.get_mut().relay.poll_ready_unpin(cx)
108    }
109
110    fn start_send(
111        self: Pin<&mut Self>,
112        item: Vec<u8>,
113    ) -> Result<(), Self::Error> {
114        self.get_mut().relay.start_send_unpin(item)
115    }
116
117    fn poll_flush(
118        self: Pin<&mut Self>,
119        cx: &mut Context<'_>,
120    ) -> Poll<Result<(), Self::Error>> {
121        self.get_mut().relay.poll_flush_unpin(cx)
122    }
123
124    fn poll_close(
125        self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127    ) -> Poll<Result<(), Self::Error>> {
128        self.get_mut().relay.poll_close_unpin(cx)
129    }
130}
131
132impl<R: Relay> Relay for BufferedMsgRelay<R> {}
133
134impl<R: Relay> Deref for BufferedMsgRelay<R> {
135    type Target = R;
136
137    fn deref(&self) -> &Self::Target {
138        &self.relay
139    }
140}
141
142impl<R: Relay> DerefMut for BufferedMsgRelay<R> {
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        &mut self.relay
145    }
146}