1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::{Arc, Weak};
3
4#[cfg(target_os = "linux")]
5use anyhow::Context;
6use anyhow::Result;
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tycho_crypto::ed25519;
9
10use self::config::EndpointConfig;
11pub use self::config::{CongestionAlgorithm, ConnectionMetricsLevel, NetworkConfig, QuicConfig};
12pub use self::connection::{Connection, RecvStream, SendStream};
13use self::connection_manager::{ActivePeers, ConnectionManager, ConnectionManagerRequest};
14pub use self::connection_manager::{
15 KnownPeerHandle, KnownPeers, KnownPeersError, PeerBannedError, WeakKnownPeerHandle,
16};
17use self::endpoint::Endpoint;
18pub use self::peer::Peer;
19use crate::types::{
20 Address, DisconnectReason, PeerEvent, PeerId, PeerInfo, Response, Service, ServiceExt,
21 ServiceRequest,
22};
23
24mod config;
25mod connection;
26mod connection_manager;
27mod crypto;
28mod endpoint;
29mod peer;
30mod request_handler;
31mod wire;
32
33pub struct NetworkBuilder<MandatoryFields = ([u8; 32],)> {
34 mandatory_fields: MandatoryFields,
35 optional_fields: BuilderFields,
36}
37
38#[derive(Default)]
39struct BuilderFields {
40 config: Option<NetworkConfig>,
41 remote_addr: Option<Address>,
42}
43
44impl<MandatoryFields> NetworkBuilder<MandatoryFields> {
45 pub fn with_config(mut self, config: NetworkConfig) -> Self {
46 self.optional_fields.config = Some(config);
47 self
48 }
49
50 pub fn with_remote_addr<T: Into<Address>>(mut self, addr: T) -> Self {
51 self.optional_fields.remote_addr = Some(addr.into());
52 self
53 }
54}
55
56impl NetworkBuilder<((),)> {
57 pub fn with_private_key(self, private_key: [u8; 32]) -> NetworkBuilder<([u8; 32],)> {
58 NetworkBuilder {
59 mandatory_fields: (private_key,),
60 optional_fields: self.optional_fields,
61 }
62 }
63
64 pub fn with_random_private_key(self) -> NetworkBuilder<([u8; 32],)> {
65 self.with_private_key(rand::random())
66 }
67}
68
69impl NetworkBuilder {
70 pub fn build<T: ToSocket, S>(self, bind_address: T, service: S) -> Result<Network>
71 where
72 S: Send + Sync + Clone + 'static,
73 S: Service<ServiceRequest, QueryResponse = Response>,
74 {
75 let config = self.optional_fields.config.unwrap_or_default();
76 let quic_config = config.quic.clone().unwrap_or_default();
77 let (private_key,) = self.mandatory_fields;
78
79 let keypair = ed25519::KeyPair::from(&ed25519::SecretKey::from_bytes(private_key));
80
81 let endpoint_config = EndpointConfig::builder()
82 .with_private_key(private_key)
83 .with_0rtt_enabled(config.enable_0rtt)
84 .with_transport_config(quic_config.make_transport_config())
85 .with_connection_metrics(config.connection_metrics)
86 .build()?;
87
88 let socket = bind_address.to_socket().map(socket2::Socket::from)?;
89
90 let max_socket_size = MaxBufferSize::read()?;
91
92 set_socket_buffer(
93 &socket,
94 quic_config.socket_send_buffer_size,
95 max_socket_size.map(|m| m.send),
96 |s, size| s.set_send_buffer_size(size),
97 "send",
98 );
99
100 set_socket_buffer(
101 &socket,
102 quic_config.socket_recv_buffer_size,
103 max_socket_size.map(|m| m.recv),
104 |s, size| s.set_recv_buffer_size(size),
105 "recv",
106 );
107
108 let config = Arc::new(config);
109 let endpoint = Arc::new(Endpoint::new(endpoint_config, socket.into())?);
110 let active_peers = ActivePeers::new(config.active_peers_event_channel_capacity);
111 let known_peers = KnownPeers::new();
112
113 let mut remote_addr = self.optional_fields.remote_addr.unwrap_or_else(|| {
114 let addr = endpoint.local_addr();
115 tracing::debug!(%addr, "using local address as remote address");
116 addr.into()
117 });
118 if remote_addr.port() == 0 {
119 remote_addr.set_port(endpoint.local_addr().port());
120 }
121
122 let service = service.boxed_clone();
123
124 let (connection_manager, connection_manager_handle) = ConnectionManager::new(
125 config.clone(),
126 endpoint.clone(),
127 active_peers.clone(),
128 known_peers.clone(),
129 service,
130 );
131
132 tokio::spawn(connection_manager.start());
133
134 Ok(Network(Arc::new(NetworkInner {
135 config,
136 remote_addr,
137 endpoint,
138 active_peers,
139 known_peers,
140 connection_manager_handle,
141 keypair,
142 })))
143 }
144}
145
146fn set_socket_buffer(
147 socket: &socket2::Socket,
148 config_size: Option<usize>,
149 max_size: Option<usize>,
150 set_buffer_fn: impl Fn(&socket2::Socket, usize) -> std::io::Result<()>,
151 buffer_type: &str,
152) {
153 if let Some(size) = config_size {
154 if let Err(e) = set_buffer_fn(socket, size) {
155 tracing::error!(%size, "failed to set socket {} buffer size: {e:?}", buffer_type);
156 }
157 } else if let Some(max) = max_size {
158 if let Err(e) = set_buffer_fn(socket, max) {
159 tracing::error!(%max, "failed to set socket {} buffer size to max value: {e:?}", buffer_type);
160 }
161 tracing::info!(
162 "set socket {} buffer size to max value: {}",
163 buffer_type,
164 max
165 );
166 }
167}
168
169#[derive(Clone)]
170#[repr(transparent)]
171pub struct WeakNetwork(Weak<NetworkInner>);
172
173impl WeakNetwork {
174 pub fn upgrade(&self) -> Option<Network> {
175 self.0
176 .upgrade()
177 .map(Network)
178 .and_then(|network| (!network.is_closed()).then_some(network))
179 }
180}
181
182#[derive(Clone)]
183#[repr(transparent)]
184pub struct Network(Arc<NetworkInner>);
185
186impl Network {
187 pub fn builder() -> NetworkBuilder<((),)> {
188 NetworkBuilder {
189 mandatory_fields: ((),),
190 optional_fields: Default::default(),
191 }
192 }
193
194 pub fn remote_addr(&self) -> &Address {
196 self.0.remote_addr()
197 }
198
199 pub fn local_addr(&self) -> SocketAddr {
201 self.0.local_addr()
202 }
203
204 pub fn peer_id(&self) -> &PeerId {
206 self.0.peer_id()
207 }
208
209 pub fn is_active(&self, peer_id: &PeerId) -> bool {
211 self.0.active_peers.contains(peer_id)
212 }
213
214 pub fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
216 self.0.peer(peer_id)
217 }
218
219 pub fn known_peers(&self) -> &KnownPeers {
221 &self.0.known_peers
222 }
223
224 pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
226 self.0.active_peers.subscribe()
227 }
228
229 pub async fn connect<T>(&self, addr: T, peer_id: &PeerId) -> Result<Peer, ConnectionError>
231 where
232 T: Into<Address>,
233 {
234 self.0.connect(addr.into(), peer_id).await
235 }
236
237 pub fn disconnect(&self, peer_id: &PeerId) {
238 self.0.disconnect(peer_id);
239 }
240
241 pub async fn shutdown(&self) {
242 self.0.shutdown().await;
243 }
244
245 pub fn is_closed(&self) -> bool {
246 self.0.is_closed()
247 }
248
249 pub fn sign_tl<T: tl_proto::TlWrite>(&self, data: T) -> [u8; 64] {
250 self.0.keypair.sign_tl(data)
251 }
252
253 pub fn sign_raw(&self, data: &[u8]) -> [u8; 64] {
254 self.0.keypair.sign_raw(data)
255 }
256
257 pub fn sign_peer_info(&self, now: u32, ttl: u32) -> PeerInfo {
258 let mut res = PeerInfo {
259 id: *self.0.peer_id(),
260 address_list: vec![self.remote_addr().clone()].into_boxed_slice(),
261 created_at: now,
262 expires_at: now.saturating_add(ttl),
263 signature: Box::new([0; 64]),
264 };
265 *res.signature = self.sign_tl(&res);
266 res
267 }
268
269 pub fn downgrade(this: &Self) -> WeakNetwork {
270 WeakNetwork(Arc::downgrade(&this.0))
271 }
272
273 pub fn max_frame_size(&self) -> usize {
275 self.0.config.max_frame_size.0 as usize
276 }
277}
278
279struct NetworkInner {
280 config: Arc<NetworkConfig>,
281 remote_addr: Address,
282 endpoint: Arc<Endpoint>,
283 active_peers: ActivePeers,
284 known_peers: KnownPeers,
285 connection_manager_handle: mpsc::Sender<ConnectionManagerRequest>,
286 keypair: ed25519::KeyPair,
287}
288
289impl NetworkInner {
290 fn remote_addr(&self) -> &Address {
291 &self.remote_addr
292 }
293
294 fn local_addr(&self) -> SocketAddr {
295 self.endpoint.local_addr()
296 }
297
298 fn peer_id(&self) -> &PeerId {
299 self.endpoint.peer_id()
300 }
301
302 async fn connect(&self, addr: Address, peer_id: &PeerId) -> Result<Peer, ConnectionError> {
303 let (tx, rx) = oneshot::channel();
304 self.connection_manager_handle
305 .send(ConnectionManagerRequest::Connect(addr, *peer_id, tx))
306 .await
307 .map_err(|_e| ConnectionError::Shutdown)?;
308
309 let Ok(res) = rx.await else {
310 return Err(ConnectionError::Shutdown);
311 };
312
313 res.map(|c| Peer::new(c, self.config.clone()))
314 }
315
316 fn disconnect(&self, peer_id: &PeerId) {
317 self.active_peers
318 .remove(peer_id, DisconnectReason::Requested);
319 }
320
321 fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
322 let connection = self.active_peers.get(peer_id)?;
323 Some(Peer::new(connection, self.config.clone()))
324 }
325
326 async fn shutdown(&self) {
327 let (sender, receiver) = oneshot::channel();
328 if self
329 .connection_manager_handle
330 .send(ConnectionManagerRequest::Shutdown(sender))
331 .await
332 .is_err()
333 {
334 return;
335 }
336
337 receiver.await.ok();
338 }
339
340 fn is_closed(&self) -> bool {
341 self.connection_manager_handle.is_closed()
342 }
343}
344
345impl Drop for NetworkInner {
346 fn drop(&mut self) {
347 tracing::debug!("network dropped");
348 }
349}
350
351pub trait ToSocket {
352 fn to_socket(self) -> Result<std::net::UdpSocket, BindError>;
353}
354
355impl ToSocket for std::net::UdpSocket {
356 fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
357 Ok(self)
358 }
359}
360
361macro_rules! impl_to_socket_for_addr {
362 ($($ty:ty),*$(,)?) => {$(
363 impl ToSocket for $ty {
364 fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
365 bind_socket_to_addr(self)
366 }
367 }
368 )*};
369}
370
371impl_to_socket_for_addr! {
372 SocketAddr,
373 std::net::SocketAddrV4,
374 std::net::SocketAddrV6,
375 (std::net::IpAddr, u16),
376 (std::net::Ipv4Addr, u16),
377 (std::net::Ipv6Addr, u16),
378 (&str, u16),
379 (String, u16),
380 &str,
381 String,
382 &[SocketAddr],
383 Address,
384}
385
386fn bind_socket_to_addr<T: ToSocketAddrs>(
387 bind_address: T,
388) -> Result<std::net::UdpSocket, BindError> {
389 use socket2::{Domain, Protocol, Socket, Type};
390
391 let socket_addrs = bind_address
392 .to_socket_addrs()
393 .map_err(BindError::AddressResolution)?;
394
395 let mut last_bind_error: Option<BindError> = None;
396
397 for addr in socket_addrs {
398 let socket = match Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
399 {
400 Ok(s) => s,
401 Err(e) => {
402 if last_bind_error.is_none() {
403 last_bind_error = Some(BindError::SocketCreation { addr, source: e });
404 }
405 continue;
406 }
407 };
408
409 match socket.bind(&socket2::SockAddr::from(addr)) {
411 Ok(()) => return Ok(socket.into()),
412 Err(e) => {
413 tracing::warn!(?e, %addr, "failed to bind, trying next address");
415 last_bind_error = Some(BindError::SocketBind { addr, source: e });
416 }
417 }
418 }
419
420 Err(last_bind_error.unwrap_or(BindError::NoAddressesToBind))
422}
423
424#[derive(thiserror::Error, Debug)]
425pub enum BindError {
426 #[error("failed to resolve socket addresses: {0}")]
427 AddressResolution(#[source] std::io::Error),
428
429 #[error("no suitable addresses found to bind to after attempting all resolved addresses")]
430 NoAddressesToBind,
431
432 #[error("failed to create new socket for address {addr:?}: {source}")]
433 SocketCreation {
434 addr: SocketAddr,
435 #[source]
436 source: std::io::Error,
437 },
438
439 #[error("failed to bind socket to address {addr}: {source}")]
440 SocketBind {
441 addr: SocketAddr,
442 #[source]
443 source: std::io::Error,
444 },
445}
446
447#[derive(Debug, Clone, Copy)]
448struct MaxBufferSize {
449 send: usize,
450 recv: usize,
451}
452
453impl MaxBufferSize {
454 #[cfg(target_os = "linux")]
455 pub fn read() -> Result<Option<Self>> {
456 const WMEM: &str = "wmem_max";
457 const RMEM: &str = "rmem_max";
458
459 #[cfg(any(feature = "test", test))]
460 let proc_path = std::env::var("MOCK_PROC_PATH").unwrap_or_else(|_| "/proc".to_string());
461 #[cfg(not(any(feature = "test", test)))]
462 let proc_path = "/proc";
463 let proc_path = std::path::Path::new(&proc_path).join("sys/net/core");
464
465 let read_and_parse = |file_name: &str| -> Result<Option<usize>> {
466 let path = proc_path.join(file_name);
467 if !path.exists() {
468 tracing::warn!("{} not found", path.display());
469 return Ok(None);
470 }
471 let res = std::fs::read_to_string(&path)
472 .with_context(|| format!("Failed to read {}", path.display()))?
473 .trim()
474 .parse()
475 .with_context(|| format!("Failed to parse {}", path.display()))?;
476 Ok(Some(res))
477 };
478
479 let rmem = read_and_parse(RMEM)?;
480 let wmem = read_and_parse(WMEM)?;
481
482 Ok(rmem.zip(wmem).map(|(recv, send)| Self { send, recv }))
483 }
484
485 #[cfg(not(target_os = "linux"))]
486 pub fn read() -> std::io::Result<Option<Self>> {
487 Ok(None)
488 }
489}
490
491#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
492pub enum ConnectionError {
493 #[error("invalid address")]
494 InvalidAddress,
495 #[error("connection init failed")]
496 ConnectionInitFailed,
497 #[error("invalid certificate")]
498 InvalidCertificate,
499 #[error("handshake failed")]
500 HandshakeFailed,
501 #[error("connection timeout")]
502 Timeout,
503 #[error("network has been shutdown")]
504 Shutdown,
505}
506
507#[cfg(test)]
508mod tests {
509 use futures_util::StreamExt;
510 use futures_util::stream::FuturesUnordered;
511
512 use super::*;
513 use crate::types::{BoxCloneService, PeerInfo, Request, service_message_fn, service_query_fn};
514 use crate::util::{NetworkExt, UnknownPeerError};
515
516 fn echo_service() -> BoxCloneService<ServiceRequest, Response> {
517 let handle = |request: ServiceRequest| async move {
518 tracing::trace!("received: {}", request.body.escape_ascii());
519 let response = Response {
520 version: Default::default(),
521 body: request.body,
522 };
523 Some(response)
524 };
525 service_query_fn(handle).boxed_clone()
526 }
527
528 fn make_network() -> Result<Network> {
529 Network::builder()
530 .with_config(NetworkConfig {
531 enable_0rtt: true,
532 ..Default::default()
533 })
534 .with_random_private_key()
535 .build("127.0.0.1:0", echo_service())
536 }
537
538 fn make_peer_info(network: &Network) -> Arc<PeerInfo> {
539 Arc::new(PeerInfo {
540 id: *network.peer_id(),
541 address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
542 created_at: 0,
543 expires_at: u32::MAX,
544 signature: Box::new([0; 64]),
545 })
546 }
547
548 #[tokio::test]
549 async fn connection_manager_works() -> Result<()> {
550 tycho_util::test::init_logger("connection_manager_works", "debug");
551
552 let peer1 = make_network()?;
553 let peer2 = make_network()?;
554
555 peer1
556 .connect(peer2.local_addr(), peer2.peer_id())
557 .await
558 .unwrap();
559 peer2
560 .connect(peer1.local_addr(), peer1.peer_id())
561 .await
562 .unwrap();
563
564 Ok(())
565 }
566
567 #[tokio::test]
568 async fn invalid_peer_id_detectable() -> Result<()> {
569 tycho_util::test::init_logger("invalid_peer_id_detectable", "debug");
570
571 let peer1 = make_network()?;
572 let peer2 = make_network()?;
573
574 let make_invalid_peer_info = |network: &Network| {
575 Arc::new(PeerInfo {
576 id: PeerId([0; 32]),
577 address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
578 created_at: 0,
579 expires_at: u32::MAX,
580 signature: Box::new([0; 64]),
581 })
582 };
583 let _handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
584 let _handle = peer1
585 .known_peers()
586 .insert(make_invalid_peer_info(&peer2), false)?;
587
588 let _handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
589 let _handle = peer2
590 .known_peers()
591 .insert(make_invalid_peer_info(&peer1), false)?;
592
593 let req = Request {
594 version: Default::default(),
595 body: "hello".into(),
596 };
597
598 peer1.query(peer2.peer_id(), req.clone()).await?;
599 peer2.query(peer1.peer_id(), req.clone()).await?;
600
601 fn assert_is_invalid_certificate(e: anyhow::Error) {
602 let e = (*e).downcast_ref::<ConnectionError>().unwrap();
604 assert_eq!(*e, ConnectionError::InvalidCertificate);
605 }
606
607 let err = peer1
608 .query(&PeerId([0; 32]), req.clone())
609 .await
610 .map(|_| ())
611 .unwrap_err();
612 assert_is_invalid_certificate(err);
613
614 let err = peer2
615 .query(&PeerId([0; 32]), req.clone())
616 .await
617 .map(|_| ())
618 .unwrap_err();
619 assert_is_invalid_certificate(err);
620
621 fn assert_is_unknown_peer(e: anyhow::Error, peer_id: &PeerId) {
622 let e = (*e).downcast_ref::<UnknownPeerError>().unwrap();
624 assert_eq!(e, &UnknownPeerError { peer_id: *peer_id });
625 }
626
627 let invalid_peer_id = PeerId([0xff; 32]);
628 let err = peer1
629 .query(&invalid_peer_id, req.clone())
630 .await
631 .map(|_| ())
632 .unwrap_err();
633 assert_is_unknown_peer(err, &invalid_peer_id);
634
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn simultaneous_queries() -> Result<()> {
640 tycho_util::test::init_logger("simultaneous_queries", "debug");
641
642 for _ in 0..10 {
643 let peer1 = make_network()?;
644 let peer2 = make_network()?;
645
646 let _peer1_peer2_handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
647 let _peer2_peer1_handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
648
649 let req = Request {
650 version: Default::default(),
651 body: "hello".into(),
652 };
653 let peer1_fut = std::pin::pin!(peer1.query(peer2.peer_id(), req.clone()));
654 let peer2_fut = std::pin::pin!(peer2.query(peer1.peer_id(), req.clone()));
655
656 let (res1, res2) = futures_util::future::join(peer1_fut, peer2_fut).await;
657 assert_eq!(res1?.body, req.body);
658 assert_eq!(res2?.body, req.body);
659 }
660
661 Ok(())
662 }
663
664 #[tokio::test(flavor = "multi_thread")]
665 async fn uni_message_handler() -> Result<()> {
666 tycho_util::test::init_logger("uni_message_handler", "debug");
667
668 fn noop_service() -> BoxCloneService<ServiceRequest, Response> {
669 let handle = |request: ServiceRequest| async move {
670 tracing::trace!("received: {} bytes", request.body.len());
671 };
672 service_message_fn(handle).boxed_clone()
673 }
674
675 fn make_network() -> Result<Network> {
676 Network::builder()
677 .with_config(NetworkConfig {
678 enable_0rtt: true,
679 ..Default::default()
680 })
681 .with_random_private_key()
682 .build("127.0.0.1:0", noop_service())
683 }
684
685 let left = make_network()?;
686 let right = make_network()?;
687
688 let _left_to_right = left.known_peers().insert(make_peer_info(&right), false)?;
689 let _right_to_left = right.known_peers().insert(make_peer_info(&left), false)?;
690
691 let req = Request {
692 version: Default::default(),
693 body: vec![0xff; 750 * 1024].into(),
694 };
695
696 for _ in 0..10 {
697 let mut futures = FuturesUnordered::new();
698 for _ in 0..100 {
699 futures.push(left.send(right.peer_id(), req.clone()));
700 }
701
702 while let Some(res) = futures.next().await {
703 res?;
704 }
705 }
706
707 Ok(())
708 }
709
710 #[test]
711 fn socket_size_works() {
712 if std::path::Path::new("/proc").exists() {
713 let socket_size = MaxBufferSize::read()
714 .unwrap()
715 .expect("socket size not found");
716 assert!(socket_size.send > 0);
717 assert!(socket_size.recv > 0);
718 } else {
719 let procfs = tempfile::tempdir().unwrap();
721 unsafe { std::env::set_var("MOCK_PROC_PATH", procfs.path()) };
722
723 std::fs::create_dir_all(procfs.path().join("sys/net/core")).unwrap();
724 std::fs::write(procfs.path().join("sys/net/core/wmem_max"), "100000\n").unwrap();
725 std::fs::write(procfs.path().join("sys/net/core/rmem_max"), "100000\n").unwrap();
726
727 let socket_size = MaxBufferSize::read()
728 .unwrap()
729 .expect("socket size not found");
730
731 assert_eq!(socket_size.send, 100000);
732 assert_eq!(socket_size.recv, 100000);
733 }
734 }
735}