1use bytes::{Bytes, BytesMut};
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc;
12use tokio::sync::oneshot;
13use tokio::time::sleep;
14use tracing;
15
16use crate::error::{Error, Result};
17use quiche::h3::NameValue;
18
19#[derive(Debug)]
21pub enum DriverCommand {
22 SendRequest {
24 method: http::Method,
25 uri: http::Uri,
26 headers: Vec<(String, String)>,
27 body: Option<Bytes>,
28 response_tx: oneshot::Sender<Result<StreamResponse>>,
29 },
30}
31
32#[derive(Debug)]
33pub struct StreamResponse {
34 pub status: u16,
35 pub headers: Vec<(String, String)>,
36 pub body: Bytes,
37}
38
39struct DriverStreamState {
41 response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
43 status: Option<u16>,
45 headers: Vec<(String, String)>,
47 body: BytesMut,
49}
50
51impl DriverStreamState {
52 fn new(response_tx: oneshot::Sender<Result<StreamResponse>>) -> Self {
53 Self {
54 response_tx: Some(response_tx),
55 status: None,
56 headers: Vec::new(),
57 body: BytesMut::new(),
58 }
59 }
60}
61
62pub struct H3Driver {
64 command_rx: mpsc::Receiver<DriverCommand>,
65 conn: quiche::Connection,
66 h3_conn: quiche::h3::Connection,
67 socket: Arc<UdpSocket>,
68 peer_addr: SocketAddr,
69 streams: HashMap<u64, DriverStreamState>,
70}
71
72impl H3Driver {
73 pub fn new(
74 command_rx: mpsc::Receiver<DriverCommand>,
75 conn: quiche::Connection,
76 h3_conn: quiche::h3::Connection,
77 socket: Arc<UdpSocket>,
78 peer_addr: SocketAddr,
79 ) -> Self {
80 Self {
81 command_rx,
82 conn,
83 h3_conn,
84 socket,
85 peer_addr,
86 streams: HashMap::new(),
87 }
88 }
89
90 pub async fn drive(mut self) -> Result<()> {
91 let result = self.drive_loop().await;
92
93 if let Err(ref e) = result {
95 tracing::error!("H3 Driver error: {}", e);
96 for (_, mut stream) in self.streams.drain() {
97 if let Some(tx) = stream.response_tx.take() {
98 let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
99 }
100 }
101 }
102
103 result
104 }
105
106 async fn drive_loop(&mut self) -> Result<()> {
107 let mut buf = vec![0u8; 65535];
108 let mut out = vec![0u8; 1350];
109
110 loop {
111 loop {
113 match self.conn.send(&mut out) {
114 Ok((len, _)) => {
115 if let Err(e) = self.socket.send_to(&out[..len], self.peer_addr).await {
116 tracing::error!("H3 socket send error: {}", e);
117 return Err(Error::Io(e));
118 }
119 }
120 Err(quiche::Error::Done) => break,
121 Err(e) => {
122 tracing::error!("H3 quiche send error: {}", e);
123 return Err(Error::Quic(format!("QUIC send error: {}", e)));
124 }
125 }
126 }
127
128 let timeout_duration = self.conn.timeout().unwrap_or(Duration::from_secs(60));
130
131 tokio::select! {
132 cmd = self.command_rx.recv() => {
134 match cmd {
135 Some(c) => self.handle_command(c).await?,
136 None => {
137 match self.conn.close(true, 0x00, b"Client shutdown") {
138 Ok(_) => {},
139 Err(quiche::Error::Done) => {},
140 Err(_) => {}
141 }
142 while let Ok((len, _)) = self.conn.send(&mut out) {
143 let _ = self.socket.send_to(&out[..len], self.peer_addr).await;
144 }
145 return Ok(());
146 }
147 }
148 }
149
150 res = self.socket.recv_from(&mut buf) => {
152 match res {
153 Ok((len, from)) => {
154 if from == self.peer_addr {
155 let info = quiche::RecvInfo {
156 from,
157 to: self.socket.local_addr().unwrap(),
158 };
161 match self.conn.recv(&mut buf[..len], info) {
162 Ok(_) => {
163 self.process_h3_events()?;
164 }
165 Err(quiche::Error::Done) => {},
166 Err(e) => {
167 tracing::warn!("QUIC recv error: {}", e);
168 }
169 }
170 }
171 }
172 Err(e) => return Err(Error::Io(e)),
173 }
174 }
175
176 _ = sleep(timeout_duration) => {
178 self.conn.on_timeout();
179 }
180 }
181
182 if self.conn.is_closed() {
184 tracing::info!("H3 Driver: Connection closed");
185 for (_id, mut stream) in self.streams.drain() {
186 if let Some(tx) = stream.response_tx.take() {
187 let _ = tx.send(Err(Error::Connection("Connection closed".into())));
188 }
189 }
190 return Ok(());
191 }
192 }
193 }
194
195 async fn handle_command(&mut self, cmd: DriverCommand) -> Result<()> {
196 match cmd {
197 DriverCommand::SendRequest {
198 method,
199 uri,
200 headers,
201 body,
202 response_tx,
203 } => {
204 let path = uri.path();
206 let path = if path.is_empty() { "/" } else { path };
207 let host = uri.host().unwrap_or("").to_string();
208
209 let mut h3_headers = vec![
210 quiche::h3::Header::new(b":method", method.as_str().as_bytes()),
211 quiche::h3::Header::new(b":scheme", b"https"),
212 quiche::h3::Header::new(b":authority", host.as_bytes()),
213 quiche::h3::Header::new(b":path", path.as_bytes()),
214 ];
215
216 for (k, v) in &headers {
217 let k_lower = k.to_lowercase();
218 if !k.starts_with(':')
220 && k_lower != "connection"
221 && k_lower != "keep-alive"
222 && k_lower != "proxy-connection"
223 && k_lower != "transfer-encoding"
224 && k_lower != "upgrade"
225 {
226 h3_headers.push(quiche::h3::Header::new(k.as_bytes(), v.as_bytes()));
227 }
228 }
229
230 let fin = body.is_none();
232 match self.h3_conn.send_request(&mut self.conn, &h3_headers, fin) {
233 Ok(stream_id) => {
234 let mut state = DriverStreamState::new(response_tx);
236
237 if let Some(data) = body {
239 if let Err(e) =
240 self.h3_conn
241 .send_body(&mut self.conn, stream_id, &data, true)
242 {
243 if let Some(tx) = state.response_tx.take() {
245 let _ = tx
246 .send(Err(Error::Quic(format!("Send body failed: {}", e))));
247 }
248 return Ok(());
249 }
250 }
251
252 self.streams.insert(stream_id, state);
253 }
254 Err(e) => {
255 let _ = response_tx
256 .send(Err(Error::Quic(format!("Send request failed: {}", e))));
257 }
258 }
259 }
260 }
261 Ok(())
262 }
263
264 fn process_h3_events(&mut self) -> Result<()> {
265 loop {
266 match self.h3_conn.poll(&mut self.conn) {
267 Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
268 if let Some(stream) = self.streams.get_mut(&stream_id) {
269 for header in list {
270 let name = String::from_utf8_lossy(header.name());
271 let value = String::from_utf8_lossy(header.value());
272
273 if name == ":status" {
274 stream.status = value.parse().ok();
275 } else {
276 stream.headers.push((name.into_owned(), value.into_owned()));
277 }
278 }
279 }
280 }
281 Ok((stream_id, quiche::h3::Event::Data)) => {
282 if let Some(stream) = self.streams.get_mut(&stream_id) {
283 let mut buf = vec![0u8; 65535];
284 while let Ok(len) =
285 self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf)
286 {
287 stream.body.extend_from_slice(&buf[..len]);
288 }
289 }
290 }
291 Ok((stream_id, quiche::h3::Event::Finished)) => {
292 if let Some(mut stream) = self.streams.remove(&stream_id) {
293 if let Some(tx) = stream.response_tx.take() {
294 let resp = StreamResponse {
295 status: stream.status.unwrap_or(0),
296 headers: stream.headers,
297 body: stream.body.freeze(),
298 };
299 let _ = tx.send(Ok(resp));
300 }
301 }
302 }
303 Ok((stream_id, quiche::h3::Event::Reset(error_code))) => {
304 if let Some(mut stream) = self.streams.remove(&stream_id) {
305 if let Some(tx) = stream.response_tx.take() {
306 let _ =
307 tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
308 }
309 }
310 }
311 Err(quiche::h3::Error::Done) => break,
312 Ok(_) => {} Err(e) => {
314 tracing::warn!("H3 poll error: {}", e);
315 return Err(Error::Quic(format!("H3 poll error: {}", e)));
316 }
317 }
318 }
319 Ok(())
320 }
321}