Skip to main content

specter/transport/h3/
driver.rs

1//! HTTP/3 connection driver - background task that reads packets and routes them to streams.
2//!
3//! The driver owns the QUIC connection and UdpSocket.
4
5use bytes::{Bytes, BytesMut};
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc;
12use tokio::sync::oneshot;
13use tokio::time::sleep;
14use tracing;
15
16use crate::error::{Error, Result};
17use quiche::h3::NameValue;
18
19/// Command sent from handle to driver
20#[derive(Debug)]
21pub enum DriverCommand {
22    /// Send a request and get response via oneshot
23    SendRequest {
24        method: http::Method,
25        uri: http::Uri,
26        headers: Vec<(String, String)>,
27        body: Option<Bytes>,
28        response_tx: oneshot::Sender<Result<StreamResponse>>,
29    },
30}
31
32#[derive(Debug)]
33pub struct StreamResponse {
34    pub status: u16,
35    pub headers: Vec<(String, String)>,
36    pub body: Bytes,
37}
38
39/// Per-stream state tracked by driver
40struct DriverStreamState {
41    /// Oneshot sender for response completion
42    response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
43    /// Accumulated response status
44    status: Option<u16>,
45    /// Accumulated response headers
46    headers: Vec<(String, String)>,
47    /// Accumulated response body
48    body: BytesMut,
49}
50
51impl DriverStreamState {
52    fn new(response_tx: oneshot::Sender<Result<StreamResponse>>) -> Self {
53        Self {
54            response_tx: Some(response_tx),
55            status: None,
56            headers: Vec::new(),
57            body: BytesMut::new(),
58        }
59    }
60}
61
62/// HTTP/3 connection driver
63pub struct H3Driver {
64    command_rx: mpsc::Receiver<DriverCommand>,
65    conn: quiche::Connection,
66    h3_conn: quiche::h3::Connection,
67    socket: Arc<UdpSocket>,
68    peer_addr: SocketAddr,
69    streams: HashMap<u64, DriverStreamState>,
70}
71
72impl H3Driver {
73    pub fn new(
74        command_rx: mpsc::Receiver<DriverCommand>,
75        conn: quiche::Connection,
76        h3_conn: quiche::h3::Connection,
77        socket: Arc<UdpSocket>,
78        peer_addr: SocketAddr,
79    ) -> Self {
80        Self {
81            command_rx,
82            conn,
83            h3_conn,
84            socket,
85            peer_addr,
86            streams: HashMap::new(),
87        }
88    }
89
90    pub async fn drive(mut self) -> Result<()> {
91        let result = self.drive_loop().await;
92
93        // Propagate error to all pending streams
94        if let Err(ref e) = result {
95            tracing::error!("H3 Driver error: {}", e);
96            for (_, mut stream) in self.streams.drain() {
97                if let Some(tx) = stream.response_tx.take() {
98                    let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
99                }
100            }
101        }
102
103        result
104    }
105
106    async fn drive_loop(&mut self) -> Result<()> {
107        let mut buf = vec![0u8; 65535];
108        let mut out = vec![0u8; 1350];
109
110        loop {
111            // 1. Process sending any pending packets first (egress)
112            loop {
113                match self.conn.send(&mut out) {
114                    Ok((len, _)) => {
115                        if let Err(e) = self.socket.send_to(&out[..len], self.peer_addr).await {
116                            tracing::error!("H3 socket send error: {}", e);
117                            return Err(Error::Io(e));
118                        }
119                    }
120                    Err(quiche::Error::Done) => break,
121                    Err(e) => {
122                        tracing::error!("H3 quiche send error: {}", e);
123                        return Err(Error::Quic(format!("QUIC send error: {}", e)));
124                    }
125                }
126            }
127
128            // 2. Select: Recv Packet OR Command OR Timeout
129            let timeout_duration = self.conn.timeout().unwrap_or(Duration::from_secs(60));
130
131            tokio::select! {
132                // Incoming Command
133                cmd = self.command_rx.recv() => {
134                    match cmd {
135                        Some(c) => self.handle_command(c).await?,
136                        None => {
137                            match self.conn.close(true, 0x00, b"Client shutdown") {
138                                Ok(_) => {},
139                                Err(quiche::Error::Done) => {},
140                                Err(_) => {}
141                            }
142                            while let Ok((len, _)) = self.conn.send(&mut out) {
143                                let _ = self.socket.send_to(&out[..len], self.peer_addr).await;
144                            }
145                            return Ok(());
146                        }
147                    }
148                }
149
150                // Incoming Packet
151                res = self.socket.recv_from(&mut buf) => {
152                    match res {
153                        Ok((len, from)) => {
154                            if from == self.peer_addr {
155                                let info = quiche::RecvInfo {
156                                    from,
157                                    to: self.socket.local_addr().unwrap(),
158                                    // to: self.socket.local_addr().unwrap(), // Need to handle unchecked?
159                                    // The original code unwrapped, presumably safe if bound.
160                                };
161                                match self.conn.recv(&mut buf[..len], info) {
162                                    Ok(_) => {
163                                        self.process_h3_events()?;
164                                    }
165                                    Err(quiche::Error::Done) => {},
166                                    Err(e) => {
167                                        tracing::warn!("QUIC recv error: {}", e);
168                                    }
169                                }
170                            }
171                        }
172                        Err(e) => return Err(Error::Io(e)),
173                    }
174                }
175
176                // Timer
177                _ = sleep(timeout_duration) => {
178                    self.conn.on_timeout();
179                }
180            }
181
182            // Check for connection closure
183            if self.conn.is_closed() {
184                tracing::info!("H3 Driver: Connection closed");
185                for (_id, mut stream) in self.streams.drain() {
186                    if let Some(tx) = stream.response_tx.take() {
187                        let _ = tx.send(Err(Error::Connection("Connection closed".into())));
188                    }
189                }
190                return Ok(());
191            }
192        }
193    }
194
195    async fn handle_command(&mut self, cmd: DriverCommand) -> Result<()> {
196        match cmd {
197            DriverCommand::SendRequest {
198                method,
199                uri,
200                headers,
201                body,
202                response_tx,
203            } => {
204                // Construct H3 headers
205                let path = uri.path();
206                let path = if path.is_empty() { "/" } else { path };
207                let host = uri.host().unwrap_or("").to_string();
208
209                let mut h3_headers = vec![
210                    quiche::h3::Header::new(b":method", method.as_str().as_bytes()),
211                    quiche::h3::Header::new(b":scheme", b"https"),
212                    quiche::h3::Header::new(b":authority", host.as_bytes()),
213                    quiche::h3::Header::new(b":path", path.as_bytes()),
214                ];
215
216                for (k, v) in &headers {
217                    let k_lower = k.to_lowercase();
218                    // Filter pseudo and prohibited headers
219                    if !k.starts_with(':')
220                        && k_lower != "connection"
221                        && k_lower != "keep-alive"
222                        && k_lower != "proxy-connection"
223                        && k_lower != "transfer-encoding"
224                        && k_lower != "upgrade"
225                    {
226                        h3_headers.push(quiche::h3::Header::new(k.as_bytes(), v.as_bytes()));
227                    }
228                }
229
230                // Send request logic
231                let fin = body.is_none();
232                match self.h3_conn.send_request(&mut self.conn, &h3_headers, fin) {
233                    Ok(stream_id) => {
234                        // Store stream state
235                        let mut state = DriverStreamState::new(response_tx);
236
237                        // Send body if present
238                        if let Some(data) = body {
239                            if let Err(e) =
240                                self.h3_conn
241                                    .send_body(&mut self.conn, stream_id, &data, true)
242                            {
243                                // Error sending body
244                                if let Some(tx) = state.response_tx.take() {
245                                    let _ = tx
246                                        .send(Err(Error::Quic(format!("Send body failed: {}", e))));
247                                }
248                                return Ok(());
249                            }
250                        }
251
252                        self.streams.insert(stream_id, state);
253                    }
254                    Err(e) => {
255                        let _ = response_tx
256                            .send(Err(Error::Quic(format!("Send request failed: {}", e))));
257                    }
258                }
259            }
260        }
261        Ok(())
262    }
263
264    fn process_h3_events(&mut self) -> Result<()> {
265        loop {
266            match self.h3_conn.poll(&mut self.conn) {
267                Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
268                    if let Some(stream) = self.streams.get_mut(&stream_id) {
269                        for header in list {
270                            let name = String::from_utf8_lossy(header.name());
271                            let value = String::from_utf8_lossy(header.value());
272
273                            if name == ":status" {
274                                stream.status = value.parse().ok();
275                            } else {
276                                stream.headers.push((name.into_owned(), value.into_owned()));
277                            }
278                        }
279                    }
280                }
281                Ok((stream_id, quiche::h3::Event::Data)) => {
282                    if let Some(stream) = self.streams.get_mut(&stream_id) {
283                        let mut buf = vec![0u8; 65535];
284                        while let Ok(len) =
285                            self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf)
286                        {
287                            stream.body.extend_from_slice(&buf[..len]);
288                        }
289                    }
290                }
291                Ok((stream_id, quiche::h3::Event::Finished)) => {
292                    if let Some(mut stream) = self.streams.remove(&stream_id) {
293                        if let Some(tx) = stream.response_tx.take() {
294                            let resp = StreamResponse {
295                                status: stream.status.unwrap_or(0),
296                                headers: stream.headers,
297                                body: stream.body.freeze(),
298                            };
299                            let _ = tx.send(Ok(resp));
300                        }
301                    }
302                }
303                Ok((stream_id, quiche::h3::Event::Reset(error_code))) => {
304                    if let Some(mut stream) = self.streams.remove(&stream_id) {
305                        if let Some(tx) = stream.response_tx.take() {
306                            let _ =
307                                tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
308                        }
309                    }
310                }
311                Err(quiche::h3::Error::Done) => break,
312                Ok(_) => {} // Ignore other events
313                Err(e) => {
314                    tracing::warn!("H3 poll error: {}", e);
315                    return Err(Error::Quic(format!("H3 poll error: {}", e)));
316                }
317            }
318        }
319        Ok(())
320    }
321}