1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5 SigRecv(tx5_signal::SignalMessage),
6 WebrtcRecv(webrtc::WebrtcEvt),
7 SendMessage(Vec<u8>),
8 WebrtcTimeoutCheck,
9 WebrtcClosed,
10}
11
12pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16 pub async fn recv(&mut self) -> Option<Vec<u8>> {
18 self.0.recv().await
19 }
20}
21
22pub struct Conn {
24 ready: Arc<tokio::sync::Semaphore>,
25 pub_key: PubKey,
26 cmd_send: CloseSend<ConnCmd>,
27 conn_task: tokio::task::JoinHandle<()>,
28 keepalive_task: tokio::task::JoinHandle<()>,
29 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38 ($lvl:ident, $($all:tt)*) => {
39 ::tracing::event!(
40 target: "NETAUDIT",
41 ::tracing::Level::$lvl,
42 m = "tx5-connection",
43 $($all)*
44 );
45 };
46}
47
48impl Drop for Conn {
49 fn drop(&mut self) {
50 netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52 self.conn_task.abort();
53 self.keepalive_task.abort();
54
55 let hub_cmd_send = self.hub_cmd_send.clone();
56 let pub_key = self.pub_key.clone();
57 tokio::task::spawn(async move {
58 let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59 });
60 }
61}
62
63impl Conn {
64 #[cfg(test)]
65 pub(crate) fn test_kill_keepalive_task(&self) {
66 self.keepalive_task.abort();
67 }
68
69 pub(crate) fn priv_new(
70 webrtc_config: WebRtcConfig,
71 is_polite: bool,
72 pub_key: PubKey,
73 client: Weak<tx5_signal::SignalConnection>,
74 config: Arc<HubConfig>,
75 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76 ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77 netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
78
79 let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
81 let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
82 let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
83 let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
84 let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
85
86 let ready = Arc::new(tokio::sync::Semaphore::new(0));
88
89 let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
90 let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
91
92 let keepalive_dur = config.signal_config.max_idle / 2;
94 let client2 = client.clone();
95 let pub_key2 = pub_key.clone();
96 let keepalive_task = tokio::task::spawn(async move {
97 loop {
98 tokio::time::sleep(keepalive_dur).await;
99
100 if let Some(client) = client2.upgrade() {
101 if client.send_keepalive(&pub_key2).await.is_err() {
102 break;
103 }
104 } else {
105 break;
106 }
107 }
108 });
109
110 msg_send.set_close_on_drop(true);
111
112 let con_task_fut = con_task(
114 is_polite,
115 webrtc_config,
116 TaskCore {
117 client,
118 config,
119 pub_key: pub_key.clone(),
120 cmd_send: cmd_send.clone(),
121 cmd_recv,
122 send_msg_count: send_msg_count.clone(),
123 send_byte_count: send_byte_count.clone(),
124 recv_msg_count: recv_msg_count.clone(),
125 recv_byte_count: recv_byte_count.clone(),
126 msg_send,
127 ready: ready.clone(),
128 is_webrtc: is_webrtc.clone(),
129 },
130 );
131 let conn_task = tokio::task::spawn(con_task_fut);
132
133 let mut cmd_send2 = cmd_send.clone();
134 cmd_send2.set_close_on_drop(true);
135 let this = Self {
136 ready,
137 pub_key,
138 cmd_send: cmd_send2,
139 conn_task,
140 keepalive_task,
141 is_webrtc,
142 send_msg_count,
143 send_byte_count,
144 recv_msg_count,
145 recv_byte_count,
146 hub_cmd_send,
147 };
148
149 (Arc::new(this), ConnRecv(msg_recv), cmd_send)
150 }
151
152 pub async fn ready(&self) {
154 let _ = self.ready.acquire().await;
156 }
157
158 pub fn is_using_webrtc(&self) -> bool {
160 self.is_webrtc.load(Ordering::SeqCst)
161 }
162
163 pub fn pub_key(&self) -> &PubKey {
165 &self.pub_key
166 }
167
168 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
170 self.cmd_send.send(ConnCmd::SendMessage(msg)).await
171 }
172
173 pub fn get_stats(&self) -> ConnStats {
175 ConnStats {
176 send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
177 send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
178 recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
179 recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
180 }
181 }
182}
183
184#[derive(Default)]
186pub struct ConnStats {
187 pub send_msg_count: u64,
189
190 pub send_byte_count: u64,
192
193 pub recv_msg_count: u64,
195
196 pub recv_byte_count: u64,
198}
199
200struct TaskCore {
201 config: Arc<HubConfig>,
202 client: Weak<tx5_signal::SignalConnection>,
203 pub_key: PubKey,
204 cmd_send: CloseSend<ConnCmd>,
205 cmd_recv: CloseRecv<ConnCmd>,
206 msg_send: CloseSend<Vec<u8>>,
207 ready: Arc<tokio::sync::Semaphore>,
208 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
209 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
210 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
211 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
212 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
213}
214
215impl TaskCore {
216 async fn handle_recv_msg(
217 &self,
218 msg: Vec<u8>,
219 ) -> std::result::Result<(), ()> {
220 self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
221 self.recv_byte_count
222 .fetch_add(msg.len() as u64, Ordering::Relaxed);
223 if self.msg_send.send(msg).await.is_err() {
224 netaudit!(
225 DEBUG,
226 pub_key = ?self.pub_key,
227 a = "close: msg_send closed",
228 );
229 Err(())
230 } else {
231 Ok(())
232 }
233 }
234
235 fn track_send_msg(&self, len: usize) {
236 self.send_msg_count.fetch_add(1, Ordering::Relaxed);
237 self.send_byte_count
238 .fetch_add(len as u64, Ordering::Relaxed);
239 }
240}
241
242async fn con_task(
243 is_polite: bool,
244 webrtc_config: WebRtcConfig,
245 mut task_core: TaskCore,
246) {
247 if let Some(client) = task_core.client.upgrade() {
249 let handshake_fut = async {
250 let nonce = client.send_handshake_req(&task_core.pub_key).await?;
251
252 let mut got_peer_res = false;
253 let mut sent_our_res = false;
254
255 while let Some(cmd) = task_core.cmd_recv.recv().await {
256 match cmd {
257 ConnCmd::SigRecv(sig) => {
258 use tx5_signal::SignalMessage::*;
259 match sig {
260 HandshakeReq(oth_nonce) => {
261 client
262 .send_handshake_res(
263 &task_core.pub_key,
264 oth_nonce,
265 )
266 .await?;
267 sent_our_res = true;
268 }
269 HandshakeRes(res_nonce) => {
270 if res_nonce != nonce {
271 return Err(Error::other("nonce mismatch"));
272 }
273 got_peer_res = true;
274 }
275 _ => (),
278 }
279 }
280 ConnCmd::SendMessage(_) => {
281 return Err(Error::other("send before ready"));
282 }
283 ConnCmd::WebrtcTimeoutCheck
284 | ConnCmd::WebrtcRecv(_)
285 | ConnCmd::WebrtcClosed => {
286 unreachable!()
289 }
290 }
291 if got_peer_res && sent_our_res {
292 break;
293 }
294 }
295
296 Result::Ok(())
297 };
298
299 match tokio::time::timeout(
300 task_core.config.signal_config.max_idle,
301 handshake_fut,
302 )
303 .await
304 {
305 Err(_) | Ok(Err(_)) => {
306 client.close_peer(&task_core.pub_key).await;
307 return;
308 }
309 Ok(Ok(_)) => (),
310 }
311 } else {
312 return;
313 }
314
315 let task_core = match con_task_attempt_webrtc(
317 is_polite,
318 webrtc_config,
319 task_core,
320 )
321 .await
322 {
323 AttemptWebrtcResult::Abort => return,
324 AttemptWebrtcResult::Fallback(task_core) => {
325 if task_core.config.danger_deny_signal_relay {
326 netaudit!(
327 INFO,
328 pub_key = ?task_core.pub_key,
329 a = "webrtc fallback: denied signal relay",
330 );
331 return;
332 }
333
334 task_core
335 }
336 };
337
338 task_core.is_webrtc.store(false, Ordering::SeqCst);
339
340 con_task_fallback_use_signal(task_core).await;
343}
344
345async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
346 match tokio::time::timeout(
347 task_core.config.signal_config.max_idle,
348 task_core.cmd_recv.recv(),
349 )
350 .await
351 {
352 Err(_) => {
353 netaudit!(
354 DEBUG,
355 pub_key = ?task_core.pub_key,
356 a = "close: connection idle",
357 );
358 None
359 }
360 Ok(None) => {
361 netaudit!(
362 DEBUG,
363 pub_key = ?task_core.pub_key,
364 a = "close: cmd_recv stream complete",
365 );
366 None
367 }
368 Ok(Some(cmd)) => Some(cmd),
369 }
370}
371
372async fn webrtc_task(
373 mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
374 cmd_send: CloseSend<ConnCmd>,
375) {
376 while let Some(evt) = webrtc_recv.recv().await {
377 if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
378 break;
379 }
380 }
381 netaudit!(DEBUG, a = "webrtc task closed, sending WebrtcClosed",);
382 let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
383}
384
385enum AttemptWebrtcResult {
386 Abort,
387 Fallback(TaskCore),
388}
389
390async fn con_task_attempt_webrtc(
391 is_polite: bool,
392 webrtc_config: WebRtcConfig,
393 mut task_core: TaskCore,
394) -> AttemptWebrtcResult {
395 use AttemptWebrtcResult::*;
396
397 let timeout_dur = task_core.config.webrtc_connect_timeout;
398 let timeout_cmd_send = task_core.cmd_send.clone();
399 tokio::task::spawn(async move {
400 tokio::time::sleep(timeout_dur).await;
401 let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
402 });
403
404 let (webrtc, webrtc_recv) = webrtc::new_backend_module(
405 task_core.config.backend_module,
406 is_polite,
407 webrtc_config,
408 4096,
410 );
411
412 struct AbortWebrtc(tokio::task::AbortHandle);
413
414 impl Drop for AbortWebrtc {
415 fn drop(&mut self) {
416 self.0.abort();
417 }
418 }
419
420 let _abort_webrtc = AbortWebrtc(
422 tokio::task::spawn(webrtc_task(
423 webrtc_recv,
424 task_core.cmd_send.clone(),
425 ))
426 .abort_handle(),
427 );
428
429 let mut is_ready = false;
430
431 if task_core.config.danger_force_signal_relay {
432 netaudit!(
433 WARN,
434 pub_key = ?task_core.pub_key,
435 a = "webrtc fallback: test",
436 );
437 return Fallback(task_core);
438 }
439
440 while let Some(cmd) = recv_cmd(&mut task_core).await {
442 use tx5_signal::SignalMessage::*;
443 use webrtc::WebrtcEvt::*;
444 use ConnCmd::*;
445 match cmd {
446 SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
447 netaudit!(
448 DEBUG,
449 pub_key = ?task_core.pub_key,
450 a = "close: unexpected handshake msg",
451 );
452 break;
453 }
454 SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
455 if task_core.handle_recv_msg(msg).await.is_err() {
456 break;
457 }
458 netaudit!(
459 WARN,
460 pub_key = ?task_core.pub_key,
461 a = "webrtc fallback: remote sent us an sbd message",
462 );
463 return Fallback(task_core);
467 }
468 SigRecv(Offer(offer)) => {
469 netaudit!(
470 TRACE,
471 pub_key = ?task_core.pub_key,
472 offer = String::from_utf8_lossy(&offer).to_string(),
473 a = "recv_offer",
474 );
475 if let Err(err) = webrtc.in_offer(offer).await {
476 netaudit!(
477 WARN,
478 pub_key = ?task_core.pub_key,
479 ?err,
480 a = "webrtc fallback: failed to parse received offer",
481 );
482 return Fallback(task_core);
483 }
484 }
485 SigRecv(Answer(answer)) => {
486 netaudit!(
487 TRACE,
488 pub_key = ?task_core.pub_key,
489 offer = String::from_utf8_lossy(&answer).to_string(),
490 a = "recv_answer",
491 );
492 if let Err(err) = webrtc.in_answer(answer).await {
493 netaudit!(
494 WARN,
495 pub_key = ?task_core.pub_key,
496 ?err,
497 a = "webrtc fallback: failed to parse received answer",
498 );
499 return Fallback(task_core);
500 }
501 }
502 SigRecv(Ice(ice)) => {
503 netaudit!(
504 TRACE,
505 pub_key = ?task_core.pub_key,
506 offer = String::from_utf8_lossy(&ice).to_string(),
507 a = "recv_ice",
508 );
509 if let Err(err) = webrtc.in_ice(ice).await {
510 netaudit!(
511 DEBUG,
512 pub_key = ?task_core.pub_key,
513 ?err,
514 a = "ignoring webrtc in_ice error",
515 );
516 }
518 }
519 SigRecv(Keepalive) | SigRecv(Unknown) => {
520 }
522 WebrtcRecv(GeneratedOffer(offer)) => {
523 netaudit!(
524 TRACE,
525 pub_key = ?task_core.pub_key,
526 offer = String::from_utf8_lossy(&offer).to_string(),
527 a = "send_offer",
528 );
529 if let Some(client) = task_core.client.upgrade() {
530 if let Err(err) =
531 client.send_offer(&task_core.pub_key, offer).await
532 {
533 netaudit!(
534 DEBUG,
535 pub_key = ?task_core.pub_key,
536 ?err,
537 a = "webrtc send_offer error",
538 );
539 break;
540 }
541 } else {
542 break;
543 }
544 }
545 WebrtcRecv(GeneratedAnswer(answer)) => {
546 netaudit!(
547 TRACE,
548 pub_key = ?task_core.pub_key,
549 offer = String::from_utf8_lossy(&answer).to_string(),
550 a = "send_answer",
551 );
552 if let Some(client) = task_core.client.upgrade() {
553 if let Err(err) =
554 client.send_answer(&task_core.pub_key, answer).await
555 {
556 netaudit!(
557 DEBUG,
558 pub_key = ?task_core.pub_key,
559 ?err,
560 a = "webrtc send_answer error",
561 );
562 break;
563 }
564 } else {
565 break;
566 }
567 }
568 WebrtcRecv(GeneratedIce(ice)) => {
569 netaudit!(
570 TRACE,
571 pub_key = ?task_core.pub_key,
572 offer = String::from_utf8_lossy(&ice).to_string(),
573 a = "send_ice",
574 );
575 if let Some(client) = task_core.client.upgrade() {
576 if let Err(err) =
577 client.send_ice(&task_core.pub_key, ice).await
578 {
579 netaudit!(
580 DEBUG,
581 pub_key = ?task_core.pub_key,
582 ?err,
583 a = "webrtc send_ice error",
584 );
585 break;
586 }
587 } else {
588 break;
589 }
590 }
591 WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
592 if task_core.handle_recv_msg(msg).await.is_err() {
593 break;
594 }
595 }
596 WebrtcRecv(Ready) => {
597 is_ready = true;
598 task_core.is_webrtc.store(true, Ordering::SeqCst);
599 task_core.ready.close();
600 }
601 SendMessage(msg) => {
602 let len = msg.len();
603
604 netaudit!(
605 TRACE,
606 pub_key = ?task_core.pub_key,
607 byte_len = len,
608 a = "queue msg for backend send",
609 );
610 if let Err(err) = webrtc.message(msg).await {
611 netaudit!(
612 WARN,
613 pub_key = ?task_core.pub_key,
614 ?err,
615 a = "webrtc fallback: failed to send message",
616 );
617 return Fallback(task_core);
618 }
619
620 task_core.track_send_msg(len);
621 }
622 WebrtcTimeoutCheck => {
623 if !is_ready {
624 netaudit!(
625 WARN,
626 pub_key = ?task_core.pub_key,
627 a = "webrtc fallback: failed to ready within timeout",
628 );
629 return Fallback(task_core);
630 }
631 }
632 WebrtcClosed => {
633 netaudit!(
634 WARN,
635 pub_key = ?task_core.pub_key,
636 a = "webrtc processing task closed",
637 );
638 break;
639 }
640 }
641 }
642
643 Abort
644}
645
646async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
647 task_core.ready.close();
649
650 while let Some(cmd) = recv_cmd(&mut task_core).await {
651 match cmd {
652 ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
653 if task_core.handle_recv_msg(msg).await.is_err() {
654 break;
655 }
656 }
657 ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
658 Some(client) => {
659 let len = msg.len();
660 if let Err(err) =
661 client.send_message(&task_core.pub_key, msg).await
662 {
663 netaudit!(
664 DEBUG,
665 pub_key = ?task_core.pub_key,
666 ?err,
667 a = "close: sbd client send error",
668 );
669 break;
670 }
671 task_core.track_send_msg(len);
672 }
673 None => {
674 netaudit!(
675 DEBUG,
676 pub_key = ?task_core.pub_key,
677 a = "close: sbd client closed",
678 );
679 break;
680 }
681 },
682 _ => (),
683 }
684 }
685}