tx5_connection/framed.rs
1use crate::*;
2
3enum Cmd {
4 Recv(Vec<u8>),
5 AwaitPermit {
6 await_registered: tokio::sync::oneshot::Sender<()>,
7 got_permit: tokio::sync::oneshot::Sender<()>,
8 },
9 RemotePermit(tokio::sync::OwnedSemaphorePermit, u32),
10 Close,
11}
12
13/// Receive a framed message on the connection.
14pub struct FramedConnRecv(tokio::sync::mpsc::Receiver<Vec<u8>>);
15
16impl FramedConnRecv {
17 /// Receive a framed message on the connection.
18 pub async fn recv(&mut self) -> Option<Vec<u8>> {
19 self.0.recv().await
20 }
21}
22
23/// A framed wrapper that can send and receive larger messages than
24/// the base connection.
25///
26/// If a message is under the frame limit, it is just sent.
27///
28/// If a message is OVER the frame limit, we instead request a permit
29/// to send the frame count to the remote peer. We only begin sending frames
30/// once we receive the permit to do so.
31///
32/// This allows individual peers to throttle the amount of pending-completion
33/// message memory they are allocating by only issuing permits up to a
34/// configured memory threshold.
35pub struct FramedConn {
36 pub_key: PubKey,
37 weak_conn: Weak<Conn>,
38 conn: tokio::sync::Mutex<Arc<Conn>>,
39 cmd_send: tokio::sync::mpsc::Sender<Cmd>,
40 recv_task: tokio::task::JoinHandle<()>,
41 cmd_task: tokio::task::JoinHandle<()>,
42}
43
44impl Drop for FramedConn {
45 fn drop(&mut self) {
46 self.recv_task.abort();
47 self.cmd_task.abort();
48 }
49}
50
51impl FramedConn {
52 /// Construct a new framed wrapper around the base connection.
53 pub async fn new(
54 conn: Arc<Conn>,
55 mut conn_recv: ConnRecv,
56 recv_limit: Arc<tokio::sync::Semaphore>,
57 ) -> Result<(Self, FramedConnRecv)> {
58 conn.ready().await;
59
60 // send the protocol header
61 let (a, b, c, d) = crate::proto::PROTO_VER_2.encode()?;
62 conn.send(vec![a, b, c, d]).await?;
63
64 let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32);
65 let (msg_send, msg_recv) = tokio::sync::mpsc::channel(32);
66
67 // set up the receive to just feed straight into the cmd task
68 let cmd_send2 = cmd_send.clone();
69 let recv_task = tokio::task::spawn(async move {
70 while let Some(msg) = conn_recv.recv().await {
71 if cmd_send2.send(Cmd::Recv(msg)).await.is_err() {
72 break;
73 }
74 }
75
76 let _ = cmd_send2.send(Cmd::Close).await;
77 });
78
79 let pub_key = conn.pub_key().clone();
80
81 // set up the cmd task.
82 // this is the main event loop of the framed wrapper
83 let pub_key2 = pub_key.clone();
84 let cmd_send2 = cmd_send.clone();
85 let weak_conn = Arc::downgrade(&conn);
86 let cmd_task = tokio::task::spawn(async move {
87 // init the stateful protocol decoder
88 let mut dec = crate::proto::ProtoDecoder::default();
89
90 while let Some(cmd) = cmd_recv.recv().await {
91 match cmd {
92 Cmd::Recv(msg) => {
93 use crate::proto::ProtoDecodeResult::*;
94 match dec.decode(&msg) {
95 Err(_) => break,
96 Ok(Idle) => (),
97 Ok(Message(msg)) => {
98 // received a message, forward to receiver
99 tracing::trace!(
100 target: "NETAUDIT",
101 pub_key = ?pub_key2,
102 byte_count = msg.len(),
103 m = "tx5-connection",
104 a = "recv_framed",
105 );
106 if msg_send.send(msg).await.is_err() {
107 tracing::info!("FramedConnRecv closed, stopping cmd task");
108 break;
109 }
110 }
111 Ok(RemotePermitRequest(permit_len)) => {
112 // receive a permit request,
113 // await the semaphore outside this loop,
114 // the semaphore permits will be issued
115 // in the order the request come in
116 let recv_limit = recv_limit.clone();
117 let cmd_send = cmd_send2.clone();
118 // fire and forget
119 tokio::task::spawn(async move {
120 if let Ok(permit) = recv_limit
121 .acquire_many_owned(permit_len)
122 .await
123 {
124 let _ = cmd_send
125 .send(Cmd::RemotePermit(
126 permit, permit_len,
127 ))
128 .await;
129 }
130 });
131 }
132 Ok(RemotePermitGrant(_)) => (),
133 }
134 }
135 Cmd::AwaitPermit {
136 await_registered,
137 got_permit,
138 } => {
139 // register a oneshot to be triggered when a
140 // permit request is responded to
141
142 // we register the oneshot request with the
143 // stateful decoder
144 if dec
145 .sent_remote_permit_request(Some(got_permit))
146 .is_err()
147 {
148 break;
149 }
150
151 // we also notify the caller that we registered it
152 let _ = await_registered.send(());
153 }
154 Cmd::RemotePermit(permit, permit_len) => {
155 // our semaphore has granted a permit for the
156 // remote peer to begin sending us data
157 // now we need to notify them of that fact
158
159 // our stateful decoder also needs to know about it
160 if dec.sent_remote_permit_grant(permit).is_err() {
161 break;
162 }
163
164 // now notify our peer
165 if let Some(conn) = weak_conn.upgrade() {
166 let (a, b, c, d) =
167 match crate::proto::ProtoHeader::PermitGrant(
168 permit_len,
169 )
170 .encode()
171 {
172 Ok(r) => r,
173 Err(_) => break,
174 };
175 if conn.send(vec![a, b, c, d]).await.is_err() {
176 break;
177 }
178 } else {
179 break;
180 }
181 }
182 Cmd::Close => break,
183 }
184 }
185 });
186
187 Ok((
188 Self {
189 pub_key,
190 weak_conn: Arc::downgrade(&conn),
191 conn: tokio::sync::Mutex::new(conn),
192 cmd_send,
193 recv_task,
194 cmd_task,
195 },
196 FramedConnRecv(msg_recv),
197 ))
198 }
199
200 /// The pub key of the remote peer this is connected to.
201 pub fn pub_key(&self) -> &PubKey {
202 &self.pub_key
203 }
204
205 /// Returns `true` if we successfully connected over webrtc.
206 pub fn is_using_webrtc(&self) -> bool {
207 if let Some(conn) = self.weak_conn.upgrade() {
208 conn.is_using_webrtc()
209 } else {
210 false
211 }
212 }
213
214 /// Get connection statistics.
215 pub fn get_stats(&self) -> ConnStats {
216 if let Some(conn) = self.weak_conn.upgrade() {
217 conn.get_stats()
218 } else {
219 ConnStats::default()
220 }
221 }
222
223 /// Send a message on the connection.
224 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
225 let byte_count = msg.len();
226 match self.send_inner(msg).await {
227 Ok(_) => {
228 tracing::trace!(
229 target: "NETAUDIT",
230 pub_key = ?self.pub_key,
231 byte_count,
232 m = "tx5-connection",
233 a = "send_framed_success",
234 );
235 Ok(())
236 }
237 Err(err) => {
238 tracing::debug!(
239 target: "NETAUDIT",
240 pub_key = ?self.pub_key,
241 byte_count,
242 ?err,
243 m = "tx5-connection",
244 a = "send_framed_error",
245 );
246 Err(err)
247 }
248 }
249 }
250
251 /// Helper to do the sending, breaking up the messages if needed and
252 /// awaiting the permit before sending broken up messages.
253 async fn send_inner(&self, msg: Vec<u8>) -> Result<()> {
254 let conn = self.conn.lock().await;
255
256 match crate::proto::proto_encode(&msg)? {
257 crate::proto::ProtoEncodeResult::OneMessage(msg) => {
258 // it's a small message, just send it as one chunk
259 conn.send(msg).await?;
260 }
261 crate::proto::ProtoEncodeResult::NeedPermit {
262 permit_req,
263 msg_payload,
264 } => {
265 // it's a big message, we've got chunks
266
267 let (s_reg, r_reg) = tokio::sync::oneshot::channel();
268 let (s_perm, r_perm) = tokio::sync::oneshot::channel();
269
270 // coordinate with the cmd task that we need a permit
271 self.cmd_send
272 .send(Cmd::AwaitPermit {
273 await_registered: s_reg,
274 got_permit: s_perm,
275 })
276 .await
277 .map_err(|_| Error::other("closed"))?;
278
279 // wait for the want permit to be registered
280 r_reg.await.map_err(|_| Error::other("closed"))?;
281
282 // send the permit request
283 conn.send(permit_req).await?;
284
285 // wait for the permit to be authorized by the peer
286 r_perm.await.map_err(|_| Error::other("closed"))?;
287
288 // send the chunked messages
289 for msg in msg_payload {
290 conn.send(msg).await?;
291 }
292 }
293 }
294
295 Ok(())
296 }
297}