tx5_connection/
framed.rs

1use crate::*;
2
3enum Cmd {
4    Recv(Vec<u8>),
5    AwaitPermit {
6        await_registered: tokio::sync::oneshot::Sender<()>,
7        got_permit: tokio::sync::oneshot::Sender<()>,
8    },
9    RemotePermit(tokio::sync::OwnedSemaphorePermit, u32),
10    Close,
11}
12
13/// Receive a framed message on the connection.
14pub struct FramedConnRecv(tokio::sync::mpsc::Receiver<Vec<u8>>);
15
16impl FramedConnRecv {
17    /// Receive a framed message on the connection.
18    pub async fn recv(&mut self) -> Option<Vec<u8>> {
19        self.0.recv().await
20    }
21}
22
23/// A framed wrapper that can send and receive larger messages than
24/// the base connection.
25///
26/// If a message is under the frame limit, it is just sent.
27///
28/// If a message is OVER the frame limit, we instead request a permit
29/// to send the frame count to the remote peer. We only begin sending frames
30/// once we receive the permit to do so.
31///
32/// This allows individual peers to throttle the amount of pending-completion
33/// message memory they are allocating by only issuing permits up to a
34/// configured memory threshold.
35pub struct FramedConn {
36    pub_key: PubKey,
37    weak_conn: Weak<Conn>,
38    conn: tokio::sync::Mutex<Arc<Conn>>,
39    cmd_send: tokio::sync::mpsc::Sender<Cmd>,
40    recv_task: tokio::task::JoinHandle<()>,
41    cmd_task: tokio::task::JoinHandle<()>,
42}
43
44impl Drop for FramedConn {
45    fn drop(&mut self) {
46        self.recv_task.abort();
47        self.cmd_task.abort();
48    }
49}
50
51impl FramedConn {
52    /// Construct a new framed wrapper around the base connection.
53    pub async fn new(
54        conn: Arc<Conn>,
55        mut conn_recv: ConnRecv,
56        recv_limit: Arc<tokio::sync::Semaphore>,
57    ) -> Result<(Self, FramedConnRecv)> {
58        conn.ready().await;
59
60        // send the protocol header
61        let (a, b, c, d) = crate::proto::PROTO_VER_2.encode()?;
62        conn.send(vec![a, b, c, d]).await?;
63
64        let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32);
65        let (msg_send, msg_recv) = tokio::sync::mpsc::channel(32);
66
67        // set up the receive to just feed straight into the cmd task
68        let cmd_send2 = cmd_send.clone();
69        let recv_task = tokio::task::spawn(async move {
70            while let Some(msg) = conn_recv.recv().await {
71                if cmd_send2.send(Cmd::Recv(msg)).await.is_err() {
72                    break;
73                }
74            }
75
76            let _ = cmd_send2.send(Cmd::Close).await;
77        });
78
79        let pub_key = conn.pub_key().clone();
80
81        // set up the cmd task.
82        // this is the main event loop of the framed wrapper
83        let pub_key2 = pub_key.clone();
84        let cmd_send2 = cmd_send.clone();
85        let weak_conn = Arc::downgrade(&conn);
86        let cmd_task = tokio::task::spawn(async move {
87            // init the stateful protocol decoder
88            let mut dec = crate::proto::ProtoDecoder::default();
89
90            while let Some(cmd) = cmd_recv.recv().await {
91                match cmd {
92                    Cmd::Recv(msg) => {
93                        use crate::proto::ProtoDecodeResult::*;
94                        match dec.decode(&msg) {
95                            Err(_) => break,
96                            Ok(Idle) => (),
97                            Ok(Message(msg)) => {
98                                // received a message, forward to receiver
99                                tracing::trace!(
100                                    target: "NETAUDIT",
101                                    pub_key = ?pub_key2,
102                                    byte_count = msg.len(),
103                                    m = "tx5-connection",
104                                    a = "recv_framed",
105                                );
106                                if msg_send.send(msg).await.is_err() {
107                                    tracing::info!("FramedConnRecv closed, stopping cmd task");
108                                    break;
109                                }
110                            }
111                            Ok(RemotePermitRequest(permit_len)) => {
112                                // receive a permit request,
113                                // await the semaphore outside this loop,
114                                // the semaphore permits will be issued
115                                // in the order the request come in
116                                let recv_limit = recv_limit.clone();
117                                let cmd_send = cmd_send2.clone();
118                                // fire and forget
119                                tokio::task::spawn(async move {
120                                    if let Ok(permit) = recv_limit
121                                        .acquire_many_owned(permit_len)
122                                        .await
123                                    {
124                                        let _ = cmd_send
125                                            .send(Cmd::RemotePermit(
126                                                permit, permit_len,
127                                            ))
128                                            .await;
129                                    }
130                                });
131                            }
132                            Ok(RemotePermitGrant(_)) => (),
133                        }
134                    }
135                    Cmd::AwaitPermit {
136                        await_registered,
137                        got_permit,
138                    } => {
139                        // register a oneshot to be triggered when a
140                        // permit request is responded to
141
142                        // we register the oneshot request with the
143                        // stateful decoder
144                        if dec
145                            .sent_remote_permit_request(Some(got_permit))
146                            .is_err()
147                        {
148                            break;
149                        }
150
151                        // we also notify the caller that we registered it
152                        let _ = await_registered.send(());
153                    }
154                    Cmd::RemotePermit(permit, permit_len) => {
155                        // our semaphore has granted a permit for the
156                        // remote peer to begin sending us data
157                        // now we need to notify them of that fact
158
159                        // our stateful decoder also needs to know about it
160                        if dec.sent_remote_permit_grant(permit).is_err() {
161                            break;
162                        }
163
164                        // now notify our peer
165                        if let Some(conn) = weak_conn.upgrade() {
166                            let (a, b, c, d) =
167                                match crate::proto::ProtoHeader::PermitGrant(
168                                    permit_len,
169                                )
170                                .encode()
171                                {
172                                    Ok(r) => r,
173                                    Err(_) => break,
174                                };
175                            if conn.send(vec![a, b, c, d]).await.is_err() {
176                                break;
177                            }
178                        } else {
179                            break;
180                        }
181                    }
182                    Cmd::Close => break,
183                }
184            }
185        });
186
187        Ok((
188            Self {
189                pub_key,
190                weak_conn: Arc::downgrade(&conn),
191                conn: tokio::sync::Mutex::new(conn),
192                cmd_send,
193                recv_task,
194                cmd_task,
195            },
196            FramedConnRecv(msg_recv),
197        ))
198    }
199
200    /// The pub key of the remote peer this is connected to.
201    pub fn pub_key(&self) -> &PubKey {
202        &self.pub_key
203    }
204
205    /// Returns `true` if we successfully connected over webrtc.
206    pub fn is_using_webrtc(&self) -> bool {
207        if let Some(conn) = self.weak_conn.upgrade() {
208            conn.is_using_webrtc()
209        } else {
210            false
211        }
212    }
213
214    /// Get connection statistics.
215    pub fn get_stats(&self) -> ConnStats {
216        if let Some(conn) = self.weak_conn.upgrade() {
217            conn.get_stats()
218        } else {
219            ConnStats::default()
220        }
221    }
222
223    /// Send a message on the connection.
224    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
225        let byte_count = msg.len();
226        match self.send_inner(msg).await {
227            Ok(_) => {
228                tracing::trace!(
229                    target: "NETAUDIT",
230                    pub_key = ?self.pub_key,
231                    byte_count,
232                    m = "tx5-connection",
233                    a = "send_framed_success",
234                );
235                Ok(())
236            }
237            Err(err) => {
238                tracing::debug!(
239                    target: "NETAUDIT",
240                    pub_key = ?self.pub_key,
241                    byte_count,
242                    ?err,
243                    m = "tx5-connection",
244                    a = "send_framed_error",
245                );
246                Err(err)
247            }
248        }
249    }
250
251    /// Helper to do the sending, breaking up the messages if needed and
252    /// awaiting the permit before sending broken up messages.
253    async fn send_inner(&self, msg: Vec<u8>) -> Result<()> {
254        let conn = self.conn.lock().await;
255
256        match crate::proto::proto_encode(&msg)? {
257            crate::proto::ProtoEncodeResult::OneMessage(msg) => {
258                // it's a small message, just send it as one chunk
259                conn.send(msg).await?;
260            }
261            crate::proto::ProtoEncodeResult::NeedPermit {
262                permit_req,
263                msg_payload,
264            } => {
265                // it's a big message, we've got chunks
266
267                let (s_reg, r_reg) = tokio::sync::oneshot::channel();
268                let (s_perm, r_perm) = tokio::sync::oneshot::channel();
269
270                // coordinate with the cmd task that we need a permit
271                self.cmd_send
272                    .send(Cmd::AwaitPermit {
273                        await_registered: s_reg,
274                        got_permit: s_perm,
275                    })
276                    .await
277                    .map_err(|_| Error::other("closed"))?;
278
279                // wait for the want permit to be registered
280                r_reg.await.map_err(|_| Error::other("closed"))?;
281
282                // send the permit request
283                conn.send(permit_req).await?;
284
285                // wait for the permit to be authorized by the peer
286                r_perm.await.map_err(|_| Error::other("closed"))?;
287
288                // send the chunked messages
289                for msg in msg_payload {
290                    conn.send(msg).await?;
291                }
292            }
293        }
294
295        Ok(())
296    }
297}