tycho_network/network/
config.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
6use rustls::crypto::CryptoProvider;
7use rustls::sign::CertifiedKey;
8use rustls::SupportedCipherSuite;
9use serde::{Deserialize, Serialize};
10use tycho_util::serde_helpers;
11
12use crate::network::crypto::{
13    generate_cert, peer_id_from_certificate, CertVerifier, CertVerifierWithPeerId,
14    SUPPORTED_SIG_ALGS,
15};
16use crate::types::PeerId;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(default)]
20#[non_exhaustive]
21pub struct NetworkConfig {
22    pub quic: Option<QuicConfig>,
23
24    /// Default: 128.
25    pub connection_manager_channel_capacity: usize,
26
27    /// Default: 5 seconds.
28    #[serde(with = "serde_helpers::humantime")]
29    pub connectivity_check_interval: Duration,
30
31    /// Default: 8 MiB.
32    pub max_frame_size: bytesize::ByteSize,
33
34    /// Default: 10 seconds.
35    #[serde(with = "serde_helpers::humantime")]
36    pub connect_timeout: Duration,
37
38    /// Default: 10 seconds.
39    #[serde(with = "serde_helpers::humantime")]
40    pub connection_backoff: Duration,
41
42    /// Default: 1 minute.
43    #[serde(with = "serde_helpers::humantime")]
44    pub max_connection_backoff: Duration,
45
46    /// Optimistic guess for some errors that there will be an incoming connection.
47    ///
48    /// Default: 3 seconds.
49    #[serde(with = "serde_helpers::humantime")]
50    pub connection_error_delay: Duration,
51
52    /// Default: 100.
53    pub max_concurrent_outstanding_connections: usize,
54
55    /// Default: unlimited.
56    pub max_concurrent_connections: Option<usize>,
57
58    /// Default: 128.
59    pub active_peers_event_channel_capacity: usize,
60
61    /// Default: 1 minute.
62    #[serde(with = "serde_helpers::humantime")]
63    pub shutdown_idle_timeout: Duration,
64
65    /// Default: no.
66    pub enable_0rtt: bool,
67}
68
69impl Default for NetworkConfig {
70    fn default() -> Self {
71        Self {
72            quic: None,
73            connection_manager_channel_capacity: 128,
74            connectivity_check_interval: Duration::from_millis(5000),
75            max_frame_size: bytesize::ByteSize::mib(8),
76            connect_timeout: Duration::from_secs(10),
77            connection_backoff: Duration::from_secs(10),
78            max_connection_backoff: Duration::from_secs(60),
79            connection_error_delay: Duration::from_secs(3),
80            max_concurrent_outstanding_connections: 100,
81            max_concurrent_connections: None,
82            active_peers_event_channel_capacity: 128,
83            shutdown_idle_timeout: Duration::from_secs(60),
84            enable_0rtt: false,
85        }
86    }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(default)]
91pub struct QuicConfig {
92    /// Default: 100.
93    pub max_concurrent_bidi_streams: u64,
94    /// Default: 100.
95    pub max_concurrent_uni_streams: u64,
96    /// Default: auto.
97    pub stream_receive_window: Option<u64>,
98    /// Default: auto.
99    pub receive_window: Option<u64>,
100    /// Default: auto.
101    pub send_window: Option<u64>,
102
103    // TODO: add all other fields from quin::TransportConfig
104    /// Default: auto.
105    pub socket_send_buffer_size: Option<usize>,
106    /// Default: auto.
107    pub socket_recv_buffer_size: Option<usize>,
108    /// Default: true.
109    pub use_pmtu: bool,
110}
111
112impl Default for QuicConfig {
113    fn default() -> Self {
114        Self {
115            max_concurrent_bidi_streams: 100,
116            max_concurrent_uni_streams: 100,
117            stream_receive_window: None,
118            receive_window: None,
119            send_window: None,
120            socket_send_buffer_size: None,
121            socket_recv_buffer_size: None,
122            use_pmtu: true,
123        }
124    }
125}
126
127impl QuicConfig {
128    pub fn make_transport_config(&self) -> quinn::TransportConfig {
129        fn make_varint(value: u64) -> quinn::VarInt {
130            quinn::VarInt::from_u64(value).unwrap_or(quinn::VarInt::MAX)
131        }
132
133        let mut config = quinn::TransportConfig::default();
134        config.max_concurrent_bidi_streams(make_varint(self.max_concurrent_bidi_streams));
135        config.max_concurrent_uni_streams(make_varint(self.max_concurrent_uni_streams));
136
137        if let Some(stream_receive_window) = self.stream_receive_window {
138            config.stream_receive_window(make_varint(stream_receive_window));
139        }
140        if let Some(receive_window) = self.receive_window {
141            config.receive_window(make_varint(receive_window));
142        }
143        if let Some(send_window) = self.send_window {
144            config.receive_window(make_varint(send_window));
145        }
146        if self.use_pmtu {
147            let mtu = quinn::MtuDiscoveryConfig::default();
148            config.mtu_discovery_config(Some(mtu));
149        }
150
151        config
152    }
153}
154
155pub(crate) struct EndpointConfig {
156    pub peer_id: PeerId,
157    pub cert_resolver: Arc<rustls::client::AlwaysResolvesClientRawPublicKeys>,
158    pub quinn_server_config: quinn::ServerConfig,
159    pub transport_config: Arc<quinn::TransportConfig>,
160    pub quinn_endpoint_config: quinn::EndpointConfig,
161    pub enable_early_data: bool,
162    pub crypto_provider: Arc<CryptoProvider>,
163}
164
165impl EndpointConfig {
166    pub fn builder() -> EndpointConfigBuilder<((),)> {
167        EndpointConfigBuilder {
168            mandatory_fields: ((),),
169            optional_fields: Default::default(),
170        }
171    }
172
173    pub fn make_client_config_for_peer_id(&self, peer_id: &PeerId) -> Result<quinn::ClientConfig> {
174        let mut client_config =
175            rustls::ClientConfig::builder_with_provider(self.crypto_provider.clone())
176                .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
177                .unwrap()
178                .dangerous()
179                .with_custom_certificate_verifier(Arc::new(CertVerifierWithPeerId::new(peer_id)))
180                .with_client_cert_resolver(self.cert_resolver.clone());
181
182        client_config.enable_early_data = self.enable_early_data;
183        let quinn_config = QuicClientConfig::try_from(client_config)?;
184
185        let mut client = quinn::ClientConfig::new(Arc::new(quinn_config));
186        client.transport_config(self.transport_config.clone());
187        Ok(client)
188    }
189}
190
191pub(crate) struct EndpointConfigBuilder<MandatoryFields = ([u8; 32],)> {
192    mandatory_fields: MandatoryFields,
193    optional_fields: EndpointConfigBuilderFields,
194}
195
196#[derive(Default)]
197struct EndpointConfigBuilderFields {
198    enable_0rtt: bool,
199    transport_config: Option<quinn::TransportConfig>,
200}
201
202impl<MandatoryFields> EndpointConfigBuilder<MandatoryFields> {
203    pub fn with_0rtt_enabled(mut self, enable_0rtt: bool) -> Self {
204        self.optional_fields.enable_0rtt = enable_0rtt;
205        self
206    }
207
208    pub fn with_transport_config(mut self, transport_config: quinn::TransportConfig) -> Self {
209        self.optional_fields.transport_config = Some(transport_config);
210        self
211    }
212}
213
214impl EndpointConfigBuilder<((),)> {
215    pub fn with_private_key(self, private_key: [u8; 32]) -> EndpointConfigBuilder<([u8; 32],)> {
216        EndpointConfigBuilder {
217            mandatory_fields: (private_key,),
218            optional_fields: self.optional_fields,
219        }
220    }
221}
222
223impl EndpointConfigBuilder {
224    pub fn build(self) -> Result<EndpointConfig> {
225        let (private_key,) = self.mandatory_fields;
226
227        let keypair = ed25519::KeypairBytes {
228            secret_key: private_key,
229            public_key: None,
230        };
231
232        let transport_config = Arc::new(self.optional_fields.transport_config.unwrap_or_default());
233
234        let reset_key = compute_reset_key(&keypair.secret_key);
235        let quinn_endpoint_config = quinn::EndpointConfig::new(reset_key);
236
237        let crypto_provider = Arc::new(CryptoProvider {
238            cipher_suites: DEFAULT_CIPHER_SUITES.to_vec(),
239            kx_groups: DEFAULT_KX_GROUPS.to_vec(),
240            signature_verification_algorithms: SUPPORTED_SIG_ALGS,
241            ..rustls::crypto::ring::default_provider()
242        });
243
244        let certified_key = generate_cert(&keypair, crypto_provider.key_provider)
245            .context("Failed to generate a certificate")?;
246
247        let cert_resolver = Arc::new(rustls::client::AlwaysResolvesClientRawPublicKeys::new(
248            certified_key.clone(),
249        ));
250        let cert_verifier = Arc::new(CertVerifier);
251
252        let quinn_server_config = make_server_config(
253            certified_key.clone(),
254            cert_verifier,
255            transport_config.clone(),
256            crypto_provider.clone(),
257            self.optional_fields.enable_0rtt,
258        )?;
259
260        let peer_id = peer_id_from_certificate(certified_key.end_entity_cert()?)?;
261
262        Ok(EndpointConfig {
263            peer_id,
264            cert_resolver,
265            quinn_server_config,
266            transport_config,
267            quinn_endpoint_config,
268            enable_early_data: self.optional_fields.enable_0rtt,
269            crypto_provider,
270        })
271    }
272}
273
274fn make_server_config(
275    certified_key: Arc<CertifiedKey>,
276    cert_verifier: Arc<CertVerifier>,
277    transport_config: Arc<quinn::TransportConfig>,
278    crypto_provider: Arc<CryptoProvider>,
279    enable_0rtt: bool,
280) -> Result<quinn::ServerConfig> {
281    let server_cert_resolver =
282        rustls::server::AlwaysResolvesServerRawPublicKeys::new(certified_key);
283
284    let mut server_crypto = rustls::ServerConfig::builder_with_provider(crypto_provider.clone())
285        .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
286        .unwrap()
287        .with_client_cert_verifier(cert_verifier)
288        .with_cert_resolver(Arc::new(server_cert_resolver));
289
290    if enable_0rtt {
291        server_crypto.max_early_data_size = u32::MAX;
292
293        // TODO: Should we enable this?
294        // server_crypto.send_half_rtt_data = true;
295    }
296    let server_config = QuicServerConfig::try_from(server_crypto)?;
297
298    let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_config));
299    server.transport = transport_config;
300    Ok(server)
301}
302
303fn compute_reset_key(private_key: &[u8; 32]) -> Arc<ring::hmac::Key> {
304    const STATELESS_RESET_SALT: &[u8] = b"tycho-stateless-reset";
305
306    let salt = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, STATELESS_RESET_SALT);
307    let private_key = salt.extract(private_key);
308    let okm = private_key.expand(&[], ring::hmac::HMAC_SHA256).unwrap();
309
310    let mut reset_key = [0; 32];
311    okm.fill(&mut reset_key).unwrap();
312
313    Arc::new(ring::hmac::Key::new(ring::hmac::HMAC_SHA256, &reset_key))
314}
315
316static DEFAULT_CIPHER_SUITES: &[SupportedCipherSuite] = &[
317    // TLS1.3 suites
318    rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
319    rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
320    rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
321];
322
323static DEFAULT_KX_GROUPS: &[&dyn rustls::crypto::SupportedKxGroup] =
324    &[rustls::crypto::ring::kx_group::X25519];
325
326static DEFAULT_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];