Skip to main content

specter/transport/h2/
driver.rs

1//! HTTP/2 connection driver - background task that reads frames and routes them to streams.
2//!
3//! The driver owns the raw H2Connection and continuously reads frames from the socket,
4//! routing them to the appropriate stream channels. This allows multiple requests
5//! to be multiplexed without blocking each other.
6
7use bytes::{Bytes, BytesMut};
8use http::{Method, Uri};
9use std::collections::{HashMap, VecDeque};
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tracing;
13
14pub type StreamingHeadersResult = Result<(u16, Vec<(String, String)>)>;
15
16use crate::error::{Error, Result};
17use crate::transport::h2::connection::{
18    ControlAction, H2Connection as RawH2Connection, StreamResponse,
19};
20use crate::transport::h2::frame::{flags, ErrorCode, FrameHeader, FrameType};
21use crate::transport::h2::tunnel::{H2Tunnel, H2TunnelEvent, H2TunnelOutbound};
22
23/// Command sent from handle to driver
24#[derive(Debug)]
25pub enum DriverCommand {
26    /// Send a request and get response via oneshot
27    /// Driver allocates stream_id
28    SendRequest {
29        method: http::Method,
30        uri: http::Uri,
31        headers: Vec<(String, String)>,
32        body: Option<bytes::Bytes>,
33        response_tx: oneshot::Sender<Result<StreamResponse>>,
34    },
35    /// Send a request with a streaming body
36    SendStreamingRequest {
37        method: Method,
38        uri: Uri,
39        headers: Vec<(String, String)>,
40        body_tx: mpsc::Sender<Result<Bytes>>,
41        headers_tx: oneshot::Sender<StreamingHeadersResult>,
42    },
43    /// Open an RFC 8441 WebSocket tunnel on a pooled HTTP/2 stream.
44    OpenWebSocketTunnel {
45        uri: Uri,
46        headers: Vec<(String, String)>,
47        response_tx: oneshot::Sender<Result<H2Tunnel>>,
48    },
49    /// Queue outbound DATA for an open RFC 8441 tunnel.
50    SendTunnelData {
51        stream_id: u32,
52        outbound: H2TunnelOutbound,
53    },
54}
55
56/// Per-stream state tracked by driver
57struct DriverStreamState {
58    /// Oneshot sender for response completion
59    response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
60    /// Accumulated response status
61    status: Option<u16>,
62    /// Accumulated response headers
63    headers: Vec<(String, String)>,
64    /// Accumulated response body
65    body: BytesMut,
66    /// Pending request body to be sent (flow control buffer)
67    pending_body: Bytes,
68    /// Offset of pending body already sent
69    body_offset: usize,
70}
71
72impl DriverStreamState {
73    fn new(response_tx: oneshot::Sender<Result<StreamResponse>>, pending_body: Bytes) -> Self {
74        Self {
75            response_tx: Some(response_tx),
76            status: None,
77            headers: Vec::new(),
78            body: BytesMut::new(),
79            pending_body,
80            body_offset: 0,
81        }
82    }
83}
84
85struct DriverTunnelState {
86    inbound_tx: mpsc::Sender<Result<H2TunnelEvent>>,
87    pending_outbound: VecDeque<H2TunnelOutbound>,
88}
89
90/// HTTP/2 connection driver that runs in a background task
91pub struct H2Driver<S> {
92    /// Channel for receiving commands from handles
93    command_rx: mpsc::Receiver<DriverCommand>,
94    /// Sender back into the driver command queue, used by tunnel outbound forwarders.
95    command_tx: mpsc::Sender<DriverCommand>,
96    /// Raw H2 connection (owned by driver)
97    connection: RawH2Connection<S>,
98    /// Per-stream state for routing responses
99    streams: HashMap<u32, DriverStreamState>,
100    /// Per-stream state for open RFC 8441 tunnels.
101    tunnels: HashMap<u32, DriverTunnelState>,
102    /// Queue for pending requests when max streams reached
103    pending_requests: std::collections::VecDeque<DriverCommand>,
104}
105
106impl<S> H2Driver<S>
107where
108    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
109{
110    /// Create a new driver from an established connection
111    pub fn new(
112        connection: RawH2Connection<S>,
113        command_tx: mpsc::Sender<DriverCommand>,
114        command_rx: mpsc::Receiver<DriverCommand>,
115    ) -> Self {
116        Self {
117            command_rx,
118            command_tx,
119            connection,
120            streams: HashMap::new(),
121            tunnels: HashMap::new(),
122            pending_requests: std::collections::VecDeque::new(),
123        }
124    }
125
126    /// Run the driver loop - processes commands and reads frames
127    pub async fn drive(mut self) -> Result<()> {
128        loop {
129            // Processing pending requests if slots available
130            self.process_pending_requests().await?;
131
132            // Try to flush any pending data (flow control)
133            self.flush_pending_data().await?;
134            self.flush_tunnel_data().await?;
135
136            tokio::select! {
137                // Handle incoming commands (send requests)
138                command = self.command_rx.recv() => {
139                    match command {
140                        Some(cmd) => {
141                             match cmd {
142                                DriverCommand::SendRequest { .. } => {
143                                    self.handle_send_request(cmd).await?;
144                                }
145                                DriverCommand::SendStreamingRequest { .. } => {
146                                    tracing::warn!("Streaming requests not yet implemented in driver");
147                                }
148                                DriverCommand::OpenWebSocketTunnel { uri, headers, response_tx } => {
149                                    self.handle_open_websocket_tunnel(uri, headers, response_tx).await?;
150                                 }
151                                DriverCommand::SendTunnelData { stream_id, outbound } => {
152                                    self.queue_tunnel_outbound(stream_id, outbound).await?;
153                                }
154                             }
155                        }
156                        None => {
157                            // Channel closed - driver should shutdown
158                            break;
159                        }
160                    }
161                }
162
163                // Handle incoming frames
164                read_res = self.connection.read_next_frame() => {
165                    match read_res {
166                        Ok((header, payload)) => {
167                            if let Err(e) = self.handle_frame(header, payload).await {
168                                tracing::error!("H2Driver frame error: {:?}", e);
169                                // Protocol errors are fatal and require connection termination.
170                                // The connection state may be inconsistent after this error.
171                                return Err(e);
172                            }
173                        }
174                        Err(e) => {
175                             // Connection error
176                            tracing::error!("H2Driver read error: {:?}", e);
177                            return Err(e);
178                        }
179                    }
180                }
181            }
182        }
183        Ok(())
184    }
185
186    /// Handle SendRequest command
187    async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
188        if !self.has_available_stream_slot() {
189            // Queue request
190            self.pending_requests.push_back(cmd);
191        } else {
192            // Send immediately
193            self.send_request_internal(cmd).await?;
194        }
195        Ok(())
196    }
197
198    /// Process pending requests if slots available
199    async fn process_pending_requests(&mut self) -> Result<()> {
200        while self.has_available_stream_slot() {
201            if let Some(cmd) = self.pending_requests.pop_front() {
202                match cmd {
203                    DriverCommand::SendRequest { .. } => {
204                        self.send_request_internal(cmd).await?;
205                    }
206                    DriverCommand::OpenWebSocketTunnel {
207                        uri,
208                        headers,
209                        response_tx,
210                    } => {
211                        self.open_websocket_tunnel_internal(uri, headers, response_tx)
212                            .await?;
213                    }
214                    DriverCommand::SendStreamingRequest { .. } => {
215                        tracing::warn!("Streaming requests not yet implemented in driver");
216                    }
217                    DriverCommand::SendTunnelData {
218                        stream_id,
219                        outbound,
220                    } => {
221                        self.queue_tunnel_outbound(stream_id, outbound).await?;
222                    }
223                }
224            } else {
225                break;
226            }
227        }
228        Ok(())
229    }
230
231    fn active_stream_count(&self) -> usize {
232        self.streams.len() + self.tunnels.len()
233    }
234
235    fn has_available_stream_slot(&self) -> bool {
236        let max_streams = self.connection.peer_settings().max_concurrent_streams as usize;
237        self.active_stream_count() < max_streams
238    }
239
240    /// Internal helper to send request
241    async fn send_request_internal(&mut self, cmd: DriverCommand) -> Result<()> {
242        if let DriverCommand::SendRequest {
243            method,
244            uri,
245            headers,
246            body,
247            response_tx,
248        } = cmd
249        {
250            // Construct request
251            let mut req_builder = http::Request::builder().method(method).uri(uri);
252
253            for (k, v) in headers {
254                req_builder = req_builder.header(k, v);
255            }
256
257            // Body
258            let body_bytes = body.unwrap_or_default();
259            let has_body = !body_bytes.is_empty();
260
261            let req = match req_builder.body(body_bytes.clone()) {
262                Ok(r) => r,
263                Err(e) => {
264                    if response_tx
265                        .send(Err(Error::HttpProtocol(format!("Invalid request: {}", e))))
266                        .is_err()
267                    {
268                        tracing::debug!("Response channel closed while sending error");
269                    }
270                    return Ok(());
271                }
272            };
273
274            // Send HEADERS frame (non-blocking write)
275            // If body is present, end_stream=false (DATA frames will be sent separately)
276            let end_stream = !has_body;
277
278            match self.connection.send_headers(&req, end_stream).await {
279                Ok(stream_id) => {
280                    // Register stream state
281                    self.streams
282                        .insert(stream_id, DriverStreamState::new(response_tx, body_bytes));
283
284                    // Trigger flush to try sending body immediately
285                    self.flush_pending_data().await?;
286                }
287                Err(e) => {
288                    // Notify error immediately
289                    if response_tx.send(Err(e)).is_err() {
290                        tracing::debug!("Response channel closed while sending error");
291                    }
292                }
293            }
294        }
295        Ok(())
296    }
297
298    async fn handle_open_websocket_tunnel(
299        &mut self,
300        uri: Uri,
301        headers: Vec<(String, String)>,
302        response_tx: oneshot::Sender<Result<H2Tunnel>>,
303    ) -> Result<()> {
304        if !self.has_available_stream_slot() {
305            self.pending_requests
306                .push_back(DriverCommand::OpenWebSocketTunnel {
307                    uri,
308                    headers,
309                    response_tx,
310                });
311            return Ok(());
312        }
313
314        self.open_websocket_tunnel_internal(uri, headers, response_tx)
315            .await
316    }
317
318    async fn open_websocket_tunnel_internal(
319        &mut self,
320        uri: Uri,
321        headers: Vec<(String, String)>,
322        response_tx: oneshot::Sender<Result<H2Tunnel>>,
323    ) -> Result<()> {
324        match self
325            .connection
326            .open_extended_connect_websocket_with_end_stream(&uri, headers)
327            .await
328        {
329            Ok((stream_id, end_stream)) => {
330                let (outbound_tx, outbound_rx) = mpsc::channel(32);
331                let (inbound_tx, inbound_rx) = mpsc::channel(32);
332                if end_stream {
333                    let _ = inbound_tx.send(Ok(H2TunnelEvent::EndStream)).await;
334                    self.connection.remove_stream(stream_id);
335                } else {
336                    let command_tx = self.command_tx.clone();
337                    tokio::spawn(async move {
338                        let mut outbound_rx = outbound_rx;
339                        while let Some(outbound) = outbound_rx.recv().await {
340                            if command_tx
341                                .send(DriverCommand::SendTunnelData {
342                                    stream_id,
343                                    outbound,
344                                })
345                                .await
346                                .is_err()
347                            {
348                                break;
349                            }
350                        }
351                    });
352                    self.tunnels.insert(
353                        stream_id,
354                        DriverTunnelState {
355                            inbound_tx,
356                            pending_outbound: VecDeque::new(),
357                        },
358                    );
359                }
360
361                if response_tx
362                    .send(Ok(H2Tunnel::new(outbound_tx, inbound_rx)))
363                    .is_err()
364                {
365                    tracing::debug!("Tunnel response channel closed after open");
366                    self.tunnels.remove(&stream_id);
367                }
368            }
369            Err(e) => {
370                if response_tx.send(Err(e)).is_err() {
371                    tracing::debug!("Tunnel response channel closed while sending open error");
372                }
373            }
374        }
375        Ok(())
376    }
377
378    async fn queue_tunnel_outbound(
379        &mut self,
380        stream_id: u32,
381        outbound: H2TunnelOutbound,
382    ) -> Result<()> {
383        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
384            tunnel.pending_outbound.push_back(outbound);
385            self.flush_tunnel_data().await?;
386        }
387
388        Ok(())
389    }
390
391    /// Iterate all active streams and try to send pending body data
392    async fn flush_pending_data(&mut self) -> Result<()> {
393        // Collect IDs to avoid borrow conflict
394        let stream_ids: Vec<u32> = self.streams.keys().cloned().collect();
395
396        for stream_id in stream_ids {
397            // Keep sending chunks for this stream until blocked or done
398            loop {
399                // Check if we have data to send
400                let (has_data, offset) = if let Some(stream) = self.streams.get(&stream_id) {
401                    (
402                        stream.body_offset < stream.pending_body.len(),
403                        stream.body_offset,
404                    )
405                } else {
406                    (false, 0)
407                };
408
409                if !has_data {
410                    break;
411                }
412
413                // Prepare arguments for send_data
414                // We clone the Bytes handle which is cheap
415                let pending_body = {
416                    let s = self.streams.get(&stream_id).unwrap();
417                    s.pending_body.clone()
418                };
419
420                let remaining = &pending_body[offset..];
421                let is_last_chunk = true;
422
423                // send_data returns bytes sent. If 0, it means blocked.
424                let sent = self
425                    .connection
426                    .send_data(stream_id, remaining, is_last_chunk)
427                    .await?;
428
429                if sent > 0 {
430                    if let Some(stream) = self.streams.get_mut(&stream_id) {
431                        stream.body_offset += sent;
432                    }
433                    // Loop again to send next chunk
434                } else {
435                    // Blocked by flow control
436                    break;
437                }
438            }
439        }
440        Ok(())
441    }
442
443    async fn flush_tunnel_data(&mut self) -> Result<()> {
444        let stream_ids: Vec<u32> = self.tunnels.keys().copied().collect();
445
446        for stream_id in stream_ids {
447            loop {
448                let outbound = match self
449                    .tunnels
450                    .get_mut(&stream_id)
451                    .and_then(|tunnel| tunnel.pending_outbound.pop_front())
452                {
453                    Some(outbound) => outbound,
454                    None => break,
455                };
456
457                let sent = self
458                    .connection
459                    .send_data(stream_id, &outbound.bytes, outbound.end_stream)
460                    .await?;
461
462                if outbound.bytes.is_empty() {
463                    continue;
464                }
465
466                if sent == 0 {
467                    if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
468                        tunnel.pending_outbound.push_front(outbound);
469                    }
470                    break;
471                }
472
473                if sent < outbound.bytes.len() {
474                    if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
475                        tunnel.pending_outbound.push_front(H2TunnelOutbound {
476                            bytes: outbound.bytes.slice(sent..),
477                            end_stream: outbound.end_stream,
478                        });
479                    }
480                    break;
481                }
482            }
483        }
484
485        Ok(())
486    }
487
488    /// Handle a single frame
489    async fn handle_frame(&mut self, header: FrameHeader, mut payload: Bytes) -> Result<()> {
490        // 1. Check control frames that modify connection state
491        match self
492            .connection
493            .handle_control_frame(&header, payload.clone())
494            .await?
495        {
496            ControlAction::RstStream(sid, code) => {
497                if let Some(tunnel) = self.tunnels.remove(&sid) {
498                    let _ = tunnel
499                        .inbound_tx
500                        .send(Ok(H2TunnelEvent::Reset(format!("{:?}", code))))
501                        .await;
502                }
503                // Notify stream of reset
504                if let Some(mut stream) = self.streams.remove(&sid) {
505                    if let Some(tx) = stream.response_tx.take() {
506                        if tx
507                            .send(Err(Error::HttpProtocol(format!(
508                                "Stream reset by peer: {:?}",
509                                code
510                            ))))
511                            .is_err()
512                        {
513                            tracing::debug!("Response channel closed while notifying stream reset");
514                        }
515                    }
516                }
517                // Stream slot freed, try to process pending
518                self.process_pending_requests().await?;
519                return Ok(());
520            }
521            ControlAction::GoAway(last_sid) => {
522                let tunnel_ids: Vec<u32> = self.tunnels.keys().copied().collect();
523                for sid in tunnel_ids {
524                    if sid > last_sid {
525                        if let Some(tunnel) = self.tunnels.remove(&sid) {
526                            let _ = tunnel
527                                .inbound_tx
528                                .send(Ok(H2TunnelEvent::GoAway {
529                                    last_stream_id: last_sid,
530                                }))
531                                .await;
532                        }
533                    }
534                }
535                // Close all streams > last_sid
536                let sids: Vec<u32> = self.streams.keys().cloned().collect();
537                for sid in sids {
538                    if sid > last_sid {
539                        if let Some(mut stream) = self.streams.remove(&sid) {
540                            if let Some(tx) = stream.response_tx.take() {
541                                if tx
542                                    .send(Err(Error::HttpProtocol("GOAWAY received".into())))
543                                    .is_err()
544                                {
545                                    tracing::debug!(
546                                        "Response channel closed while notifying GOAWAY"
547                                    );
548                                }
549                            }
550                        }
551                    }
552                }
553                // Driver continues processing existing streams until they complete.
554                // A future enhancement could implement immediate shutdown on GOAWAY.
555                return Ok(());
556            }
557            ControlAction::RefusePush(_stream_id, promised_id) => {
558                // Send RST_STREAM for the promised stream
559                // RFC 9113 8.4: RST_STREAM with REFUSED_STREAM
560                if let Err(e) = self
561                    .connection
562                    .send_rst_stream(promised_id, ErrorCode::RefusedStream)
563                    .await
564                {
565                    tracing::warn!(
566                        "Failed to send RST_STREAM for refused push promise: {:?}",
567                        e
568                    );
569                }
570            }
571            ControlAction::None => {
572                // Continue to specific processing
573            }
574        }
575
576        // 2. Data / Headers routing
577        match header.frame_type {
578            FrameType::Headers => {
579                let stream_id = header.stream_id;
580
581                // Handle CONTINUATION frames if needed (END_HEADERS flag not set).
582                // CONTINUATION frames are collected in the loop below; this branch handles
583                // the initial HEADERS frame that starts a header block.
584                if (header.flags & flags::END_HEADERS) == 0 {
585                    // Loop to read CONTINUATION frames
586                    // This inner loop blocks the driver select! loop, which is expected
587                    // per RFC 9113 Section 6.2 (CONTINUATION frames must be processed sequentially).
588                    let mut block = BytesMut::from(payload);
589                    loop {
590                        let (next_header, next_payload) = self.connection.read_next_frame().await?;
591                        if next_header.frame_type != FrameType::Continuation {
592                            return Err(Error::HttpProtocol("Expected CONTINUATION frame".into()));
593                        }
594                        if next_header.stream_id != stream_id {
595                            return Err(Error::HttpProtocol(
596                                "CONTINUATION frame stream ID mismatch".into(),
597                            ));
598                        }
599                        block.extend_from_slice(&next_payload);
600                        if (next_header.flags & flags::END_HEADERS) != 0 {
601                            break;
602                        }
603                    }
604                    payload = block.freeze();
605                }
606
607                let decoded = self.connection.decode_header_block(payload)?;
608
609                // Parse pseudo-headers
610                let mut status = 0u16;
611                let mut regular_headers = Vec::new();
612
613                for (name, value) in decoded {
614                    if name == ":status" {
615                        status = value.parse().unwrap_or(0);
616                    } else if !name.starts_with(':') {
617                        regular_headers.push((name, value));
618                    }
619                }
620
621                if let Some(stream) = self.streams.get_mut(&stream_id) {
622                    stream.status = Some(status);
623                    stream.headers = regular_headers;
624
625                    if (header.flags & flags::END_STREAM) != 0 {
626                        self.complete_stream(stream_id);
627                    }
628                }
629            }
630            FrameType::Data => {
631                let stream_id = header.stream_id;
632                let end_stream = (header.flags & flags::END_STREAM) != 0;
633
634                // Process flow control for inbound DATA frame.
635                // The process_inbound_data_frame method takes stream_id, flags, and payload
636                // to handle window updates and flow control state.
637                let data = self
638                    .connection
639                    .process_inbound_data_frame(stream_id, header.flags, payload)
640                    .await?;
641
642                if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
643                    if !data.is_empty() {
644                        let _ = tunnel.inbound_tx.send(Ok(H2TunnelEvent::Data(data))).await;
645                    }
646                    if end_stream {
647                        let _ = tunnel.inbound_tx.send(Ok(H2TunnelEvent::EndStream)).await;
648                        self.tunnels.remove(&stream_id);
649                    }
650                    return Ok(());
651                }
652
653                if let Some(stream) = self.streams.get_mut(&stream_id) {
654                    stream.body.extend_from_slice(&data);
655
656                    if end_stream {
657                        self.complete_stream(stream_id);
658                    }
659                }
660            }
661            FrameType::WindowUpdate => {
662                // Window update received and processed by handle_control_frame,
663                // which updates the connection/stream window in self.connection.
664                // Flush any pending data that was previously blocked by flow control.
665                self.flush_pending_data().await?;
666                self.flush_tunnel_data().await?;
667            }
668            _ => {} // Other frames handled by handle_control_frame (or ignored)
669        }
670
671        Ok(())
672    }
673
674    /// Complete a stream: build response and send
675    fn complete_stream(&mut self, stream_id: u32) {
676        if let Some(mut stream) = self.streams.remove(&stream_id) {
677            if let Some(tx) = stream.response_tx.take() {
678                // If no status was received, this is a protocol violation
679                // Return an error rather than defaulting to 200
680                let response = match stream.status {
681                    Some(status) => Ok(StreamResponse {
682                        status,
683                        headers: stream.headers,
684                        body: stream.body.freeze(),
685                    }),
686                    None => Err(Error::HttpProtocol(format!(
687                        "Stream {} completed without status code",
688                        stream_id
689                    ))),
690                };
691                if tx.send(response).is_err() {
692                    tracing::debug!("Response channel closed while completing stream");
693                }
694            }
695        }
696        // Stream slot is now available. The main loop will call process_pending_requests
697        // to process any queued requests waiting for available stream slots.
698    }
699}