tx5_connection/
conn.rs

1use super::*;
2use std::sync::atomic::Ordering;
3
4pub(crate) enum ConnCmd {
5    SigRecv(tx5_signal::SignalMessage),
6    WebrtcRecv(webrtc::WebrtcEvt),
7    SendMessage(Vec<u8>),
8    WebrtcTimeoutCheck,
9    WebrtcClosed,
10}
11
12/// Receive messages from a tx5 connection.
13pub struct ConnRecv(CloseRecv<Vec<u8>>);
14
15impl ConnRecv {
16    /// Receive up to 16KiB of message data.
17    pub async fn recv(&mut self) -> Option<Vec<u8>> {
18        self.0.recv().await
19    }
20}
21
22/// A tx5 connection.
23pub struct Conn {
24    ready: Arc<tokio::sync::Semaphore>,
25    pub_key: PubKey,
26    cmd_send: CloseSend<ConnCmd>,
27    conn_task: tokio::task::JoinHandle<()>,
28    keepalive_task: tokio::task::JoinHandle<()>,
29    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
30    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
31    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
32    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
33    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
34    hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
35}
36
37macro_rules! netaudit {
38    ($lvl:ident, $($all:tt)*) => {
39        ::tracing::event!(
40            target: "NETAUDIT",
41            ::tracing::Level::$lvl,
42            m = "tx5-connection",
43            $($all)*
44        );
45    };
46}
47
48impl Drop for Conn {
49    fn drop(&mut self) {
50        netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
51
52        self.conn_task.abort();
53        self.keepalive_task.abort();
54
55        let hub_cmd_send = self.hub_cmd_send.clone();
56        let pub_key = self.pub_key.clone();
57        tokio::task::spawn(async move {
58            let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
59        });
60    }
61}
62
63impl Conn {
64    #[cfg(test)]
65    pub(crate) fn test_kill_keepalive_task(&self) {
66        self.keepalive_task.abort();
67    }
68
69    pub(crate) fn priv_new(
70        webrtc_config: WebRtcConfig,
71        is_polite: bool,
72        pub_key: PubKey,
73        client: Weak<tx5_signal::SignalConnection>,
74        config: Arc<HubConfig>,
75        hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
76    ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
77        netaudit!(DEBUG, ?webrtc_config, ?pub_key, ?is_polite, a = "open",);
78
79        let is_webrtc = Arc::new(std::sync::atomic::AtomicBool::new(false));
80        let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
81        let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
82        let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
83        let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
84
85        // zero len semaphore.. we actually just wait for the close
86        let ready = Arc::new(tokio::sync::Semaphore::new(0));
87
88        let (mut msg_send, msg_recv) = CloseSend::sized_channel(1024);
89        let (cmd_send, cmd_recv) = CloseSend::sized_channel(1024);
90
91        let keepalive_dur = config.signal_config.max_idle / 2;
92        let client2 = client.clone();
93        let pub_key2 = pub_key.clone();
94        let keepalive_task = tokio::task::spawn(async move {
95            loop {
96                tokio::time::sleep(keepalive_dur).await;
97
98                if let Some(client) = client2.upgrade() {
99                    if client.send_keepalive(&pub_key2).await.is_err() {
100                        break;
101                    }
102                } else {
103                    break;
104                }
105            }
106        });
107
108        msg_send.set_close_on_drop(true);
109
110        let con_task_fut = con_task(
111            is_polite,
112            webrtc_config,
113            TaskCore {
114                client,
115                config,
116                pub_key: pub_key.clone(),
117                cmd_send: cmd_send.clone(),
118                cmd_recv,
119                send_msg_count: send_msg_count.clone(),
120                send_byte_count: send_byte_count.clone(),
121                recv_msg_count: recv_msg_count.clone(),
122                recv_byte_count: recv_byte_count.clone(),
123                msg_send,
124                ready: ready.clone(),
125                is_webrtc: is_webrtc.clone(),
126            },
127        );
128        let conn_task = tokio::task::spawn(con_task_fut);
129
130        let mut cmd_send2 = cmd_send.clone();
131        cmd_send2.set_close_on_drop(true);
132        let this = Self {
133            ready,
134            pub_key,
135            cmd_send: cmd_send2,
136            conn_task,
137            keepalive_task,
138            is_webrtc,
139            send_msg_count,
140            send_byte_count,
141            recv_msg_count,
142            recv_byte_count,
143            hub_cmd_send,
144        };
145
146        (Arc::new(this), ConnRecv(msg_recv), cmd_send)
147    }
148
149    /// Wait until this connection is ready to send / receive data.
150    pub async fn ready(&self) {
151        // this will error when we close the semaphore waking up the task
152        let _ = self.ready.acquire().await;
153    }
154
155    /// Returns `true` if we sucessfully connected over webrtc.
156    pub fn is_using_webrtc(&self) -> bool {
157        self.is_webrtc.load(Ordering::SeqCst)
158    }
159
160    /// The pub key of the remote peer this is connected to.
161    pub fn pub_key(&self) -> &PubKey {
162        &self.pub_key
163    }
164
165    /// Send up to 16KiB of message data.
166    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
167        self.cmd_send.send(ConnCmd::SendMessage(msg)).await
168    }
169
170    /// Get connection statistics.
171    pub fn get_stats(&self) -> ConnStats {
172        ConnStats {
173            send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
174            send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
175            recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
176            recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
177        }
178    }
179}
180
181/// Connection statistics.
182#[derive(Default)]
183pub struct ConnStats {
184    /// message count sent.
185    pub send_msg_count: u64,
186
187    /// byte count sent.
188    pub send_byte_count: u64,
189
190    /// message count received.
191    pub recv_msg_count: u64,
192
193    /// byte count received.
194    pub recv_byte_count: u64,
195}
196
197struct TaskCore {
198    config: Arc<HubConfig>,
199    client: Weak<tx5_signal::SignalConnection>,
200    pub_key: PubKey,
201    cmd_send: CloseSend<ConnCmd>,
202    cmd_recv: CloseRecv<ConnCmd>,
203    msg_send: CloseSend<Vec<u8>>,
204    ready: Arc<tokio::sync::Semaphore>,
205    is_webrtc: Arc<std::sync::atomic::AtomicBool>,
206    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
207    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
208    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
209    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
210}
211
212impl TaskCore {
213    async fn handle_recv_msg(
214        &self,
215        msg: Vec<u8>,
216    ) -> std::result::Result<(), ()> {
217        self.recv_msg_count.fetch_add(1, Ordering::Relaxed);
218        self.recv_byte_count
219            .fetch_add(msg.len() as u64, Ordering::Relaxed);
220        if self.msg_send.send(msg).await.is_err() {
221            netaudit!(
222                DEBUG,
223                pub_key = ?self.pub_key,
224                a = "close: msg_send closed",
225            );
226            Err(())
227        } else {
228            Ok(())
229        }
230    }
231
232    fn track_send_msg(&self, len: usize) {
233        self.send_msg_count.fetch_add(1, Ordering::Relaxed);
234        self.send_byte_count
235            .fetch_add(len as u64, Ordering::Relaxed);
236    }
237}
238
239async fn con_task(
240    is_polite: bool,
241    webrtc_config: WebRtcConfig,
242    mut task_core: TaskCore,
243) {
244    // first process the handshake
245    if let Some(client) = task_core.client.upgrade() {
246        let handshake_fut = async {
247            let nonce = client.send_handshake_req(&task_core.pub_key).await?;
248
249            let mut got_peer_res = false;
250            let mut sent_our_res = false;
251
252            while let Some(cmd) = task_core.cmd_recv.recv().await {
253                match cmd {
254                    ConnCmd::SigRecv(sig) => {
255                        use tx5_signal::SignalMessage::*;
256                        match sig {
257                            HandshakeReq(oth_nonce) => {
258                                client
259                                    .send_handshake_res(
260                                        &task_core.pub_key,
261                                        oth_nonce,
262                                    )
263                                    .await?;
264                                sent_our_res = true;
265                            }
266                            HandshakeRes(res_nonce) => {
267                                if res_nonce != nonce {
268                                    return Err(Error::other("nonce mismatch"));
269                                }
270                                got_peer_res = true;
271                            }
272                            // Ignore all other message types...
273                            // they may be from previous sessions
274                            _ => (),
275                        }
276                    }
277                    ConnCmd::SendMessage(_) => {
278                        return Err(Error::other("send before ready"));
279                    }
280                    ConnCmd::WebrtcTimeoutCheck
281                    | ConnCmd::WebrtcRecv(_)
282                    | ConnCmd::WebrtcClosed => {
283                        // only emitted by the webrtc module
284                        // which at this point hasn't yet been initialized
285                        unreachable!()
286                    }
287                }
288                if got_peer_res && sent_our_res {
289                    break;
290                }
291            }
292
293            Result::Ok(())
294        };
295
296        match tokio::time::timeout(
297            task_core.config.signal_config.max_idle,
298            handshake_fut,
299        )
300        .await
301        {
302            Err(_) | Ok(Err(_)) => {
303                client.close_peer(&task_core.pub_key).await;
304                return;
305            }
306            Ok(Ok(_)) => (),
307        }
308    } else {
309        return;
310    }
311
312    // next, attempt webrtc
313    let task_core = match con_task_attempt_webrtc(
314        is_polite,
315        webrtc_config,
316        task_core,
317    )
318    .await
319    {
320        AttemptWebrtcResult::Abort => return,
321        AttemptWebrtcResult::Fallback(task_core) => task_core,
322    };
323
324    task_core.is_webrtc.store(false, Ordering::SeqCst);
325
326    // if webrtc failed in a way that allows us to fall back to sbd,
327    // use the fallback sbd messaging system
328    con_task_fallback_use_signal(task_core).await;
329}
330
331async fn recv_cmd(task_core: &mut TaskCore) -> Option<ConnCmd> {
332    match tokio::time::timeout(
333        task_core.config.signal_config.max_idle,
334        task_core.cmd_recv.recv(),
335    )
336    .await
337    {
338        Err(_) => {
339            netaudit!(
340                DEBUG,
341                pub_key = ?task_core.pub_key,
342                a = "close: connection idle",
343            );
344            None
345        }
346        Ok(None) => {
347            netaudit!(
348                DEBUG,
349                pub_key = ?task_core.pub_key,
350                a = "close: cmd_recv stream complete",
351            );
352            None
353        }
354        Ok(Some(cmd)) => Some(cmd),
355    }
356}
357
358async fn webrtc_task(
359    mut webrtc_recv: CloseRecv<webrtc::WebrtcEvt>,
360    cmd_send: CloseSend<ConnCmd>,
361) {
362    while let Some(evt) = webrtc_recv.recv().await {
363        if cmd_send.send(ConnCmd::WebrtcRecv(evt)).await.is_err() {
364            break;
365        }
366    }
367    let _ = cmd_send.send(ConnCmd::WebrtcClosed).await;
368}
369
370enum AttemptWebrtcResult {
371    Abort,
372    Fallback(TaskCore),
373}
374
375async fn con_task_attempt_webrtc(
376    is_polite: bool,
377    webrtc_config: WebRtcConfig,
378    mut task_core: TaskCore,
379) -> AttemptWebrtcResult {
380    use AttemptWebrtcResult::*;
381
382    let timeout_dur = task_core.config.signal_config.max_idle;
383    let timeout_cmd_send = task_core.cmd_send.clone();
384    tokio::task::spawn(async move {
385        tokio::time::sleep(timeout_dur).await;
386        let _ = timeout_cmd_send.send(ConnCmd::WebrtcTimeoutCheck).await;
387    });
388
389    let (webrtc, webrtc_recv) = webrtc::new_backend_module(
390        task_core.config.backend_module,
391        is_polite,
392        webrtc_config,
393        // MAYBE - make this configurable
394        4096,
395    );
396
397    struct AbortWebrtc(tokio::task::AbortHandle);
398
399    impl Drop for AbortWebrtc {
400        fn drop(&mut self) {
401            self.0.abort();
402        }
403    }
404
405    // ensure if we exit this loop that the tokio task is stopped
406    let _abort_webrtc = AbortWebrtc(
407        tokio::task::spawn(webrtc_task(
408            webrtc_recv,
409            task_core.cmd_send.clone(),
410        ))
411        .abort_handle(),
412    );
413
414    let mut is_ready = false;
415
416    #[cfg(test)]
417    if task_core.config.test_fail_webrtc {
418        netaudit!(
419            WARN,
420            pub_key = ?task_core.pub_key,
421            a = "webrtc fallback: test",
422        );
423        return Fallback(task_core);
424    }
425
426    while let Some(cmd) = recv_cmd(&mut task_core).await {
427        use tx5_signal::SignalMessage::*;
428        use webrtc::WebrtcEvt::*;
429        use ConnCmd::*;
430        match cmd {
431            SigRecv(HandshakeReq(_)) | SigRecv(HandshakeRes(_)) => {
432                netaudit!(
433                    DEBUG,
434                    pub_key = ?task_core.pub_key,
435                    a = "close: unexpected handshake msg",
436                );
437                return Abort;
438            }
439            SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
440                if task_core.handle_recv_msg(msg).await.is_err() {
441                    return Abort;
442                }
443                netaudit!(
444                    WARN,
445                    pub_key = ?task_core.pub_key,
446                    a = "webrtc fallback: remote sent us an sbd message",
447                );
448                // if we get a message from the remote, we have to assume
449                // they are switching to fallback mode, and thus we cannot
450                // use webrtc ourselves.
451                return Fallback(task_core);
452            }
453            SigRecv(Offer(offer)) => {
454                netaudit!(
455                    TRACE,
456                    pub_key = ?task_core.pub_key,
457                    offer = String::from_utf8_lossy(&offer).to_string(),
458                    a = "recv_offer",
459                );
460                if let Err(err) = webrtc.in_offer(offer).await {
461                    netaudit!(
462                        WARN,
463                        pub_key = ?task_core.pub_key,
464                        ?err,
465                        a = "webrtc fallback: failed to parse received offer",
466                    );
467                    return Fallback(task_core);
468                }
469            }
470            SigRecv(Answer(answer)) => {
471                netaudit!(
472                    TRACE,
473                    pub_key = ?task_core.pub_key,
474                    offer = String::from_utf8_lossy(&answer).to_string(),
475                    a = "recv_answer",
476                );
477                if let Err(err) = webrtc.in_answer(answer).await {
478                    netaudit!(
479                        WARN,
480                        pub_key = ?task_core.pub_key,
481                        ?err,
482                        a = "webrtc fallback: failed to parse received answer",
483                    );
484                    return Fallback(task_core);
485                }
486            }
487            SigRecv(Ice(ice)) => {
488                netaudit!(
489                    TRACE,
490                    pub_key = ?task_core.pub_key,
491                    offer = String::from_utf8_lossy(&ice).to_string(),
492                    a = "recv_ice",
493                );
494                if let Err(err) = webrtc.in_ice(ice).await {
495                    netaudit!(
496                        DEBUG,
497                        pub_key = ?task_core.pub_key,
498                        ?err,
499                        a = "ignoring webrtc in_ice error",
500                    );
501                    // ice errors are often benign... just ignore it
502                }
503            }
504            SigRecv(Keepalive) | SigRecv(Unknown) => {
505                // these are no-ops
506            }
507            WebrtcRecv(GeneratedOffer(offer)) => {
508                netaudit!(
509                    TRACE,
510                    pub_key = ?task_core.pub_key,
511                    offer = String::from_utf8_lossy(&offer).to_string(),
512                    a = "send_offer",
513                );
514                if let Some(client) = task_core.client.upgrade() {
515                    if let Err(err) =
516                        client.send_offer(&task_core.pub_key, offer).await
517                    {
518                        netaudit!(
519                            DEBUG,
520                            pub_key = ?task_core.pub_key,
521                            ?err,
522                            a = "webrtc send_offer error",
523                        );
524                        return Abort;
525                    }
526                } else {
527                    return Abort;
528                }
529            }
530            WebrtcRecv(GeneratedAnswer(answer)) => {
531                netaudit!(
532                    TRACE,
533                    pub_key = ?task_core.pub_key,
534                    offer = String::from_utf8_lossy(&answer).to_string(),
535                    a = "send_answer",
536                );
537                if let Some(client) = task_core.client.upgrade() {
538                    if let Err(err) =
539                        client.send_answer(&task_core.pub_key, answer).await
540                    {
541                        netaudit!(
542                            DEBUG,
543                            pub_key = ?task_core.pub_key,
544                            ?err,
545                            a = "webrtc send_answer error",
546                        );
547                        return Abort;
548                    }
549                } else {
550                    return Abort;
551                }
552            }
553            WebrtcRecv(GeneratedIce(ice)) => {
554                netaudit!(
555                    TRACE,
556                    pub_key = ?task_core.pub_key,
557                    offer = String::from_utf8_lossy(&ice).to_string(),
558                    a = "send_ice",
559                );
560                if let Some(client) = task_core.client.upgrade() {
561                    if let Err(err) =
562                        client.send_ice(&task_core.pub_key, ice).await
563                    {
564                        netaudit!(
565                            DEBUG,
566                            pub_key = ?task_core.pub_key,
567                            ?err,
568                            a = "webrtc send_ice error",
569                        );
570                        return Abort;
571                    }
572                } else {
573                    return Abort;
574                }
575            }
576            WebrtcRecv(webrtc::WebrtcEvt::Message(msg)) => {
577                if task_core.handle_recv_msg(msg).await.is_err() {
578                    return Abort;
579                }
580            }
581            WebrtcRecv(Ready) => {
582                is_ready = true;
583                task_core.is_webrtc.store(true, Ordering::SeqCst);
584                task_core.ready.close();
585            }
586            SendMessage(msg) => {
587                let len = msg.len();
588
589                netaudit!(
590                    TRACE,
591                    pub_key = ?task_core.pub_key,
592                    byte_len = len,
593                    a = "queue msg for backend send",
594                );
595                if let Err(err) = webrtc.message(msg).await {
596                    netaudit!(
597                        WARN,
598                        pub_key = ?task_core.pub_key,
599                        ?err,
600                        a = "webrtc fallback: failed to send message",
601                    );
602                    return Fallback(task_core);
603                }
604
605                task_core.track_send_msg(len);
606            }
607            WebrtcTimeoutCheck => {
608                if !is_ready {
609                    netaudit!(
610                        WARN,
611                        pub_key = ?task_core.pub_key,
612                        a = "webrtc fallback: failed to ready within timeout",
613                    );
614                    return Fallback(task_core);
615                }
616            }
617            WebrtcClosed => {
618                netaudit!(
619                    WARN,
620                    pub_key = ?task_core.pub_key,
621                    a = "webrtc fallback: webrtc processing task closed",
622                );
623                return Fallback(task_core);
624            }
625        }
626    }
627
628    Abort
629}
630
631async fn con_task_fallback_use_signal(mut task_core: TaskCore) {
632    // closing the semaphore causes all the acquire awaits to end
633    task_core.ready.close();
634
635    while let Some(cmd) = recv_cmd(&mut task_core).await {
636        match cmd {
637            ConnCmd::SigRecv(tx5_signal::SignalMessage::Message(msg)) => {
638                if task_core.handle_recv_msg(msg).await.is_err() {
639                    break;
640                }
641            }
642            ConnCmd::SendMessage(msg) => match task_core.client.upgrade() {
643                Some(client) => {
644                    let len = msg.len();
645                    if let Err(err) =
646                        client.send_message(&task_core.pub_key, msg).await
647                    {
648                        netaudit!(
649                            DEBUG,
650                            pub_key = ?task_core.pub_key,
651                            ?err,
652                            a = "close: sbd client send error",
653                        );
654                        break;
655                    }
656                    task_core.track_send_msg(len);
657                }
658                None => {
659                    netaudit!(
660                        DEBUG,
661                        pub_key = ?task_core.pub_key,
662                        a = "close: sbd client closed",
663                    );
664                    break;
665                }
666            },
667            _ => (),
668        }
669    }
670}