tx5_connection/
conn.rs

1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5    SigRecv(tx5_signal::SignalMessage),
6    WebrtcRecv(webrtc::WebrtcEvt),
7    SendMessage(Vec<u8>),
8    WebrtcTimeoutCheck,
9    WebrtcClosed,
10}
11
12/// Receive messages from a tx5 connection.
13pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16    /// Receive up to 16KiB of message data.
17    pub async fn recv(&mut self) -> Option<Vec<u8>> {
18        self.0.recv().await
19    }
20}
21
22/// A tx5 connection.
23pub struct Conn {
24    ready: Arc<tokio::sync::Semaphore>,
25    pub_key: PubKey,
26    cmd_send: CloseSend<ConnCmd>,
27    conn_task: tokio::task::JoinHandle<()>,
28    keepalive_task: tokio::task::JoinHandle<()>,
29    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34    hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38    ($lvl:ident, $($all:tt)*) => {
39        ::tracing::event!(
40            target: "NETAUDIT",
41            ::tracing::Level::$lvl,
42            m = "tx5-connection",
43            $($all)*
44        );
45    };
46}
47
48impl Drop for Conn {
49    fn drop(&mut self) {
50        netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52        self.conn_task.abort();
53        self.keepalive_task.abort();
54
55        let hub_cmd_send = self.hub_cmd_send.clone();
56        let pub_key = self.pub_key.clone();
57        tokio::task::spawn(async move {
58            let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59        });
60    }
61}
62
63impl Conn {
64    #[cfg(test)]
65    pub(crate) fn test_kill_keepalive_task(&self) {
66        self.keepalive_task.abort();
67    }
68
69    pub(crate) fn priv_new(
70        webrtc_config: Vec<u8>,
71        is_polite: bool,
72        pub_key: PubKey,
73        client: Weak<tx5_signal::SignalConnection>,
74        config: Arc<HubConfig>,
75        hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76    ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77        netaudit!(
78            DEBUG,
79            webrtc_config = String::from_utf8_lossy(&webrtc_config).to_string(),
80            ?pub_key,
81            ?is_polite,
82            a = "open",
83        );
84
85        let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
86        let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
87        let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
88        let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
89        let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
90
91        // zero len semaphore.. we actually just wait for the close
92        let ready = Arc::new(tokio::sync::Semaphore::new(0));
93
94        let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
95        let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
96
97        let keepalive_dur = config.signal_config.max_idle / 2;
98        let client2 = client.clone();
99        let pub_key2 = pub_key.clone();
100        let keepalive_task = tokio::task::spawn(async move {
101            loop {
102                tokio::time::sleep(keepalive_dur).await;
103
104                if let Some(client) = client2.upgrade() {
105                    if client.send_keepalive(&pub_key2).await.is_err() {
106                        break;
107                    }
108                } else {
109                    break;
110                }
111            }
112        });
113
114        msg_send.set_close_on_drop(true);
115
116        let con_task_fut = con_task(
117            is_polite,
118            webrtc_config,
119            TaskCore {
120                client,
121                config,
122                pub_key: pub_key.clone(),
123                cmd_send: cmd_send.clone(),
124                cmd_recv,
125                send_msg_count: send_msg_count.clone(),
126                send_byte_count: send_byte_count.clone(),
127                recv_msg_count: recv_msg_count.clone(),
128                recv_byte_count: recv_byte_count.clone(),
129                msg_send,
130                ready: ready.clone(),
131                is_webrtc: is_webrtc.clone(),
132            },
133        );
134        let conn_task = tokio::task::spawn(con_task_fut);
135
136        let mut cmd_send2 = cmd_send.clone();
137        cmd_send2.set_close_on_drop(true);
138        let this = Self {
139            ready,
140            pub_key,
141            cmd_send: cmd_send2,
142            conn_task,
143            keepalive_task,
144            is_webrtc,
145            send_msg_count,
146            send_byte_count,
147            recv_msg_count,
148            recv_byte_count,
149            hub_cmd_send,
150        };
151
152        (Arc::new(this), ConnRecv(msg_recv), cmd_send)
153    }
154
155    /// Wait until this connection is ready to send / receive data.
156    pub async fn ready(&self) {
157        // this will error when we close the semaphore waking up the task
158        let _ = self.ready.acquire().await;
159    }
160
161    /// Returns `true` if we sucessfully connected over webrtc.
162    pub fn is_using_webrtc(&self) -> bool {
163        self.is_webrtc.load(Ordering::SeqCst)
164    }
165
166    /// The pub key of the remote peer this is connected to.
167    pub fn pub_key(&self) -> &PubKey {
168        &self.pub_key
169    }
170
171    /// Send up to 16KiB of message data.
172    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
173        self.cmd_send.send(ConnCmd::SendMessage(msg)).await
174    }
175
176    /// Get connection statistics.
177    pub fn get_stats(&self) -> ConnStats {
178        ConnStats {
179            send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
180            send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
181            recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
182            recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
183        }
184    }
185}
186
187/// Connection statistics.
188#[derive(Default)]
189pub struct ConnStats {
190    /// message count sent.
191    pub send_msg_count: u64,
192
193    /// byte count sent.
194    pub send_byte_count: u64,
195
196    /// message count received.
197    pub recv_msg_count: u64,
198
199    /// byte count received.
200    pub recv_byte_count: u64,
201}
202
203struct TaskCore {
204    config: Arc<HubConfig>,
205    client: Weak<tx5_signal::SignalConnection>,
206    pub_key: PubKey,
207    cmd_send: CloseSend<ConnCmd>,
208    cmd_recv: CloseRecv<ConnCmd>,
209    msg_send: CloseSend<Vec<u8>>,
210    ready: Arc<tokio::sync::Semaphore>,
211    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
212    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
213    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
214    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
215    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
216}
217
218impl TaskCore {
219    async fn handle_recv_msg(
220        &self,
221        msg: Vec<u8>,
222    ) -> std::result::Result<(), ()> {
223        self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
224        self.recv_byte_count
225            .fetch_add(msg.len() as u64, Ordering::Relaxed);
226        if self.msg_send.send(msg).await.is_err() {
227            netaudit!(
228                DEBUG,
229                pub_key = ?self.pub_key,
230                a = "close: msg_send closed",
231            );
232            Err(())
233        } else {
234            Ok(())
235        }
236    }
237
238    fn track_send_msg(&self, len: usize) {
239        self.send_msg_count.fetch_add(1, Ordering::Relaxed);
240        self.send_byte_count
241            .fetch_add(len as u64, Ordering::Relaxed);
242    }
243}
244
245async fn con_task(
246    is_polite: bool,
247    webrtc_config: Vec<u8>,
248    mut task_core: TaskCore,
249) {
250    // first process the handshake
251    if let Some(client) = task_core.client.upgrade() {
252        let handshake_fut = async {
253            let nonce = client.send_handshake_req(&task_core.pub_key).await?;
254
255            let mut got_peer_res = false;
256            let mut sent_our_res = false;
257
258            while let Some(cmd) = task_core.cmd_recv.recv().await {
259                match cmd {
260                    ConnCmd::SigRecv(sig) => {
261                        use tx5_signal::SignalMessage::*;
262                        match sig {
263                            HandshakeReq(oth_nonce) => {
264                                client
265                                    .send_handshake_res(
266                                        &task_core.pub_key,
267                                        oth_nonce,
268                                    )
269                                    .await?;
270                                sent_our_res = true;
271                            }
272                            HandshakeRes(res_nonce) => {
273                                if res_nonce != nonce {
274                                    return Err(Error::other("nonce mismatch"));
275                                }
276                                got_peer_res = true;
277                            }
278                            // Ignore all other message types...
279                            // they may be from previous sessions
280                            _ => (),
281                        }
282                    }
283                    ConnCmd::SendMessage(_) => {
284                        return Err(Error::other("send before ready"));
285                    }
286                    ConnCmd::WebrtcTimeoutCheck
287                    | ConnCmd::WebrtcRecv(_)
288                    | ConnCmd::WebrtcClosed => {
289                        // only emitted by the webrtc module
290                        // which at this point hasn't yet been initialized
291                        unreachable!()
292                    }
293                }
294                if got_peer_res && sent_our_res {
295                    break;
296                }
297            }
298
299            Result::Ok(())
300        };
301
302        match tokio::time::timeout(
303            task_core.config.signal_config.max_idle,
304            handshake_fut,
305        )
306        .await
307        {
308            Err(_) | Ok(Err(_)) => {
309                client.close_peer(&task_core.pub_key).await;
310                return;
311            }
312            Ok(Ok(_)) => (),
313        }
314    } else {
315        return;
316    }
317
318    // next, attempt webrtc
319    let task_core = match con_task_attempt_webrtc(
320        is_polite,
321        webrtc_config,
322        task_core,
323    )
324    .await
325    {
326        AttemptWebrtcResult::Abort => return,
327        AttemptWebrtcResult::Fallback(task_core) => task_core,
328    };
329
330    task_core.is_webrtc.store(false, Ordering::SeqCst);
331
332    // if webrtc failed in a way that allows us to fall back to sbd,
333    // use the fallback sbd messaging system
334    con_task_fallback_use_signal(task_core).await;
335}
336
337async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
338    match tokio::time::timeout(
339        task_core.config.signal_config.max_idle,
340        task_core.cmd_recv.recv(),
341    )
342    .await
343    {
344        Err(_) => {
345            netaudit!(
346                DEBUG,
347                pub_key = ?task_core.pub_key,
348                a = "close: connection idle",
349            );
350            None
351        }
352        Ok(None) => {
353            netaudit!(
354                DEBUG,
355                pub_key = ?task_core.pub_key,
356                a = "close: cmd_recv stream complete",
357            );
358            None
359        }
360        Ok(Some(cmd)) => Some(cmd),
361    }
362}
363
364async fn webrtc_task(
365    mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
366    cmd_send: CloseSend<ConnCmd>,
367) {
368    while let Some(evt) = webrtc_recv.recv().await {
369        if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
370            break;
371        }
372    }
373    let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
374}
375
376enum AttemptWebrtcResult {
377    Abort,
378    Fallback(TaskCore),
379}
380
381async fn con_task_attempt_webrtc(
382    is_polite: bool,
383    webrtc_config: Vec<u8>,
384    mut task_core: TaskCore,
385) -> AttemptWebrtcResult {
386    use AttemptWebrtcResult::*;
387
388    let timeout_dur = task_core.config.signal_config.max_idle;
389    let timeout_cmd_send = task_core.cmd_send.clone();
390    tokio::task::spawn(async move {
391        tokio::time::sleep(timeout_dur).await;
392        let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
393    });
394
395    let (webrtc, webrtc_recv) = webrtc::new_backend_module(
396        task_core.config.backend_module,
397        is_polite,
398        webrtc_config.clone(),
399        // MAYBE - make this configurable
400        4096,
401    );
402
403    struct AbortWebrtc(tokio::task::AbortHandle);
404
405    impl Drop for AbortWebrtc {
406        fn drop(&mut self) {
407            self.0.abort();
408        }
409    }
410
411    // ensure if we exit this loop that the tokio task is stopped
412    let _abort_webrtc = AbortWebrtc(
413        tokio::task::spawn(webrtc_task(
414            webrtc_recv,
415            task_core.cmd_send.clone(),
416        ))
417        .abort_handle(),
418    );
419
420    let mut is_ready = false;
421
422    #[cfg(test)]
423    if task_core.config.test_fail_webrtc {
424        netaudit!(
425            WARN,
426            pub_key = ?task_core.pub_key,
427            a = "webrtc fallback: test",
428        );
429        return Fallback(task_core);
430    }
431
432    while let Some(cmd) = recv_cmd(&mut task_core).await {
433        use tx5_signal::SignalMessage::*;
434        use webrtc::WebrtcEvt::*;
435        use ConnCmd::*;
436        match cmd {
437            SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
438                netaudit!(
439                    DEBUG,
440                    pub_key = ?task_core.pub_key,
441                    a = "close: unexpected handshake msg",
442                );
443                return Abort;
444            }
445            SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
446                if task_core.handle_recv_msg(msg).await.is_err() {
447                    return Abort;
448                }
449                netaudit!(
450                    WARN,
451                    pub_key = ?task_core.pub_key,
452                    a = "webrtc fallback: remote sent us an sbd message",
453                );
454                // if we get a message from the remote, we have to assume
455                // they are switching to fallback mode, and thus we cannot
456                // use webrtc ourselves.
457                return Fallback(task_core);
458            }
459            SigRecv(Offer(offer)) => {
460                netaudit!(
461                    TRACE,
462                    pub_key = ?task_core.pub_key,
463                    offer = String::from_utf8_lossy(&offer).to_string(),
464                    a = "recv_offer",
465                );
466                if let Err(err) = webrtc.in_offer(offer).await {
467                    netaudit!(
468                        WARN,
469                        pub_key = ?task_core.pub_key,
470                        ?err,
471                        a = "webrtc fallback: failed to parse received offer",
472                    );
473                    return Fallback(task_core);
474                }
475            }
476            SigRecv(Answer(answer)) => {
477                netaudit!(
478                    TRACE,
479                    pub_key = ?task_core.pub_key,
480                    offer = String::from_utf8_lossy(&answer).to_string(),
481                    a = "recv_answer",
482                );
483                if let Err(err) = webrtc.in_answer(answer).await {
484                    netaudit!(
485                        WARN,
486                        pub_key = ?task_core.pub_key,
487                        ?err,
488                        a = "webrtc fallback: failed to parse received answer",
489                    );
490                    return Fallback(task_core);
491                }
492            }
493            SigRecv(Ice(ice)) => {
494                netaudit!(
495                    TRACE,
496                    pub_key = ?task_core.pub_key,
497                    offer = String::from_utf8_lossy(&ice).to_string(),
498                    a = "recv_ice",
499                );
500                if let Err(err) = webrtc.in_ice(ice).await {
501                    netaudit!(
502                        DEBUG,
503                        pub_key = ?task_core.pub_key,
504                        ?err,
505                        a = "ignoring webrtc in_ice error",
506                    );
507                    // ice errors are often benign... just ignore it
508                }
509            }
510            SigRecv(Keepalive) | SigRecv(Unknown) => {
511                // these are no-ops
512            }
513            WebrtcRecv(GeneratedOffer(offer)) => {
514                netaudit!(
515                    TRACE,
516                    pub_key = ?task_core.pub_key,
517                    offer = String::from_utf8_lossy(&offer).to_string(),
518                    a = "send_offer",
519                );
520                if let Some(client) = task_core.client.upgrade() {
521                    if let Err(err) =
522                        client.send_offer(&task_core.pub_key, offer).await
523                    {
524                        netaudit!(
525                            DEBUG,
526                            pub_key = ?task_core.pub_key,
527                            ?err,
528                            a = "webrtc send_offer error",
529                        );
530                        return Abort;
531                    }
532                } else {
533                    return Abort;
534                }
535            }
536            WebrtcRecv(GeneratedAnswer(answer)) => {
537                netaudit!(
538                    TRACE,
539                    pub_key = ?task_core.pub_key,
540                    offer = String::from_utf8_lossy(&answer).to_string(),
541                    a = "send_answer",
542                );
543                if let Some(client) = task_core.client.upgrade() {
544                    if let Err(err) =
545                        client.send_answer(&task_core.pub_key, answer).await
546                    {
547                        netaudit!(
548                            DEBUG,
549                            pub_key = ?task_core.pub_key,
550                            ?err,
551                            a = "webrtc send_answer error",
552                        );
553                        return Abort;
554                    }
555                } else {
556                    return Abort;
557                }
558            }
559            WebrtcRecv(GeneratedIce(ice)) => {
560                netaudit!(
561                    TRACE,
562                    pub_key = ?task_core.pub_key,
563                    offer = String::from_utf8_lossy(&ice).to_string(),
564                    a = "send_ice",
565                );
566                if let Some(client) = task_core.client.upgrade() {
567                    if let Err(err) =
568                        client.send_ice(&task_core.pub_key, ice).await
569                    {
570                        netaudit!(
571                            DEBUG,
572                            pub_key = ?task_core.pub_key,
573                            ?err,
574                            a = "webrtc send_ice error",
575                        );
576                        return Abort;
577                    }
578                } else {
579                    return Abort;
580                }
581            }
582            WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
583                if task_core.handle_recv_msg(msg).await.is_err() {
584                    return Abort;
585                }
586            }
587            WebrtcRecv(Ready) => {
588                is_ready = true;
589                task_core.is_webrtc.store(true, Ordering::SeqCst);
590                task_core.ready.close();
591            }
592            SendMessage(msg) => {
593                let len = msg.len();
594
595                netaudit!(
596                    TRACE,
597                    pub_key = ?task_core.pub_key,
598                    byte_len = len,
599                    a = "queue msg for backend send",
600                );
601                if let Err(err) = webrtc.message(msg).await {
602                    netaudit!(
603                        WARN,
604                        pub_key = ?task_core.pub_key,
605                        ?err,
606                        a = "webrtc fallback: failed to send message",
607                    );
608                    return Fallback(task_core);
609                }
610
611                task_core.track_send_msg(len);
612            }
613            WebrtcTimeoutCheck => {
614                if !is_ready {
615                    netaudit!(
616                        WARN,
617                        pub_key = ?task_core.pub_key,
618                        a = "webrtc fallback: failed to ready within timeout",
619                    );
620                    return Fallback(task_core);
621                }
622            }
623            WebrtcClosed => {
624                netaudit!(
625                    WARN,
626                    pub_key = ?task_core.pub_key,
627                    a = "webrtc fallback: webrtc processing task closed",
628                );
629                return Fallback(task_core);
630            }
631        }
632    }
633
634    Abort
635}
636
637async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
638    // closing the semaphore causes all the acquire awaits to end
639    task_core.ready.close();
640
641    while let Some(cmd) = recv_cmd(&mut task_core).await {
642        match cmd {
643            ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
644                if task_core.handle_recv_msg(msg).await.is_err() {
645                    break;
646                }
647            }
648            ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
649                Some(client) => {
650                    let len = msg.len();
651                    if let Err(err) =
652                        client.send_message(&task_core.pub_key, msg).await
653                    {
654                        netaudit!(
655                            DEBUG,
656                            pub_key = ?task_core.pub_key,
657                            ?err,
658                            a = "close: sbd client send error",
659                        );
660                        break;
661                    }
662                    task_core.track_send_msg(len);
663                }
664                None => {
665                    netaudit!(
666                        DEBUG,
667                        pub_key = ?task_core.pub_key,
668                        a = "close: sbd client closed",
669                    );
670                    break;
671                }
672            },
673            _ => (),
674        }
675    }
676}