1use std::fmt;
25use std::io;
26use std::marker::PhantomData;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29
30#[cfg(any(feature = "http2", feature = "http3"))]
31use bytes::{Buf, Bytes, BytesMut};
32use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
33
34use crate::transport::{Http1, Transport};
35
36#[cfg(feature = "http2")]
37use crate::transport::Http2;
38
39#[cfg(feature = "http3")]
40use crate::transport::Http3;
41
42pub struct Stream<T: Transport> {
57 inner: StreamInner,
58 _marker: PhantomData<T>,
59}
60
61#[allow(clippy::large_enum_variant)]
64enum StreamInner {
65 Http1(Box<dyn Http1Stream>),
67
68 #[cfg(feature = "http2")]
70 Http2(Http2StreamInner),
71
72 #[cfg(feature = "http3")]
74 Http3(Http3StreamInner),
75}
76
77trait Http1Stream: AsyncRead + AsyncWrite + Unpin + Send {}
83impl<T: AsyncRead + AsyncWrite + Unpin + Send> Http1Stream for T {}
84
85impl Stream<Http1> {
86 pub fn new<S>(inner: S) -> Self
98 where
99 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
100 {
101 Self {
102 inner: StreamInner::Http1(Box::new(inner)),
103 _marker: PhantomData,
104 }
105 }
106}
107
108impl AsyncRead for Stream<Http1> {
109 fn poll_read(
110 mut self: Pin<&mut Self>,
111 cx: &mut Context<'_>,
112 buf: &mut ReadBuf<'_>,
113 ) -> Poll<io::Result<()>> {
114 match &mut self.inner {
115 StreamInner::Http1(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf),
116 #[cfg(any(feature = "http2", feature = "http3"))]
117 _ => unreachable!(),
118 }
119 }
120}
121
122impl AsyncWrite for Stream<Http1> {
123 fn poll_write(
124 mut self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 buf: &[u8],
127 ) -> Poll<io::Result<usize>> {
128 match &mut self.inner {
129 StreamInner::Http1(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf),
130 #[cfg(any(feature = "http2", feature = "http3"))]
131 _ => unreachable!(),
132 }
133 }
134
135 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136 match &mut self.inner {
137 StreamInner::Http1(stream) => Pin::new(stream.as_mut()).poll_flush(cx),
138 #[cfg(any(feature = "http2", feature = "http3"))]
139 _ => unreachable!(),
140 }
141 }
142
143 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
144 match &mut self.inner {
145 StreamInner::Http1(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx),
146 #[cfg(any(feature = "http2", feature = "http3"))]
147 _ => unreachable!(),
148 }
149 }
150}
151
152impl fmt::Debug for Stream<Http1> {
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 f.debug_struct("Stream<Http1>").finish()
155 }
156}
157
158#[cfg(feature = "http2")]
163struct Http2StreamInner {
164 send: h2::SendStream<Bytes>,
165 recv: h2::RecvStream,
166 recv_buf: BytesMut,
167 recv_eof: bool,
168 capacity_needed: usize,
169}
170
171#[cfg(feature = "http2")]
172impl Stream<Http2> {
173 pub fn from_h2(send: h2::SendStream<Bytes>, recv: h2::RecvStream) -> Self {
186 Self {
187 inner: StreamInner::Http2(Http2StreamInner {
188 send,
189 recv,
190 recv_buf: BytesMut::with_capacity(64 * 1024),
191 recv_eof: false,
192 capacity_needed: 0,
193 }),
194 _marker: PhantomData,
195 }
196 }
197
198 pub fn send_stream(&self) -> Option<&h2::SendStream<Bytes>> {
200 match &self.inner {
201 StreamInner::Http2(inner) => Some(&inner.send),
202 _ => None,
203 }
204 }
205
206 pub fn send_stream_mut(&mut self) -> Option<&mut h2::SendStream<Bytes>> {
208 match &mut self.inner {
209 StreamInner::Http2(inner) => Some(&mut inner.send),
210 _ => None,
211 }
212 }
213}
214
215#[cfg(feature = "http2")]
216impl AsyncRead for Stream<Http2> {
217 fn poll_read(
218 mut self: Pin<&mut Self>,
219 cx: &mut Context<'_>,
220 buf: &mut ReadBuf<'_>,
221 ) -> Poll<io::Result<()>> {
222 let inner = match &mut self.inner {
223 StreamInner::Http2(inner) => inner,
224 _ => unreachable!(),
225 };
226
227 if !inner.recv_buf.is_empty() {
229 let to_copy = std::cmp::min(buf.remaining(), inner.recv_buf.len());
230 buf.put_slice(&inner.recv_buf.split_to(to_copy));
231 return Poll::Ready(Ok(()));
232 }
233
234 if inner.recv_eof {
236 return Poll::Ready(Ok(()));
237 }
238
239 match Pin::new(&mut inner.recv).poll_data(cx) {
241 Poll::Ready(Some(Ok(mut data))) => {
242 let len = data.len();
244 let _ = inner.recv.flow_control().release_capacity(len);
245
246 let to_copy = std::cmp::min(buf.remaining(), data.len());
248 buf.put_slice(&data.split_to(to_copy));
249
250 if data.has_remaining() {
252 inner.recv_buf.extend_from_slice(data.chunk());
253 }
254
255 Poll::Ready(Ok(()))
256 }
257 Poll::Ready(Some(Err(e))) => Poll::Ready(Err(io::Error::other(e))),
258 Poll::Ready(None) => {
259 inner.recv_eof = true;
260 Poll::Ready(Ok(()))
261 }
262 Poll::Pending => Poll::Pending,
263 }
264 }
265}
266
267#[cfg(feature = "http2")]
268impl AsyncWrite for Stream<Http2> {
269 fn poll_write(
270 mut self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 buf: &[u8],
273 ) -> Poll<io::Result<usize>> {
274 if buf.is_empty() {
275 return Poll::Ready(Ok(0));
276 }
277
278 let inner = match &mut self.inner {
279 StreamInner::Http2(inner) => inner,
280 _ => unreachable!(),
281 };
282
283 if inner.capacity_needed > 0 || inner.send.capacity() == 0 {
285 inner.send.reserve_capacity(buf.len());
286 inner.capacity_needed = buf.len();
287 }
288
289 match inner.send.poll_capacity(cx) {
291 Poll::Ready(Some(Ok(capacity))) => {
292 let to_send = std::cmp::min(capacity, buf.len());
293 let data = Bytes::copy_from_slice(&buf[..to_send]);
294
295 inner
296 .send
297 .send_data(data, false)
298 .map_err(io::Error::other)?;
299
300 inner.capacity_needed = 0;
301 Poll::Ready(Ok(to_send))
302 }
303 Poll::Ready(Some(Err(e))) => Poll::Ready(Err(io::Error::other(e))),
304 Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
305 io::ErrorKind::BrokenPipe,
306 "HTTP/2 stream closed",
307 ))),
308 Poll::Pending => {
309 inner.capacity_needed = buf.len();
310 Poll::Pending
311 }
312 }
313 }
314
315 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316 Poll::Ready(Ok(()))
318 }
319
320 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
321 let inner = match &mut self.inner {
322 StreamInner::Http2(inner) => inner,
323 _ => unreachable!(),
324 };
325
326 inner
328 .send
329 .send_data(Bytes::new(), true)
330 .map_err(io::Error::other)?;
331 Poll::Ready(Ok(()))
332 }
333}
334
335#[cfg(feature = "http2")]
336impl fmt::Debug for Stream<Http2> {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 match &self.inner {
339 StreamInner::Http2(inner) => f
340 .debug_struct("Stream<Http2>")
341 .field("recv_buf_len", &inner.recv_buf.len())
342 .field("recv_eof", &inner.recv_eof)
343 .finish(),
344 _ => unreachable!(),
345 }
346 }
347}
348
349#[cfg(feature = "http3")]
354enum Http3StreamInner {
355 Raw {
357 send: quinn::SendStream,
358 recv: quinn::RecvStream,
359 recv_buf: BytesMut,
360 recv_finished: bool,
361 },
362 Server {
364 stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
365 read_buf: BytesMut,
366 },
367 Client {
369 stream: h3::client::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
370 read_buf: BytesMut,
371 },
372}
373
374#[cfg(feature = "http3")]
375impl Stream<Http3> {
376 pub fn from_quic(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
380 Self {
381 inner: StreamInner::Http3(Http3StreamInner::Raw {
382 send,
383 recv,
384 recv_buf: BytesMut::with_capacity(64 * 1024),
385 recv_finished: false,
386 }),
387 _marker: PhantomData,
388 }
389 }
390
391 pub fn from_quic_bi(stream: (quinn::SendStream, quinn::RecvStream)) -> Self {
393 Self::from_quic(stream.0, stream.1)
394 }
395
396 pub fn from_h3_server(
400 stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
401 ) -> Self {
402 Self {
403 inner: StreamInner::Http3(Http3StreamInner::Server {
404 stream,
405 read_buf: BytesMut::with_capacity(64 * 1024),
406 }),
407 _marker: PhantomData,
408 }
409 }
410
411 pub fn from_h3_client(
415 stream: h3::client::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
416 ) -> Self {
417 Self {
418 inner: StreamInner::Http3(Http3StreamInner::Client {
419 stream,
420 read_buf: BytesMut::with_capacity(64 * 1024),
421 }),
422 _marker: PhantomData,
423 }
424 }
425
426 pub fn stream_id(&self) -> Option<quinn::StreamId> {
428 match &self.inner {
429 StreamInner::Http3(Http3StreamInner::Raw { send, .. }) => Some(send.id()),
430 _ => None,
431 }
432 }
433}
434
435#[cfg(feature = "http3")]
436impl AsyncRead for Stream<Http3> {
437 fn poll_read(
438 mut self: Pin<&mut Self>,
439 cx: &mut Context<'_>,
440 buf: &mut ReadBuf<'_>,
441 ) -> Poll<io::Result<()>> {
442 match &mut self.inner {
443 StreamInner::Http3(Http3StreamInner::Raw {
444 recv,
445 recv_buf,
446 recv_finished,
447 ..
448 }) => {
449 if !recv_buf.is_empty() {
451 let to_copy = std::cmp::min(buf.remaining(), recv_buf.len());
452 buf.put_slice(&recv_buf.split_to(to_copy));
453 return Poll::Ready(Ok(()));
454 }
455
456 if *recv_finished {
457 return Poll::Ready(Ok(()));
458 }
459
460 match recv.poll_read_buf(cx, buf) {
462 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
463 Poll::Ready(Err(e)) => {
464 if matches!(e, quinn::ReadError::Reset(_)) {
465 *recv_finished = true;
466 }
467 Poll::Ready(Err(io::Error::other(e)))
468 }
469 Poll::Pending => Poll::Pending,
470 }
471 }
472 StreamInner::Http3(Http3StreamInner::Server { stream, read_buf }) => {
473 if !read_buf.is_empty() {
475 let to_copy = std::cmp::min(buf.remaining(), read_buf.len());
476 buf.put_slice(&read_buf.split_to(to_copy));
477 return Poll::Ready(Ok(()));
478 }
479
480 let mut fut = Box::pin(stream.recv_data());
482 match fut.as_mut().poll(cx) {
483 Poll::Ready(Ok(Some(mut data))) => {
484 let data_len = data.remaining();
485 let to_copy = std::cmp::min(buf.remaining(), data_len);
486 let chunk = data.copy_to_bytes(to_copy);
487 buf.put_slice(&chunk);
488
489 if data.has_remaining() {
491 while data.has_remaining() {
492 read_buf.extend_from_slice(data.chunk());
493 let len = data.chunk().len();
494 data.advance(len);
495 }
496 }
497 Poll::Ready(Ok(()))
498 }
499 Poll::Ready(Ok(None)) => Poll::Ready(Ok(())),
500 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
501 Poll::Pending => Poll::Pending,
502 }
503 }
504 StreamInner::Http3(Http3StreamInner::Client { stream, read_buf }) => {
505 if !read_buf.is_empty() {
507 let to_copy = std::cmp::min(buf.remaining(), read_buf.len());
508 buf.put_slice(&read_buf.split_to(to_copy));
509 return Poll::Ready(Ok(()));
510 }
511
512 let mut fut = Box::pin(stream.recv_data());
514 match fut.as_mut().poll(cx) {
515 Poll::Ready(Ok(Some(mut data))) => {
516 let data_len = data.remaining();
517 let to_copy = std::cmp::min(buf.remaining(), data_len);
518 let chunk = data.copy_to_bytes(to_copy);
519 buf.put_slice(&chunk);
520
521 if data.has_remaining() {
523 while data.has_remaining() {
524 read_buf.extend_from_slice(data.chunk());
525 let len = data.chunk().len();
526 data.advance(len);
527 }
528 }
529 Poll::Ready(Ok(()))
530 }
531 Poll::Ready(Ok(None)) => Poll::Ready(Ok(())),
532 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
533 Poll::Pending => Poll::Pending,
534 }
535 }
536 _ => unreachable!(),
537 }
538 }
539}
540
541#[cfg(feature = "http3")]
542impl AsyncWrite for Stream<Http3> {
543 fn poll_write(
544 mut self: Pin<&mut Self>,
545 cx: &mut Context<'_>,
546 buf: &[u8],
547 ) -> Poll<io::Result<usize>> {
548 if buf.is_empty() {
549 return Poll::Ready(Ok(0));
550 }
551
552 match &mut self.inner {
553 StreamInner::Http3(Http3StreamInner::Raw { send, .. }) => {
554 match Pin::new(send).poll_write(cx, buf) {
555 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
556 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
557 Poll::Pending => Poll::Pending,
558 }
559 }
560 StreamInner::Http3(Http3StreamInner::Server { stream, .. }) => {
561 let data = Bytes::copy_from_slice(buf);
562 let fut = stream.send_data(data);
563 tokio::pin!(fut);
564
565 match fut.poll(cx) {
566 Poll::Ready(Ok(())) => Poll::Ready(Ok(buf.len())),
567 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
568 Poll::Pending => Poll::Pending,
569 }
570 }
571 StreamInner::Http3(Http3StreamInner::Client { stream, .. }) => {
572 let data = Bytes::copy_from_slice(buf);
573 let fut = stream.send_data(data);
574 tokio::pin!(fut);
575
576 match fut.poll(cx) {
577 Poll::Ready(Ok(())) => Poll::Ready(Ok(buf.len())),
578 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
579 Poll::Pending => Poll::Pending,
580 }
581 }
582 _ => unreachable!(),
583 }
584 }
585
586 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
587 Poll::Ready(Ok(()))
589 }
590
591 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
592 match &mut self.inner {
593 StreamInner::Http3(Http3StreamInner::Raw { send, .. }) => match send.finish() {
594 Ok(()) => Poll::Ready(Ok(())),
595 Err(e) => Poll::Ready(Err(io::Error::other(e))),
596 },
597 StreamInner::Http3(Http3StreamInner::Server { stream, .. }) => {
598 let fut = stream.finish();
599 tokio::pin!(fut);
600
601 match fut.poll(cx) {
602 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
603 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
604 Poll::Pending => Poll::Pending,
605 }
606 }
607 StreamInner::Http3(Http3StreamInner::Client { stream, .. }) => {
608 let fut = stream.finish();
609 tokio::pin!(fut);
610
611 match fut.poll(cx) {
612 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
613 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e.to_string()))),
614 Poll::Pending => Poll::Pending,
615 }
616 }
617 _ => unreachable!(),
618 }
619 }
620}
621
622#[cfg(feature = "http3")]
623impl fmt::Debug for Stream<Http3> {
624 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
625 match &self.inner {
626 StreamInner::Http3(Http3StreamInner::Raw {
627 recv_buf,
628 recv_finished,
629 ..
630 }) => f
631 .debug_struct("Stream<Http3>")
632 .field("variant", &"Raw")
633 .field("recv_buf_len", &recv_buf.len())
634 .field("recv_finished", recv_finished)
635 .finish(),
636 StreamInner::Http3(Http3StreamInner::Server { read_buf, .. }) => f
637 .debug_struct("Stream<Http3>")
638 .field("variant", &"Server")
639 .field("read_buf_len", &read_buf.len())
640 .finish(),
641 StreamInner::Http3(Http3StreamInner::Client { read_buf, .. }) => f
642 .debug_struct("Stream<Http3>")
643 .field("variant", &"Client")
644 .field("read_buf_len", &read_buf.len())
645 .finish(),
646 _ => unreachable!(),
647 }
648 }
649}
650
651unsafe impl Send for Stream<Http1> {}
657#[cfg(feature = "http2")]
658unsafe impl Send for Stream<Http2> {}
659#[cfg(feature = "http3")]
660unsafe impl Send for Stream<Http3> {}
661
662#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_stream_is_send() {
672 fn assert_send<T: Send>() {}
673 assert_send::<Stream<Http1>>();
674 #[cfg(feature = "http2")]
675 assert_send::<Stream<Http2>>();
676 #[cfg(feature = "http3")]
677 assert_send::<Stream<Http3>>();
678 }
679}