Skip to main content

sozu_lib/protocol/
rustls.rs

1//! Rustls handshake driver.
2//!
3//! Owns the per-session `rustls::ServerConnection` during the TLS
4//! handshake: pumps `read_tls`/`write_tls`, surfaces handshake completion
5//! to the parent state, and emits handshake-completion metrics. Cipher /
6//! ALPN / SNI binding decisions live in `lib/src/https.rs`; certificate
7//! resolution and dynamic cert reload live in `lib/src/tls.rs`.
8
9use std::{cell::RefCell, io::ErrorKind, net::SocketAddr, rc::Rc, time::Instant};
10
11use mio::{Token, net::TcpStream};
12use rustls::{Error as RustlsError, ServerConnection};
13use rusty_ulid::Ulid;
14use sozu_command::{
15    config::MAX_LOOP_ITERATIONS,
16    logging::{LogContext, ansi_palette},
17};
18
19use crate::metrics::names;
20use crate::{
21    Readiness, Ready, SessionMetrics, SessionResult, StateResult, protocol::SessionState,
22    timer::TimeoutContainer,
23};
24
25/// This macro is defined uniquely in this module to help the tracking of tls
26/// issues inside Sōzu. When the logger emits to a TTY the protocol label is
27/// bold bright-white (uniform across every protocol), the `Session` keyword is
28/// light grey, attribute keys are gray and values are bright white. ANSI codes
29/// are skipped when output goes to a file or otherwise non-colored sink. The
30/// `[ulid - - -]` context prefix comes first to keep column alignment with
31/// `MUX-*` and `SOCKET` logs.
32macro_rules! log_context {
33    ($self:expr) => {{
34        let (open, reset, grey, gray, white) = ansi_palette();
35        format!(
36            "{gray}{ctx}{reset}\t{open}RUSTLS{reset}\t{grey}Session{reset}({gray}sni{reset}={white}{sni:?}{reset}, {gray}alpn{reset}={white}{alpn}{reset}, {gray}version{reset}={white}{version:?}{reset}, {gray}source{reset}={white}{source:?}{reset}, {gray}frontend{reset}={white}{frontend}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
37            open = open,
38            reset = reset,
39            grey = grey,
40            gray = gray,
41            white = white,
42            ctx = $self.log_context(),
43            sni = $self
44                .session
45                .server_name()
46                .map(|addr| addr.to_string())
47                .unwrap_or_else(|| "<none>".to_string()),
48            alpn = $self
49                .session
50                .alpn_protocol()
51                .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
52                .unwrap_or_else(|| "<none>".to_string()),
53            version = $self.session.protocol_version(),
54            source = $self
55                .peer_address
56                .map(|addr| addr.to_string())
57                .unwrap_or_else(|| "<none>".to_string()),
58            frontend = $self.frontend_token.0,
59            readiness = $self.frontend_readiness,
60        )
61    }};
62}
63
64pub enum TlsState {
65    Initial,
66    Handshake,
67    Established,
68    Error,
69}
70
71pub struct TlsHandshake {
72    pub container_frontend_timeout: TimeoutContainer,
73    pub frontend_readiness: Readiness,
74    frontend_token: Token,
75    pub peer_address: Option<SocketAddr>,
76    pub request_id: Ulid,
77    pub session: ServerConnection,
78    pub stream: TcpStream,
79    /// Wall-clock anchor for the `tls.handshake_ms` histogram. Captured the
80    /// first time the handshake state actually does I/O (not at construction,
81    /// because the session may sit in the accept queue or in expect-proxy for
82    /// an unbounded amount of time before the TLS bytes start flowing).
83    handshake_started_at: Option<Instant>,
84}
85
86impl TlsHandshake {
87    /// Instantiate a new TlsHandshake SessionState with:
88    ///
89    /// - frontend_interest: READABLE | HUP | ERROR
90    /// - frontend_event: EMPTY
91    ///
92    /// Remember to set the events from the previous State!
93    pub fn new(
94        container_frontend_timeout: TimeoutContainer,
95        session: ServerConnection,
96        stream: TcpStream,
97        frontend_token: Token,
98        request_id: Ulid,
99        peer_address: Option<SocketAddr>,
100    ) -> TlsHandshake {
101        TlsHandshake {
102            container_frontend_timeout,
103            frontend_readiness: Readiness {
104                interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
105                event: Ready::EMPTY,
106            },
107            frontend_token,
108            peer_address,
109            request_id,
110            session,
111            stream,
112            handshake_started_at: None,
113        }
114    }
115
116    /// Returns the elapsed handshake duration in milliseconds and clears the
117    /// captured start instant so the histogram is only recorded once. Returns
118    /// `None` when no I/O happened (e.g. the connection closed mid-handshake
119    /// before any bytes were exchanged); callers should not emit
120    /// `tls.handshake_ms` in that case.
121    fn record_handshake_duration_ms(&mut self) -> Option<u128> {
122        let was_anchored = self.handshake_started_at.is_some();
123        let elapsed = self
124            .handshake_started_at
125            .take()
126            .map(|t| t.elapsed().as_millis());
127        // `take()` is idempotent-disarming: the anchor is always cleared so the
128        // histogram is recorded at most once, and a duration is returned iff an
129        // anchor existed.
130        debug_assert!(
131            self.handshake_started_at.is_none(),
132            "handshake anchor must be cleared after recording the duration"
133        );
134        debug_assert_eq!(
135            elapsed.is_some(),
136            was_anchored,
137            "a duration is returned iff the handshake had been anchored"
138        );
139        elapsed
140    }
141
142    pub fn readable(&mut self) -> SessionResult {
143        // Anchor the handshake duration the first time we observe TLS bytes
144        // moving in either direction. Using `get_or_insert_with` keeps the
145        // anchor sticky across `WouldBlock` retries and across the
146        // readable/writable boundary.
147        self.handshake_started_at.get_or_insert_with(Instant::now);
148        // The anchor is sticky once set: this method must never run unanchored.
149        debug_assert!(
150            self.handshake_started_at.is_some(),
151            "handshake anchor must be set before driving TLS I/O"
152        );
153
154        // rustls handshake completion is monotonic (`true → false`, never
155        // back). Snapshot it so the exit assertions can prove we never resurrect
156        // a finished handshake.
157        let was_handshaking = self.session.is_handshaking();
158
159        let mut can_read = true;
160
161        loop {
162            let mut can_work = false;
163
164            if self.session.wants_read() && can_read {
165                can_work = true;
166
167                match self.session.read_tls(&mut self.stream) {
168                    Ok(0) => {
169                        error!("{} Connection closed during handshake", log_context!(self));
170                        return SessionResult::Close;
171                    }
172                    Ok(_) => {}
173                    Err(e) => match e.kind() {
174                        ErrorKind::WouldBlock => {
175                            self.frontend_readiness.event.remove(Ready::READABLE);
176                            can_read = false
177                        }
178                        _ => {
179                            error!(
180                                "{} Could not perform handshake: {:?}",
181                                log_context!(self),
182                                e
183                            );
184                            return SessionResult::Close;
185                        }
186                    },
187                }
188
189                if let Err(e) = self.session.process_new_packets() {
190                    self.log_handshake_error(&e);
191                    return SessionResult::Close;
192                }
193            }
194
195            if !can_work {
196                break;
197            }
198        }
199
200        // Handshake completion is monotonic: a handshake that had already
201        // finished at entry cannot become unfinished by pumping `read_tls`.
202        debug_assert!(
203            was_handshaking || !self.session.is_handshaking(),
204            "rustls handshake must not regress from finished back to handshaking"
205        );
206
207        // Readiness must mirror rustls's own wants: we only drop READABLE
208        // interest when the session no longer wants to read.
209        if !self.session.wants_read() {
210            self.frontend_readiness.interest.remove(Ready::READABLE);
211        }
212        debug_assert!(
213            self.session.wants_read() || !self.frontend_readiness.interest.is_readable(),
214            "READABLE interest must be cleared once rustls stops wanting reads"
215        );
216
217        if self.session.wants_write() {
218            self.frontend_readiness.interest.insert(Ready::WRITABLE);
219        }
220
221        if self.session.is_handshaking() {
222            SessionResult::Continue
223        } else {
224            // handshake might be finished, but we still have something to send
225            if self.session.wants_write() {
226                SessionResult::Continue
227            } else {
228                // Upgrade is only signalled once the handshake is complete and
229                // there is nothing left to flush to the peer.
230                debug_assert!(
231                    !self.session.is_handshaking() && !self.session.wants_write(),
232                    "Upgrade requires a completed handshake with no pending output"
233                );
234                self.frontend_readiness.interest.insert(Ready::READABLE);
235                self.frontend_readiness.event.insert(Ready::READABLE);
236                self.frontend_readiness.interest.insert(Ready::WRITABLE);
237                if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
238                    time!(names::tls::HANDSHAKE_MS, elapsed_ms);
239                }
240                SessionResult::Upgrade
241            }
242        }
243    }
244
245    pub fn writable(&mut self) -> SessionResult {
246        // Same anchor logic as `readable()` — see the comment there.
247        self.handshake_started_at.get_or_insert_with(Instant::now);
248        debug_assert!(
249            self.handshake_started_at.is_some(),
250            "handshake anchor must be set before driving TLS I/O"
251        );
252
253        // Snapshot handshake completion for the monotonicity post-condition.
254        let was_handshaking = self.session.is_handshaking();
255
256        let mut can_write = true;
257
258        loop {
259            let mut can_work = false;
260
261            if self.session.wants_write() && can_write {
262                can_work = true;
263
264                match self.session.write_tls(&mut self.stream) {
265                    Ok(_) => {}
266                    Err(e) => match e.kind() {
267                        ErrorKind::WouldBlock => {
268                            self.frontend_readiness.event.remove(Ready::WRITABLE);
269                            can_write = false
270                        }
271                        _ => {
272                            error!(
273                                "{} Could not perform handshake: {:?}",
274                                log_context!(self),
275                                e
276                            );
277                            return SessionResult::Close;
278                        }
279                    },
280                }
281
282                if let Err(e) = self.session.process_new_packets() {
283                    self.log_handshake_error(&e);
284                    return SessionResult::Close;
285                }
286            }
287
288            if !can_work {
289                break;
290            }
291        }
292
293        // Handshake completion is monotonic: pumping `write_tls` can finish a
294        // handshake but never un-finish one.
295        debug_assert!(
296            was_handshaking || !self.session.is_handshaking(),
297            "rustls handshake must not regress from finished back to handshaking"
298        );
299
300        // Readiness mirrors rustls's wants: WRITABLE interest is only dropped
301        // once the session no longer wants to write.
302        if !self.session.wants_write() {
303            self.frontend_readiness.interest.remove(Ready::WRITABLE);
304        }
305        debug_assert!(
306            self.session.wants_write() || !self.frontend_readiness.interest.is_writable(),
307            "WRITABLE interest must be cleared once rustls stops wanting writes"
308        );
309
310        if self.session.wants_read() {
311            self.frontend_readiness.interest.insert(Ready::READABLE);
312        }
313
314        if self.session.is_handshaking() {
315            SessionResult::Continue
316        } else if self.session.wants_read() {
317            // Upgrade after a completed handshake; the session still wants to
318            // read application data, which the upgraded state will drive.
319            debug_assert!(
320                !self.session.is_handshaking(),
321                "Upgrade requires a completed handshake"
322            );
323            self.frontend_readiness.interest.insert(Ready::READABLE);
324            if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
325                time!(names::tls::HANDSHAKE_MS, elapsed_ms);
326            }
327            SessionResult::Upgrade
328        } else {
329            debug_assert!(
330                !self.session.is_handshaking(),
331                "Upgrade requires a completed handshake"
332            );
333            self.frontend_readiness.interest.insert(Ready::WRITABLE);
334            self.frontend_readiness.interest.insert(Ready::READABLE);
335            if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
336                time!(names::tls::HANDSHAKE_MS, elapsed_ms);
337            }
338            SessionResult::Upgrade
339        }
340    }
341
342    pub fn log_context(&self) -> LogContext<'_> {
343        LogContext {
344            session_id: self.request_id,
345            request_id: None,
346            cluster_id: None,
347            backend_id: None,
348        }
349    }
350
351    pub fn front_socket(&self) -> &TcpStream {
352        &self.stream
353    }
354
355    /// Tiered logging for TLS handshake errors surfaced by `process_new_packets`.
356    ///
357    /// - `AlertReceived(_)`: remote peer rejected our cert/config (e.g. old
358    ///   CA bundle, scanner, cert-pinning client). Not actionable per-connection
359    ///   on a public endpoint, so log at `debug!`.
360    /// - Peer protocol violations (`PeerIncompatible`, `PeerMisbehaved`,
361    ///   `InvalidMessage`, inappropriate message / handshake message,
362    ///   oversized record, ALPN mismatch, bad client cert, `DecryptError`,
363    ///   `NoCertificatesPresented`): occasionally useful to spot buggy
364    ///   clients or stale roots, so log at `warn!`.
365    /// - Everything else (local/config/provider failures like `EncryptError`,
366    ///   `General`, `Other`, CRL issues, missing entropy): genuine server-side
367    ///   problems, stay at `error!`.
368    ///
369    /// Each tier additionally bumps `tls.handshake.failed.<reason>` so dashboards
370    /// can split spikes by category without having to grep logs.
371    fn log_handshake_error(&self, err: &RustlsError) {
372        let reason = handshake_failure_reason(err);
373        // Every reason must stay inside the bounded `tls.handshake.failed.*`
374        // namespace so statsd cardinality is predictable — unknown variants
375        // collapse to `.other`, never an unnamespaced key.
376        debug_assert!(
377            reason.starts_with("tls.handshake.failed."),
378            "handshake failure metric {reason} escaped the tls.handshake.failed. namespace"
379        );
380        match err {
381            RustlsError::AlertReceived(_) => debug!(
382                "{} Could not perform handshake: {:?}",
383                log_context!(self),
384                err
385            ),
386            RustlsError::PeerIncompatible(_)
387            | RustlsError::PeerMisbehaved(_)
388            | RustlsError::InvalidMessage(_)
389            | RustlsError::InappropriateMessage { .. }
390            | RustlsError::InappropriateHandshakeMessage { .. }
391            | RustlsError::PeerSentOversizedRecord
392            | RustlsError::NoApplicationProtocol
393            | RustlsError::InvalidCertificate(_)
394            | RustlsError::DecryptError
395            | RustlsError::NoCertificatesPresented => warn!(
396                "{} Could not perform handshake: {:?}",
397                log_context!(self),
398                err
399            ),
400            _ => error!(
401                "{} Could not perform handshake: {:?}",
402                log_context!(self),
403                err
404            ),
405        }
406        count!(reason, 1);
407    }
408}
409
410/// Compile-time literal `tls.handshake.failed.<reason>` keys for every variant
411/// the proxy can observe. Free function (rather than a method) so unit tests
412/// can drive it without constructing a real `ServerConnection`. The set of
413/// suffixes is bounded — anything outside the explicit `match` arms collapses
414/// to `tls.handshake.failed.other` so statsd cardinality stays predictable.
415fn handshake_failure_reason(err: &RustlsError) -> &'static str {
416    match err {
417        RustlsError::AlertReceived(_) => "tls.handshake.failed.alert_received",
418        RustlsError::PeerIncompatible(_) => "tls.handshake.failed.peer_incompatible",
419        RustlsError::PeerMisbehaved(_) => "tls.handshake.failed.peer_misbehaved",
420        RustlsError::InvalidMessage(_) => "tls.handshake.failed.invalid_message",
421        RustlsError::InappropriateMessage { .. } => "tls.handshake.failed.inappropriate_message",
422        RustlsError::InappropriateHandshakeMessage { .. } => {
423            "tls.handshake.failed.inappropriate_handshake_message"
424        }
425        RustlsError::PeerSentOversizedRecord => "tls.handshake.failed.oversized_record",
426        RustlsError::NoApplicationProtocol => "tls.handshake.failed.no_alpn",
427        RustlsError::InvalidCertificate(_) => "tls.handshake.failed.invalid_certificate",
428        RustlsError::DecryptError => "tls.handshake.failed.decrypt_error",
429        RustlsError::NoCertificatesPresented => "tls.handshake.failed.no_certificates_present",
430        _ => "tls.handshake.failed.other",
431    }
432}
433
434impl SessionState for TlsHandshake {
435    fn ready(
436        &mut self,
437        _session: Rc<RefCell<dyn crate::ProxySession>>,
438        _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
439        _metrics: &mut SessionMetrics,
440    ) -> SessionResult {
441        let mut counter = 0;
442
443        if self.frontend_readiness.event.is_hup() {
444            return SessionResult::Close;
445        }
446
447        while counter < MAX_LOOP_ITERATIONS {
448            let frontend_interest = self.frontend_readiness.filter_interest();
449
450            trace!("{} Interest({:?})", log_context!(self), frontend_interest);
451            if frontend_interest.is_empty() {
452                break;
453            }
454
455            if frontend_interest.is_readable() {
456                let protocol_result = self.readable();
457                if protocol_result != SessionResult::Continue {
458                    return protocol_result;
459                }
460            }
461
462            if frontend_interest.is_writable() {
463                let protocol_result = self.writable();
464                if protocol_result != SessionResult::Continue {
465                    return protocol_result;
466                }
467            }
468
469            if frontend_interest.is_error() {
470                error!("{} Front socket error, disconnecting", log_context!(self));
471                self.frontend_readiness.interest = Ready::EMPTY;
472                return SessionResult::Close;
473            }
474
475            counter += 1;
476        }
477
478        if counter >= MAX_LOOP_ITERATIONS {
479            error!(
480                "{}\tHandling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
481                log_context!(self),
482                MAX_LOOP_ITERATIONS
483            );
484
485            incr!(names::http::INFINITE_LOOP_ERROR);
486            self.print_state("HTTPS");
487
488            return SessionResult::Close;
489        }
490
491        SessionResult::Continue
492    }
493
494    fn update_readiness(&mut self, token: Token, events: Ready) {
495        if self.frontend_token == token {
496            self.frontend_readiness.event |= events;
497        }
498    }
499
500    fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
501        // relevant timeout is still stored in the Session as front_timeout.
502        if self.frontend_token == token {
503            self.container_frontend_timeout.triggered();
504            return StateResult::CloseSession;
505        }
506
507        error!(
508            "{}, Expect state: got timeout for an invalid token: {:?}",
509            log_context!(self),
510            token
511        );
512        StateResult::CloseSession
513    }
514
515    fn cancel_timeouts(&mut self) {
516        self.container_frontend_timeout.cancel();
517    }
518
519    fn print_state(&self, context: &str) {
520        error!(
521            "{} Session(Handshake)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
522            context, self.frontend_token, self.frontend_readiness
523        );
524    }
525}
526
527// -----------------------------------------------------------------------------
528// Unit tests
529
530#[cfg(test)]
531mod tests {
532    use std::collections::HashSet;
533
534    use rustls::{
535        AlertDescription, CertificateError, ContentType, Error as RustlsError, HandshakeType,
536        InvalidMessage, PeerIncompatible, PeerMisbehaved,
537    };
538
539    use super::handshake_failure_reason;
540
541    /// Every rustls error variant the proxy can observe must map to a distinct,
542    /// compile-time literal `tls.handshake.failed.<reason>` key. Unknown
543    /// variants (future rustls additions, `General`, `Other`, CRL errors, etc.)
544    /// collapse to `tls.handshake.failed.other` so statsd cardinality stays
545    /// bounded. This test also guards against accidental duplicate keys.
546    #[test]
547    fn handshake_failure_reason_maps_every_variant_to_unique_namespaced_key() {
548        let cases: &[(RustlsError, &str)] = &[
549            (
550                RustlsError::AlertReceived(AlertDescription::HandshakeFailure),
551                "tls.handshake.failed.alert_received",
552            ),
553            (
554                RustlsError::PeerIncompatible(PeerIncompatible::NoCipherSuitesInCommon),
555                "tls.handshake.failed.peer_incompatible",
556            ),
557            (
558                RustlsError::PeerMisbehaved(PeerMisbehaved::IllegalMiddleboxChangeCipherSpec),
559                "tls.handshake.failed.peer_misbehaved",
560            ),
561            (
562                RustlsError::InvalidMessage(InvalidMessage::InvalidContentType),
563                "tls.handshake.failed.invalid_message",
564            ),
565            (
566                RustlsError::InappropriateMessage {
567                    expect_types: vec![ContentType::Handshake],
568                    got_type: ContentType::ApplicationData,
569                },
570                "tls.handshake.failed.inappropriate_message",
571            ),
572            (
573                RustlsError::InappropriateHandshakeMessage {
574                    expect_types: vec![HandshakeType::ClientHello],
575                    got_type: HandshakeType::Finished,
576                },
577                "tls.handshake.failed.inappropriate_handshake_message",
578            ),
579            (
580                RustlsError::PeerSentOversizedRecord,
581                "tls.handshake.failed.oversized_record",
582            ),
583            (
584                RustlsError::NoApplicationProtocol,
585                "tls.handshake.failed.no_alpn",
586            ),
587            (
588                RustlsError::InvalidCertificate(CertificateError::Expired),
589                "tls.handshake.failed.invalid_certificate",
590            ),
591            (
592                RustlsError::DecryptError,
593                "tls.handshake.failed.decrypt_error",
594            ),
595            (
596                RustlsError::NoCertificatesPresented,
597                "tls.handshake.failed.no_certificates_present",
598            ),
599            // `Other` bucket — any variant not in the explicit list collapses here.
600            (
601                RustlsError::General("test".to_owned()),
602                "tls.handshake.failed.other",
603            ),
604            (RustlsError::EncryptError, "tls.handshake.failed.other"),
605            (
606                RustlsError::FailedToGetCurrentTime,
607                "tls.handshake.failed.other",
608            ),
609            (
610                RustlsError::HandshakeNotComplete,
611                "tls.handshake.failed.other",
612            ),
613        ];
614
615        let mut seen = HashSet::new();
616        for (err, expected) in cases {
617            let got = handshake_failure_reason(err);
618            assert_eq!(got, *expected, "variant {err:?} → {got}, want {expected}");
619            assert!(
620                got.starts_with("tls.handshake.failed."),
621                "reason {got} missing tls.handshake.failed. namespace"
622            );
623            seen.insert(got);
624        }
625
626        // 11 explicit buckets + 1 shared `other` bucket = 12 distinct keys.
627        assert_eq!(seen.len(), 12, "unexpected key set: {seen:?}");
628    }
629}