tycho_network/network/
connection.rs

1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use anyhow::{Context as _, Result};
7use bytes::Bytes;
8use quinn::{ConnectionError, SendDatagramError};
9use webpki::types::CertificateDer;
10
11use crate::network::crypto::peer_id_from_certificate;
12use crate::types::{Direction, InboundRequestMeta, PeerId};
13
14#[derive(Clone)]
15pub struct Connection {
16    inner: quinn::Connection,
17    request_meta: Arc<InboundRequestMeta>,
18}
19
20impl Connection {
21    pub fn new(inner: quinn::Connection, origin: Direction) -> Result<Self> {
22        let peer_id = extract_peer_id(&inner)?;
23        Ok(Self::with_peer_id(inner, origin, peer_id))
24    }
25
26    pub fn with_peer_id(inner: quinn::Connection, origin: Direction, peer_id: PeerId) -> Self {
27        Self {
28            request_meta: Arc::new(InboundRequestMeta {
29                peer_id,
30                origin,
31                remote_address: inner.remote_address(),
32            }),
33            inner,
34        }
35    }
36
37    pub fn request_meta(&self) -> &Arc<InboundRequestMeta> {
38        &self.request_meta
39    }
40
41    pub fn peer_id(&self) -> &PeerId {
42        &self.request_meta.peer_id
43    }
44
45    pub fn stable_id(&self) -> usize {
46        self.inner.stable_id()
47    }
48
49    pub fn origin(&self) -> Direction {
50        self.request_meta.origin
51    }
52
53    pub fn remote_address(&self) -> SocketAddr {
54        self.request_meta.remote_address
55    }
56
57    pub fn close(&self) {
58        self.inner.close(0u8.into(), b"connection closed");
59    }
60
61    pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
62        self.inner
63            .open_bi()
64            .await
65            .map(|(send, recv)| (SendStream(send), RecvStream(recv)))
66    }
67
68    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
69        self.inner
70            .accept_bi()
71            .await
72            .map(|(send, recv)| (SendStream(send), RecvStream(recv)))
73    }
74
75    pub async fn open_uni(&self) -> Result<SendStream, ConnectionError> {
76        self.inner.open_uni().await.map(SendStream)
77    }
78
79    pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
80        self.inner.accept_uni().await.map(RecvStream)
81    }
82
83    pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
84        self.inner.send_datagram(data)
85    }
86
87    pub async fn read_datagram(&self) -> Result<Bytes, ConnectionError> {
88        self.inner.read_datagram().await
89    }
90}
91
92impl std::fmt::Debug for Connection {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("Connection")
95            .field("origin", &self.request_meta.origin)
96            .field("id", &self.stable_id())
97            .field("remote_address", &self.remote_address())
98            .field("peer_id", &self.request_meta.peer_id)
99            .finish_non_exhaustive()
100    }
101}
102
103#[repr(transparent)]
104pub struct SendStream(quinn::SendStream);
105
106impl Drop for SendStream {
107    fn drop(&mut self) {
108        _ = self.0.reset(0u8.into());
109    }
110}
111
112impl std::ops::Deref for SendStream {
113    type Target = quinn::SendStream;
114
115    #[inline]
116    fn deref(&self) -> &Self::Target {
117        &self.0
118    }
119}
120
121impl std::ops::DerefMut for SendStream {
122    #[inline]
123    fn deref_mut(&mut self) -> &mut Self::Target {
124        &mut self.0
125    }
126}
127
128impl tokio::io::AsyncWrite for SendStream {
129    #[inline]
130    fn poll_write(
131        mut self: Pin<&mut Self>,
132        cx: &mut Context<'_>,
133        buf: &[u8],
134    ) -> Poll<Result<usize, std::io::Error>> {
135        Pin::new(&mut self.0)
136            .poll_write(cx, buf)
137            .map_err(std::io::Error::from)
138    }
139
140    #[inline]
141    fn poll_flush(
142        mut self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144    ) -> Poll<Result<(), std::io::Error>> {
145        Pin::new(&mut self.0).poll_flush(cx)
146    }
147
148    #[inline]
149    fn poll_shutdown(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152    ) -> Poll<Result<(), std::io::Error>> {
153        Pin::new(&mut self.0).poll_flush(cx)
154    }
155}
156
157#[repr(transparent)]
158pub struct RecvStream(quinn::RecvStream);
159
160impl std::ops::Deref for RecvStream {
161    type Target = quinn::RecvStream;
162
163    #[inline]
164    fn deref(&self) -> &Self::Target {
165        &self.0
166    }
167}
168
169impl std::ops::DerefMut for RecvStream {
170    #[inline]
171    fn deref_mut(&mut self) -> &mut Self::Target {
172        &mut self.0
173    }
174}
175
176impl tokio::io::AsyncRead for RecvStream {
177    #[inline]
178    fn poll_read(
179        mut self: Pin<&mut Self>,
180        cx: &mut Context<'_>,
181        buf: &mut tokio::io::ReadBuf<'_>,
182    ) -> Poll<std::io::Result<()>> {
183        Pin::new(&mut self.0).poll_read(cx, buf)
184    }
185}
186
187pub(crate) fn extract_peer_id(connection: &quinn::Connection) -> Result<PeerId> {
188    parse_peer_identity(
189        connection
190            .peer_identity()
191            .context("No identity found in the connection")?,
192    )
193}
194
195pub(crate) fn parse_peer_identity(identity: Box<dyn std::any::Any>) -> Result<PeerId> {
196    let certificate = identity
197        .downcast::<Vec<CertificateDer<'static>>>()
198        .ok()
199        .and_then(|certificates| certificates.into_iter().next())
200        .context("No certificate found in the connection")?;
201
202    peer_id_from_certificate(&certificate).map_err(Into::into)
203}