s2n_quic_rustls/
client.rs1use crate::{certificate, cipher_suite::default_crypto_provider, session::Session, Error};
5use core::convert::TryFrom;
6use rustls::{ClientConfig, ConfigBuilder, WantsVerifier};
7use s2n_codec::EncoderValue;
8use s2n_quic_core::{application::ServerName, crypto::tls};
9use std::sync::Arc;
10
11fn default_config_builder() -> Result<ConfigBuilder<ClientConfig, WantsVerifier>, rustls::Error> {
15 let tls13_cipher_suite_crypto_provider = default_crypto_provider()?;
16 ClientConfig::builder_with_provider(tls13_cipher_suite_crypto_provider.into())
17 .with_protocol_versions(crate::PROTOCOL_VERSIONS)
18}
19
20#[derive(Clone)]
21pub struct Client {
22 config: Arc<ClientConfig>,
23}
24
25impl Client {
26 #[deprecated = "client and server builders should be used instead"]
34 pub fn new(config: ClientConfig) -> Self {
35 Self {
36 config: Arc::new(config),
37 }
38 }
39
40 pub fn builder() -> Builder {
41 Builder::new()
42 }
43}
44
45impl Default for Client {
46 fn default() -> Self {
47 Self::builder()
49 .build()
50 .expect("could not create default client")
51 }
52}
53
54impl From<ClientConfig> for Client {
56 fn from(config: ClientConfig) -> Self {
57 Self::from(Arc::new(config))
58 }
59}
60
61impl From<Arc<ClientConfig>> for Client {
63 fn from(config: Arc<ClientConfig>) -> Self {
64 Self { config }
65 }
66}
67
68impl tls::Endpoint for Client {
69 type Session = Session;
70
71 fn new_server_session<Params: EncoderValue>(
72 &mut self,
73 _transport_parameters: &Params,
74 ) -> Self::Session {
75 panic!("cannot create a server session from a client config");
76 }
77
78 fn new_client_session<Params: EncoderValue>(
79 &mut self,
80 transport_parameters: &Params,
81 server_name: ServerName,
82 ) -> Self::Session {
83 let transport_parameters = transport_parameters.encode_to_vec();
86
87 let rustls_server_name = rustls::pki_types::ServerName::try_from(server_name.to_string())
88 .expect("invalid server name");
89
90 let session = rustls::quic::ClientConnection::new(
91 self.config.clone(),
92 crate::QUIC_VERSION,
93 rustls_server_name,
94 transport_parameters,
95 )
96 .expect("could not create rustls client session");
97
98 Session::new(session.into(), Some(server_name))
99 }
100
101 fn max_tag_length(&self) -> usize {
102 s2n_quic_crypto::MAX_TAG_LEN
103 }
104}
105
106pub struct Builder {
107 cert_store: rustls::RootCertStore,
108 application_protocols: Vec<Vec<u8>>,
109 key_log: Option<Arc<dyn rustls::KeyLog>>,
110}
111
112impl Default for Builder {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl Builder {
119 pub fn new() -> Self {
120 Self {
121 cert_store: rustls::RootCertStore::empty(),
122 application_protocols: vec![b"h3".to_vec()],
123 key_log: None,
124 }
125 }
126
127 pub fn with_certificate<C: certificate::IntoCertificate>(
128 mut self,
129 certificate: C,
130 ) -> Result<Self, Error> {
131 let certificates = certificate.into_certificate()?;
132 let root_certificate = certificates.0.first().ok_or_else(|| {
133 rustls::Error::General("Certificate chain needs to have at least one entry".to_string())
134 })?;
135 self.cert_store
136 .add(root_certificate.to_owned())
137 .map_err(|err| rustls::Error::General(err.to_string()))?;
138 Ok(self)
139 }
140
141 pub fn with_max_cert_chain_depth(self, len: u16) -> Result<Self, Error> {
142 let _ = len;
144 Ok(self)
145 }
146
147 pub fn with_application_protocols<P: Iterator<Item = I>, I: AsRef<[u8]>>(
148 mut self,
149 protocols: P,
150 ) -> Result<Self, rustls::Error> {
151 self.application_protocols = protocols.map(|p| p.as_ref().to_vec()).collect();
152 Ok(self)
153 }
154
155 pub fn with_key_logging(mut self) -> Result<Self, Error> {
156 self.key_log = Some(Arc::new(rustls::KeyLogFile::new()));
157 Ok(self)
158 }
159
160 pub fn build(self) -> Result<Client, Error> {
161 if self.cert_store.is_empty() {
163 return Err(
166 rustls::Error::General("missing trusted root certificate(s)".to_string()).into(),
167 );
168 }
169
170 let mut config = default_config_builder()?
171 .with_root_certificates(self.cert_store)
172 .with_no_client_auth();
173
174 config.max_fragment_size = None;
175 config.alpn_protocols = self.application_protocols;
176
177 if let Some(key_log) = self.key_log {
178 config.key_log = key_log;
179 }
180
181 #[allow(deprecated)]
182 Ok(Client::new(config))
183 }
184}