1use std::{
2 borrow::Cow,
3 cell::RefCell,
4 collections::HashMap,
5 ops::DerefMut,
6 time::{Duration, Instant},
7};
8
9use crossbeam_utils::sync::Unparker;
10use parking_lot::{Mutex, RwLock};
11use zerortt_api::{
12 Acceptor, Error, Event, QuicClient, QuicPoll, QuicServerTransport, QuicTransport, Result,
13 StreamKind, Token,
14 quiche::{self, ConnectionId, RecvInfo, SendInfo},
15 random_conn_id,
16};
17
18#[cfg(feature = "server")]
19use zerortt_api::Handshake;
20
21use crate::{
22 conn::{LocKind, LockContext, QuicConn},
23 readiness::Readiness,
24 utils::release_time,
25};
26
27static DEFAULT_RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
28
29macro_rules! lock {
30 ($self:ident, $token: ident, $kind: expr) => {{
31 let state = $self.state.lock();
32
33 let conn = state
34 .conns
35 .get(&$token)
36 .ok_or_else(|| zerortt_api::Error::NotFound)?
37 .borrow_mut()
38 .try_lock($kind, |ctx| $self.unlock(ctx))?;
39
40 drop(state);
41
42 conn
43 }};
44}
45
46#[derive(Default)]
47struct State {
48 token_next: u32,
49 conns: HashMap<Token, RefCell<QuicConn>>,
50 readiness: RefCell<Readiness>,
51 unparkers: HashMap<Token, Unparker>,
52}
53
54pub struct Group {
56 state: Mutex<State>,
57 scids: RwLock<HashMap<ConnectionId<'static>, Token>>,
58}
59
60impl Default for Group {
61 fn default() -> Self {
62 Self {
63 state: Default::default(),
64 scids: Default::default(),
65 }
66 }
67}
68
69impl Group {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 fn unlock(&self, ctx: LockContext) {
76 let mut state = self.state.lock();
77
78 if let Some(unparker) = state.unparkers.remove(&ctx.token) {
79 unparker.unpark();
80 }
81
82 if state
83 .conns
84 .get(&ctx.token)
85 .expect("Unlock.")
86 .borrow_mut()
87 .unlock(
88 ctx.lock_count,
89 ctx.send_done,
90 state.readiness.borrow_mut().deref_mut(),
91 )
92 {
93 log::trace!(
94 "automatic deregister closed connection, token={:?}",
95 ctx.token
96 );
97
98 drop(state);
99 _ = self.deregister(ctx.token);
100 }
101 }
102
103 pub fn recv_with_connection_id(
106 &self,
107 scid: &ConnectionId<'_>,
108 buf: &mut [u8],
109 info: RecvInfo,
110 unparker: Option<&Unparker>,
111 ) -> Result<(Token, usize)> {
112 let token = self
113 .scids
114 .read()
115 .get(&scid)
116 .ok_or_else(|| Error::NotFound)?
117 .clone();
118
119 let mut state = self.state.lock();
120
121 let Ok(mut conn) = state
122 .conns
123 .get(&token)
124 .ok_or_else(|| Error::NotFound)?
125 .borrow_mut()
126 .try_lock(LocKind::Recv, |ctx| self.unlock(ctx))
127 else {
128 if let Some(unparker) = unparker {
130 state.unparkers.insert(token, unparker.clone());
131 }
132
133 return Err(Error::Busy);
134 };
135
136 drop(state);
137
138 match conn.recv(buf, info) {
139 Ok(recv_size) => {
140 log::trace!(
141 "Connection recv, scid={:?}, len={}",
142 conn.source_id(),
143 recv_size
144 );
145
146 Ok((token, recv_size))
147 }
148 Err(err) => {
149 log::error!("Connection recv, scid={:?}, err={}", conn.source_id(), err);
150 Err(Error::Quiche(err))
151 }
152 }
153 }
154}
155
156impl QuicPoll for Group {
157 type Error = zerortt_api::Error;
158 #[inline]
160 fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
161 let mut state = self.state.lock();
162
163 loop {
164 let token = Token(state.token_next);
165
166 (state.token_next, _) = state.token_next.overflowing_add(1);
167
168 if state.conns.contains_key(&token) {
169 continue;
170 }
171
172 assert!(
173 self.scids
174 .write()
175 .insert(wrapped.source_id().into_owned(), token)
176 .is_none()
177 );
178
179 log::trace!(
180 "register quic connection, token={:?}, trace_id={}",
181 token,
182 wrapped.trace_id()
183 );
184
185 let conn = RefCell::new(QuicConn::new(token, wrapped));
186
187 let guard = conn.borrow_mut().try_lock(LocKind::ReadLock, |context| {
188 conn.borrow_mut().unlock(
189 context.lock_count,
190 false,
191 state.readiness.borrow_mut().deref_mut(),
192 );
193 })?;
194
195 drop(guard);
196
197 state.conns.insert(token, conn);
198
199 return Ok(token);
200 }
201 }
202
203 #[inline]
205 fn deregister(&self, token: Token) -> Result<quiche::Connection> {
206 let mut state = self.state.lock();
207
208 let conn: quiche::Connection = state
209 .conns
210 .remove(&token)
211 .ok_or_else(|| Error::NotFound)?
212 .into_inner()
213 .into();
214
215 drop(state);
216
217 assert_eq!(
218 self.scids.write().remove(&conn.source_id().into_owned()),
219 Some(token)
220 );
221
222 Ok(conn)
223 }
224
225 #[inline]
227 fn len(&self) -> usize {
228 self.state.lock().conns.len()
229 }
230
231 #[inline]
233 fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
234 let state = self.state.lock();
235 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
236
237 conn.borrow_mut()
238 .close(app, err, reason, state.readiness.borrow_mut().deref_mut())
239 }
240
241 #[inline]
243 fn stream_open(
244 &self,
245 token: Token,
246 kind: StreamKind,
247 max_streams_as_error: bool,
248 ) -> Result<u64> {
249 let state = self.state.lock();
250 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
251
252 conn.borrow_mut().stream_open(
253 kind,
254 max_streams_as_error,
255 state.readiness.borrow_mut().deref_mut(),
256 )
257 }
258
259 #[inline]
261 fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
262 let state = self.state.lock();
263 let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
264
265 conn.borrow_mut()
266 .stream_close(stream_id, err, state.readiness.borrow_mut().deref_mut())
267 }
268
269 #[inline]
271 fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
272 let mut conn = lock!(
273 self,
274 token,
275 LocKind::StreamSend {
276 id: stream_id,
277 len: buf.len()
278 }
279 );
280
281 match conn.stream_send(stream_id, buf, fin) {
282 Ok(send_size) => {
283 log::trace!(
284 "stream send, scid={:?}, stream_id={}, len={}, fin={}",
285 conn.source_id(),
286 stream_id,
287 send_size,
288 fin
289 );
290 return Ok(send_size);
291 }
292 Err(quiche::Error::Done) => {
293 log::trace!(
294 "stream send, scid={:?}, stream_id={}, fin={}, Done",
295 conn.source_id(),
296 stream_id,
297 fin
298 );
299 return Err(Error::Retry);
300 }
301 Err(err) => {
302 log::error!(
303 "stream send, scid={:?}, stream_id={}, fin={}, err={}",
304 conn.source_id(),
305 stream_id,
306 fin,
307 err
308 );
309
310 return Err(Error::Quiche(err));
311 }
312 }
313 }
314
315 #[inline]
317 fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
318 let mut conn = lock!(self, token, LocKind::StreamRecv(stream_id));
319
320 match conn.stream_recv(stream_id, buf) {
321 Ok((recv_size, fin)) => {
322 log::trace!(
323 "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
324 conn.source_id(),
325 stream_id,
326 recv_size,
327 fin,
328 conn.is_server(),
329 );
330 return Ok((recv_size, fin));
331 }
332 Err(quiche::Error::Done) => {
333 if conn.stream_finished(stream_id) {
334 log::trace!(
335 "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
336 conn.source_id(),
337 stream_id,
338 0,
339 true,
340 conn.is_server(),
341 );
342
343 return Ok((0, true));
344 }
345
346 log::trace!(
347 "stream recv, scid={:?}, stream_id={}, is_server={}, Done",
348 conn.source_id(),
349 stream_id,
350 conn.is_server(),
351 );
352 return Err(Error::Retry);
353 }
354 Err(err) => {
355 log::error!(
356 "stream recv, scid={:?}, stream_id={}, is_server={}, err={}",
357 conn.source_id(),
358 stream_id,
359 conn.is_server(),
360 err
361 );
362
363 return Err(Error::Quiche(err));
364 }
365 }
366 }
367
368 #[inline]
370 fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
371 let state = self.state.lock();
372
373 Ok(state
374 .readiness
375 .borrow_mut()
376 .poll(events, DEFAULT_RELEASE_TIMER_THRESHOLD))
377 }
378}
379
380impl QuicTransport for Group {
381 type Error = zerortt_api::Error;
382 #[inline]
384 fn recv(&self, buf: &mut [u8], info: RecvInfo) -> Result<usize> {
385 let header =
386 quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(Error::Quiche)?;
387
388 self.recv_with_connection_id(&header.dcid, buf, info, None)
389 .map(|(_, recv_size)| recv_size)
390 }
391
392 #[inline]
394 fn send(&self, token: Token, buf: &mut [u8]) -> Result<(usize, SendInfo)> {
395 let mut conn = lock!(self, token, LocKind::Recv);
396
397 if let Some(release_time) =
398 release_time(&conn, Instant::now(), DEFAULT_RELEASE_TIMER_THRESHOLD)
399 {
400 log::trace!(
401 "connection send, scid={:?}, next_release_time={:?}",
402 conn.trace_id(),
403 release_time,
404 );
405 return Err(Error::Retry);
406 }
407
408 conn.on_timeout();
410
411 match conn.send(buf) {
412 Ok((send_size, send_info)) => {
413 log::trace!(
414 "connection send, scid={:?}, send_size={}, send_info={:?}",
415 conn.trace_id(),
416 send_size,
417 send_info
418 );
419 return Ok((send_size, send_info));
420 }
421 Err(quiche::Error::Done) => {
422 log::trace!("connection send, scid={:?}, done", conn.trace_id());
423 conn.send_done();
424 return Err(Error::Retry);
425 }
426 Err(err) => {
427 log::error!("connection send, scid={:?}, err={}", conn.trace_id(), err);
428 return Err(Error::Quiche(err));
429 }
430 }
431 }
432}
433
434#[cfg(feature = "server")]
435impl QuicServerTransport for Group {
436 fn recv_with_acceptor(
437 &self,
438 acceptor: &mut Acceptor,
439 buf: &mut [u8],
440 recv_size: usize,
441 recv_info: RecvInfo,
442 unparker: Option<&Unparker>,
443 ) -> Result<(usize, SendInfo)> {
444 let header = quiche::Header::from_slice(&mut buf[..recv_size], quiche::MAX_CONN_ID_LEN)
445 .map_err(Error::Quiche)?;
446
447 match self.recv_with_connection_id(&header.dcid, &mut buf[..recv_size], recv_info, unparker)
448 {
449 Ok((token, _)) => match self.send(token, buf) {
450 Err(Error::Busy) | Err(Error::Retry) => Ok((
451 0,
452 SendInfo {
453 at: Instant::now(),
454 from: recv_info.to,
455 to: recv_info.from,
456 },
457 )),
458 r => r,
459 },
460 Err(Error::NotFound) => match acceptor.handshake(&header, buf, recv_size, recv_info) {
461 Ok(Handshake::Accept(conn)) => {
462 let token = self.register(conn)?;
463
464 match self.recv_with_connection_id(
466 &header.dcid,
467 &mut buf[..recv_size],
468 recv_info,
469 None,
470 ) {
471 Ok(_) => {}
472 Err(Error::Busy) | Err(Error::Retry) => {
473 unreachable!("Newly registered connections should be idle");
474 }
475 Err(err) => return Err(err),
476 }
477
478 match self.send(token, buf) {
479 Err(Error::Busy) | Err(Error::Retry) => Ok((
480 0,
481 SendInfo {
482 at: Instant::now(),
483 from: recv_info.to,
484 to: recv_info.from,
485 },
486 )),
487 r => r,
488 }
489 }
490 Ok(Handshake::Handshake(send_size)) => Ok((
491 send_size,
492 SendInfo {
493 at: Instant::now(),
494 from: recv_info.to,
495 to: recv_info.from,
496 },
497 )),
498 Err(err) => Err(err),
499 },
500 Err(err) => Err(err),
501 }
502 }
503}
504
505#[cfg(feature = "server")]
506impl QuicClient for Group {
507 fn connect(
508 &self,
509 server_name: Option<&str>,
510 local: std::net::SocketAddr,
511 peer: std::net::SocketAddr,
512 config: &mut quiche::Config,
513 ) -> Result<Token> {
514 let conn = quiche::connect(server_name, &random_conn_id(), local, peer, config)?;
515
516 let token = self.register(conn)?;
517
518 Ok(token)
519 }
520}