Skip to main content

sozu_lib/protocol/mux/
connection.rs

1//! Protocol-agnostic frontend/backend connection wrapper.
2//!
3//! [`Connection`] is the H1/H2 dispatch enum used everywhere in the mux
4//! layer. Most of the methods are trivial pass-through forwarders to the
5//! underlying [`ConnectionH1`] or [`ConnectionH2`] implementation — the
6//! local `forward!` macro removes the boilerplate.
7//!
8//! The two `Endpoint` adaptors ([`EndpointServer`], [`EndpointClient`]) are
9//! also defined here: they let a connection call back into either the
10//! frontend connection or the backend [`Router`] map without knowing which
11//! direction it faces.
12//!
13//! Edge-trigger discipline lives in `mux/h2.rs` (`writable`) — the canonical
14//! home for the `signal_pending_write` / `arm_writable` invariant. This
15//! module's abstractions delegate to that discipline through the
16//! protocol-specific writers.
17
18use std::{
19    cell::RefCell,
20    fmt::Debug,
21    rc::{Rc, Weak},
22    time::Instant,
23};
24
25use mio::{Token, net::TcpStream};
26use rusty_ulid::Ulid;
27use sozu_command::{logging::ansi_palette, ready::Ready};
28
29use super::{
30    BackendStatus, ConnectionH1, ConnectionH2, Context, Endpoint, GlobalStreamId, MuxResult,
31    Position, Router,
32    h2::{self, H2StreamId},
33};
34use crate::metrics::names;
35use crate::{
36    L7ListenerHandler, ListenerHandler, Readiness, backends::Backend, pool::Pool,
37    socket::SocketHandler, timer::TimeoutContainer,
38};
39
40/// Module-level prefix used on every log line emitted from this module.
41/// Produces a bold bright-white `MUX-CONN` label (uniform across every
42/// protocol) when the logger is in colored mode. Session-specific context
43/// cannot be derived here because most log sites are inside `Endpoint` adapter
44/// methods that only see the backend/frontend maps, not the wrapping
45/// [`Connection`].
46macro_rules! log_module_context {
47    () => {{
48        let (open, reset, _, _, _) = ansi_palette();
49        format!("{open}MUX-CONN{reset}\t >>>", open = open, reset = reset)
50    }};
51}
52
53#[derive(Debug)]
54#[allow(clippy::large_enum_variant)]
55pub enum Connection<Front: SocketHandler> {
56    H1(ConnectionH1<Front>),
57    H2(ConnectionH2<Front>),
58}
59
60// Dispatches a method call or field access to the inner H1/H2 connection.
61// Used by trivial pass-through methods on Connection<Front> to avoid
62// repeating the two-arm match.
63macro_rules! forward {
64    ($self:expr, $method:ident ( $($args:tt)* )) => {
65        match $self {
66            Connection::H1(c) => c.$method($($args)*),
67            Connection::H2(c) => c.$method($($args)*),
68        }
69    };
70    (&$self:expr, $field:ident) => {
71        match $self {
72            Connection::H1(c) => &c.$field,
73            Connection::H2(c) => &c.$field,
74        }
75    };
76    (&mut $self:expr, $field:ident) => {
77        match $self {
78            Connection::H1(c) => &mut c.$field,
79            Connection::H2(c) => &mut c.$field,
80        }
81    };
82}
83
84impl<Front: SocketHandler> Connection<Front> {
85    pub fn new_h1_server(
86        session_ulid: Ulid,
87        front_stream: Front,
88        timeout_container: TimeoutContainer,
89    ) -> Connection<Front> {
90        Connection::H1(ConnectionH1 {
91            socket: front_stream,
92            position: Position::Server,
93            readiness: Readiness {
94                interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
95                event: Ready::EMPTY,
96            },
97            requests: 0,
98            stream: Some(0),
99            timeout_container,
100            parked_on_buffer_pressure: false,
101            close_notify_sent: false,
102            session_ulid,
103        })
104    }
105    pub fn new_h1_client(
106        session_ulid: Ulid,
107        front_stream: Front,
108        cluster_id: String,
109        backend: Rc<RefCell<Backend>>,
110        timeout_container: TimeoutContainer,
111    ) -> Connection<Front> {
112        Connection::H1(ConnectionH1 {
113            socket: front_stream,
114            position: Position::Client(
115                cluster_id,
116                backend,
117                BackendStatus::Connecting(Instant::now()),
118            ),
119            readiness: Readiness {
120                interest: Ready::WRITABLE | Ready::READABLE | Ready::HUP | Ready::ERROR,
121                event: Ready::EMPTY,
122            },
123            stream: None,
124            requests: 0,
125            timeout_container,
126            parked_on_buffer_pressure: false,
127            close_notify_sent: false,
128            session_ulid,
129        })
130    }
131
132    #[allow(clippy::too_many_arguments)]
133    pub fn new_h2_server(
134        session_ulid: Ulid,
135        front_stream: Front,
136        pool: Weak<RefCell<Pool>>,
137        timeout_container: TimeoutContainer,
138        flood_config: h2::H2FloodConfig,
139        connection_config: h2::H2ConnectionConfig,
140        stream_idle_timeout: std::time::Duration,
141        graceful_shutdown_deadline: Option<std::time::Duration>,
142    ) -> Option<Connection<Front>> {
143        Some(Connection::H2(ConnectionH2::new(
144            session_ulid,
145            front_stream,
146            Position::Server,
147            pool,
148            flood_config,
149            connection_config,
150            stream_idle_timeout,
151            graceful_shutdown_deadline,
152            timeout_container,
153            Some((H2StreamId::Zero, h2::CLIENT_PREFACE_SIZE)),
154            Ready::READABLE | Ready::HUP | Ready::ERROR,
155        )?))
156    }
157
158    #[allow(clippy::too_many_arguments)]
159    pub fn new_h2_client(
160        session_ulid: Ulid,
161        front_stream: Front,
162        cluster_id: String,
163        backend: Rc<RefCell<Backend>>,
164        pool: Weak<RefCell<Pool>>,
165        timeout_container: TimeoutContainer,
166        flood_config: h2::H2FloodConfig,
167        connection_config: h2::H2ConnectionConfig,
168        stream_idle_timeout: std::time::Duration,
169        graceful_shutdown_deadline: Option<std::time::Duration>,
170    ) -> Option<Connection<Front>> {
171        // Test-only injection point: when set via
172        // [`__test_force_h2_client_failure`], pretend the pool was exhausted
173        // and return `None`. This mirrors the buffer-pool-exhaustion branch
174        // inside [`ConnectionH2::new`] deterministically so E2E tests can
175        // exercise `Router::connect`'s rollback path (FIX-18) without having
176        // to starve the pool in-process.
177        #[cfg(any(test, feature = "e2e-hooks"))]
178        if test_hooks::FORCE_NEW_H2_CLIENT_FAILURE.swap(false, std::sync::atomic::Ordering::SeqCst)
179        {
180            return None;
181        }
182        Some(Connection::H2(ConnectionH2::new(
183            session_ulid,
184            front_stream,
185            Position::Client(
186                cluster_id,
187                backend,
188                BackendStatus::Connecting(Instant::now()),
189            ),
190            pool,
191            flood_config,
192            connection_config,
193            stream_idle_timeout,
194            graceful_shutdown_deadline,
195            timeout_container,
196            None,
197            Ready::WRITABLE | Ready::HUP | Ready::ERROR,
198        )?))
199    }
200
201    pub fn readiness(&self) -> &Readiness {
202        forward!(&self, readiness)
203    }
204    pub fn readiness_mut(&mut self) -> &mut Readiness {
205        forward!(&mut self, readiness)
206    }
207    pub fn position(&self) -> &Position {
208        forward!(&self, position)
209    }
210    pub fn position_mut(&mut self) -> &mut Position {
211        forward!(&mut self, position)
212    }
213    pub fn socket(&self) -> &TcpStream {
214        match self {
215            Connection::H1(c) => c.socket.socket_ref(),
216            Connection::H2(c) => c.socket.socket_ref(),
217        }
218    }
219    pub fn socket_mut(&mut self) -> &mut TcpStream {
220        match self {
221            Connection::H1(c) => c.socket.socket_mut(),
222            Connection::H2(c) => c.socket.socket_mut(),
223        }
224    }
225    pub fn timeout_container(&mut self) -> &mut TimeoutContainer {
226        forward!(&mut self, timeout_container)
227    }
228
229    /// Returns connection-level byte overhead (bin, bout) for H2, (0, 0) for H1.
230    pub fn overhead_bytes(&self) -> (usize, usize) {
231        match self {
232            Connection::H1(_) => (0, 0),
233            Connection::H2(c) => (c.bytes.overhead_bin, c.bytes.overhead_bout),
234        }
235    }
236
237    pub(super) fn readable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
238    where
239        E: Endpoint,
240        L: ListenerHandler + L7ListenerHandler,
241    {
242        forward!(self, readable(context, endpoint))
243    }
244    pub(super) fn writable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
245    where
246        E: Endpoint,
247        L: ListenerHandler + L7ListenerHandler,
248    {
249        forward!(self, writable(context, endpoint))
250    }
251
252    /// Returns true if this connection could not read because its stream's
253    /// kawa buffer was full. Used to prevent the dead-backend check from
254    /// closing a backend that still has data in the OS socket buffer.
255    pub(super) fn has_buffer_pressure<L>(&self, context: &Context<L>) -> bool
256    where
257        L: ListenerHandler + L7ListenerHandler,
258    {
259        match self {
260            Connection::H1(c) => {
261                let Some(stream_id) = c.stream else {
262                    // No stream assigned — no buffer pressure.
263                    return false;
264                };
265                let kawa = match c.position {
266                    Position::Client(..) => &context.streams[stream_id].back,
267                    Position::Server => &context.streams[stream_id].front,
268                };
269                kawa.storage.available_space() == 0
270            }
271            // H2 connections manage their own flow control via expect_read
272            Connection::H2(_) => false,
273        }
274    }
275
276    /// Re-enable READABLE if this connection is parked waiting for buffer space
277    /// and the target stream's buffer now has enough room.
278    ///
279    /// For H1: checks the `parked_on_buffer_pressure` flag set when `readable`
280    /// exits early because the kawa buffer was full. Edge-triggered epoll will
281    /// not re-fire READABLE for data already in the kernel socket buffer, so
282    /// this is the only path that re-arms it after the peer drains space.
283    ///
284    /// For H2: checks the `expect_read` field tracking which stream and how
285    /// many bytes are needed.
286    pub(super) fn try_resume_reading<L>(&mut self, context: &Context<L>) -> bool
287    where
288        L: ListenerHandler + L7ListenerHandler,
289    {
290        match self {
291            Connection::H1(c) => {
292                if !c.parked_on_buffer_pressure {
293                    return false;
294                }
295                let Some(stream_id) = c.stream else {
296                    return false;
297                };
298                let kawa = match c.position {
299                    Position::Client(..) => &context.streams[stream_id].back,
300                    Position::Server => &context.streams[stream_id].front,
301                };
302                if kawa.storage.available_space() > 0 {
303                    trace!(
304                        "{} H1 try_resume_reading: re-arming READABLE",
305                        log_module_context!()
306                    );
307                    c.readiness.signal_pending_read();
308                    true
309                } else {
310                    false
311                }
312            }
313            Connection::H2(c) => c.try_resume_reading(context),
314        }
315    }
316
317    pub(super) fn graceful_goaway(&mut self) -> MuxResult {
318        match self {
319            Connection::H1(_) => MuxResult::Continue,
320            Connection::H2(c) => c.graceful_goaway(),
321        }
322    }
323
324    pub(super) fn is_draining(&self) -> bool {
325        match self {
326            Connection::H1(_) => false,
327            Connection::H2(c) => c.drain.draining,
328        }
329    }
330
331    /// Proxy-side graceful-shutdown budget exhaustion check. Only H2
332    /// connections carry the timer — H1 has no multiplex to drain, so its
333    /// answer is always `false` and the H1 path continues to fall through
334    /// to the ordinary single-response close. See
335    /// [`h2::ConnectionH2::graceful_shutdown_deadline_elapsed`].
336    pub(super) fn graceful_shutdown_deadline_elapsed(&self) -> bool {
337        match self {
338            Connection::H1(_) => false,
339            Connection::H2(c) => c.graceful_shutdown_deadline_elapsed(),
340        }
341    }
342
343    pub(super) fn has_pending_write(&self) -> bool {
344        forward!(self, has_pending_write())
345    }
346
347    /// Connection-level [`Self::has_pending_write`] extended with a per-stream
348    /// back-buffer probe (LIFECYCLE §9 invariant 16). Only H2 multiplexes
349    /// multiple streams — H1 falls back to [`Self::has_pending_write`] since
350    /// its single-response pipeline already accounts for pending bytes.
351    pub(super) fn has_pending_write_including_streams<L>(&self, context: &super::Context<L>) -> bool
352    where
353        L: ListenerHandler + L7ListenerHandler,
354    {
355        match self {
356            Connection::H1(c) => c.has_pending_write(),
357            Connection::H2(c) => c.has_pending_write_full(context),
358        }
359    }
360
361    pub(super) fn initiate_close_notify(&mut self) -> bool {
362        forward!(self, initiate_close_notify())
363    }
364
365    pub(super) fn flush_zero_buffer(&mut self) {
366        if let Connection::H2(c) = self {
367            c.flush_zero_buffer();
368        }
369    }
370
371    fn pre_close_client_bookkeeping(&self) {
372        if let Position::Client(cluster_id, backend, _) = self.position() {
373            let mut backend_borrow = backend.borrow_mut();
374            backend_borrow.dec_connections();
375            gauge_add!(names::backend::CONNECTIONS, -1);
376            // Pair with the `+1` at `router.rs::connect` (new-dial path).
377            // This is the graceful-close decrement, used both by the dead
378            // backend path in `mod.rs::back_readable` (which routes through
379            // `client.close()`) and by any explicit Connection::close.
380            gauge_add!(names::backend::POOL_SIZE, -1);
381            gauge_add!(
382                names::backend::CONNECTIONS_PER_BACKEND,
383                -1,
384                Some(cluster_id),
385                Some(&backend_borrow.backend_id)
386            );
387            trace!(
388                "{} connection close: {:#?}",
389                log_module_context!(),
390                backend_borrow
391            );
392        }
393    }
394
395    fn pre_end_stream_client_bookkeeping(&self) {
396        if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
397            let mut backend_borrow = backend.borrow_mut();
398            backend_borrow.active_requests = backend_borrow.active_requests.saturating_sub(1);
399            trace!(
400                "{} connection end stream: {:#?}",
401                log_module_context!(),
402                backend_borrow
403            );
404        }
405    }
406
407    fn pre_start_stream_client_bookkeeping(&self) {
408        if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
409            let mut backend_borrow = backend.borrow_mut();
410            backend_borrow.active_requests += 1;
411            trace!(
412                "{} connection start stream: {:#?}",
413                log_module_context!(),
414                backend_borrow
415            );
416        }
417    }
418
419    pub(super) fn close<E, L>(&mut self, context: &mut Context<L>, endpoint: E)
420    where
421        E: Endpoint,
422        L: ListenerHandler + L7ListenerHandler,
423    {
424        self.pre_close_client_bookkeeping();
425        forward!(self, close(context, endpoint))
426    }
427
428    pub(super) fn end_stream<L>(&mut self, stream: GlobalStreamId, context: &mut Context<L>)
429    where
430        L: ListenerHandler + L7ListenerHandler,
431    {
432        self.pre_end_stream_client_bookkeeping();
433        forward!(self, end_stream(stream, context))
434    }
435
436    pub(super) fn start_stream<L>(
437        &mut self,
438        stream: GlobalStreamId,
439        context: &mut Context<L>,
440    ) -> bool
441    where
442        L: ListenerHandler + L7ListenerHandler,
443    {
444        self.pre_start_stream_client_bookkeeping();
445        let started = forward!(self, start_stream(stream, context));
446        if !started {
447            // Undo active_requests increment on failure
448            self.pre_end_stream_client_bookkeeping();
449        }
450        started
451    }
452}
453
454#[derive(Debug)]
455pub(super) struct EndpointServer<'a, Front: SocketHandler>(pub &'a mut Connection<Front>);
456#[derive(Debug)]
457pub(super) struct EndpointClient<'a>(pub &'a mut Router);
458
459// note: EndpointServer are used by client Connection, they do not know the frontend Token
460// they will use the Stream's Token which is their backend token
461impl<Front: SocketHandler + Debug> Endpoint for EndpointServer<'_, Front> {
462    fn readiness(&self, _token: Token) -> &Readiness {
463        self.0.readiness()
464    }
465    fn readiness_mut(&mut self, _token: Token) -> &mut Readiness {
466        self.0.readiness_mut()
467    }
468    fn socket(&self, _token: Token) -> Option<&TcpStream> {
469        Some(self.0.socket())
470    }
471
472    fn end_stream<L>(&mut self, _token: Token, stream: GlobalStreamId, context: &mut Context<L>)
473    where
474        L: ListenerHandler + L7ListenerHandler,
475    {
476        // this may be used to forward H2<->H2 RstStream
477        // or to handle backend hup
478        self.0.end_stream(stream, context);
479    }
480
481    fn start_stream<L>(
482        &mut self,
483        _token: Token,
484        stream: GlobalStreamId,
485        context: &mut Context<L>,
486    ) -> bool
487    where
488        L: ListenerHandler + L7ListenerHandler,
489    {
490        // Forward stream start to the frontend connection.
491        // This is used when a backend H2 connection starts a new stream
492        // (e.g. for H2<->H2 proxying or PUSH_PROMISE forwarding).
493        self.0.start_stream(stream, context)
494    }
495}
496impl Endpoint for EndpointClient<'_> {
497    fn readiness(&self, token: Token) -> &Readiness {
498        match self.0.backends.get(&token) {
499            Some(backend) => backend.readiness(),
500            None => {
501                error!(
502                    "{} backend token {:?} missing from backends map (readiness)",
503                    log_module_context!(),
504                    token
505                );
506                &self.0.fallback_readiness
507            }
508        }
509    }
510    fn readiness_mut(&mut self, token: Token) -> &mut Readiness {
511        match self.0.backends.get_mut(&token) {
512            Some(backend) => backend.readiness_mut(),
513            None => {
514                error!(
515                    "{} backend token {:?} missing from backends map (readiness_mut)",
516                    log_module_context!(),
517                    token
518                );
519                &mut self.0.fallback_readiness
520            }
521        }
522    }
523    fn socket(&self, token: Token) -> Option<&TcpStream> {
524        self.0.backends.get(&token).map(|c| c.socket())
525    }
526
527    fn end_stream<L>(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context<L>)
528    where
529        L: ListenerHandler + L7ListenerHandler,
530    {
531        match self.0.backends.get_mut(&token) {
532            Some(backend) => backend.end_stream(stream, context),
533            None => {
534                error!(
535                    "{} backend token {:?} missing from backends map (end_stream)",
536                    log_module_context!(),
537                    token
538                );
539            }
540        }
541    }
542
543    fn start_stream<L>(
544        &mut self,
545        token: Token,
546        stream: GlobalStreamId,
547        context: &mut Context<L>,
548    ) -> bool
549    where
550        L: ListenerHandler + L7ListenerHandler,
551    {
552        match self.0.backends.get_mut(&token) {
553            Some(backend) => backend.start_stream(stream, context),
554            None => {
555                error!(
556                    "{} backend token {:?} missing from backends map (start_stream)",
557                    log_module_context!(),
558                    token
559                );
560                false
561            }
562        }
563    }
564}
565
566/// Test-only injection hooks for the mux layer.
567///
568/// These are compiled **only** when running `cargo test` (or with
569/// `cfg(test)` enabled); downstream code must not rely on them. They exist
570/// so end-to-end tests can drive hard-to-reach code paths — buffer-pool
571/// exhaustion during backend attach, stream-ID exhaustion — without
572/// having to reproduce the underlying resource starvation in-process.
573#[cfg(any(test, feature = "e2e-hooks"))]
574pub mod test_hooks {
575    use std::sync::atomic::AtomicBool;
576
577    /// When `true`, the next call to [`super::Connection::new_h2_client`]
578    /// returns `None` as if the buffer pool were exhausted. The flag is
579    /// consumed (reset to `false`) by that call so each opt-in is scoped
580    /// to exactly one attempted backend attach.
581    pub static FORCE_NEW_H2_CLIENT_FAILURE: AtomicBool = AtomicBool::new(false);
582
583    /// Arm or disarm the `new_h2_client` failure injection. Returns the
584    /// previous value so tests can stack-save/restore if they run in
585    /// parallel (`cargo test` defaults to serial for this crate because
586    /// of global registries, but keep the API honest).
587    pub fn __test_force_h2_client_failure(on: bool) -> bool {
588        FORCE_NEW_H2_CLIENT_FAILURE.swap(on, std::sync::atomic::Ordering::SeqCst)
589    }
590}