1use std::{
19 cell::RefCell,
20 fmt::Debug,
21 rc::{Rc, Weak},
22 time::Instant,
23};
24
25use mio::{Token, net::TcpStream};
26use rusty_ulid::Ulid;
27use sozu_command::{logging::ansi_palette, ready::Ready};
28
29use super::{
30 BackendStatus, ConnectionH1, ConnectionH2, Context, Endpoint, GlobalStreamId, MuxResult,
31 Position, Router,
32 h2::{self, H2StreamId},
33};
34use crate::metrics::names;
35use crate::{
36 L7ListenerHandler, ListenerHandler, Readiness, backends::Backend, pool::Pool,
37 socket::SocketHandler, timer::TimeoutContainer,
38};
39
40macro_rules! log_module_context {
47 () => {{
48 let (open, reset, _, _, _) = ansi_palette();
49 format!("{open}MUX-CONN{reset}\t >>>", open = open, reset = reset)
50 }};
51}
52
53#[derive(Debug)]
54#[allow(clippy::large_enum_variant)]
55pub enum Connection<Front: SocketHandler> {
56 H1(ConnectionH1<Front>),
57 H2(ConnectionH2<Front>),
58}
59
60macro_rules! forward {
64 ($self:expr, $method:ident ( $($args:tt)* )) => {
65 match $self {
66 Connection::H1(c) => c.$method($($args)*),
67 Connection::H2(c) => c.$method($($args)*),
68 }
69 };
70 (&$self:expr, $field:ident) => {
71 match $self {
72 Connection::H1(c) => &c.$field,
73 Connection::H2(c) => &c.$field,
74 }
75 };
76 (&mut $self:expr, $field:ident) => {
77 match $self {
78 Connection::H1(c) => &mut c.$field,
79 Connection::H2(c) => &mut c.$field,
80 }
81 };
82}
83
84impl<Front: SocketHandler> Connection<Front> {
85 pub fn new_h1_server(
86 session_ulid: Ulid,
87 front_stream: Front,
88 timeout_container: TimeoutContainer,
89 ) -> Connection<Front> {
90 Connection::H1(ConnectionH1 {
91 socket: front_stream,
92 position: Position::Server,
93 readiness: Readiness {
94 interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
95 event: Ready::EMPTY,
96 },
97 requests: 0,
98 stream: Some(0),
99 timeout_container,
100 parked_on_buffer_pressure: false,
101 close_notify_sent: false,
102 session_ulid,
103 })
104 }
105 pub fn new_h1_client(
106 session_ulid: Ulid,
107 front_stream: Front,
108 cluster_id: String,
109 backend: Rc<RefCell<Backend>>,
110 timeout_container: TimeoutContainer,
111 ) -> Connection<Front> {
112 Connection::H1(ConnectionH1 {
113 socket: front_stream,
114 position: Position::Client(
115 cluster_id,
116 backend,
117 BackendStatus::Connecting(Instant::now()),
118 ),
119 readiness: Readiness {
120 interest: Ready::WRITABLE | Ready::READABLE | Ready::HUP | Ready::ERROR,
121 event: Ready::EMPTY,
122 },
123 stream: None,
124 requests: 0,
125 timeout_container,
126 parked_on_buffer_pressure: false,
127 close_notify_sent: false,
128 session_ulid,
129 })
130 }
131
132 #[allow(clippy::too_many_arguments)]
133 pub fn new_h2_server(
134 session_ulid: Ulid,
135 front_stream: Front,
136 pool: Weak<RefCell<Pool>>,
137 timeout_container: TimeoutContainer,
138 flood_config: h2::H2FloodConfig,
139 connection_config: h2::H2ConnectionConfig,
140 stream_idle_timeout: std::time::Duration,
141 graceful_shutdown_deadline: Option<std::time::Duration>,
142 ) -> Option<Connection<Front>> {
143 Some(Connection::H2(ConnectionH2::new(
144 session_ulid,
145 front_stream,
146 Position::Server,
147 pool,
148 flood_config,
149 connection_config,
150 stream_idle_timeout,
151 graceful_shutdown_deadline,
152 timeout_container,
153 Some((H2StreamId::Zero, h2::CLIENT_PREFACE_SIZE)),
154 Ready::READABLE | Ready::HUP | Ready::ERROR,
155 )?))
156 }
157
158 #[allow(clippy::too_many_arguments)]
159 pub fn new_h2_client(
160 session_ulid: Ulid,
161 front_stream: Front,
162 cluster_id: String,
163 backend: Rc<RefCell<Backend>>,
164 pool: Weak<RefCell<Pool>>,
165 timeout_container: TimeoutContainer,
166 flood_config: h2::H2FloodConfig,
167 connection_config: h2::H2ConnectionConfig,
168 stream_idle_timeout: std::time::Duration,
169 graceful_shutdown_deadline: Option<std::time::Duration>,
170 ) -> Option<Connection<Front>> {
171 #[cfg(any(test, feature = "e2e-hooks"))]
178 if test_hooks::FORCE_NEW_H2_CLIENT_FAILURE.swap(false, std::sync::atomic::Ordering::SeqCst)
179 {
180 return None;
181 }
182 Some(Connection::H2(ConnectionH2::new(
183 session_ulid,
184 front_stream,
185 Position::Client(
186 cluster_id,
187 backend,
188 BackendStatus::Connecting(Instant::now()),
189 ),
190 pool,
191 flood_config,
192 connection_config,
193 stream_idle_timeout,
194 graceful_shutdown_deadline,
195 timeout_container,
196 None,
197 Ready::WRITABLE | Ready::HUP | Ready::ERROR,
198 )?))
199 }
200
201 pub fn readiness(&self) -> &Readiness {
202 forward!(&self, readiness)
203 }
204 pub fn readiness_mut(&mut self) -> &mut Readiness {
205 forward!(&mut self, readiness)
206 }
207 pub fn position(&self) -> &Position {
208 forward!(&self, position)
209 }
210 pub fn position_mut(&mut self) -> &mut Position {
211 forward!(&mut self, position)
212 }
213 pub fn socket(&self) -> &TcpStream {
214 match self {
215 Connection::H1(c) => c.socket.socket_ref(),
216 Connection::H2(c) => c.socket.socket_ref(),
217 }
218 }
219 pub fn socket_mut(&mut self) -> &mut TcpStream {
220 match self {
221 Connection::H1(c) => c.socket.socket_mut(),
222 Connection::H2(c) => c.socket.socket_mut(),
223 }
224 }
225 pub fn timeout_container(&mut self) -> &mut TimeoutContainer {
226 forward!(&mut self, timeout_container)
227 }
228
229 pub fn overhead_bytes(&self) -> (usize, usize) {
231 match self {
232 Connection::H1(_) => (0, 0),
233 Connection::H2(c) => (c.bytes.overhead_bin, c.bytes.overhead_bout),
234 }
235 }
236
237 pub(super) fn readable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
238 where
239 E: Endpoint,
240 L: ListenerHandler + L7ListenerHandler,
241 {
242 forward!(self, readable(context, endpoint))
243 }
244 pub(super) fn writable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
245 where
246 E: Endpoint,
247 L: ListenerHandler + L7ListenerHandler,
248 {
249 forward!(self, writable(context, endpoint))
250 }
251
252 pub(super) fn has_buffer_pressure<L>(&self, context: &Context<L>) -> bool
256 where
257 L: ListenerHandler + L7ListenerHandler,
258 {
259 match self {
260 Connection::H1(c) => {
261 let Some(stream_id) = c.stream else {
262 return false;
264 };
265 let kawa = match c.position {
266 Position::Client(..) => &context.streams[stream_id].back,
267 Position::Server => &context.streams[stream_id].front,
268 };
269 kawa.storage.available_space() == 0
270 }
271 Connection::H2(_) => false,
273 }
274 }
275
276 pub(super) fn try_resume_reading<L>(&mut self, context: &Context<L>) -> bool
287 where
288 L: ListenerHandler + L7ListenerHandler,
289 {
290 match self {
291 Connection::H1(c) => {
292 if !c.parked_on_buffer_pressure {
293 return false;
294 }
295 let Some(stream_id) = c.stream else {
296 return false;
297 };
298 let kawa = match c.position {
299 Position::Client(..) => &context.streams[stream_id].back,
300 Position::Server => &context.streams[stream_id].front,
301 };
302 if kawa.storage.available_space() > 0 {
303 trace!(
304 "{} H1 try_resume_reading: re-arming READABLE",
305 log_module_context!()
306 );
307 debug_assert!(
313 c.parked_on_buffer_pressure,
314 "re-arm only fires for a connection parked on buffer pressure"
315 );
316 c.readiness.signal_pending_read();
317 debug_assert!(
318 c.readiness.event.is_readable(),
319 "signal_pending_read must leave a READABLE event queued"
320 );
321 true
322 } else {
323 false
324 }
325 }
326 Connection::H2(c) => c.try_resume_reading(context),
327 }
328 }
329
330 pub(super) fn graceful_goaway(&mut self) -> MuxResult {
331 match self {
332 Connection::H1(_) => MuxResult::Continue,
333 Connection::H2(c) => c.graceful_goaway(),
334 }
335 }
336
337 pub(super) fn is_draining(&self) -> bool {
338 match self {
339 Connection::H1(_) => false,
340 Connection::H2(c) => c.drain.draining,
341 }
342 }
343
344 pub(super) fn graceful_shutdown_deadline_elapsed(&self) -> bool {
350 match self {
351 Connection::H1(_) => false,
352 Connection::H2(c) => c.graceful_shutdown_deadline_elapsed(),
353 }
354 }
355
356 pub(super) fn has_pending_write(&self) -> bool {
357 forward!(self, has_pending_write())
358 }
359
360 pub(super) fn has_pending_write_including_streams<L>(&self, context: &super::Context<L>) -> bool
365 where
366 L: ListenerHandler + L7ListenerHandler,
367 {
368 match self {
369 Connection::H1(c) => c.has_pending_write(),
370 Connection::H2(c) => c.has_pending_write_full(context),
371 }
372 }
373
374 pub(super) fn initiate_close_notify(&mut self) -> bool {
375 forward!(self, initiate_close_notify())
376 }
377
378 pub(super) fn flush_zero_buffer(&mut self) {
379 if let Connection::H2(c) = self {
380 c.flush_zero_buffer();
381 }
382 }
383
384 fn pre_close_client_bookkeeping(&self) {
385 if let Position::Client(cluster_id, backend, _) = self.position() {
386 let mut backend_borrow = backend.borrow_mut();
387 let before = backend_borrow.active_connections;
392 backend_borrow.dec_connections();
393 debug_assert_eq!(
394 backend_borrow.active_connections,
395 before.saturating_sub(1),
396 "close must release exactly one backend connection (saturating at 0)"
397 );
398 gauge_add!(names::backend::CONNECTIONS, -1);
399 gauge_add!(names::backend::POOL_SIZE, -1);
404 gauge_add!(
405 names::backend::CONNECTIONS_PER_BACKEND,
406 -1,
407 Some(cluster_id),
408 Some(&backend_borrow.backend_id)
409 );
410 trace!(
411 "{} connection close: {:#?}",
412 log_module_context!(),
413 backend_borrow
414 );
415 }
416 }
417
418 fn pre_end_stream_client_bookkeeping(&self) {
419 if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
420 let mut backend_borrow = backend.borrow_mut();
421 let before = backend_borrow.active_requests;
426 backend_borrow.active_requests = backend_borrow.active_requests.saturating_sub(1);
427 debug_assert_eq!(
428 backend_borrow.active_requests,
429 before.saturating_sub(1),
430 "end_stream bookkeeping must decrement active_requests by one (saturating)"
431 );
432 debug_assert!(
433 backend_borrow.active_requests <= before,
434 "active_requests must not grow on stream end"
435 );
436 trace!(
437 "{} connection end stream: {:#?}",
438 log_module_context!(),
439 backend_borrow
440 );
441 }
442 }
443
444 fn pre_start_stream_client_bookkeeping(&self) {
445 if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
446 let mut backend_borrow = backend.borrow_mut();
447 let before = backend_borrow.active_requests;
450 backend_borrow.active_requests += 1;
451 debug_assert_eq!(
452 backend_borrow.active_requests,
453 before + 1,
454 "start_stream bookkeeping must increment active_requests by exactly one"
455 );
456 trace!(
457 "{} connection start stream: {:#?}",
458 log_module_context!(),
459 backend_borrow
460 );
461 }
462 }
463
464 pub(super) fn close<E, L>(&mut self, context: &mut Context<L>, endpoint: E)
465 where
466 E: Endpoint,
467 L: ListenerHandler + L7ListenerHandler,
468 {
469 self.pre_close_client_bookkeeping();
470 forward!(self, close(context, endpoint))
471 }
472
473 pub(super) fn end_stream<L>(&mut self, stream: GlobalStreamId, context: &mut Context<L>)
474 where
475 L: ListenerHandler + L7ListenerHandler,
476 {
477 self.pre_end_stream_client_bookkeeping();
478 forward!(self, end_stream(stream, context))
479 }
480
481 pub(super) fn start_stream<L>(
482 &mut self,
483 stream: GlobalStreamId,
484 context: &mut Context<L>,
485 ) -> bool
486 where
487 L: ListenerHandler + L7ListenerHandler,
488 {
489 #[cfg(debug_assertions)]
496 let before = self.backend_active_requests();
497 self.pre_start_stream_client_bookkeeping();
498 let started = forward!(self, start_stream(stream, context));
499 if !started {
500 self.pre_end_stream_client_bookkeeping();
502 #[cfg(debug_assertions)]
503 debug_assert_eq!(
504 self.backend_active_requests(),
505 before,
506 "a refused start_stream must roll active_requests back to its prior value"
507 );
508 }
509 started
510 }
511
512 #[cfg(debug_assertions)]
518 fn backend_active_requests(&self) -> Option<usize> {
519 match self.position() {
520 Position::Client(_, backend, BackendStatus::Connected) => {
521 Some(backend.borrow().active_requests)
522 }
523 _ => None,
524 }
525 }
526}
527
528#[derive(Debug)]
529pub(super) struct EndpointServer<'a, Front: SocketHandler>(pub &'a mut Connection<Front>);
530#[derive(Debug)]
531pub(super) struct EndpointClient<'a>(pub &'a mut Router);
532
533impl<Front: SocketHandler + Debug> Endpoint for EndpointServer<'_, Front> {
536 fn readiness(&self, _token: Token) -> &Readiness {
537 self.0.readiness()
538 }
539 fn readiness_mut(&mut self, _token: Token) -> &mut Readiness {
540 self.0.readiness_mut()
541 }
542 fn socket(&self, _token: Token) -> Option<&TcpStream> {
543 Some(self.0.socket())
544 }
545
546 fn end_stream<L>(&mut self, _token: Token, stream: GlobalStreamId, context: &mut Context<L>)
547 where
548 L: ListenerHandler + L7ListenerHandler,
549 {
550 self.0.end_stream(stream, context);
553 }
554
555 fn start_stream<L>(
556 &mut self,
557 _token: Token,
558 stream: GlobalStreamId,
559 context: &mut Context<L>,
560 ) -> bool
561 where
562 L: ListenerHandler + L7ListenerHandler,
563 {
564 self.0.start_stream(stream, context)
568 }
569}
570impl Endpoint for EndpointClient<'_> {
571 fn readiness(&self, token: Token) -> &Readiness {
572 match self.0.backends.get(&token) {
573 Some(backend) => backend.readiness(),
574 None => {
575 error!(
576 "{} backend token {:?} missing from backends map (readiness)",
577 log_module_context!(),
578 token
579 );
580 &self.0.fallback_readiness
581 }
582 }
583 }
584 fn readiness_mut(&mut self, token: Token) -> &mut Readiness {
585 match self.0.backends.get_mut(&token) {
586 Some(backend) => backend.readiness_mut(),
587 None => {
588 error!(
589 "{} backend token {:?} missing from backends map (readiness_mut)",
590 log_module_context!(),
591 token
592 );
593 &mut self.0.fallback_readiness
594 }
595 }
596 }
597 fn socket(&self, token: Token) -> Option<&TcpStream> {
598 self.0.backends.get(&token).map(|c| c.socket())
599 }
600
601 fn end_stream<L>(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context<L>)
602 where
603 L: ListenerHandler + L7ListenerHandler,
604 {
605 match self.0.backends.get_mut(&token) {
606 Some(backend) => backend.end_stream(stream, context),
607 None => {
608 error!(
609 "{} backend token {:?} missing from backends map (end_stream)",
610 log_module_context!(),
611 token
612 );
613 }
614 }
615 }
616
617 fn start_stream<L>(
618 &mut self,
619 token: Token,
620 stream: GlobalStreamId,
621 context: &mut Context<L>,
622 ) -> bool
623 where
624 L: ListenerHandler + L7ListenerHandler,
625 {
626 match self.0.backends.get_mut(&token) {
627 Some(backend) => backend.start_stream(stream, context),
628 None => {
629 error!(
630 "{} backend token {:?} missing from backends map (start_stream)",
631 log_module_context!(),
632 token
633 );
634 false
635 }
636 }
637 }
638}
639
640#[cfg(any(test, feature = "e2e-hooks"))]
648pub mod test_hooks {
649 use std::sync::atomic::AtomicBool;
650
651 pub static FORCE_NEW_H2_CLIENT_FAILURE: AtomicBool = AtomicBool::new(false);
656
657 pub fn __test_force_h2_client_failure(on: bool) -> bool {
662 FORCE_NEW_H2_CLIENT_FAILURE.swap(on, std::sync::atomic::Ordering::SeqCst)
663 }
664}