1use bytes::{Bytes, BytesMut};
6use quiche::h3::NameValue;
7use std::collections::{HashMap, VecDeque};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::net::UdpSocket;
12use tokio::sync::mpsc;
13use tokio::sync::oneshot;
14use tokio::time::sleep;
15
16use crate::error::{Error, Result};
17use crate::transport::h3::{H3Tunnel, H3TunnelEvent, H3TunnelOutbound};
18
19#[derive(Debug)]
21pub enum DriverCommand {
22 SendRequest {
24 method: http::Method,
25 uri: http::Uri,
26 headers: Vec<(String, String)>,
27 body: Option<Bytes>,
28 response_tx: oneshot::Sender<Result<StreamResponse>>,
29 },
30 OpenWebSocketTunnel {
32 uri: http::Uri,
33 headers: Vec<(String, String)>,
34 response_tx: oneshot::Sender<Result<H3Tunnel>>,
35 },
36 SendTunnelData {
38 stream_id: u64,
39 outbound: H3TunnelOutbound,
40 },
41}
42
43#[derive(Debug)]
44pub struct StreamResponse {
45 pub status: u16,
46 pub headers: Vec<(String, String)>,
47 pub body: Bytes,
48}
49
50struct DriverStreamState {
52 response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
53 status: Option<u16>,
54 headers: Vec<(String, String)>,
55 body: BytesMut,
56}
57
58impl DriverStreamState {
59 fn new(response_tx: oneshot::Sender<Result<StreamResponse>>) -> Self {
60 Self {
61 response_tx: Some(response_tx),
62 status: None,
63 headers: Vec::new(),
64 body: BytesMut::new(),
65 }
66 }
67}
68
69struct DriverTunnelState {
70 response_tx: Option<oneshot::Sender<Result<H3Tunnel>>>,
71 outbound_tx: Option<mpsc::Sender<H3TunnelOutbound>>,
72 outbound_rx: Option<mpsc::Receiver<H3TunnelOutbound>>,
73 inbound_tx: mpsc::Sender<Result<H3TunnelEvent>>,
74 inbound_rx: Option<mpsc::Receiver<Result<H3TunnelEvent>>>,
75 pending_outbound: VecDeque<H3TunnelOutbound>,
76 opened: bool,
77 status: Option<u16>,
78 headers: Vec<(String, String)>,
79}
80
81impl DriverTunnelState {
82 fn new(response_tx: oneshot::Sender<Result<H3Tunnel>>) -> Self {
83 let (outbound_tx, outbound_rx) = mpsc::channel(32);
84 let (inbound_tx, inbound_rx) = mpsc::channel(32);
85
86 Self {
87 response_tx: Some(response_tx),
88 outbound_tx: Some(outbound_tx),
89 outbound_rx: Some(outbound_rx),
90 inbound_tx,
91 inbound_rx: Some(inbound_rx),
92 pending_outbound: VecDeque::new(),
93 opened: false,
94 status: None,
95 headers: Vec::new(),
96 }
97 }
98}
99
100pub struct H3Driver {
102 command_tx: mpsc::Sender<DriverCommand>,
103 command_rx: mpsc::Receiver<DriverCommand>,
104 conn: quiche::Connection,
105 h3_conn: quiche::h3::Connection,
106 socket: Arc<UdpSocket>,
107 peer_addr: SocketAddr,
108 streams: HashMap<u64, DriverStreamState>,
109 tunnels: HashMap<u64, DriverTunnelState>,
110 pending_commands: VecDeque<DriverCommand>,
111 goaway_id: Option<u64>,
112}
113
114impl H3Driver {
115 pub fn new(
116 command_tx: mpsc::Sender<DriverCommand>,
117 command_rx: mpsc::Receiver<DriverCommand>,
118 conn: quiche::Connection,
119 h3_conn: quiche::h3::Connection,
120 socket: Arc<UdpSocket>,
121 peer_addr: SocketAddr,
122 ) -> Self {
123 Self {
124 command_tx,
125 command_rx,
126 conn,
127 h3_conn,
128 socket,
129 peer_addr,
130 streams: HashMap::new(),
131 tunnels: HashMap::new(),
132 pending_commands: VecDeque::new(),
133 goaway_id: None,
134 }
135 }
136
137 pub async fn drive(mut self) -> Result<()> {
138 let result = self.drive_loop().await;
139
140 if let Err(ref e) = result {
141 tracing::error!("H3 Driver error: {}", e);
142 for (_, mut stream) in self.streams.drain() {
143 if let Some(tx) = stream.response_tx.take() {
144 let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
145 }
146 }
147 for (_, mut tunnel) in self.tunnels.drain() {
148 if let Some(tx) = tunnel.response_tx.take() {
149 let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
150 } else {
151 let _ = tunnel
152 .inbound_tx
153 .send(Err(Error::Quic(format!("Driver error: {}", e))))
154 .await;
155 }
156 }
157 for cmd in self.pending_commands.drain(..) {
158 Self::fail_pending_command(cmd, Error::Quic(format!("Driver error: {}", e)));
159 }
160 }
161
162 result
163 }
164
165 async fn drive_loop(&mut self) -> Result<()> {
166 let mut buf = vec![0u8; 65535];
167 let mut out = vec![0u8; 1350];
168
169 loop {
170 self.process_h3_events().await?;
171 self.process_pending_commands().await?;
172 self.flush_tunnel_data().await?;
173
174 loop {
175 match self.conn.send(&mut out) {
176 Ok((len, _)) => {
177 if let Err(e) = self.socket.send_to(&out[..len], self.peer_addr).await {
178 tracing::error!("H3 socket send error: {}", e);
179 return Err(Error::Io(e));
180 }
181 }
182 Err(quiche::Error::Done) => break,
183 Err(e) => {
184 tracing::error!("H3 quiche send error: {}", e);
185 return Err(Error::Quic(format!("QUIC send error: {}", e)));
186 }
187 }
188 }
189
190 let timeout_duration = self.conn.timeout().unwrap_or(Duration::from_secs(60));
191
192 tokio::select! {
193 cmd = self.command_rx.recv() => {
194 match cmd {
195 Some(c) => self.handle_command(c).await?,
196 None => {
197 match self.conn.close(true, 0x00, b"Client shutdown") {
198 Ok(_) | Err(quiche::Error::Done) => {},
199 Err(_) => {}
200 }
201 while let Ok((len, _)) = self.conn.send(&mut out) {
202 let _ = self.socket.send_to(&out[..len], self.peer_addr).await;
203 }
204 return Ok(());
205 }
206 }
207 }
208
209 res = self.socket.recv_from(&mut buf) => {
210 match res {
211 Ok((len, from)) => {
212 if from == self.peer_addr {
213 let info = quiche::RecvInfo {
214 from,
215 to: self.socket.local_addr().unwrap(),
216 };
217 match self.conn.recv(&mut buf[..len], info) {
218 Ok(_) => self.process_h3_events().await?,
219 Err(quiche::Error::Done) => {},
220 Err(e) => {
221 tracing::warn!("QUIC recv error: {}", e);
222 }
223 }
224 }
225 }
226 Err(e) => return Err(Error::Io(e)),
227 }
228 }
229
230 _ = sleep(timeout_duration) => {
231 self.conn.on_timeout();
232 }
233 }
234
235 if self.conn.is_closed() {
236 tracing::info!("H3 Driver: Connection closed");
237 self.fail_all(Error::Connection("Connection closed".into()))
238 .await;
239 return Ok(());
240 }
241 }
242 }
243
244 async fn handle_command(&mut self, cmd: DriverCommand) -> Result<()> {
245 match cmd {
246 DriverCommand::SendRequest { .. } => self.handle_send_request(cmd).await?,
247 DriverCommand::OpenWebSocketTunnel { .. } => {
248 self.handle_open_websocket_tunnel(cmd).await?
249 }
250 DriverCommand::SendTunnelData {
251 stream_id,
252 outbound,
253 } => self.queue_tunnel_outbound(stream_id, outbound).await?,
254 }
255 Ok(())
256 }
257
258 async fn process_pending_commands(&mut self) -> Result<()> {
259 let original_len = self.pending_commands.len();
260 for _ in 0..original_len {
261 let Some(cmd) = self.pending_commands.pop_front() else {
262 break;
263 };
264
265 match cmd {
266 DriverCommand::OpenWebSocketTunnel { .. } => {
267 if self.h3_conn.peer_settings_raw().is_none() {
268 self.pending_commands.push_back(cmd);
269 } else {
270 self.handle_open_websocket_tunnel(cmd).await?;
271 }
272 }
273 other => self.handle_command(other).await?,
274 }
275 }
276
277 Ok(())
278 }
279
280 async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
281 if let DriverCommand::SendRequest {
282 method,
283 uri,
284 headers,
285 body,
286 response_tx,
287 } = cmd
288 {
289 if self.goaway_id.is_some() {
290 let _ = response_tx.send(Err(Error::HttpProtocol(
291 "HTTP/3 GOAWAY received; refusing new request".into(),
292 )));
293 return Ok(());
294 }
295
296 let h3_headers = match build_request_headers(&method, &uri, &headers) {
297 Ok(headers) => headers,
298 Err(err) => {
299 let _ = response_tx.send(Err(err));
300 return Ok(());
301 }
302 };
303
304 let fin = body.is_none();
305 match self.h3_conn.send_request(&mut self.conn, &h3_headers, fin) {
306 Ok(stream_id) => {
307 let mut state = DriverStreamState::new(response_tx);
308
309 if let Some(data) = body {
310 match self
311 .h3_conn
312 .send_body(&mut self.conn, stream_id, &data, true)
313 {
314 Ok(sent) if sent == data.len() => {}
315 Ok(sent) => {
316 if let Some(tx) = state.response_tx.take() {
317 let _ = tx.send(Err(Error::Quic(format!(
318 "Partial H3 request body write: sent {sent} of {} bytes",
319 data.len()
320 ))));
321 }
322 return Ok(());
323 }
324 Err(e) => {
325 if let Some(tx) = state.response_tx.take() {
326 let _ = tx
327 .send(Err(Error::Quic(format!("Send body failed: {}", e))));
328 }
329 return Ok(());
330 }
331 }
332 }
333
334 self.streams.insert(stream_id, state);
335 }
336 Err(e) => {
337 let _ =
338 response_tx.send(Err(Error::Quic(format!("Send request failed: {}", e))));
339 }
340 }
341 }
342
343 Ok(())
344 }
345
346 async fn handle_open_websocket_tunnel(&mut self, cmd: DriverCommand) -> Result<()> {
347 if let DriverCommand::OpenWebSocketTunnel {
348 uri,
349 headers,
350 response_tx,
351 } = cmd
352 {
353 if self.goaway_id.is_some() {
354 let _ = response_tx.send(Err(Error::HttpProtocol(
355 "HTTP/3 GOAWAY received; refusing new RFC 9220 tunnel".into(),
356 )));
357 return Ok(());
358 }
359
360 if self.h3_conn.peer_settings_raw().is_none() {
361 self.pending_commands
362 .push_back(DriverCommand::OpenWebSocketTunnel {
363 uri,
364 headers,
365 response_tx,
366 });
367 return Ok(());
368 }
369
370 if !self.h3_conn.extended_connect_enabled_by_peer() {
371 let _ = response_tx.send(Err(Error::WebSocketUnsupported(
372 "RFC 9220 requires peer SETTINGS_ENABLE_CONNECT_PROTOCOL = 1".into(),
373 )));
374 return Ok(());
375 }
376
377 let h3_headers = match build_websocket_connect_headers(&uri, &headers) {
378 Ok(headers) => headers,
379 Err(err) => {
380 let _ = response_tx.send(Err(err));
381 return Ok(());
382 }
383 };
384
385 match self
386 .h3_conn
387 .send_request(&mut self.conn, &h3_headers, false)
388 {
389 Ok(stream_id) => {
390 self.tunnels
391 .insert(stream_id, DriverTunnelState::new(response_tx));
392 }
393 Err(e) => {
394 let _ = response_tx
395 .send(Err(Error::Quic(format!("RFC 9220 CONNECT failed: {}", e))));
396 }
397 }
398 }
399
400 Ok(())
401 }
402
403 async fn queue_tunnel_outbound(
404 &mut self,
405 stream_id: u64,
406 outbound: H3TunnelOutbound,
407 ) -> Result<()> {
408 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
409 tunnel.pending_outbound.push_back(outbound);
410 self.flush_tunnel_data().await?;
411 }
412
413 Ok(())
414 }
415
416 async fn flush_tunnel_data(&mut self) -> Result<()> {
417 let stream_ids: Vec<u64> = self.tunnels.keys().copied().collect();
418
419 for stream_id in stream_ids {
420 loop {
421 let outbound = match self
422 .tunnels
423 .get_mut(&stream_id)
424 .and_then(|tunnel| tunnel.pending_outbound.pop_front())
425 {
426 Some(outbound) => outbound,
427 None => break,
428 };
429
430 match self.h3_conn.send_body(
431 &mut self.conn,
432 stream_id,
433 &outbound.bytes,
434 outbound.fin,
435 ) {
436 Ok(sent) if sent == outbound.bytes.len() => {}
437 Ok(sent) => {
438 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
439 tunnel.pending_outbound.push_front(H3TunnelOutbound {
440 bytes: outbound.bytes.slice(sent..),
441 fin: outbound.fin,
442 });
443 }
444 break;
445 }
446 Err(quiche::h3::Error::Done) | Err(quiche::h3::Error::StreamBlocked) => {
447 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
448 tunnel.pending_outbound.push_front(outbound);
449 }
450 break;
451 }
452 Err(e) => {
453 return Err(Error::Quic(format!("H3 tunnel send body failed: {}", e)));
454 }
455 }
456 }
457 }
458
459 Ok(())
460 }
461
462 async fn process_h3_events(&mut self) -> Result<()> {
463 loop {
464 match self.h3_conn.poll(&mut self.conn) {
465 Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
466 self.handle_headers_event(stream_id, list).await?;
467 }
468 Ok((stream_id, quiche::h3::Event::Data)) => {
469 self.handle_data_event(stream_id).await?;
470 }
471 Ok((stream_id, quiche::h3::Event::Finished)) => {
472 self.handle_finished_event(stream_id).await?;
473 }
474 Ok((stream_id, quiche::h3::Event::Reset(error_code))) => {
475 self.handle_reset_event(stream_id, error_code).await?;
476 }
477 Ok((id, quiche::h3::Event::GoAway)) => {
478 self.handle_goaway_event(id).await?;
479 }
480 Err(quiche::h3::Error::Done) => break,
481 Ok(_) => {}
482 Err(e) => {
483 tracing::warn!("H3 poll error: {}", e);
484 return Err(Error::Quic(format!("H3 poll error: {}", e)));
485 }
486 }
487 }
488
489 Ok(())
490 }
491
492 async fn handle_headers_event(
493 &mut self,
494 stream_id: u64,
495 list: Vec<quiche::h3::Header>,
496 ) -> Result<()> {
497 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
498 for header in list {
499 let name = String::from_utf8_lossy(header.name());
500 let value = String::from_utf8_lossy(header.value());
501
502 if name == ":status" {
503 tunnel.status = value.parse().ok();
504 } else if !name.starts_with(':') {
505 tunnel.headers.push((name.into_owned(), value.into_owned()));
506 }
507 }
508
509 match tunnel.status {
510 Some(200) if !tunnel.opened => {
511 let outbound_tx = tunnel.outbound_tx.take().expect("outbound tx");
512 let inbound_rx = tunnel.inbound_rx.take().expect("inbound rx");
513 let mut outbound_rx = tunnel.outbound_rx.take().expect("outbound rx");
514 let command_tx = self.command_tx.clone();
515
516 tokio::spawn(async move {
517 while let Some(outbound) = outbound_rx.recv().await {
518 if command_tx
519 .send(DriverCommand::SendTunnelData {
520 stream_id,
521 outbound,
522 })
523 .await
524 .is_err()
525 {
526 break;
527 }
528 }
529 });
530
531 tunnel.opened = true;
532 if let Some(tx) = tunnel.response_tx.take() {
533 let _ = tx.send(Ok(H3Tunnel::new(outbound_tx, inbound_rx)));
534 }
535 }
536 Some(status) if status >= 200 && !tunnel.opened => {
537 let headers = crate::headers::Headers::from(tunnel.headers.clone());
538 if let Some(tx) = tunnel.response_tx.take() {
539 let _ = tx.send(Err(Error::WebSocketHandshake { status, headers }));
540 }
541 self.tunnels.remove(&stream_id);
542 }
543 _ => {}
544 }
545
546 return Ok(());
547 }
548
549 if let Some(stream) = self.streams.get_mut(&stream_id) {
550 for header in list {
551 let name = String::from_utf8_lossy(header.name());
552 let value = String::from_utf8_lossy(header.value());
553
554 if name == ":status" {
555 stream.status = value.parse().ok();
556 } else {
557 stream.headers.push((name.into_owned(), value.into_owned()));
558 }
559 }
560 }
561
562 Ok(())
563 }
564
565 async fn handle_data_event(&mut self, stream_id: u64) -> Result<()> {
566 let mut buf = vec![0u8; 65535];
567
568 if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
569 loop {
570 match self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf) {
571 Ok(0) => break,
572 Ok(len) => {
573 if tunnel.opened {
574 let _ = tunnel
575 .inbound_tx
576 .send(Ok(H3TunnelEvent::Data(Bytes::copy_from_slice(&buf[..len]))))
577 .await;
578 } else if let Some(tx) = tunnel.response_tx.take() {
579 let _ = tx.send(Err(Error::HttpProtocol(
580 "RFC 9220 tunnel DATA received before :status 200".into(),
581 )));
582 }
583 }
584 Err(quiche::h3::Error::Done) => break,
585 Err(e) => return Err(Error::Quic(format!("H3 recv body failed: {}", e))),
586 }
587 }
588 return Ok(());
589 }
590
591 if let Some(stream) = self.streams.get_mut(&stream_id) {
592 loop {
593 match self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf) {
594 Ok(0) => break,
595 Ok(len) => stream.body.extend_from_slice(&buf[..len]),
596 Err(quiche::h3::Error::Done) => break,
597 Err(e) => return Err(Error::Quic(format!("H3 recv body failed: {}", e))),
598 }
599 }
600 }
601
602 Ok(())
603 }
604
605 async fn handle_finished_event(&mut self, stream_id: u64) -> Result<()> {
606 if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
607 if tunnel.opened {
608 let _ = tunnel.inbound_tx.send(Ok(H3TunnelEvent::EndStream)).await;
609 } else if let Some(tx) = tunnel.response_tx.take() {
610 let _ = tx.send(Err(Error::HttpProtocol(
611 "RFC 9220 tunnel completed before :status 200".into(),
612 )));
613 }
614 return Ok(());
615 }
616
617 if let Some(mut stream) = self.streams.remove(&stream_id) {
618 if let Some(tx) = stream.response_tx.take() {
619 let response = match stream.status {
620 Some(status) => Ok(StreamResponse {
621 status,
622 headers: stream.headers,
623 body: stream.body.freeze(),
624 }),
625 None => Err(Error::HttpProtocol(format!(
626 "H3 stream {} completed without status code",
627 stream_id
628 ))),
629 };
630 let _ = tx.send(response);
631 }
632 }
633
634 Ok(())
635 }
636
637 async fn handle_reset_event(&mut self, stream_id: u64, error_code: u64) -> Result<()> {
638 if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
639 if tunnel.opened {
640 let _ = tunnel
641 .inbound_tx
642 .send(Ok(H3TunnelEvent::Reset(error_code.to_string())))
643 .await;
644 } else if let Some(tx) = tunnel.response_tx.take() {
645 let _ = tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
646 }
647 return Ok(());
648 }
649
650 if let Some(mut stream) = self.streams.remove(&stream_id) {
651 if let Some(tx) = stream.response_tx.take() {
652 let _ = tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
653 }
654 }
655
656 Ok(())
657 }
658
659 async fn handle_goaway_event(&mut self, id: u64) -> Result<()> {
660 self.goaway_id = Some(id);
661
662 let tunnel_ids: Vec<u64> = self.tunnels.keys().copied().collect();
663 for stream_id in tunnel_ids {
664 if stream_id > id {
665 if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
666 if tunnel.opened {
667 let _ = tunnel
668 .inbound_tx
669 .send(Ok(H3TunnelEvent::GoAway { id }))
670 .await;
671 } else if let Some(tx) = tunnel.response_tx.take() {
672 let _ = tx.send(Err(Error::HttpProtocol(format!(
673 "HTTP/3 GOAWAY received id={id}"
674 ))));
675 }
676 }
677 }
678 }
679
680 let stream_ids: Vec<u64> = self.streams.keys().copied().collect();
681 for stream_id in stream_ids {
682 if stream_id > id {
683 if let Some(mut stream) = self.streams.remove(&stream_id) {
684 if let Some(tx) = stream.response_tx.take() {
685 let _ = tx.send(Err(Error::HttpProtocol(format!(
686 "HTTP/3 GOAWAY received id={id}"
687 ))));
688 }
689 }
690 }
691 }
692
693 Ok(())
694 }
695
696 async fn fail_all(&mut self, err: Error) {
697 for (_, mut stream) in self.streams.drain() {
698 if let Some(tx) = stream.response_tx.take() {
699 let _ = tx.send(Err(Error::HttpProtocol(err.to_string())));
700 }
701 }
702
703 for (_, mut tunnel) in self.tunnels.drain() {
704 if let Some(tx) = tunnel.response_tx.take() {
705 let _ = tx.send(Err(Error::HttpProtocol(err.to_string())));
706 } else {
707 let _ = tunnel
708 .inbound_tx
709 .send(Err(Error::HttpProtocol(err.to_string())))
710 .await;
711 }
712 }
713
714 for cmd in self.pending_commands.drain(..) {
715 Self::fail_pending_command(cmd, Error::HttpProtocol(err.to_string()));
716 }
717 }
718
719 fn fail_pending_command(cmd: DriverCommand, err: Error) {
720 match cmd {
721 DriverCommand::SendRequest { response_tx, .. } => {
722 let _ = response_tx.send(Err(Error::HttpProtocol(err.to_string())));
723 }
724 DriverCommand::OpenWebSocketTunnel { response_tx, .. } => {
725 let _ = response_tx.send(Err(Error::HttpProtocol(err.to_string())));
726 }
727 DriverCommand::SendTunnelData { .. } => {}
728 }
729 }
730}
731
732pub(crate) fn build_websocket_connect_headers(
733 uri: &http::Uri,
734 headers: &[(String, String)],
735) -> Result<Vec<quiche::h3::Header>> {
736 let scheme = uri.scheme_str().ok_or_else(|| {
737 Error::WebSocketUnsupported("RFC 9220 requires an https URI internally".into())
738 })?;
739 if scheme != "https" {
740 return Err(Error::WebSocketUnsupported(
741 "RFC 9220 WebSocket over HTTP/3 requires wss://".into(),
742 ));
743 }
744
745 let authority = uri
746 .authority()
747 .ok_or_else(|| Error::HttpProtocol("RFC 9220 CONNECT requires :authority".into()))?
748 .as_str();
749 let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
750
751 let mut h3_headers = vec![
752 quiche::h3::Header::new(b":method", b"CONNECT"),
753 quiche::h3::Header::new(b":protocol", b"websocket"),
754 quiche::h3::Header::new(b":scheme", scheme.as_bytes()),
755 quiche::h3::Header::new(b":path", path.as_bytes()),
756 quiche::h3::Header::new(b":authority", authority.as_bytes()),
757 ];
758
759 for (name, value) in headers {
760 let lower = name.to_ascii_lowercase();
761 if name.starts_with(':') {
762 return Err(Error::HttpProtocol(format!(
763 "user pseudo-header {name} is not allowed on RFC 9220 CONNECT"
764 )));
765 }
766
767 if matches!(
768 lower.as_str(),
769 "connection"
770 | "upgrade"
771 | "host"
772 | "sec-websocket-key"
773 | "sec-websocket-accept"
774 | "sec-websocket-extensions"
775 ) {
776 return Err(Error::WebSocketUnsupported(format!(
777 "header {name} is not allowed on RFC 9220 WebSocket over HTTP/3"
778 )));
779 }
780
781 if matches!(
782 lower.as_str(),
783 "keep-alive" | "proxy-connection" | "transfer-encoding"
784 ) {
785 continue;
786 }
787
788 h3_headers.push(quiche::h3::Header::new(lower.as_bytes(), value.as_bytes()));
789 }
790
791 Ok(h3_headers)
792}
793
794fn build_request_headers(
795 method: &http::Method,
796 uri: &http::Uri,
797 headers: &[(String, String)],
798) -> Result<Vec<quiche::h3::Header>> {
799 let scheme = uri.scheme_str().unwrap_or("https");
800 let authority = uri
801 .authority()
802 .map(|authority| authority.as_str())
803 .or_else(|| uri.host())
804 .unwrap_or("");
805 let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
806
807 let mut h3_headers = vec![
808 quiche::h3::Header::new(b":method", method.as_str().as_bytes()),
809 quiche::h3::Header::new(b":scheme", scheme.as_bytes()),
810 quiche::h3::Header::new(b":authority", authority.as_bytes()),
811 quiche::h3::Header::new(b":path", path.as_bytes()),
812 ];
813
814 for (name, value) in headers {
815 let lower = name.to_ascii_lowercase();
816 if !name.starts_with(':')
817 && lower != "connection"
818 && lower != "keep-alive"
819 && lower != "proxy-connection"
820 && lower != "transfer-encoding"
821 && lower != "upgrade"
822 {
823 h3_headers.push(quiche::h3::Header::new(lower.as_bytes(), value.as_bytes()));
824 }
825 }
826
827 Ok(h3_headers)
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833
834 fn header_pairs(headers: &[quiche::h3::Header]) -> Vec<(String, String)> {
835 headers
836 .iter()
837 .map(|h| {
838 (
839 String::from_utf8_lossy(h.name()).into_owned(),
840 String::from_utf8_lossy(h.value()).into_owned(),
841 )
842 })
843 .collect()
844 }
845
846 #[test]
847 fn rfc9220_headers_have_required_pseudo_headers_in_order() {
848 let uri: http::Uri = "https://example.test:443/chat?room=one".parse().unwrap();
849 let headers =
850 build_websocket_connect_headers(&uri, &[("User-Agent".into(), "specter-test".into())])
851 .unwrap();
852 let pairs = header_pairs(&headers);
853
854 assert_eq!(
855 &pairs[..5],
856 &[
857 (":method".into(), "CONNECT".into()),
858 (":protocol".into(), "websocket".into()),
859 (":scheme".into(), "https".into()),
860 (":path".into(), "/chat?room=one".into()),
861 (":authority".into(), "example.test:443".into()),
862 ]
863 );
864 assert!(pairs.contains(&("user-agent".into(), "specter-test".into())));
865 }
866
867 #[test]
868 fn rfc9220_rejects_h1_websocket_bootstrap_headers() {
869 let uri: http::Uri = "https://example.test/chat".parse().unwrap();
870 for name in [
871 "Connection",
872 "Upgrade",
873 "Host",
874 "Sec-WebSocket-Key",
875 "Sec-WebSocket-Accept",
876 "Sec-WebSocket-Extensions",
877 ] {
878 let err = build_websocket_connect_headers(&uri, &[(name.into(), "x".into())])
879 .expect_err("forbidden header must fail");
880 let msg = err.to_string();
881 assert!(msg.contains("not allowed"), "{name}: {msg}");
882 }
883 }
884
885 #[test]
886 fn rfc9220_rejects_user_pseudo_headers() {
887 let uri: http::Uri = "https://example.test/chat".parse().unwrap();
888 let err = build_websocket_connect_headers(&uri, &[(":authority".into(), "evil".into())])
889 .expect_err("user pseudo headers must fail");
890 assert!(err.to_string().contains("pseudo-header"));
891 }
892}