Skip to main content

specter/transport/h2/
driver.rs

1//! HTTP/2 connection driver - background task that reads frames and routes them to streams.
2//!
3//! The driver owns the raw H2Connection and continuously reads frames from the socket,
4//! routing them to the appropriate stream channels. This allows multiple requests
5//! to be multiplexed without blocking each other.
6
7use bytes::{Bytes, BytesMut};
8use http::{Method, Uri};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tracing;
13
14pub type StreamingHeadersResult = Result<(u16, Vec<(String, String)>)>;
15
16use crate::error::{Error, Result};
17use crate::transport::h2::connection::{
18    ControlAction, H2Connection as RawH2Connection, StreamResponse,
19};
20use crate::transport::h2::frame::{flags, ErrorCode, FrameHeader, FrameType};
21
22/// Command sent from handle to driver
23#[derive(Debug)]
24pub enum DriverCommand {
25    /// Send a request and get response via oneshot
26    /// Driver allocates stream_id
27    SendRequest {
28        method: http::Method,
29        uri: http::Uri,
30        headers: Vec<(String, String)>,
31        body: Option<bytes::Bytes>,
32        response_tx: oneshot::Sender<Result<StreamResponse>>,
33    },
34    /// Send a request with a streaming body
35    SendStreamingRequest {
36        method: Method,
37        uri: Uri,
38        headers: Vec<(String, String)>,
39        body_tx: mpsc::Sender<Result<Bytes>>,
40        headers_tx: oneshot::Sender<StreamingHeadersResult>,
41    },
42}
43
44/// Per-stream state tracked by driver
45struct DriverStreamState {
46    /// Oneshot sender for response completion
47    response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
48    /// Accumulated response status
49    status: Option<u16>,
50    /// Accumulated response headers
51    headers: Vec<(String, String)>,
52    /// Accumulated response body
53    body: BytesMut,
54    /// Pending request body to be sent (flow control buffer)
55    pending_body: Bytes,
56    /// Offset of pending body already sent
57    body_offset: usize,
58}
59
60impl DriverStreamState {
61    fn new(response_tx: oneshot::Sender<Result<StreamResponse>>, pending_body: Bytes) -> Self {
62        Self {
63            response_tx: Some(response_tx),
64            status: None,
65            headers: Vec::new(),
66            body: BytesMut::new(),
67            pending_body,
68            body_offset: 0,
69        }
70    }
71}
72
73/// HTTP/2 connection driver that runs in a background task
74pub struct H2Driver<S> {
75    /// Channel for receiving commands from handles
76    command_rx: mpsc::Receiver<DriverCommand>,
77    /// Raw H2 connection (owned by driver)
78    connection: RawH2Connection<S>,
79    /// Per-stream state for routing responses
80    streams: HashMap<u32, DriverStreamState>,
81    /// Queue for pending requests when max streams reached
82    pending_requests: std::collections::VecDeque<DriverCommand>,
83}
84
85impl<S> H2Driver<S>
86where
87    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
88{
89    /// Create a new driver from an established connection
90    pub fn new(connection: RawH2Connection<S>, command_rx: mpsc::Receiver<DriverCommand>) -> Self {
91        Self {
92            command_rx,
93            connection,
94            streams: HashMap::new(),
95            pending_requests: std::collections::VecDeque::new(),
96        }
97    }
98
99    /// Run the driver loop - processes commands and reads frames
100    pub async fn drive(mut self) -> Result<()> {
101        loop {
102            // Processing pending requests if slots available
103            self.process_pending_requests().await?;
104
105            // Try to flush any pending data (flow control)
106            self.flush_pending_data().await?;
107
108            tokio::select! {
109                // Handle incoming commands (send requests)
110                command = self.command_rx.recv() => {
111                    match command {
112                        Some(cmd) => {
113                             match cmd {
114                                DriverCommand::SendRequest { .. } => {
115                                    self.handle_send_request(cmd).await?;
116                                }
117                                DriverCommand::SendStreamingRequest { .. } => {
118                                    tracing::warn!("Streaming requests not yet implemented in driver");
119                                }
120                             }
121                        }
122                        None => {
123                            // Channel closed - driver should shutdown
124                            break;
125                        }
126                    }
127                }
128
129                // Handle incoming frames
130                read_res = self.connection.read_next_frame() => {
131                    match read_res {
132                        Ok((header, payload)) => {
133                            if let Err(e) = self.handle_frame(header, payload).await {
134                                tracing::error!("H2Driver frame error: {:?}", e);
135                                // Protocol errors are fatal and require connection termination.
136                                // The connection state may be inconsistent after this error.
137                                return Err(e);
138                            }
139                        }
140                        Err(e) => {
141                             // Connection error
142                            tracing::error!("H2Driver read error: {:?}", e);
143                            return Err(e);
144                        }
145                    }
146                }
147            }
148        }
149        Ok(())
150    }
151
152    /// Handle SendRequest command
153    async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
154        let max_streams = self.connection.peer_settings().max_concurrent_streams;
155
156        if self.streams.len() >= max_streams as usize {
157            // Queue request
158            self.pending_requests.push_back(cmd);
159        } else {
160            // Send immediately
161            self.send_request_internal(cmd).await?;
162        }
163        Ok(())
164    }
165
166    /// Process pending requests if slots available
167    async fn process_pending_requests(&mut self) -> Result<()> {
168        let max_streams = self.connection.peer_settings().max_concurrent_streams;
169
170        while self.streams.len() < max_streams as usize {
171            if let Some(cmd) = self.pending_requests.pop_front() {
172                self.send_request_internal(cmd).await?;
173            } else {
174                break;
175            }
176        }
177        Ok(())
178    }
179
180    /// Internal helper to send request
181    async fn send_request_internal(&mut self, cmd: DriverCommand) -> Result<()> {
182        if let DriverCommand::SendRequest {
183            method,
184            uri,
185            headers,
186            body,
187            response_tx,
188        } = cmd
189        {
190            // Construct request
191            let mut req_builder = http::Request::builder().method(method).uri(uri);
192
193            for (k, v) in headers {
194                req_builder = req_builder.header(k, v);
195            }
196
197            // Body
198            let body_bytes = body.unwrap_or_default();
199            let has_body = !body_bytes.is_empty();
200
201            let req = match req_builder.body(body_bytes.clone()) {
202                Ok(r) => r,
203                Err(e) => {
204                    if response_tx
205                        .send(Err(Error::HttpProtocol(format!("Invalid request: {}", e))))
206                        .is_err()
207                    {
208                        tracing::debug!("Response channel closed while sending error");
209                    }
210                    return Ok(());
211                }
212            };
213
214            // Send HEADERS frame (non-blocking write)
215            // If body is present, end_stream=false (DATA frames will be sent separately)
216            let end_stream = !has_body;
217
218            match self.connection.send_headers(&req, end_stream).await {
219                Ok(stream_id) => {
220                    // Register stream state
221                    self.streams
222                        .insert(stream_id, DriverStreamState::new(response_tx, body_bytes));
223
224                    // Trigger flush to try sending body immediately
225                    self.flush_pending_data().await?;
226                }
227                Err(e) => {
228                    // Notify error immediately
229                    if response_tx.send(Err(e)).is_err() {
230                        tracing::debug!("Response channel closed while sending error");
231                    }
232                }
233            }
234        }
235        Ok(())
236    }
237
238    /// Iterate all active streams and try to send pending body data
239    async fn flush_pending_data(&mut self) -> Result<()> {
240        // Collect IDs to avoid borrow conflict
241        let stream_ids: Vec<u32> = self.streams.keys().cloned().collect();
242
243        for stream_id in stream_ids {
244            // Keep sending chunks for this stream until blocked or done
245            loop {
246                // Check if we have data to send
247                let (has_data, offset) = if let Some(stream) = self.streams.get(&stream_id) {
248                    (
249                        stream.body_offset < stream.pending_body.len(),
250                        stream.body_offset,
251                    )
252                } else {
253                    (false, 0)
254                };
255
256                if !has_data {
257                    break;
258                }
259
260                // Prepare arguments for send_data
261                // We clone the Bytes handle which is cheap
262                let pending_body = {
263                    let s = self.streams.get(&stream_id).unwrap();
264                    s.pending_body.clone()
265                };
266
267                let remaining = &pending_body[offset..];
268                let is_last_chunk = true;
269
270                // send_data returns bytes sent. If 0, it means blocked.
271                let sent = self
272                    .connection
273                    .send_data(stream_id, remaining, is_last_chunk)
274                    .await?;
275
276                if sent > 0 {
277                    if let Some(stream) = self.streams.get_mut(&stream_id) {
278                        stream.body_offset += sent;
279                    }
280                    // Loop again to send next chunk
281                } else {
282                    // Blocked by flow control
283                    break;
284                }
285            }
286        }
287        Ok(())
288    }
289
290    /// Handle a single frame
291    async fn handle_frame(&mut self, header: FrameHeader, mut payload: Bytes) -> Result<()> {
292        // 1. Check control frames that modify connection state
293        match self
294            .connection
295            .handle_control_frame(&header, payload.clone())
296            .await?
297        {
298            ControlAction::RstStream(sid, code) => {
299                // Notify stream of reset
300                if let Some(mut stream) = self.streams.remove(&sid) {
301                    if let Some(tx) = stream.response_tx.take() {
302                        if tx
303                            .send(Err(Error::HttpProtocol(format!(
304                                "Stream reset by peer: {:?}",
305                                code
306                            ))))
307                            .is_err()
308                        {
309                            tracing::debug!("Response channel closed while notifying stream reset");
310                        }
311                    }
312                }
313                // Stream slot freed, try to process pending
314                self.process_pending_requests().await?;
315                return Ok(());
316            }
317            ControlAction::GoAway(last_sid) => {
318                // Close all streams > last_sid
319                let sids: Vec<u32> = self.streams.keys().cloned().collect();
320                for sid in sids {
321                    if sid > last_sid {
322                        if let Some(mut stream) = self.streams.remove(&sid) {
323                            if let Some(tx) = stream.response_tx.take() {
324                                if tx
325                                    .send(Err(Error::HttpProtocol("GOAWAY received".into())))
326                                    .is_err()
327                                {
328                                    tracing::debug!(
329                                        "Response channel closed while notifying GOAWAY"
330                                    );
331                                }
332                            }
333                        }
334                    }
335                }
336                // Driver continues processing existing streams until they complete.
337                // A future enhancement could implement immediate shutdown on GOAWAY.
338                return Ok(());
339            }
340            ControlAction::RefusePush(_stream_id, promised_id) => {
341                // Send RST_STREAM for the promised stream
342                // RFC 9113 8.4: RST_STREAM with REFUSED_STREAM
343                if let Err(e) = self
344                    .connection
345                    .send_rst_stream(promised_id, ErrorCode::RefusedStream)
346                    .await
347                {
348                    tracing::warn!(
349                        "Failed to send RST_STREAM for refused push promise: {:?}",
350                        e
351                    );
352                }
353            }
354            ControlAction::None => {
355                // Continue to specific processing
356            }
357        }
358
359        // 2. Data / Headers routing
360        match header.frame_type {
361            FrameType::Headers => {
362                let stream_id = header.stream_id;
363
364                // Handle CONTINUATION frames if needed (END_HEADERS flag not set).
365                // CONTINUATION frames are collected in the loop below; this branch handles
366                // the initial HEADERS frame that starts a header block.
367                if (header.flags & flags::END_HEADERS) == 0 {
368                    // Loop to read CONTINUATION frames
369                    // This inner loop blocks the driver select! loop, which is expected
370                    // per RFC 9113 Section 6.2 (CONTINUATION frames must be processed sequentially).
371                    let mut block = BytesMut::from(payload);
372                    loop {
373                        let (next_header, next_payload) = self.connection.read_next_frame().await?;
374                        if next_header.frame_type != FrameType::Continuation {
375                            return Err(Error::HttpProtocol("Expected CONTINUATION frame".into()));
376                        }
377                        if next_header.stream_id != stream_id {
378                            return Err(Error::HttpProtocol(
379                                "CONTINUATION frame stream ID mismatch".into(),
380                            ));
381                        }
382                        block.extend_from_slice(&next_payload);
383                        if (next_header.flags & flags::END_HEADERS) != 0 {
384                            break;
385                        }
386                    }
387                    payload = block.freeze();
388                }
389
390                let decoded = self.connection.decode_header_block(payload)?;
391
392                // Parse pseudo-headers
393                let mut status = 0u16;
394                let mut regular_headers = Vec::new();
395
396                for (name, value) in decoded {
397                    if name == ":status" {
398                        status = value.parse().unwrap_or(0);
399                    } else if !name.starts_with(':') {
400                        regular_headers.push((name, value));
401                    }
402                }
403
404                if let Some(stream) = self.streams.get_mut(&stream_id) {
405                    stream.status = Some(status);
406                    stream.headers = regular_headers;
407
408                    if (header.flags & flags::END_STREAM) != 0 {
409                        self.complete_stream(stream_id);
410                    }
411                }
412            }
413            FrameType::Data => {
414                let stream_id = header.stream_id;
415                let end_stream = (header.flags & flags::END_STREAM) != 0;
416
417                // Process flow control for inbound DATA frame.
418                // The process_inbound_data_frame method takes stream_id, flags, and payload
419                // to handle window updates and flow control state.
420                let data = self
421                    .connection
422                    .process_inbound_data_frame(stream_id, header.flags, payload)
423                    .await?;
424
425                if let Some(stream) = self.streams.get_mut(&stream_id) {
426                    stream.body.extend_from_slice(&data);
427
428                    if end_stream {
429                        self.complete_stream(stream_id);
430                    }
431                }
432            }
433            FrameType::WindowUpdate => {
434                // Window update received and processed by handle_control_frame,
435                // which updates the connection/stream window in self.connection.
436                // Flush any pending data that was previously blocked by flow control.
437                self.flush_pending_data().await?;
438            }
439            _ => {} // Other frames handled by handle_control_frame (or ignored)
440        }
441
442        Ok(())
443    }
444
445    /// Complete a stream: build response and send
446    fn complete_stream(&mut self, stream_id: u32) {
447        if let Some(mut stream) = self.streams.remove(&stream_id) {
448            if let Some(tx) = stream.response_tx.take() {
449                // If no status was received, this is a protocol violation
450                // Return an error rather than defaulting to 200
451                let response = match stream.status {
452                    Some(status) => Ok(StreamResponse {
453                        status,
454                        headers: stream.headers,
455                        body: stream.body.freeze(),
456                    }),
457                    None => Err(Error::HttpProtocol(format!(
458                        "Stream {} completed without status code",
459                        stream_id
460                    ))),
461                };
462                if tx.send(response).is_err() {
463                    tracing::debug!("Response channel closed while completing stream");
464                }
465            }
466        }
467        // Stream slot is now available. The main loop will call process_pending_requests
468        // to process any queued requests waiting for available stream slots.
469    }
470}