tx5_connection/
framed.rs

1use crate::*;
2
3enum Cmd {
4    Recv(Vec<u8>),
5    AwaitPermit {
6        await_registered: tokio::sync::oneshot::Sender<()>,
7        got_permit: tokio::sync::oneshot::Sender<()>,
8    },
9    RemotePermit(tokio::sync::OwnedSemaphorePermit, u32),
10    Close,
11}
12
13/// Receive a framed message on the connection.
14pub struct FramedConnRecv(tokio::sync::mpsc::Receiver<Vec<u8>>);
15
16impl FramedConnRecv {
17    /// Receive a framed message on the connection.
18    pub async fn recv(&mut self) -> Option<Vec<u8>> {
19        self.0.recv().await
20    }
21}
22
23/// A framed wrapper that can send and receive larger messages than
24/// the base connection.
25pub struct FramedConn {
26    pub_key: PubKey,
27    weak_conn: Weak<Conn>,
28    conn: tokio::sync::Mutex<Arc<Conn>>,
29    cmd_send: tokio::sync::mpsc::Sender<Cmd>,
30    recv_task: tokio::task::JoinHandle<()>,
31    cmd_task: tokio::task::JoinHandle<()>,
32}
33
34impl Drop for FramedConn {
35    fn drop(&mut self) {
36        self.recv_task.abort();
37        self.cmd_task.abort();
38    }
39}
40
41impl FramedConn {
42    /// Construct a new framed wrapper around the base connection.
43    pub async fn new(
44        conn: Arc<Conn>,
45        mut conn_recv: ConnRecv,
46        recv_limit: Arc<tokio::sync::Semaphore>,
47    ) -> Result<(Self, FramedConnRecv)> {
48        conn.ready().await;
49
50        let (a, b, c, d) = crate::proto::PROTO_VER_2.encode()?;
51        conn.send(vec![a, b, c, d]).await?;
52
53        let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32);
54        let (msg_send, msg_recv) = tokio::sync::mpsc::channel(32);
55
56        let cmd_send2 = cmd_send.clone();
57        let recv_task = tokio::task::spawn(async move {
58            while let Some(msg) = conn_recv.recv().await {
59                if cmd_send2.send(Cmd::Recv(msg)).await.is_err() {
60                    break;
61                }
62            }
63
64            let _ = cmd_send2.send(Cmd::Close).await;
65        });
66
67        let pub_key = conn.pub_key().clone();
68
69        let pub_key2 = pub_key.clone();
70        let cmd_send2 = cmd_send.clone();
71        let weak_conn = Arc::downgrade(&conn);
72        let cmd_task = tokio::task::spawn(async move {
73            let mut dec = crate::proto::ProtoDecoder::default();
74
75            while let Some(cmd) = cmd_recv.recv().await {
76                match cmd {
77                    Cmd::Recv(msg) => {
78                        use crate::proto::ProtoDecodeResult::*;
79                        match dec.decode(&msg) {
80                            Err(_) => break,
81                            Ok(Idle) => (),
82                            Ok(Message(msg)) => {
83                                tracing::trace!(
84                                    target: "NETAUDIT",
85                                    pub_key = ?pub_key2,
86                                    byte_count = msg.len(),
87                                    m = "tx5-connection",
88                                    a = "recv_framed",
89                                );
90                                if msg_send.send(msg).await.is_err() {
91                                    break;
92                                }
93                            }
94                            Ok(RemotePermitRequest(permit_len)) => {
95                                let recv_limit = recv_limit.clone();
96                                let cmd_send = cmd_send2.clone();
97                                // fire and forget
98                                tokio::task::spawn(async move {
99                                    if let Ok(permit) = recv_limit
100                                        .acquire_many_owned(permit_len)
101                                        .await
102                                    {
103                                        let _ = cmd_send
104                                            .send(Cmd::RemotePermit(
105                                                permit, permit_len,
106                                            ))
107                                            .await;
108                                    }
109                                });
110                            }
111                            Ok(RemotePermitGrant(_)) => (),
112                        }
113                    }
114                    Cmd::AwaitPermit {
115                        await_registered,
116                        got_permit,
117                    } => {
118                        if dec
119                            .sent_remote_permit_request(Some(got_permit))
120                            .is_err()
121                        {
122                            break;
123                        }
124                        let _ = await_registered.send(());
125                    }
126                    Cmd::RemotePermit(permit, permit_len) => {
127                        if dec.sent_remote_permit_grant(permit).is_err() {
128                            break;
129                        }
130                        if let Some(conn) = weak_conn.upgrade() {
131                            let (a, b, c, d) =
132                                match crate::proto::ProtoHeader::PermitGrant(
133                                    permit_len,
134                                )
135                                .encode()
136                                {
137                                    Ok(r) => r,
138                                    Err(_) => break,
139                                };
140                            if conn.send(vec![a, b, c, d]).await.is_err() {
141                                break;
142                            }
143                        } else {
144                            break;
145                        }
146                    }
147                    Cmd::Close => break,
148                }
149            }
150        });
151
152        Ok((
153            Self {
154                pub_key,
155                weak_conn: Arc::downgrade(&conn),
156                conn: tokio::sync::Mutex::new(conn),
157                cmd_send,
158                recv_task,
159                cmd_task,
160            },
161            FramedConnRecv(msg_recv),
162        ))
163    }
164
165    /// The pub key of the remote peer this is connected to.
166    pub fn pub_key(&self) -> &PubKey {
167        &self.pub_key
168    }
169
170    /// Returns `true` if we successfully connected over webrtc.
171    pub fn is_using_webrtc(&self) -> bool {
172        if let Some(conn) = self.weak_conn.upgrade() {
173            conn.is_using_webrtc()
174        } else {
175            false
176        }
177    }
178
179    /// Get connection statistics.
180    pub fn get_stats(&self) -> ConnStats {
181        if let Some(conn) = self.weak_conn.upgrade() {
182            conn.get_stats()
183        } else {
184            ConnStats::default()
185        }
186    }
187
188    /// Send a message on the connection.
189    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
190        let byte_count = msg.len();
191        match self.send_inner(msg).await {
192            Ok(_) => {
193                tracing::trace!(
194                    target: "NETAUDIT",
195                    pub_key = ?self.pub_key,
196                    byte_count,
197                    m = "tx5-connection",
198                    a = "send_framed_success",
199                );
200                Ok(())
201            }
202            Err(err) => {
203                tracing::debug!(
204                    target: "NETAUDIT",
205                    pub_key = ?self.pub_key,
206                    byte_count,
207                    ?err,
208                    m = "tx5-connection",
209                    a = "send_framed_error",
210                );
211                Err(err)
212            }
213        }
214    }
215
216    async fn send_inner(&self, msg: Vec<u8>) -> Result<()> {
217        let conn = self.conn.lock().await;
218
219        match crate::proto::proto_encode(&msg)? {
220            crate::proto::ProtoEncodeResult::OneMessage(msg) => {
221                conn.send(msg).await?;
222            }
223            crate::proto::ProtoEncodeResult::NeedPermit {
224                permit_req,
225                msg_payload,
226            } => {
227                let (s_reg, r_reg) = tokio::sync::oneshot::channel();
228                let (s_perm, r_perm) = tokio::sync::oneshot::channel();
229
230                self.cmd_send
231                    .send(Cmd::AwaitPermit {
232                        await_registered: s_reg,
233                        got_permit: s_perm,
234                    })
235                    .await
236                    .map_err(|_| Error::other("closed"))?;
237
238                r_reg.await.map_err(|_| Error::other("closed"))?;
239
240                conn.send(permit_req).await?;
241
242                r_perm.await.map_err(|_| Error::other("closed"))?;
243
244                for msg in msg_payload {
245                    conn.send(msg).await?;
246                }
247            }
248        }
249
250        Ok(())
251    }
252}