sec_http3/webtransport/
stream.rs

1use std::task::Poll;
2
3use crate::{
4    quic::{self},
5    stream::BufRecvStream,
6};
7use bytes::{Buf, Bytes};
8use pin_project_lite::pin_project;
9use tokio::io::ReadBuf;
10
11pin_project! {
12    /// WebTransport receive stream
13    pub struct RecvStream<S,B> {
14        #[pin]
15        stream: BufRecvStream<S, B>,
16    }
17}
18
19impl<S, B> RecvStream<S, B> {
20    #[allow(missing_docs)]
21    pub fn new(stream: BufRecvStream<S, B>) -> Self {
22        Self { stream }
23    }
24}
25
26impl<S, B> quic::RecvStream for RecvStream<S, B>
27where
28    S: quic::RecvStream,
29    B: Buf,
30{
31    type Buf = Bytes;
32
33    type Error = S::Error;
34
35    fn poll_data(
36        &mut self,
37        cx: &mut std::task::Context<'_>,
38    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
39        self.stream.poll_data(cx)
40    }
41
42    fn stop_sending(&mut self, error_code: u64) {
43        self.stream.stop_sending(error_code)
44    }
45
46    fn recv_id(&self) -> quic::StreamId {
47        self.stream.recv_id()
48    }
49}
50
51impl<S, B> futures_util::io::AsyncRead for RecvStream<S, B>
52where
53    BufRecvStream<S, B>: futures_util::io::AsyncRead,
54{
55    fn poll_read(
56        self: std::pin::Pin<&mut Self>,
57        cx: &mut std::task::Context<'_>,
58        buf: &mut [u8],
59    ) -> Poll<std::io::Result<usize>> {
60        let p = self.project();
61        p.stream.poll_read(cx, buf)
62    }
63}
64
65impl<S, B> tokio::io::AsyncRead for RecvStream<S, B>
66where
67    BufRecvStream<S, B>: tokio::io::AsyncRead,
68{
69    fn poll_read(
70        self: std::pin::Pin<&mut Self>,
71        cx: &mut std::task::Context<'_>,
72        buf: &mut ReadBuf<'_>,
73    ) -> Poll<std::io::Result<()>> {
74        let p = self.project();
75        p.stream.poll_read(cx, buf)
76    }
77}
78
79pin_project! {
80    /// WebTransport send stream
81    pub struct SendStream<S,B> {
82        #[pin]
83        stream: BufRecvStream<S ,B>,
84    }
85}
86
87impl<S, B> std::fmt::Debug for SendStream<S, B> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("SendStream")
90            .field("stream", &self.stream)
91            .finish()
92    }
93}
94
95impl<S, B> SendStream<S, B> {
96    #[allow(missing_docs)]
97    pub(crate) fn new(stream: BufRecvStream<S, B>) -> Self {
98        Self { stream }
99    }
100}
101
102impl<S, B> quic::SendStreamUnframed<B> for SendStream<S, B>
103where
104    S: quic::SendStreamUnframed<B>,
105    B: Buf,
106{
107    fn poll_send<D: Buf>(
108        &mut self,
109        cx: &mut std::task::Context<'_>,
110        buf: &mut D,
111    ) -> Poll<Result<usize, Self::Error>> {
112        self.stream.poll_send(cx, buf)
113    }
114}
115
116impl<S, B> quic::SendStream<B> for SendStream<S, B>
117where
118    S: quic::SendStream<B>,
119    B: Buf,
120{
121    type Error = S::Error;
122
123    fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
124        self.stream.poll_finish(cx)
125    }
126
127    fn reset(&mut self, reset_code: u64) {
128        self.stream.reset(reset_code)
129    }
130
131    fn send_id(&self) -> quic::StreamId {
132        self.stream.send_id()
133    }
134
135    fn send_data<T: Into<crate::stream::WriteBuf<B>>>(
136        &mut self,
137        data: T,
138    ) -> Result<(), Self::Error> {
139        self.stream.send_data(data)
140    }
141
142    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
143        self.stream.poll_ready(cx)
144    }
145}
146
147impl<S, B> futures_util::io::AsyncWrite for SendStream<S, B>
148where
149    BufRecvStream<S, B>: futures_util::io::AsyncWrite,
150{
151    fn poll_write(
152        self: std::pin::Pin<&mut Self>,
153        cx: &mut std::task::Context<'_>,
154        buf: &[u8],
155    ) -> Poll<std::io::Result<usize>> {
156        let p = self.project();
157        p.stream.poll_write(cx, buf)
158    }
159
160    fn poll_flush(
161        self: std::pin::Pin<&mut Self>,
162        cx: &mut std::task::Context<'_>,
163    ) -> Poll<std::io::Result<()>> {
164        let p = self.project();
165        p.stream.poll_flush(cx)
166    }
167
168    fn poll_close(
169        self: std::pin::Pin<&mut Self>,
170        cx: &mut std::task::Context<'_>,
171    ) -> Poll<std::io::Result<()>> {
172        let p = self.project();
173        p.stream.poll_close(cx)
174    }
175}
176
177impl<S, B> tokio::io::AsyncWrite for SendStream<S, B>
178where
179    BufRecvStream<S, B>: tokio::io::AsyncWrite,
180{
181    fn poll_write(
182        self: std::pin::Pin<&mut Self>,
183        cx: &mut std::task::Context<'_>,
184        buf: &[u8],
185    ) -> Poll<std::io::Result<usize>> {
186        let p = self.project();
187        p.stream.poll_write(cx, buf)
188    }
189
190    fn poll_flush(
191        self: std::pin::Pin<&mut Self>,
192        cx: &mut std::task::Context<'_>,
193    ) -> Poll<std::io::Result<()>> {
194        let p = self.project();
195        p.stream.poll_flush(cx)
196    }
197
198    fn poll_shutdown(
199        self: std::pin::Pin<&mut Self>,
200        cx: &mut std::task::Context<'_>,
201    ) -> Poll<std::io::Result<()>> {
202        let p = self.project();
203        p.stream.poll_shutdown(cx)
204    }
205}
206
207pin_project! {
208    /// Combined send and receive stream.
209    ///
210    /// Can be split into a [`RecvStream`] and [`SendStream`] if the underlying QUIC implementation
211    /// supports it.
212    pub struct BidiStream<S, B> {
213        #[pin]
214        stream: BufRecvStream<S, B>,
215    }
216}
217
218impl<S, B> BidiStream<S, B> {
219    pub(crate) fn new(stream: BufRecvStream<S, B>) -> Self {
220        Self { stream }
221    }
222}
223
224impl<S, B> quic::SendStream<B> for BidiStream<S, B>
225where
226    S: quic::SendStream<B>,
227    B: Buf,
228{
229    type Error = S::Error;
230
231    fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
232        self.stream.poll_finish(cx)
233    }
234
235    fn reset(&mut self, reset_code: u64) {
236        self.stream.reset(reset_code)
237    }
238
239    fn send_id(&self) -> quic::StreamId {
240        self.stream.send_id()
241    }
242
243    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
244        self.stream.poll_ready(cx)
245    }
246
247    fn send_data<T: Into<crate::stream::WriteBuf<B>>>(
248        &mut self,
249        data: T,
250    ) -> Result<(), Self::Error> {
251        self.stream.send_data(data)
252    }
253}
254
255impl<S, B> quic::SendStreamUnframed<B> for BidiStream<S, B>
256where
257    S: quic::SendStreamUnframed<B>,
258    B: Buf,
259{
260    fn poll_send<D: Buf>(
261        &mut self,
262        cx: &mut std::task::Context<'_>,
263        buf: &mut D,
264    ) -> Poll<Result<usize, Self::Error>> {
265        self.stream.poll_send(cx, buf)
266    }
267}
268
269impl<S: quic::RecvStream, B> quic::RecvStream for BidiStream<S, B> {
270    type Buf = Bytes;
271
272    type Error = S::Error;
273
274    fn poll_data(
275        &mut self,
276        cx: &mut std::task::Context<'_>,
277    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
278        self.stream.poll_data(cx)
279    }
280
281    fn stop_sending(&mut self, error_code: u64) {
282        self.stream.stop_sending(error_code)
283    }
284
285    fn recv_id(&self) -> quic::StreamId {
286        self.stream.recv_id()
287    }
288}
289
290impl<S, B> quic::BidiStream<B> for BidiStream<S, B>
291where
292    S: quic::BidiStream<B>,
293    B: Buf,
294{
295    type SendStream = SendStream<S::SendStream, B>;
296
297    type RecvStream = RecvStream<S::RecvStream, B>;
298
299    fn split(self) -> (Self::SendStream, Self::RecvStream) {
300        let (send, recv) = self.stream.split();
301        (SendStream::new(send), RecvStream::new(recv))
302    }
303}
304
305impl<S, B> futures_util::io::AsyncRead for BidiStream<S, B>
306where
307    BufRecvStream<S, B>: futures_util::io::AsyncRead,
308{
309    fn poll_read(
310        self: std::pin::Pin<&mut Self>,
311        cx: &mut std::task::Context<'_>,
312        buf: &mut [u8],
313    ) -> Poll<std::io::Result<usize>> {
314        let p = self.project();
315        p.stream.poll_read(cx, buf)
316    }
317}
318
319impl<S, B> futures_util::io::AsyncWrite for BidiStream<S, B>
320where
321    BufRecvStream<S, B>: futures_util::io::AsyncWrite,
322{
323    fn poll_write(
324        self: std::pin::Pin<&mut Self>,
325        cx: &mut std::task::Context<'_>,
326        buf: &[u8],
327    ) -> Poll<std::io::Result<usize>> {
328        let p = self.project();
329        p.stream.poll_write(cx, buf)
330    }
331
332    fn poll_flush(
333        self: std::pin::Pin<&mut Self>,
334        cx: &mut std::task::Context<'_>,
335    ) -> Poll<std::io::Result<()>> {
336        let p = self.project();
337        p.stream.poll_flush(cx)
338    }
339
340    fn poll_close(
341        self: std::pin::Pin<&mut Self>,
342        cx: &mut std::task::Context<'_>,
343    ) -> Poll<std::io::Result<()>> {
344        let p = self.project();
345        p.stream.poll_close(cx)
346    }
347}
348
349impl<S, B> tokio::io::AsyncRead for BidiStream<S, B>
350where
351    BufRecvStream<S, B>: tokio::io::AsyncRead,
352{
353    fn poll_read(
354        self: std::pin::Pin<&mut Self>,
355        cx: &mut std::task::Context<'_>,
356        buf: &mut ReadBuf<'_>,
357    ) -> Poll<std::io::Result<()>> {
358        let p = self.project();
359        p.stream.poll_read(cx, buf)
360    }
361}
362
363impl<S, B> tokio::io::AsyncWrite for BidiStream<S, B>
364where
365    BufRecvStream<S, B>: tokio::io::AsyncWrite,
366{
367    fn poll_write(
368        self: std::pin::Pin<&mut Self>,
369        cx: &mut std::task::Context<'_>,
370        buf: &[u8],
371    ) -> Poll<std::io::Result<usize>> {
372        let p = self.project();
373        p.stream.poll_write(cx, buf)
374    }
375
376    fn poll_flush(
377        self: std::pin::Pin<&mut Self>,
378        cx: &mut std::task::Context<'_>,
379    ) -> Poll<std::io::Result<()>> {
380        let p = self.project();
381        p.stream.poll_flush(cx)
382    }
383
384    fn poll_shutdown(
385        self: std::pin::Pin<&mut Self>,
386        cx: &mut std::task::Context<'_>,
387    ) -> Poll<std::io::Result<()>> {
388        let p = self.project();
389        p.stream.poll_shutdown(cx)
390    }
391}