tycho_network/network/
config.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use quinn::congestion::{self, ControllerFactory};
6use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
7use rustls::SupportedCipherSuite;
8use rustls::crypto::CryptoProvider;
9use rustls::sign::CertifiedKey;
10use serde::{Deserialize, Serialize};
11use tycho_util::serde_helpers;
12
13use crate::network::crypto::{
14    CertVerifier, CertVerifierWithPeerId, SUPPORTED_SIG_ALGS, generate_cert,
15    peer_id_from_certificate,
16};
17use crate::types::PeerId;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(default)]
21#[non_exhaustive]
22pub struct NetworkConfig {
23    pub quic: Option<QuicConfig>,
24
25    /// Default: 128.
26    pub connection_manager_channel_capacity: usize,
27
28    /// Default: 5 seconds.
29    #[serde(with = "serde_helpers::humantime")]
30    pub connectivity_check_interval: Duration,
31
32    /// Default: 8 MiB.
33    pub max_frame_size: bytesize::ByteSize,
34
35    /// Default: 10 seconds.
36    #[serde(with = "serde_helpers::humantime")]
37    pub connect_timeout: Duration,
38
39    /// Default: 10 seconds.
40    #[serde(with = "serde_helpers::humantime")]
41    pub connection_backoff: Duration,
42
43    /// Default: 1 minute.
44    #[serde(with = "serde_helpers::humantime")]
45    pub max_connection_backoff: Duration,
46
47    /// Optimistic guess for some errors that there will be an incoming connection.
48    ///
49    /// Default: 3 seconds.
50    #[serde(with = "serde_helpers::humantime")]
51    pub connection_error_delay: Duration,
52
53    /// Default: 100.
54    pub max_concurrent_outstanding_connections: usize,
55
56    /// Default: unlimited.
57    pub max_concurrent_connections: Option<usize>,
58
59    /// Default: 128.
60    pub active_peers_event_channel_capacity: usize,
61
62    /// Maximum number of concurrent requests (uni and bi streams) allowed from a single peer.
63    /// When this limit is reached, new incoming streams will be rejected.
64    ///
65    /// Default: 128.
66    pub max_concurrent_requests_per_peer: usize,
67
68    /// Default: 1 minute.
69    #[serde(with = "serde_helpers::humantime")]
70    pub shutdown_idle_timeout: Duration,
71
72    /// Default: no.
73    pub enable_0rtt: bool,
74
75    /// Default: disabled.
76    pub connection_metrics: Option<ConnectionMetricsLevel>,
77}
78
79impl Default for NetworkConfig {
80    fn default() -> Self {
81        Self {
82            quic: None,
83            connection_manager_channel_capacity: 128,
84            connectivity_check_interval: Duration::from_millis(5000),
85            max_frame_size: bytesize::ByteSize::mib(8),
86            connect_timeout: Duration::from_secs(10),
87            connection_backoff: Duration::from_secs(10),
88            max_connection_backoff: Duration::from_secs(60),
89            connection_error_delay: Duration::from_secs(3),
90            max_concurrent_outstanding_connections: 100,
91            max_concurrent_connections: None,
92            active_peers_event_channel_capacity: 128,
93            max_concurrent_requests_per_peer: 128,
94            shutdown_idle_timeout: Duration::from_secs(60),
95            enable_0rtt: false,
96            connection_metrics: None,
97        }
98    }
99}
100
101#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
102pub enum ConnectionMetricsLevel {
103    Brief,
104    Detailed,
105}
106
107impl ConnectionMetricsLevel {
108    pub fn should_export_peer_id(self) -> bool {
109        matches!(self, Self::Detailed)
110    }
111}
112
113#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
114pub enum CongestionAlgorithm {
115    Cubic,
116    Bbr,
117    NewReno,
118}
119
120impl CongestionAlgorithm {
121    pub fn build(self) -> Arc<dyn ControllerFactory + Send + Sync + 'static> {
122        match self {
123            CongestionAlgorithm::Cubic => Arc::new(congestion::CubicConfig::default()),
124            CongestionAlgorithm::Bbr => Arc::new(congestion::BbrConfig::default()),
125            CongestionAlgorithm::NewReno => Arc::new(congestion::NewRenoConfig::default()),
126        }
127    }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(default)]
132pub struct QuicConfig {
133    /// Default: 100.
134    pub max_concurrent_bidi_streams: u64,
135    /// Default: 100.
136    pub max_concurrent_uni_streams: u64,
137    /// Default: auto.
138    pub stream_receive_window: Option<u64>,
139    /// Default: auto.
140    pub receive_window: Option<u64>,
141    /// Default: auto.
142    pub send_window: Option<u64>,
143
144    /// Whether to implement fair queuing for send streams having the same priority.
145    ///
146    /// Default: true.
147    pub send_fairness: bool,
148
149    /// Whether to use "Generic Segmentation Offload" to accelerate transmits,
150    /// when supported by the environment.
151    ///
152    /// Default: true.
153    pub enable_segmentation_offload: bool,
154
155    // TODO: add all other fields from quin::TransportConfig
156    /// Default: auto.
157    pub socket_send_buffer_size: Option<usize>,
158    /// Default: auto.
159    pub socket_recv_buffer_size: Option<usize>,
160    /// Default: true.
161    pub use_pmtu: bool,
162
163    /// Default: auto.
164    pub initial_mtu: Option<u16>,
165
166    /// Default: `Bbr`.
167    pub congestion_algorithm: CongestionAlgorithm,
168}
169
170impl Default for QuicConfig {
171    fn default() -> Self {
172        Self {
173            max_concurrent_bidi_streams: 100,
174            max_concurrent_uni_streams: 100,
175            stream_receive_window: None,
176            receive_window: None,
177            send_window: None,
178            send_fairness: true,
179            enable_segmentation_offload: true,
180            socket_send_buffer_size: None,
181            socket_recv_buffer_size: None,
182            use_pmtu: true,
183            initial_mtu: None,
184            congestion_algorithm: CongestionAlgorithm::Bbr,
185        }
186    }
187}
188
189impl QuicConfig {
190    pub fn make_transport_config(&self) -> quinn::TransportConfig {
191        fn make_varint(value: u64) -> quinn::VarInt {
192            quinn::VarInt::from_u64(value).unwrap_or(quinn::VarInt::MAX)
193        }
194
195        let mut config = quinn::TransportConfig::default();
196        config.max_concurrent_bidi_streams(make_varint(self.max_concurrent_bidi_streams));
197        config.max_concurrent_uni_streams(make_varint(self.max_concurrent_uni_streams));
198
199        config.datagram_receive_buffer_size(None);
200
201        config.enable_segmentation_offload(self.enable_segmentation_offload);
202        config.send_fairness(self.send_fairness);
203
204        if let Some(stream_receive_window) = self.stream_receive_window {
205            config.stream_receive_window(make_varint(stream_receive_window));
206        }
207        if let Some(receive_window) = self.receive_window {
208            config.receive_window(make_varint(receive_window));
209        }
210        if let Some(send_window) = self.send_window {
211            config.send_window(send_window);
212        }
213        if self.use_pmtu {
214            let mtu = quinn::MtuDiscoveryConfig::default();
215            config.mtu_discovery_config(Some(mtu));
216        }
217
218        if let Some(mtu) = self.initial_mtu {
219            config.initial_mtu(mtu);
220        }
221
222        config.congestion_controller_factory(self.congestion_algorithm.build());
223
224        config
225    }
226}
227
228pub(crate) struct EndpointConfig {
229    pub peer_id: PeerId,
230    pub cert_resolver: Arc<rustls::client::AlwaysResolvesClientRawPublicKeys>,
231    pub quinn_server_config: quinn::ServerConfig,
232    pub transport_config: Arc<quinn::TransportConfig>,
233    pub quinn_endpoint_config: quinn::EndpointConfig,
234    pub enable_early_data: bool,
235    pub crypto_provider: Arc<CryptoProvider>,
236    pub connection_metrics: Option<ConnectionMetricsLevel>,
237}
238
239impl EndpointConfig {
240    pub fn builder() -> EndpointConfigBuilder<((),)> {
241        EndpointConfigBuilder {
242            mandatory_fields: ((),),
243            optional_fields: Default::default(),
244        }
245    }
246
247    pub fn make_client_config_for_peer_id(&self, peer_id: &PeerId) -> quinn::ClientConfig {
248        let mut client_config =
249            rustls::ClientConfig::builder_with_provider(self.crypto_provider.clone())
250                .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
251                .unwrap()
252                .dangerous()
253                .with_custom_certificate_verifier(Arc::new(CertVerifierWithPeerId::new(peer_id)))
254                .with_client_cert_resolver(self.cert_resolver.clone());
255
256        client_config.enable_early_data = self.enable_early_data;
257        let quinn_config =
258            QuicClientConfig::try_from(client_config).expect("cipher suite is always provided");
259
260        let mut client = quinn::ClientConfig::new(Arc::new(quinn_config));
261        client.transport_config(self.transport_config.clone());
262        client
263    }
264}
265
266pub(crate) struct EndpointConfigBuilder<MandatoryFields = ([u8; 32],)> {
267    mandatory_fields: MandatoryFields,
268    optional_fields: EndpointConfigBuilderFields,
269}
270
271#[derive(Default)]
272struct EndpointConfigBuilderFields {
273    enable_0rtt: bool,
274    transport_config: Option<quinn::TransportConfig>,
275    connection_metrics: Option<ConnectionMetricsLevel>,
276}
277
278impl<MandatoryFields> EndpointConfigBuilder<MandatoryFields> {
279    pub fn with_0rtt_enabled(mut self, enable_0rtt: bool) -> Self {
280        self.optional_fields.enable_0rtt = enable_0rtt;
281        self
282    }
283
284    pub fn with_transport_config(mut self, transport_config: quinn::TransportConfig) -> Self {
285        self.optional_fields.transport_config = Some(transport_config);
286        self
287    }
288
289    pub fn with_connection_metrics(mut self, metrics: Option<ConnectionMetricsLevel>) -> Self {
290        self.optional_fields.connection_metrics = metrics;
291        self
292    }
293}
294
295impl EndpointConfigBuilder<((),)> {
296    pub fn with_private_key(self, private_key: [u8; 32]) -> EndpointConfigBuilder<([u8; 32],)> {
297        EndpointConfigBuilder {
298            mandatory_fields: (private_key,),
299            optional_fields: self.optional_fields,
300        }
301    }
302}
303
304impl EndpointConfigBuilder {
305    pub fn build(self) -> Result<EndpointConfig> {
306        let (private_key,) = self.mandatory_fields;
307
308        let keypair = ed25519::KeypairBytes {
309            secret_key: private_key,
310            public_key: None,
311        };
312
313        let transport_config = Arc::new(self.optional_fields.transport_config.unwrap_or_default());
314
315        let reset_key = compute_reset_key(&keypair.secret_key);
316        let quinn_endpoint_config = quinn::EndpointConfig::new(reset_key);
317
318        let crypto_provider = Arc::new(CryptoProvider {
319            cipher_suites: DEFAULT_CIPHER_SUITES.to_vec(),
320            kx_groups: DEFAULT_KX_GROUPS.to_vec(),
321            signature_verification_algorithms: SUPPORTED_SIG_ALGS,
322            ..rustls::crypto::ring::default_provider()
323        });
324
325        let certified_key = generate_cert(&keypair, crypto_provider.key_provider)
326            .context("Failed to generate a certificate")?;
327
328        let cert_resolver = Arc::new(rustls::client::AlwaysResolvesClientRawPublicKeys::new(
329            certified_key.clone(),
330        ));
331        let cert_verifier = Arc::new(CertVerifier);
332
333        let quinn_server_config = make_server_config(
334            certified_key.clone(),
335            cert_verifier,
336            transport_config.clone(),
337            crypto_provider.clone(),
338            self.optional_fields.enable_0rtt,
339        )?;
340
341        let peer_id = peer_id_from_certificate(certified_key.end_entity_cert()?)?;
342
343        Ok(EndpointConfig {
344            peer_id,
345            cert_resolver,
346            quinn_server_config,
347            transport_config,
348            quinn_endpoint_config,
349            enable_early_data: self.optional_fields.enable_0rtt,
350            crypto_provider,
351            connection_metrics: self.optional_fields.connection_metrics,
352        })
353    }
354}
355
356fn make_server_config(
357    certified_key: Arc<CertifiedKey>,
358    cert_verifier: Arc<CertVerifier>,
359    transport_config: Arc<quinn::TransportConfig>,
360    crypto_provider: Arc<CryptoProvider>,
361    enable_0rtt: bool,
362) -> Result<quinn::ServerConfig> {
363    let server_cert_resolver =
364        rustls::server::AlwaysResolvesServerRawPublicKeys::new(certified_key);
365
366    let mut server_crypto = rustls::ServerConfig::builder_with_provider(crypto_provider.clone())
367        .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
368        .unwrap()
369        .with_client_cert_verifier(cert_verifier)
370        .with_cert_resolver(Arc::new(server_cert_resolver));
371
372    if enable_0rtt {
373        server_crypto.max_early_data_size = u32::MAX;
374
375        // TODO: Should we enable this?
376        // server_crypto.send_half_rtt_data = true;
377    }
378    let server_config = QuicServerConfig::try_from(server_crypto)?;
379
380    let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_config));
381    server.transport = transport_config;
382    Ok(server)
383}
384
385fn compute_reset_key(private_key: &[u8; 32]) -> Arc<ring::hmac::Key> {
386    const STATELESS_RESET_SALT: &[u8] = b"tycho-stateless-reset";
387
388    let salt = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, STATELESS_RESET_SALT);
389    let private_key = salt.extract(private_key);
390    let okm = private_key.expand(&[], ring::hmac::HMAC_SHA256).unwrap();
391
392    let mut reset_key = [0; 32];
393    okm.fill(&mut reset_key).unwrap();
394
395    Arc::new(ring::hmac::Key::new(ring::hmac::HMAC_SHA256, &reset_key))
396}
397
398static DEFAULT_CIPHER_SUITES: &[SupportedCipherSuite] = &[
399    // TLS1.3 suites
400    rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
401    rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
402    rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
403];
404
405static DEFAULT_KX_GROUPS: &[&dyn rustls::crypto::SupportedKxGroup] =
406    &[rustls::crypto::ring::kx_group::X25519];
407
408static DEFAULT_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];