1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use quinn::congestion::{self, ControllerFactory};
6use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
7use rustls::SupportedCipherSuite;
8use rustls::crypto::CryptoProvider;
9use rustls::sign::CertifiedKey;
10use serde::{Deserialize, Serialize};
11use tycho_util::serde_helpers;
12
13use crate::network::crypto::{
14 CertVerifier, CertVerifierWithPeerId, SUPPORTED_SIG_ALGS, generate_cert,
15 peer_id_from_certificate,
16};
17use crate::types::PeerId;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(default)]
21#[non_exhaustive]
22pub struct NetworkConfig {
23 pub quic: Option<QuicConfig>,
24
25 pub connection_manager_channel_capacity: usize,
27
28 #[serde(with = "serde_helpers::humantime")]
30 pub connectivity_check_interval: Duration,
31
32 pub max_frame_size: bytesize::ByteSize,
34
35 #[serde(with = "serde_helpers::humantime")]
37 pub connect_timeout: Duration,
38
39 #[serde(with = "serde_helpers::humantime")]
41 pub connection_backoff: Duration,
42
43 #[serde(with = "serde_helpers::humantime")]
45 pub max_connection_backoff: Duration,
46
47 #[serde(with = "serde_helpers::humantime")]
51 pub connection_error_delay: Duration,
52
53 pub max_concurrent_outstanding_connections: usize,
55
56 pub max_concurrent_connections: Option<usize>,
58
59 pub active_peers_event_channel_capacity: usize,
61
62 pub max_concurrent_requests_per_peer: usize,
67
68 #[serde(with = "serde_helpers::humantime")]
70 pub shutdown_idle_timeout: Duration,
71
72 pub enable_0rtt: bool,
74
75 pub connection_metrics: Option<ConnectionMetricsLevel>,
77}
78
79impl Default for NetworkConfig {
80 fn default() -> Self {
81 Self {
82 quic: None,
83 connection_manager_channel_capacity: 128,
84 connectivity_check_interval: Duration::from_millis(5000),
85 max_frame_size: bytesize::ByteSize::mib(8),
86 connect_timeout: Duration::from_secs(10),
87 connection_backoff: Duration::from_secs(10),
88 max_connection_backoff: Duration::from_secs(60),
89 connection_error_delay: Duration::from_secs(3),
90 max_concurrent_outstanding_connections: 100,
91 max_concurrent_connections: None,
92 active_peers_event_channel_capacity: 128,
93 max_concurrent_requests_per_peer: 128,
94 shutdown_idle_timeout: Duration::from_secs(60),
95 enable_0rtt: false,
96 connection_metrics: None,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
102pub enum ConnectionMetricsLevel {
103 Brief,
104 Detailed,
105}
106
107impl ConnectionMetricsLevel {
108 pub fn should_export_peer_id(self) -> bool {
109 matches!(self, Self::Detailed)
110 }
111}
112
113#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
114pub enum CongestionAlgorithm {
115 Cubic,
116 Bbr,
117 NewReno,
118}
119
120impl CongestionAlgorithm {
121 pub fn build(self) -> Arc<dyn ControllerFactory + Send + Sync + 'static> {
122 match self {
123 CongestionAlgorithm::Cubic => Arc::new(congestion::CubicConfig::default()),
124 CongestionAlgorithm::Bbr => Arc::new(congestion::BbrConfig::default()),
125 CongestionAlgorithm::NewReno => Arc::new(congestion::NewRenoConfig::default()),
126 }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131#[serde(default)]
132pub struct QuicConfig {
133 pub max_concurrent_bidi_streams: u64,
135 pub max_concurrent_uni_streams: u64,
137 pub stream_receive_window: Option<u64>,
139 pub receive_window: Option<u64>,
141 pub send_window: Option<u64>,
143
144 pub send_fairness: bool,
148
149 pub enable_segmentation_offload: bool,
154
155 pub socket_send_buffer_size: Option<usize>,
158 pub socket_recv_buffer_size: Option<usize>,
160 pub use_pmtu: bool,
162
163 pub initial_mtu: Option<u16>,
165
166 pub congestion_algorithm: CongestionAlgorithm,
168}
169
170impl Default for QuicConfig {
171 fn default() -> Self {
172 Self {
173 max_concurrent_bidi_streams: 100,
174 max_concurrent_uni_streams: 100,
175 stream_receive_window: None,
176 receive_window: None,
177 send_window: None,
178 send_fairness: true,
179 enable_segmentation_offload: true,
180 socket_send_buffer_size: None,
181 socket_recv_buffer_size: None,
182 use_pmtu: true,
183 initial_mtu: None,
184 congestion_algorithm: CongestionAlgorithm::Bbr,
185 }
186 }
187}
188
189impl QuicConfig {
190 pub fn make_transport_config(&self) -> quinn::TransportConfig {
191 fn make_varint(value: u64) -> quinn::VarInt {
192 quinn::VarInt::from_u64(value).unwrap_or(quinn::VarInt::MAX)
193 }
194
195 let mut config = quinn::TransportConfig::default();
196 config.max_concurrent_bidi_streams(make_varint(self.max_concurrent_bidi_streams));
197 config.max_concurrent_uni_streams(make_varint(self.max_concurrent_uni_streams));
198
199 config.datagram_receive_buffer_size(None);
200
201 config.enable_segmentation_offload(self.enable_segmentation_offload);
202 config.send_fairness(self.send_fairness);
203
204 if let Some(stream_receive_window) = self.stream_receive_window {
205 config.stream_receive_window(make_varint(stream_receive_window));
206 }
207 if let Some(receive_window) = self.receive_window {
208 config.receive_window(make_varint(receive_window));
209 }
210 if let Some(send_window) = self.send_window {
211 config.send_window(send_window);
212 }
213 if self.use_pmtu {
214 let mtu = quinn::MtuDiscoveryConfig::default();
215 config.mtu_discovery_config(Some(mtu));
216 }
217
218 if let Some(mtu) = self.initial_mtu {
219 config.initial_mtu(mtu);
220 }
221
222 config.congestion_controller_factory(self.congestion_algorithm.build());
223
224 config
225 }
226}
227
228pub(crate) struct EndpointConfig {
229 pub peer_id: PeerId,
230 pub cert_resolver: Arc<rustls::client::AlwaysResolvesClientRawPublicKeys>,
231 pub quinn_server_config: quinn::ServerConfig,
232 pub transport_config: Arc<quinn::TransportConfig>,
233 pub quinn_endpoint_config: quinn::EndpointConfig,
234 pub enable_early_data: bool,
235 pub crypto_provider: Arc<CryptoProvider>,
236 pub connection_metrics: Option<ConnectionMetricsLevel>,
237}
238
239impl EndpointConfig {
240 pub fn builder() -> EndpointConfigBuilder<((),)> {
241 EndpointConfigBuilder {
242 mandatory_fields: ((),),
243 optional_fields: Default::default(),
244 }
245 }
246
247 pub fn make_client_config_for_peer_id(&self, peer_id: &PeerId) -> quinn::ClientConfig {
248 let mut client_config =
249 rustls::ClientConfig::builder_with_provider(self.crypto_provider.clone())
250 .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
251 .unwrap()
252 .dangerous()
253 .with_custom_certificate_verifier(Arc::new(CertVerifierWithPeerId::new(peer_id)))
254 .with_client_cert_resolver(self.cert_resolver.clone());
255
256 client_config.enable_early_data = self.enable_early_data;
257 let quinn_config =
258 QuicClientConfig::try_from(client_config).expect("cipher suite is always provided");
259
260 let mut client = quinn::ClientConfig::new(Arc::new(quinn_config));
261 client.transport_config(self.transport_config.clone());
262 client
263 }
264}
265
266pub(crate) struct EndpointConfigBuilder<MandatoryFields = ([u8; 32],)> {
267 mandatory_fields: MandatoryFields,
268 optional_fields: EndpointConfigBuilderFields,
269}
270
271#[derive(Default)]
272struct EndpointConfigBuilderFields {
273 enable_0rtt: bool,
274 transport_config: Option<quinn::TransportConfig>,
275 connection_metrics: Option<ConnectionMetricsLevel>,
276}
277
278impl<MandatoryFields> EndpointConfigBuilder<MandatoryFields> {
279 pub fn with_0rtt_enabled(mut self, enable_0rtt: bool) -> Self {
280 self.optional_fields.enable_0rtt = enable_0rtt;
281 self
282 }
283
284 pub fn with_transport_config(mut self, transport_config: quinn::TransportConfig) -> Self {
285 self.optional_fields.transport_config = Some(transport_config);
286 self
287 }
288
289 pub fn with_connection_metrics(mut self, metrics: Option<ConnectionMetricsLevel>) -> Self {
290 self.optional_fields.connection_metrics = metrics;
291 self
292 }
293}
294
295impl EndpointConfigBuilder<((),)> {
296 pub fn with_private_key(self, private_key: [u8; 32]) -> EndpointConfigBuilder<([u8; 32],)> {
297 EndpointConfigBuilder {
298 mandatory_fields: (private_key,),
299 optional_fields: self.optional_fields,
300 }
301 }
302}
303
304impl EndpointConfigBuilder {
305 pub fn build(self) -> Result<EndpointConfig> {
306 let (private_key,) = self.mandatory_fields;
307
308 let keypair = ed25519::KeypairBytes {
309 secret_key: private_key,
310 public_key: None,
311 };
312
313 let transport_config = Arc::new(self.optional_fields.transport_config.unwrap_or_default());
314
315 let reset_key = compute_reset_key(&keypair.secret_key);
316 let quinn_endpoint_config = quinn::EndpointConfig::new(reset_key);
317
318 let crypto_provider = Arc::new(CryptoProvider {
319 cipher_suites: DEFAULT_CIPHER_SUITES.to_vec(),
320 kx_groups: DEFAULT_KX_GROUPS.to_vec(),
321 signature_verification_algorithms: SUPPORTED_SIG_ALGS,
322 ..rustls::crypto::ring::default_provider()
323 });
324
325 let certified_key = generate_cert(&keypair, crypto_provider.key_provider)
326 .context("Failed to generate a certificate")?;
327
328 let cert_resolver = Arc::new(rustls::client::AlwaysResolvesClientRawPublicKeys::new(
329 certified_key.clone(),
330 ));
331 let cert_verifier = Arc::new(CertVerifier);
332
333 let quinn_server_config = make_server_config(
334 certified_key.clone(),
335 cert_verifier,
336 transport_config.clone(),
337 crypto_provider.clone(),
338 self.optional_fields.enable_0rtt,
339 )?;
340
341 let peer_id = peer_id_from_certificate(certified_key.end_entity_cert()?)?;
342
343 Ok(EndpointConfig {
344 peer_id,
345 cert_resolver,
346 quinn_server_config,
347 transport_config,
348 quinn_endpoint_config,
349 enable_early_data: self.optional_fields.enable_0rtt,
350 crypto_provider,
351 connection_metrics: self.optional_fields.connection_metrics,
352 })
353 }
354}
355
356fn make_server_config(
357 certified_key: Arc<CertifiedKey>,
358 cert_verifier: Arc<CertVerifier>,
359 transport_config: Arc<quinn::TransportConfig>,
360 crypto_provider: Arc<CryptoProvider>,
361 enable_0rtt: bool,
362) -> Result<quinn::ServerConfig> {
363 let server_cert_resolver =
364 rustls::server::AlwaysResolvesServerRawPublicKeys::new(certified_key);
365
366 let mut server_crypto = rustls::ServerConfig::builder_with_provider(crypto_provider.clone())
367 .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
368 .unwrap()
369 .with_client_cert_verifier(cert_verifier)
370 .with_cert_resolver(Arc::new(server_cert_resolver));
371
372 if enable_0rtt {
373 server_crypto.max_early_data_size = u32::MAX;
374
375 }
378 let server_config = QuicServerConfig::try_from(server_crypto)?;
379
380 let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_config));
381 server.transport = transport_config;
382 Ok(server)
383}
384
385fn compute_reset_key(private_key: &[u8; 32]) -> Arc<ring::hmac::Key> {
386 const STATELESS_RESET_SALT: &[u8] = b"tycho-stateless-reset";
387
388 let salt = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, STATELESS_RESET_SALT);
389 let private_key = salt.extract(private_key);
390 let okm = private_key.expand(&[], ring::hmac::HMAC_SHA256).unwrap();
391
392 let mut reset_key = [0; 32];
393 okm.fill(&mut reset_key).unwrap();
394
395 Arc::new(ring::hmac::Key::new(ring::hmac::HMAC_SHA256, &reset_key))
396}
397
398static DEFAULT_CIPHER_SUITES: &[SupportedCipherSuite] = &[
399 rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
401 rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
402 rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
403];
404
405static DEFAULT_KX_GROUPS: &[&dyn rustls::crypto::SupportedKxGroup] =
406 &[rustls::crypto::ring::kx_group::X25519];
407
408static DEFAULT_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];