1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use log::*;
6
7use crate::cipher_suite::*;
8use crate::config::*;
9use crate::conn::*;
10use crate::content::*;
11use crate::crypto::*;
12use crate::error::*;
13use crate::extension::extension_use_srtp::*;
14use crate::signature_hash_algorithm::*;
15
16use rustls::client::danger::ServerCertVerifier;
17use rustls::pki_types::CertificateDer;
18use rustls::server::danger::ClientCertVerifier;
19
20#[derive(Copy, Clone, PartialEq)]
58pub(crate) enum HandshakeState {
59 Errored,
60 Preparing,
61 Sending,
62 Waiting,
63 Finished,
64}
65
66impl fmt::Display for HandshakeState {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match *self {
69 HandshakeState::Errored => write!(f, "Errored"),
70 HandshakeState::Preparing => write!(f, "Preparing"),
71 HandshakeState::Sending => write!(f, "Sending"),
72 HandshakeState::Waiting => write!(f, "Waiting"),
73 HandshakeState::Finished => write!(f, "Finished"),
74 }
75 }
76}
77
78pub(crate) type VerifyPeerCertificateFn =
79 Arc<dyn (Fn(&[Vec<u8>], &[CertificateDer<'static>]) -> Result<()>) + Send + Sync>;
80
81pub(crate) struct HandshakeConfig {
82 pub(crate) local_psk_callback: Option<PskCallback>,
83 pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
84 pub(crate) local_cipher_suites: Vec<CipherSuiteId>, pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, pub(crate) extended_master_secret: ExtendedMasterSecretType, pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, pub(crate) server_name: String,
89 pub(crate) client_auth: ClientAuthType, pub(crate) local_certificates: Vec<Certificate>,
91 pub(crate) name_to_certificate: HashMap<String, Certificate>,
92 pub(crate) insecure_skip_verify: bool,
93 pub(crate) insecure_verification: bool,
94 pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
95 pub(crate) server_cert_verifier: Arc<dyn ServerCertVerifier>,
96 pub(crate) client_cert_verifier: Option<Arc<dyn ClientCertVerifier>>,
97 pub(crate) retransmit_interval: tokio::time::Duration,
98 pub(crate) initial_epoch: u16,
99 }
102
103pub fn gen_self_signed_root_cert() -> rustls::RootCertStore {
104 let mut certs = rustls::RootCertStore::empty();
105 certs
106 .add(
107 rcgen::generate_simple_self_signed(vec![])
108 .unwrap()
109 .cert
110 .der()
111 .to_owned(),
112 )
113 .unwrap();
114 certs
115}
116
117impl Default for HandshakeConfig {
118 fn default() -> Self {
119 HandshakeConfig {
120 local_psk_callback: None,
121 local_psk_identity_hint: None,
122 local_cipher_suites: vec![],
123 local_signature_schemes: vec![],
124 extended_master_secret: ExtendedMasterSecretType::Disable,
125 local_srtp_protection_profiles: vec![],
126 server_name: String::new(),
127 client_auth: ClientAuthType::NoClientCert,
128 local_certificates: vec![],
129 name_to_certificate: HashMap::new(),
130 insecure_skip_verify: false,
131 insecure_verification: false,
132 verify_peer_certificate: None,
133 server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
134 gen_self_signed_root_cert(),
135 ))
136 .build()
137 .unwrap(),
138 client_cert_verifier: None,
139 retransmit_interval: tokio::time::Duration::from_secs(0),
140 initial_epoch: 0,
141 }
142 }
143}
144
145impl HandshakeConfig {
146 pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
147 if self.local_certificates.is_empty() {
178 return Err(Error::ErrNoCertificates);
179 }
180
181 if self.local_certificates.len() == 1 {
182 return Ok(self.local_certificates[0].clone());
184 }
185
186 if server_name.is_empty() {
187 return Ok(self.local_certificates[0].clone());
188 }
189
190 let lower = server_name.to_lowercase();
191 let name = lower.trim_end_matches('.');
192
193 if let Some(cert) = self.name_to_certificate.get(name) {
194 return Ok(cert.clone());
195 }
196
197 let mut labels: Vec<&str> = name.split_terminator('.').collect();
200 for i in 0..labels.len() {
201 labels[i] = "*";
202 let candidate = labels.join(".");
203 if let Some(cert) = self.name_to_certificate.get(&candidate) {
204 return Ok(cert.clone());
205 }
206 }
207
208 Ok(self.local_certificates[0].clone())
210 }
211}
212
213pub(crate) fn srv_cli_str(is_client: bool) -> String {
214 if is_client {
215 return "client".to_owned();
216 }
217 "server".to_owned()
218}
219
220impl DTLSConn {
221 pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<()> {
222 loop {
223 trace!(
224 "[handshake:{}] {}: {}",
225 srv_cli_str(self.state.is_client),
226 self.current_flight,
227 state
228 );
229
230 if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() {
231 self.set_handshake_completed_successfully();
232 self.handshake_done_tx.take(); return Ok(());
234 }
235
236 state = match state {
237 HandshakeState::Preparing => self.prepare().await?,
238 HandshakeState::Sending => self.send().await?,
239 HandshakeState::Waiting => self.wait().await?,
240 HandshakeState::Finished => self.finish().await?,
241 _ => return Err(Error::ErrInvalidFsmTransition),
242 };
243 }
244 }
245
246 async fn prepare(&mut self) -> Result<HandshakeState> {
247 self.flights = None;
248
249 self.retransmit = self.current_flight.has_retransmit();
251
252 let result = self
253 .current_flight
254 .generate(&mut self.state, &self.cache, &self.cfg)
255 .await;
256
257 match result {
258 Err((a, mut err)) => {
259 if let Some(a) = a {
260 let alert_err = self.notify(a.alert_level, a.alert_description).await;
261
262 if let Err(alert_err) = alert_err {
263 if err.is_some() {
264 err = Some(alert_err);
265 }
266 }
267 }
268 if let Some(err) = err {
269 return Err(err);
270 }
271 }
272 Ok(pkts) => {
273 self.flights = Some(pkts)
287 }
288 };
289
290 let epoch = self.cfg.initial_epoch;
291 let mut next_epoch = epoch;
292 if let Some(pkts) = &mut self.flights {
293 for p in pkts {
294 p.record.record_layer_header.epoch += epoch;
295 if p.record.record_layer_header.epoch > next_epoch {
296 next_epoch = p.record.record_layer_header.epoch;
297 }
298 if let Content::Handshake(h) = &mut p.record.content {
299 h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16;
300 self.state.handshake_send_sequence += 1;
301 }
302 }
303 }
304 if epoch != next_epoch {
305 trace!(
306 "[handshake:{}] -> changeCipherSpec (epoch: {})",
307 srv_cli_str(self.state.is_client),
308 next_epoch
309 );
310 self.set_local_epoch(next_epoch);
311 }
312
313 Ok(HandshakeState::Sending)
314 }
315 async fn send(&mut self) -> Result<HandshakeState> {
316 if let Some(pkts) = self.flights.clone() {
318 self.write_packets(pkts).await?;
319 }
320
321 if self.current_flight.is_last_send_flight() {
322 Ok(HandshakeState::Finished)
323 } else {
324 Ok(HandshakeState::Waiting)
325 }
326 }
327 async fn wait(&mut self) -> Result<HandshakeState> {
328 let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
329 tokio::pin!(retransmit_timer);
330
331 loop {
332 tokio::select! {
333 done_senders = self.handshake_rx.recv() =>{
334 if done_senders.is_none() {
335 trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight);
336 return Err(Error::ErrAlertFatalOrClose);
337 } else if let Some((rendezvous_tx, done_tx)) = done_senders {
338 rendezvous_tx.send(()).ok();
339 let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
341 drop(done_tx);
342 match result {
343 Err((alert, mut err)) => {
344 trace!("[handshake:{}] {} result alert:{:?}, err:{:?}",
345 srv_cli_str(self.state.is_client),
346 self.current_flight,
347 alert,
348 err);
349
350 if let Some(alert) = alert {
351 let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
352
353 if let Err(alert_err) = alert_err {
354 if err.is_some() {
355 err = Some(alert_err);
356 }
357 }
358 }
359 if let Some(err) = err {
360 return Err(err);
361 }
362 }
363 Ok(next_flight) => {
364 trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight, next_flight);
365 if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() {
366 return Ok(HandshakeState::Finished);
367 }
368 self.current_flight = next_flight;
369 return Ok(HandshakeState::Preparing);
370 }
371 };
372 }
373 }
374
375 _ = retransmit_timer.as_mut() =>{
376 trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight);
377
378 if !self.retransmit {
379 return Ok(HandshakeState::Waiting);
380 }
381 return Ok(HandshakeState::Sending);
382 }
383
384 }
388 }
389 }
390 async fn finish(&mut self) -> Result<HandshakeState> {
391 let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
392
393 tokio::select! {
394 done = self.handshake_rx.recv() =>{
395 if done.is_none() {
396 trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight);
397 return Err(Error::ErrAlertFatalOrClose);
398 }
399 let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
400 drop(done);
401 match result {
402 Err((alert, mut err)) => {
403 if let Some(alert) = alert {
404 let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
405 if let Err(alert_err) = alert_err {
406 if err.is_some() {
407 err = Some(alert_err);
408 }
409 }
410 }
411 if let Some(err) = err {
412 return Err(err);
413 }
414 }
415 Ok(_) => {
416 retransmit_timer.await;
417 return Ok(HandshakeState::Sending);
419 }
420 };
421 }
422
423 }
427
428 Ok(HandshakeState::Finished)
429 }
430}