tx5_connection/
conn.rs

1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5    SigRecv(tx5_signal::SignalMessage),
6    WebrtcRecv(webrtc::WebrtcEvt),
7    SendMessage(Vec<u8>),
8    WebrtcTimeoutCheck,
9    WebrtcClosed,
10}
11
12/// Receive messages from a tx5 connection.
13pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16    /// Receive up to 16KiB of message data.
17    pub async fn recv(&mut self) -> Option<Vec<u8>> {
18        self.0.recv().await
19    }
20}
21
22/// A tx5 connection.
23pub struct Conn {
24    ready: Arc<tokio::sync::Semaphore>,
25    pub_key: PubKey,
26    cmd_send: CloseSend<ConnCmd>,
27    conn_task: tokio::task::JoinHandle<()>,
28    keepalive_task: tokio::task::JoinHandle<()>,
29    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34    hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38    ($lvl:ident, $($all:tt)*) => {
39        ::tracing::event!(
40            target: "NETAUDIT",
41            ::tracing::Level::$lvl,
42            m = "tx5-connection",
43            $($all)*
44        );
45    };
46}
47
48impl Drop for Conn {
49    fn drop(&mut self) {
50        netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52        self.conn_task.abort();
53        self.keepalive_task.abort();
54
55        let hub_cmd_send = self.hub_cmd_send.clone();
56        let pub_key = self.pub_key.clone();
57        tokio::task::spawn(async move {
58            let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59        });
60    }
61}
62
63impl Conn {
64    #[cfg(test)]
65    pub(crate) fn test_kill_keepalive_task(&self) {
66        self.keepalive_task.abort();
67    }
68
69    pub(crate) fn priv_new(
70        webrtc_config: WebRtcConfig,
71        is_polite: bool,
72        pub_key: PubKey,
73        client: Weak<tx5_signal::SignalConnection>,
74        config: Arc<HubConfig>,
75        hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76    ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77        netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
78
79        // set up some metrics
80        let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
81        let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
82        let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
83        let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
84        let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
85
86        // zero len semaphore.. we actually just wait for the close
87        let ready = Arc::new(tokio::sync::Semaphore::new(0));
88
89        let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
90        let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
91
92        // signal keepalive task
93        let keepalive_dur = config.signal_config.max_idle / 2;
94        let client2 = client.clone();
95        let pub_key2 = pub_key.clone();
96        let keepalive_task = tokio::task::spawn(async move {
97            loop {
98                tokio::time::sleep(keepalive_dur).await;
99
100                if let Some(client) = client2.upgrade() {
101                    if client.send_keepalive(&pub_key2).await.is_err() {
102                        break;
103                    }
104                } else {
105                    break;
106                }
107            }
108        });
109
110        msg_send.set_close_on_drop(true);
111
112        // con_task is the main event loop for a connection
113        let con_task_fut = con_task(
114            is_polite,
115            webrtc_config,
116            TaskCore {
117                client,
118                config,
119                pub_key: pub_key.clone(),
120                cmd_send: cmd_send.clone(),
121                cmd_recv,
122                send_msg_count: send_msg_count.clone(),
123                send_byte_count: send_byte_count.clone(),
124                recv_msg_count: recv_msg_count.clone(),
125                recv_byte_count: recv_byte_count.clone(),
126                msg_send,
127                ready: ready.clone(),
128                is_webrtc: is_webrtc.clone(),
129            },
130        );
131        let conn_task = tokio::task::spawn(con_task_fut);
132
133        let mut cmd_send2 = cmd_send.clone();
134        cmd_send2.set_close_on_drop(true);
135        let this = Self {
136            ready,
137            pub_key,
138            cmd_send: cmd_send2,
139            conn_task,
140            keepalive_task,
141            is_webrtc,
142            send_msg_count,
143            send_byte_count,
144            recv_msg_count,
145            recv_byte_count,
146            hub_cmd_send,
147        };
148
149        (Arc::new(this), ConnRecv(msg_recv), cmd_send)
150    }
151
152    /// Wait until this connection is ready to send / receive data.
153    pub async fn ready(&self) {
154        // this will error when we close the semaphore waking up the task
155        let _ = self.ready.acquire().await;
156    }
157
158    /// Returns `true` if we successfully connected over webrtc.
159    pub fn is_using_webrtc(&self) -> bool {
160        self.is_webrtc.load(Ordering::SeqCst)
161    }
162
163    /// The pub key of the remote peer this is connected to.
164    pub fn pub_key(&self) -> &PubKey {
165        &self.pub_key
166    }
167
168    /// Send up to 16KiB of message data.
169    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
170        self.cmd_send.send(ConnCmd::SendMessage(msg)).await
171    }
172
173    /// Get connection statistics.
174    pub fn get_stats(&self) -> ConnStats {
175        ConnStats {
176            send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
177            send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
178            recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
179            recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
180        }
181    }
182}
183
184/// Connection statistics.
185#[derive(Default)]
186pub struct ConnStats {
187    /// message count sent.
188    pub send_msg_count: u64,
189
190    /// byte count sent.
191    pub send_byte_count: u64,
192
193    /// message count received.
194    pub recv_msg_count: u64,
195
196    /// byte count received.
197    pub recv_byte_count: u64,
198}
199
200struct TaskCore {
201    config: Arc<HubConfig>,
202    client: Weak<tx5_signal::SignalConnection>,
203    pub_key: PubKey,
204    cmd_send: CloseSend<ConnCmd>,
205    cmd_recv: CloseRecv<ConnCmd>,
206    msg_send: CloseSend<Vec<u8>>,
207    ready: Arc<tokio::sync::Semaphore>,
208    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
209    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
210    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
211    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
212    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
213}
214
215impl TaskCore {
216    async fn handle_recv_msg(
217        &self,
218        msg: Vec<u8>,
219    ) -> std::result::Result<(), ()> {
220        self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
221        self.recv_byte_count
222            .fetch_add(msg.len() as u64, Ordering::Relaxed);
223        if self.msg_send.send(msg).await.is_err() {
224            netaudit!(
225                DEBUG,
226                pub_key = ?self.pub_key,
227                a = "close: msg_send closed",
228            );
229            Err(())
230        } else {
231            Ok(())
232        }
233    }
234
235    fn track_send_msg(&self, len: usize) {
236        self.send_msg_count.fetch_add(1, Ordering::Relaxed);
237        self.send_byte_count
238            .fetch_add(len as u64, Ordering::Relaxed);
239    }
240}
241
242async fn con_task(
243    is_polite: bool,
244    webrtc_config: WebRtcConfig,
245    mut task_core: TaskCore,
246) {
247    // first process the handshake
248    if let Some(client) = task_core.client.upgrade() {
249        let handshake_fut = async {
250            let nonce = client.send_handshake_req(&task_core.pub_key).await?;
251
252            let mut got_peer_res = false;
253            let mut sent_our_res = false;
254
255            while let Some(cmd) = task_core.cmd_recv.recv().await {
256                match cmd {
257                    ConnCmd::SigRecv(sig) => {
258                        use tx5_signal::SignalMessage::*;
259                        match sig {
260                            HandshakeReq(oth_nonce) => {
261                                client
262                                    .send_handshake_res(
263                                        &task_core.pub_key,
264                                        oth_nonce,
265                                    )
266                                    .await?;
267                                sent_our_res = true;
268                            }
269                            HandshakeRes(res_nonce) => {
270                                if res_nonce != nonce {
271                                    return Err(Error::other("nonce mismatch"));
272                                }
273                                got_peer_res = true;
274                            }
275                            // Ignore all other message types...
276                            // they may be from previous sessions
277                            _ => (),
278                        }
279                    }
280                    ConnCmd::SendMessage(_) => {
281                        return Err(Error::other("send before ready"));
282                    }
283                    ConnCmd::WebrtcTimeoutCheck
284                    | ConnCmd::WebrtcRecv(_)
285                    | ConnCmd::WebrtcClosed => {
286                        // only emitted by the webrtc module
287                        // which at this point hasn't yet been initialized
288                        unreachable!()
289                    }
290                }
291                if got_peer_res && sent_our_res {
292                    break;
293                }
294            }
295
296            Result::Ok(())
297        };
298
299        match tokio::time::timeout(
300            task_core.config.signal_config.max_idle,
301            handshake_fut,
302        )
303        .await
304        {
305            Err(_) | Ok(Err(_)) => {
306                client.close_peer(&task_core.pub_key).await;
307                return;
308            }
309            Ok(Ok(_)) => (),
310        }
311    } else {
312        return;
313    }
314
315    // next, attempt webrtc
316    let task_core = match con_task_attempt_webrtc(
317        is_polite,
318        webrtc_config,
319        task_core,
320    )
321    .await
322    {
323        AttemptWebrtcResult::Abort => return,
324        AttemptWebrtcResult::Fallback(task_core) => {
325            if task_core.config.danger_deny_signal_relay {
326                netaudit!(
327                    INFO,
328                    pub_key = ?task_core.pub_key,
329                    a = "webrtc fallback: denied signal relay",
330                );
331                return;
332            }
333
334            task_core
335        }
336    };
337
338    task_core.is_webrtc.store(false, Ordering::SeqCst);
339
340    // if webrtc failed in a way that allows us to fall back to sbd,
341    // use the fallback sbd messaging system
342    con_task_fallback_use_signal(task_core).await;
343}
344
345async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
346    match tokio::time::timeout(
347        task_core.config.signal_config.max_idle,
348        task_core.cmd_recv.recv(),
349    )
350    .await
351    {
352        Err(_) => {
353            netaudit!(
354                DEBUG,
355                pub_key = ?task_core.pub_key,
356                a = "close: connection idle",
357            );
358            None
359        }
360        Ok(None) => {
361            netaudit!(
362                DEBUG,
363                pub_key = ?task_core.pub_key,
364                a = "close: cmd_recv stream complete",
365            );
366            None
367        }
368        Ok(Some(cmd)) => Some(cmd),
369    }
370}
371
372async fn webrtc_task(
373    mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
374    cmd_send: CloseSend<ConnCmd>,
375) {
376    while let Some(evt) = webrtc_recv.recv().await {
377        if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
378            break;
379        }
380    }
381    netaudit!(DEBUG, a = "webrtc task closed, sending WebrtcClosed",);
382    let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
383}
384
385enum AttemptWebrtcResult {
386    Abort,
387    Fallback(TaskCore),
388}
389
390async fn con_task_attempt_webrtc(
391    is_polite: bool,
392    webrtc_config: WebRtcConfig,
393    mut task_core: TaskCore,
394) -> AttemptWebrtcResult {
395    use AttemptWebrtcResult::*;
396
397    let timeout_dur = task_core.config.webrtc_connect_timeout;
398    let timeout_cmd_send = task_core.cmd_send.clone();
399    tokio::task::spawn(async move {
400        tokio::time::sleep(timeout_dur).await;
401        let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
402    });
403
404    let (webrtc, webrtc_recv) = webrtc::new_backend_module(
405        task_core.config.backend_module,
406        is_polite,
407        webrtc_config,
408        // MAYBE - make this configurable
409        4096,
410    );
411
412    struct AbortWebrtc(tokio::task::AbortHandle);
413
414    impl Drop for AbortWebrtc {
415        fn drop(&mut self) {
416            self.0.abort();
417        }
418    }
419
420    // ensure if we exit this loop that the tokio task is stopped
421    let _abort_webrtc = AbortWebrtc(
422        tokio::task::spawn(webrtc_task(
423            webrtc_recv,
424            task_core.cmd_send.clone(),
425        ))
426        .abort_handle(),
427    );
428
429    let mut is_ready = false;
430
431    if task_core.config.danger_force_signal_relay {
432        netaudit!(
433            WARN,
434            pub_key = ?task_core.pub_key,
435            a = "webrtc fallback: test",
436        );
437        return Fallback(task_core);
438    }
439
440    // receive webrtc commands
441    while let Some(cmd) = recv_cmd(&mut task_core).await {
442        use tx5_signal::SignalMessage::*;
443        use webrtc::WebrtcEvt::*;
444        use ConnCmd::*;
445        match cmd {
446            SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
447                netaudit!(
448                    DEBUG,
449                    pub_key = ?task_core.pub_key,
450                    a = "close: unexpected handshake msg",
451                );
452                break;
453            }
454            SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
455                if task_core.handle_recv_msg(msg).await.is_err() {
456                    break;
457                }
458                netaudit!(
459                    WARN,
460                    pub_key = ?task_core.pub_key,
461                    a = "webrtc fallback: remote sent us an sbd message",
462                );
463                // if we get a message from the remote, we have to assume
464                // they are switching to fallback mode, and thus we cannot
465                // use webrtc ourselves.
466                return Fallback(task_core);
467            }
468            SigRecv(Offer(offer)) => {
469                netaudit!(
470                    TRACE,
471                    pub_key = ?task_core.pub_key,
472                    offer = String::from_utf8_lossy(&offer).to_string(),
473                    a = "recv_offer",
474                );
475                if let Err(err) = webrtc.in_offer(offer).await {
476                    netaudit!(
477                        WARN,
478                        pub_key = ?task_core.pub_key,
479                        ?err,
480                        a = "webrtc fallback: failed to parse received offer",
481                    );
482                    return Fallback(task_core);
483                }
484            }
485            SigRecv(Answer(answer)) => {
486                netaudit!(
487                    TRACE,
488                    pub_key = ?task_core.pub_key,
489                    offer = String::from_utf8_lossy(&answer).to_string(),
490                    a = "recv_answer",
491                );
492                if let Err(err) = webrtc.in_answer(answer).await {
493                    netaudit!(
494                        WARN,
495                        pub_key = ?task_core.pub_key,
496                        ?err,
497                        a = "webrtc fallback: failed to parse received answer",
498                    );
499                    return Fallback(task_core);
500                }
501            }
502            SigRecv(Ice(ice)) => {
503                netaudit!(
504                    TRACE,
505                    pub_key = ?task_core.pub_key,
506                    offer = String::from_utf8_lossy(&ice).to_string(),
507                    a = "recv_ice",
508                );
509                if let Err(err) = webrtc.in_ice(ice).await {
510                    netaudit!(
511                        DEBUG,
512                        pub_key = ?task_core.pub_key,
513                        ?err,
514                        a = "ignoring webrtc in_ice error",
515                    );
516                    // ice errors are often benign... just ignore it
517                }
518            }
519            SigRecv(Keepalive) | SigRecv(Unknown) => {
520                // these are no-ops
521            }
522            WebrtcRecv(GeneratedOffer(offer)) => {
523                netaudit!(
524                    TRACE,
525                    pub_key = ?task_core.pub_key,
526                    offer = String::from_utf8_lossy(&offer).to_string(),
527                    a = "send_offer",
528                );
529                if let Some(client) = task_core.client.upgrade() {
530                    if let Err(err) =
531                        client.send_offer(&task_core.pub_key, offer).await
532                    {
533                        netaudit!(
534                            DEBUG,
535                            pub_key = ?task_core.pub_key,
536                            ?err,
537                            a = "webrtc send_offer error",
538                        );
539                        break;
540                    }
541                } else {
542                    break;
543                }
544            }
545            WebrtcRecv(GeneratedAnswer(answer)) => {
546                netaudit!(
547                    TRACE,
548                    pub_key = ?task_core.pub_key,
549                    offer = String::from_utf8_lossy(&answer).to_string(),
550                    a = "send_answer",
551                );
552                if let Some(client) = task_core.client.upgrade() {
553                    if let Err(err) =
554                        client.send_answer(&task_core.pub_key, answer).await
555                    {
556                        netaudit!(
557                            DEBUG,
558                            pub_key = ?task_core.pub_key,
559                            ?err,
560                            a = "webrtc send_answer error",
561                        );
562                        break;
563                    }
564                } else {
565                    break;
566                }
567            }
568            WebrtcRecv(GeneratedIce(ice)) => {
569                netaudit!(
570                    TRACE,
571                    pub_key = ?task_core.pub_key,
572                    offer = String::from_utf8_lossy(&ice).to_string(),
573                    a = "send_ice",
574                );
575                if let Some(client) = task_core.client.upgrade() {
576                    if let Err(err) =
577                        client.send_ice(&task_core.pub_key, ice).await
578                    {
579                        netaudit!(
580                            DEBUG,
581                            pub_key = ?task_core.pub_key,
582                            ?err,
583                            a = "webrtc send_ice error",
584                        );
585                        break;
586                    }
587                } else {
588                    break;
589                }
590            }
591            WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
592                if task_core.handle_recv_msg(msg).await.is_err() {
593                    break;
594                }
595            }
596            WebrtcRecv(Ready) => {
597                is_ready = true;
598                task_core.is_webrtc.store(true, Ordering::SeqCst);
599                task_core.ready.close();
600            }
601            SendMessage(msg) => {
602                let len = msg.len();
603
604                netaudit!(
605                    TRACE,
606                    pub_key = ?task_core.pub_key,
607                    byte_len = len,
608                    a = "queue msg for backend send",
609                );
610                if let Err(err) = webrtc.message(msg).await {
611                    netaudit!(
612                        WARN,
613                        pub_key = ?task_core.pub_key,
614                        ?err,
615                        a = "webrtc fallback: failed to send message",
616                    );
617                    return Fallback(task_core);
618                }
619
620                task_core.track_send_msg(len);
621            }
622            WebrtcTimeoutCheck => {
623                if !is_ready {
624                    netaudit!(
625                        WARN,
626                        pub_key = ?task_core.pub_key,
627                        a = "webrtc fallback: failed to ready within timeout",
628                    );
629                    return Fallback(task_core);
630                }
631            }
632            WebrtcClosed => {
633                netaudit!(
634                    WARN,
635                    pub_key = ?task_core.pub_key,
636                    a = "webrtc processing task closed",
637                );
638                break;
639            }
640        }
641    }
642
643    Abort
644}
645
646async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
647    // closing the semaphore causes all the acquire awaits to end
648    task_core.ready.close();
649
650    while let Some(cmd) = recv_cmd(&mut task_core).await {
651        match cmd {
652            ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
653                if task_core.handle_recv_msg(msg).await.is_err() {
654                    break;
655                }
656            }
657            ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
658                Some(client) => {
659                    let len = msg.len();
660                    if let Err(err) =
661                        client.send_message(&task_core.pub_key, msg).await
662                    {
663                        netaudit!(
664                            DEBUG,
665                            pub_key = ?task_core.pub_key,
666                            ?err,
667                            a = "close: sbd client send error",
668                        );
669                        break;
670                    }
671                    task_core.track_send_msg(len);
672                }
673                None => {
674                    netaudit!(
675                        DEBUG,
676                        pub_key = ?task_core.pub_key,
677                        a = "close: sbd client closed",
678                    );
679                    break;
680                }
681            },
682            _ => (),
683        }
684    }
685}