1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3
4use tokio::net::lookup_host;
5use url::{Host, Url};
6
7use crate::{ClientError, Provider, Session, ALPN};
8use quinn::{crypto::rustls::QuicClientConfig, rustls};
9use rustls::{client::danger::ServerCertVerifier, pki_types::CertificateDer};
10
11pub enum CongestionControl {
14 Default,
15 Throughput,
16 LowLatency,
17}
18
19pub struct ClientBuilder {
23 provider: Arc<rustls::crypto::CryptoProvider>,
24 congestion_controller:
25 Option<Arc<dyn quinn::congestion::ControllerFactory + Send + Sync + 'static>>,
26}
27
28impl ClientBuilder {
29 pub fn new() -> Self {
31 Self {
32 provider: Arc::new(Provider::default()),
33 congestion_controller: None,
34 }
35 }
36
37 pub fn with_unreliable(self, val: bool) -> Self {
39 if !val {
40 panic!("with_unreliable must be true for quic transport");
41 }
42
43 self
44 }
45
46 pub fn with_congestion_control(mut self, algorithm: CongestionControl) -> Self {
48 self.congestion_controller = match algorithm {
49 CongestionControl::LowLatency => {
50 Some(Arc::new(quinn::congestion::BbrConfig::default()))
51 }
52 CongestionControl::Throughput => {
54 Some(Arc::new(quinn::congestion::CubicConfig::default()))
55 }
56 CongestionControl::Default => None,
57 };
58
59 self
60 }
61
62 pub fn with_system_roots(self) -> Result<Client, ClientError> {
64 let mut roots = rustls::RootCertStore::empty();
65
66 let native = rustls_native_certs::load_native_certs();
67
68 for err in native.errors {
70 log::warn!("failed to load root cert: {:?}", err);
71 }
72
73 for cert in native.certs {
75 if let Err(err) = roots.add(cert) {
76 log::warn!("failed to add root cert: {:?}", err);
77 }
78 }
79
80 let crypto = self
81 .builder()
82 .with_root_certificates(roots)
83 .with_no_client_auth();
84
85 self.build(crypto)
86 }
87
88 pub fn with_server_certificates(
90 self,
91 certs: Vec<CertificateDer>,
92 ) -> Result<Client, ClientError> {
93 let hashes = certs
94 .iter()
95 .map(|cert| Provider::sha256(cert).as_ref().to_vec());
96
97 self.with_server_certificate_hashes(hashes.collect())
98 }
99
100 pub fn with_server_certificate_hashes(
102 self,
103 hashes: Vec<Vec<u8>>,
104 ) -> Result<Client, ClientError> {
105 let fingerprints = Arc::new(ServerFingerprints {
107 provider: self.provider.clone(),
108 fingerprints: hashes,
109 });
110
111 let crypto = self
113 .builder()
114 .dangerous()
115 .with_custom_certificate_verifier(fingerprints.clone())
116 .with_no_client_auth();
117
118 self.build(crypto)
119 }
120
121 pub unsafe fn with_no_certificate_verification(self) -> Result<Client, ClientError> {
127 let noop = NoCertificateVerification(self.provider.clone());
128
129 let crypto = self
130 .builder()
131 .dangerous()
132 .with_custom_certificate_verifier(Arc::new(noop))
133 .with_no_client_auth();
134
135 self.build(crypto)
136 }
137
138 fn builder(&self) -> rustls::ConfigBuilder<rustls::ClientConfig, rustls::WantsVerifier> {
139 rustls::ClientConfig::builder_with_provider(self.provider.clone())
140 .with_protocol_versions(&[&rustls::version::TLS13])
141 .unwrap()
142 }
143
144 fn build(self, mut crypto: rustls::ClientConfig) -> Result<Client, ClientError> {
145 crypto.alpn_protocols = vec![ALPN.as_bytes().to_vec()];
146
147 let client_config = QuicClientConfig::try_from(crypto).unwrap();
148 let mut client_config = quinn::ClientConfig::new(Arc::new(client_config));
149
150 let mut transport = quinn::TransportConfig::default();
151 if let Some(cc) = &self.congestion_controller {
152 transport.congestion_controller_factory(cc.clone());
153 }
154
155 client_config.transport_config(transport.into());
156
157 let client = quinn::Endpoint::client("[::]:0".parse().unwrap()).unwrap();
158 Ok(Client {
159 endpoint: client,
160 config: client_config,
161 })
162 }
163}
164
165impl Default for ClientBuilder {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171pub struct Client {
173 endpoint: quinn::Endpoint,
174 config: quinn::ClientConfig,
175}
176
177impl Client {
178 pub fn new(endpoint: quinn::Endpoint, config: quinn::ClientConfig) -> Self {
182 Self { endpoint, config }
183 }
184
185 pub async fn connect(&self, url: Url) -> Result<Session, ClientError> {
187 let port = url.port().unwrap_or(443);
188
189 let (host, remote) = match url
191 .host()
192 .ok_or_else(|| ClientError::InvalidDnsName("".to_string()))?
193 {
194 Host::Domain(domain) => {
195 let domain = domain.to_string();
196 let mut remotes = match lookup_host((domain.clone(), port)).await {
198 Ok(remotes) => remotes,
199 Err(_) => return Err(ClientError::InvalidDnsName(domain)),
200 };
201
202 let remote = match remotes.next() {
204 Some(remote) => remote,
205 None => return Err(ClientError::InvalidDnsName(domain)),
206 };
207
208 (domain, remote)
209 }
210 Host::Ipv4(ipv4) => (ipv4.to_string(), SocketAddr::new(IpAddr::V4(ipv4), port)),
211 Host::Ipv6(ipv6) => (ipv6.to_string(), SocketAddr::new(IpAddr::V6(ipv6), port)),
212 };
213
214 let conn = self
216 .endpoint
217 .connect_with(self.config.clone(), remote, &host)?;
218 let conn = conn.await?;
219
220 Session::connect(conn, url).await
222 }
223}
224
225impl Default for Client {
226 fn default() -> Self {
227 ClientBuilder::new().with_system_roots().unwrap()
228 }
229}
230
231#[derive(Debug)]
232struct ServerFingerprints {
233 provider: Arc<rustls::crypto::CryptoProvider>,
234 fingerprints: Vec<Vec<u8>>,
235}
236
237impl ServerCertVerifier for ServerFingerprints {
238 fn verify_server_cert(
239 &self,
240 end_entity: &rustls::pki_types::CertificateDer<'_>,
241 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
242 _server_name: &rustls::pki_types::ServerName<'_>,
243 _ocsp_response: &[u8],
244 _now: rustls::pki_types::UnixTime,
245 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
246 let cert_hash = Provider::sha256(end_entity);
247 if self
248 .fingerprints
249 .iter()
250 .any(|fingerprint| fingerprint == cert_hash.as_ref())
251 {
252 return Ok(rustls::client::danger::ServerCertVerified::assertion());
253 }
254
255 Err(rustls::Error::InvalidCertificate(
256 rustls::CertificateError::UnknownIssuer,
257 ))
258 }
259
260 fn verify_tls12_signature(
261 &self,
262 message: &[u8],
263 cert: &rustls::pki_types::CertificateDer<'_>,
264 dss: &rustls::DigitallySignedStruct,
265 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266 rustls::crypto::verify_tls12_signature(
267 message,
268 cert,
269 dss,
270 &self.provider.signature_verification_algorithms,
271 )
272 }
273
274 fn verify_tls13_signature(
275 &self,
276 message: &[u8],
277 cert: &rustls::pki_types::CertificateDer<'_>,
278 dss: &rustls::DigitallySignedStruct,
279 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
280 rustls::crypto::verify_tls13_signature(
281 message,
282 cert,
283 dss,
284 &self.provider.signature_verification_algorithms,
285 )
286 }
287
288 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
289 self.provider
290 .signature_verification_algorithms
291 .supported_schemes()
292 }
293}
294
295#[derive(Debug)]
296pub struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
297
298impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
299 fn verify_server_cert(
300 &self,
301 _end_entity: &CertificateDer<'_>,
302 _intermediates: &[CertificateDer<'_>],
303 _server_name: &rustls::pki_types::ServerName<'_>,
304 _ocsp: &[u8],
305 _now: rustls::pki_types::UnixTime,
306 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
307 Ok(rustls::client::danger::ServerCertVerified::assertion())
308 }
309
310 fn verify_tls12_signature(
311 &self,
312 message: &[u8],
313 cert: &CertificateDer<'_>,
314 dss: &rustls::DigitallySignedStruct,
315 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
316 rustls::crypto::verify_tls12_signature(
317 message,
318 cert,
319 dss,
320 &self.0.signature_verification_algorithms,
321 )
322 }
323
324 fn verify_tls13_signature(
325 &self,
326 message: &[u8],
327 cert: &CertificateDer<'_>,
328 dss: &rustls::DigitallySignedStruct,
329 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
330 rustls::crypto::verify_tls13_signature(
331 message,
332 cert,
333 dss,
334 &self.0.signature_verification_algorithms,
335 )
336 }
337
338 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
339 self.0.signature_verification_algorithms.supported_schemes()
340 }
341}