1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5 SigRecv(tx5_signal::SignalMessage),
6 WebrtcRecv(webrtc::WebrtcEvt),
7 SendMessage(Vec<u8>),
8 WebrtcTimeoutCheck,
9 WebrtcClosed,
10}
11
12pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16 pub async fn recv(&mut self) -> Option<Vec<u8>> {
18 self.0.recv().await
19 }
20}
21
22pub struct Conn {
24 ready: Arc<tokio::sync::Semaphore>,
25 pub_key: PubKey,
26 cmd_send: CloseSend<ConnCmd>,
27 conn_task: tokio::task::JoinHandle<()>,
28 keepalive_task: tokio::task::JoinHandle<()>,
29 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38 ($lvl:ident, $($all:tt)*) => {
39 ::tracing::event!(
40 target: "NETAUDIT",
41 ::tracing::Level::$lvl,
42 m = "tx5-connection",
43 $($all)*
44 );
45 };
46}
47
48impl Drop for Conn {
49 fn drop(&mut self) {
50 netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52 self.conn_task.abort();
53 self.keepalive_task.abort();
54
55 let hub_cmd_send = self.hub_cmd_send.clone();
56 let pub_key = self.pub_key.clone();
57 tokio::task::spawn(async move {
58 let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59 });
60 }
61}
62
63impl Conn {
64 #[cfg(test)]
65 pub(crate) fn test_kill_keepalive_task(&self) {
66 self.keepalive_task.abort();
67 }
68
69 pub(crate) fn priv_new(
70 webrtc_config: WebRtcConfig,
71 is_polite: bool,
72 pub_key: PubKey,
73 client: Weak<tx5_signal::SignalConnection>,
74 config: Arc<HubConfig>,
75 hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76 ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77 netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
78
79 let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
80 let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
81 let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
82 let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
83 let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
84
85 let ready = Arc::new(tokio::sync::Semaphore::new(0));
87
88 let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
89 let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
90
91 let keepalive_dur = config.signal_config.max_idle / 2;
92 let client2 = client.clone();
93 let pub_key2 = pub_key.clone();
94 let keepalive_task = tokio::task::spawn(async move {
95 loop {
96 tokio::time::sleep(keepalive_dur).await;
97
98 if let Some(client) = client2.upgrade() {
99 if client.send_keepalive(&pub_key2).await.is_err() {
100 break;
101 }
102 } else {
103 break;
104 }
105 }
106 });
107
108 msg_send.set_close_on_drop(true);
109
110 let con_task_fut = con_task(
111 is_polite,
112 webrtc_config,
113 TaskCore {
114 client,
115 config,
116 pub_key: pub_key.clone(),
117 cmd_send: cmd_send.clone(),
118 cmd_recv,
119 send_msg_count: send_msg_count.clone(),
120 send_byte_count: send_byte_count.clone(),
121 recv_msg_count: recv_msg_count.clone(),
122 recv_byte_count: recv_byte_count.clone(),
123 msg_send,
124 ready: ready.clone(),
125 is_webrtc: is_webrtc.clone(),
126 },
127 );
128 let conn_task = tokio::task::spawn(con_task_fut);
129
130 let mut cmd_send2 = cmd_send.clone();
131 cmd_send2.set_close_on_drop(true);
132 let this = Self {
133 ready,
134 pub_key,
135 cmd_send: cmd_send2,
136 conn_task,
137 keepalive_task,
138 is_webrtc,
139 send_msg_count,
140 send_byte_count,
141 recv_msg_count,
142 recv_byte_count,
143 hub_cmd_send,
144 };
145
146 (Arc::new(this), ConnRecv(msg_recv), cmd_send)
147 }
148
149 pub async fn ready(&self) {
151 let _ = self.ready.acquire().await;
153 }
154
155 pub fn is_using_webrtc(&self) -> bool {
157 self.is_webrtc.load(Ordering::SeqCst)
158 }
159
160 pub fn pub_key(&self) -> &PubKey {
162 &self.pub_key
163 }
164
165 pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
167 self.cmd_send.send(ConnCmd::SendMessage(msg)).await
168 }
169
170 pub fn get_stats(&self) -> ConnStats {
172 ConnStats {
173 send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
174 send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
175 recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
176 recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
177 }
178 }
179}
180
181#[derive(Default)]
183pub struct ConnStats {
184 pub send_msg_count: u64,
186
187 pub send_byte_count: u64,
189
190 pub recv_msg_count: u64,
192
193 pub recv_byte_count: u64,
195}
196
197struct TaskCore {
198 config: Arc<HubConfig>,
199 client: Weak<tx5_signal::SignalConnection>,
200 pub_key: PubKey,
201 cmd_send: CloseSend<ConnCmd>,
202 cmd_recv: CloseRecv<ConnCmd>,
203 msg_send: CloseSend<Vec<u8>>,
204 ready: Arc<tokio::sync::Semaphore>,
205 is_webrtc: Arc<std::sync::atomic::AtomicBool>,
206 send_msg_count: Arc<std::sync::atomic::AtomicU64>,
207 send_byte_count: Arc<std::sync::atomic::AtomicU64>,
208 recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
209 recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
210}
211
212impl TaskCore {
213 async fn handle_recv_msg(
214 &self,
215 msg: Vec<u8>,
216 ) -> std::result::Result<(), ()> {
217 self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
218 self.recv_byte_count
219 .fetch_add(msg.len() as u64, Ordering::Relaxed);
220 if self.msg_send.send(msg).await.is_err() {
221 netaudit!(
222 DEBUG,
223 pub_key = ?self.pub_key,
224 a = "close: msg_send closed",
225 );
226 Err(())
227 } else {
228 Ok(())
229 }
230 }
231
232 fn track_send_msg(&self, len: usize) {
233 self.send_msg_count.fetch_add(1, Ordering::Relaxed);
234 self.send_byte_count
235 .fetch_add(len as u64, Ordering::Relaxed);
236 }
237}
238
239async fn con_task(
240 is_polite: bool,
241 webrtc_config: WebRtcConfig,
242 mut task_core: TaskCore,
243) {
244 if let Some(client) = task_core.client.upgrade() {
246 let handshake_fut = async {
247 let nonce = client.send_handshake_req(&task_core.pub_key).await?;
248
249 let mut got_peer_res = false;
250 let mut sent_our_res = false;
251
252 while let Some(cmd) = task_core.cmd_recv.recv().await {
253 match cmd {
254 ConnCmd::SigRecv(sig) => {
255 use tx5_signal::SignalMessage::*;
256 match sig {
257 HandshakeReq(oth_nonce) => {
258 client
259 .send_handshake_res(
260 &task_core.pub_key,
261 oth_nonce,
262 )
263 .await?;
264 sent_our_res = true;
265 }
266 HandshakeRes(res_nonce) => {
267 if res_nonce != nonce {
268 return Err(Error::other("nonce mismatch"));
269 }
270 got_peer_res = true;
271 }
272 _ => (),
275 }
276 }
277 ConnCmd::SendMessage(_) => {
278 return Err(Error::other("send before ready"));
279 }
280 ConnCmd::WebrtcTimeoutCheck
281 | ConnCmd::WebrtcRecv(_)
282 | ConnCmd::WebrtcClosed => {
283 unreachable!()
286 }
287 }
288 if got_peer_res && sent_our_res {
289 break;
290 }
291 }
292
293 Result::Ok(())
294 };
295
296 match tokio::time::timeout(
297 task_core.config.signal_config.max_idle,
298 handshake_fut,
299 )
300 .await
301 {
302 Err(_) | Ok(Err(_)) => {
303 client.close_peer(&task_core.pub_key).await;
304 return;
305 }
306 Ok(Ok(_)) => (),
307 }
308 } else {
309 return;
310 }
311
312 let task_core = match con_task_attempt_webrtc(
314 is_polite,
315 webrtc_config,
316 task_core,
317 )
318 .await
319 {
320 AttemptWebrtcResult::Abort => return,
321 AttemptWebrtcResult::Fallback(task_core) => task_core,
322 };
323
324 task_core.is_webrtc.store(false, Ordering::SeqCst);
325
326 con_task_fallback_use_signal(task_core).await;
329}
330
331async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
332 match tokio::time::timeout(
333 task_core.config.signal_config.max_idle,
334 task_core.cmd_recv.recv(),
335 )
336 .await
337 {
338 Err(_) => {
339 netaudit!(
340 DEBUG,
341 pub_key = ?task_core.pub_key,
342 a = "close: connection idle",
343 );
344 None
345 }
346 Ok(None) => {
347 netaudit!(
348 DEBUG,
349 pub_key = ?task_core.pub_key,
350 a = "close: cmd_recv stream complete",
351 );
352 None
353 }
354 Ok(Some(cmd)) => Some(cmd),
355 }
356}
357
358async fn webrtc_task(
359 mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
360 cmd_send: CloseSend<ConnCmd>,
361) {
362 while let Some(evt) = webrtc_recv.recv().await {
363 if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
364 break;
365 }
366 }
367 let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
368}
369
370enum AttemptWebrtcResult {
371 Abort,
372 Fallback(TaskCore),
373}
374
375async fn con_task_attempt_webrtc(
376 is_polite: bool,
377 webrtc_config: WebRtcConfig,
378 mut task_core: TaskCore,
379) -> AttemptWebrtcResult {
380 use AttemptWebrtcResult::*;
381
382 let timeout_dur = task_core.config.signal_config.max_idle;
383 let timeout_cmd_send = task_core.cmd_send.clone();
384 tokio::task::spawn(async move {
385 tokio::time::sleep(timeout_dur).await;
386 let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
387 });
388
389 let (webrtc, webrtc_recv) = webrtc::new_backend_module(
390 task_core.config.backend_module,
391 is_polite,
392 webrtc_config,
393 4096,
395 );
396
397 struct AbortWebrtc(tokio::task::AbortHandle);
398
399 impl Drop for AbortWebrtc {
400 fn drop(&mut self) {
401 self.0.abort();
402 }
403 }
404
405 let _abort_webrtc = AbortWebrtc(
407 tokio::task::spawn(webrtc_task(
408 webrtc_recv,
409 task_core.cmd_send.clone(),
410 ))
411 .abort_handle(),
412 );
413
414 let mut is_ready = false;
415
416 #[cfg(test)]
417 if task_core.config.test_fail_webrtc {
418 netaudit!(
419 WARN,
420 pub_key = ?task_core.pub_key,
421 a = "webrtc fallback: test",
422 );
423 return Fallback(task_core);
424 }
425
426 while let Some(cmd) = recv_cmd(&mut task_core).await {
427 use tx5_signal::SignalMessage::*;
428 use webrtc::WebrtcEvt::*;
429 use ConnCmd::*;
430 match cmd {
431 SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
432 netaudit!(
433 DEBUG,
434 pub_key = ?task_core.pub_key,
435 a = "close: unexpected handshake msg",
436 );
437 return Abort;
438 }
439 SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
440 if task_core.handle_recv_msg(msg).await.is_err() {
441 return Abort;
442 }
443 netaudit!(
444 WARN,
445 pub_key = ?task_core.pub_key,
446 a = "webrtc fallback: remote sent us an sbd message",
447 );
448 return Fallback(task_core);
452 }
453 SigRecv(Offer(offer)) => {
454 netaudit!(
455 TRACE,
456 pub_key = ?task_core.pub_key,
457 offer = String::from_utf8_lossy(&offer).to_string(),
458 a = "recv_offer",
459 );
460 if let Err(err) = webrtc.in_offer(offer).await {
461 netaudit!(
462 WARN,
463 pub_key = ?task_core.pub_key,
464 ?err,
465 a = "webrtc fallback: failed to parse received offer",
466 );
467 return Fallback(task_core);
468 }
469 }
470 SigRecv(Answer(answer)) => {
471 netaudit!(
472 TRACE,
473 pub_key = ?task_core.pub_key,
474 offer = String::from_utf8_lossy(&answer).to_string(),
475 a = "recv_answer",
476 );
477 if let Err(err) = webrtc.in_answer(answer).await {
478 netaudit!(
479 WARN,
480 pub_key = ?task_core.pub_key,
481 ?err,
482 a = "webrtc fallback: failed to parse received answer",
483 );
484 return Fallback(task_core);
485 }
486 }
487 SigRecv(Ice(ice)) => {
488 netaudit!(
489 TRACE,
490 pub_key = ?task_core.pub_key,
491 offer = String::from_utf8_lossy(&ice).to_string(),
492 a = "recv_ice",
493 );
494 if let Err(err) = webrtc.in_ice(ice).await {
495 netaudit!(
496 DEBUG,
497 pub_key = ?task_core.pub_key,
498 ?err,
499 a = "ignoring webrtc in_ice error",
500 );
501 }
503 }
504 SigRecv(Keepalive) | SigRecv(Unknown) => {
505 }
507 WebrtcRecv(GeneratedOffer(offer)) => {
508 netaudit!(
509 TRACE,
510 pub_key = ?task_core.pub_key,
511 offer = String::from_utf8_lossy(&offer).to_string(),
512 a = "send_offer",
513 );
514 if let Some(client) = task_core.client.upgrade() {
515 if let Err(err) =
516 client.send_offer(&task_core.pub_key, offer).await
517 {
518 netaudit!(
519 DEBUG,
520 pub_key = ?task_core.pub_key,
521 ?err,
522 a = "webrtc send_offer error",
523 );
524 return Abort;
525 }
526 } else {
527 return Abort;
528 }
529 }
530 WebrtcRecv(GeneratedAnswer(answer)) => {
531 netaudit!(
532 TRACE,
533 pub_key = ?task_core.pub_key,
534 offer = String::from_utf8_lossy(&answer).to_string(),
535 a = "send_answer",
536 );
537 if let Some(client) = task_core.client.upgrade() {
538 if let Err(err) =
539 client.send_answer(&task_core.pub_key, answer).await
540 {
541 netaudit!(
542 DEBUG,
543 pub_key = ?task_core.pub_key,
544 ?err,
545 a = "webrtc send_answer error",
546 );
547 return Abort;
548 }
549 } else {
550 return Abort;
551 }
552 }
553 WebrtcRecv(GeneratedIce(ice)) => {
554 netaudit!(
555 TRACE,
556 pub_key = ?task_core.pub_key,
557 offer = String::from_utf8_lossy(&ice).to_string(),
558 a = "send_ice",
559 );
560 if let Some(client) = task_core.client.upgrade() {
561 if let Err(err) =
562 client.send_ice(&task_core.pub_key, ice).await
563 {
564 netaudit!(
565 DEBUG,
566 pub_key = ?task_core.pub_key,
567 ?err,
568 a = "webrtc send_ice error",
569 );
570 return Abort;
571 }
572 } else {
573 return Abort;
574 }
575 }
576 WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
577 if task_core.handle_recv_msg(msg).await.is_err() {
578 return Abort;
579 }
580 }
581 WebrtcRecv(Ready) => {
582 is_ready = true;
583 task_core.is_webrtc.store(true, Ordering::SeqCst);
584 task_core.ready.close();
585 }
586 SendMessage(msg) => {
587 let len = msg.len();
588
589 netaudit!(
590 TRACE,
591 pub_key = ?task_core.pub_key,
592 byte_len = len,
593 a = "queue msg for backend send",
594 );
595 if let Err(err) = webrtc.message(msg).await {
596 netaudit!(
597 WARN,
598 pub_key = ?task_core.pub_key,
599 ?err,
600 a = "webrtc fallback: failed to send message",
601 );
602 return Fallback(task_core);
603 }
604
605 task_core.track_send_msg(len);
606 }
607 WebrtcTimeoutCheck => {
608 if !is_ready {
609 netaudit!(
610 WARN,
611 pub_key = ?task_core.pub_key,
612 a = "webrtc fallback: failed to ready within timeout",
613 );
614 return Fallback(task_core);
615 }
616 }
617 WebrtcClosed => {
618 netaudit!(
619 WARN,
620 pub_key = ?task_core.pub_key,
621 a = "webrtc fallback: webrtc processing task closed",
622 );
623 return Fallback(task_core);
624 }
625 }
626 }
627
628 Abort
629}
630
631async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
632 task_core.ready.close();
634
635 while let Some(cmd) = recv_cmd(&mut task_core).await {
636 match cmd {
637 ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
638 if task_core.handle_recv_msg(msg).await.is_err() {
639 break;
640 }
641 }
642 ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
643 Some(client) => {
644 let len = msg.len();
645 if let Err(err) =
646 client.send_message(&task_core.pub_key, msg).await
647 {
648 netaudit!(
649 DEBUG,
650 pub_key = ?task_core.pub_key,
651 ?err,
652 a = "close: sbd client send error",
653 );
654 break;
655 }
656 task_core.track_send_msg(len);
657 }
658 None => {
659 netaudit!(
660 DEBUG,
661 pub_key = ?task_core.pub_key,
662 a = "close: sbd client closed",
663 );
664 break;
665 }
666 },
667 _ => (),
668 }
669 }
670}