1use std::{cell::RefCell, io::ErrorKind, net::SocketAddr, rc::Rc, time::Instant};
10
11use mio::{Token, net::TcpStream};
12use rustls::{Error as RustlsError, ServerConnection};
13use rusty_ulid::Ulid;
14use sozu_command::{
15 config::MAX_LOOP_ITERATIONS,
16 logging::{LogContext, ansi_palette},
17};
18
19use crate::metrics::names;
20use crate::{
21 Readiness, Ready, SessionMetrics, SessionResult, StateResult, protocol::SessionState,
22 timer::TimeoutContainer,
23};
24
25macro_rules! log_context {
33 ($self:expr) => {{
34 let (open, reset, grey, gray, white) = ansi_palette();
35 format!(
36 "{gray}{ctx}{reset}\t{open}RUSTLS{reset}\t{grey}Session{reset}({gray}sni{reset}={white}{sni:?}{reset}, {gray}alpn{reset}={white}{alpn}{reset}, {gray}version{reset}={white}{version:?}{reset}, {gray}source{reset}={white}{source:?}{reset}, {gray}frontend{reset}={white}{frontend}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
37 open = open,
38 reset = reset,
39 grey = grey,
40 gray = gray,
41 white = white,
42 ctx = $self.log_context(),
43 sni = $self
44 .session
45 .server_name()
46 .map(|addr| addr.to_string())
47 .unwrap_or_else(|| "<none>".to_string()),
48 alpn = $self
49 .session
50 .alpn_protocol()
51 .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
52 .unwrap_or_else(|| "<none>".to_string()),
53 version = $self.session.protocol_version(),
54 source = $self
55 .peer_address
56 .map(|addr| addr.to_string())
57 .unwrap_or_else(|| "<none>".to_string()),
58 frontend = $self.frontend_token.0,
59 readiness = $self.frontend_readiness,
60 )
61 }};
62}
63
64pub enum TlsState {
65 Initial,
66 Handshake,
67 Established,
68 Error,
69}
70
71pub struct TlsHandshake {
72 pub container_frontend_timeout: TimeoutContainer,
73 pub frontend_readiness: Readiness,
74 frontend_token: Token,
75 pub peer_address: Option<SocketAddr>,
76 pub request_id: Ulid,
77 pub session: ServerConnection,
78 pub stream: TcpStream,
79 handshake_started_at: Option<Instant>,
84}
85
86impl TlsHandshake {
87 pub fn new(
94 container_frontend_timeout: TimeoutContainer,
95 session: ServerConnection,
96 stream: TcpStream,
97 frontend_token: Token,
98 request_id: Ulid,
99 peer_address: Option<SocketAddr>,
100 ) -> TlsHandshake {
101 TlsHandshake {
102 container_frontend_timeout,
103 frontend_readiness: Readiness {
104 interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
105 event: Ready::EMPTY,
106 },
107 frontend_token,
108 peer_address,
109 request_id,
110 session,
111 stream,
112 handshake_started_at: None,
113 }
114 }
115
116 fn record_handshake_duration_ms(&mut self) -> Option<u128> {
122 self.handshake_started_at
123 .take()
124 .map(|t| t.elapsed().as_millis())
125 }
126
127 pub fn readable(&mut self) -> SessionResult {
128 self.handshake_started_at.get_or_insert_with(Instant::now);
133
134 let mut can_read = true;
135
136 loop {
137 let mut can_work = false;
138
139 if self.session.wants_read() && can_read {
140 can_work = true;
141
142 match self.session.read_tls(&mut self.stream) {
143 Ok(0) => {
144 error!("{} Connection closed during handshake", log_context!(self));
145 return SessionResult::Close;
146 }
147 Ok(_) => {}
148 Err(e) => match e.kind() {
149 ErrorKind::WouldBlock => {
150 self.frontend_readiness.event.remove(Ready::READABLE);
151 can_read = false
152 }
153 _ => {
154 error!(
155 "{} Could not perform handshake: {:?}",
156 log_context!(self),
157 e
158 );
159 return SessionResult::Close;
160 }
161 },
162 }
163
164 if let Err(e) = self.session.process_new_packets() {
165 self.log_handshake_error(&e);
166 return SessionResult::Close;
167 }
168 }
169
170 if !can_work {
171 break;
172 }
173 }
174
175 if !self.session.wants_read() {
176 self.frontend_readiness.interest.remove(Ready::READABLE);
177 }
178
179 if self.session.wants_write() {
180 self.frontend_readiness.interest.insert(Ready::WRITABLE);
181 }
182
183 if self.session.is_handshaking() {
184 SessionResult::Continue
185 } else {
186 if self.session.wants_write() {
188 SessionResult::Continue
189 } else {
190 self.frontend_readiness.interest.insert(Ready::READABLE);
191 self.frontend_readiness.event.insert(Ready::READABLE);
192 self.frontend_readiness.interest.insert(Ready::WRITABLE);
193 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
194 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
195 }
196 SessionResult::Upgrade
197 }
198 }
199 }
200
201 pub fn writable(&mut self) -> SessionResult {
202 self.handshake_started_at.get_or_insert_with(Instant::now);
204
205 let mut can_write = true;
206
207 loop {
208 let mut can_work = false;
209
210 if self.session.wants_write() && can_write {
211 can_work = true;
212
213 match self.session.write_tls(&mut self.stream) {
214 Ok(_) => {}
215 Err(e) => match e.kind() {
216 ErrorKind::WouldBlock => {
217 self.frontend_readiness.event.remove(Ready::WRITABLE);
218 can_write = false
219 }
220 _ => {
221 error!(
222 "{} Could not perform handshake: {:?}",
223 log_context!(self),
224 e
225 );
226 return SessionResult::Close;
227 }
228 },
229 }
230
231 if let Err(e) = self.session.process_new_packets() {
232 self.log_handshake_error(&e);
233 return SessionResult::Close;
234 }
235 }
236
237 if !can_work {
238 break;
239 }
240 }
241
242 if !self.session.wants_write() {
243 self.frontend_readiness.interest.remove(Ready::WRITABLE);
244 }
245
246 if self.session.wants_read() {
247 self.frontend_readiness.interest.insert(Ready::READABLE);
248 }
249
250 if self.session.is_handshaking() {
251 SessionResult::Continue
252 } else if self.session.wants_read() {
253 self.frontend_readiness.interest.insert(Ready::READABLE);
254 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
255 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
256 }
257 SessionResult::Upgrade
258 } else {
259 self.frontend_readiness.interest.insert(Ready::WRITABLE);
260 self.frontend_readiness.interest.insert(Ready::READABLE);
261 if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
262 time!(names::tls::HANDSHAKE_MS, elapsed_ms);
263 }
264 SessionResult::Upgrade
265 }
266 }
267
268 pub fn log_context(&self) -> LogContext<'_> {
269 LogContext {
270 session_id: self.request_id,
271 request_id: None,
272 cluster_id: None,
273 backend_id: None,
274 }
275 }
276
277 pub fn front_socket(&self) -> &TcpStream {
278 &self.stream
279 }
280
281 fn log_handshake_error(&self, err: &RustlsError) {
298 let reason = handshake_failure_reason(err);
299 match err {
300 RustlsError::AlertReceived(_) => debug!(
301 "{} Could not perform handshake: {:?}",
302 log_context!(self),
303 err
304 ),
305 RustlsError::PeerIncompatible(_)
306 | RustlsError::PeerMisbehaved(_)
307 | RustlsError::InvalidMessage(_)
308 | RustlsError::InappropriateMessage { .. }
309 | RustlsError::InappropriateHandshakeMessage { .. }
310 | RustlsError::PeerSentOversizedRecord
311 | RustlsError::NoApplicationProtocol
312 | RustlsError::InvalidCertificate(_)
313 | RustlsError::DecryptError
314 | RustlsError::NoCertificatesPresented => warn!(
315 "{} Could not perform handshake: {:?}",
316 log_context!(self),
317 err
318 ),
319 _ => error!(
320 "{} Could not perform handshake: {:?}",
321 log_context!(self),
322 err
323 ),
324 }
325 count!(reason, 1);
326 }
327}
328
329fn handshake_failure_reason(err: &RustlsError) -> &'static str {
335 match err {
336 RustlsError::AlertReceived(_) => "tls.handshake.failed.alert_received",
337 RustlsError::PeerIncompatible(_) => "tls.handshake.failed.peer_incompatible",
338 RustlsError::PeerMisbehaved(_) => "tls.handshake.failed.peer_misbehaved",
339 RustlsError::InvalidMessage(_) => "tls.handshake.failed.invalid_message",
340 RustlsError::InappropriateMessage { .. } => "tls.handshake.failed.inappropriate_message",
341 RustlsError::InappropriateHandshakeMessage { .. } => {
342 "tls.handshake.failed.inappropriate_handshake_message"
343 }
344 RustlsError::PeerSentOversizedRecord => "tls.handshake.failed.oversized_record",
345 RustlsError::NoApplicationProtocol => "tls.handshake.failed.no_alpn",
346 RustlsError::InvalidCertificate(_) => "tls.handshake.failed.invalid_certificate",
347 RustlsError::DecryptError => "tls.handshake.failed.decrypt_error",
348 RustlsError::NoCertificatesPresented => "tls.handshake.failed.no_certificates_present",
349 _ => "tls.handshake.failed.other",
350 }
351}
352
353impl SessionState for TlsHandshake {
354 fn ready(
355 &mut self,
356 _session: Rc<RefCell<dyn crate::ProxySession>>,
357 _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
358 _metrics: &mut SessionMetrics,
359 ) -> SessionResult {
360 let mut counter = 0;
361
362 if self.frontend_readiness.event.is_hup() {
363 return SessionResult::Close;
364 }
365
366 while counter < MAX_LOOP_ITERATIONS {
367 let frontend_interest = self.frontend_readiness.filter_interest();
368
369 trace!("{} Interest({:?})", log_context!(self), frontend_interest);
370 if frontend_interest.is_empty() {
371 break;
372 }
373
374 if frontend_interest.is_readable() {
375 let protocol_result = self.readable();
376 if protocol_result != SessionResult::Continue {
377 return protocol_result;
378 }
379 }
380
381 if frontend_interest.is_writable() {
382 let protocol_result = self.writable();
383 if protocol_result != SessionResult::Continue {
384 return protocol_result;
385 }
386 }
387
388 if frontend_interest.is_error() {
389 error!("{} Front socket error, disconnecting", log_context!(self));
390 self.frontend_readiness.interest = Ready::EMPTY;
391 return SessionResult::Close;
392 }
393
394 counter += 1;
395 }
396
397 if counter >= MAX_LOOP_ITERATIONS {
398 error!(
399 "{}\tHandling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
400 log_context!(self),
401 MAX_LOOP_ITERATIONS
402 );
403
404 incr!(names::http::INFINITE_LOOP_ERROR);
405 self.print_state("HTTPS");
406
407 return SessionResult::Close;
408 }
409
410 SessionResult::Continue
411 }
412
413 fn update_readiness(&mut self, token: Token, events: Ready) {
414 if self.frontend_token == token {
415 self.frontend_readiness.event |= events;
416 }
417 }
418
419 fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
420 if self.frontend_token == token {
422 self.container_frontend_timeout.triggered();
423 return StateResult::CloseSession;
424 }
425
426 error!(
427 "{}, Expect state: got timeout for an invalid token: {:?}",
428 log_context!(self),
429 token
430 );
431 StateResult::CloseSession
432 }
433
434 fn cancel_timeouts(&mut self) {
435 self.container_frontend_timeout.cancel();
436 }
437
438 fn print_state(&self, context: &str) {
439 error!(
440 "{} Session(Handshake)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
441 context, self.frontend_token, self.frontend_readiness
442 );
443 }
444}
445
446#[cfg(test)]
450mod tests {
451 use std::collections::HashSet;
452
453 use rustls::{
454 AlertDescription, CertificateError, ContentType, Error as RustlsError, HandshakeType,
455 InvalidMessage, PeerIncompatible, PeerMisbehaved,
456 };
457
458 use super::handshake_failure_reason;
459
460 #[test]
466 fn handshake_failure_reason_maps_every_variant_to_unique_namespaced_key() {
467 let cases: &[(RustlsError, &str)] = &[
468 (
469 RustlsError::AlertReceived(AlertDescription::HandshakeFailure),
470 "tls.handshake.failed.alert_received",
471 ),
472 (
473 RustlsError::PeerIncompatible(PeerIncompatible::NoCipherSuitesInCommon),
474 "tls.handshake.failed.peer_incompatible",
475 ),
476 (
477 RustlsError::PeerMisbehaved(PeerMisbehaved::IllegalMiddleboxChangeCipherSpec),
478 "tls.handshake.failed.peer_misbehaved",
479 ),
480 (
481 RustlsError::InvalidMessage(InvalidMessage::InvalidContentType),
482 "tls.handshake.failed.invalid_message",
483 ),
484 (
485 RustlsError::InappropriateMessage {
486 expect_types: vec![ContentType::Handshake],
487 got_type: ContentType::ApplicationData,
488 },
489 "tls.handshake.failed.inappropriate_message",
490 ),
491 (
492 RustlsError::InappropriateHandshakeMessage {
493 expect_types: vec![HandshakeType::ClientHello],
494 got_type: HandshakeType::Finished,
495 },
496 "tls.handshake.failed.inappropriate_handshake_message",
497 ),
498 (
499 RustlsError::PeerSentOversizedRecord,
500 "tls.handshake.failed.oversized_record",
501 ),
502 (
503 RustlsError::NoApplicationProtocol,
504 "tls.handshake.failed.no_alpn",
505 ),
506 (
507 RustlsError::InvalidCertificate(CertificateError::Expired),
508 "tls.handshake.failed.invalid_certificate",
509 ),
510 (
511 RustlsError::DecryptError,
512 "tls.handshake.failed.decrypt_error",
513 ),
514 (
515 RustlsError::NoCertificatesPresented,
516 "tls.handshake.failed.no_certificates_present",
517 ),
518 (
520 RustlsError::General("test".to_owned()),
521 "tls.handshake.failed.other",
522 ),
523 (RustlsError::EncryptError, "tls.handshake.failed.other"),
524 (
525 RustlsError::FailedToGetCurrentTime,
526 "tls.handshake.failed.other",
527 ),
528 (
529 RustlsError::HandshakeNotComplete,
530 "tls.handshake.failed.other",
531 ),
532 ];
533
534 let mut seen = HashSet::new();
535 for (err, expected) in cases {
536 let got = handshake_failure_reason(err);
537 assert_eq!(got, *expected, "variant {err:?} → {got}, want {expected}");
538 assert!(
539 got.starts_with("tls.handshake.failed."),
540 "reason {got} missing tls.handshake.failed. namespace"
541 );
542 seen.insert(got);
543 }
544
545 assert_eq!(seen.len(), 12, "unexpected key set: {seen:?}");
547 }
548}