Skip to main content

specter/transport/h3/
driver.rs

1//! HTTP/3 connection driver - background task that reads packets and routes them to streams.
2//!
3//! The driver owns the QUIC connection and UdpSocket.
4
5use bytes::{Bytes, BytesMut};
6use quiche::h3::NameValue;
7use std::collections::{HashMap, VecDeque};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::net::UdpSocket;
12use tokio::sync::mpsc;
13use tokio::sync::oneshot;
14use tokio::time::sleep;
15
16use crate::error::{Error, Result};
17use crate::transport::h3::{H3Tunnel, H3TunnelEvent, H3TunnelOutbound};
18
19/// Command sent from handle to driver.
20#[derive(Debug)]
21pub enum DriverCommand {
22    /// Send a request and get response via oneshot.
23    SendRequest {
24        method: http::Method,
25        uri: http::Uri,
26        headers: Vec<(String, String)>,
27        body: Option<Bytes>,
28        response_tx: oneshot::Sender<Result<StreamResponse>>,
29    },
30    /// Open an RFC 9220 WebSocket-over-HTTP/3 tunnel.
31    OpenWebSocketTunnel {
32        uri: http::Uri,
33        headers: Vec<(String, String)>,
34        response_tx: oneshot::Sender<Result<H3Tunnel>>,
35    },
36    /// Queue outbound DATA for an open RFC 9220 tunnel.
37    SendTunnelData {
38        stream_id: u64,
39        outbound: H3TunnelOutbound,
40    },
41}
42
43#[derive(Debug)]
44pub struct StreamResponse {
45    pub status: u16,
46    pub headers: Vec<(String, String)>,
47    pub body: Bytes,
48}
49
50/// Per-stream state tracked by driver.
51struct DriverStreamState {
52    response_tx: Option<oneshot::Sender<Result<StreamResponse>>>,
53    status: Option<u16>,
54    headers: Vec<(String, String)>,
55    body: BytesMut,
56}
57
58impl DriverStreamState {
59    fn new(response_tx: oneshot::Sender<Result<StreamResponse>>) -> Self {
60        Self {
61            response_tx: Some(response_tx),
62            status: None,
63            headers: Vec::new(),
64            body: BytesMut::new(),
65        }
66    }
67}
68
69struct DriverTunnelState {
70    response_tx: Option<oneshot::Sender<Result<H3Tunnel>>>,
71    outbound_tx: Option<mpsc::Sender<H3TunnelOutbound>>,
72    outbound_rx: Option<mpsc::Receiver<H3TunnelOutbound>>,
73    inbound_tx: mpsc::Sender<Result<H3TunnelEvent>>,
74    inbound_rx: Option<mpsc::Receiver<Result<H3TunnelEvent>>>,
75    pending_outbound: VecDeque<H3TunnelOutbound>,
76    opened: bool,
77    status: Option<u16>,
78    headers: Vec<(String, String)>,
79}
80
81impl DriverTunnelState {
82    fn new(response_tx: oneshot::Sender<Result<H3Tunnel>>) -> Self {
83        let (outbound_tx, outbound_rx) = mpsc::channel(32);
84        let (inbound_tx, inbound_rx) = mpsc::channel(32);
85
86        Self {
87            response_tx: Some(response_tx),
88            outbound_tx: Some(outbound_tx),
89            outbound_rx: Some(outbound_rx),
90            inbound_tx,
91            inbound_rx: Some(inbound_rx),
92            pending_outbound: VecDeque::new(),
93            opened: false,
94            status: None,
95            headers: Vec::new(),
96        }
97    }
98}
99
100/// HTTP/3 connection driver.
101pub struct H3Driver {
102    command_tx: mpsc::Sender<DriverCommand>,
103    command_rx: mpsc::Receiver<DriverCommand>,
104    conn: quiche::Connection,
105    h3_conn: quiche::h3::Connection,
106    socket: Arc<UdpSocket>,
107    peer_addr: SocketAddr,
108    streams: HashMap<u64, DriverStreamState>,
109    tunnels: HashMap<u64, DriverTunnelState>,
110    pending_commands: VecDeque<DriverCommand>,
111    goaway_id: Option<u64>,
112}
113
114impl H3Driver {
115    pub fn new(
116        command_tx: mpsc::Sender<DriverCommand>,
117        command_rx: mpsc::Receiver<DriverCommand>,
118        conn: quiche::Connection,
119        h3_conn: quiche::h3::Connection,
120        socket: Arc<UdpSocket>,
121        peer_addr: SocketAddr,
122    ) -> Self {
123        Self {
124            command_tx,
125            command_rx,
126            conn,
127            h3_conn,
128            socket,
129            peer_addr,
130            streams: HashMap::new(),
131            tunnels: HashMap::new(),
132            pending_commands: VecDeque::new(),
133            goaway_id: None,
134        }
135    }
136
137    pub async fn drive(mut self) -> Result<()> {
138        let result = self.drive_loop().await;
139
140        if let Err(ref e) = result {
141            tracing::error!("H3 Driver error: {}", e);
142            for (_, mut stream) in self.streams.drain() {
143                if let Some(tx) = stream.response_tx.take() {
144                    let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
145                }
146            }
147            for (_, mut tunnel) in self.tunnels.drain() {
148                if let Some(tx) = tunnel.response_tx.take() {
149                    let _ = tx.send(Err(Error::Quic(format!("Driver error: {}", e))));
150                } else {
151                    let _ = tunnel
152                        .inbound_tx
153                        .send(Err(Error::Quic(format!("Driver error: {}", e))))
154                        .await;
155                }
156            }
157            for cmd in self.pending_commands.drain(..) {
158                Self::fail_pending_command(cmd, Error::Quic(format!("Driver error: {}", e)));
159            }
160        }
161
162        result
163    }
164
165    async fn drive_loop(&mut self) -> Result<()> {
166        let mut buf = vec![0u8; 65535];
167        let mut out = vec![0u8; 1350];
168
169        loop {
170            self.process_h3_events().await?;
171            self.process_pending_commands().await?;
172            self.flush_tunnel_data().await?;
173
174            loop {
175                match self.conn.send(&mut out) {
176                    Ok((len, _)) => {
177                        if let Err(e) = self.socket.send_to(&out[..len], self.peer_addr).await {
178                            tracing::error!("H3 socket send error: {}", e);
179                            return Err(Error::Io(e));
180                        }
181                    }
182                    Err(quiche::Error::Done) => break,
183                    Err(e) => {
184                        tracing::error!("H3 quiche send error: {}", e);
185                        return Err(Error::Quic(format!("QUIC send error: {}", e)));
186                    }
187                }
188            }
189
190            let timeout_duration = self.conn.timeout().unwrap_or(Duration::from_secs(60));
191
192            tokio::select! {
193                cmd = self.command_rx.recv() => {
194                    match cmd {
195                        Some(c) => self.handle_command(c).await?,
196                        None => {
197                            match self.conn.close(true, 0x00, b"Client shutdown") {
198                                Ok(_) | Err(quiche::Error::Done) => {},
199                                Err(_) => {}
200                            }
201                            while let Ok((len, _)) = self.conn.send(&mut out) {
202                                let _ = self.socket.send_to(&out[..len], self.peer_addr).await;
203                            }
204                            return Ok(());
205                        }
206                    }
207                }
208
209                res = self.socket.recv_from(&mut buf) => {
210                    match res {
211                        Ok((len, from)) => {
212                            if from == self.peer_addr {
213                                let info = quiche::RecvInfo {
214                                    from,
215                                    to: self.socket.local_addr().unwrap(),
216                                };
217                                match self.conn.recv(&mut buf[..len], info) {
218                                    Ok(_) => self.process_h3_events().await?,
219                                    Err(quiche::Error::Done) => {},
220                                    Err(e) => {
221                                        tracing::warn!("QUIC recv error: {}", e);
222                                    }
223                                }
224                            }
225                        }
226                        Err(e) => return Err(Error::Io(e)),
227                    }
228                }
229
230                _ = sleep(timeout_duration) => {
231                    self.conn.on_timeout();
232                }
233            }
234
235            if self.conn.is_closed() {
236                tracing::info!("H3 Driver: Connection closed");
237                self.fail_all(Error::Connection("Connection closed".into()))
238                    .await;
239                return Ok(());
240            }
241        }
242    }
243
244    async fn handle_command(&mut self, cmd: DriverCommand) -> Result<()> {
245        match cmd {
246            DriverCommand::SendRequest { .. } => self.handle_send_request(cmd).await?,
247            DriverCommand::OpenWebSocketTunnel { .. } => {
248                self.handle_open_websocket_tunnel(cmd).await?
249            }
250            DriverCommand::SendTunnelData {
251                stream_id,
252                outbound,
253            } => self.queue_tunnel_outbound(stream_id, outbound).await?,
254        }
255        Ok(())
256    }
257
258    async fn process_pending_commands(&mut self) -> Result<()> {
259        let original_len = self.pending_commands.len();
260        for _ in 0..original_len {
261            let Some(cmd) = self.pending_commands.pop_front() else {
262                break;
263            };
264
265            match cmd {
266                DriverCommand::OpenWebSocketTunnel { .. } => {
267                    if self.h3_conn.peer_settings_raw().is_none() {
268                        self.pending_commands.push_back(cmd);
269                    } else {
270                        self.handle_open_websocket_tunnel(cmd).await?;
271                    }
272                }
273                other => self.handle_command(other).await?,
274            }
275        }
276
277        Ok(())
278    }
279
280    async fn handle_send_request(&mut self, cmd: DriverCommand) -> Result<()> {
281        if let DriverCommand::SendRequest {
282            method,
283            uri,
284            headers,
285            body,
286            response_tx,
287        } = cmd
288        {
289            if self.goaway_id.is_some() {
290                let _ = response_tx.send(Err(Error::HttpProtocol(
291                    "HTTP/3 GOAWAY received; refusing new request".into(),
292                )));
293                return Ok(());
294            }
295
296            let h3_headers = match build_request_headers(&method, &uri, &headers) {
297                Ok(headers) => headers,
298                Err(err) => {
299                    let _ = response_tx.send(Err(err));
300                    return Ok(());
301                }
302            };
303
304            let fin = body.is_none();
305            match self.h3_conn.send_request(&mut self.conn, &h3_headers, fin) {
306                Ok(stream_id) => {
307                    let mut state = DriverStreamState::new(response_tx);
308
309                    if let Some(data) = body {
310                        match self
311                            .h3_conn
312                            .send_body(&mut self.conn, stream_id, &data, true)
313                        {
314                            Ok(sent) if sent == data.len() => {}
315                            Ok(sent) => {
316                                if let Some(tx) = state.response_tx.take() {
317                                    let _ = tx.send(Err(Error::Quic(format!(
318                                        "Partial H3 request body write: sent {sent} of {} bytes",
319                                        data.len()
320                                    ))));
321                                }
322                                return Ok(());
323                            }
324                            Err(e) => {
325                                if let Some(tx) = state.response_tx.take() {
326                                    let _ = tx
327                                        .send(Err(Error::Quic(format!("Send body failed: {}", e))));
328                                }
329                                return Ok(());
330                            }
331                        }
332                    }
333
334                    self.streams.insert(stream_id, state);
335                }
336                Err(e) => {
337                    let _ =
338                        response_tx.send(Err(Error::Quic(format!("Send request failed: {}", e))));
339                }
340            }
341        }
342
343        Ok(())
344    }
345
346    async fn handle_open_websocket_tunnel(&mut self, cmd: DriverCommand) -> Result<()> {
347        if let DriverCommand::OpenWebSocketTunnel {
348            uri,
349            headers,
350            response_tx,
351        } = cmd
352        {
353            if self.goaway_id.is_some() {
354                let _ = response_tx.send(Err(Error::HttpProtocol(
355                    "HTTP/3 GOAWAY received; refusing new RFC 9220 tunnel".into(),
356                )));
357                return Ok(());
358            }
359
360            if self.h3_conn.peer_settings_raw().is_none() {
361                self.pending_commands
362                    .push_back(DriverCommand::OpenWebSocketTunnel {
363                        uri,
364                        headers,
365                        response_tx,
366                    });
367                return Ok(());
368            }
369
370            if !self.h3_conn.extended_connect_enabled_by_peer() {
371                let _ = response_tx.send(Err(Error::WebSocketUnsupported(
372                    "RFC 9220 requires peer SETTINGS_ENABLE_CONNECT_PROTOCOL = 1".into(),
373                )));
374                return Ok(());
375            }
376
377            let h3_headers = match build_websocket_connect_headers(&uri, &headers) {
378                Ok(headers) => headers,
379                Err(err) => {
380                    let _ = response_tx.send(Err(err));
381                    return Ok(());
382                }
383            };
384
385            match self
386                .h3_conn
387                .send_request(&mut self.conn, &h3_headers, false)
388            {
389                Ok(stream_id) => {
390                    self.tunnels
391                        .insert(stream_id, DriverTunnelState::new(response_tx));
392                }
393                Err(e) => {
394                    let _ = response_tx
395                        .send(Err(Error::Quic(format!("RFC 9220 CONNECT failed: {}", e))));
396                }
397            }
398        }
399
400        Ok(())
401    }
402
403    async fn queue_tunnel_outbound(
404        &mut self,
405        stream_id: u64,
406        outbound: H3TunnelOutbound,
407    ) -> Result<()> {
408        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
409            tunnel.pending_outbound.push_back(outbound);
410            self.flush_tunnel_data().await?;
411        }
412
413        Ok(())
414    }
415
416    async fn flush_tunnel_data(&mut self) -> Result<()> {
417        let stream_ids: Vec<u64> = self.tunnels.keys().copied().collect();
418
419        for stream_id in stream_ids {
420            loop {
421                let outbound = match self
422                    .tunnels
423                    .get_mut(&stream_id)
424                    .and_then(|tunnel| tunnel.pending_outbound.pop_front())
425                {
426                    Some(outbound) => outbound,
427                    None => break,
428                };
429
430                match self.h3_conn.send_body(
431                    &mut self.conn,
432                    stream_id,
433                    &outbound.bytes,
434                    outbound.fin,
435                ) {
436                    Ok(sent) if sent == outbound.bytes.len() => {}
437                    Ok(sent) => {
438                        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
439                            tunnel.pending_outbound.push_front(H3TunnelOutbound {
440                                bytes: outbound.bytes.slice(sent..),
441                                fin: outbound.fin,
442                            });
443                        }
444                        break;
445                    }
446                    Err(quiche::h3::Error::Done) | Err(quiche::h3::Error::StreamBlocked) => {
447                        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
448                            tunnel.pending_outbound.push_front(outbound);
449                        }
450                        break;
451                    }
452                    Err(e) => {
453                        return Err(Error::Quic(format!("H3 tunnel send body failed: {}", e)));
454                    }
455                }
456            }
457        }
458
459        Ok(())
460    }
461
462    async fn process_h3_events(&mut self) -> Result<()> {
463        loop {
464            match self.h3_conn.poll(&mut self.conn) {
465                Ok((stream_id, quiche::h3::Event::Headers { list, .. })) => {
466                    self.handle_headers_event(stream_id, list).await?;
467                }
468                Ok((stream_id, quiche::h3::Event::Data)) => {
469                    self.handle_data_event(stream_id).await?;
470                }
471                Ok((stream_id, quiche::h3::Event::Finished)) => {
472                    self.handle_finished_event(stream_id).await?;
473                }
474                Ok((stream_id, quiche::h3::Event::Reset(error_code))) => {
475                    self.handle_reset_event(stream_id, error_code).await?;
476                }
477                Ok((id, quiche::h3::Event::GoAway)) => {
478                    self.handle_goaway_event(id).await?;
479                }
480                Err(quiche::h3::Error::Done) => break,
481                Ok(_) => {}
482                Err(e) => {
483                    tracing::warn!("H3 poll error: {}", e);
484                    return Err(Error::Quic(format!("H3 poll error: {}", e)));
485                }
486            }
487        }
488
489        Ok(())
490    }
491
492    async fn handle_headers_event(
493        &mut self,
494        stream_id: u64,
495        list: Vec<quiche::h3::Header>,
496    ) -> Result<()> {
497        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
498            for header in list {
499                let name = String::from_utf8_lossy(header.name());
500                let value = String::from_utf8_lossy(header.value());
501
502                if name == ":status" {
503                    tunnel.status = value.parse().ok();
504                } else if !name.starts_with(':') {
505                    tunnel.headers.push((name.into_owned(), value.into_owned()));
506                }
507            }
508
509            match tunnel.status {
510                Some(200) if !tunnel.opened => {
511                    let outbound_tx = tunnel.outbound_tx.take().expect("outbound tx");
512                    let inbound_rx = tunnel.inbound_rx.take().expect("inbound rx");
513                    let mut outbound_rx = tunnel.outbound_rx.take().expect("outbound rx");
514                    let command_tx = self.command_tx.clone();
515
516                    tokio::spawn(async move {
517                        while let Some(outbound) = outbound_rx.recv().await {
518                            if command_tx
519                                .send(DriverCommand::SendTunnelData {
520                                    stream_id,
521                                    outbound,
522                                })
523                                .await
524                                .is_err()
525                            {
526                                break;
527                            }
528                        }
529                    });
530
531                    tunnel.opened = true;
532                    if let Some(tx) = tunnel.response_tx.take() {
533                        let _ = tx.send(Ok(H3Tunnel::new(outbound_tx, inbound_rx)));
534                    }
535                }
536                Some(status) if status >= 200 && !tunnel.opened => {
537                    let headers = crate::headers::Headers::from(tunnel.headers.clone());
538                    if let Some(tx) = tunnel.response_tx.take() {
539                        let _ = tx.send(Err(Error::WebSocketHandshake { status, headers }));
540                    }
541                    self.tunnels.remove(&stream_id);
542                }
543                _ => {}
544            }
545
546            return Ok(());
547        }
548
549        if let Some(stream) = self.streams.get_mut(&stream_id) {
550            for header in list {
551                let name = String::from_utf8_lossy(header.name());
552                let value = String::from_utf8_lossy(header.value());
553
554                if name == ":status" {
555                    stream.status = value.parse().ok();
556                } else {
557                    stream.headers.push((name.into_owned(), value.into_owned()));
558                }
559            }
560        }
561
562        Ok(())
563    }
564
565    async fn handle_data_event(&mut self, stream_id: u64) -> Result<()> {
566        let mut buf = vec![0u8; 65535];
567
568        if let Some(tunnel) = self.tunnels.get_mut(&stream_id) {
569            loop {
570                match self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf) {
571                    Ok(0) => break,
572                    Ok(len) => {
573                        if tunnel.opened {
574                            let _ = tunnel
575                                .inbound_tx
576                                .send(Ok(H3TunnelEvent::Data(Bytes::copy_from_slice(&buf[..len]))))
577                                .await;
578                        } else if let Some(tx) = tunnel.response_tx.take() {
579                            let _ = tx.send(Err(Error::HttpProtocol(
580                                "RFC 9220 tunnel DATA received before :status 200".into(),
581                            )));
582                        }
583                    }
584                    Err(quiche::h3::Error::Done) => break,
585                    Err(e) => return Err(Error::Quic(format!("H3 recv body failed: {}", e))),
586                }
587            }
588            return Ok(());
589        }
590
591        if let Some(stream) = self.streams.get_mut(&stream_id) {
592            loop {
593                match self.h3_conn.recv_body(&mut self.conn, stream_id, &mut buf) {
594                    Ok(0) => break,
595                    Ok(len) => stream.body.extend_from_slice(&buf[..len]),
596                    Err(quiche::h3::Error::Done) => break,
597                    Err(e) => return Err(Error::Quic(format!("H3 recv body failed: {}", e))),
598                }
599            }
600        }
601
602        Ok(())
603    }
604
605    async fn handle_finished_event(&mut self, stream_id: u64) -> Result<()> {
606        if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
607            if tunnel.opened {
608                let _ = tunnel.inbound_tx.send(Ok(H3TunnelEvent::EndStream)).await;
609            } else if let Some(tx) = tunnel.response_tx.take() {
610                let _ = tx.send(Err(Error::HttpProtocol(
611                    "RFC 9220 tunnel completed before :status 200".into(),
612                )));
613            }
614            return Ok(());
615        }
616
617        if let Some(mut stream) = self.streams.remove(&stream_id) {
618            if let Some(tx) = stream.response_tx.take() {
619                let response = match stream.status {
620                    Some(status) => Ok(StreamResponse {
621                        status,
622                        headers: stream.headers,
623                        body: stream.body.freeze(),
624                    }),
625                    None => Err(Error::HttpProtocol(format!(
626                        "H3 stream {} completed without status code",
627                        stream_id
628                    ))),
629                };
630                let _ = tx.send(response);
631            }
632        }
633
634        Ok(())
635    }
636
637    async fn handle_reset_event(&mut self, stream_id: u64, error_code: u64) -> Result<()> {
638        if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
639            if tunnel.opened {
640                let _ = tunnel
641                    .inbound_tx
642                    .send(Ok(H3TunnelEvent::Reset(error_code.to_string())))
643                    .await;
644            } else if let Some(tx) = tunnel.response_tx.take() {
645                let _ = tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
646            }
647            return Ok(());
648        }
649
650        if let Some(mut stream) = self.streams.remove(&stream_id) {
651            if let Some(tx) = stream.response_tx.take() {
652                let _ = tx.send(Err(Error::Quic(format!("Stream reset: {}", error_code))));
653            }
654        }
655
656        Ok(())
657    }
658
659    async fn handle_goaway_event(&mut self, id: u64) -> Result<()> {
660        self.goaway_id = Some(id);
661
662        let tunnel_ids: Vec<u64> = self.tunnels.keys().copied().collect();
663        for stream_id in tunnel_ids {
664            if stream_id > id {
665                if let Some(mut tunnel) = self.tunnels.remove(&stream_id) {
666                    if tunnel.opened {
667                        let _ = tunnel
668                            .inbound_tx
669                            .send(Ok(H3TunnelEvent::GoAway { id }))
670                            .await;
671                    } else if let Some(tx) = tunnel.response_tx.take() {
672                        let _ = tx.send(Err(Error::HttpProtocol(format!(
673                            "HTTP/3 GOAWAY received id={id}"
674                        ))));
675                    }
676                }
677            }
678        }
679
680        let stream_ids: Vec<u64> = self.streams.keys().copied().collect();
681        for stream_id in stream_ids {
682            if stream_id > id {
683                if let Some(mut stream) = self.streams.remove(&stream_id) {
684                    if let Some(tx) = stream.response_tx.take() {
685                        let _ = tx.send(Err(Error::HttpProtocol(format!(
686                            "HTTP/3 GOAWAY received id={id}"
687                        ))));
688                    }
689                }
690            }
691        }
692
693        Ok(())
694    }
695
696    async fn fail_all(&mut self, err: Error) {
697        for (_, mut stream) in self.streams.drain() {
698            if let Some(tx) = stream.response_tx.take() {
699                let _ = tx.send(Err(Error::HttpProtocol(err.to_string())));
700            }
701        }
702
703        for (_, mut tunnel) in self.tunnels.drain() {
704            if let Some(tx) = tunnel.response_tx.take() {
705                let _ = tx.send(Err(Error::HttpProtocol(err.to_string())));
706            } else {
707                let _ = tunnel
708                    .inbound_tx
709                    .send(Err(Error::HttpProtocol(err.to_string())))
710                    .await;
711            }
712        }
713
714        for cmd in self.pending_commands.drain(..) {
715            Self::fail_pending_command(cmd, Error::HttpProtocol(err.to_string()));
716        }
717    }
718
719    fn fail_pending_command(cmd: DriverCommand, err: Error) {
720        match cmd {
721            DriverCommand::SendRequest { response_tx, .. } => {
722                let _ = response_tx.send(Err(Error::HttpProtocol(err.to_string())));
723            }
724            DriverCommand::OpenWebSocketTunnel { response_tx, .. } => {
725                let _ = response_tx.send(Err(Error::HttpProtocol(err.to_string())));
726            }
727            DriverCommand::SendTunnelData { .. } => {}
728        }
729    }
730}
731
732pub(crate) fn build_websocket_connect_headers(
733    uri: &http::Uri,
734    headers: &[(String, String)],
735) -> Result<Vec<quiche::h3::Header>> {
736    let scheme = uri.scheme_str().ok_or_else(|| {
737        Error::WebSocketUnsupported("RFC 9220 requires an https URI internally".into())
738    })?;
739    if scheme != "https" {
740        return Err(Error::WebSocketUnsupported(
741            "RFC 9220 WebSocket over HTTP/3 requires wss://".into(),
742        ));
743    }
744
745    let authority = uri
746        .authority()
747        .ok_or_else(|| Error::HttpProtocol("RFC 9220 CONNECT requires :authority".into()))?
748        .as_str();
749    let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
750
751    let mut h3_headers = vec![
752        quiche::h3::Header::new(b":method", b"CONNECT"),
753        quiche::h3::Header::new(b":protocol", b"websocket"),
754        quiche::h3::Header::new(b":scheme", scheme.as_bytes()),
755        quiche::h3::Header::new(b":path", path.as_bytes()),
756        quiche::h3::Header::new(b":authority", authority.as_bytes()),
757    ];
758
759    for (name, value) in headers {
760        let lower = name.to_ascii_lowercase();
761        if name.starts_with(':') {
762            return Err(Error::HttpProtocol(format!(
763                "user pseudo-header {name} is not allowed on RFC 9220 CONNECT"
764            )));
765        }
766
767        if matches!(
768            lower.as_str(),
769            "connection"
770                | "upgrade"
771                | "host"
772                | "sec-websocket-key"
773                | "sec-websocket-accept"
774                | "sec-websocket-extensions"
775        ) {
776            return Err(Error::WebSocketUnsupported(format!(
777                "header {name} is not allowed on RFC 9220 WebSocket over HTTP/3"
778            )));
779        }
780
781        if matches!(
782            lower.as_str(),
783            "keep-alive" | "proxy-connection" | "transfer-encoding"
784        ) {
785            continue;
786        }
787
788        h3_headers.push(quiche::h3::Header::new(lower.as_bytes(), value.as_bytes()));
789    }
790
791    Ok(h3_headers)
792}
793
794fn build_request_headers(
795    method: &http::Method,
796    uri: &http::Uri,
797    headers: &[(String, String)],
798) -> Result<Vec<quiche::h3::Header>> {
799    let scheme = uri.scheme_str().unwrap_or("https");
800    let authority = uri
801        .authority()
802        .map(|authority| authority.as_str())
803        .or_else(|| uri.host())
804        .unwrap_or("");
805    let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
806
807    let mut h3_headers = vec![
808        quiche::h3::Header::new(b":method", method.as_str().as_bytes()),
809        quiche::h3::Header::new(b":scheme", scheme.as_bytes()),
810        quiche::h3::Header::new(b":authority", authority.as_bytes()),
811        quiche::h3::Header::new(b":path", path.as_bytes()),
812    ];
813
814    for (name, value) in headers {
815        let lower = name.to_ascii_lowercase();
816        if !name.starts_with(':')
817            && lower != "connection"
818            && lower != "keep-alive"
819            && lower != "proxy-connection"
820            && lower != "transfer-encoding"
821            && lower != "upgrade"
822        {
823            h3_headers.push(quiche::h3::Header::new(lower.as_bytes(), value.as_bytes()));
824        }
825    }
826
827    Ok(h3_headers)
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833
834    fn header_pairs(headers: &[quiche::h3::Header]) -> Vec<(String, String)> {
835        headers
836            .iter()
837            .map(|h| {
838                (
839                    String::from_utf8_lossy(h.name()).into_owned(),
840                    String::from_utf8_lossy(h.value()).into_owned(),
841                )
842            })
843            .collect()
844    }
845
846    #[test]
847    fn rfc9220_headers_have_required_pseudo_headers_in_order() {
848        let uri: http::Uri = "https://example.test:443/chat?room=one".parse().unwrap();
849        let headers =
850            build_websocket_connect_headers(&uri, &[("User-Agent".into(), "specter-test".into())])
851                .unwrap();
852        let pairs = header_pairs(&headers);
853
854        assert_eq!(
855            &pairs[..5],
856            &[
857                (":method".into(), "CONNECT".into()),
858                (":protocol".into(), "websocket".into()),
859                (":scheme".into(), "https".into()),
860                (":path".into(), "/chat?room=one".into()),
861                (":authority".into(), "example.test:443".into()),
862            ]
863        );
864        assert!(pairs.contains(&("user-agent".into(), "specter-test".into())));
865    }
866
867    #[test]
868    fn rfc9220_rejects_h1_websocket_bootstrap_headers() {
869        let uri: http::Uri = "https://example.test/chat".parse().unwrap();
870        for name in [
871            "Connection",
872            "Upgrade",
873            "Host",
874            "Sec-WebSocket-Key",
875            "Sec-WebSocket-Accept",
876            "Sec-WebSocket-Extensions",
877        ] {
878            let err = build_websocket_connect_headers(&uri, &[(name.into(), "x".into())])
879                .expect_err("forbidden header must fail");
880            let msg = err.to_string();
881            assert!(msg.contains("not allowed"), "{name}: {msg}");
882        }
883    }
884
885    #[test]
886    fn rfc9220_rejects_user_pseudo_headers() {
887        let uri: http::Uri = "https://example.test/chat".parse().unwrap();
888        let err = build_websocket_connect_headers(&uri, &[(":authority".into(), "evil".into())])
889            .expect_err("user pseudo headers must fail");
890        assert!(err.to_string().contains("pseudo-header"));
891    }
892}