1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::{Arc, Weak};
3
4#[cfg(target_os = "linux")]
5use anyhow::Context;
6use anyhow::Result;
7use tokio::sync::{broadcast, mpsc, oneshot};
8use tycho_crypto::ed25519;
9
10use self::config::EndpointConfig;
11pub use self::config::{NetworkConfig, QuicConfig};
12pub use self::connection::{Connection, RecvStream, SendStream};
13use self::connection_manager::{ActivePeers, ConnectionManager, ConnectionManagerRequest};
14pub use self::connection_manager::{
15 KnownPeerHandle, KnownPeers, KnownPeersError, PeerBannedError, WeakKnownPeerHandle,
16};
17use self::endpoint::Endpoint;
18pub use self::peer::Peer;
19use crate::types::{
20 Address, DisconnectReason, PeerEvent, PeerId, PeerInfo, Response, Service, ServiceExt,
21 ServiceRequest,
22};
23
24mod config;
25mod connection;
26mod connection_manager;
27mod crypto;
28mod endpoint;
29mod peer;
30mod request_handler;
31mod wire;
32
33pub struct NetworkBuilder<MandatoryFields = ([u8; 32],)> {
34 mandatory_fields: MandatoryFields,
35 optional_fields: BuilderFields,
36}
37
38#[derive(Default)]
39struct BuilderFields {
40 config: Option<NetworkConfig>,
41 remote_addr: Option<Address>,
42}
43
44impl<MandatoryFields> NetworkBuilder<MandatoryFields> {
45 pub fn with_config(mut self, config: NetworkConfig) -> Self {
46 self.optional_fields.config = Some(config);
47 self
48 }
49
50 pub fn with_remote_addr<T: Into<Address>>(mut self, addr: T) -> Self {
51 self.optional_fields.remote_addr = Some(addr.into());
52 self
53 }
54}
55
56impl NetworkBuilder<((),)> {
57 pub fn with_private_key(self, private_key: [u8; 32]) -> NetworkBuilder<([u8; 32],)> {
58 NetworkBuilder {
59 mandatory_fields: (private_key,),
60 optional_fields: self.optional_fields,
61 }
62 }
63
64 pub fn with_random_private_key(self) -> NetworkBuilder<([u8; 32],)> {
65 self.with_private_key(rand::random())
66 }
67}
68
69impl NetworkBuilder {
70 pub fn build<T: ToSocket, S>(self, bind_address: T, service: S) -> Result<Network>
71 where
72 S: Send + Sync + Clone + 'static,
73 S: Service<ServiceRequest, QueryResponse = Response>,
74 {
75 let config = self.optional_fields.config.unwrap_or_default();
76 let quic_config = config.quic.clone().unwrap_or_default();
77 let (private_key,) = self.mandatory_fields;
78
79 let keypair = ed25519::KeyPair::from(&ed25519::SecretKey::from_bytes(private_key));
80
81 let endpoint_config = EndpointConfig::builder()
82 .with_private_key(private_key)
83 .with_0rtt_enabled(config.enable_0rtt)
84 .with_transport_config(quic_config.make_transport_config())
85 .with_connection_metrics(config.connection_metrics)
86 .build()?;
87
88 let socket = bind_address.to_socket().map(socket2::Socket::from)?;
89
90 let max_socket_size = MaxBufferSize::read()?;
91
92 set_socket_buffer(
93 &socket,
94 quic_config.socket_send_buffer_size,
95 max_socket_size.map(|m| m.send),
96 |s, size| s.set_send_buffer_size(size),
97 "send",
98 );
99
100 set_socket_buffer(
101 &socket,
102 quic_config.socket_recv_buffer_size,
103 max_socket_size.map(|m| m.recv),
104 |s, size| s.set_recv_buffer_size(size),
105 "recv",
106 );
107
108 let config = Arc::new(config);
109 let endpoint = Arc::new(Endpoint::new(endpoint_config, socket.into())?);
110 let active_peers = ActivePeers::new(config.active_peers_event_channel_capacity);
111 let known_peers = KnownPeers::new();
112
113 let remote_addr = self.optional_fields.remote_addr.unwrap_or_else(|| {
114 let addr = endpoint.local_addr();
115 tracing::debug!(%addr, "using local address as remote address");
116 addr.into()
117 });
118
119 let service = service.boxed_clone();
120
121 let (connection_manager, connection_manager_handle) = ConnectionManager::new(
122 config.clone(),
123 endpoint.clone(),
124 active_peers.clone(),
125 known_peers.clone(),
126 service,
127 );
128
129 tokio::spawn(connection_manager.start());
130
131 Ok(Network(Arc::new(NetworkInner {
132 config,
133 remote_addr,
134 endpoint,
135 active_peers,
136 known_peers,
137 connection_manager_handle,
138 keypair,
139 })))
140 }
141}
142
143fn set_socket_buffer(
144 socket: &socket2::Socket,
145 config_size: Option<usize>,
146 max_size: Option<usize>,
147 set_buffer_fn: impl Fn(&socket2::Socket, usize) -> std::io::Result<()>,
148 buffer_type: &str,
149) {
150 if let Some(size) = config_size {
151 if let Err(e) = set_buffer_fn(socket, size) {
152 tracing::error!(%size, "failed to set socket {} buffer size: {e:?}", buffer_type);
153 }
154 } else if let Some(max) = max_size {
155 if let Err(e) = set_buffer_fn(socket, max) {
156 tracing::error!(%max, "failed to set socket {} buffer size to max value: {e:?}", buffer_type);
157 }
158 tracing::info!(
159 "set socket {} buffer size to max value: {}",
160 buffer_type,
161 max
162 );
163 }
164}
165
166#[derive(Clone)]
167#[repr(transparent)]
168pub struct WeakNetwork(Weak<NetworkInner>);
169
170impl WeakNetwork {
171 pub fn upgrade(&self) -> Option<Network> {
172 self.0
173 .upgrade()
174 .map(Network)
175 .and_then(|network| (!network.is_closed()).then_some(network))
176 }
177}
178
179#[derive(Clone)]
180#[repr(transparent)]
181pub struct Network(Arc<NetworkInner>);
182
183impl Network {
184 pub fn builder() -> NetworkBuilder<((),)> {
185 NetworkBuilder {
186 mandatory_fields: ((),),
187 optional_fields: Default::default(),
188 }
189 }
190
191 pub fn remote_addr(&self) -> &Address {
193 self.0.remote_addr()
194 }
195
196 pub fn local_addr(&self) -> SocketAddr {
198 self.0.local_addr()
199 }
200
201 pub fn peer_id(&self) -> &PeerId {
203 self.0.peer_id()
204 }
205
206 pub fn is_active(&self, peer_id: &PeerId) -> bool {
208 self.0.active_peers.contains(peer_id)
209 }
210
211 pub fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
213 self.0.peer(peer_id)
214 }
215
216 pub fn known_peers(&self) -> &KnownPeers {
218 &self.0.known_peers
219 }
220
221 pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
223 self.0.active_peers.subscribe()
224 }
225
226 pub async fn connect<T>(&self, addr: T, peer_id: &PeerId) -> Result<Peer, ConnectionError>
228 where
229 T: Into<Address>,
230 {
231 self.0.connect(addr.into(), peer_id).await
232 }
233
234 pub fn disconnect(&self, peer_id: &PeerId) {
235 self.0.disconnect(peer_id);
236 }
237
238 pub async fn shutdown(&self) {
239 self.0.shutdown().await;
240 }
241
242 pub fn is_closed(&self) -> bool {
243 self.0.is_closed()
244 }
245
246 pub fn sign_tl<T: tl_proto::TlWrite>(&self, data: T) -> [u8; 64] {
247 self.0.keypair.sign_tl(data)
248 }
249
250 pub fn sign_raw(&self, data: &[u8]) -> [u8; 64] {
251 self.0.keypair.sign_raw(data)
252 }
253
254 pub fn sign_peer_info(&self, now: u32, ttl: u32) -> PeerInfo {
255 let mut res = PeerInfo {
256 id: *self.0.peer_id(),
257 address_list: vec![self.remote_addr().clone()].into_boxed_slice(),
258 created_at: now,
259 expires_at: now.saturating_add(ttl),
260 signature: Box::new([0; 64]),
261 };
262 *res.signature = self.sign_tl(&res);
263 res
264 }
265
266 pub fn downgrade(this: &Self) -> WeakNetwork {
267 WeakNetwork(Arc::downgrade(&this.0))
268 }
269
270 pub fn max_frame_size(&self) -> usize {
272 self.0.config.max_frame_size.0 as usize
273 }
274}
275
276struct NetworkInner {
277 config: Arc<NetworkConfig>,
278 remote_addr: Address,
279 endpoint: Arc<Endpoint>,
280 active_peers: ActivePeers,
281 known_peers: KnownPeers,
282 connection_manager_handle: mpsc::Sender<ConnectionManagerRequest>,
283 keypair: ed25519::KeyPair,
284}
285
286impl NetworkInner {
287 fn remote_addr(&self) -> &Address {
288 &self.remote_addr
289 }
290
291 fn local_addr(&self) -> SocketAddr {
292 self.endpoint.local_addr()
293 }
294
295 fn peer_id(&self) -> &PeerId {
296 self.endpoint.peer_id()
297 }
298
299 async fn connect(&self, addr: Address, peer_id: &PeerId) -> Result<Peer, ConnectionError> {
300 let (tx, rx) = oneshot::channel();
301 self.connection_manager_handle
302 .send(ConnectionManagerRequest::Connect(addr, *peer_id, tx))
303 .await
304 .map_err(|_e| ConnectionError::Shutdown)?;
305
306 let Ok(res) = rx.await else {
307 return Err(ConnectionError::Shutdown);
308 };
309
310 res.map(|c| Peer::new(c, self.config.clone()))
311 }
312
313 fn disconnect(&self, peer_id: &PeerId) {
314 self.active_peers
315 .remove(peer_id, DisconnectReason::Requested);
316 }
317
318 fn peer(&self, peer_id: &PeerId) -> Option<Peer> {
319 let connection = self.active_peers.get(peer_id)?;
320 Some(Peer::new(connection, self.config.clone()))
321 }
322
323 async fn shutdown(&self) {
324 let (sender, receiver) = oneshot::channel();
325 if self
326 .connection_manager_handle
327 .send(ConnectionManagerRequest::Shutdown(sender))
328 .await
329 .is_err()
330 {
331 return;
332 }
333
334 receiver.await.ok();
335 }
336
337 fn is_closed(&self) -> bool {
338 self.connection_manager_handle.is_closed()
339 }
340}
341
342impl Drop for NetworkInner {
343 fn drop(&mut self) {
344 tracing::debug!("network dropped");
345 }
346}
347
348pub trait ToSocket {
349 fn to_socket(self) -> Result<std::net::UdpSocket, BindError>;
350}
351
352impl ToSocket for std::net::UdpSocket {
353 fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
354 Ok(self)
355 }
356}
357
358macro_rules! impl_to_socket_for_addr {
359 ($($ty:ty),*$(,)?) => {$(
360 impl ToSocket for $ty {
361 fn to_socket(self) -> Result<std::net::UdpSocket, BindError> {
362 bind_socket_to_addr(self)
363 }
364 }
365 )*};
366}
367
368impl_to_socket_for_addr! {
369 SocketAddr,
370 std::net::SocketAddrV4,
371 std::net::SocketAddrV6,
372 (std::net::IpAddr, u16),
373 (std::net::Ipv4Addr, u16),
374 (std::net::Ipv6Addr, u16),
375 (&str, u16),
376 (String, u16),
377 &str,
378 String,
379 &[SocketAddr],
380 Address,
381}
382
383fn bind_socket_to_addr<T: ToSocketAddrs>(
384 bind_address: T,
385) -> Result<std::net::UdpSocket, BindError> {
386 use socket2::{Domain, Protocol, Socket, Type};
387
388 let socket_addrs = bind_address
389 .to_socket_addrs()
390 .map_err(BindError::AddressResolution)?;
391
392 let mut last_bind_error: Option<BindError> = None;
393
394 for addr in socket_addrs {
395 let socket = match Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
396 {
397 Ok(s) => s,
398 Err(e) => {
399 if last_bind_error.is_none() {
400 last_bind_error = Some(BindError::SocketCreation { addr, source: e });
401 }
402 continue;
403 }
404 };
405
406 match socket.bind(&socket2::SockAddr::from(addr)) {
408 Ok(()) => return Ok(socket.into()),
409 Err(e) => {
410 tracing::warn!(?e, %addr, "failed to bind, trying next address");
412 last_bind_error = Some(BindError::SocketBind { addr, source: e });
413 }
414 }
415 }
416
417 Err(last_bind_error.unwrap_or(BindError::NoAddressesToBind))
419}
420
421#[derive(thiserror::Error, Debug)]
422pub enum BindError {
423 #[error("failed to resolve socket addresses: {0}")]
424 AddressResolution(#[source] std::io::Error),
425
426 #[error("no suitable addresses found to bind to after attempting all resolved addresses")]
427 NoAddressesToBind,
428
429 #[error("failed to create new socket for address {addr:?}: {source}")]
430 SocketCreation {
431 addr: SocketAddr,
432 #[source]
433 source: std::io::Error,
434 },
435
436 #[error("failed to bind socket to address {addr}: {source}")]
437 SocketBind {
438 addr: SocketAddr,
439 #[source]
440 source: std::io::Error,
441 },
442}
443
444#[derive(Debug, Clone, Copy)]
445struct MaxBufferSize {
446 send: usize,
447 recv: usize,
448}
449
450impl MaxBufferSize {
451 #[cfg(target_os = "linux")]
452 pub fn read() -> Result<Option<Self>> {
453 const WMEM: &str = "wmem_max";
454 const RMEM: &str = "rmem_max";
455
456 #[cfg(any(feature = "test", test))]
457 let proc_path = std::env::var("MOCK_PROC_PATH").unwrap_or_else(|_| "/proc".to_string());
458 #[cfg(not(any(feature = "test", test)))]
459 let proc_path = "/proc";
460 let proc_path = std::path::Path::new(&proc_path).join("sys/net/core");
461
462 let read_and_parse = |file_name: &str| -> Result<Option<usize>> {
463 let path = proc_path.join(file_name);
464 if !path.exists() {
465 tracing::warn!("{} not found", path.display());
466 return Ok(None);
467 }
468 let res = std::fs::read_to_string(&path)
469 .with_context(|| format!("Failed to read {}", path.display()))?
470 .trim()
471 .parse()
472 .with_context(|| format!("Failed to parse {}", path.display()))?;
473 Ok(Some(res))
474 };
475
476 let rmem = read_and_parse(RMEM)?;
477 let wmem = read_and_parse(WMEM)?;
478
479 Ok(rmem.zip(wmem).map(|(recv, send)| Self { send, recv }))
480 }
481
482 #[cfg(not(target_os = "linux"))]
483 pub fn read() -> std::io::Result<Option<Self>> {
484 Ok(None)
485 }
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
489pub enum ConnectionError {
490 #[error("invalid address")]
491 InvalidAddress,
492 #[error("connection init failed")]
493 ConnectionInitFailed,
494 #[error("invalid certificate")]
495 InvalidCertificate,
496 #[error("handshake failed")]
497 HandshakeFailed,
498 #[error("connection timeout")]
499 Timeout,
500 #[error("network has been shutdown")]
501 Shutdown,
502}
503
504#[cfg(test)]
505mod tests {
506 use futures_util::StreamExt;
507 use futures_util::stream::FuturesUnordered;
508
509 use super::*;
510 use crate::types::{BoxCloneService, PeerInfo, Request, service_message_fn, service_query_fn};
511 use crate::util::{NetworkExt, UnknownPeerError};
512
513 fn echo_service() -> BoxCloneService<ServiceRequest, Response> {
514 let handle = |request: ServiceRequest| async move {
515 tracing::trace!("received: {}", request.body.escape_ascii());
516 let response = Response {
517 version: Default::default(),
518 body: request.body,
519 };
520 Some(response)
521 };
522 service_query_fn(handle).boxed_clone()
523 }
524
525 fn make_network() -> Result<Network> {
526 Network::builder()
527 .with_config(NetworkConfig {
528 enable_0rtt: true,
529 ..Default::default()
530 })
531 .with_random_private_key()
532 .build("127.0.0.1:0", echo_service())
533 }
534
535 fn make_peer_info(network: &Network) -> Arc<PeerInfo> {
536 Arc::new(PeerInfo {
537 id: *network.peer_id(),
538 address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
539 created_at: 0,
540 expires_at: u32::MAX,
541 signature: Box::new([0; 64]),
542 })
543 }
544
545 #[tokio::test]
546 async fn connection_manager_works() -> Result<()> {
547 tycho_util::test::init_logger("connection_manager_works", "debug");
548
549 let peer1 = make_network()?;
550 let peer2 = make_network()?;
551
552 peer1
553 .connect(peer2.local_addr(), peer2.peer_id())
554 .await
555 .unwrap();
556 peer2
557 .connect(peer1.local_addr(), peer1.peer_id())
558 .await
559 .unwrap();
560
561 Ok(())
562 }
563
564 #[tokio::test]
565 async fn invalid_peer_id_detectable() -> Result<()> {
566 tycho_util::test::init_logger("invalid_peer_id_detectable", "debug");
567
568 let peer1 = make_network()?;
569 let peer2 = make_network()?;
570
571 let make_invalid_peer_info = |network: &Network| {
572 Arc::new(PeerInfo {
573 id: PeerId([0; 32]),
574 address_list: vec![network.remote_addr().clone()].into_boxed_slice(),
575 created_at: 0,
576 expires_at: u32::MAX,
577 signature: Box::new([0; 64]),
578 })
579 };
580 let _handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
581 let _handle = peer1
582 .known_peers()
583 .insert(make_invalid_peer_info(&peer2), false)?;
584
585 let _handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
586 let _handle = peer2
587 .known_peers()
588 .insert(make_invalid_peer_info(&peer1), false)?;
589
590 let req = Request {
591 version: Default::default(),
592 body: "hello".into(),
593 };
594
595 peer1.query(peer2.peer_id(), req.clone()).await?;
596 peer2.query(peer1.peer_id(), req.clone()).await?;
597
598 fn assert_is_invalid_certificate(e: anyhow::Error) {
599 let e = (*e).downcast_ref::<ConnectionError>().unwrap();
601 assert_eq!(*e, ConnectionError::InvalidCertificate);
602 }
603
604 let err = peer1
605 .query(&PeerId([0; 32]), req.clone())
606 .await
607 .map(|_| ())
608 .unwrap_err();
609 assert_is_invalid_certificate(err);
610
611 let err = peer2
612 .query(&PeerId([0; 32]), req.clone())
613 .await
614 .map(|_| ())
615 .unwrap_err();
616 assert_is_invalid_certificate(err);
617
618 fn assert_is_unknown_peer(e: anyhow::Error, peer_id: &PeerId) {
619 let e = (*e).downcast_ref::<UnknownPeerError>().unwrap();
621 assert_eq!(e, &UnknownPeerError { peer_id: *peer_id });
622 }
623
624 let invalid_peer_id = PeerId([0xff; 32]);
625 let err = peer1
626 .query(&invalid_peer_id, req.clone())
627 .await
628 .map(|_| ())
629 .unwrap_err();
630 assert_is_unknown_peer(err, &invalid_peer_id);
631
632 Ok(())
633 }
634
635 #[tokio::test]
636 async fn simultaneous_queries() -> Result<()> {
637 tycho_util::test::init_logger("simultaneous_queries", "debug");
638
639 for _ in 0..10 {
640 let peer1 = make_network()?;
641 let peer2 = make_network()?;
642
643 let _peer1_peer2_handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?;
644 let _peer2_peer1_handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?;
645
646 let req = Request {
647 version: Default::default(),
648 body: "hello".into(),
649 };
650 let peer1_fut = std::pin::pin!(peer1.query(peer2.peer_id(), req.clone()));
651 let peer2_fut = std::pin::pin!(peer2.query(peer1.peer_id(), req.clone()));
652
653 let (res1, res2) = futures_util::future::join(peer1_fut, peer2_fut).await;
654 assert_eq!(res1?.body, req.body);
655 assert_eq!(res2?.body, req.body);
656 }
657
658 Ok(())
659 }
660
661 #[tokio::test(flavor = "multi_thread")]
662 async fn uni_message_handler() -> Result<()> {
663 tycho_util::test::init_logger("uni_message_handler", "debug");
664
665 fn noop_service() -> BoxCloneService<ServiceRequest, Response> {
666 let handle = |request: ServiceRequest| async move {
667 tracing::trace!("received: {} bytes", request.body.len());
668 };
669 service_message_fn(handle).boxed_clone()
670 }
671
672 fn make_network() -> Result<Network> {
673 Network::builder()
674 .with_config(NetworkConfig {
675 enable_0rtt: true,
676 ..Default::default()
677 })
678 .with_random_private_key()
679 .build("127.0.0.1:0", noop_service())
680 }
681
682 let left = make_network()?;
683 let right = make_network()?;
684
685 let _left_to_right = left.known_peers().insert(make_peer_info(&right), false)?;
686 let _right_to_left = right.known_peers().insert(make_peer_info(&left), false)?;
687
688 let req = Request {
689 version: Default::default(),
690 body: vec![0xff; 750 * 1024].into(),
691 };
692
693 for _ in 0..10 {
694 let mut futures = FuturesUnordered::new();
695 for _ in 0..100 {
696 futures.push(left.send(right.peer_id(), req.clone()));
697 }
698
699 while let Some(res) = futures.next().await {
700 res?;
701 }
702 }
703
704 Ok(())
705 }
706
707 #[test]
708 fn socket_size_works() {
709 if std::path::Path::new("/proc").exists() {
710 let socket_size = MaxBufferSize::read()
711 .unwrap()
712 .expect("socket size not found");
713 assert!(socket_size.send > 0);
714 assert!(socket_size.recv > 0);
715 } else {
716 let procfs = tempfile::tempdir().unwrap();
718 unsafe { std::env::set_var("MOCK_PROC_PATH", procfs.path()) };
719
720 std::fs::create_dir_all(procfs.path().join("sys/net/core")).unwrap();
721 std::fs::write(procfs.path().join("sys/net/core/wmem_max"), "100000\n").unwrap();
722 std::fs::write(procfs.path().join("sys/net/core/rmem_max"), "100000\n").unwrap();
723
724 let socket_size = MaxBufferSize::read()
725 .unwrap()
726 .expect("socket size not found");
727
728 assert_eq!(socket_size.send, 100000);
729 assert_eq!(socket_size.recv, 100000);
730 }
731 }
732}