1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5 SigRecv(tx5_signal::SignalMessage),
6 WebrtcRecv(webrtc::WebrtcEvt),
7 SendMessage(Vec<u8>),
8 WebrtcTimeoutCheck,
9 WebrtcClosed,
10}
11
12pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16 pub async fn recv(&mut self) -> Option<Vec<u8>> {
18 self.0.recv().await
19 }
20}
21
22pub struct Conn {
24 ready: Arc<tokio::sync::Semaphore>,
25 pub_key: PubKey,
26 cmd_send: CloseSend<ConnCmd>,
27 conn_task: tokio::task::JoinHandle<()>,
28 keepalive_task: tokio::task::JoinHandle<()>,
29 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38 ($lvl:ident, $($all:tt)*) => {
39 ::tracing::event!(
40 target: "NETAUDIT",
41 ::tracing::Level::$lvl,
42 m = "tx5-connection",
43 $($all)*
44 );
45 };
46}
47
48impl Drop for Conn {
49 fn drop(&mut self) {
50 netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52 self.conn_task.abort();
53 self.keepalive_task.abort();
54
55 let hub_cmd_send = self.hub_cmd_send.clone();
56 let pub_key = self.pub_key.clone();
57 tokio::task::spawn(async move {
58 let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59 });
60 }
61}
62
63impl Conn {
64 #[cfg(test)]
65 pub(crate) fn test_kill_keepalive_task(&self) {
66 self.keepalive_task.abort();
67 }
68
69 pub(crate) fn priv_new(
70 webrtc_config: WebRtcConfig,
71 is_polite: bool,
72 pub_key: PubKey,
73 client: Weak<tx5_signal::SignalConnection>,
74 config: Arc<HubConfig>,
75 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76 ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77 netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
78
79 let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
81 let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
82 let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
83 let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
84 let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
85
86 let ready = Arc::new(tokio::sync::Semaphore::new(0));
88
89 let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
90 let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
91
92 let keepalive_dur = config.signal_config.max_idle / 2;
94 let client2 = client.clone();
95 let pub_key2 = pub_key.clone();
96 let keepalive_task = tokio::task::spawn(async move {
97 loop {
98 tokio::time::sleep(keepalive_dur).await;
99
100 if let Some(client) = client2.upgrade() {
101 if client.send_keepalive(&pub_key2).await.is_err() {
102 break;
103 }
104 } else {
105 break;
106 }
107 }
108 });
109
110 msg_send.set_close_on_drop(true);
111
112 let con_task_fut = con_task(
114 is_polite,
115 webrtc_config,
116 TaskCore {
117 client,
118 config,
119 pub_key: pub_key.clone(),
120 cmd_send: cmd_send.clone(),
121 cmd_recv,
122 send_msg_count: send_msg_count.clone(),
123 send_byte_count: send_byte_count.clone(),
124 recv_msg_count: recv_msg_count.clone(),
125 recv_byte_count: recv_byte_count.clone(),
126 msg_send,
127 ready: ready.clone(),
128 is_webrtc: is_webrtc.clone(),
129 },
130 );
131 let conn_task = tokio::task::spawn(con_task_fut);
132
133 let mut cmd_send2 = cmd_send.clone();
134 cmd_send2.set_close_on_drop(true);
135 let this = Self {
136 ready,
137 pub_key,
138 cmd_send: cmd_send2,
139 conn_task,
140 keepalive_task,
141 is_webrtc,
142 send_msg_count,
143 send_byte_count,
144 recv_msg_count,
145 recv_byte_count,
146 hub_cmd_send,
147 };
148
149 (Arc::new(this), ConnRecv(msg_recv), cmd_send)
150 }
151
152 pub async fn ready(&self) {
154 let _ = self.ready.acquire().await;
156 }
157
158 pub fn is_using_webrtc(&self) -> bool {
160 self.is_webrtc.load(Ordering::SeqCst)
161 }
162
163 pub fn pub_key(&self) -> &PubKey {
165 &self.pub_key
166 }
167
168 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
170 self.cmd_send.send(ConnCmd::SendMessage(msg)).await
171 }
172
173 pub fn get_stats(&self) -> ConnStats {
175 ConnStats {
176 send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
177 send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
178 recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
179 recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
180 }
181 }
182}
183
184#[derive(Default)]
186pub struct ConnStats {
187 pub send_msg_count: u64,
189
190 pub send_byte_count: u64,
192
193 pub recv_msg_count: u64,
195
196 pub recv_byte_count: u64,
198}
199
200struct TaskCore {
201 config: Arc<HubConfig>,
202 client: Weak<tx5_signal::SignalConnection>,
203 pub_key: PubKey,
204 cmd_send: CloseSend<ConnCmd>,
205 cmd_recv: CloseRecv<ConnCmd>,
206 msg_send: CloseSend<Vec<u8>>,
207 ready: Arc<tokio::sync::Semaphore>,
208 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
209 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
210 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
211 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
212 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
213}
214
215impl TaskCore {
216 async fn handle_recv_msg(
217 &self,
218 msg: Vec<u8>,
219 ) -> std::result::Result<(), ()> {
220 self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
221 self.recv_byte_count
222 .fetch_add(msg.len() as u64, Ordering::Relaxed);
223 if self.msg_send.send(msg).await.is_err() {
224 netaudit!(
225 DEBUG,
226 pub_key = ?self.pub_key,
227 a = "close: msg_send closed",
228 );
229 Err(())
230 } else {
231 Ok(())
232 }
233 }
234
235 fn track_send_msg(&self, len: usize) {
236 self.send_msg_count.fetch_add(1, Ordering::Relaxed);
237 self.send_byte_count
238 .fetch_add(len as u64, Ordering::Relaxed);
239 }
240}
241
242async fn con_task(
243 is_polite: bool,
244 webrtc_config: WebRtcConfig,
245 mut task_core: TaskCore,
246) {
247 if let Some(client) = task_core.client.upgrade() {
249 let handshake_fut = async {
250 let nonce = client.send_handshake_req(&task_core.pub_key).await?;
251
252 let mut got_peer_res = false;
253 let mut sent_our_res = false;
254
255 while let Some(cmd) = task_core.cmd_recv.recv().await {
256 match cmd {
257 ConnCmd::SigRecv(sig) => {
258 use tx5_signal::SignalMessage::*;
259 match sig {
260 HandshakeReq(oth_nonce) => {
261 client
262 .send_handshake_res(
263 &task_core.pub_key,
264 oth_nonce,
265 )
266 .await?;
267 sent_our_res = true;
268 }
269 HandshakeRes(res_nonce) => {
270 if res_nonce != nonce {
271 return Err(Error::other("nonce mismatch"));
272 }
273 got_peer_res = true;
274 }
275 _ => (),
278 }
279 }
280 ConnCmd::SendMessage(_) => {
281 return Err(Error::other("send before ready"));
282 }
283 ConnCmd::WebrtcTimeoutCheck
284 | ConnCmd::WebrtcRecv(_)
285 | ConnCmd::WebrtcClosed => {
286 unreachable!()
289 }
290 }
291 if got_peer_res && sent_our_res {
292 break;
293 }
294 }
295
296 Result::Ok(())
297 };
298
299 match tokio::time::timeout(
300 task_core.config.signal_config.max_idle,
301 handshake_fut,
302 )
303 .await
304 {
305 Err(_) | Ok(Err(_)) => {
306 client.close_peer(&task_core.pub_key).await;
307 return;
308 }
309 Ok(Ok(_)) => (),
310 }
311 } else {
312 return;
313 }
314
315 let task_core = match con_task_attempt_webrtc(
317 is_polite,
318 webrtc_config,
319 task_core,
320 )
321 .await
322 {
323 AttemptWebrtcResult::Abort => return,
324 AttemptWebrtcResult::Fallback(task_core) => task_core,
325 };
326
327 task_core.is_webrtc.store(false, Ordering::SeqCst);
328
329 con_task_fallback_use_signal(task_core).await;
332}
333
334async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
335 match tokio::time::timeout(
336 task_core.config.signal_config.max_idle,
337 task_core.cmd_recv.recv(),
338 )
339 .await
340 {
341 Err(_) => {
342 netaudit!(
343 DEBUG,
344 pub_key = ?task_core.pub_key,
345 a = "close: connection idle",
346 );
347 None
348 }
349 Ok(None) => {
350 netaudit!(
351 DEBUG,
352 pub_key = ?task_core.pub_key,
353 a = "close: cmd_recv stream complete",
354 );
355 None
356 }
357 Ok(Some(cmd)) => Some(cmd),
358 }
359}
360
361async fn webrtc_task(
362 mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
363 cmd_send: CloseSend<ConnCmd>,
364) {
365 while let Some(evt) = webrtc_recv.recv().await {
366 if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
367 break;
368 }
369 }
370 let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
371}
372
373enum AttemptWebrtcResult {
374 Abort,
375 Fallback(TaskCore),
376}
377
378async fn con_task_attempt_webrtc(
379 is_polite: bool,
380 webrtc_config: WebRtcConfig,
381 mut task_core: TaskCore,
382) -> AttemptWebrtcResult {
383 use AttemptWebrtcResult::*;
384
385 let timeout_dur = task_core.config.signal_config.max_idle;
386 let timeout_cmd_send = task_core.cmd_send.clone();
387 tokio::task::spawn(async move {
388 tokio::time::sleep(timeout_dur).await;
389 let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
390 });
391
392 let (webrtc, webrtc_recv) = webrtc::new_backend_module(
393 task_core.config.backend_module,
394 is_polite,
395 webrtc_config,
396 4096,
398 );
399
400 struct AbortWebrtc(tokio::task::AbortHandle);
401
402 impl Drop for AbortWebrtc {
403 fn drop(&mut self) {
404 self.0.abort();
405 }
406 }
407
408 let _abort_webrtc = AbortWebrtc(
410 tokio::task::spawn(webrtc_task(
411 webrtc_recv,
412 task_core.cmd_send.clone(),
413 ))
414 .abort_handle(),
415 );
416
417 let mut is_ready = false;
418
419 #[cfg(any(test, feature = "test-utils"))]
420 if task_core.config.test_fail_webrtc {
421 netaudit!(
422 WARN,
423 pub_key = ?task_core.pub_key,
424 a = "webrtc fallback: test",
425 );
426 return Fallback(task_core);
427 }
428
429 while let Some(cmd) = recv_cmd(&mut task_core).await {
431 use tx5_signal::SignalMessage::*;
432 use webrtc::WebrtcEvt::*;
433 use ConnCmd::*;
434 match cmd {
435 SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
436 netaudit!(
437 DEBUG,
438 pub_key = ?task_core.pub_key,
439 a = "close: unexpected handshake msg",
440 );
441 return Abort;
442 }
443 SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
444 if task_core.handle_recv_msg(msg).await.is_err() {
445 return Abort;
446 }
447 netaudit!(
448 WARN,
449 pub_key = ?task_core.pub_key,
450 a = "webrtc fallback: remote sent us an sbd message",
451 );
452 return Fallback(task_core);
456 }
457 SigRecv(Offer(offer)) => {
458 netaudit!(
459 TRACE,
460 pub_key = ?task_core.pub_key,
461 offer = String::from_utf8_lossy(&offer).to_string(),
462 a = "recv_offer",
463 );
464 if let Err(err) = webrtc.in_offer(offer).await {
465 netaudit!(
466 WARN,
467 pub_key = ?task_core.pub_key,
468 ?err,
469 a = "webrtc fallback: failed to parse received offer",
470 );
471 return Fallback(task_core);
472 }
473 }
474 SigRecv(Answer(answer)) => {
475 netaudit!(
476 TRACE,
477 pub_key = ?task_core.pub_key,
478 offer = String::from_utf8_lossy(&answer).to_string(),
479 a = "recv_answer",
480 );
481 if let Err(err) = webrtc.in_answer(answer).await {
482 netaudit!(
483 WARN,
484 pub_key = ?task_core.pub_key,
485 ?err,
486 a = "webrtc fallback: failed to parse received answer",
487 );
488 return Fallback(task_core);
489 }
490 }
491 SigRecv(Ice(ice)) => {
492 netaudit!(
493 TRACE,
494 pub_key = ?task_core.pub_key,
495 offer = String::from_utf8_lossy(&ice).to_string(),
496 a = "recv_ice",
497 );
498 if let Err(err) = webrtc.in_ice(ice).await {
499 netaudit!(
500 DEBUG,
501 pub_key = ?task_core.pub_key,
502 ?err,
503 a = "ignoring webrtc in_ice error",
504 );
505 }
507 }
508 SigRecv(Keepalive) | SigRecv(Unknown) => {
509 }
511 WebrtcRecv(GeneratedOffer(offer)) => {
512 netaudit!(
513 TRACE,
514 pub_key = ?task_core.pub_key,
515 offer = String::from_utf8_lossy(&offer).to_string(),
516 a = "send_offer",
517 );
518 if let Some(client) = task_core.client.upgrade() {
519 if let Err(err) =
520 client.send_offer(&task_core.pub_key, offer).await
521 {
522 netaudit!(
523 DEBUG,
524 pub_key = ?task_core.pub_key,
525 ?err,
526 a = "webrtc send_offer error",
527 );
528 return Abort;
529 }
530 } else {
531 return Abort;
532 }
533 }
534 WebrtcRecv(GeneratedAnswer(answer)) => {
535 netaudit!(
536 TRACE,
537 pub_key = ?task_core.pub_key,
538 offer = String::from_utf8_lossy(&answer).to_string(),
539 a = "send_answer",
540 );
541 if let Some(client) = task_core.client.upgrade() {
542 if let Err(err) =
543 client.send_answer(&task_core.pub_key, answer).await
544 {
545 netaudit!(
546 DEBUG,
547 pub_key = ?task_core.pub_key,
548 ?err,
549 a = "webrtc send_answer error",
550 );
551 return Abort;
552 }
553 } else {
554 return Abort;
555 }
556 }
557 WebrtcRecv(GeneratedIce(ice)) => {
558 netaudit!(
559 TRACE,
560 pub_key = ?task_core.pub_key,
561 offer = String::from_utf8_lossy(&ice).to_string(),
562 a = "send_ice",
563 );
564 if let Some(client) = task_core.client.upgrade() {
565 if let Err(err) =
566 client.send_ice(&task_core.pub_key, ice).await
567 {
568 netaudit!(
569 DEBUG,
570 pub_key = ?task_core.pub_key,
571 ?err,
572 a = "webrtc send_ice error",
573 );
574 return Abort;
575 }
576 } else {
577 return Abort;
578 }
579 }
580 WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
581 if task_core.handle_recv_msg(msg).await.is_err() {
582 return Abort;
583 }
584 }
585 WebrtcRecv(Ready) => {
586 is_ready = true;
587 task_core.is_webrtc.store(true, Ordering::SeqCst);
588 task_core.ready.close();
589 }
590 SendMessage(msg) => {
591 let len = msg.len();
592
593 netaudit!(
594 TRACE,
595 pub_key = ?task_core.pub_key,
596 byte_len = len,
597 a = "queue msg for backend send",
598 );
599 if let Err(err) = webrtc.message(msg).await {
600 netaudit!(
601 WARN,
602 pub_key = ?task_core.pub_key,
603 ?err,
604 a = "webrtc fallback: failed to send message",
605 );
606 return Fallback(task_core);
607 }
608
609 task_core.track_send_msg(len);
610 }
611 WebrtcTimeoutCheck => {
612 if !is_ready {
613 netaudit!(
614 WARN,
615 pub_key = ?task_core.pub_key,
616 a = "webrtc fallback: failed to ready within timeout",
617 );
618 return Fallback(task_core);
619 }
620 }
621 WebrtcClosed => {
622 netaudit!(
623 WARN,
624 pub_key = ?task_core.pub_key,
625 a = "webrtc fallback: webrtc processing task closed",
626 );
627 return Fallback(task_core);
628 }
629 }
630 }
631
632 Abort
633}
634
635async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
636 task_core.ready.close();
638
639 while let Some(cmd) = recv_cmd(&mut task_core).await {
640 match cmd {
641 ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
642 if task_core.handle_recv_msg(msg).await.is_err() {
643 break;
644 }
645 }
646 ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
647 Some(client) => {
648 let len = msg.len();
649 if let Err(err) =
650 client.send_message(&task_core.pub_key, msg).await
651 {
652 netaudit!(
653 DEBUG,
654 pub_key = ?task_core.pub_key,
655 ?err,
656 a = "close: sbd client send error",
657 );
658 break;
659 }
660 task_core.track_send_msg(len);
661 }
662 None => {
663 netaudit!(
664 DEBUG,
665 pub_key = ?task_core.pub_key,
666 a = "close: sbd client closed",
667 );
668 break;
669 }
670 },
671 _ => (),
672 }
673 }
674}