1use bytes::{Bytes, BytesMut};
8use http::{Method, Uri};
9use std::collections::{HashMap, VecDeque};
10use tokio::sync::mpsc;
11use tokio::sync::oneshot;
12use tracing;
13
14pub type StreamingHeadersResult = Result<(u16, Vec<(String, String)>)>;
15
16use crate::error::{Error, Result};
17use crate::transport::h2::connection::{
18 ControlAction, H2Connection as RawH2Connection, StreamResponse,
19};
20use crate::transport::h2::frame::{flags, ErrorCode, FrameHeader, FrameType};
21use crate::transport::h2::tunnel::{H2Tunnel, H2TunnelEvent, H2TunnelOutbound};
22
23#[derive(Debug)]
25pub enum DriverCommand {
26 SendRequest {
29 method: http::Method,
30 uri: http::Uri,
31 headers: Vec<(String, String)>,
32 body: Option<bytes::Bytes>,
33 response_tx: oneshot::Sender<Result<StreamResponse>>,
34 },
35 SendStreamingRequest {
37 method: Method,
38 uri: Uri,
39 headers: Vec<(String, String)>,
40 body_tx: mpsc::Sender<Result<Bytes>>,
41 headers_tx: oneshot::Sender<StreamingHeadersResult>,
42 },
43 OpenWebSocketTunnel {
45 uri: Uri,
46 headers: Vec<(String, String)>,
47 response_tx: oneshot::Sender<Result<H2Tunnel>>,
48 },
49 SendTunnelData {
51 stream_id: u32,
52 outbound: H2TunnelOutbound,
53 },
54}
55
56struct DriverStreamState {
58 response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
60 status: Option<u16>,
62 headers: Vec<(String, String)>,
64 body: BytesMut,
66 pending_body: Bytes,
68 body_offset: usize,
70}
71
72impl DriverStreamState {
73 fn new(response_tx: oneshot::Sender<Result<StreamResponse>>, pending_body: Bytes) -> Self {
74 Self {
75 response_tx: Some(response_tx),
76 status: None,
77 headers: Vec::new(),
78 body: BytesMut::new(),
79 pending_body,
80 body_offset: 0,
81 }
82 }
83}
84
85struct DriverTunnelState {
86 inbound_tx: mpsc::Sender<Result<H2TunnelEvent>>,
87 pending_outbound: VecDeque<H2TunnelOutbound>,
88}
89
90pub struct H2Driver<S> {
92 command_rx: mpsc::Receiver<DriverCommand>,
94 command_tx: mpsc::Sender<DriverCommand>,
96 connection: RawH2Connection<S>,
98 streams: HashMap<u32, DriverStreamState>,
100 tunnels: HashMap<u32, DriverTunnelState>,
102 pending_requests: std::collections::VecDeque<DriverCommand>,
104}
105
106impl<S> H2Driver<S>
107where
108 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
109{
110 pub fn new(
112 connection: RawH2Connection<S>,
113 command_tx: mpsc::Sender<DriverCommand>,
114 command_rx: mpsc::Receiver<DriverCommand>,
115 ) -> Self {
116 Self {
117 command_rx,
118 command_tx,
119 connection,
120 streams: HashMap::new(),
121 tunnels: HashMap::new(),
122 pending_requests: std::collections::VecDeque::new(),
123 }
124 }
125
126 pub async fn drive(mut self) -> Result<()> {
128 loop {
129 self.process_pending_requests().await?;
131
132 self.flush_pending_data().await?;
134 self.flush_tunnel_data().await?;
135
136 tokio::select! {
137 command = self.command_rx.recv() => {
139 match command {
140 Some(cmd) => {
141 match cmd {
142 DriverCommand::SendRequest { .. } => {
143 self.handle_send_request(cmd).await?;
144 }
145 DriverCommand::SendStreamingRequest { .. } => {
146 tracing::warn!("Streaming requests not yet implemented in driver");
147 }
148 DriverCommand::OpenWebSocketTunnel { uri, headers, response_tx } => {
149 self.handle_open_websocket_tunnel(uri, headers, response_tx).await?;
150 }
151 DriverCommand::SendTunnelData { stream_id, outbound } => {
152 self.queue_tunnel_outbound(stream_id, outbound).await?;
153 }
154 }
155 }
156 None => {
157 break;
159 }
160 }
161 }
162
163 read_res = self.connection.read_next_frame() => {
165 match read_res {
166 Ok((header, payload)) => {
167 if let Err(e) = self.handle_frame(header, payload).await {
168 tracing::error!("H2Driver frame error: {:?}", e);
169 return Err(e);
172 }
173 }
174 Err(e) => {
175 tracing::error!("H2Driver read error: {:?}", e);
177 return Err(e);
178 }
179 }
180 }
181 }
182 }
183 Ok(())
184 }
185
186 async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
188 if !self.has_available_stream_slot() {
189 self.pending_requests.push_back(cmd);
191 } else {
192 self.send_request_internal(cmd).await?;
194 }
195 Ok(())
196 }
197
198 async fn process_pending_requests(&mut self) -> Result<()> {
200 while self.has_available_stream_slot() {
201 if let Some(cmd) = self.pending_requests.pop_front() {
202 match cmd {
203 DriverCommand::SendRequest { .. } => {
204 self.send_request_internal(cmd).await?;
205 }
206 DriverCommand::OpenWebSocketTunnel {
207 uri,
208 headers,
209 response_tx,
210 } => {
211 self.open_websocket_tunnel_internal(uri, headers, response_tx)
212 .await?;
213 }
214 DriverCommand::SendStreamingRequest { .. } => {
215 tracing::warn!("Streaming requests not yet implemented in driver");
216 }
217 DriverCommand::SendTunnelData {
218 stream_id,
219 outbound,
220 } => {
221 self.queue_tunnel_outbound(stream_id, outbound).await?;
222 }
223 }
224 } else {
225 break;
226 }
227 }
228 Ok(())
229 }
230
231 fn active_stream_count(&self) -> usize {
232 self.streams.len() + self.tunnels.len()
233 }
234
235 fn has_available_stream_slot(&self) -> bool {
236 let max_streams = self.connection.peer_settings().max_concurrent_streams as usize;
237 self.active_stream_count() < max_streams
238 }
239
240 async fn send_request_internal(&mut self, cmd: DriverCommand) -> Result<()> {
242 if let DriverCommand::SendRequest {
243 method,
244 uri,
245 headers,
246 body,
247 response_tx,
248 } = cmd
249 {
250 let mut req_builder = http::Request::builder().method(method).uri(uri);
252
253 for (k, v) in headers {
254 req_builder = req_builder.header(k, v);
255 }
256
257 let body_bytes = body.unwrap_or_default();
259 let has_body = !body_bytes.is_empty();
260
261 let req = match req_builder.body(body_bytes.clone()) {
262 Ok(r) => r,
263 Err(e) => {
264 if response_tx
265 .send(Err(Error::HttpProtocol(format!("Invalid request: {}", e))))
266 .is_err()
267 {
268 tracing::debug!("Response channel closed while sending error");
269 }
270 return Ok(());
271 }
272 };
273
274 let end_stream = !has_body;
277
278 match self.connection.send_headers(&req, end_stream).await {
279 Ok(stream_id) => {
280 self.streams
282 .insert(stream_id, DriverStreamState::new(response_tx, body_bytes));
283
284 self.flush_pending_data().await?;
286 }
287 Err(e) => {
288 if response_tx.send(Err(e)).is_err() {
290 tracing::debug!("Response channel closed while sending error");
291 }
292 }
293 }
294 }
295 Ok(())
296 }
297
298 async fn handle_open_websocket_tunnel(
299 &mut self,
300 uri: Uri,
301 headers: Vec<(String, String)>,
302 response_tx: oneshot::Sender<Result<H2Tunnel>>,
303 ) -> Result<()> {
304 if !self.has_available_stream_slot() {
305 self.pending_requests
306 .push_back(DriverCommand::OpenWebSocketTunnel {
307 uri,
308 headers,
309 response_tx,
310 });
311 return Ok(());
312 }
313
314 self.open_websocket_tunnel_internal(uri, headers, response_tx)
315 .await
316 }
317
318 async fn open_websocket_tunnel_internal(
319 &mut self,
320 uri: Uri,
321 headers: Vec<(String, String)>,
322 response_tx: oneshot::Sender<Result<H2Tunnel>>,
323 ) -> Result<()> {
324 match self
325 .connection
326 .open_extended_connect_websocket_with_end_stream(&uri, headers)
327 .await
328 {
329 Ok((stream_id, end_stream)) => {
330 let (outbound_tx, outbound_rx) = mpsc::channel(32);
331 let (inbound_tx, inbound_rx) = mpsc::channel(32);
332 if end_stream {
333 let _ = inbound_tx.send(Ok(H2TunnelEvent::EndStream)).await;
334 self.connection.remove_stream(stream_id);
335 } else {
336 let command_tx = self.command_tx.clone();
337 tokio::spawn(async move {
338 let mut outbound_rx = outbound_rx;
339 while let Some(outbound) = outbound_rx.recv().await {
340 if command_tx
341 .send(DriverCommand::SendTunnelData {
342 stream_id,
343 outbound,
344 })
345 .await
346 .is_err()
347 {
348 break;
349 }
350 }
351 });
352 self.tunnels.insert(
353 stream_id,
354 DriverTunnelState {
355 inbound_tx,
356 pending_outbound: VecDeque::new(),
357 },
358 );
359 }
360
361 if response_tx
362 .send(Ok(H2Tunnel::new(outbound_tx, inbound_rx)))
363 .is_err()
364 {
365 tracing::debug!("Tunnel response channel closed after open");
366 self.tunnels.remove(&stream_id);
367 }
368 }
369 Err(e) => {
370 if response_tx.send(Err(e)).is_err() {
371 tracing::debug!("Tunnel response channel closed while sending open error");
372 }
373 }
374 }
375 Ok(())
376 }
377
378 async fn queue_tunnel_outbound(
379 &mut self,
380 stream_id: u32,
381 outbound: H2TunnelOutbound,
382 ) -> Result<()> {
383 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
384 tunnel.pending_outbound.push_back(outbound);
385 self.flush_tunnel_data().await?;
386 }
387
388 Ok(())
389 }
390
391 async fn flush_pending_data(&mut self) -> Result<()> {
393 let stream_ids: Vec<u32> = self.streams.keys().cloned().collect();
395
396 for stream_id in stream_ids {
397 loop {
399 let (has_data, offset) = if let Some(stream) = self.streams.get(&stream_id) {
401 (
402 stream.body_offset < stream.pending_body.len(),
403 stream.body_offset,
404 )
405 } else {
406 (false, 0)
407 };
408
409 if !has_data {
410 break;
411 }
412
413 let pending_body = {
416 let s = self.streams.get(&stream_id).unwrap();
417 s.pending_body.clone()
418 };
419
420 let remaining = &pending_body[offset..];
421 let is_last_chunk = true;
422
423 let sent = self
425 .connection
426 .send_data(stream_id, remaining, is_last_chunk)
427 .await?;
428
429 if sent > 0 {
430 if let Some(stream) = self.streams.get_mut(&stream_id) {
431 stream.body_offset += sent;
432 }
433 } else {
435 break;
437 }
438 }
439 }
440 Ok(())
441 }
442
443 async fn flush_tunnel_data(&mut self) -> Result<()> {
444 let stream_ids: Vec<u32> = self.tunnels.keys().copied().collect();
445
446 for stream_id in stream_ids {
447 loop {
448 let outbound = match self
449 .tunnels
450 .get_mut(&stream_id)
451 .and_then(|tunnel| tunnel.pending_outbound.pop_front())
452 {
453 Some(outbound) => outbound,
454 None => break,
455 };
456
457 let sent = self
458 .connection
459 .send_data(stream_id, &outbound.bytes, outbound.end_stream)
460 .await?;
461
462 if outbound.bytes.is_empty() {
463 continue;
464 }
465
466 if sent == 0 {
467 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
468 tunnel.pending_outbound.push_front(outbound);
469 }
470 break;
471 }
472
473 if sent < outbound.bytes.len() {
474 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
475 tunnel.pending_outbound.push_front(H2TunnelOutbound {
476 bytes: outbound.bytes.slice(sent..),
477 end_stream: outbound.end_stream,
478 });
479 }
480 break;
481 }
482 }
483 }
484
485 Ok(())
486 }
487
488 async fn handle_frame(&mut self, header: FrameHeader, mut payload: Bytes) -> Result<()> {
490 match self
492 .connection
493 .handle_control_frame(&header, payload.clone())
494 .await?
495 {
496 ControlAction::RstStream(sid, code) => {
497 if let Some(tunnel) = self.tunnels.remove(&sid) {
498 let _ = tunnel
499 .inbound_tx
500 .send(Ok(H2TunnelEvent::Reset(format!("{:?}", code))))
501 .await;
502 }
503 if let Some(mut stream) = self.streams.remove(&sid) {
505 if let Some(tx) = stream.response_tx.take() {
506 if tx
507 .send(Err(Error::HttpProtocol(format!(
508 "Stream reset by peer: {:?}",
509 code
510 ))))
511 .is_err()
512 {
513 tracing::debug!("Response channel closed while notifying stream reset");
514 }
515 }
516 }
517 self.process_pending_requests().await?;
519 return Ok(());
520 }
521 ControlAction::GoAway(last_sid) => {
522 let tunnel_ids: Vec<u32> = self.tunnels.keys().copied().collect();
523 for sid in tunnel_ids {
524 if sid > last_sid {
525 if let Some(tunnel) = self.tunnels.remove(&sid) {
526 let _ = tunnel
527 .inbound_tx
528 .send(Ok(H2TunnelEvent::GoAway {
529 last_stream_id: last_sid,
530 }))
531 .await;
532 }
533 }
534 }
535 let sids: Vec<u32> = self.streams.keys().cloned().collect();
537 for sid in sids {
538 if sid > last_sid {
539 if let Some(mut stream) = self.streams.remove(&sid) {
540 if let Some(tx) = stream.response_tx.take() {
541 if tx
542 .send(Err(Error::HttpProtocol("GOAWAY received".into())))
543 .is_err()
544 {
545 tracing::debug!(
546 "Response channel closed while notifying GOAWAY"
547 );
548 }
549 }
550 }
551 }
552 }
553 return Ok(());
556 }
557 ControlAction::RefusePush(_stream_id, promised_id) => {
558 if let Err(e) = self
561 .connection
562 .send_rst_stream(promised_id, ErrorCode::RefusedStream)
563 .await
564 {
565 tracing::warn!(
566 "Failed to send RST_STREAM for refused push promise: {:?}",
567 e
568 );
569 }
570 }
571 ControlAction::None => {
572 }
574 }
575
576 match header.frame_type {
578 FrameType::Headers => {
579 let stream_id = header.stream_id;
580
581 if (header.flags & flags::END_HEADERS) == 0 {
585 let mut block = BytesMut::from(payload);
589 loop {
590 let (next_header, next_payload) = self.connection.read_next_frame().await?;
591 if next_header.frame_type != FrameType::Continuation {
592 return Err(Error::HttpProtocol("Expected CONTINUATION frame".into()));
593 }
594 if next_header.stream_id != stream_id {
595 return Err(Error::HttpProtocol(
596 "CONTINUATION frame stream ID mismatch".into(),
597 ));
598 }
599 block.extend_from_slice(&next_payload);
600 if (next_header.flags & flags::END_HEADERS) != 0 {
601 break;
602 }
603 }
604 payload = block.freeze();
605 }
606
607 let decoded = self.connection.decode_header_block(payload)?;
608
609 let mut status = 0u16;
611 let mut regular_headers = Vec::new();
612
613 for (name, value) in decoded {
614 if name == ":status" {
615 status = value.parse().unwrap_or(0);
616 } else if !name.starts_with(':') {
617 regular_headers.push((name, value));
618 }
619 }
620
621 if let Some(stream) = self.streams.get_mut(&stream_id) {
622 stream.status = Some(status);
623 stream.headers = regular_headers;
624
625 if (header.flags & flags::END_STREAM) != 0 {
626 self.complete_stream(stream_id);
627 }
628 }
629 }
630 FrameType::Data => {
631 let stream_id = header.stream_id;
632 let end_stream = (header.flags & flags::END_STREAM) != 0;
633
634 let data = self
638 .connection
639 .process_inbound_data_frame(stream_id, header.flags, payload)
640 .await?;
641
642 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
643 if !data.is_empty() {
644 let _ = tunnel.inbound_tx.send(Ok(H2TunnelEvent::Data(data))).await;
645 }
646 if end_stream {
647 let _ = tunnel.inbound_tx.send(Ok(H2TunnelEvent::EndStream)).await;
648 self.tunnels.remove(&stream_id);
649 }
650 return Ok(());
651 }
652
653 if let Some(stream) = self.streams.get_mut(&stream_id) {
654 stream.body.extend_from_slice(&data);
655
656 if end_stream {
657 self.complete_stream(stream_id);
658 }
659 }
660 }
661 FrameType::WindowUpdate => {
662 self.flush_pending_data().await?;
666 self.flush_tunnel_data().await?;
667 }
668 _ => {} }
670
671 Ok(())
672 }
673
674 fn complete_stream(&mut self, stream_id: u32) {
676 if let Some(mut stream) = self.streams.remove(&stream_id) {
677 if let Some(tx) = stream.response_tx.take() {
678 let response = match stream.status {
681 Some(status) => Ok(StreamResponse {
682 status,
683 headers: stream.headers,
684 body: stream.body.freeze(),
685 }),
686 None => Err(Error::HttpProtocol(format!(
687 "Stream {} completed without status code",
688 stream_id
689 ))),
690 };
691 if tx.send(response).is_err() {
692 tracing::debug!("Response channel closed while completing stream");
693 }
694 }
695 }
696 }
699}