tx5_connection/
conn.rs

1use super::*;
2use std::sync::atomic::Ordering;
3
4/// Message count allowed to accumulate on the boundary switching to webrtc.
5/// This prevents potentially faster to transfer webrtc messages from being
6/// delivered out of order before trailing slower sbd messages are received.
7const MAX_WEBRTC_BUF: usize = 512;
8
9pub(crate) enum ConnCmd {
10    SigRecv(tx5_signal::SignalMessage),
11    SendMessage(Vec<u8>),
12    WebrtcMessage(Vec<u8>),
13    WebrtcReady,
14    WebrtcClosed,
15}
16
17/// Receive messages from a tx5 connection.
18pub struct ConnRecv(CloseRecv<Vec<u8>>);
19
20impl ConnRecv {
21    /// Receive up to 16KiB of message data.
22    pub async fn recv(&mut self) -> Option<Vec<u8>> {
23        self.0.recv().await
24    }
25}
26
27/// A tx5 connection.
28pub struct Conn {
29    ready: Arc<tokio::sync::Semaphore>,
30    pub_key: PubKey,
31    cmd_send: CloseSend<ConnCmd>,
32    conn_task: tokio::task::JoinHandle<()>,
33    keepalive_task: tokio::task::JoinHandle<()>,
34    webrtc_ready: Arc<tokio::sync::Semaphore>,
35    send_msg_count: Arc<std::sync::atomic::AtomicU64>,
36    send_byte_count: Arc<std::sync::atomic::AtomicU64>,
37    recv_msg_count: Arc<std::sync::atomic::AtomicU64>,
38    recv_byte_count: Arc<std::sync::atomic::AtomicU64>,
39    hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
40}
41
42macro_rules! netaudit {
43    ($lvl:ident, $($all:tt)*) => {
44        ::tracing::event!(
45            target: "NETAUDIT",
46            ::tracing::Level::$lvl,
47            m = "tx5-connection",
48            $($all)*
49        );
50    };
51}
52
53impl Drop for Conn {
54    fn drop(&mut self) {
55        netaudit!(DEBUG, pub_key = ?self.pub_key, a = "drop");
56
57        self.conn_task.abort();
58        self.keepalive_task.abort();
59
60        let hub_cmd_send = self.hub_cmd_send.clone();
61        let pub_key = self.pub_key.clone();
62        tokio::task::spawn(async move {
63            let _ = hub_cmd_send.send(HubCmd::Disconnect(pub_key)).await;
64        });
65    }
66}
67
68impl Conn {
69    #[cfg(test)]
70    pub(crate) fn test_kill_keepalive_task(&self) {
71        self.keepalive_task.abort();
72    }
73
74    pub(crate) fn priv_new(
75        webrtc_config: Vec<u8>,
76        is_polite: bool,
77        pub_key: PubKey,
78        client: Weak<tx5_signal::SignalConnection>,
79        config: Arc<HubConfig>,
80        hub_cmd_send: tokio::sync::mpsc::Sender<HubCmd>,
81    ) -> (Arc<Self>, ConnRecv, CloseSend<ConnCmd>) {
82        netaudit!(
83            DEBUG,
84            webrtc_config = String::from_utf8_lossy(&webrtc_config).to_string(),
85            ?pub_key,
86            ?is_polite,
87            a = "open",
88        );
89
90        let send_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
91        let send_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
92        let recv_msg_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
93        let recv_byte_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
94
95        // zero len semaphore.. we actually just wait for the close
96        let ready = Arc::new(tokio::sync::Semaphore::new(0));
97        let webrtc_ready = Arc::new(tokio::sync::Semaphore::new(0));
98
99        let (mut msg_send, msg_recv) = CloseSend::channel();
100        let (cmd_send, mut cmd_recv) = CloseSend::channel();
101
102        let keepalive_dur = config.signal_config.max_idle / 2;
103        let client2 = client.clone();
104        let pub_key2 = pub_key.clone();
105        let keepalive_task = tokio::task::spawn(async move {
106            loop {
107                tokio::time::sleep(keepalive_dur).await;
108
109                if let Some(client) = client2.upgrade() {
110                    if client.send_keepalive(&pub_key2).await.is_err() {
111                        break;
112                    }
113                } else {
114                    break;
115                }
116            }
117        });
118
119        let send_msg_count2 = send_msg_count.clone();
120        let send_byte_count2 = send_byte_count.clone();
121        let recv_msg_count2 = recv_msg_count.clone();
122        let recv_byte_count2 = recv_byte_count.clone();
123        let webrtc_ready2 = webrtc_ready.clone();
124        let ready2 = ready.clone();
125        let client2 = client.clone();
126        let pub_key2 = pub_key.clone();
127        let mut cmd_send3 = cmd_send.clone();
128        let conn_task = tokio::task::spawn(async move {
129            let client = match client2.upgrade() {
130                Some(client) => client,
131                None => return,
132            };
133
134            let mut webrtc_message_buffer = Vec::new();
135
136            let handshake_fut = async {
137                let nonce = client.send_handshake_req(&pub_key2).await?;
138
139                let mut got_peer_res = false;
140                let mut sent_our_res = false;
141
142                while let Some(cmd) = cmd_recv.recv().await {
143                    match cmd {
144                        ConnCmd::SigRecv(sig) => {
145                            use tx5_signal::SignalMessage::*;
146                            match sig {
147                                HandshakeReq(oth_nonce) => {
148                                    client
149                                        .send_handshake_res(
150                                            &pub_key2, oth_nonce,
151                                        )
152                                        .await?;
153                                    sent_our_res = true;
154                                }
155                                HandshakeRes(res_nonce) => {
156                                    if res_nonce != nonce {
157                                        return Err(Error::other(
158                                            "nonce mismatch",
159                                        ));
160                                    }
161                                    got_peer_res = true;
162                                }
163                                // Ignore all other message types...
164                                // they may be from previous sessions
165                                _ => (),
166                            }
167                        }
168                        ConnCmd::SendMessage(_) => {
169                            return Err(Error::other("send before ready"));
170                        }
171                        ConnCmd::WebrtcMessage(msg) => {
172                            webrtc_message_buffer.push(msg);
173                        }
174                        ConnCmd::WebrtcReady => {
175                            webrtc_ready2.close();
176                        }
177                        ConnCmd::WebrtcClosed => {
178                            // only emitted by the webrtc module
179                            // which at this point hasn't yet been initialized
180                            unreachable!()
181                        }
182                    }
183                    if got_peer_res && sent_our_res {
184                        break;
185                    }
186                }
187
188                Result::Ok(())
189            };
190
191            match tokio::time::timeout(
192                config.signal_config.max_idle,
193                handshake_fut,
194            )
195            .await
196            {
197                Err(_) | Ok(Err(_)) => {
198                    client.close_peer(&pub_key2).await;
199                    return;
200                }
201                Ok(Ok(_)) => (),
202            }
203
204            drop(client);
205
206            // closing the semaphore causes all the acquire awaits to end
207            ready2.close();
208
209            let (webrtc, mut webrtc_recv) = webrtc::new_backend_module(
210                config.backend_module,
211                is_polite,
212                webrtc_config,
213                // MAYBE - make this configurable
214                4096,
215            );
216
217            let client3 = client2.clone();
218            let pub_key3 = pub_key2.clone();
219            let _webrtc_task = AbortTask(tokio::task::spawn(async move {
220                use webrtc::WebrtcEvt::*;
221                cmd_send3.set_close_on_drop(true);
222                while let Some(evt) = webrtc_recv.recv().await {
223                    match evt {
224                        GeneratedOffer(offer) => {
225                            netaudit!(
226                                TRACE,
227                                pub_key = ?pub_key3,
228                                offer = String::from_utf8_lossy(&offer).to_string(),
229                                a = "send_offer",
230                            );
231                            if let Some(client) = client3.upgrade() {
232                                if let Err(err) =
233                                    client.send_offer(&pub_key3, offer).await
234                                {
235                                    netaudit!(
236                                        DEBUG,
237                                        pub_key = ?pub_key3,
238                                        ?err,
239                                        a = "webrtc send_offer error",
240                                    );
241                                    break;
242                                }
243                            } else {
244                                break;
245                            }
246                        }
247                        GeneratedAnswer(answer) => {
248                            netaudit!(
249                                TRACE,
250                                pub_key = ?pub_key3,
251                                offer = String::from_utf8_lossy(&answer).to_string(),
252                                a = "send_answer",
253                            );
254                            if let Some(client) = client3.upgrade() {
255                                if let Err(err) =
256                                    client.send_answer(&pub_key3, answer).await
257                                {
258                                    netaudit!(
259                                        DEBUG,
260                                        pub_key = ?pub_key3,
261                                        ?err,
262                                        a = "webrtc send_answer error",
263                                    );
264                                    break;
265                                }
266                            } else {
267                                break;
268                            }
269                        }
270                        GeneratedIce(ice) => {
271                            netaudit!(
272                                TRACE,
273                                pub_key = ?pub_key3,
274                                offer = String::from_utf8_lossy(&ice).to_string(),
275                                a = "send_ice",
276                            );
277                            if let Some(client) = client3.upgrade() {
278                                if let Err(err) =
279                                    client.send_ice(&pub_key3, ice).await
280                                {
281                                    netaudit!(
282                                        DEBUG,
283                                        pub_key = ?pub_key3,
284                                        ?err,
285                                        a = "webrtc send_ice error",
286                                    );
287                                    break;
288                                }
289                            } else {
290                                break;
291                            }
292                        }
293                        Message(msg) => {
294                            if cmd_send3
295                                .send(ConnCmd::WebrtcMessage(msg))
296                                .await
297                                .is_err()
298                            {
299                                netaudit!(
300                                    DEBUG,
301                                    pub_key = ?pub_key3,
302                                    a = "webrtc cmd closed",
303                                );
304                                break;
305                            }
306                        }
307                        Ready => {
308                            if cmd_send3
309                                .send(ConnCmd::WebrtcReady)
310                                .await
311                                .is_err()
312                            {
313                                netaudit!(
314                                    DEBUG,
315                                    pub_key = ?pub_key3,
316                                    a = "webrtc cmd closed",
317                                );
318                                break;
319                            }
320                        }
321                    }
322                }
323
324                let _ = cmd_send3.send(ConnCmd::WebrtcClosed).await;
325            }));
326
327            msg_send.set_close_on_drop(true);
328
329            let mut recv_over_webrtc = false;
330            let mut send_over_webrtc = false;
331
332            loop {
333                let cmd = tokio::time::timeout(
334                    config.signal_config.max_idle,
335                    cmd_recv.recv(),
336                )
337                .await;
338
339                let cmd = match cmd {
340                    Err(_) => {
341                        netaudit!(
342                            DEBUG,
343                            pub_key = ?pub_key2,
344                            a = "close: connection idle",
345                        );
346                        break;
347                    }
348                    Ok(cmd) => cmd,
349                };
350
351                let cmd = match cmd {
352                    None => {
353                        netaudit!(
354                            DEBUG,
355                            pub_key = ?pub_key2,
356                            a = "close: cmd_recv stream complete",
357                        );
358                        break;
359                    }
360                    Some(cmd) => cmd,
361                };
362
363                let mut slow_task = "unknown";
364
365                if breakable_timeout!(match cmd {
366                    ConnCmd::SigRecv(sig) => {
367                        use tx5_signal::SignalMessage::*;
368                        match sig {
369                            // invalid
370                            HandshakeReq(_) | HandshakeRes(_) => {
371                                slow_task = "sig-handshake";
372                                netaudit!(
373                                    DEBUG,
374                                    pub_key = ?pub_key2,
375                                    a = "close: unexpected handshake msg",
376                                );
377                                break;
378                            }
379                            Message(msg) => {
380                                slow_task = "sig-message";
381                                if recv_over_webrtc {
382                                    netaudit!(
383                                        DEBUG,
384                                        pub_key = ?pub_key2,
385                                        a = "close: unexpected sbd msg after recv_over_webrtc",
386                                    );
387                                    break;
388                                } else {
389                                    recv_msg_count2
390                                        .fetch_add(1, Ordering::Relaxed);
391                                    recv_byte_count2.fetch_add(
392                                        msg.len() as u64,
393                                        Ordering::Relaxed,
394                                    );
395                                    if msg_send.send(msg).await.is_err() {
396                                        netaudit!(
397                                            DEBUG,
398                                            pub_key = ?pub_key2,
399                                            a = "close: msg_send closed",
400                                        );
401                                        break;
402                                    }
403                                }
404                            }
405                            Offer(offer) => {
406                                slow_task = "sig-offer";
407                                netaudit!(
408                                    TRACE,
409                                    pub_key = ?pub_key2,
410                                    offer = String::from_utf8_lossy(&offer).to_string(),
411                                    a = "recv_offer",
412                                );
413                                if let Err(err) = webrtc.in_offer(offer).await {
414                                    netaudit!(
415                                        DEBUG,
416                                        pub_key = ?pub_key2,
417                                        ?err,
418                                        a = "close: webrtc in_offer error",
419                                    );
420                                    break;
421                                }
422                            }
423                            Answer(answer) => {
424                                slow_task = "sig-answer";
425                                netaudit!(
426                                    TRACE,
427                                    pub_key = ?pub_key2,
428                                    offer = String::from_utf8_lossy(&answer).to_string(),
429                                    a = "recv_answer",
430                                );
431                                if let Err(err) = webrtc.in_answer(answer).await
432                                {
433                                    netaudit!(
434                                        DEBUG,
435                                        pub_key = ?pub_key2,
436                                        ?err,
437                                        a = "close: webrtc in_answer error",
438                                    );
439                                    break;
440                                }
441                            }
442                            Ice(ice) => {
443                                slow_task = "sig-ice";
444                                netaudit!(
445                                    TRACE,
446                                    pub_key = ?pub_key2,
447                                    offer = String::from_utf8_lossy(&ice).to_string(),
448                                    a = "recv_ice",
449                                );
450                                if let Err(err) = webrtc.in_ice(ice).await {
451                                    netaudit!(
452                                        DEBUG,
453                                        pub_key = ?pub_key2,
454                                        ?err,
455                                        a = "close: webrtc in_ice error",
456                                    );
457                                    break;
458                                }
459                            }
460                            WebrtcReady => {
461                                slow_task = "sig-webrtc-ready";
462                                recv_over_webrtc = true;
463                                netaudit!(
464                                    DEBUG,
465                                    pub_key = ?pub_key2,
466                                    a = "recv_over_webrtc",
467                                );
468                                for msg in webrtc_message_buffer.drain(..) {
469                                    // don't bump send metrics here,
470                                    // we bumped them on receive
471                                    if msg_send.send(msg).await.is_err() {
472                                        netaudit!(
473                                            DEBUG,
474                                            pub_key = ?pub_key2,
475                                            a = "close: msg_send closed",
476                                        );
477                                        break;
478                                    }
479                                }
480                            }
481                            _ => (),
482                        }
483                    }
484                    ConnCmd::SendMessage(msg) => {
485                        slow_task = "send-message";
486                        send_msg_count2.fetch_add(1, Ordering::Relaxed);
487                        send_byte_count2
488                            .fetch_add(msg.len() as u64, Ordering::Relaxed);
489                        if send_over_webrtc {
490                            if let Err(err) = webrtc.message(msg).await {
491                                netaudit!(
492                                    DEBUG,
493                                    pub_key = ?pub_key2,
494                                    ?err,
495                                    a = "close: webrtc message error",
496                                );
497                                break;
498                            }
499                        } else if let Some(client) = client2.upgrade() {
500                            if let Err(err) =
501                                client.send_message(&pub_key2, msg).await
502                            {
503                                netaudit!(
504                                    DEBUG,
505                                    pub_key = ?pub_key2,
506                                    ?err,
507                                    a = "close: sbd client send error",
508                                );
509                                break;
510                            }
511                        } else {
512                            netaudit!(
513                                DEBUG,
514                                pub_key = ?pub_key2,
515                                a = "close: sbd client closed",
516                            );
517                            break;
518                        }
519                    }
520                    ConnCmd::WebrtcMessage(msg) => {
521                        slow_task = "webrtc-message";
522                        recv_msg_count2.fetch_add(1, Ordering::Relaxed);
523                        recv_byte_count2
524                            .fetch_add(msg.len() as u64, Ordering::Relaxed);
525                        if recv_over_webrtc {
526                            if msg_send.send(msg).await.is_err() {
527                                netaudit!(
528                                    DEBUG,
529                                    pub_key = ?pub_key2,
530                                    a = "close: msg_send closed",
531                                );
532                                break;
533                            }
534                        } else {
535                            webrtc_message_buffer.push(msg);
536                            if webrtc_message_buffer.len() > MAX_WEBRTC_BUF {
537                                // prevent memory fillup
538                                netaudit!(
539                                    DEBUG,
540                                    pub_key = ?pub_key2,
541                                    a = "close: webrtc buffer overflow",
542                                );
543                                break;
544                            }
545                        }
546                    }
547                    ConnCmd::WebrtcReady => {
548                        slow_task = "webrtc-ready";
549                        if let Some(client) = client2.upgrade() {
550                            if let Err(err) =
551                                client.send_webrtc_ready(&pub_key2).await
552                            {
553                                netaudit!(
554                                    DEBUG,
555                                    pub_key = ?pub_key2,
556                                    ?err,
557                                    a = "close: sbd client ready error",
558                                );
559                                break;
560                            }
561                        } else {
562                            netaudit!(
563                                DEBUG,
564                                pub_key = ?pub_key2,
565                                a = "close: sbd client closed",
566                            );
567                            break;
568                        }
569                        send_over_webrtc = true;
570                        netaudit!(
571                            DEBUG,
572                            pub_key = ?pub_key2,
573                            a = "send_over_webrtc",
574                        );
575                        webrtc_ready2.close();
576                    }
577                    ConnCmd::WebrtcClosed => {
578                        slow_task = "webrtc-closed";
579                        netaudit!(
580                            DEBUG,
581                            pub_key = ?pub_key2,
582                            a = "close: webrtc closed",
583                        );
584                        break;
585                    }
586                }).is_err() {
587                    tracing::warn!(slow_task, "slow app on conn loop task");
588                    break;
589                }
590            }
591
592            // explicitly close the peer
593            if let Some(client) = client2.upgrade() {
594                client.close_peer(&pub_key2).await;
595            };
596
597            // the receiver side is closed because msg_send is dropped.
598        });
599
600        let mut cmd_send2 = cmd_send.clone();
601        cmd_send2.set_close_on_drop(true);
602        let this = Self {
603            ready,
604            pub_key,
605            cmd_send: cmd_send2,
606            conn_task,
607            keepalive_task,
608            webrtc_ready,
609            send_msg_count,
610            send_byte_count,
611            recv_msg_count,
612            recv_byte_count,
613            hub_cmd_send,
614        };
615
616        (Arc::new(this), ConnRecv(msg_recv), cmd_send)
617    }
618
619    /// Wait until this connection is ready to send / receive data.
620    pub async fn ready(&self) {
621        // this will error when we close the semaphore waking up the task
622        let _ = self.ready.acquire().await;
623    }
624
625    /// Wait until this connection is connected via webrtc.
626    /// Note, this will never resolve if we never successfully
627    /// connect over webrtc.
628    pub async fn webrtc_ready(&self) {
629        let _ = self.webrtc_ready.acquire().await;
630    }
631
632    /// Returns `true` if we sucessfully connected over webrtc.
633    pub fn is_using_webrtc(&self) -> bool {
634        self.webrtc_ready.is_closed()
635    }
636
637    /// The pub key of the remote peer this is connected to.
638    pub fn pub_key(&self) -> &PubKey {
639        &self.pub_key
640    }
641
642    /// Send up to 16KiB of message data.
643    pub async fn send(&self, msg: Vec<u8>) -> Result<()> {
644        self.cmd_send.send(ConnCmd::SendMessage(msg)).await
645    }
646
647    /// Get connection statistics.
648    pub fn get_stats(&self) -> ConnStats {
649        ConnStats {
650            send_msg_count: self.send_msg_count.load(Ordering::Relaxed),
651            send_byte_count: self.send_byte_count.load(Ordering::Relaxed),
652            recv_msg_count: self.recv_msg_count.load(Ordering::Relaxed),
653            recv_byte_count: self.recv_byte_count.load(Ordering::Relaxed),
654        }
655    }
656}
657
658/// Connection statistics.
659#[derive(Default)]
660pub struct ConnStats {
661    /// message count sent.
662    pub send_msg_count: u64,
663
664    /// byte count sent.
665    pub send_byte_count: u64,
666
667    /// message count received.
668    pub recv_msg_count: u64,
669
670    /// byte count received.
671    pub recv_byte_count: u64,
672}