Skip to main content

sozu_lib/protocol/mux/
connection.rs

1//! Protocol-agnostic frontend/backend connection wrapper.
2//!
3//! [`Connection`] is the H1/H2 dispatch enum used everywhere in the mux
4//! layer. Most of the methods are trivial pass-through forwarders to the
5//! underlying [`ConnectionH1`] or [`ConnectionH2`] implementation — the
6//! local `forward!` macro removes the boilerplate.
7//!
8//! The two `Endpoint` adaptors ([`EndpointServer`], [`EndpointClient`]) are
9//! also defined here: they let a connection call back into either the
10//! frontend connection or the backend [`Router`] map without knowing which
11//! direction it faces.
12//!
13//! Edge-trigger discipline lives in `mux/h2.rs` (`writable`) — the canonical
14//! home for the `signal_pending_write` / `arm_writable` invariant. This
15//! module's abstractions delegate to that discipline through the
16//! protocol-specific writers.
17
18use std::{
19    cell::RefCell,
20    fmt::Debug,
21    rc::{Rc, Weak},
22    time::Instant,
23};
24
25use mio::{Token, net::TcpStream};
26use rusty_ulid::Ulid;
27use sozu_command::{logging::ansi_palette, ready::Ready};
28
29use super::{
30    BackendStatus, ConnectionH1, ConnectionH2, Context, Endpoint, GlobalStreamId, MuxResult,
31    Position, Router,
32    h2::{self, H2StreamId},
33};
34use crate::metrics::names;
35use crate::{
36    L7ListenerHandler, ListenerHandler, Readiness, backends::Backend, pool::Pool,
37    socket::SocketHandler, timer::TimeoutContainer,
38};
39
40/// Module-level prefix used on every log line emitted from this module.
41/// Produces a bold bright-white `MUX-CONN` label (uniform across every
42/// protocol) when the logger is in colored mode. Session-specific context
43/// cannot be derived here because most log sites are inside `Endpoint` adapter
44/// methods that only see the backend/frontend maps, not the wrapping
45/// [`Connection`].
46macro_rules! log_module_context {
47    () => {{
48        let (open, reset, _, _, _) = ansi_palette();
49        format!("{open}MUX-CONN{reset}\t >>>", open = open, reset = reset)
50    }};
51}
52
53#[derive(Debug)]
54#[allow(clippy::large_enum_variant)]
55pub enum Connection<Front: SocketHandler> {
56    H1(ConnectionH1<Front>),
57    H2(ConnectionH2<Front>),
58}
59
60// Dispatches a method call or field access to the inner H1/H2 connection.
61// Used by trivial pass-through methods on Connection<Front> to avoid
62// repeating the two-arm match.
63macro_rules! forward {
64    ($self:expr, $method:ident ( $($args:tt)* )) => {
65        match $self {
66            Connection::H1(c) => c.$method($($args)*),
67            Connection::H2(c) => c.$method($($args)*),
68        }
69    };
70    (&$self:expr, $field:ident) => {
71        match $self {
72            Connection::H1(c) => &c.$field,
73            Connection::H2(c) => &c.$field,
74        }
75    };
76    (&mut $self:expr, $field:ident) => {
77        match $self {
78            Connection::H1(c) => &mut c.$field,
79            Connection::H2(c) => &mut c.$field,
80        }
81    };
82}
83
84impl<Front: SocketHandler> Connection<Front> {
85    pub fn new_h1_server(
86        session_ulid: Ulid,
87        front_stream: Front,
88        timeout_container: TimeoutContainer,
89    ) -> Connection<Front> {
90        Connection::H1(ConnectionH1 {
91            socket: front_stream,
92            position: Position::Server,
93            readiness: Readiness {
94                interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
95                event: Ready::EMPTY,
96            },
97            requests: 0,
98            stream: Some(0),
99            timeout_container,
100            parked_on_buffer_pressure: false,
101            close_notify_sent: false,
102            session_ulid,
103        })
104    }
105    pub fn new_h1_client(
106        session_ulid: Ulid,
107        front_stream: Front,
108        cluster_id: String,
109        backend: Rc<RefCell<Backend>>,
110        timeout_container: TimeoutContainer,
111    ) -> Connection<Front> {
112        Connection::H1(ConnectionH1 {
113            socket: front_stream,
114            position: Position::Client(
115                cluster_id,
116                backend,
117                BackendStatus::Connecting(Instant::now()),
118            ),
119            readiness: Readiness {
120                interest: Ready::WRITABLE | Ready::READABLE | Ready::HUP | Ready::ERROR,
121                event: Ready::EMPTY,
122            },
123            stream: None,
124            requests: 0,
125            timeout_container,
126            parked_on_buffer_pressure: false,
127            close_notify_sent: false,
128            session_ulid,
129        })
130    }
131
132    #[allow(clippy::too_many_arguments)]
133    pub fn new_h2_server(
134        session_ulid: Ulid,
135        front_stream: Front,
136        pool: Weak<RefCell<Pool>>,
137        timeout_container: TimeoutContainer,
138        flood_config: h2::H2FloodConfig,
139        connection_config: h2::H2ConnectionConfig,
140        stream_idle_timeout: std::time::Duration,
141        graceful_shutdown_deadline: Option<std::time::Duration>,
142    ) -> Option<Connection<Front>> {
143        Some(Connection::H2(ConnectionH2::new(
144            session_ulid,
145            front_stream,
146            Position::Server,
147            pool,
148            flood_config,
149            connection_config,
150            stream_idle_timeout,
151            graceful_shutdown_deadline,
152            timeout_container,
153            Some((H2StreamId::Zero, h2::CLIENT_PREFACE_SIZE)),
154            Ready::READABLE | Ready::HUP | Ready::ERROR,
155        )?))
156    }
157
158    #[allow(clippy::too_many_arguments)]
159    pub fn new_h2_client(
160        session_ulid: Ulid,
161        front_stream: Front,
162        cluster_id: String,
163        backend: Rc<RefCell<Backend>>,
164        pool: Weak<RefCell<Pool>>,
165        timeout_container: TimeoutContainer,
166        flood_config: h2::H2FloodConfig,
167        connection_config: h2::H2ConnectionConfig,
168        stream_idle_timeout: std::time::Duration,
169        graceful_shutdown_deadline: Option<std::time::Duration>,
170    ) -> Option<Connection<Front>> {
171        // Test-only injection point: when set via
172        // [`__test_force_h2_client_failure`], pretend the pool was exhausted
173        // and return `None`. This mirrors the buffer-pool-exhaustion branch
174        // inside [`ConnectionH2::new`] deterministically so E2E tests can
175        // exercise `Router::connect`'s rollback path (FIX-18) without having
176        // to starve the pool in-process.
177        #[cfg(any(test, feature = "e2e-hooks"))]
178        if test_hooks::FORCE_NEW_H2_CLIENT_FAILURE.swap(false, std::sync::atomic::Ordering::SeqCst)
179        {
180            return None;
181        }
182        Some(Connection::H2(ConnectionH2::new(
183            session_ulid,
184            front_stream,
185            Position::Client(
186                cluster_id,
187                backend,
188                BackendStatus::Connecting(Instant::now()),
189            ),
190            pool,
191            flood_config,
192            connection_config,
193            stream_idle_timeout,
194            graceful_shutdown_deadline,
195            timeout_container,
196            None,
197            Ready::WRITABLE | Ready::HUP | Ready::ERROR,
198        )?))
199    }
200
201    pub fn readiness(&self) -> &Readiness {
202        forward!(&self, readiness)
203    }
204    pub fn readiness_mut(&mut self) -> &mut Readiness {
205        forward!(&mut self, readiness)
206    }
207    pub fn position(&self) -> &Position {
208        forward!(&self, position)
209    }
210    pub fn position_mut(&mut self) -> &mut Position {
211        forward!(&mut self, position)
212    }
213    pub fn socket(&self) -> &TcpStream {
214        match self {
215            Connection::H1(c) => c.socket.socket_ref(),
216            Connection::H2(c) => c.socket.socket_ref(),
217        }
218    }
219    pub fn socket_mut(&mut self) -> &mut TcpStream {
220        match self {
221            Connection::H1(c) => c.socket.socket_mut(),
222            Connection::H2(c) => c.socket.socket_mut(),
223        }
224    }
225    pub fn timeout_container(&mut self) -> &mut TimeoutContainer {
226        forward!(&mut self, timeout_container)
227    }
228
229    /// Returns connection-level byte overhead (bin, bout) for H2, (0, 0) for H1.
230    pub fn overhead_bytes(&self) -> (usize, usize) {
231        match self {
232            Connection::H1(_) => (0, 0),
233            Connection::H2(c) => (c.bytes.overhead_bin, c.bytes.overhead_bout),
234        }
235    }
236
237    pub(super) fn readable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
238    where
239        E: Endpoint,
240        L: ListenerHandler + L7ListenerHandler,
241    {
242        forward!(self, readable(context, endpoint))
243    }
244    pub(super) fn writable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
245    where
246        E: Endpoint,
247        L: ListenerHandler + L7ListenerHandler,
248    {
249        forward!(self, writable(context, endpoint))
250    }
251
252    /// Returns true if this connection could not read because its stream's
253    /// kawa buffer was full. Used to prevent the dead-backend check from
254    /// closing a backend that still has data in the OS socket buffer.
255    pub(super) fn has_buffer_pressure<L>(&self, context: &Context<L>) -> bool
256    where
257        L: ListenerHandler + L7ListenerHandler,
258    {
259        match self {
260            Connection::H1(c) => {
261                let Some(stream_id) = c.stream else {
262                    // No stream assigned — no buffer pressure.
263                    return false;
264                };
265                let kawa = match c.position {
266                    Position::Client(..) => &context.streams[stream_id].back,
267                    Position::Server => &context.streams[stream_id].front,
268                };
269                kawa.storage.available_space() == 0
270            }
271            // H2 connections manage their own flow control via expect_read
272            Connection::H2(_) => false,
273        }
274    }
275
276    /// Re-enable READABLE if this connection is parked waiting for buffer space
277    /// and the target stream's buffer now has enough room.
278    ///
279    /// For H1: checks the `parked_on_buffer_pressure` flag set when `readable`
280    /// exits early because the kawa buffer was full. Edge-triggered epoll will
281    /// not re-fire READABLE for data already in the kernel socket buffer, so
282    /// this is the only path that re-arms it after the peer drains space.
283    ///
284    /// For H2: checks the `expect_read` field tracking which stream and how
285    /// many bytes are needed.
286    pub(super) fn try_resume_reading<L>(&mut self, context: &Context<L>) -> bool
287    where
288        L: ListenerHandler + L7ListenerHandler,
289    {
290        match self {
291            Connection::H1(c) => {
292                if !c.parked_on_buffer_pressure {
293                    return false;
294                }
295                let Some(stream_id) = c.stream else {
296                    return false;
297                };
298                let kawa = match c.position {
299                    Position::Client(..) => &context.streams[stream_id].back,
300                    Position::Server => &context.streams[stream_id].front,
301                };
302                if kawa.storage.available_space() > 0 {
303                    trace!(
304                        "{} H1 try_resume_reading: re-arming READABLE",
305                        log_module_context!()
306                    );
307                    // Pre: we only reach here because the connection parked on
308                    // buffer pressure AND the kawa now has room. Pair the
309                    // synthetic READABLE event with that drained-space fact:
310                    // edge-triggered epoll won't re-fire on its own, so this is
311                    // the sole re-arm path.
312                    debug_assert!(
313                        c.parked_on_buffer_pressure,
314                        "re-arm only fires for a connection parked on buffer pressure"
315                    );
316                    c.readiness.signal_pending_read();
317                    debug_assert!(
318                        c.readiness.event.is_readable(),
319                        "signal_pending_read must leave a READABLE event queued"
320                    );
321                    true
322                } else {
323                    false
324                }
325            }
326            Connection::H2(c) => c.try_resume_reading(context),
327        }
328    }
329
330    pub(super) fn graceful_goaway(&mut self) -> MuxResult {
331        match self {
332            Connection::H1(_) => MuxResult::Continue,
333            Connection::H2(c) => c.graceful_goaway(),
334        }
335    }
336
337    pub(super) fn is_draining(&self) -> bool {
338        match self {
339            Connection::H1(_) => false,
340            Connection::H2(c) => c.drain.draining,
341        }
342    }
343
344    /// Proxy-side graceful-shutdown budget exhaustion check. Only H2
345    /// connections carry the timer — H1 has no multiplex to drain, so its
346    /// answer is always `false` and the H1 path continues to fall through
347    /// to the ordinary single-response close. See
348    /// [`h2::ConnectionH2::graceful_shutdown_deadline_elapsed`].
349    pub(super) fn graceful_shutdown_deadline_elapsed(&self) -> bool {
350        match self {
351            Connection::H1(_) => false,
352            Connection::H2(c) => c.graceful_shutdown_deadline_elapsed(),
353        }
354    }
355
356    pub(super) fn has_pending_write(&self) -> bool {
357        forward!(self, has_pending_write())
358    }
359
360    /// Connection-level [`Self::has_pending_write`] extended with a per-stream
361    /// back-buffer probe (LIFECYCLE §9 invariant 16). Only H2 multiplexes
362    /// multiple streams — H1 falls back to [`Self::has_pending_write`] since
363    /// its single-response pipeline already accounts for pending bytes.
364    pub(super) fn has_pending_write_including_streams<L>(&self, context: &super::Context<L>) -> bool
365    where
366        L: ListenerHandler + L7ListenerHandler,
367    {
368        match self {
369            Connection::H1(c) => c.has_pending_write(),
370            Connection::H2(c) => c.has_pending_write_full(context),
371        }
372    }
373
374    pub(super) fn initiate_close_notify(&mut self) -> bool {
375        forward!(self, initiate_close_notify())
376    }
377
378    pub(super) fn flush_zero_buffer(&mut self) {
379        if let Connection::H2(c) = self {
380            c.flush_zero_buffer();
381        }
382    }
383
384    fn pre_close_client_bookkeeping(&self) {
385        if let Position::Client(cluster_id, backend, _) = self.position() {
386            let mut backend_borrow = backend.borrow_mut();
387            // Pair the `dec_connections` with its prior value: the close path
388            // releases exactly one slot. `dec_connections` floors at 0 (a
389            // double-close from a desynced peer must not panic), so the
390            // post-relation is "decreased by one, unless already at zero".
391            let before = backend_borrow.active_connections;
392            backend_borrow.dec_connections();
393            debug_assert_eq!(
394                backend_borrow.active_connections,
395                before.saturating_sub(1),
396                "close must release exactly one backend connection (saturating at 0)"
397            );
398            gauge_add!(names::backend::CONNECTIONS, -1);
399            // Pair with the `+1` at `router.rs::connect` (new-dial path).
400            // This is the graceful-close decrement, used both by the dead
401            // backend path in `mod.rs::back_readable` (which routes through
402            // `client.close()`) and by any explicit Connection::close.
403            gauge_add!(names::backend::POOL_SIZE, -1);
404            gauge_add!(
405                names::backend::CONNECTIONS_PER_BACKEND,
406                -1,
407                Some(cluster_id),
408                Some(&backend_borrow.backend_id)
409            );
410            trace!(
411                "{} connection close: {:#?}",
412                log_module_context!(),
413                backend_borrow
414            );
415        }
416    }
417
418    fn pre_end_stream_client_bookkeeping(&self) {
419        if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
420            let mut backend_borrow = backend.borrow_mut();
421            // Pairs with the `+1` in `pre_start_stream_client_bookkeeping`.
422            // `saturating_sub` is the network-safe floor (a desync from the
423            // peer must not panic), so we can only assert the post-relation,
424            // not that `before > 0`.
425            let before = backend_borrow.active_requests;
426            backend_borrow.active_requests = backend_borrow.active_requests.saturating_sub(1);
427            debug_assert_eq!(
428                backend_borrow.active_requests,
429                before.saturating_sub(1),
430                "end_stream bookkeeping must decrement active_requests by one (saturating)"
431            );
432            debug_assert!(
433                backend_borrow.active_requests <= before,
434                "active_requests must not grow on stream end"
435            );
436            trace!(
437                "{} connection end stream: {:#?}",
438                log_module_context!(),
439                backend_borrow
440            );
441        }
442    }
443
444    fn pre_start_stream_client_bookkeeping(&self) {
445        if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
446            let mut backend_borrow = backend.borrow_mut();
447            // Pairs with the `saturating_sub(1)` in the end path. Snapshot the
448            // counter so we can assert it advanced by exactly one.
449            let before = backend_borrow.active_requests;
450            backend_borrow.active_requests += 1;
451            debug_assert_eq!(
452                backend_borrow.active_requests,
453                before + 1,
454                "start_stream bookkeeping must increment active_requests by exactly one"
455            );
456            trace!(
457                "{} connection start stream: {:#?}",
458                log_module_context!(),
459                backend_borrow
460            );
461        }
462    }
463
464    pub(super) fn close<E, L>(&mut self, context: &mut Context<L>, endpoint: E)
465    where
466        E: Endpoint,
467        L: ListenerHandler + L7ListenerHandler,
468    {
469        self.pre_close_client_bookkeeping();
470        forward!(self, close(context, endpoint))
471    }
472
473    pub(super) fn end_stream<L>(&mut self, stream: GlobalStreamId, context: &mut Context<L>)
474    where
475        L: ListenerHandler + L7ListenerHandler,
476    {
477        self.pre_end_stream_client_bookkeeping();
478        forward!(self, end_stream(stream, context))
479    }
480
481    pub(super) fn start_stream<L>(
482        &mut self,
483        stream: GlobalStreamId,
484        context: &mut Context<L>,
485    ) -> bool
486    where
487        L: ListenerHandler + L7ListenerHandler,
488    {
489        // Snapshot the backend's in-flight counter so we can assert the
490        // increment/rollback is net-zero on the failure path: a refused
491        // `start_stream` must NOT leak an `active_requests` charge onto the
492        // backend (it would skew least-loaded balancing forever). The snapshot
493        // and its assert are both cfg'd so the `RefCell` borrow never exists in
494        // release (the helper is `#[cfg(debug_assertions)]` too).
495        #[cfg(debug_assertions)]
496        let before = self.backend_active_requests();
497        self.pre_start_stream_client_bookkeeping();
498        let started = forward!(self, start_stream(stream, context));
499        if !started {
500            // Undo active_requests increment on failure
501            self.pre_end_stream_client_bookkeeping();
502            #[cfg(debug_assertions)]
503            debug_assert_eq!(
504                self.backend_active_requests(),
505                before,
506                "a refused start_stream must roll active_requests back to its prior value"
507            );
508        }
509        started
510    }
511
512    /// Snapshot the backend's in-flight request counter, or `None` when this
513    /// connection has no `Connected` backend (server position, or a client
514    /// that is still connecting / already keep-alive). Used by `start_stream`
515    /// to pair-assert the increment/rollback is net-zero on refusal. Cheap
516    /// `RefCell` borrow, compiled only with `debug_assertions`.
517    #[cfg(debug_assertions)]
518    fn backend_active_requests(&self) -> Option<usize> {
519        match self.position() {
520            Position::Client(_, backend, BackendStatus::Connected) => {
521                Some(backend.borrow().active_requests)
522            }
523            _ => None,
524        }
525    }
526}
527
528#[derive(Debug)]
529pub(super) struct EndpointServer<'a, Front: SocketHandler>(pub &'a mut Connection<Front>);
530#[derive(Debug)]
531pub(super) struct EndpointClient<'a>(pub &'a mut Router);
532
533// note: EndpointServer are used by client Connection, they do not know the frontend Token
534// they will use the Stream's Token which is their backend token
535impl<Front: SocketHandler + Debug> Endpoint for EndpointServer<'_, Front> {
536    fn readiness(&self, _token: Token) -> &Readiness {
537        self.0.readiness()
538    }
539    fn readiness_mut(&mut self, _token: Token) -> &mut Readiness {
540        self.0.readiness_mut()
541    }
542    fn socket(&self, _token: Token) -> Option<&TcpStream> {
543        Some(self.0.socket())
544    }
545
546    fn end_stream<L>(&mut self, _token: Token, stream: GlobalStreamId, context: &mut Context<L>)
547    where
548        L: ListenerHandler + L7ListenerHandler,
549    {
550        // this may be used to forward H2<->H2 RstStream
551        // or to handle backend hup
552        self.0.end_stream(stream, context);
553    }
554
555    fn start_stream<L>(
556        &mut self,
557        _token: Token,
558        stream: GlobalStreamId,
559        context: &mut Context<L>,
560    ) -> bool
561    where
562        L: ListenerHandler + L7ListenerHandler,
563    {
564        // Forward stream start to the frontend connection.
565        // This is used when a backend H2 connection starts a new stream
566        // (e.g. for H2<->H2 proxying or PUSH_PROMISE forwarding).
567        self.0.start_stream(stream, context)
568    }
569}
570impl Endpoint for EndpointClient<'_> {
571    fn readiness(&self, token: Token) -> &Readiness {
572        match self.0.backends.get(&token) {
573            Some(backend) => backend.readiness(),
574            None => {
575                error!(
576                    "{} backend token {:?} missing from backends map (readiness)",
577                    log_module_context!(),
578                    token
579                );
580                &self.0.fallback_readiness
581            }
582        }
583    }
584    fn readiness_mut(&mut self, token: Token) -> &mut Readiness {
585        match self.0.backends.get_mut(&token) {
586            Some(backend) => backend.readiness_mut(),
587            None => {
588                error!(
589                    "{} backend token {:?} missing from backends map (readiness_mut)",
590                    log_module_context!(),
591                    token
592                );
593                &mut self.0.fallback_readiness
594            }
595        }
596    }
597    fn socket(&self, token: Token) -> Option<&TcpStream> {
598        self.0.backends.get(&token).map(|c| c.socket())
599    }
600
601    fn end_stream<L>(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context<L>)
602    where
603        L: ListenerHandler + L7ListenerHandler,
604    {
605        match self.0.backends.get_mut(&token) {
606            Some(backend) => backend.end_stream(stream, context),
607            None => {
608                error!(
609                    "{} backend token {:?} missing from backends map (end_stream)",
610                    log_module_context!(),
611                    token
612                );
613            }
614        }
615    }
616
617    fn start_stream<L>(
618        &mut self,
619        token: Token,
620        stream: GlobalStreamId,
621        context: &mut Context<L>,
622    ) -> bool
623    where
624        L: ListenerHandler + L7ListenerHandler,
625    {
626        match self.0.backends.get_mut(&token) {
627            Some(backend) => backend.start_stream(stream, context),
628            None => {
629                error!(
630                    "{} backend token {:?} missing from backends map (start_stream)",
631                    log_module_context!(),
632                    token
633                );
634                false
635            }
636        }
637    }
638}
639
640/// Test-only injection hooks for the mux layer.
641///
642/// These are compiled **only** when running `cargo test` (or with
643/// `cfg(test)` enabled); downstream code must not rely on them. They exist
644/// so end-to-end tests can drive hard-to-reach code paths — buffer-pool
645/// exhaustion during backend attach, stream-ID exhaustion — without
646/// having to reproduce the underlying resource starvation in-process.
647#[cfg(any(test, feature = "e2e-hooks"))]
648pub mod test_hooks {
649    use std::sync::atomic::AtomicBool;
650
651    /// When `true`, the next call to [`super::Connection::new_h2_client`]
652    /// returns `None` as if the buffer pool were exhausted. The flag is
653    /// consumed (reset to `false`) by that call so each opt-in is scoped
654    /// to exactly one attempted backend attach.
655    pub static FORCE_NEW_H2_CLIENT_FAILURE: AtomicBool = AtomicBool::new(false);
656
657    /// Arm or disarm the `new_h2_client` failure injection. Returns the
658    /// previous value so tests can stack-save/restore if they run in
659    /// parallel (`cargo test` defaults to serial for this crate because
660    /// of global registries, but keep the API honest).
661    pub fn __test_force_h2_client_failure(on: bool) -> bool {
662        FORCE_NEW_H2_CLIENT_FAILURE.swap(on, std::sync::atomic::Ordering::SeqCst)
663    }
664}