web_transport_quinn/
client.rs

1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3
4use tokio::net::lookup_host;
5use url::{Host, Url};
6
7use crate::{ClientError, Provider, Session, ALPN};
8use quinn::{crypto::rustls::QuicClientConfig, rustls};
9use rustls::{client::danger::ServerCertVerifier, pki_types::CertificateDer};
10
11// Copies the Web options, hiding the actual implementation.
12/// Allows specifying a class of congestion control algorithm.
13pub enum CongestionControl {
14    Default,
15    Throughput,
16    LowLatency,
17}
18
19/// Construct a WebTransport [Client] using sane defaults.
20///
21/// This is optional; advanced users may use [Client::new] directly.
22pub struct ClientBuilder {
23    provider: Arc<rustls::crypto::CryptoProvider>,
24    congestion_controller:
25        Option<Arc<dyn quinn::congestion::ControllerFactory + Send + Sync + 'static>>,
26}
27
28impl ClientBuilder {
29    /// Create a Client builder, which can be used to establish multiple [Session]s.
30    pub fn new() -> Self {
31        Self {
32            provider: Arc::new(Provider::default()),
33            congestion_controller: None,
34        }
35    }
36
37    /// For compatibility with WASM. Panics if `val` is false, but does nothing else.
38    pub fn with_unreliable(self, val: bool) -> Self {
39        if !val {
40            panic!("with_unreliable must be true for quic transport");
41        }
42
43        self
44    }
45
46    /// Enable the specified congestion controller.
47    pub fn with_congestion_control(mut self, algorithm: CongestionControl) -> Self {
48        self.congestion_controller = match algorithm {
49            CongestionControl::LowLatency => {
50                Some(Arc::new(quinn::congestion::BbrConfig::default()))
51            }
52            // TODO BBR is also higher throughput in theory.
53            CongestionControl::Throughput => {
54                Some(Arc::new(quinn::congestion::CubicConfig::default()))
55            }
56            CongestionControl::Default => None,
57        };
58
59        self
60    }
61
62    /// Accept any certificate from the server if it uses a known root CA.
63    pub fn with_system_roots(self) -> Result<Client, ClientError> {
64        let mut roots = rustls::RootCertStore::empty();
65
66        let native = rustls_native_certs::load_native_certs();
67
68        // Log any errors that occurred while loading the native root certificates.
69        for err in native.errors {
70            log::warn!("failed to load root cert: {:?}", err);
71        }
72
73        // Add the platform's native root certificates.
74        for cert in native.certs {
75            if let Err(err) = roots.add(cert) {
76                log::warn!("failed to add root cert: {:?}", err);
77            }
78        }
79
80        let crypto = self
81            .builder()
82            .with_root_certificates(roots)
83            .with_no_client_auth();
84
85        self.build(crypto)
86    }
87
88    /// Supply certificates for accepted servers instead of using root CAs.
89    pub fn with_server_certificates(
90        self,
91        certs: Vec<CertificateDer>,
92    ) -> Result<Client, ClientError> {
93        let hashes = certs
94            .iter()
95            .map(|cert| Provider::sha256(cert).as_ref().to_vec());
96
97        self.with_server_certificate_hashes(hashes.collect())
98    }
99
100    /// Supply sha256 hashes for accepted certificates instead of using root CAs.
101    pub fn with_server_certificate_hashes(
102        self,
103        hashes: Vec<Vec<u8>>,
104    ) -> Result<Client, ClientError> {
105        // Use a custom fingerprint verifier.
106        let fingerprints = Arc::new(ServerFingerprints {
107            provider: self.provider.clone(),
108            fingerprints: hashes,
109        });
110
111        // Configure the crypto client.
112        let crypto = self
113            .builder()
114            .dangerous()
115            .with_custom_certificate_verifier(fingerprints.clone())
116            .with_no_client_auth();
117
118        self.build(crypto)
119    }
120
121    /// Ignore the server's provided certificate, always accepting it.
122    ///
123    /// # Safety
124    /// This makes the connection vulnerable to man-in-the-middle attacks.
125    /// Only use it in secure environments, such as in local development or over a VPN connection.
126    pub unsafe fn with_no_certificate_verification(self) -> Result<Client, ClientError> {
127        let noop = NoCertificateVerification(self.provider.clone());
128
129        let crypto = self
130            .builder()
131            .dangerous()
132            .with_custom_certificate_verifier(Arc::new(noop))
133            .with_no_client_auth();
134
135        self.build(crypto)
136    }
137
138    fn builder(&self) -> rustls::ConfigBuilder<rustls::ClientConfig, rustls::WantsVerifier> {
139        rustls::ClientConfig::builder_with_provider(self.provider.clone())
140            .with_protocol_versions(&[&rustls::version::TLS13])
141            .unwrap()
142    }
143
144    fn build(self, mut crypto: rustls::ClientConfig) -> Result<Client, ClientError> {
145        crypto.alpn_protocols = vec![ALPN.as_bytes().to_vec()];
146
147        let client_config = QuicClientConfig::try_from(crypto).unwrap();
148        let mut client_config = quinn::ClientConfig::new(Arc::new(client_config));
149
150        let mut transport = quinn::TransportConfig::default();
151        if let Some(cc) = &self.congestion_controller {
152            transport.congestion_controller_factory(cc.clone());
153        }
154
155        client_config.transport_config(transport.into());
156
157        let client = quinn::Endpoint::client("[::]:0".parse().unwrap()).unwrap();
158        Ok(Client {
159            endpoint: client,
160            config: client_config,
161        })
162    }
163}
164
165impl Default for ClientBuilder {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171/// A client for connecting to a WebTransport server.
172pub struct Client {
173    endpoint: quinn::Endpoint,
174    config: quinn::ClientConfig,
175}
176
177impl Client {
178    /// Manually create a client via a Quinn endpoint and config.
179    ///
180    /// The ALPN MUST be set to [ALPN].
181    pub fn new(endpoint: quinn::Endpoint, config: quinn::ClientConfig) -> Self {
182        Self { endpoint, config }
183    }
184
185    /// Connect to the server.
186    pub async fn connect(&self, url: Url) -> Result<Session, ClientError> {
187        let port = url.port().unwrap_or(443);
188
189        // TODO error on username:password in host
190        let (host, remote) = match url
191            .host()
192            .ok_or_else(|| ClientError::InvalidDnsName("".to_string()))?
193        {
194            Host::Domain(domain) => {
195                let domain = domain.to_string();
196                // Look up the DNS entry.
197                let mut remotes = match lookup_host((domain.clone(), port)).await {
198                    Ok(remotes) => remotes,
199                    Err(_) => return Err(ClientError::InvalidDnsName(domain)),
200                };
201
202                // Return the first entry.
203                let remote = match remotes.next() {
204                    Some(remote) => remote,
205                    None => return Err(ClientError::InvalidDnsName(domain)),
206                };
207
208                (domain, remote)
209            }
210            Host::Ipv4(ipv4) => (ipv4.to_string(), SocketAddr::new(IpAddr::V4(ipv4), port)),
211            Host::Ipv6(ipv6) => (ipv6.to_string(), SocketAddr::new(IpAddr::V6(ipv6), port)),
212        };
213
214        // Connect to the server using the addr we just resolved.
215        let conn = self
216            .endpoint
217            .connect_with(self.config.clone(), remote, &host)?;
218        let conn = conn.await?;
219
220        // Connect with the connection we established.
221        Session::connect(conn, url).await
222    }
223}
224
225impl Default for Client {
226    fn default() -> Self {
227        ClientBuilder::new().with_system_roots().unwrap()
228    }
229}
230
231#[derive(Debug)]
232struct ServerFingerprints {
233    provider: Arc<rustls::crypto::CryptoProvider>,
234    fingerprints: Vec<Vec<u8>>,
235}
236
237impl ServerCertVerifier for ServerFingerprints {
238    fn verify_server_cert(
239        &self,
240        end_entity: &rustls::pki_types::CertificateDer<'_>,
241        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
242        _server_name: &rustls::pki_types::ServerName<'_>,
243        _ocsp_response: &[u8],
244        _now: rustls::pki_types::UnixTime,
245    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
246        let cert_hash = Provider::sha256(end_entity);
247        if self
248            .fingerprints
249            .iter()
250            .any(|fingerprint| fingerprint == cert_hash.as_ref())
251        {
252            return Ok(rustls::client::danger::ServerCertVerified::assertion());
253        }
254
255        Err(rustls::Error::InvalidCertificate(
256            rustls::CertificateError::UnknownIssuer,
257        ))
258    }
259
260    fn verify_tls12_signature(
261        &self,
262        message: &[u8],
263        cert: &rustls::pki_types::CertificateDer<'_>,
264        dss: &rustls::DigitallySignedStruct,
265    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266        rustls::crypto::verify_tls12_signature(
267            message,
268            cert,
269            dss,
270            &self.provider.signature_verification_algorithms,
271        )
272    }
273
274    fn verify_tls13_signature(
275        &self,
276        message: &[u8],
277        cert: &rustls::pki_types::CertificateDer<'_>,
278        dss: &rustls::DigitallySignedStruct,
279    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
280        rustls::crypto::verify_tls13_signature(
281            message,
282            cert,
283            dss,
284            &self.provider.signature_verification_algorithms,
285        )
286    }
287
288    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
289        self.provider
290            .signature_verification_algorithms
291            .supported_schemes()
292    }
293}
294
295#[derive(Debug)]
296pub struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
297
298impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
299    fn verify_server_cert(
300        &self,
301        _end_entity: &CertificateDer<'_>,
302        _intermediates: &[CertificateDer<'_>],
303        _server_name: &rustls::pki_types::ServerName<'_>,
304        _ocsp: &[u8],
305        _now: rustls::pki_types::UnixTime,
306    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
307        Ok(rustls::client::danger::ServerCertVerified::assertion())
308    }
309
310    fn verify_tls12_signature(
311        &self,
312        message: &[u8],
313        cert: &CertificateDer<'_>,
314        dss: &rustls::DigitallySignedStruct,
315    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
316        rustls::crypto::verify_tls12_signature(
317            message,
318            cert,
319            dss,
320            &self.0.signature_verification_algorithms,
321        )
322    }
323
324    fn verify_tls13_signature(
325        &self,
326        message: &[u8],
327        cert: &CertificateDer<'_>,
328        dss: &rustls::DigitallySignedStruct,
329    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
330        rustls::crypto::verify_tls13_signature(
331            message,
332            cert,
333            dss,
334            &self.0.signature_verification_algorithms,
335        )
336    }
337
338    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
339        self.0.signature_verification_algorithms.supported_schemes()
340    }
341}