1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 collections::HashMap,
5 ops::DerefMut,
6 time::{Duration, Instant},
7};
8
9use crossbeam_utils::sync::Unparker;
10use parking_lot::{Mutex, RwLock};
11use zerortt_api::{
12 Acceptor, Error, Event, QuicClient, QuicPoll, QuicServerTransport, QuicTransport, Result,
13 StreamKind, Token,
14 quiche::{self, ConnectionId, RecvInfo, SendInfo},
15 random_conn_id,
16};
17
18#[cfg(feature = "server")]
19use zerortt_api::Handshake;
20
21use crate::{
22 conn::{LocKind, LockContext, QuicConn},
23 readiness::Readiness,
24 utils::release_time,
25};
26
27static DEFAULT_RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
28
29macro_rules! lock {
30 ($self:ident, $token: ident, $kind: expr) => {{
31 let state = $self.state.lock();
32
33 let conn = state
34 .conns
35 .get(&$token)
36 .ok_or_else(|| zerortt_api::Error::NotFound)?
37 .borrow_mut()
38 .try_lock($kind, |ctx| $self.unlock(ctx))?;
39
40 drop(state);
41
42 conn
43 }};
44}
45
46#[derive(Default)]
47struct State {
48 token_next: u32,
49 conns: HashMap<Token, RefCell<QuicConn>>,
50 readiness: RefCell<Readiness>,
51 unparkers: HashMap<Token, Unparker>,
52}
53
54pub struct Group {
56 state: Mutex<State>,
57 scids: RwLock<HashMap<ConnectionId<'static>, Token>>,
58}
59
60impl Default for Group {
61 fn default() -> Self {
62 Self {
63 state: Default::default(),
64 scids: Default::default(),
65 }
66 }
67}
68
69impl Group {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 fn unlock(&self, ctx: LockContext) {
76 let mut state = self.state.lock();
77
78 if let Some(unparker) = state.unparkers.remove(&ctx.token) {
79 unparker.unpark();
80 }
81
82 if state
83 .conns
84 .get(&ctx.token)
85 .expect("Unlock.")
86 .borrow_mut()
87 .unlock(
88 ctx.lock_count,
89 ctx.send_done,
90 state.readiness.borrow_mut().deref_mut(),
91 )
92 {
93 log::trace!(
94 "automatic deregister closed connection, token={:?}",
95 ctx.token
96 );
97
98 drop(state);
99 _ = self.deregister(ctx.token);
100 }
101 }
102
103 pub fn recv_with_connection_id(
106 &self,
107 scid: &ConnectionId<'_>,
108 buf: &mut [u8],
109 info: RecvInfo,
110 unparker: Option<&Unparker>,
111 ) -> Result<(Token, usize)> {
112 let token = self
113 .scids
114 .read()
115 .get(&scid)
116 .ok_or_else(|| Error::NotFound)?
117 .clone();
118
119 let mut state = self.state.lock();
120
121 let Ok(mut conn) = state
122 .conns
123 .get(&token)
124 .ok_or_else(|| Error::NotFound)?
125 .borrow_mut()
126 .try_lock(LocKind::Recv, |ctx| self.unlock(ctx))
127 else {
128 if let Some(unparker) = unparker {
130 state.unparkers.insert(token, unparker.clone());
131 }
132
133 return Err(Error::Busy);
134 };
135
136 drop(state);
137
138 match conn.recv(buf, info) {
139 Ok(recv_size) => {
140 log::trace!(
141 "Connection recv, scid={:?}, len={}",
142 conn.source_id(),
143 recv_size
144 );
145
146 Ok((token, recv_size))
147 }
148 Err(err) => {
149 log::error!("Connection recv, scid={:?}, err={}", conn.source_id(), err);
150 Err(Error::Quiche(err))
151 }
152 }
153 }
154}
155
156impl QuicPoll for Group {
157 type Error = zerortt_api::Error;
158 #[inline]
160 fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
161 let mut state = self.state.lock();
162
163 loop {
164 let token = Token(state.token_next);
165
166 (state.token_next, _) = state.token_next.overflowing_add(1);
167
168 if state.conns.contains_key(&token) {
169 continue;
170 }
171
172 assert!(
173 self.scids
174 .write()
175 .insert(wrapped.source_id().into_owned(), token)
176 .is_none()
177 );
178
179 log::trace!(
180 "register quic connection, token={:?}, trace_id={}",
181 token,
182 wrapped.trace_id()
183 );
184
185 let conn = RefCell::new(QuicConn::new(token, wrapped));
186
187 let guard = conn.borrow_mut().try_lock(LocKind::ReadLock, |context| {
188 conn.borrow_mut().unlock(
189 context.lock_count,
190 false,
191 state.readiness.borrow_mut().deref_mut(),
192 );
193 })?;
194
195 drop(guard);
196
197 state.conns.insert(token, conn);
198
199 return Ok(token);
200 }
201 }
202
203 #[inline]
205 fn deregister(&self, token: Token) -> Result<quiche::Connection> {
206 let mut state = self.state.lock();
207
208 let conn: quiche::Connection = state
209 .conns
210 .remove(&token)
211 .ok_or_else(|| Error::NotFound)?
212 .into_inner()
213 .into();
214
215 drop(state);
216
217 assert_eq!(
218 self.scids.write().remove(&conn.source_id().into_owned()),
219 Some(token)
220 );
221
222 Ok(conn)
223 }
224
225 #[inline]
227 fn len(&self) -> usize {
228 self.state.lock().conns.len()
229 }
230
231 #[inline]
233 fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
234 let state = self.state.lock();
235 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
236
237 conn.borrow_mut()
238 .close(app, err, reason, state.readiness.borrow_mut().deref_mut())
239 }
240
241 #[inline]
243 fn stream_open(
244 &self,
245 token: Token,
246 kind: StreamKind,
247 non_blocking: bool,
248 ) -> Result<Option<u64>> {
249 let state = self.state.lock();
250 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
251
252 conn.borrow_mut()
253 .stream_open(kind, non_blocking, state.readiness.borrow_mut().deref_mut())
254 }
255
256 #[inline]
258 fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
259 let state = self.state.lock();
260 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
261
262 conn.borrow_mut()
263 .stream_close(stream_id, err, state.readiness.borrow_mut().deref_mut())
264 }
265
266 #[inline]
268 fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
269 let mut conn = lock!(
270 self,
271 token,
272 LocKind::StreamSend {
273 id: stream_id,
274 len: buf.len()
275 }
276 );
277
278 match conn.stream_send(stream_id, buf, fin) {
279 Ok(send_size) => {
280 log::trace!(
281 "stream send, scid={:?}, stream_id={}, len={}, fin={}",
282 conn.source_id(),
283 stream_id,
284 send_size,
285 fin
286 );
287 return Ok(send_size);
288 }
289 Err(quiche::Error::Done) => {
290 log::trace!(
291 "stream send, scid={:?}, stream_id={}, fin={}, Done",
292 conn.source_id(),
293 stream_id,
294 fin
295 );
296 return Err(Error::Retry);
297 }
298 Err(err) => {
299 log::error!(
300 "stream send, scid={:?}, stream_id={}, fin={}, err={}",
301 conn.source_id(),
302 stream_id,
303 fin,
304 err
305 );
306
307 return Err(Error::Quiche(err));
308 }
309 }
310 }
311
312 #[inline]
314 fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
315 let mut conn = lock!(self, token, LocKind::StreamRecv(stream_id));
316
317 match conn.stream_recv(stream_id, buf) {
318 Ok((recv_size, fin)) => {
319 log::trace!(
320 "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
321 conn.source_id(),
322 stream_id,
323 recv_size,
324 fin,
325 conn.is_server(),
326 );
327 return Ok((recv_size, fin));
328 }
329 Err(quiche::Error::Done) => {
330 if conn.stream_finished(stream_id) {
331 log::trace!(
332 "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
333 conn.source_id(),
334 stream_id,
335 0,
336 true,
337 conn.is_server(),
338 );
339
340 return Ok((0, true));
341 }
342
343 log::trace!(
344 "stream recv, scid={:?}, stream_id={}, is_server={}, Done",
345 conn.source_id(),
346 stream_id,
347 conn.is_server(),
348 );
349 return Err(Error::Retry);
350 }
351 Err(err) => {
352 log::error!(
353 "stream recv, scid={:?}, stream_id={}, is_server={}, err={}",
354 conn.source_id(),
355 stream_id,
356 conn.is_server(),
357 err
358 );
359
360 return Err(Error::Quiche(err));
361 }
362 }
363 }
364
365 #[inline]
367 fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
368 let state = self.state.lock();
369
370 Ok(state
371 .readiness
372 .borrow_mut()
373 .poll(events, DEFAULT_RELEASE_TIMER_THRESHOLD))
374 }
375}
376
377impl QuicTransport for Group {
378 type Error = zerortt_api::Error;
379 #[inline]
381 fn recv(&self, buf: &mut [u8], info: RecvInfo) -> Result<usize> {
382 let header =
383 quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(Error::Quiche)?;
384
385 self.recv_with_connection_id(&header.dcid, buf, info, None)
386 .map(|(_, recv_size)| recv_size)
387 }
388
389 #[inline]
391 fn send(&self, token: Token, buf: &mut [u8]) -> Result<(usize, SendInfo)> {
392 let mut conn = lock!(self, token, LocKind::Recv);
393
394 if let Some(release_time) =
395 release_time(&conn, Instant::now(), DEFAULT_RELEASE_TIMER_THRESHOLD)
396 {
397 log::trace!(
398 "connection send, scid={:?}, next_release_time={:?}",
399 conn.trace_id(),
400 release_time,
401 );
402 return Err(Error::Retry);
403 }
404
405 conn.on_timeout();
407
408 match conn.send(buf) {
409 Ok((send_size, send_info)) => {
410 log::trace!(
411 "connection send, scid={:?}, send_size={}, send_info={:?}",
412 conn.trace_id(),
413 send_size,
414 send_info
415 );
416 return Ok((send_size, send_info));
417 }
418 Err(quiche::Error::Done) => {
419 log::trace!("connection send, scid={:?}, done", conn.trace_id());
420 conn.send_done();
421 return Err(Error::Retry);
422 }
423 Err(err) => {
424 log::error!("connection send, scid={:?}, err={}", conn.trace_id(), err);
425 return Err(Error::Quiche(err));
426 }
427 }
428 }
429}
430
431#[cfg(feature = "server")]
432impl QuicServerTransport for Group {
433 fn recv_with_acceptor(
434 &self,
435 acceptor: &mut Acceptor,
436 buf: &mut [u8],
437 recv_size: usize,
438 recv_info: RecvInfo,
439 unparker: Option<&Unparker>,
440 ) -> Result<(usize, SendInfo)> {
441 let header = quiche::Header::from_slice(&mut buf[..recv_size], quiche::MAX_CONN_ID_LEN)
442 .map_err(Error::Quiche)?;
443
444 match self.recv_with_connection_id(&header.dcid, &mut buf[..recv_size], recv_info, unparker)
445 {
446 Ok((token, _)) => match self.send(token, buf) {
447 Err(Error::Busy) | Err(Error::Retry) => Ok((
448 0,
449 SendInfo {
450 at: Instant::now(),
451 from: recv_info.to,
452 to: recv_info.from,
453 },
454 )),
455 r => r,
456 },
457 Err(Error::NotFound) => match acceptor.handshake(&header, buf, recv_size, recv_info) {
458 Ok(Handshake::Accept(conn)) => {
459 let token = self.register(conn)?;
460
461 match self.recv_with_connection_id(
463 &header.dcid,
464 &mut buf[..recv_size],
465 recv_info,
466 None,
467 ) {
468 Ok(_) => {}
469 Err(Error::Busy) | Err(Error::Retry) => {
470 unreachable!("Newly registered connections should be idle");
471 }
472 Err(err) => return Err(err),
473 }
474
475 match self.send(token, buf) {
476 Err(Error::Busy) | Err(Error::Retry) => Ok((
477 0,
478 SendInfo {
479 at: Instant::now(),
480 from: recv_info.to,
481 to: recv_info.from,
482 },
483 )),
484 r => r,
485 }
486 }
487 Ok(Handshake::Handshake(send_size)) => Ok((
488 send_size,
489 SendInfo {
490 at: Instant::now(),
491 from: recv_info.to,
492 to: recv_info.from,
493 },
494 )),
495 Err(err) => Err(err),
496 },
497 Err(err) => Err(err),
498 }
499 }
500}
501
502#[cfg(feature = "server")]
503impl QuicClient for Group {
504 fn connect(
505 &self,
506 server_name: Option<&str>,
507 local: std::net::SocketAddr,
508 peer: std::net::SocketAddr,
509 config: &mut quiche::Config,
510 ) -> Result<Token> {
511 let conn = quiche::connect(server_name, &random_conn_id(), local, peer, config)?;
512
513 let token = self.register(conn)?;
514
515 Ok(token)
516 }
517}