webrtc_dtls/
handshaker.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use log::*;
6
7use crate::cipher_suite::*;
8use crate::config::*;
9use crate::conn::*;
10use crate::content::*;
11use crate::crypto::*;
12use crate::error::*;
13use crate::extension::extension_use_srtp::*;
14use crate::signature_hash_algorithm::*;
15
16use rustls::client::danger::ServerCertVerifier;
17use rustls::pki_types::CertificateDer;
18use rustls::server::danger::ClientCertVerifier;
19
20//use std::io::BufWriter;
21
22// [RFC6347 Section-4.2.4]
23//                      +-----------+
24//                +---> | PREPARING | <--------------------+
25//                |     +-----------+                      |
26//                |           |                            |
27//                |           | Buffer next flight         |
28//                |           |                            |
29//                |          \|/                           |
30//                |     +-----------+                      |
31//                |     |  SENDING  |<------------------+  | Send
32//                |     +-----------+                   |  | HelloRequest
33//        Receive |           |                         |  |
34//           next |           | Send flight             |  | or
35//         flight |  +--------+                         |  |
36//                |  |        | Set retransmit timer    |  | Receive
37//                |  |       \|/                        |  | HelloRequest
38//                |  |  +-----------+                   |  | Send
39//                +--)--|  WAITING  |-------------------+  | ClientHello
40//                |  |  +-----------+   Timer expires   |  |
41//                |  |         |                        |  |
42//                |  |         +------------------------+  |
43//        Receive |  | Send           Read retransmit      |
44//           last |  | last                                |
45//         flight |  | flight                              |
46//                |  |                                     |
47//               \|/\|/                                    |
48//            +-----------+                                |
49//            | FINISHED  | -------------------------------+
50//            +-----------+
51//                 |  /|\
52//                 |   |
53//                 +---+
54//              Read retransmit
55//           Retransmit last flight
56
57#[derive(Copy, Clone, PartialEq)]
58pub(crate) enum HandshakeState {
59    Errored,
60    Preparing,
61    Sending,
62    Waiting,
63    Finished,
64}
65
66impl fmt::Display for HandshakeState {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match *self {
69            HandshakeState::Errored => write!(f, "Errored"),
70            HandshakeState::Preparing => write!(f, "Preparing"),
71            HandshakeState::Sending => write!(f, "Sending"),
72            HandshakeState::Waiting => write!(f, "Waiting"),
73            HandshakeState::Finished => write!(f, "Finished"),
74        }
75    }
76}
77
78pub(crate) type VerifyPeerCertificateFn =
79    Arc<dyn (Fn(&[Vec<u8>], &[CertificateDer<'static>]) -> Result<()>) + Send + Sync>;
80
81pub(crate) struct HandshakeConfig {
82    pub(crate) local_psk_callback: Option<PskCallback>,
83    pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
84    pub(crate) local_cipher_suites: Vec<CipherSuiteId>, // Available CipherSuites
85    pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, // Available signature schemes
86    pub(crate) extended_master_secret: ExtendedMasterSecretType, // Policy for the Extended Master Support extension
87    pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, // Available SRTPProtectionProfiles, if empty no SRTP support
88    pub(crate) server_name: String,
89    pub(crate) client_auth: ClientAuthType, // If we are a client should we request a client certificate
90    pub(crate) local_certificates: Vec<Certificate>,
91    pub(crate) name_to_certificate: HashMap<String, Certificate>,
92    pub(crate) insecure_skip_verify: bool,
93    pub(crate) insecure_verification: bool,
94    pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
95    pub(crate) server_cert_verifier: Arc<dyn ServerCertVerifier>,
96    pub(crate) client_cert_verifier: Option<Arc<dyn ClientCertVerifier>>,
97    pub(crate) retransmit_interval: tokio::time::Duration,
98    pub(crate) initial_epoch: u16,
99    //log           logging.LeveledLogger
100    //mu sync.Mutex
101}
102
103pub fn gen_self_signed_root_cert() -> rustls::RootCertStore {
104    let mut certs = rustls::RootCertStore::empty();
105    certs
106        .add(
107            rcgen::generate_simple_self_signed(vec![])
108                .unwrap()
109                .cert
110                .der()
111                .to_owned(),
112        )
113        .unwrap();
114    certs
115}
116
117impl Default for HandshakeConfig {
118    fn default() -> Self {
119        HandshakeConfig {
120            local_psk_callback: None,
121            local_psk_identity_hint: None,
122            local_cipher_suites: vec![],
123            local_signature_schemes: vec![],
124            extended_master_secret: ExtendedMasterSecretType::Disable,
125            local_srtp_protection_profiles: vec![],
126            server_name: String::new(),
127            client_auth: ClientAuthType::NoClientCert,
128            local_certificates: vec![],
129            name_to_certificate: HashMap::new(),
130            insecure_skip_verify: false,
131            insecure_verification: false,
132            verify_peer_certificate: None,
133            server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
134                gen_self_signed_root_cert(),
135            ))
136            .build()
137            .unwrap(),
138            client_cert_verifier: None,
139            retransmit_interval: tokio::time::Duration::from_secs(0),
140            initial_epoch: 0,
141        }
142    }
143}
144
145impl HandshakeConfig {
146    pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
147        //TODO
148        /*if self.name_to_certificate.is_empty() {
149            let mut name_to_certificate = HashMap::new();
150            for cert in &self.local_certificates {
151                if let Ok((_rem, x509_cert)) = x509_parser::parse_x509_der(&cert.certificate) {
152                    if let Some(a) = x509_cert.tbs_certificate.subject.iter_common_name().next() {
153                        let common_name = match a.attr_value.as_str() {
154                            Ok(cn) => cn.to_lowercase(),
155                            Err(err) => return Err(Error::new(err.to_string())),
156                        };
157                        name_to_certificate.insert(common_name, cert.clone());
158                    }
159                    if let Some((_, sans)) = x509_cert.tbs_certificate.subject_alternative_name() {
160                        for gn in &sans.general_names {
161                            match gn {
162                                x509_parser::extensions::GeneralName::DNSName(san) => {
163                                    let san = san.to_lowercase();
164                                    name_to_certificate.insert(san, cert.clone());
165                                }
166                                _ => {}
167                            }
168                        }
169                    }
170                } else {
171                    continue;
172                }
173            }
174            self.name_to_certificate = name_to_certificate;
175        }*/
176
177        if self.local_certificates.is_empty() {
178            return Err(Error::ErrNoCertificates);
179        }
180
181        if self.local_certificates.len() == 1 {
182            // There's only one choice, so no point doing any work.
183            return Ok(self.local_certificates[0].clone());
184        }
185
186        if server_name.is_empty() {
187            return Ok(self.local_certificates[0].clone());
188        }
189
190        let lower = server_name.to_lowercase();
191        let name = lower.trim_end_matches('.');
192
193        if let Some(cert) = self.name_to_certificate.get(name) {
194            return Ok(cert.clone());
195        }
196
197        // try replacing labels in the name with wildcards until we get a
198        // match.
199        let mut labels: Vec<&str> = name.split_terminator('.').collect();
200        for i in 0..labels.len() {
201            labels[i] = "*";
202            let candidate = labels.join(".");
203            if let Some(cert) = self.name_to_certificate.get(&candidate) {
204                return Ok(cert.clone());
205            }
206        }
207
208        // If nothing matches, return the first certificate.
209        Ok(self.local_certificates[0].clone())
210    }
211}
212
213pub(crate) fn srv_cli_str(is_client: bool) -> String {
214    if is_client {
215        return "client".to_owned();
216    }
217    "server".to_owned()
218}
219
220impl DTLSConn {
221    pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<()> {
222        loop {
223            trace!(
224                "[handshake:{}] {}: {}",
225                srv_cli_str(self.state.is_client),
226                self.current_flight,
227                state
228            );
229
230            if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() {
231                self.set_handshake_completed_successfully();
232                self.handshake_done_tx.take(); // drop it by take
233                return Ok(());
234            }
235
236            state = match state {
237                HandshakeState::Preparing => self.prepare().await?,
238                HandshakeState::Sending => self.send().await?,
239                HandshakeState::Waiting => self.wait().await?,
240                HandshakeState::Finished => self.finish().await?,
241                _ => return Err(Error::ErrInvalidFsmTransition),
242            };
243        }
244    }
245
246    async fn prepare(&mut self) -> Result<HandshakeState> {
247        self.flights = None;
248
249        // Prepare flights
250        self.retransmit = self.current_flight.has_retransmit();
251
252        let result = self
253            .current_flight
254            .generate(&mut self.state, &self.cache, &self.cfg)
255            .await;
256
257        match result {
258            Err((a, mut err)) => {
259                if let Some(a) = a {
260                    let alert_err = self.notify(a.alert_level, a.alert_description).await;
261
262                    if let Err(alert_err) = alert_err {
263                        if err.is_some() {
264                            err = Some(alert_err);
265                        }
266                    }
267                }
268                if let Some(err) = err {
269                    return Err(err);
270                }
271            }
272            Ok(pkts) => {
273                /*if !pkts.is_empty() {
274                    let mut s = vec![];
275                    {
276                        let mut writer = BufWriter::<&mut Vec<u8>>::new(s.as_mut());
277                        pkts[0].record.content.marshal(&mut writer)?;
278                    }
279                    trace!(
280                        "[handshake:{}] {}: {:?}",
281                        srv_cli_str(self.state.is_client),
282                        self.current_flight.to_string(),
283                        s,
284                    );
285                }*/
286                self.flights = Some(pkts)
287            }
288        };
289
290        let epoch = self.cfg.initial_epoch;
291        let mut next_epoch = epoch;
292        if let Some(pkts) = &mut self.flights {
293            for p in pkts {
294                p.record.record_layer_header.epoch += epoch;
295                if p.record.record_layer_header.epoch > next_epoch {
296                    next_epoch = p.record.record_layer_header.epoch;
297                }
298                if let Content::Handshake(h) = &mut p.record.content {
299                    h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16;
300                    self.state.handshake_send_sequence += 1;
301                }
302            }
303        }
304        if epoch != next_epoch {
305            trace!(
306                "[handshake:{}] -> changeCipherSpec (epoch: {})",
307                srv_cli_str(self.state.is_client),
308                next_epoch
309            );
310            self.set_local_epoch(next_epoch);
311        }
312
313        Ok(HandshakeState::Sending)
314    }
315    async fn send(&mut self) -> Result<HandshakeState> {
316        // Send flights
317        if let Some(pkts) = self.flights.clone() {
318            self.write_packets(pkts).await?;
319        }
320
321        if self.current_flight.is_last_send_flight() {
322            Ok(HandshakeState::Finished)
323        } else {
324            Ok(HandshakeState::Waiting)
325        }
326    }
327    async fn wait(&mut self) -> Result<HandshakeState> {
328        let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
329        tokio::pin!(retransmit_timer);
330
331        loop {
332            tokio::select! {
333                 done_senders = self.handshake_rx.recv() =>{
334                    if done_senders.is_none() {
335                        trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight);
336                        return Err(Error::ErrAlertFatalOrClose);
337                    } else if let Some((rendezvous_tx, done_tx)) = done_senders {
338                        rendezvous_tx.send(()).ok();
339                        //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight);
340                        let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
341                        drop(done_tx);
342                        match result {
343                            Err((alert, mut err)) => {
344                                trace!("[handshake:{}] {} result alert:{:?}, err:{:?}",
345                                        srv_cli_str(self.state.is_client),
346                                        self.current_flight,
347                                        alert,
348                                        err);
349
350                                if let Some(alert) = alert {
351                                    let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
352
353                                    if let Err(alert_err) = alert_err {
354                                        if err.is_some() {
355                                            err = Some(alert_err);
356                                        }
357                                    }
358                                }
359                                if let Some(err) = err {
360                                    return Err(err);
361                                }
362                            }
363                            Ok(next_flight) => {
364                                trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight, next_flight);
365                                if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() {
366                                    return Ok(HandshakeState::Finished);
367                                }
368                                self.current_flight = next_flight;
369                                return Ok(HandshakeState::Preparing);
370                            }
371                        };
372                    }
373                }
374
375                _ = retransmit_timer.as_mut() =>{
376                    trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight);
377
378                    if !self.retransmit {
379                        return Ok(HandshakeState::Waiting);
380                    }
381                    return Ok(HandshakeState::Sending);
382                }
383
384                /*_ = self.done_rx.recv() => {
385                    return Err(Error::new("done_rx recv".to_owned()));
386                }*/
387            }
388        }
389    }
390    async fn finish(&mut self) -> Result<HandshakeState> {
391        let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
392
393        tokio::select! {
394            done = self.handshake_rx.recv() =>{
395                if done.is_none() {
396                    trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight);
397                    return Err(Error::ErrAlertFatalOrClose);
398                }
399                let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
400                drop(done);
401                match result {
402                    Err((alert, mut err)) => {
403                        if let Some(alert) = alert {
404                            let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
405                            if let Err(alert_err) = alert_err {
406                                if err.is_some() {
407                                    err = Some(alert_err);
408                                }
409                            }
410                        }
411                        if let Some(err) = err {
412                            return Err(err);
413                        }
414                    }
415                    Ok(_) => {
416                        retransmit_timer.await;
417                        // Retransmit last flight
418                        return Ok(HandshakeState::Sending);
419                    }
420                };
421            }
422
423            /*_ = self.done_rx.recv() => {
424                return Err(Error::new("done_rx recv".to_owned()));
425            }*/
426        }
427
428        Ok(HandshakeState::Finished)
429    }
430}