specter/transport/h2/driver.rs
1//! HTTP/2 connection driver - background task that reads frames and routes them to streams.
2//!
3//! The driver owns the raw H2Connection and continuously reads frames from the socket,
4//! routing them to the appropriate stream channels. This allows multiple requests
5//! to be multiplexed without blocking each other.
6
7use bytes::{Bytes, BytesMut};
8use http::{Method, Uri};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tracing;
13
14pub type StreamingHeadersResult = Result<(u16, Vec<(String, String)>)>;
15
16use crate::error::{Error, Result};
17use crate::transport::h2::connection::{
18 ControlAction, H2Connection as RawH2Connection, StreamResponse,
19};
20use crate::transport::h2::frame::{flags, ErrorCode, FrameHeader, FrameType};
21
22/// Command sent from handle to driver
23#[derive(Debug)]
24pub enum DriverCommand {
25 /// Send a request and get response via oneshot
26 /// Driver allocates stream_id
27 SendRequest {
28 method: http::Method,
29 uri: http::Uri,
30 headers: Vec<(String, String)>,
31 body: Option<bytes::Bytes>,
32 response_tx: oneshot::Sender<Result<StreamResponse>>,
33 },
34 /// Send a request with a streaming body
35 SendStreamingRequest {
36 method: Method,
37 uri: Uri,
38 headers: Vec<(String, String)>,
39 body_tx: mpsc::Sender<Result<Bytes>>,
40 headers_tx: oneshot::Sender<StreamingHeadersResult>,
41 },
42}
43
44/// Per-stream state tracked by driver
45struct DriverStreamState {
46 /// Oneshot sender for response completion
47 response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
48 /// Accumulated response status
49 status: Option<u16>,
50 /// Accumulated response headers
51 headers: Vec<(String, String)>,
52 /// Accumulated response body
53 body: BytesMut,
54 /// Pending request body to be sent (flow control buffer)
55 pending_body: Bytes,
56 /// Offset of pending body already sent
57 body_offset: usize,
58}
59
60impl DriverStreamState {
61 fn new(response_tx: oneshot::Sender<Result<StreamResponse>>, pending_body: Bytes) -> Self {
62 Self {
63 response_tx: Some(response_tx),
64 status: None,
65 headers: Vec::new(),
66 body: BytesMut::new(),
67 pending_body,
68 body_offset: 0,
69 }
70 }
71}
72
73/// HTTP/2 connection driver that runs in a background task
74pub struct H2Driver<S> {
75 /// Channel for receiving commands from handles
76 command_rx: mpsc::Receiver<DriverCommand>,
77 /// Raw H2 connection (owned by driver)
78 connection: RawH2Connection<S>,
79 /// Per-stream state for routing responses
80 streams: HashMap<u32, DriverStreamState>,
81 /// Queue for pending requests when max streams reached
82 pending_requests: std::collections::VecDeque<DriverCommand>,
83}
84
85impl<S> H2Driver<S>
86where
87 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
88{
89 /// Create a new driver from an established connection
90 pub fn new(connection: RawH2Connection<S>, command_rx: mpsc::Receiver<DriverCommand>) -> Self {
91 Self {
92 command_rx,
93 connection,
94 streams: HashMap::new(),
95 pending_requests: std::collections::VecDeque::new(),
96 }
97 }
98
99 /// Run the driver loop - processes commands and reads frames
100 pub async fn drive(mut self) -> Result<()> {
101 loop {
102 // Processing pending requests if slots available
103 self.process_pending_requests().await?;
104
105 // Try to flush any pending data (flow control)
106 self.flush_pending_data().await?;
107
108 tokio::select! {
109 // Handle incoming commands (send requests)
110 command = self.command_rx.recv() => {
111 match command {
112 Some(cmd) => {
113 match cmd {
114 DriverCommand::SendRequest { .. } => {
115 self.handle_send_request(cmd).await?;
116 }
117 DriverCommand::SendStreamingRequest { .. } => {
118 tracing::warn!("Streaming requests not yet implemented in driver");
119 }
120 }
121 }
122 None => {
123 // Channel closed - driver should shutdown
124 break;
125 }
126 }
127 }
128
129 // Handle incoming frames
130 read_res = self.connection.read_next_frame() => {
131 match read_res {
132 Ok((header, payload)) => {
133 if let Err(e) = self.handle_frame(header, payload).await {
134 tracing::error!("H2Driver frame error: {:?}", e);
135 // Protocol errors are fatal and require connection termination.
136 // The connection state may be inconsistent after this error.
137 return Err(e);
138 }
139 }
140 Err(e) => {
141 // Connection error
142 tracing::error!("H2Driver read error: {:?}", e);
143 return Err(e);
144 }
145 }
146 }
147 }
148 }
149 Ok(())
150 }
151
152 /// Handle SendRequest command
153 async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
154 let max_streams = self.connection.peer_settings().max_concurrent_streams;
155
156 if self.streams.len() >= max_streams as usize {
157 // Queue request
158 self.pending_requests.push_back(cmd);
159 } else {
160 // Send immediately
161 self.send_request_internal(cmd).await?;
162 }
163 Ok(())
164 }
165
166 /// Process pending requests if slots available
167 async fn process_pending_requests(&mut self) -> Result<()> {
168 let max_streams = self.connection.peer_settings().max_concurrent_streams;
169
170 while self.streams.len() < max_streams as usize {
171 if let Some(cmd) = self.pending_requests.pop_front() {
172 self.send_request_internal(cmd).await?;
173 } else {
174 break;
175 }
176 }
177 Ok(())
178 }
179
180 /// Internal helper to send request
181 async fn send_request_internal(&mut self, cmd: DriverCommand) -> Result<()> {
182 if let DriverCommand::SendRequest {
183 method,
184 uri,
185 headers,
186 body,
187 response_tx,
188 } = cmd
189 {
190 // Construct request
191 let mut req_builder = http::Request::builder().method(method).uri(uri);
192
193 for (k, v) in headers {
194 req_builder = req_builder.header(k, v);
195 }
196
197 // Body
198 let body_bytes = body.unwrap_or_default();
199 let has_body = !body_bytes.is_empty();
200
201 let req = match req_builder.body(body_bytes.clone()) {
202 Ok(r) => r,
203 Err(e) => {
204 if response_tx
205 .send(Err(Error::HttpProtocol(format!("Invalid request: {}", e))))
206 .is_err()
207 {
208 tracing::debug!("Response channel closed while sending error");
209 }
210 return Ok(());
211 }
212 };
213
214 // Send HEADERS frame (non-blocking write)
215 // If body is present, end_stream=false (DATA frames will be sent separately)
216 let end_stream = !has_body;
217
218 match self.connection.send_headers(&req, end_stream).await {
219 Ok(stream_id) => {
220 // Register stream state
221 self.streams
222 .insert(stream_id, DriverStreamState::new(response_tx, body_bytes));
223
224 // Trigger flush to try sending body immediately
225 self.flush_pending_data().await?;
226 }
227 Err(e) => {
228 // Notify error immediately
229 if response_tx.send(Err(e)).is_err() {
230 tracing::debug!("Response channel closed while sending error");
231 }
232 }
233 }
234 }
235 Ok(())
236 }
237
238 /// Iterate all active streams and try to send pending body data
239 async fn flush_pending_data(&mut self) -> Result<()> {
240 // Collect IDs to avoid borrow conflict
241 let stream_ids: Vec<u32> = self.streams.keys().cloned().collect();
242
243 for stream_id in stream_ids {
244 // Keep sending chunks for this stream until blocked or done
245 loop {
246 // Check if we have data to send
247 let (has_data, offset) = if let Some(stream) = self.streams.get(&stream_id) {
248 (
249 stream.body_offset < stream.pending_body.len(),
250 stream.body_offset,
251 )
252 } else {
253 (false, 0)
254 };
255
256 if !has_data {
257 break;
258 }
259
260 // Prepare arguments for send_data
261 // We clone the Bytes handle which is cheap
262 let pending_body = {
263 let s = self.streams.get(&stream_id).unwrap();
264 s.pending_body.clone()
265 };
266
267 let remaining = &pending_body[offset..];
268 let is_last_chunk = true;
269
270 // send_data returns bytes sent. If 0, it means blocked.
271 let sent = self
272 .connection
273 .send_data(stream_id, remaining, is_last_chunk)
274 .await?;
275
276 if sent > 0 {
277 if let Some(stream) = self.streams.get_mut(&stream_id) {
278 stream.body_offset += sent;
279 }
280 // Loop again to send next chunk
281 } else {
282 // Blocked by flow control
283 break;
284 }
285 }
286 }
287 Ok(())
288 }
289
290 /// Handle a single frame
291 async fn handle_frame(&mut self, header: FrameHeader, mut payload: Bytes) -> Result<()> {
292 // 1. Check control frames that modify connection state
293 match self
294 .connection
295 .handle_control_frame(&header, payload.clone())
296 .await?
297 {
298 ControlAction::RstStream(sid, code) => {
299 // Notify stream of reset
300 if let Some(mut stream) = self.streams.remove(&sid) {
301 if let Some(tx) = stream.response_tx.take() {
302 if tx
303 .send(Err(Error::HttpProtocol(format!(
304 "Stream reset by peer: {:?}",
305 code
306 ))))
307 .is_err()
308 {
309 tracing::debug!("Response channel closed while notifying stream reset");
310 }
311 }
312 }
313 // Stream slot freed, try to process pending
314 self.process_pending_requests().await?;
315 return Ok(());
316 }
317 ControlAction::GoAway(last_sid) => {
318 // Close all streams > last_sid
319 let sids: Vec<u32> = self.streams.keys().cloned().collect();
320 for sid in sids {
321 if sid > last_sid {
322 if let Some(mut stream) = self.streams.remove(&sid) {
323 if let Some(tx) = stream.response_tx.take() {
324 if tx
325 .send(Err(Error::HttpProtocol("GOAWAY received".into())))
326 .is_err()
327 {
328 tracing::debug!(
329 "Response channel closed while notifying GOAWAY"
330 );
331 }
332 }
333 }
334 }
335 }
336 // Driver continues processing existing streams until they complete.
337 // A future enhancement could implement immediate shutdown on GOAWAY.
338 return Ok(());
339 }
340 ControlAction::RefusePush(_stream_id, promised_id) => {
341 // Send RST_STREAM for the promised stream
342 // RFC 9113 8.4: RST_STREAM with REFUSED_STREAM
343 if let Err(e) = self
344 .connection
345 .send_rst_stream(promised_id, ErrorCode::RefusedStream)
346 .await
347 {
348 tracing::warn!(
349 "Failed to send RST_STREAM for refused push promise: {:?}",
350 e
351 );
352 }
353 }
354 ControlAction::None => {
355 // Continue to specific processing
356 }
357 }
358
359 // 2. Data / Headers routing
360 match header.frame_type {
361 FrameType::Headers => {
362 let stream_id = header.stream_id;
363
364 // Handle CONTINUATION frames if needed (END_HEADERS flag not set).
365 // CONTINUATION frames are collected in the loop below; this branch handles
366 // the initial HEADERS frame that starts a header block.
367 if (header.flags & flags::END_HEADERS) == 0 {
368 // Loop to read CONTINUATION frames
369 // This inner loop blocks the driver select! loop, which is expected
370 // per RFC 9113 Section 6.2 (CONTINUATION frames must be processed sequentially).
371 let mut block = BytesMut::from(payload);
372 loop {
373 let (next_header, next_payload) = self.connection.read_next_frame().await?;
374 if next_header.frame_type != FrameType::Continuation {
375 return Err(Error::HttpProtocol("Expected CONTINUATION frame".into()));
376 }
377 if next_header.stream_id != stream_id {
378 return Err(Error::HttpProtocol(
379 "CONTINUATION frame stream ID mismatch".into(),
380 ));
381 }
382 block.extend_from_slice(&next_payload);
383 if (next_header.flags & flags::END_HEADERS) != 0 {
384 break;
385 }
386 }
387 payload = block.freeze();
388 }
389
390 let decoded = self.connection.decode_header_block(payload)?;
391
392 // Parse pseudo-headers
393 let mut status = 0u16;
394 let mut regular_headers = Vec::new();
395
396 for (name, value) in decoded {
397 if name == ":status" {
398 status = value.parse().unwrap_or(0);
399 } else if !name.starts_with(':') {
400 regular_headers.push((name, value));
401 }
402 }
403
404 if let Some(stream) = self.streams.get_mut(&stream_id) {
405 stream.status = Some(status);
406 stream.headers = regular_headers;
407
408 if (header.flags & flags::END_STREAM) != 0 {
409 self.complete_stream(stream_id);
410 }
411 }
412 }
413 FrameType::Data => {
414 let stream_id = header.stream_id;
415 let end_stream = (header.flags & flags::END_STREAM) != 0;
416
417 // Process flow control for inbound DATA frame.
418 // The process_inbound_data_frame method takes stream_id, flags, and payload
419 // to handle window updates and flow control state.
420 let data = self
421 .connection
422 .process_inbound_data_frame(stream_id, header.flags, payload)
423 .await?;
424
425 if let Some(stream) = self.streams.get_mut(&stream_id) {
426 stream.body.extend_from_slice(&data);
427
428 if end_stream {
429 self.complete_stream(stream_id);
430 }
431 }
432 }
433 FrameType::WindowUpdate => {
434 // Window update received and processed by handle_control_frame,
435 // which updates the connection/stream window in self.connection.
436 // Flush any pending data that was previously blocked by flow control.
437 self.flush_pending_data().await?;
438 }
439 _ => {} // Other frames handled by handle_control_frame (or ignored)
440 }
441
442 Ok(())
443 }
444
445 /// Complete a stream: build response and send
446 fn complete_stream(&mut self, stream_id: u32) {
447 if let Some(mut stream) = self.streams.remove(&stream_id) {
448 if let Some(tx) = stream.response_tx.take() {
449 // If no status was received, this is a protocol violation
450 // Return an error rather than defaulting to 200
451 let response = match stream.status {
452 Some(status) => Ok(StreamResponse {
453 status,
454 headers: stream.headers,
455 body: stream.body.freeze(),
456 }),
457 None => Err(Error::HttpProtocol(format!(
458 "Stream {} completed without status code",
459 stream_id
460 ))),
461 };
462 if tx.send(response).is_err() {
463 tracing::debug!("Response channel closed while completing stream");
464 }
465 }
466 }
467 // Stream slot is now available. The main loop will call process_pending_requests
468 // to process any queued requests waiting for available stream slots.
469 }
470}