tycho_network/network/
connection.rs1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use anyhow::{Context as _, Result};
7use bytes::Bytes;
8use quinn::{ConnectionError, SendDatagramError};
9use webpki::types::CertificateDer;
10
11use crate::network::crypto::peer_id_from_certificate;
12use crate::types::{Direction, InboundRequestMeta, PeerId};
13
14#[derive(Clone)]
15pub struct Connection {
16 inner: quinn::Connection,
17 request_meta: Arc<InboundRequestMeta>,
18}
19
20impl Connection {
21 pub fn new(inner: quinn::Connection, origin: Direction) -> Result<Self> {
22 let peer_id = extract_peer_id(&inner)?;
23 Ok(Self::with_peer_id(inner, origin, peer_id))
24 }
25
26 pub fn with_peer_id(inner: quinn::Connection, origin: Direction, peer_id: PeerId) -> Self {
27 Self {
28 request_meta: Arc::new(InboundRequestMeta {
29 peer_id,
30 origin,
31 remote_address: inner.remote_address(),
32 }),
33 inner,
34 }
35 }
36
37 pub fn request_meta(&self) -> &Arc<InboundRequestMeta> {
38 &self.request_meta
39 }
40
41 pub fn peer_id(&self) -> &PeerId {
42 &self.request_meta.peer_id
43 }
44
45 pub fn stable_id(&self) -> usize {
46 self.inner.stable_id()
47 }
48
49 pub fn origin(&self) -> Direction {
50 self.request_meta.origin
51 }
52
53 pub fn remote_address(&self) -> SocketAddr {
54 self.request_meta.remote_address
55 }
56
57 pub fn close(&self) {
58 self.inner.close(0u8.into(), b"connection closed");
59 }
60
61 pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
62 self.inner
63 .open_bi()
64 .await
65 .map(|(send, recv)| (SendStream(send), RecvStream(recv)))
66 }
67
68 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
69 self.inner
70 .accept_bi()
71 .await
72 .map(|(send, recv)| (SendStream(send), RecvStream(recv)))
73 }
74
75 pub async fn open_uni(&self) -> Result<SendStream, ConnectionError> {
76 self.inner.open_uni().await.map(SendStream)
77 }
78
79 pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
80 self.inner.accept_uni().await.map(RecvStream)
81 }
82
83 pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
84 self.inner.send_datagram(data)
85 }
86
87 pub async fn read_datagram(&self) -> Result<Bytes, ConnectionError> {
88 self.inner.read_datagram().await
89 }
90}
91
92impl std::fmt::Debug for Connection {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("Connection")
95 .field("origin", &self.request_meta.origin)
96 .field("id", &self.stable_id())
97 .field("remote_address", &self.remote_address())
98 .field("peer_id", &self.request_meta.peer_id)
99 .finish_non_exhaustive()
100 }
101}
102
103#[repr(transparent)]
104pub struct SendStream(quinn::SendStream);
105
106impl Drop for SendStream {
107 fn drop(&mut self) {
108 _ = self.0.reset(0u8.into());
109 }
110}
111
112impl std::ops::Deref for SendStream {
113 type Target = quinn::SendStream;
114
115 #[inline]
116 fn deref(&self) -> &Self::Target {
117 &self.0
118 }
119}
120
121impl std::ops::DerefMut for SendStream {
122 #[inline]
123 fn deref_mut(&mut self) -> &mut Self::Target {
124 &mut self.0
125 }
126}
127
128impl tokio::io::AsyncWrite for SendStream {
129 #[inline]
130 fn poll_write(
131 mut self: Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 buf: &[u8],
134 ) -> Poll<Result<usize, std::io::Error>> {
135 Pin::new(&mut self.0)
136 .poll_write(cx, buf)
137 .map_err(std::io::Error::from)
138 }
139
140 #[inline]
141 fn poll_flush(
142 mut self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Result<(), std::io::Error>> {
145 Pin::new(&mut self.0).poll_flush(cx)
146 }
147
148 #[inline]
149 fn poll_shutdown(
150 mut self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 ) -> Poll<Result<(), std::io::Error>> {
153 Pin::new(&mut self.0).poll_flush(cx)
154 }
155}
156
157#[repr(transparent)]
158pub struct RecvStream(quinn::RecvStream);
159
160impl std::ops::Deref for RecvStream {
161 type Target = quinn::RecvStream;
162
163 #[inline]
164 fn deref(&self) -> &Self::Target {
165 &self.0
166 }
167}
168
169impl std::ops::DerefMut for RecvStream {
170 #[inline]
171 fn deref_mut(&mut self) -> &mut Self::Target {
172 &mut self.0
173 }
174}
175
176impl tokio::io::AsyncRead for RecvStream {
177 #[inline]
178 fn poll_read(
179 mut self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 buf: &mut tokio::io::ReadBuf<'_>,
182 ) -> Poll<std::io::Result<()>> {
183 Pin::new(&mut self.0).poll_read(cx, buf)
184 }
185}
186
187pub(crate) fn extract_peer_id(connection: &quinn::Connection) -> Result<PeerId> {
188 parse_peer_identity(
189 connection
190 .peer_identity()
191 .context("No identity found in the connection")?,
192 )
193}
194
195pub(crate) fn parse_peer_identity(identity: Box<dyn std::any::Any>) -> Result<PeerId> {
196 let certificate = identity
197 .downcast::<Vec<CertificateDer<'static>>>()
198 .ok()
199 .and_then(|certificates| certificates.into_iter().next())
200 .context("No certificate found in the connection")?;
201
202 peer_id_from_certificate(&certificate).map_err(Into::into)
203}