1use std::sync::Arc;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
6use rustls::crypto::CryptoProvider;
7use rustls::sign::CertifiedKey;
8use rustls::SupportedCipherSuite;
9use serde::{Deserialize, Serialize};
10use tycho_util::serde_helpers;
11
12use crate::network::crypto::{
13 generate_cert, peer_id_from_certificate, CertVerifier, CertVerifierWithPeerId,
14 SUPPORTED_SIG_ALGS,
15};
16use crate::types::PeerId;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(default)]
20#[non_exhaustive]
21pub struct NetworkConfig {
22 pub quic: Option<QuicConfig>,
23
24 pub connection_manager_channel_capacity: usize,
26
27 #[serde(with = "serde_helpers::humantime")]
29 pub connectivity_check_interval: Duration,
30
31 pub max_frame_size: bytesize::ByteSize,
33
34 #[serde(with = "serde_helpers::humantime")]
36 pub connect_timeout: Duration,
37
38 #[serde(with = "serde_helpers::humantime")]
40 pub connection_backoff: Duration,
41
42 #[serde(with = "serde_helpers::humantime")]
44 pub max_connection_backoff: Duration,
45
46 #[serde(with = "serde_helpers::humantime")]
50 pub connection_error_delay: Duration,
51
52 pub max_concurrent_outstanding_connections: usize,
54
55 pub max_concurrent_connections: Option<usize>,
57
58 pub active_peers_event_channel_capacity: usize,
60
61 #[serde(with = "serde_helpers::humantime")]
63 pub shutdown_idle_timeout: Duration,
64
65 pub enable_0rtt: bool,
67}
68
69impl Default for NetworkConfig {
70 fn default() -> Self {
71 Self {
72 quic: None,
73 connection_manager_channel_capacity: 128,
74 connectivity_check_interval: Duration::from_millis(5000),
75 max_frame_size: bytesize::ByteSize::mib(8),
76 connect_timeout: Duration::from_secs(10),
77 connection_backoff: Duration::from_secs(10),
78 max_connection_backoff: Duration::from_secs(60),
79 connection_error_delay: Duration::from_secs(3),
80 max_concurrent_outstanding_connections: 100,
81 max_concurrent_connections: None,
82 active_peers_event_channel_capacity: 128,
83 shutdown_idle_timeout: Duration::from_secs(60),
84 enable_0rtt: false,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(default)]
91pub struct QuicConfig {
92 pub max_concurrent_bidi_streams: u64,
94 pub max_concurrent_uni_streams: u64,
96 pub stream_receive_window: Option<u64>,
98 pub receive_window: Option<u64>,
100 pub send_window: Option<u64>,
102
103 pub socket_send_buffer_size: Option<usize>,
106 pub socket_recv_buffer_size: Option<usize>,
108 pub use_pmtu: bool,
110}
111
112impl Default for QuicConfig {
113 fn default() -> Self {
114 Self {
115 max_concurrent_bidi_streams: 100,
116 max_concurrent_uni_streams: 100,
117 stream_receive_window: None,
118 receive_window: None,
119 send_window: None,
120 socket_send_buffer_size: None,
121 socket_recv_buffer_size: None,
122 use_pmtu: true,
123 }
124 }
125}
126
127impl QuicConfig {
128 pub fn make_transport_config(&self) -> quinn::TransportConfig {
129 fn make_varint(value: u64) -> quinn::VarInt {
130 quinn::VarInt::from_u64(value).unwrap_or(quinn::VarInt::MAX)
131 }
132
133 let mut config = quinn::TransportConfig::default();
134 config.max_concurrent_bidi_streams(make_varint(self.max_concurrent_bidi_streams));
135 config.max_concurrent_uni_streams(make_varint(self.max_concurrent_uni_streams));
136
137 if let Some(stream_receive_window) = self.stream_receive_window {
138 config.stream_receive_window(make_varint(stream_receive_window));
139 }
140 if let Some(receive_window) = self.receive_window {
141 config.receive_window(make_varint(receive_window));
142 }
143 if let Some(send_window) = self.send_window {
144 config.receive_window(make_varint(send_window));
145 }
146 if self.use_pmtu {
147 let mtu = quinn::MtuDiscoveryConfig::default();
148 config.mtu_discovery_config(Some(mtu));
149 }
150
151 config
152 }
153}
154
155pub(crate) struct EndpointConfig {
156 pub peer_id: PeerId,
157 pub cert_resolver: Arc<rustls::client::AlwaysResolvesClientRawPublicKeys>,
158 pub quinn_server_config: quinn::ServerConfig,
159 pub transport_config: Arc<quinn::TransportConfig>,
160 pub quinn_endpoint_config: quinn::EndpointConfig,
161 pub enable_early_data: bool,
162 pub crypto_provider: Arc<CryptoProvider>,
163}
164
165impl EndpointConfig {
166 pub fn builder() -> EndpointConfigBuilder<((),)> {
167 EndpointConfigBuilder {
168 mandatory_fields: ((),),
169 optional_fields: Default::default(),
170 }
171 }
172
173 pub fn make_client_config_for_peer_id(&self, peer_id: &PeerId) -> Result<quinn::ClientConfig> {
174 let mut client_config =
175 rustls::ClientConfig::builder_with_provider(self.crypto_provider.clone())
176 .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
177 .unwrap()
178 .dangerous()
179 .with_custom_certificate_verifier(Arc::new(CertVerifierWithPeerId::new(peer_id)))
180 .with_client_cert_resolver(self.cert_resolver.clone());
181
182 client_config.enable_early_data = self.enable_early_data;
183 let quinn_config = QuicClientConfig::try_from(client_config)?;
184
185 let mut client = quinn::ClientConfig::new(Arc::new(quinn_config));
186 client.transport_config(self.transport_config.clone());
187 Ok(client)
188 }
189}
190
191pub(crate) struct EndpointConfigBuilder<MandatoryFields = ([u8; 32],)> {
192 mandatory_fields: MandatoryFields,
193 optional_fields: EndpointConfigBuilderFields,
194}
195
196#[derive(Default)]
197struct EndpointConfigBuilderFields {
198 enable_0rtt: bool,
199 transport_config: Option<quinn::TransportConfig>,
200}
201
202impl<MandatoryFields> EndpointConfigBuilder<MandatoryFields> {
203 pub fn with_0rtt_enabled(mut self, enable_0rtt: bool) -> Self {
204 self.optional_fields.enable_0rtt = enable_0rtt;
205 self
206 }
207
208 pub fn with_transport_config(mut self, transport_config: quinn::TransportConfig) -> Self {
209 self.optional_fields.transport_config = Some(transport_config);
210 self
211 }
212}
213
214impl EndpointConfigBuilder<((),)> {
215 pub fn with_private_key(self, private_key: [u8; 32]) -> EndpointConfigBuilder<([u8; 32],)> {
216 EndpointConfigBuilder {
217 mandatory_fields: (private_key,),
218 optional_fields: self.optional_fields,
219 }
220 }
221}
222
223impl EndpointConfigBuilder {
224 pub fn build(self) -> Result<EndpointConfig> {
225 let (private_key,) = self.mandatory_fields;
226
227 let keypair = ed25519::KeypairBytes {
228 secret_key: private_key,
229 public_key: None,
230 };
231
232 let transport_config = Arc::new(self.optional_fields.transport_config.unwrap_or_default());
233
234 let reset_key = compute_reset_key(&keypair.secret_key);
235 let quinn_endpoint_config = quinn::EndpointConfig::new(reset_key);
236
237 let crypto_provider = Arc::new(CryptoProvider {
238 cipher_suites: DEFAULT_CIPHER_SUITES.to_vec(),
239 kx_groups: DEFAULT_KX_GROUPS.to_vec(),
240 signature_verification_algorithms: SUPPORTED_SIG_ALGS,
241 ..rustls::crypto::ring::default_provider()
242 });
243
244 let certified_key = generate_cert(&keypair, crypto_provider.key_provider)
245 .context("Failed to generate a certificate")?;
246
247 let cert_resolver = Arc::new(rustls::client::AlwaysResolvesClientRawPublicKeys::new(
248 certified_key.clone(),
249 ));
250 let cert_verifier = Arc::new(CertVerifier);
251
252 let quinn_server_config = make_server_config(
253 certified_key.clone(),
254 cert_verifier,
255 transport_config.clone(),
256 crypto_provider.clone(),
257 self.optional_fields.enable_0rtt,
258 )?;
259
260 let peer_id = peer_id_from_certificate(certified_key.end_entity_cert()?)?;
261
262 Ok(EndpointConfig {
263 peer_id,
264 cert_resolver,
265 quinn_server_config,
266 transport_config,
267 quinn_endpoint_config,
268 enable_early_data: self.optional_fields.enable_0rtt,
269 crypto_provider,
270 })
271 }
272}
273
274fn make_server_config(
275 certified_key: Arc<CertifiedKey>,
276 cert_verifier: Arc<CertVerifier>,
277 transport_config: Arc<quinn::TransportConfig>,
278 crypto_provider: Arc<CryptoProvider>,
279 enable_0rtt: bool,
280) -> Result<quinn::ServerConfig> {
281 let server_cert_resolver =
282 rustls::server::AlwaysResolvesServerRawPublicKeys::new(certified_key);
283
284 let mut server_crypto = rustls::ServerConfig::builder_with_provider(crypto_provider.clone())
285 .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS)
286 .unwrap()
287 .with_client_cert_verifier(cert_verifier)
288 .with_cert_resolver(Arc::new(server_cert_resolver));
289
290 if enable_0rtt {
291 server_crypto.max_early_data_size = u32::MAX;
292
293 }
296 let server_config = QuicServerConfig::try_from(server_crypto)?;
297
298 let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_config));
299 server.transport = transport_config;
300 Ok(server)
301}
302
303fn compute_reset_key(private_key: &[u8; 32]) -> Arc<ring::hmac::Key> {
304 const STATELESS_RESET_SALT: &[u8] = b"tycho-stateless-reset";
305
306 let salt = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, STATELESS_RESET_SALT);
307 let private_key = salt.extract(private_key);
308 let okm = private_key.expand(&[], ring::hmac::HMAC_SHA256).unwrap();
309
310 let mut reset_key = [0; 32];
311 okm.fill(&mut reset_key).unwrap();
312
313 Arc::new(ring::hmac::Key::new(ring::hmac::HMAC_SHA256, &reset_key))
314}
315
316static DEFAULT_CIPHER_SUITES: &[SupportedCipherSuite] = &[
317 rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
319 rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
320 rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
321];
322
323static DEFAULT_KX_GROUPS: &[&dyn rustls::crypto::SupportedKxGroup] =
324 &[rustls::crypto::ring::kx_group::X25519];
325
326static DEFAULT_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];