1use std::{cell::RefCell, io::ErrorKind, net::SocketAddr, rc::Rc, time::Instant};
10
11use mio::{Token, net::TcpStream};
12use rustls::{Error as RustlsError, ServerConnection};
13use rusty_ulid::Ulid;
14use sozu_command::{
15 config::MAX_LOOP_ITERATIONS,
16 logging::{LogContext, ansi_palette},
17};
18
19use crate::metrics::names;
20use crate::{
21 Readiness, Ready, SessionMetrics, SessionResult, StateResult, protocol::SessionState,
22 timer::TimeoutContainer,
23};
24
25macro_rules! log_context {
33 ($self:expr) => {{
34 let (open, reset, grey, gray, white) = ansi_palette();
35 format!(
36 "{gray}{ctx}{reset}\t{open}RUSTLS{reset}\t{grey}Session{reset}({gray}sni{reset}={white}{sni:?}{reset}, {gray}alpn{reset}={white}{alpn}{reset}, {gray}version{reset}={white}{version:?}{reset}, {gray}source{reset}={white}{source:?}{reset}, {gray}frontend{reset}={white}{frontend}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
37 open = open,
38 reset = reset,
39 grey = grey,
40 gray = gray,
41 white = white,
42 ctx = $self.log_context(),
43 sni = $self
44 .session
45 .server_name()
46 .map(|addr| addr.to_string())
47 .unwrap_or_else(|| "<none>".to_string()),
48 alpn = $self
49 .session
50 .alpn_protocol()
51 .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
52 .unwrap_or_else(|| "<none>".to_string()),
53 version = $self.session.protocol_version(),
54 source = $self
55 .peer_address
56 .map(|addr| addr.to_string())
57 .unwrap_or_else(|| "<none>".to_string()),
58 frontend = $self.frontend_token.0,
59 readiness = $self.frontend_readiness,
60 )
61 }};
62}
63
64pub enum TlsState {
65 Initial,
66 Handshake,
67 Established,
68 Error,
69}
70
71pub struct TlsHandshake {
72 pub container_frontend_timeout: TimeoutContainer,
73 pub frontend_readiness: Readiness,
74 frontend_token: Token,
75 pub peer_address: Option<SocketAddr>,
76 pub request_id: Ulid,
77 pub session: ServerConnection,
78 pub stream: TcpStream,
79 handshake_started_at: Option<Instant>,
84}
85
86impl TlsHandshake {
87 pub fn new(
94 container_frontend_timeout: TimeoutContainer,
95 session: ServerConnection,
96 stream: TcpStream,
97 frontend_token: Token,
98 request_id: Ulid,
99 peer_address: Option<SocketAddr>,
100 ) -> TlsHandshake {
101 TlsHandshake {
102 container_frontend_timeout,
103 frontend_readiness: Readiness {
104 interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
105 event: Ready::EMPTY,
106 },
107 frontend_token,
108 peer_address,
109 request_id,
110 session,
111 stream,
112 handshake_started_at: None,
113 }
114 }
115
116 fn record_handshake_duration_ms(&mut self) -> Option<u128> {
122 let was_anchored = self.handshake_started_at.is_some();
123 let elapsed = self
124 .handshake_started_at
125 .take()
126 .map(|t| t.elapsed().as_millis());
127 debug_assert!(
131 self.handshake_started_at.is_none(),
132 "handshake anchor must be cleared after recording the duration"
133 );
134 debug_assert_eq!(
135 elapsed.is_some(),
136 was_anchored,
137 "a duration is returned iff the handshake had been anchored"
138 );
139 elapsed
140 }
141
142 pub fn readable(&mut self) -> SessionResult {
143 self.handshake_started_at.get_or_insert_with(Instant::now);
148 debug_assert!(
150 self.handshake_started_at.is_some(),
151 "handshake anchor must be set before driving TLS I/O"
152 );
153
154 let was_handshaking = self.session.is_handshaking();
158
159 let mut can_read = true;
160
161 loop {
162 let mut can_work = false;
163
164 if self.session.wants_read() && can_read {
165 can_work = true;
166
167 match self.session.read_tls(&mut self.stream) {
168 Ok(0) => {
169 error!("{} Connection closed during handshake", log_context!(self));
170 return SessionResult::Close;
171 }
172 Ok(_) => {}
173 Err(e) => match e.kind() {
174 ErrorKind::WouldBlock => {
175 self.frontend_readiness.event.remove(Ready::READABLE);
176 can_read = false
177 }
178 _ => {
179 error!(
180 "{} Could not perform handshake: {:?}",
181 log_context!(self),
182 e
183 );
184 return SessionResult::Close;
185 }
186 },
187 }
188
189 if let Err(e) = self.session.process_new_packets() {
190 self.log_handshake_error(&e);
191 return SessionResult::Close;
192 }
193 }
194
195 if !can_work {
196 break;
197 }
198 }
199
200 debug_assert!(
203 was_handshaking || !self.session.is_handshaking(),
204 "rustls handshake must not regress from finished back to handshaking"
205 );
206
207 if !self.session.wants_read() {
210 self.frontend_readiness.interest.remove(Ready::READABLE);
211 }
212 debug_assert!(
213 self.session.wants_read() || !self.frontend_readiness.interest.is_readable(),
214 "READABLE interest must be cleared once rustls stops wanting reads"
215 );
216
217 if self.session.wants_write() {
218 self.frontend_readiness.interest.insert(Ready::WRITABLE);
219 }
220
221 if self.session.is_handshaking() {
222 SessionResult::Continue
223 } else {
224 if self.session.wants_write() {
226 SessionResult::Continue
227 } else {
228 debug_assert!(
231 !self.session.is_handshaking() && !self.session.wants_write(),
232 "Upgrade requires a completed handshake with no pending output"
233 );
234 self.frontend_readiness.interest.insert(Ready::READABLE);
235 self.frontend_readiness.event.insert(Ready::READABLE);
236 self.frontend_readiness.interest.insert(Ready::WRITABLE);
237 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
238 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
239 }
240 SessionResult::Upgrade
241 }
242 }
243 }
244
245 pub fn writable(&mut self) -> SessionResult {
246 self.handshake_started_at.get_or_insert_with(Instant::now);
248 debug_assert!(
249 self.handshake_started_at.is_some(),
250 "handshake anchor must be set before driving TLS I/O"
251 );
252
253 let was_handshaking = self.session.is_handshaking();
255
256 let mut can_write = true;
257
258 loop {
259 let mut can_work = false;
260
261 if self.session.wants_write() && can_write {
262 can_work = true;
263
264 match self.session.write_tls(&mut self.stream) {
265 Ok(_) => {}
266 Err(e) => match e.kind() {
267 ErrorKind::WouldBlock => {
268 self.frontend_readiness.event.remove(Ready::WRITABLE);
269 can_write = false
270 }
271 _ => {
272 error!(
273 "{} Could not perform handshake: {:?}",
274 log_context!(self),
275 e
276 );
277 return SessionResult::Close;
278 }
279 },
280 }
281
282 if let Err(e) = self.session.process_new_packets() {
283 self.log_handshake_error(&e);
284 return SessionResult::Close;
285 }
286 }
287
288 if !can_work {
289 break;
290 }
291 }
292
293 debug_assert!(
296 was_handshaking || !self.session.is_handshaking(),
297 "rustls handshake must not regress from finished back to handshaking"
298 );
299
300 if !self.session.wants_write() {
303 self.frontend_readiness.interest.remove(Ready::WRITABLE);
304 }
305 debug_assert!(
306 self.session.wants_write() || !self.frontend_readiness.interest.is_writable(),
307 "WRITABLE interest must be cleared once rustls stops wanting writes"
308 );
309
310 if self.session.wants_read() {
311 self.frontend_readiness.interest.insert(Ready::READABLE);
312 }
313
314 if self.session.is_handshaking() {
315 SessionResult::Continue
316 } else if self.session.wants_read() {
317 debug_assert!(
320 !self.session.is_handshaking(),
321 "Upgrade requires a completed handshake"
322 );
323 self.frontend_readiness.interest.insert(Ready::READABLE);
324 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
325 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
326 }
327 SessionResult::Upgrade
328 } else {
329 debug_assert!(
330 !self.session.is_handshaking(),
331 "Upgrade requires a completed handshake"
332 );
333 self.frontend_readiness.interest.insert(Ready::WRITABLE);
334 self.frontend_readiness.interest.insert(Ready::READABLE);
335 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
336 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
337 }
338 SessionResult::Upgrade
339 }
340 }
341
342 pub fn log_context(&self) -> LogContext<'_> {
343 LogContext {
344 session_id: self.request_id,
345 request_id: None,
346 cluster_id: None,
347 backend_id: None,
348 }
349 }
350
351 pub fn front_socket(&self) -> &TcpStream {
352 &self.stream
353 }
354
355 fn log_handshake_error(&self, err: &RustlsError) {
372 let reason = handshake_failure_reason(err);
373 debug_assert!(
377 reason.starts_with("tls.handshake.failed."),
378 "handshake failure metric {reason} escaped the tls.handshake.failed. namespace"
379 );
380 match err {
381 RustlsError::AlertReceived(_) => debug!(
382 "{} Could not perform handshake: {:?}",
383 log_context!(self),
384 err
385 ),
386 RustlsError::PeerIncompatible(_)
387 | RustlsError::PeerMisbehaved(_)
388 | RustlsError::InvalidMessage(_)
389 | RustlsError::InappropriateMessage { .. }
390 | RustlsError::InappropriateHandshakeMessage { .. }
391 | RustlsError::PeerSentOversizedRecord
392 | RustlsError::NoApplicationProtocol
393 | RustlsError::InvalidCertificate(_)
394 | RustlsError::DecryptError
395 | RustlsError::NoCertificatesPresented => warn!(
396 "{} Could not perform handshake: {:?}",
397 log_context!(self),
398 err
399 ),
400 _ => error!(
401 "{} Could not perform handshake: {:?}",
402 log_context!(self),
403 err
404 ),
405 }
406 count!(reason, 1);
407 }
408}
409
410fn handshake_failure_reason(err: &RustlsError) -> &'static str {
416 match err {
417 RustlsError::AlertReceived(_) => "tls.handshake.failed.alert_received",
418 RustlsError::PeerIncompatible(_) => "tls.handshake.failed.peer_incompatible",
419 RustlsError::PeerMisbehaved(_) => "tls.handshake.failed.peer_misbehaved",
420 RustlsError::InvalidMessage(_) => "tls.handshake.failed.invalid_message",
421 RustlsError::InappropriateMessage { .. } => "tls.handshake.failed.inappropriate_message",
422 RustlsError::InappropriateHandshakeMessage { .. } => {
423 "tls.handshake.failed.inappropriate_handshake_message"
424 }
425 RustlsError::PeerSentOversizedRecord => "tls.handshake.failed.oversized_record",
426 RustlsError::NoApplicationProtocol => "tls.handshake.failed.no_alpn",
427 RustlsError::InvalidCertificate(_) => "tls.handshake.failed.invalid_certificate",
428 RustlsError::DecryptError => "tls.handshake.failed.decrypt_error",
429 RustlsError::NoCertificatesPresented => "tls.handshake.failed.no_certificates_present",
430 _ => "tls.handshake.failed.other",
431 }
432}
433
434impl SessionState for TlsHandshake {
435 fn ready(
436 &mut self,
437 _session: Rc<RefCell<dyn crate::ProxySession>>,
438 _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
439 _metrics: &mut SessionMetrics,
440 ) -> SessionResult {
441 let mut counter = 0;
442
443 if self.frontend_readiness.event.is_hup() {
444 return SessionResult::Close;
445 }
446
447 while counter < MAX_LOOP_ITERATIONS {
448 let frontend_interest = self.frontend_readiness.filter_interest();
449
450 trace!("{} Interest({:?})", log_context!(self), frontend_interest);
451 if frontend_interest.is_empty() {
452 break;
453 }
454
455 if frontend_interest.is_readable() {
456 let protocol_result = self.readable();
457 if protocol_result != SessionResult::Continue {
458 return protocol_result;
459 }
460 }
461
462 if frontend_interest.is_writable() {
463 let protocol_result = self.writable();
464 if protocol_result != SessionResult::Continue {
465 return protocol_result;
466 }
467 }
468
469 if frontend_interest.is_error() {
470 error!("{} Front socket error, disconnecting", log_context!(self));
471 self.frontend_readiness.interest = Ready::EMPTY;
472 return SessionResult::Close;
473 }
474
475 counter += 1;
476 }
477
478 if counter >= MAX_LOOP_ITERATIONS {
479 error!(
480 "{}\tHandling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
481 log_context!(self),
482 MAX_LOOP_ITERATIONS
483 );
484
485 incr!(names::http::INFINITE_LOOP_ERROR);
486 self.print_state("HTTPS");
487
488 return SessionResult::Close;
489 }
490
491 SessionResult::Continue
492 }
493
494 fn update_readiness(&mut self, token: Token, events: Ready) {
495 if self.frontend_token == token {
496 self.frontend_readiness.event |= events;
497 }
498 }
499
500 fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
501 if self.frontend_token == token {
503 self.container_frontend_timeout.triggered();
504 return StateResult::CloseSession;
505 }
506
507 error!(
508 "{}, Expect state: got timeout for an invalid token: {:?}",
509 log_context!(self),
510 token
511 );
512 StateResult::CloseSession
513 }
514
515 fn cancel_timeouts(&mut self) {
516 self.container_frontend_timeout.cancel();
517 }
518
519 fn print_state(&self, context: &str) {
520 error!(
521 "{} Session(Handshake)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
522 context, self.frontend_token, self.frontend_readiness
523 );
524 }
525}
526
527#[cfg(test)]
531mod tests {
532 use std::collections::HashSet;
533
534 use rustls::{
535 AlertDescription, CertificateError, ContentType, Error as RustlsError, HandshakeType,
536 InvalidMessage, PeerIncompatible, PeerMisbehaved,
537 };
538
539 use super::handshake_failure_reason;
540
541 #[test]
547 fn handshake_failure_reason_maps_every_variant_to_unique_namespaced_key() {
548 let cases: &[(RustlsError, &str)] = &[
549 (
550 RustlsError::AlertReceived(AlertDescription::HandshakeFailure),
551 "tls.handshake.failed.alert_received",
552 ),
553 (
554 RustlsError::PeerIncompatible(PeerIncompatible::NoCipherSuitesInCommon),
555 "tls.handshake.failed.peer_incompatible",
556 ),
557 (
558 RustlsError::PeerMisbehaved(PeerMisbehaved::IllegalMiddleboxChangeCipherSpec),
559 "tls.handshake.failed.peer_misbehaved",
560 ),
561 (
562 RustlsError::InvalidMessage(InvalidMessage::InvalidContentType),
563 "tls.handshake.failed.invalid_message",
564 ),
565 (
566 RustlsError::InappropriateMessage {
567 expect_types: vec![ContentType::Handshake],
568 got_type: ContentType::ApplicationData,
569 },
570 "tls.handshake.failed.inappropriate_message",
571 ),
572 (
573 RustlsError::InappropriateHandshakeMessage {
574 expect_types: vec![HandshakeType::ClientHello],
575 got_type: HandshakeType::Finished,
576 },
577 "tls.handshake.failed.inappropriate_handshake_message",
578 ),
579 (
580 RustlsError::PeerSentOversizedRecord,
581 "tls.handshake.failed.oversized_record",
582 ),
583 (
584 RustlsError::NoApplicationProtocol,
585 "tls.handshake.failed.no_alpn",
586 ),
587 (
588 RustlsError::InvalidCertificate(CertificateError::Expired),
589 "tls.handshake.failed.invalid_certificate",
590 ),
591 (
592 RustlsError::DecryptError,
593 "tls.handshake.failed.decrypt_error",
594 ),
595 (
596 RustlsError::NoCertificatesPresented,
597 "tls.handshake.failed.no_certificates_present",
598 ),
599 (
601 RustlsError::General("test".to_owned()),
602 "tls.handshake.failed.other",
603 ),
604 (RustlsError::EncryptError, "tls.handshake.failed.other"),
605 (
606 RustlsError::FailedToGetCurrentTime,
607 "tls.handshake.failed.other",
608 ),
609 (
610 RustlsError::HandshakeNotComplete,
611 "tls.handshake.failed.other",
612 ),
613 ];
614
615 let mut seen = HashSet::new();
616 for (err, expected) in cases {
617 let got = handshake_failure_reason(err);
618 assert_eq!(got, *expected, "variant {err:?} → {got}, want {expected}");
619 assert!(
620 got.starts_with("tls.handshake.failed."),
621 "reason {got} missing tls.handshake.failed. namespace"
622 );
623 seen.insert(got);
624 }
625
626 assert_eq!(seen.len(), 12, "unexpected key set: {seen:?}");
628 }
629}