1use crate::*;
2
3enum Cmd {
4 Recv(Vec<u8>),
5 AwaitPermit {
6 await_registered: tokio::sync::oneshot::Sender<()>,
7 got_permit: tokio::sync::oneshot::Sender<()>,
8 },
9 RemotePermit(tokio::sync::OwnedSemaphorePermit, u32),
10 Close,
11}
12
13pub struct FramedConnRecv(tokio::sync::mpsc::Receiver<Vec<u8>>);
15
16impl FramedConnRecv {
17 pub async fn recv(&mut self) -> Option<Vec<u8>> {
19 self.0.recv().await
20 }
21}
22
23pub struct FramedConn {
26 pub_key: PubKey,
27 weak_conn: Weak<Conn>,
28 conn: tokio::sync::Mutex<Arc<Conn>>,
29 cmd_send: tokio::sync::mpsc::Sender<Cmd>,
30 recv_task: tokio::task::JoinHandle<()>,
31 cmd_task: tokio::task::JoinHandle<()>,
32}
33
34impl Drop for FramedConn {
35 fn drop(&mut self) {
36 self.recv_task.abort();
37 self.cmd_task.abort();
38 }
39}
40
41impl FramedConn {
42 pub async fn new(
44 conn: Arc<Conn>,
45 mut conn_recv: ConnRecv,
46 recv_limit: Arc<tokio::sync::Semaphore>,
47 ) -> Result<(Self, FramedConnRecv)> {
48 conn.ready().await;
49
50 let (a, b, c, d) = crate::proto::PROTO_VER_2.encode()?;
51 conn.send(vec![a, b, c, d]).await?;
52
53 let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32);
54 let (msg_send, msg_recv) = tokio::sync::mpsc::channel(32);
55
56 let cmd_send2 = cmd_send.clone();
57 let recv_task = tokio::task::spawn(async move {
58 while let Some(msg) = conn_recv.recv().await {
59 if cmd_send2.send(Cmd::Recv(msg)).await.is_err() {
60 break;
61 }
62 }
63
64 let _ = cmd_send2.send(Cmd::Close).await;
65 });
66
67 let pub_key = conn.pub_key().clone();
68
69 let pub_key2 = pub_key.clone();
70 let cmd_send2 = cmd_send.clone();
71 let weak_conn = Arc::downgrade(&conn);
72 let cmd_task = tokio::task::spawn(async move {
73 let mut dec = crate::proto::ProtoDecoder::default();
74
75 while let Some(cmd) = cmd_recv.recv().await {
76 match cmd {
77 Cmd::Recv(msg) => {
78 use crate::proto::ProtoDecodeResult::*;
79 match dec.decode(&msg) {
80 Err(_) => break,
81 Ok(Idle) => (),
82 Ok(Message(msg)) => {
83 tracing::trace!(
84 target: "NETAUDIT",
85 pub_key = ?pub_key2,
86 byte_count = msg.len(),
87 m = "tx5-connection",
88 a = "recv_framed",
89 );
90 if msg_send.send(msg).await.is_err() {
91 break;
92 }
93 }
94 Ok(RemotePermitRequest(permit_len)) => {
95 let recv_limit = recv_limit.clone();
96 let cmd_send = cmd_send2.clone();
97 tokio::task::spawn(async move {
99 if let Ok(permit) = recv_limit
100 .acquire_many_owned(permit_len)
101 .await
102 {
103 let _ = cmd_send
104 .send(Cmd::RemotePermit(
105 permit, permit_len,
106 ))
107 .await;
108 }
109 });
110 }
111 Ok(RemotePermitGrant(_)) => (),
112 }
113 }
114 Cmd::AwaitPermit {
115 await_registered,
116 got_permit,
117 } => {
118 if dec
119 .sent_remote_permit_request(Some(got_permit))
120 .is_err()
121 {
122 break;
123 }
124 let _ = await_registered.send(());
125 }
126 Cmd::RemotePermit(permit, permit_len) => {
127 if dec.sent_remote_permit_grant(permit).is_err() {
128 break;
129 }
130 if let Some(conn) = weak_conn.upgrade() {
131 let (a, b, c, d) =
132 match crate::proto::ProtoHeader::PermitGrant(
133 permit_len,
134 )
135 .encode()
136 {
137 Ok(r) => r,
138 Err(_) => break,
139 };
140 if conn.send(vec![a, b, c, d]).await.is_err() {
141 break;
142 }
143 } else {
144 break;
145 }
146 }
147 Cmd::Close => break,
148 }
149 }
150 });
151
152 Ok((
153 Self {
154 pub_key,
155 weak_conn: Arc::downgrade(&conn),
156 conn: tokio::sync::Mutex::new(conn),
157 cmd_send,
158 recv_task,
159 cmd_task,
160 },
161 FramedConnRecv(msg_recv),
162 ))
163 }
164
165 pub fn pub_key(&self) -> &PubKey {
167 &self.pub_key
168 }
169
170 pub fn is_using_webrtc(&self) -> bool {
172 if let Some(conn) = self.weak_conn.upgrade() {
173 conn.is_using_webrtc()
174 } else {
175 false
176 }
177 }
178
179 pub fn get_stats(&self) -> ConnStats {
181 if let Some(conn) = self.weak_conn.upgrade() {
182 conn.get_stats()
183 } else {
184 ConnStats::default()
185 }
186 }
187
188 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
190 let byte_count = msg.len();
191 match self.send_inner(msg).await {
192 Ok(_) => {
193 tracing::trace!(
194 target: "NETAUDIT",
195 pub_key = ?self.pub_key,
196 byte_count,
197 m = "tx5-connection",
198 a = "send_framed_success",
199 );
200 Ok(())
201 }
202 Err(err) => {
203 tracing::debug!(
204 target: "NETAUDIT",
205 pub_key = ?self.pub_key,
206 byte_count,
207 ?err,
208 m = "tx5-connection",
209 a = "send_framed_error",
210 );
211 Err(err)
212 }
213 }
214 }
215
216 async fn send_inner(&self, msg: Vec<u8>) -> Result<()> {
217 let conn = self.conn.lock().await;
218
219 match crate::proto::proto_encode(&msg)? {
220 crate::proto::ProtoEncodeResult::OneMessage(msg) => {
221 conn.send(msg).await?;
222 }
223 crate::proto::ProtoEncodeResult::NeedPermit {
224 permit_req,
225 msg_payload,
226 } => {
227 let (s_reg, r_reg) = tokio::sync::oneshot::channel();
228 let (s_perm, r_perm) = tokio::sync::oneshot::channel();
229
230 self.cmd_send
231 .send(Cmd::AwaitPermit {
232 await_registered: s_reg,
233 got_permit: s_perm,
234 })
235 .await
236 .map_err(|_| Error::other("closed"))?;
237
238 r_reg.await.map_err(|_| Error::other("closed"))?;
239
240 conn.send(permit_req).await?;
241
242 r_perm.await.map_err(|_| Error::other("closed"))?;
243
244 for msg in msg_payload {
245 conn.send(msg).await?;
246 }
247 }
248 }
249
250 Ok(())
251 }
252}