1#![deny(missing_docs)]
5
6use std::{
7 convert::TryInto,
8 fmt::{self, Display},
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{self, Poll},
13};
14
15use bytes::{Buf, Bytes, BytesMut};
16
17use futures::{
18 ready,
19 stream::{self, BoxStream},
20 StreamExt,
21};
22use quinn::ReadDatagram;
23pub use quinn::{
24 self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError,
25};
26
27use crate::{
28 ext::Datagram,
29 quic::{self, Error, StreamId, WriteBuf},
30};
31use tokio_util::sync::ReusableBoxFuture;
32
33pub struct Connection {
37 conn: quinn::Connection,
38 incoming_bi: BoxStream<'static, <AcceptBi<'static> as Future>::Output>,
39 opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>,
40 incoming_uni: BoxStream<'static, <AcceptUni<'static> as Future>::Output>,
41 opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>,
42 datagrams: BoxStream<'static, <ReadDatagram<'static> as Future>::Output>,
43}
44
45impl Connection {
46 pub fn new(conn: quinn::Connection) -> Self {
48 Self {
49 conn: conn.clone(),
50 incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
51 Some((conn.accept_bi().await, conn))
52 })),
53 opening_bi: None,
54 incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
55 Some((conn.accept_uni().await, conn))
56 })),
57 opening_uni: None,
58 datagrams: Box::pin(stream::unfold(conn, |conn| async {
59 Some((conn.read_datagram().await, conn))
60 })),
61 }
62 }
63}
64
65#[derive(Debug)]
69pub struct ConnectionError(quinn::ConnectionError);
70
71impl std::error::Error for ConnectionError {}
72
73impl fmt::Display for ConnectionError {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 self.0.fmt(f)
76 }
77}
78
79impl Error for ConnectionError {
80 fn is_timeout(&self) -> bool {
81 matches!(self.0, quinn::ConnectionError::TimedOut)
82 }
83
84 fn err_code(&self) -> Option<u64> {
85 match self.0 {
86 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
87 error_code,
88 ..
89 }) => Some(error_code.into_inner()),
90 _ => None,
91 }
92 }
93}
94
95impl From<quinn::ConnectionError> for ConnectionError {
96 fn from(e: quinn::ConnectionError) -> Self {
97 Self(e)
98 }
99}
100
101#[derive(Debug)]
103pub enum SendDatagramError {
104 UnsupportedByPeer,
106 Disabled,
108 TooLarge,
110 ConnectionLost(Box<dyn Error>),
112}
113
114impl fmt::Display for SendDatagramError {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 match self {
117 SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"),
118 SendDatagramError::Disabled => write!(f, "datagram support disabled"),
119 SendDatagramError::TooLarge => write!(f, "datagram too large"),
120 SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"),
121 }
122 }
123}
124
125impl std::error::Error for SendDatagramError {}
126
127impl Error for SendDatagramError {
128 fn is_timeout(&self) -> bool {
129 false
130 }
131
132 fn err_code(&self) -> Option<u64> {
133 match self {
134 Self::ConnectionLost(err) => err.err_code(),
135 _ => None,
136 }
137 }
138}
139
140impl From<quinn::SendDatagramError> for SendDatagramError {
141 fn from(value: quinn::SendDatagramError) -> Self {
142 match value {
143 quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer,
144 quinn::SendDatagramError::Disabled => Self::Disabled,
145 quinn::SendDatagramError::TooLarge => Self::TooLarge,
146 quinn::SendDatagramError::ConnectionLost(err) => {
147 Self::ConnectionLost(ConnectionError::from(err).into())
148 }
149 }
150 }
151}
152
153impl<B> quic::Connection<B> for Connection
154where
155 B: Buf,
156{
157 type SendStream = SendStream<B>;
158 type RecvStream = RecvStream;
159 type BidiStream = BidiStream<B>;
160 type OpenStreams = OpenStreams;
161 type Error = ConnectionError;
162
163 fn poll_accept_bidi(
164 &mut self,
165 cx: &mut task::Context<'_>,
166 ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> {
167 let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) {
168 Some(x) => x?,
169 None => return Poll::Ready(Ok(None)),
170 };
171 Poll::Ready(Ok(Some(Self::BidiStream {
172 send: Self::SendStream::new(send),
173 recv: Self::RecvStream::new(recv),
174 })))
175 }
176
177 fn poll_accept_recv(
178 &mut self,
179 cx: &mut task::Context<'_>,
180 ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> {
181 let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
182 Some(x) => x?,
183 None => return Poll::Ready(Ok(None)),
184 };
185 Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
186 }
187
188 fn poll_open_bidi(
189 &mut self,
190 cx: &mut task::Context<'_>,
191 ) -> Poll<Result<Self::BidiStream, Self::Error>> {
192 if self.opening_bi.is_none() {
193 self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
194 Some((conn.clone().open_bi().await, conn))
195 })));
196 }
197
198 let (send, recv) =
199 ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
200 Poll::Ready(Ok(Self::BidiStream {
201 send: Self::SendStream::new(send),
202 recv: Self::RecvStream::new(recv),
203 }))
204 }
205
206 fn poll_open_send(
207 &mut self,
208 cx: &mut task::Context<'_>,
209 ) -> Poll<Result<Self::SendStream, Self::Error>> {
210 if self.opening_uni.is_none() {
211 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
212 Some((conn.open_uni().await, conn))
213 })));
214 }
215
216 let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
217 Poll::Ready(Ok(Self::SendStream::new(send)))
218 }
219
220 fn opener(&self) -> Self::OpenStreams {
221 OpenStreams {
222 conn: self.conn.clone(),
223 opening_bi: None,
224 opening_uni: None,
225 }
226 }
227
228 fn close(&mut self, code: crate::error::Code, reason: &[u8]) {
229 self.conn.close(
230 VarInt::from_u64(code.value()).expect("error code VarInt"),
231 reason,
232 );
233 }
234}
235
236impl<B> quic::SendDatagramExt<B> for Connection
237where
238 B: Buf,
239{
240 type Error = SendDatagramError;
241
242 fn send_datagram(&mut self, data: Datagram<B>) -> Result<(), SendDatagramError> {
243 let mut buf = BytesMut::new();
245 data.encode(&mut buf);
246 self.conn.send_datagram(buf.freeze())?;
247
248 Ok(())
249 }
250}
251
252impl quic::RecvDatagramExt for Connection {
253 type Buf = Bytes;
254
255 type Error = ConnectionError;
256
257 #[inline]
258 fn poll_accept_datagram(
259 &mut self,
260 cx: &mut task::Context<'_>,
261 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
262 match ready!(self.datagrams.poll_next_unpin(cx)) {
263 Some(Ok(x)) => Poll::Ready(Ok(Some(x))),
264 Some(Err(e)) => Poll::Ready(Err(e.into())),
265 None => Poll::Ready(Ok(None)),
266 }
267 }
268}
269
270pub struct OpenStreams {
275 conn: quinn::Connection,
276 opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>,
277 opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>,
278}
279
280impl<B> quic::OpenStreams<B> for OpenStreams
281where
282 B: Buf,
283{
284 type RecvStream = RecvStream;
285 type SendStream = SendStream<B>;
286 type BidiStream = BidiStream<B>;
287 type Error = ConnectionError;
288
289 fn poll_open_bidi(
290 &mut self,
291 cx: &mut task::Context<'_>,
292 ) -> Poll<Result<Self::BidiStream, Self::Error>> {
293 if self.opening_bi.is_none() {
294 self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
295 Some((conn.open_bi().await, conn))
296 })));
297 }
298
299 let (send, recv) =
300 ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
301 Poll::Ready(Ok(Self::BidiStream {
302 send: Self::SendStream::new(send),
303 recv: Self::RecvStream::new(recv),
304 }))
305 }
306
307 fn poll_open_send(
308 &mut self,
309 cx: &mut task::Context<'_>,
310 ) -> Poll<Result<Self::SendStream, Self::Error>> {
311 if self.opening_uni.is_none() {
312 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
313 Some((conn.open_uni().await, conn))
314 })));
315 }
316
317 let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
318 Poll::Ready(Ok(Self::SendStream::new(send)))
319 }
320
321 fn close(&mut self, code: crate::error::Code, reason: &[u8]) {
322 self.conn.close(
323 VarInt::from_u64(code.value()).expect("error code VarInt"),
324 reason,
325 );
326 }
327}
328
329impl Clone for OpenStreams {
330 fn clone(&self) -> Self {
331 Self {
332 conn: self.conn.clone(),
333 opening_bi: None,
334 opening_uni: None,
335 }
336 }
337}
338
339pub struct BidiStream<B>
344where
345 B: Buf,
346{
347 send: SendStream<B>,
348 recv: RecvStream,
349}
350
351impl<B> quic::BidiStream<B> for BidiStream<B>
352where
353 B: Buf,
354{
355 type SendStream = SendStream<B>;
356 type RecvStream = RecvStream;
357
358 fn split(self) -> (Self::SendStream, Self::RecvStream) {
359 (self.send, self.recv)
360 }
361}
362
363impl<B: Buf> quic::RecvStream for BidiStream<B> {
364 type Buf = Bytes;
365 type Error = ReadError;
366
367 fn poll_data(
368 &mut self,
369 cx: &mut task::Context<'_>,
370 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
371 self.recv.poll_data(cx)
372 }
373
374 fn stop_sending(&mut self, error_code: u64) {
375 self.recv.stop_sending(error_code)
376 }
377
378 fn recv_id(&self) -> StreamId {
379 self.recv.recv_id()
380 }
381}
382
383impl<B> quic::SendStream<B> for BidiStream<B>
384where
385 B: Buf,
386{
387 type Error = SendStreamError;
388
389 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
390 self.send.poll_ready(cx)
391 }
392
393 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
394 self.send.poll_finish(cx)
395 }
396
397 fn reset(&mut self, reset_code: u64) {
398 self.send.reset(reset_code)
399 }
400
401 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
402 self.send.send_data(data)
403 }
404
405 fn send_id(&self) -> StreamId {
406 self.send.send_id()
407 }
408}
409impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
410where
411 B: Buf,
412{
413 fn poll_send<D: Buf>(
414 &mut self,
415 cx: &mut task::Context<'_>,
416 buf: &mut D,
417 ) -> Poll<Result<usize, Self::Error>> {
418 self.send.poll_send(cx, buf)
419 }
420}
421
422pub struct RecvStream {
426 stream: Option<quinn::RecvStream>,
427 read_chunk_fut: ReadChunkFuture,
428}
429
430type ReadChunkFuture = ReusableBoxFuture<
431 'static,
432 (
433 quinn::RecvStream,
434 Result<Option<quinn::Chunk>, quinn::ReadError>,
435 ),
436>;
437
438impl RecvStream {
439 fn new(stream: quinn::RecvStream) -> Self {
440 Self {
441 stream: Some(stream),
442 read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
444 }
445 }
446}
447
448impl quic::RecvStream for RecvStream {
449 type Buf = Bytes;
450 type Error = ReadError;
451
452 fn poll_data(
453 &mut self,
454 cx: &mut task::Context<'_>,
455 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
456 if let Some(mut stream) = self.stream.take() {
457 self.read_chunk_fut.set(async move {
458 let chunk = stream.read_chunk(usize::MAX, true).await;
459 (stream, chunk)
460 })
461 };
462
463 let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
464 self.stream = Some(stream);
465 Poll::Ready(Ok(chunk?.map(|c| c.bytes)))
466 }
467
468 fn stop_sending(&mut self, error_code: u64) {
469 self.stream
470 .as_mut()
471 .unwrap()
472 .stop(VarInt::from_u64(error_code).expect("invalid error_code"))
473 .ok();
474 }
475
476 fn recv_id(&self) -> StreamId {
477 self.stream
478 .as_ref()
479 .unwrap()
480 .id()
481 .0
482 .try_into()
483 .expect("invalid stream id")
484 }
485}
486
487#[derive(Debug)]
491pub struct ReadError(quinn::ReadError);
492
493impl From<ReadError> for std::io::Error {
494 fn from(value: ReadError) -> Self {
495 value.0.into()
496 }
497}
498
499impl std::error::Error for ReadError {
500 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
501 self.0.source()
502 }
503}
504
505impl fmt::Display for ReadError {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 self.0.fmt(f)
508 }
509}
510
511impl From<ReadError> for Arc<dyn Error> {
512 fn from(e: ReadError) -> Self {
513 Arc::new(e)
514 }
515}
516
517impl From<quinn::ReadError> for ReadError {
518 fn from(e: quinn::ReadError) -> Self {
519 Self(e)
520 }
521}
522
523impl Error for ReadError {
524 fn is_timeout(&self) -> bool {
525 matches!(
526 self.0,
527 quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut)
528 )
529 }
530
531 fn err_code(&self) -> Option<u64> {
532 match self.0 {
533 quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(
534 quinn_proto::ApplicationClose { error_code, .. },
535 )) => Some(error_code.into_inner()),
536 quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()),
537 _ => None,
538 }
539 }
540}
541
542pub struct SendStream<B: Buf> {
546 stream: Option<quinn::SendStream>,
547 writing: Option<WriteBuf<B>>,
548 write_fut: WriteFuture,
549}
550
551type WriteFuture =
552 ReusableBoxFuture<'static, (quinn::SendStream, Result<usize, quinn::WriteError>)>;
553
554impl<B> SendStream<B>
555where
556 B: Buf,
557{
558 fn new(stream: quinn::SendStream) -> SendStream<B> {
559 Self {
560 stream: Some(stream),
561 writing: None,
562 write_fut: ReusableBoxFuture::new(async { unreachable!() }),
563 }
564 }
565}
566
567impl<B> quic::SendStream<B> for SendStream<B>
568where
569 B: Buf,
570{
571 type Error = SendStreamError;
572
573 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
574 if let Some(ref mut data) = self.writing {
575 while data.has_remaining() {
576 if let Some(mut stream) = self.stream.take() {
577 let chunk = data.chunk().to_owned(); self.write_fut.set(async move {
579 let ret = stream.write(&chunk).await;
580 (stream, ret)
581 });
582 }
583
584 let (stream, res) = ready!(self.write_fut.poll(cx));
585 self.stream = Some(stream);
586 match res {
587 Ok(cnt) => data.advance(cnt),
588 Err(err) => {
589 return Poll::Ready(Err(SendStreamError::Write(err)));
590 }
591 }
592 }
593 }
594 self.writing = None;
595 Poll::Ready(Ok(()))
596 }
597
598 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
599 self.stream
600 .as_mut()
601 .unwrap()
602 .poll_finish(cx)
603 .map_err(Into::into)
604 }
605
606 fn reset(&mut self, reset_code: u64) {
607 let _ = self
608 .stream
609 .as_mut()
610 .unwrap()
611 .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
612 }
613
614 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
615 if self.writing.is_some() {
616 return Err(Self::Error::NotReady);
617 }
618 self.writing = Some(data.into());
619 Ok(())
620 }
621
622 fn send_id(&self) -> StreamId {
623 self.stream
624 .as_ref()
625 .unwrap()
626 .id()
627 .0
628 .try_into()
629 .expect("invalid stream id")
630 }
631}
632
633impl<B> quic::SendStreamUnframed<B> for SendStream<B>
634where
635 B: Buf,
636{
637 fn poll_send<D: Buf>(
638 &mut self,
639 cx: &mut task::Context<'_>,
640 buf: &mut D,
641 ) -> Poll<Result<usize, Self::Error>> {
642 if self.writing.is_some() {
643 panic!("poll_send called while send stream is not ready")
645 }
646
647 let s = Pin::new(self.stream.as_mut().unwrap());
648
649 let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk()));
650 match res {
651 Ok(written) => {
652 buf.advance(written);
653 Poll::Ready(Ok(written))
654 }
655 Err(err) => {
656 let err = err
664 .into_inner()
665 .expect("write stream returned an empty error")
666 .downcast::<WriteError>()
667 .expect("write stream returned an error which type is not WriteError");
668
669 Poll::Ready(Err(SendStreamError::Write(*err)))
670 }
671 }
672 }
673}
674
675#[derive(Debug)]
679pub enum SendStreamError {
680 Write(WriteError),
682 NotReady,
685}
686
687impl From<SendStreamError> for std::io::Error {
688 fn from(value: SendStreamError) -> Self {
689 match value {
690 SendStreamError::Write(err) => err.into(),
691 SendStreamError::NotReady => {
692 std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready")
693 }
694 }
695 }
696}
697
698impl std::error::Error for SendStreamError {}
699
700impl Display for SendStreamError {
701 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
702 write!(f, "{:?}", self)
703 }
704}
705
706impl From<WriteError> for SendStreamError {
707 fn from(e: WriteError) -> Self {
708 Self::Write(e)
709 }
710}
711
712impl Error for SendStreamError {
713 fn is_timeout(&self) -> bool {
714 matches!(
715 self,
716 Self::Write(quinn::WriteError::ConnectionLost(
717 quinn::ConnectionError::TimedOut
718 ))
719 )
720 }
721
722 fn err_code(&self) -> Option<u64> {
723 match self {
724 Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()),
725 Self::Write(quinn::WriteError::ConnectionLost(
726 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
727 error_code,
728 ..
729 }),
730 )) => Some(error_code.into_inner()),
731 _ => None,
732 }
733 }
734}
735
736impl From<SendStreamError> for Arc<dyn Error> {
737 fn from(e: SendStreamError) -> Self {
738 Arc::new(e)
739 }
740}