tycho_network/network/
mod.rs

1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::{Arc, Weak};
3
4#[cfg(target_os = "linux")]
5use anyhow::Context;
6use anyhow::Result;
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tycho_crypto::ed25519;
9
10use self::config::EndpointConfig;
11pub use self::config::{CongestionAlgorithm, ConnectionMetricsLevel, NetworkConfig, QuicConfig};
12pub use self::connection::{Connection, RecvStream, SendStream};
13use self::connection_manager::{ActivePeers, ConnectionManager, ConnectionManagerRequest};
14pub use self::connection_manager::{
15    KnownPeerHandle, KnownPeers, KnownPeersError, PeerBannedError, WeakKnownPeerHandle,
16};
17use self::endpoint::Endpoint;
18pub use self::peer::Peer;
19use crate::types::{
20    Address, DisconnectReason, PeerEvent, PeerId, PeerInfo, Response, Service, ServiceExt,
21    ServiceRequest,
22};
23
24mod config;
25mod connection;
26mod connection_manager;
27mod crypto;
28mod endpoint;
29mod peer;
30mod request_handler;
31mod wire;
32
33pub struct NetworkBuilder<MandatoryFields = ([u8; 32],)> {
34    mandatory_fields: MandatoryFields,
35    optional_fields: BuilderFields,
36}
37
38#[derive(Default)]
39struct BuilderFields {
40    config: Option<NetworkConfig>,
41    remote_addr: Option<Address>,
42}
43
44impl<MandatoryFields> NetworkBuilder<MandatoryFields> {
45    pub fn with_config(mut self, config: NetworkConfig) -> Self {
46        self.optional_fields.config = Some(config);
47        self
48    }
49
50    pub fn with_remote_addr<T: Into<Address>>(mut self, addr: T) -> Self {
51        self.optional_fields.remote_addr = Some(addr.into());
52        self
53    }
54}
55
56impl NetworkBuilder<((),)> {
57    pub fn with_private_key(self, private_key: [u8; 32]) -> NetworkBuilder<([u8; 32],)> {
58        NetworkBuilder {
59            mandatory_fields: (private_key,),
60            optional_fields: self.optional_fields,
61        }
62    }
63
64    pub fn with_random_private_key(self) -> NetworkBuilder<([u8; 32],)> {
65        self.with_private_key(rand::random())
66    }
67}
68
69impl NetworkBuilder {
70    pub fn build<T: ToSocket, S>(self, bind_address: T, service: S) -> Result<Network>
71    where
72        S: Send + Sync + Clone + 'static,
73        S: Service<ServiceRequest, QueryResponse = Response>,
74    {
75        let config = self.optional_fields.config.unwrap_or_default();
76        let quic_config = config.quic.clone().unwrap_or_default();
77        let (private_key,) = self.mandatory_fields;
78
79        let keypair = ed25519::KeyPair::from(&ed25519::SecretKey::from_bytes(private_key));
80
81        let endpoint_config = EndpointConfig::builder()
82            .with_private_key(private_key)
83            .with_0rtt_enabled(config.enable_0rtt)
84            .with_transport_config(quic_config.make_transport_config())
85            .with_connection_metrics(config.connection_metrics)
86            .build()?;
87
88        let socket = bind_address.to_socket().map(socket2::Socket::from)?;
89
90        let max_socket_size = MaxBufferSize::read()?;
91
92        set_socket_buffer(
93            &socket,
94            quic_config.socket_send_buffer_size,
95            max_socket_size.map(|m| m.send),
96            |s, size| s.set_send_buffer_size(size),
97            "send",
98        );
99
100        set_socket_buffer(
101            &socket,
102            quic_config.socket_recv_buffer_size,
103            max_socket_size.map(|m| m.recv),
104            |s, size| s.set_recv_buffer_size(size),
105            "recv",
106        );
107
108        let config = Arc::new(config);
109        let endpoint = Arc::new(Endpoint::new(endpoint_config, socket.into())?);
110        let active_peers = ActivePeers::new(config.active_peers_event_channel_capacity);
111        let known_peers = KnownPeers::new();
112
113        let mut remote_addr = self.optional_fields.remote_addr.unwrap_or_else(|| {
114            let addr = endpoint.local_addr();
115            tracing::debug!(%addr, "using local address as remote address");
116            addr.into()
117        });
118        if remote_addr.port() == 0 {
119            remote_addr.set_port(endpoint.local_addr().port());
120        }
121
122        let service = service.boxed_clone();
123
124        let (connection_manager, connection_manager_handle) = ConnectionManager::new(
125            config.clone(),
126            endpoint.clone(),
127            active_peers.clone(),
128            known_peers.clone(),
129            service,
130        );
131
132        tokio::spawn(connection_manager.start());
133
134        Ok(Network(Arc::new(NetworkInner {
135            config,
136            remote_addr,
137            endpoint,
138            active_peers,
139            known_peers,
140            connection_manager_handle,
141            keypair,
142        })))
143    }
144}
145
146fn set_socket_buffer(
147    socket: &socket2::Socket,
148    config_size: Option<usize>,
149    max_size: Option<usize>,
150    set_buffer_fn: impl Fn(&socket2::Socket, usize) -> std::io::Result<()>,
151    buffer_type: &str,
152) {
153    if let Some(size) = config_size {
154        if let Err(e) = set_buffer_fn(socket, size) {
155            tracing::error!(%size, "failed to set socket {} buffer size: {e:?}", buffer_type);
156        }
157    } else if let Some(max) = max_size {
158        if let Err(e) = set_buffer_fn(socket, max) {
159            tracing::error!(%max, "failed to set socket {} buffer size to max value: {e:?}", buffer_type);
160        }
161        tracing::info!(
162            "set socket {} buffer size to max value: {}",
163            buffer_type,
164            max
165        );
166    }
167}
168
169#[derive(Clone)]
170#[repr(transparent)]
171pub struct WeakNetwork(Weak<NetworkInner>);
172
173impl WeakNetwork {
174    pub fn upgrade(&self) -> Option<Network> {
175        self.0
176            .upgrade()
177            .map(Network)
178            .and_then(|network| (!network.is_closed()).then_some(network))
179    }
180}
181
182#[derive(Clone)]
183#[repr(transparent)]
184pub struct Network(Arc<NetworkInner>);
185
186impl Network {
187    pub fn builder() -> NetworkBuilder<((),)> {
188        NetworkBuilder {
189            mandatory_fields: ((),),
190            optional_fields: Default::default(),
191        }
192    }
193
194    /// The public address of this node.
195    pub fn remote_addr(&self) -> &Address {
196        self.0.remote_addr()
197    }
198
199    /// The listening address of this node.
200    pub fn local_addr(&self) -> SocketAddr {
201        self.0.local_addr()
202    }
203
204    /// The local peer id of this node.
205    pub fn peer_id(&self) -> &PeerId {
206        self.0.peer_id()
207    }
208
209    /// Returns true if the peer is currently connected.
210    pub fn is_active(&self, peer_id: &PeerId) -> bool {
211        self.0.active_peers.contains(peer_id)
212    }
213
214    /// Returns a connection wrapper for the specified peer.
215    pub fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
216        self.0.peer(peer_id)
217    }
218
219    /// A set of known peers.
220    pub fn known_peers(&self) -> &KnownPeers {
221        &self.0.known_peers
222    }
223
224    /// Subscribe to active peer changes.
225    pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
226        self.0.active_peers.subscribe()
227    }
228
229    /// Initiate a connection to the specified peer.
230    pub async fn connect<T>(&self, addr: T, peer_id: &PeerId) -> Result<Peer, ConnectionError>
231    where
232        T: Into<Address>,
233    {
234        self.0.connect(addr.into(), peer_id).await
235    }
236
237    pub fn disconnect(&self, peer_id: &PeerId) {
238        self.0.disconnect(peer_id);
239    }
240
241    pub async fn shutdown(&self) {
242        self.0.shutdown().await;
243    }
244
245    pub fn is_closed(&self) -> bool {
246        self.0.is_closed()
247    }
248
249    pub fn sign_tl<T: tl_proto::TlWrite>(&self, data: T) -> [u8; 64] {
250        self.0.keypair.sign_tl(data)
251    }
252
253    pub fn sign_raw(&self, data: &[u8]) -> [u8; 64] {
254        self.0.keypair.sign_raw(data)
255    }
256
257    pub fn sign_peer_info(&self, now: u32, ttl: u32) -> PeerInfo {
258        let mut res = PeerInfo {
259            id: *self.0.peer_id(),
260            address_list: vec![self.remote_addr().clone()].into_boxed_slice(),
261            created_at: now,
262            expires_at: now.saturating_add(ttl),
263            signature: Box::new([0; 64]),
264        };
265        *res.signature = self.sign_tl(&res);
266        res
267    }
268
269    pub fn downgrade(this: &Self) -> WeakNetwork {
270        WeakNetwork(Arc::downgrade(&this.0))
271    }
272
273    /// returns the maximum size which can be potentially sent in a single frame
274    pub fn max_frame_size(&self) -> usize {
275        self.0.config.max_frame_size.0 as usize
276    }
277}
278
279struct NetworkInner {
280    config: Arc<NetworkConfig>,
281    remote_addr: Address,
282    endpoint: Arc<Endpoint>,
283    active_peers: ActivePeers,
284    known_peers: KnownPeers,
285    connection_manager_handle: mpsc::Sender<ConnectionManagerRequest>,
286    keypair: ed25519::KeyPair,
287}
288
289impl NetworkInner {
290    fn remote_addr(&self) -> &Address {
291        &self.remote_addr
292    }
293
294    fn local_addr(&self) -> SocketAddr {
295        self.endpoint.local_addr()
296    }
297
298    fn peer_id(&self) -> &PeerId {
299        self.endpoint.peer_id()
300    }
301
302    async fn connect(&self, addr: Address, peer_id: &PeerId) -> Result<Peer, ConnectionError> {
303        let (tx, rx) = oneshot::channel();
304        self.connection_manager_handle
305            .send(ConnectionManagerRequest::Connect(addr, *peer_id, tx))
306            .await
307            .map_err(|_e| ConnectionError::Shutdown)?;
308
309        let Ok(res) = rx.await else {
310            return Err(ConnectionError::Shutdown);
311        };
312
313        res.map(|c| Peer::new(c, self.config.clone()))
314    }
315
316    fn disconnect(&self, peer_id: &PeerId) {
317        self.active_peers
318            .remove(peer_id, DisconnectReason::Requested);
319    }
320
321    fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
322        let connection = self.active_peers.get(peer_id)?;
323        Some(Peer::new(connection, self.config.clone()))
324    }
325
326    async fn shutdown(&self) {
327        let (sender, receiver) = oneshot::channel();
328        if self
329            .connection_manager_handle
330            .send(ConnectionManagerRequest::Shutdown(sender))
331            .await
332            .is_err()
333        {
334            return;
335        }
336
337        receiver.await.ok();
338    }
339
340    fn is_closed(&self) -> bool {
341        self.connection_manager_handle.is_closed()
342    }
343}
344
345impl Drop for NetworkInner {
346    fn drop(&mut self) {
347        tracing::debug!("network dropped");
348    }
349}
350
351pub trait ToSocket {
352    fn to_socket(self) -> Result<std::net::UdpSocket, BindError>;
353}
354
355impl ToSocket for std::net::UdpSocket {
356    fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
357        Ok(self)
358    }
359}
360
361macro_rules! impl_to_socket_for_addr {
362    ($($ty:ty),*$(,)?) => {$(
363        impl ToSocket for $ty {
364            fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
365                bind_socket_to_addr(self)
366            }
367        }
368    )*};
369}
370
371impl_to_socket_for_addr! {
372    SocketAddr,
373    std::net::SocketAddrV4,
374    std::net::SocketAddrV6,
375    (std::net::IpAddr, u16),
376    (std::net::Ipv4Addr, u16),
377    (std::net::Ipv6Addr, u16),
378    (&str, u16),
379    (String, u16),
380    &str,
381    String,
382    &[SocketAddr],
383    Address,
384}
385
386fn bind_socket_to_addr<T: ToSocketAddrs>(
387    bind_address: T,
388) -> Result<std::net::UdpSocket, BindError> {
389    use socket2::{Domain, Protocol, Socket, Type};
390
391    let socket_addrs = bind_address
392        .to_socket_addrs()
393        .map_err(BindError::AddressResolution)?;
394
395    let mut last_bind_error: Option<BindError> = None;
396
397    for addr in socket_addrs {
398        let socket = match Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
399        {
400            Ok(s) => s,
401            Err(e) => {
402                if last_bind_error.is_none() {
403                    last_bind_error = Some(BindError::SocketCreation { addr, source: e });
404                }
405                continue;
406            }
407        };
408
409        // Attempt to bind the socket
410        match socket.bind(&socket2::SockAddr::from(addr)) {
411            Ok(()) => return Ok(socket.into()),
412            Err(e) => {
413                // Bind failed for this address, store the error and try the next
414                tracing::warn!(?e, %addr, "failed to bind, trying next address");
415                last_bind_error = Some(BindError::SocketBind { addr, source: e });
416            }
417        }
418    }
419
420    // If we've looped through all addresses and none succeeded
421    Err(last_bind_error.unwrap_or(BindError::NoAddressesToBind))
422}
423
424#[derive(thiserror::Error, Debug)]
425pub enum BindError {
426    #[error("failed to resolve socket addresses: {0}")]
427    AddressResolution(#[source] std::io::Error),
428
429    #[error("no suitable addresses found to bind to after attempting all resolved addresses")]
430    NoAddressesToBind,
431
432    #[error("failed to create new socket for address {addr:?}: {source}")]
433    SocketCreation {
434        addr: SocketAddr,
435        #[source]
436        source: std::io::Error,
437    },
438
439    #[error("failed to bind socket to address {addr}: {source}")]
440    SocketBind {
441        addr: SocketAddr,
442        #[source]
443        source: std::io::Error,
444    },
445}
446
447#[derive(Debug, Clone, Copy)]
448struct MaxBufferSize {
449    send: usize,
450    recv: usize,
451}
452
453impl MaxBufferSize {
454    #[cfg(target_os = "linux")]
455    pub fn read() -> Result<Option<Self>> {
456        const WMEM: &str = "wmem_max";
457        const RMEM: &str = "rmem_max";
458
459        #[cfg(any(feature = "test", test))]
460        let proc_path = std::env::var("MOCK_PROC_PATH").unwrap_or_else(|_| "/proc".to_string());
461        #[cfg(not(any(feature = "test", test)))]
462        let proc_path = "/proc";
463        let proc_path = std::path::Path::new(&proc_path).join("sys/net/core");
464
465        let read_and_parse = |file_name: &str| -> Result<Option<usize>> {
466            let path = proc_path.join(file_name);
467            if !path.exists() {
468                tracing::warn!("{} not found", path.display());
469                return Ok(None);
470            }
471            let res = std::fs::read_to_string(&path)
472                .with_context(|| format!("Failed to read {}", path.display()))?
473                .trim()
474                .parse()
475                .with_context(|| format!("Failed to parse {}", path.display()))?;
476            Ok(Some(res))
477        };
478
479        let rmem = read_and_parse(RMEM)?;
480        let wmem = read_and_parse(WMEM)?;
481
482        Ok(rmem.zip(wmem).map(|(recv, send)| Self { send, recv }))
483    }
484
485    #[cfg(not(target_os = "linux"))]
486    pub fn read() -> std::io::Result<Option<Self>> {
487        Ok(None)
488    }
489}
490
491#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
492pub enum ConnectionError {
493    #[error("invalid address")]
494    InvalidAddress,
495    #[error("connection init failed")]
496    ConnectionInitFailed,
497    #[error("invalid certificate")]
498    InvalidCertificate,
499    #[error("handshake failed")]
500    HandshakeFailed,
501    #[error("connection timeout")]
502    Timeout,
503    #[error("network has been shutdown")]
504    Shutdown,
505}
506
507#[cfg(test)]
508mod tests {
509    use futures_util::StreamExt;
510    use futures_util::stream::FuturesUnordered;
511
512    use super::*;
513    use crate::types::{BoxCloneService, PeerInfo, Request, service_message_fn, service_query_fn};
514    use crate::util::{NetworkExt, UnknownPeerError};
515
516    fn echo_service() -> BoxCloneService<ServiceRequest, Response> {
517        let handle = |request: ServiceRequest| async move {
518            tracing::trace!("received: {}", request.body.escape_ascii());
519            let response = Response {
520                version: Default::default(),
521                body: request.body,
522            };
523            Some(response)
524        };
525        service_query_fn(handle).boxed_clone()
526    }
527
528    fn make_network() -> Result<Network> {
529        Network::builder()
530            .with_config(NetworkConfig {
531                enable_0rtt: true,
532                ..Default::default()
533            })
534            .with_random_private_key()
535            .build("127.0.0.1:0", echo_service())
536    }
537
538    fn make_peer_info(network: &Network) -> Arc<PeerInfo> {
539        Arc::new(PeerInfo {
540            id: *network.peer_id(),
541            address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
542            created_at: 0,
543            expires_at: u32::MAX,
544            signature: Box::new([0; 64]),
545        })
546    }
547
548    #[tokio::test]
549    async fn connection_manager_works() -> Result<()> {
550        tycho_util::test::init_logger("connection_manager_works", "debug");
551
552        let peer1 = make_network()?;
553        let peer2 = make_network()?;
554
555        peer1
556            .connect(peer2.local_addr(), peer2.peer_id())
557            .await
558            .unwrap();
559        peer2
560            .connect(peer1.local_addr(), peer1.peer_id())
561            .await
562            .unwrap();
563
564        Ok(())
565    }
566
567    #[tokio::test]
568    async fn invalid_peer_id_detectable() -> Result<()> {
569        tycho_util::test::init_logger("invalid_peer_id_detectable", "debug");
570
571        let peer1 = make_network()?;
572        let peer2 = make_network()?;
573
574        let make_invalid_peer_info = |network: &Network| {
575            Arc::new(PeerInfo {
576                id: PeerId([0; 32]),
577                address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
578                created_at: 0,
579                expires_at: u32::MAX,
580                signature: Box::new([0; 64]),
581            })
582        };
583        let _handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
584        let _handle = peer1
585            .known_peers()
586            .insert(make_invalid_peer_info(&peer2), false)?;
587
588        let _handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
589        let _handle = peer2
590            .known_peers()
591            .insert(make_invalid_peer_info(&peer1), false)?;
592
593        let req = Request {
594            version: Default::default(),
595            body: "hello".into(),
596        };
597
598        peer1.query(peer2.peer_id(), req.clone()).await?;
599        peer2.query(peer1.peer_id(), req.clone()).await?;
600
601        fn assert_is_invalid_certificate(e: anyhow::Error) {
602            // A non-recursive downcast to find a connection error
603            let e = (*e).downcast_ref::<ConnectionError>().unwrap();
604            assert_eq!(*e, ConnectionError::InvalidCertificate);
605        }
606
607        let err = peer1
608            .query(&PeerId([0; 32]), req.clone())
609            .await
610            .map(|_| ())
611            .unwrap_err();
612        assert_is_invalid_certificate(err);
613
614        let err = peer2
615            .query(&PeerId([0; 32]), req.clone())
616            .await
617            .map(|_| ())
618            .unwrap_err();
619        assert_is_invalid_certificate(err);
620
621        fn assert_is_unknown_peer(e: anyhow::Error, peer_id: &PeerId) {
622            // A non-recursive downcast to find an error
623            let e = (*e).downcast_ref::<UnknownPeerError>().unwrap();
624            assert_eq!(e, &UnknownPeerError { peer_id: *peer_id });
625        }
626
627        let invalid_peer_id = PeerId([0xff; 32]);
628        let err = peer1
629            .query(&invalid_peer_id, req.clone())
630            .await
631            .map(|_| ())
632            .unwrap_err();
633        assert_is_unknown_peer(err, &invalid_peer_id);
634
635        Ok(())
636    }
637
638    #[tokio::test]
639    async fn simultaneous_queries() -> Result<()> {
640        tycho_util::test::init_logger("simultaneous_queries", "debug");
641
642        for _ in 0..10 {
643            let peer1 = make_network()?;
644            let peer2 = make_network()?;
645
646            let _peer1_peer2_handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
647            let _peer2_peer1_handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
648
649            let req = Request {
650                version: Default::default(),
651                body: "hello".into(),
652            };
653            let peer1_fut = std::pin::pin!(peer1.query(peer2.peer_id(), req.clone()));
654            let peer2_fut = std::pin::pin!(peer2.query(peer1.peer_id(), req.clone()));
655
656            let (res1, res2) = futures_util::future::join(peer1_fut, peer2_fut).await;
657            assert_eq!(res1?.body, req.body);
658            assert_eq!(res2?.body, req.body);
659        }
660
661        Ok(())
662    }
663
664    #[tokio::test(flavor = "multi_thread")]
665    async fn uni_message_handler() -> Result<()> {
666        tycho_util::test::init_logger("uni_message_handler", "debug");
667
668        fn noop_service() -> BoxCloneService<ServiceRequest, Response> {
669            let handle = |request: ServiceRequest| async move {
670                tracing::trace!("received: {} bytes", request.body.len());
671            };
672            service_message_fn(handle).boxed_clone()
673        }
674
675        fn make_network() -> Result<Network> {
676            Network::builder()
677                .with_config(NetworkConfig {
678                    enable_0rtt: true,
679                    ..Default::default()
680                })
681                .with_random_private_key()
682                .build("127.0.0.1:0", noop_service())
683        }
684
685        let left = make_network()?;
686        let right = make_network()?;
687
688        let _left_to_right = left.known_peers().insert(make_peer_info(&right), false)?;
689        let _right_to_left = right.known_peers().insert(make_peer_info(&left), false)?;
690
691        let req = Request {
692            version: Default::default(),
693            body: vec![0xff; 750 * 1024].into(),
694        };
695
696        for _ in 0..10 {
697            let mut futures = FuturesUnordered::new();
698            for _ in 0..100 {
699                futures.push(left.send(right.peer_id(), req.clone()));
700            }
701
702            while let Some(res) = futures.next().await {
703                res?;
704            }
705        }
706
707        Ok(())
708    }
709
710    #[test]
711    fn socket_size_works() {
712        if std::path::Path::new("/proc").exists() {
713            let socket_size = MaxBufferSize::read()
714                .unwrap()
715                .expect("socket size not found");
716            assert!(socket_size.send > 0);
717            assert!(socket_size.recv > 0);
718        } else {
719            // github doesn't expose /proc and macos exists
720            let procfs = tempfile::tempdir().unwrap();
721            unsafe { std::env::set_var("MOCK_PROC_PATH", procfs.path()) };
722
723            std::fs::create_dir_all(procfs.path().join("sys/net/core")).unwrap();
724            std::fs::write(procfs.path().join("sys/net/core/wmem_max"), "100000\n").unwrap();
725            std::fs::write(procfs.path().join("sys/net/core/rmem_max"), "100000\n").unwrap();
726
727            let socket_size = MaxBufferSize::read()
728                .unwrap()
729                .expect("socket size not found");
730
731            assert_eq!(socket_size.send, 100000);
732            assert_eq!(socket_size.recv, 100000);
733        }
734    }
735}