zerortt_poll/
group.rs

1use std::{
2    borrow::Cow,
3    cell::RefCell,
4    collections::HashMap,
5    ops::DerefMut,
6    time::{Duration, Instant},
7};
8
9use crossbeam_utils::sync::Unparker;
10use parking_lot::{Mutex, RwLock};
11use zerortt_api::{
12    Acceptor, Error, Event, QuicClient, QuicPoll, QuicServerTransport, QuicTransport, Result,
13    StreamKind, Token,
14    quiche::{self, ConnectionId, RecvInfo, SendInfo},
15    random_conn_id,
16};
17
18#[cfg(feature = "server")]
19use zerortt_api::Handshake;
20
21use crate::{
22    conn::{LocKind, LockContext, QuicConn},
23    readiness::Readiness,
24    utils::release_time,
25};
26
27static DEFAULT_RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
28
29macro_rules! lock {
30    ($self:ident, $token: ident, $kind: expr) => {{
31        let state = $self.state.lock();
32
33        let conn = state
34            .conns
35            .get(&$token)
36            .ok_or_else(|| zerortt_api::Error::NotFound)?
37            .borrow_mut()
38            .try_lock($kind, |ctx| $self.unlock(ctx))?;
39
40        drop(state);
41
42        conn
43    }};
44}
45
46#[derive(Default)]
47struct State {
48    token_next: u32,
49    conns: HashMap<Token, RefCell<QuicConn>>,
50    readiness: RefCell<Readiness>,
51    unparkers: HashMap<Token, Unparker>,
52}
53
54/// A group of `quiche:Connection`s.
55pub struct Group {
56    state: Mutex<State>,
57    scids: RwLock<HashMap<ConnectionId<'static>, Token>>,
58}
59
60impl Default for Group {
61    fn default() -> Self {
62        Self {
63            state: Default::default(),
64            scids: Default::default(),
65        }
66    }
67}
68
69impl Group {
70    /// Create a group with default parameters.
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    fn unlock(&self, ctx: LockContext) {
76        let mut state = self.state.lock();
77
78        if let Some(unparker) = state.unparkers.remove(&ctx.token) {
79            unparker.unpark();
80        }
81
82        if state
83            .conns
84            .get(&ctx.token)
85            .expect("Unlock.")
86            .borrow_mut()
87            .unlock(
88                ctx.lock_count,
89                ctx.send_done,
90                state.readiness.borrow_mut().deref_mut(),
91            )
92        {
93            log::trace!(
94                "automatic deregister closed connection, token={:?}",
95                ctx.token
96            );
97
98            drop(state);
99            _ = self.deregister(ctx.token);
100        }
101    }
102
103    /// Process a packet recv.
104    /// Processes QUIC packets received from the peer.
105    pub fn recv_with_connection_id(
106        &self,
107        scid: &ConnectionId<'_>,
108        buf: &mut [u8],
109        info: RecvInfo,
110        unparker: Option<&Unparker>,
111    ) -> Result<(Token, usize)> {
112        let token = self
113            .scids
114            .read()
115            .get(&scid)
116            .ok_or_else(|| Error::NotFound)?
117            .clone();
118
119        let mut state = self.state.lock();
120
121        let Ok(mut conn) = state
122            .conns
123            .get(&token)
124            .ok_or_else(|| Error::NotFound)?
125            .borrow_mut()
126            .try_lock(LocKind::Recv, |ctx| self.unlock(ctx))
127        else {
128            // insert unparker.
129            if let Some(unparker) = unparker {
130                state.unparkers.insert(token, unparker.clone());
131            }
132
133            return Err(Error::Busy);
134        };
135
136        drop(state);
137
138        match conn.recv(buf, info) {
139            Ok(recv_size) => {
140                log::trace!(
141                    "Connection recv, scid={:?}, len={}",
142                    conn.source_id(),
143                    recv_size
144                );
145
146                Ok((token, recv_size))
147            }
148            Err(err) => {
149                log::error!("Connection recv, scid={:?}, err={}", conn.source_id(), err);
150                Err(Error::Quiche(err))
151            }
152        }
153    }
154}
155
156impl QuicPoll for Group {
157    type Error = zerortt_api::Error;
158    /// Wrap and register a new `quiche::Connection`.
159    #[inline]
160    fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
161        let mut state = self.state.lock();
162
163        loop {
164            let token = Token(state.token_next);
165
166            (state.token_next, _) = state.token_next.overflowing_add(1);
167
168            if state.conns.contains_key(&token) {
169                continue;
170            }
171
172            assert!(
173                self.scids
174                    .write()
175                    .insert(wrapped.source_id().into_owned(), token)
176                    .is_none()
177            );
178
179            log::trace!(
180                "register quic connection, token={:?}, trace_id={}",
181                token,
182                wrapped.trace_id()
183            );
184
185            let conn = RefCell::new(QuicConn::new(token, wrapped));
186
187            let guard = conn.borrow_mut().try_lock(LocKind::ReadLock, |context| {
188                conn.borrow_mut().unlock(
189                    context.lock_count,
190                    false,
191                    state.readiness.borrow_mut().deref_mut(),
192                );
193            })?;
194
195            drop(guard);
196
197            state.conns.insert(token, conn);
198
199            return Ok(token);
200        }
201    }
202
203    /// Unwrap a bound `quiche::Connection`
204    #[inline]
205    fn deregister(&self, token: Token) -> Result<quiche::Connection> {
206        let mut state = self.state.lock();
207
208        let conn: quiche::Connection = state
209            .conns
210            .remove(&token)
211            .ok_or_else(|| Error::NotFound)?
212            .into_inner()
213            .into();
214
215        drop(state);
216
217        assert_eq!(
218            self.scids.write().remove(&conn.source_id().into_owned()),
219            Some(token)
220        );
221
222        Ok(conn)
223    }
224
225    /// Returns number of connections in the group.
226    #[inline]
227    fn len(&self) -> usize {
228        self.state.lock().conns.len()
229    }
230
231    /// Close one connection.
232    #[inline]
233    fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
234        let state = self.state.lock();
235        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
236
237        conn.borrow_mut()
238            .close(app, err, reason, state.readiness.borrow_mut().deref_mut())
239    }
240
241    /// Open a outbound stream.
242    #[inline]
243    fn stream_open(
244        &self,
245        token: Token,
246        kind: StreamKind,
247        max_streams_as_error: bool,
248    ) -> Result<u64> {
249        let state = self.state.lock();
250        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
251
252        conn.borrow_mut().stream_open(
253            kind,
254            max_streams_as_error,
255            state.readiness.borrow_mut().deref_mut(),
256        )
257    }
258
259    /// Shutdown a stream.
260    #[inline]
261    fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
262        let state = self.state.lock();
263        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
264
265        conn.borrow_mut()
266            .stream_close(stream_id, err, state.readiness.borrow_mut().deref_mut())
267    }
268
269    /// Writes data to a stream.
270    #[inline]
271    fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
272        let mut conn = lock!(
273            self,
274            token,
275            LocKind::StreamSend {
276                id: stream_id,
277                len: buf.len()
278            }
279        );
280
281        match conn.stream_send(stream_id, buf, fin) {
282            Ok(send_size) => {
283                log::trace!(
284                    "stream send, scid={:?}, stream_id={}, len={}, fin={}",
285                    conn.source_id(),
286                    stream_id,
287                    send_size,
288                    fin
289                );
290                return Ok(send_size);
291            }
292            Err(quiche::Error::Done) => {
293                log::trace!(
294                    "stream send, scid={:?}, stream_id={}, fin={}, Done",
295                    conn.source_id(),
296                    stream_id,
297                    fin
298                );
299                return Err(Error::Retry);
300            }
301            Err(err) => {
302                log::error!(
303                    "stream send, scid={:?}, stream_id={}, fin={}, err={}",
304                    conn.source_id(),
305                    stream_id,
306                    fin,
307                    err
308                );
309
310                return Err(Error::Quiche(err));
311            }
312        }
313    }
314
315    /// Reads contiguous data from a stream into the provided slice.
316    #[inline]
317    fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
318        let mut conn = lock!(self, token, LocKind::StreamRecv(stream_id));
319
320        match conn.stream_recv(stream_id, buf) {
321            Ok((recv_size, fin)) => {
322                log::trace!(
323                    "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
324                    conn.source_id(),
325                    stream_id,
326                    recv_size,
327                    fin,
328                    conn.is_server(),
329                );
330                return Ok((recv_size, fin));
331            }
332            Err(quiche::Error::Done) => {
333                if conn.stream_finished(stream_id) {
334                    log::trace!(
335                        "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
336                        conn.source_id(),
337                        stream_id,
338                        0,
339                        true,
340                        conn.is_server(),
341                    );
342
343                    return Ok((0, true));
344                }
345
346                log::trace!(
347                    "stream recv, scid={:?}, stream_id={}, is_server={}, Done",
348                    conn.source_id(),
349                    stream_id,
350                    conn.is_server(),
351                );
352                return Err(Error::Retry);
353            }
354            Err(err) => {
355                log::error!(
356                    "stream recv, scid={:?}, stream_id={}, is_server={}, err={}",
357                    conn.source_id(),
358                    stream_id,
359                    conn.is_server(),
360                    err
361                );
362
363                return Err(Error::Quiche(err));
364            }
365        }
366    }
367
368    /// Waits for readiness events without blocking current thread and returns possible retry time duration.
369    #[inline]
370    fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
371        let state = self.state.lock();
372
373        Ok(state
374            .readiness
375            .borrow_mut()
376            .poll(events, DEFAULT_RELEASE_TIMER_THRESHOLD))
377    }
378}
379
380impl QuicTransport for Group {
381    type Error = zerortt_api::Error;
382    /// Processes QUIC packets received from the peer.
383    #[inline]
384    fn recv(&self, buf: &mut [u8], info: RecvInfo) -> Result<usize> {
385        let header =
386            quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(Error::Quiche)?;
387
388        self.recv_with_connection_id(&header.dcid, buf, info, None)
389            .map(|(_, recv_size)| recv_size)
390    }
391
392    /// Writes a single QUIC packet to be sent to the peer.
393    #[inline]
394    fn send(&self, token: Token, buf: &mut [u8]) -> Result<(usize, SendInfo)> {
395        let mut conn = lock!(self, token, LocKind::Recv);
396
397        if let Some(release_time) =
398            release_time(&conn, Instant::now(), DEFAULT_RELEASE_TIMER_THRESHOLD)
399        {
400            log::trace!(
401                "connection send, scid={:?}, next_release_time={:?}",
402                conn.trace_id(),
403                release_time,
404            );
405            return Err(Error::Retry);
406        }
407
408        // TODO: prevent frequent calls to on_timeout
409        conn.on_timeout();
410
411        match conn.send(buf) {
412            Ok((send_size, send_info)) => {
413                log::trace!(
414                    "connection send, scid={:?}, send_size={}, send_info={:?}",
415                    conn.trace_id(),
416                    send_size,
417                    send_info
418                );
419                return Ok((send_size, send_info));
420            }
421            Err(quiche::Error::Done) => {
422                log::trace!("connection send, scid={:?}, done", conn.trace_id());
423                conn.send_done();
424                return Err(Error::Retry);
425            }
426            Err(err) => {
427                log::error!("connection send, scid={:?}, err={}", conn.trace_id(), err);
428                return Err(Error::Quiche(err));
429            }
430        }
431    }
432}
433
434#[cfg(feature = "server")]
435impl QuicServerTransport for Group {
436    fn recv_with_acceptor(
437        &self,
438        acceptor: &mut Acceptor,
439        buf: &mut [u8],
440        recv_size: usize,
441        recv_info: RecvInfo,
442        unparker: Option<&Unparker>,
443    ) -> Result<(usize, SendInfo)> {
444        let header = quiche::Header::from_slice(&mut buf[..recv_size], quiche::MAX_CONN_ID_LEN)
445            .map_err(Error::Quiche)?;
446
447        match self.recv_with_connection_id(&header.dcid, &mut buf[..recv_size], recv_info, unparker)
448        {
449            Ok((token, _)) => match self.send(token, buf) {
450                Err(Error::Busy) | Err(Error::Retry) => Ok((
451                    0,
452                    SendInfo {
453                        at: Instant::now(),
454                        from: recv_info.to,
455                        to: recv_info.from,
456                    },
457                )),
458                r => r,
459            },
460            Err(Error::NotFound) => match acceptor.handshake(&header, buf, recv_size, recv_info) {
461                Ok(Handshake::Accept(conn)) => {
462                    let token = self.register(conn)?;
463
464                    // Newly registered connections should be idle.
465                    match self.recv_with_connection_id(
466                        &header.dcid,
467                        &mut buf[..recv_size],
468                        recv_info,
469                        None,
470                    ) {
471                        Ok(_) => {}
472                        Err(Error::Busy) | Err(Error::Retry) => {
473                            unreachable!("Newly registered connections should be idle");
474                        }
475                        Err(err) => return Err(err),
476                    }
477
478                    match self.send(token, buf) {
479                        Err(Error::Busy) | Err(Error::Retry) => Ok((
480                            0,
481                            SendInfo {
482                                at: Instant::now(),
483                                from: recv_info.to,
484                                to: recv_info.from,
485                            },
486                        )),
487                        r => r,
488                    }
489                }
490                Ok(Handshake::Handshake(send_size)) => Ok((
491                    send_size,
492                    SendInfo {
493                        at: Instant::now(),
494                        from: recv_info.to,
495                        to: recv_info.from,
496                    },
497                )),
498                Err(err) => Err(err),
499            },
500            Err(err) => Err(err),
501        }
502    }
503}
504
505#[cfg(feature = "server")]
506impl QuicClient for Group {
507    fn connect(
508        &self,
509        server_name: Option<&str>,
510        local: std::net::SocketAddr,
511        peer: std::net::SocketAddr,
512        config: &mut quiche::Config,
513    ) -> Result<Token> {
514        let conn = quiche::connect(server_name, &random_conn_id(), local, peer, config)?;
515
516        let token = self.register(conn)?;
517
518        Ok(token)
519    }
520}