tycho_network/network/
mod.rs

1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::{Arc, Weak};
3
4#[cfg(target_os = "linux")]
5use anyhow::Context;
6use anyhow::Result;
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tycho_crypto::ed25519;
9
10use self::config::EndpointConfig;
11pub use self::config::{NetworkConfig, QuicConfig};
12pub use self::connection::{Connection, RecvStream, SendStream};
13use self::connection_manager::{ActivePeers, ConnectionManager, ConnectionManagerRequest};
14pub use self::connection_manager::{
15    KnownPeerHandle, KnownPeers, KnownPeersError, PeerBannedError, WeakKnownPeerHandle,
16};
17use self::endpoint::Endpoint;
18pub use self::peer::Peer;
19use crate::types::{
20    Address, DisconnectReason, PeerEvent, PeerId, PeerInfo, Response, Service, ServiceExt,
21    ServiceRequest,
22};
23
24mod config;
25mod connection;
26mod connection_manager;
27mod crypto;
28mod endpoint;
29mod peer;
30mod request_handler;
31mod wire;
32
33pub struct NetworkBuilder<MandatoryFields = ([u8; 32],)> {
34    mandatory_fields: MandatoryFields,
35    optional_fields: BuilderFields,
36}
37
38#[derive(Default)]
39struct BuilderFields {
40    config: Option<NetworkConfig>,
41    remote_addr: Option<Address>,
42}
43
44impl<MandatoryFields> NetworkBuilder<MandatoryFields> {
45    pub fn with_config(mut self, config: NetworkConfig) -> Self {
46        self.optional_fields.config = Some(config);
47        self
48    }
49
50    pub fn with_remote_addr<T: Into<Address>>(mut self, addr: T) -> Self {
51        self.optional_fields.remote_addr = Some(addr.into());
52        self
53    }
54}
55
56impl NetworkBuilder<((),)> {
57    pub fn with_private_key(self, private_key: [u8; 32]) -> NetworkBuilder<([u8; 32],)> {
58        NetworkBuilder {
59            mandatory_fields: (private_key,),
60            optional_fields: self.optional_fields,
61        }
62    }
63
64    pub fn with_random_private_key(self) -> NetworkBuilder<([u8; 32],)> {
65        self.with_private_key(rand::random())
66    }
67}
68
69impl NetworkBuilder {
70    pub fn build<T: ToSocket, S>(self, bind_address: T, service: S) -> Result<Network>
71    where
72        S: Send + Sync + Clone + 'static,
73        S: Service<ServiceRequest, QueryResponse = Response>,
74    {
75        let config = self.optional_fields.config.unwrap_or_default();
76        let quic_config = config.quic.clone().unwrap_or_default();
77        let (private_key,) = self.mandatory_fields;
78
79        let keypair = ed25519::KeyPair::from(&ed25519::SecretKey::from_bytes(private_key));
80
81        let endpoint_config = EndpointConfig::builder()
82            .with_private_key(private_key)
83            .with_0rtt_enabled(config.enable_0rtt)
84            .with_transport_config(quic_config.make_transport_config())
85            .with_connection_metrics(config.connection_metrics)
86            .build()?;
87
88        let socket = bind_address.to_socket().map(socket2::Socket::from)?;
89
90        let max_socket_size = MaxBufferSize::read()?;
91
92        set_socket_buffer(
93            &socket,
94            quic_config.socket_send_buffer_size,
95            max_socket_size.map(|m| m.send),
96            |s, size| s.set_send_buffer_size(size),
97            "send",
98        );
99
100        set_socket_buffer(
101            &socket,
102            quic_config.socket_recv_buffer_size,
103            max_socket_size.map(|m| m.recv),
104            |s, size| s.set_recv_buffer_size(size),
105            "recv",
106        );
107
108        let config = Arc::new(config);
109        let endpoint = Arc::new(Endpoint::new(endpoint_config, socket.into())?);
110        let active_peers = ActivePeers::new(config.active_peers_event_channel_capacity);
111        let known_peers = KnownPeers::new();
112
113        let remote_addr = self.optional_fields.remote_addr.unwrap_or_else(|| {
114            let addr = endpoint.local_addr();
115            tracing::debug!(%addr, "using local address as remote address");
116            addr.into()
117        });
118
119        let service = service.boxed_clone();
120
121        let (connection_manager, connection_manager_handle) = ConnectionManager::new(
122            config.clone(),
123            endpoint.clone(),
124            active_peers.clone(),
125            known_peers.clone(),
126            service,
127        );
128
129        tokio::spawn(connection_manager.start());
130
131        Ok(Network(Arc::new(NetworkInner {
132            config,
133            remote_addr,
134            endpoint,
135            active_peers,
136            known_peers,
137            connection_manager_handle,
138            keypair,
139        })))
140    }
141}
142
143fn set_socket_buffer(
144    socket: &socket2::Socket,
145    config_size: Option<usize>,
146    max_size: Option<usize>,
147    set_buffer_fn: impl Fn(&socket2::Socket, usize) -> std::io::Result<()>,
148    buffer_type: &str,
149) {
150    if let Some(size) = config_size {
151        if let Err(e) = set_buffer_fn(socket, size) {
152            tracing::error!(%size, "failed to set socket {} buffer size: {e:?}", buffer_type);
153        }
154    } else if let Some(max) = max_size {
155        if let Err(e) = set_buffer_fn(socket, max) {
156            tracing::error!(%max, "failed to set socket {} buffer size to max value: {e:?}", buffer_type);
157        }
158        tracing::info!(
159            "set socket {} buffer size to max value: {}",
160            buffer_type,
161            max
162        );
163    }
164}
165
166#[derive(Clone)]
167#[repr(transparent)]
168pub struct WeakNetwork(Weak<NetworkInner>);
169
170impl WeakNetwork {
171    pub fn upgrade(&self) -> Option<Network> {
172        self.0
173            .upgrade()
174            .map(Network)
175            .and_then(|network| (!network.is_closed()).then_some(network))
176    }
177}
178
179#[derive(Clone)]
180#[repr(transparent)]
181pub struct Network(Arc<NetworkInner>);
182
183impl Network {
184    pub fn builder() -> NetworkBuilder<((),)> {
185        NetworkBuilder {
186            mandatory_fields: ((),),
187            optional_fields: Default::default(),
188        }
189    }
190
191    /// The public address of this node.
192    pub fn remote_addr(&self) -> &Address {
193        self.0.remote_addr()
194    }
195
196    /// The listening address of this node.
197    pub fn local_addr(&self) -> SocketAddr {
198        self.0.local_addr()
199    }
200
201    /// The local peer id of this node.
202    pub fn peer_id(&self) -> &PeerId {
203        self.0.peer_id()
204    }
205
206    /// Returns true if the peer is currently connected.
207    pub fn is_active(&self, peer_id: &PeerId) -> bool {
208        self.0.active_peers.contains(peer_id)
209    }
210
211    /// Returns a connection wrapper for the specified peer.
212    pub fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
213        self.0.peer(peer_id)
214    }
215
216    /// A set of known peers.
217    pub fn known_peers(&self) -> &KnownPeers {
218        &self.0.known_peers
219    }
220
221    /// Subscribe to active peer changes.
222    pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
223        self.0.active_peers.subscribe()
224    }
225
226    /// Initiate a connection to the specified peer.
227    pub async fn connect<T>(&self, addr: T, peer_id: &PeerId) -> Result<Peer, ConnectionError>
228    where
229        T: Into<Address>,
230    {
231        self.0.connect(addr.into(), peer_id).await
232    }
233
234    pub fn disconnect(&self, peer_id: &PeerId) {
235        self.0.disconnect(peer_id);
236    }
237
238    pub async fn shutdown(&self) {
239        self.0.shutdown().await;
240    }
241
242    pub fn is_closed(&self) -> bool {
243        self.0.is_closed()
244    }
245
246    pub fn sign_tl<T: tl_proto::TlWrite>(&self, data: T) -> [u8; 64] {
247        self.0.keypair.sign_tl(data)
248    }
249
250    pub fn sign_raw(&self, data: &[u8]) -> [u8; 64] {
251        self.0.keypair.sign_raw(data)
252    }
253
254    pub fn sign_peer_info(&self, now: u32, ttl: u32) -> PeerInfo {
255        let mut res = PeerInfo {
256            id: *self.0.peer_id(),
257            address_list: vec![self.remote_addr().clone()].into_boxed_slice(),
258            created_at: now,
259            expires_at: now.saturating_add(ttl),
260            signature: Box::new([0; 64]),
261        };
262        *res.signature = self.sign_tl(&res);
263        res
264    }
265
266    pub fn downgrade(this: &Self) -> WeakNetwork {
267        WeakNetwork(Arc::downgrade(&this.0))
268    }
269
270    /// returns the maximum size which can be potentially sent in a single frame
271    pub fn max_frame_size(&self) -> usize {
272        self.0.config.max_frame_size.0 as usize
273    }
274}
275
276struct NetworkInner {
277    config: Arc<NetworkConfig>,
278    remote_addr: Address,
279    endpoint: Arc<Endpoint>,
280    active_peers: ActivePeers,
281    known_peers: KnownPeers,
282    connection_manager_handle: mpsc::Sender<ConnectionManagerRequest>,
283    keypair: ed25519::KeyPair,
284}
285
286impl NetworkInner {
287    fn remote_addr(&self) -> &Address {
288        &self.remote_addr
289    }
290
291    fn local_addr(&self) -> SocketAddr {
292        self.endpoint.local_addr()
293    }
294
295    fn peer_id(&self) -> &PeerId {
296        self.endpoint.peer_id()
297    }
298
299    async fn connect(&self, addr: Address, peer_id: &PeerId) -> Result<Peer, ConnectionError> {
300        let (tx, rx) = oneshot::channel();
301        self.connection_manager_handle
302            .send(ConnectionManagerRequest::Connect(addr, *peer_id, tx))
303            .await
304            .map_err(|_e| ConnectionError::Shutdown)?;
305
306        let Ok(res) = rx.await else {
307            return Err(ConnectionError::Shutdown);
308        };
309
310        res.map(|c| Peer::new(c, self.config.clone()))
311    }
312
313    fn disconnect(&self, peer_id: &PeerId) {
314        self.active_peers
315            .remove(peer_id, DisconnectReason::Requested);
316    }
317
318    fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
319        let connection = self.active_peers.get(peer_id)?;
320        Some(Peer::new(connection, self.config.clone()))
321    }
322
323    async fn shutdown(&self) {
324        let (sender, receiver) = oneshot::channel();
325        if self
326            .connection_manager_handle
327            .send(ConnectionManagerRequest::Shutdown(sender))
328            .await
329            .is_err()
330        {
331            return;
332        }
333
334        receiver.await.ok();
335    }
336
337    fn is_closed(&self) -> bool {
338        self.connection_manager_handle.is_closed()
339    }
340}
341
342impl Drop for NetworkInner {
343    fn drop(&mut self) {
344        tracing::debug!("network dropped");
345    }
346}
347
348pub trait ToSocket {
349    fn to_socket(self) -> Result<std::net::UdpSocket, BindError>;
350}
351
352impl ToSocket for std::net::UdpSocket {
353    fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
354        Ok(self)
355    }
356}
357
358macro_rules! impl_to_socket_for_addr {
359    ($($ty:ty),*$(,)?) => {$(
360        impl ToSocket for $ty {
361            fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
362                bind_socket_to_addr(self)
363            }
364        }
365    )*};
366}
367
368impl_to_socket_for_addr! {
369    SocketAddr,
370    std::net::SocketAddrV4,
371    std::net::SocketAddrV6,
372    (std::net::IpAddr, u16),
373    (std::net::Ipv4Addr, u16),
374    (std::net::Ipv6Addr, u16),
375    (&str, u16),
376    (String, u16),
377    &str,
378    String,
379    &[SocketAddr],
380    Address,
381}
382
383fn bind_socket_to_addr<T: ToSocketAddrs>(
384    bind_address: T,
385) -> Result<std::net::UdpSocket, BindError> {
386    use socket2::{Domain, Protocol, Socket, Type};
387
388    let socket_addrs = bind_address
389        .to_socket_addrs()
390        .map_err(BindError::AddressResolution)?;
391
392    let mut last_bind_error: Option<BindError> = None;
393
394    for addr in socket_addrs {
395        let socket = match Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
396        {
397            Ok(s) => s,
398            Err(e) => {
399                if last_bind_error.is_none() {
400                    last_bind_error = Some(BindError::SocketCreation { addr, source: e });
401                }
402                continue;
403            }
404        };
405
406        // Attempt to bind the socket
407        match socket.bind(&socket2::SockAddr::from(addr)) {
408            Ok(()) => return Ok(socket.into()),
409            Err(e) => {
410                // Bind failed for this address, store the error and try the next
411                tracing::warn!(?e, %addr, "failed to bind, trying next address");
412                last_bind_error = Some(BindError::SocketBind { addr, source: e });
413            }
414        }
415    }
416
417    // If we've looped through all addresses and none succeeded
418    Err(last_bind_error.unwrap_or(BindError::NoAddressesToBind))
419}
420
421#[derive(thiserror::Error, Debug)]
422pub enum BindError {
423    #[error("failed to resolve socket addresses: {0}")]
424    AddressResolution(#[source] std::io::Error),
425
426    #[error("no suitable addresses found to bind to after attempting all resolved addresses")]
427    NoAddressesToBind,
428
429    #[error("failed to create new socket for address {addr:?}: {source}")]
430    SocketCreation {
431        addr: SocketAddr,
432        #[source]
433        source: std::io::Error,
434    },
435
436    #[error("failed to bind socket to address {addr}: {source}")]
437    SocketBind {
438        addr: SocketAddr,
439        #[source]
440        source: std::io::Error,
441    },
442}
443
444#[derive(Debug, Clone, Copy)]
445struct MaxBufferSize {
446    send: usize,
447    recv: usize,
448}
449
450impl MaxBufferSize {
451    #[cfg(target_os = "linux")]
452    pub fn read() -> Result<Option<Self>> {
453        const WMEM: &str = "wmem_max";
454        const RMEM: &str = "rmem_max";
455
456        #[cfg(any(feature = "test", test))]
457        let proc_path = std::env::var("MOCK_PROC_PATH").unwrap_or_else(|_| "/proc".to_string());
458        #[cfg(not(any(feature = "test", test)))]
459        let proc_path = "/proc";
460        let proc_path = std::path::Path::new(&proc_path).join("sys/net/core");
461
462        let read_and_parse = |file_name: &str| -> Result<Option<usize>> {
463            let path = proc_path.join(file_name);
464            if !path.exists() {
465                tracing::warn!("{} not found", path.display());
466                return Ok(None);
467            }
468            let res = std::fs::read_to_string(&path)
469                .with_context(|| format!("Failed to read {}", path.display()))?
470                .trim()
471                .parse()
472                .with_context(|| format!("Failed to parse {}", path.display()))?;
473            Ok(Some(res))
474        };
475
476        let rmem = read_and_parse(RMEM)?;
477        let wmem = read_and_parse(WMEM)?;
478
479        Ok(rmem.zip(wmem).map(|(recv, send)| Self { send, recv }))
480    }
481
482    #[cfg(not(target_os = "linux"))]
483    pub fn read() -> std::io::Result<Option<Self>> {
484        Ok(None)
485    }
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
489pub enum ConnectionError {
490    #[error("invalid address")]
491    InvalidAddress,
492    #[error("connection init failed")]
493    ConnectionInitFailed,
494    #[error("invalid certificate")]
495    InvalidCertificate,
496    #[error("handshake failed")]
497    HandshakeFailed,
498    #[error("connection timeout")]
499    Timeout,
500    #[error("network has been shutdown")]
501    Shutdown,
502}
503
504#[cfg(test)]
505mod tests {
506    use futures_util::StreamExt;
507    use futures_util::stream::FuturesUnordered;
508
509    use super::*;
510    use crate::types::{BoxCloneService, PeerInfo, Request, service_message_fn, service_query_fn};
511    use crate::util::{NetworkExt, UnknownPeerError};
512
513    fn echo_service() -> BoxCloneService<ServiceRequest, Response> {
514        let handle = |request: ServiceRequest| async move {
515            tracing::trace!("received: {}", request.body.escape_ascii());
516            let response = Response {
517                version: Default::default(),
518                body: request.body,
519            };
520            Some(response)
521        };
522        service_query_fn(handle).boxed_clone()
523    }
524
525    fn make_network() -> Result<Network> {
526        Network::builder()
527            .with_config(NetworkConfig {
528                enable_0rtt: true,
529                ..Default::default()
530            })
531            .with_random_private_key()
532            .build("127.0.0.1:0", echo_service())
533    }
534
535    fn make_peer_info(network: &Network) -> Arc<PeerInfo> {
536        Arc::new(PeerInfo {
537            id: *network.peer_id(),
538            address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
539            created_at: 0,
540            expires_at: u32::MAX,
541            signature: Box::new([0; 64]),
542        })
543    }
544
545    #[tokio::test]
546    async fn connection_manager_works() -> Result<()> {
547        tycho_util::test::init_logger("connection_manager_works", "debug");
548
549        let peer1 = make_network()?;
550        let peer2 = make_network()?;
551
552        peer1
553            .connect(peer2.local_addr(), peer2.peer_id())
554            .await
555            .unwrap();
556        peer2
557            .connect(peer1.local_addr(), peer1.peer_id())
558            .await
559            .unwrap();
560
561        Ok(())
562    }
563
564    #[tokio::test]
565    async fn invalid_peer_id_detectable() -> Result<()> {
566        tycho_util::test::init_logger("invalid_peer_id_detectable", "debug");
567
568        let peer1 = make_network()?;
569        let peer2 = make_network()?;
570
571        let make_invalid_peer_info = |network: &Network| {
572            Arc::new(PeerInfo {
573                id: PeerId([0; 32]),
574                address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
575                created_at: 0,
576                expires_at: u32::MAX,
577                signature: Box::new([0; 64]),
578            })
579        };
580        let _handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
581        let _handle = peer1
582            .known_peers()
583            .insert(make_invalid_peer_info(&peer2), false)?;
584
585        let _handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
586        let _handle = peer2
587            .known_peers()
588            .insert(make_invalid_peer_info(&peer1), false)?;
589
590        let req = Request {
591            version: Default::default(),
592            body: "hello".into(),
593        };
594
595        peer1.query(peer2.peer_id(), req.clone()).await?;
596        peer2.query(peer1.peer_id(), req.clone()).await?;
597
598        fn assert_is_invalid_certificate(e: anyhow::Error) {
599            // A non-recursive downcast to find a connection error
600            let e = (*e).downcast_ref::<ConnectionError>().unwrap();
601            assert_eq!(*e, ConnectionError::InvalidCertificate);
602        }
603
604        let err = peer1
605            .query(&PeerId([0; 32]), req.clone())
606            .await
607            .map(|_| ())
608            .unwrap_err();
609        assert_is_invalid_certificate(err);
610
611        let err = peer2
612            .query(&PeerId([0; 32]), req.clone())
613            .await
614            .map(|_| ())
615            .unwrap_err();
616        assert_is_invalid_certificate(err);
617
618        fn assert_is_unknown_peer(e: anyhow::Error, peer_id: &PeerId) {
619            // A non-recursive downcast to find an error
620            let e = (*e).downcast_ref::<UnknownPeerError>().unwrap();
621            assert_eq!(e, &UnknownPeerError { peer_id: *peer_id });
622        }
623
624        let invalid_peer_id = PeerId([0xff; 32]);
625        let err = peer1
626            .query(&invalid_peer_id, req.clone())
627            .await
628            .map(|_| ())
629            .unwrap_err();
630        assert_is_unknown_peer(err, &invalid_peer_id);
631
632        Ok(())
633    }
634
635    #[tokio::test]
636    async fn simultaneous_queries() -> Result<()> {
637        tycho_util::test::init_logger("simultaneous_queries", "debug");
638
639        for _ in 0..10 {
640            let peer1 = make_network()?;
641            let peer2 = make_network()?;
642
643            let _peer1_peer2_handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
644            let _peer2_peer1_handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
645
646            let req = Request {
647                version: Default::default(),
648                body: "hello".into(),
649            };
650            let peer1_fut = std::pin::pin!(peer1.query(peer2.peer_id(), req.clone()));
651            let peer2_fut = std::pin::pin!(peer2.query(peer1.peer_id(), req.clone()));
652
653            let (res1, res2) = futures_util::future::join(peer1_fut, peer2_fut).await;
654            assert_eq!(res1?.body, req.body);
655            assert_eq!(res2?.body, req.body);
656        }
657
658        Ok(())
659    }
660
661    #[tokio::test(flavor = "multi_thread")]
662    async fn uni_message_handler() -> Result<()> {
663        tycho_util::test::init_logger("uni_message_handler", "debug");
664
665        fn noop_service() -> BoxCloneService<ServiceRequest, Response> {
666            let handle = |request: ServiceRequest| async move {
667                tracing::trace!("received: {} bytes", request.body.len());
668            };
669            service_message_fn(handle).boxed_clone()
670        }
671
672        fn make_network() -> Result<Network> {
673            Network::builder()
674                .with_config(NetworkConfig {
675                    enable_0rtt: true,
676                    ..Default::default()
677                })
678                .with_random_private_key()
679                .build("127.0.0.1:0", noop_service())
680        }
681
682        let left = make_network()?;
683        let right = make_network()?;
684
685        let _left_to_right = left.known_peers().insert(make_peer_info(&right), false)?;
686        let _right_to_left = right.known_peers().insert(make_peer_info(&left), false)?;
687
688        let req = Request {
689            version: Default::default(),
690            body: vec![0xff; 750 * 1024].into(),
691        };
692
693        for _ in 0..10 {
694            let mut futures = FuturesUnordered::new();
695            for _ in 0..100 {
696                futures.push(left.send(right.peer_id(), req.clone()));
697            }
698
699            while let Some(res) = futures.next().await {
700                res?;
701            }
702        }
703
704        Ok(())
705    }
706
707    #[test]
708    fn socket_size_works() {
709        if std::path::Path::new("/proc").exists() {
710            let socket_size = MaxBufferSize::read()
711                .unwrap()
712                .expect("socket size not found");
713            assert!(socket_size.send > 0);
714            assert!(socket_size.recv > 0);
715        } else {
716            // github doesn't expose /proc and macos exists
717            let procfs = tempfile::tempdir().unwrap();
718            unsafe { std::env::set_var("MOCK_PROC_PATH", procfs.path()) };
719
720            std::fs::create_dir_all(procfs.path().join("sys/net/core")).unwrap();
721            std::fs::write(procfs.path().join("sys/net/core/wmem_max"), "100000\n").unwrap();
722            std::fs::write(procfs.path().join("sys/net/core/rmem_max"), "100000\n").unwrap();
723
724            let socket_size = MaxBufferSize::read()
725                .unwrap()
726                .expect("socket size not found");
727
728            assert_eq!(socket_size.send, 100000);
729            assert_eq!(socket_size.recv, 100000);
730        }
731    }
732}