zerortt_poll/
group.rs

1use std::{
2    borrow::Cow,
3    cell::RefCell,
4    collections::HashMap,
5    ops::DerefMut,
6    time::{Duration, Instant},
7};
8
9use crossbeam_utils::sync::Unparker;
10use parking_lot::{Mutex, RwLock};
11use zerortt_api::{
12    Acceptor, Error, Event, QuicClient, QuicPoll, QuicServerTransport, QuicTransport, Result,
13    StreamKind, Token,
14    quiche::{self, ConnectionId, RecvInfo, SendInfo},
15    random_conn_id,
16};
17
18#[cfg(feature = "server")]
19use zerortt_api::Handshake;
20
21use crate::{
22    conn::{LocKind, LockContext, QuicConn},
23    readiness::Readiness,
24    utils::release_time,
25};
26
27static DEFAULT_RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
28
29macro_rules! lock {
30    ($self:ident, $token: ident, $kind: expr) => {{
31        let state = $self.state.lock();
32
33        let conn = state
34            .conns
35            .get(&$token)
36            .ok_or_else(|| zerortt_api::Error::NotFound)?
37            .borrow_mut()
38            .try_lock($kind, |ctx| $self.unlock(ctx))?;
39
40        drop(state);
41
42        conn
43    }};
44}
45
46#[derive(Default)]
47struct State {
48    token_next: u32,
49    conns: HashMap<Token, RefCell<QuicConn>>,
50    readiness: RefCell<Readiness>,
51    unparkers: HashMap<Token, Unparker>,
52}
53
54/// A group of `quiche:Connection`s.
55pub struct Group {
56    state: Mutex<State>,
57    scids: RwLock<HashMap<ConnectionId<'static>, Token>>,
58}
59
60impl Default for Group {
61    fn default() -> Self {
62        Self {
63            state: Default::default(),
64            scids: Default::default(),
65        }
66    }
67}
68
69impl Group {
70    /// Create a group with default parameters.
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    fn unlock(&self, ctx: LockContext) {
76        let mut state = self.state.lock();
77
78        if let Some(unparker) = state.unparkers.remove(&ctx.token) {
79            unparker.unpark();
80        }
81
82        if state
83            .conns
84            .get(&ctx.token)
85            .expect("Unlock.")
86            .borrow_mut()
87            .unlock(
88                ctx.lock_count,
89                ctx.send_done,
90                state.readiness.borrow_mut().deref_mut(),
91            )
92        {
93            log::trace!(
94                "automatic deregister closed connection, token={:?}",
95                ctx.token
96            );
97
98            drop(state);
99            _ = self.deregister(ctx.token);
100        }
101    }
102
103    /// Process a packet recv.
104    /// Processes QUIC packets received from the peer.
105    pub fn recv_with_connection_id(
106        &self,
107        scid: &ConnectionId<'_>,
108        buf: &mut [u8],
109        info: RecvInfo,
110        unparker: Option<&Unparker>,
111    ) -> Result<(Token, usize)> {
112        let token = self
113            .scids
114            .read()
115            .get(&scid)
116            .ok_or_else(|| Error::NotFound)?
117            .clone();
118
119        let mut state = self.state.lock();
120
121        let Ok(mut conn) = state
122            .conns
123            .get(&token)
124            .ok_or_else(|| Error::NotFound)?
125            .borrow_mut()
126            .try_lock(LocKind::Recv, |ctx| self.unlock(ctx))
127        else {
128            // insert unparker.
129            if let Some(unparker) = unparker {
130                state.unparkers.insert(token, unparker.clone());
131            }
132
133            return Err(Error::Busy);
134        };
135
136        drop(state);
137
138        match conn.recv(buf, info) {
139            Ok(recv_size) => {
140                log::trace!(
141                    "Connection recv, scid={:?}, len={}",
142                    conn.source_id(),
143                    recv_size
144                );
145
146                Ok((token, recv_size))
147            }
148            Err(err) => {
149                log::error!("Connection recv, scid={:?}, err={}", conn.source_id(), err);
150                Err(Error::Quiche(err))
151            }
152        }
153    }
154}
155
156impl QuicPoll for Group {
157    type Error = zerortt_api::Error;
158    /// Wrap and register a new `quiche::Connection`.
159    #[inline]
160    fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
161        let mut state = self.state.lock();
162
163        loop {
164            let token = Token(state.token_next);
165
166            (state.token_next, _) = state.token_next.overflowing_add(1);
167
168            if state.conns.contains_key(&token) {
169                continue;
170            }
171
172            assert!(
173                self.scids
174                    .write()
175                    .insert(wrapped.source_id().into_owned(), token)
176                    .is_none()
177            );
178
179            log::trace!(
180                "register quic connection, token={:?}, trace_id={}",
181                token,
182                wrapped.trace_id()
183            );
184
185            let conn = RefCell::new(QuicConn::new(token, wrapped));
186
187            let guard = conn.borrow_mut().try_lock(LocKind::ReadLock, |context| {
188                conn.borrow_mut().unlock(
189                    context.lock_count,
190                    false,
191                    state.readiness.borrow_mut().deref_mut(),
192                );
193            })?;
194
195            drop(guard);
196
197            state.conns.insert(token, conn);
198
199            return Ok(token);
200        }
201    }
202
203    /// Unwrap a bound `quiche::Connection`
204    #[inline]
205    fn deregister(&self, token: Token) -> Result<quiche::Connection> {
206        let mut state = self.state.lock();
207
208        let conn: quiche::Connection = state
209            .conns
210            .remove(&token)
211            .ok_or_else(|| Error::NotFound)?
212            .into_inner()
213            .into();
214
215        drop(state);
216
217        assert_eq!(
218            self.scids.write().remove(&conn.source_id().into_owned()),
219            Some(token)
220        );
221
222        Ok(conn)
223    }
224
225    /// Returns number of connections in the group.
226    #[inline]
227    fn len(&self) -> usize {
228        self.state.lock().conns.len()
229    }
230
231    /// Close one connection.
232    #[inline]
233    fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
234        let state = self.state.lock();
235        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
236
237        conn.borrow_mut()
238            .close(app, err, reason, state.readiness.borrow_mut().deref_mut())
239    }
240
241    /// Open a outbound stream.
242    #[inline]
243    fn stream_open(
244        &self,
245        token: Token,
246        kind: StreamKind,
247        non_blocking: bool,
248    ) -> Result<Option<u64>> {
249        let state = self.state.lock();
250        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
251
252        conn.borrow_mut()
253            .stream_open(kind, non_blocking, state.readiness.borrow_mut().deref_mut())
254    }
255
256    /// Shutdown a stream.
257    #[inline]
258    fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
259        let state = self.state.lock();
260        let conn = state.conns.get(&token).ok_or_else(|| Error::NotFound)?;
261
262        conn.borrow_mut()
263            .stream_close(stream_id, err, state.readiness.borrow_mut().deref_mut())
264    }
265
266    /// Writes data to a stream.
267    #[inline]
268    fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
269        let mut conn = lock!(
270            self,
271            token,
272            LocKind::StreamSend {
273                id: stream_id,
274                len: buf.len()
275            }
276        );
277
278        match conn.stream_send(stream_id, buf, fin) {
279            Ok(send_size) => {
280                log::trace!(
281                    "stream send, scid={:?}, stream_id={}, len={}, fin={}",
282                    conn.source_id(),
283                    stream_id,
284                    send_size,
285                    fin
286                );
287                return Ok(send_size);
288            }
289            Err(quiche::Error::Done) => {
290                log::trace!(
291                    "stream send, scid={:?}, stream_id={}, fin={}, Done",
292                    conn.source_id(),
293                    stream_id,
294                    fin
295                );
296                return Err(Error::Retry);
297            }
298            Err(err) => {
299                log::error!(
300                    "stream send, scid={:?}, stream_id={}, fin={}, err={}",
301                    conn.source_id(),
302                    stream_id,
303                    fin,
304                    err
305                );
306
307                return Err(Error::Quiche(err));
308            }
309        }
310    }
311
312    /// Reads contiguous data from a stream into the provided slice.
313    #[inline]
314    fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
315        let mut conn = lock!(self, token, LocKind::StreamRecv(stream_id));
316
317        match conn.stream_recv(stream_id, buf) {
318            Ok((recv_size, fin)) => {
319                log::trace!(
320                    "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
321                    conn.source_id(),
322                    stream_id,
323                    recv_size,
324                    fin,
325                    conn.is_server(),
326                );
327                return Ok((recv_size, fin));
328            }
329            Err(quiche::Error::Done) => {
330                if conn.stream_finished(stream_id) {
331                    log::trace!(
332                        "stream recv, scid={:?}, stream_id={}, len={}, fin={}, is_server={}",
333                        conn.source_id(),
334                        stream_id,
335                        0,
336                        true,
337                        conn.is_server(),
338                    );
339
340                    return Ok((0, true));
341                }
342
343                log::trace!(
344                    "stream recv, scid={:?}, stream_id={}, is_server={}, Done",
345                    conn.source_id(),
346                    stream_id,
347                    conn.is_server(),
348                );
349                return Err(Error::Retry);
350            }
351            Err(err) => {
352                log::error!(
353                    "stream recv, scid={:?}, stream_id={}, is_server={}, err={}",
354                    conn.source_id(),
355                    stream_id,
356                    conn.is_server(),
357                    err
358                );
359
360                return Err(Error::Quiche(err));
361            }
362        }
363    }
364
365    /// Waits for readiness events without blocking current thread and returns possible retry time duration.
366    #[inline]
367    fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
368        let state = self.state.lock();
369
370        Ok(state
371            .readiness
372            .borrow_mut()
373            .poll(events, DEFAULT_RELEASE_TIMER_THRESHOLD))
374    }
375}
376
377impl QuicTransport for Group {
378    type Error = zerortt_api::Error;
379    /// Processes QUIC packets received from the peer.
380    #[inline]
381    fn recv(&self, buf: &mut [u8], info: RecvInfo) -> Result<usize> {
382        let header =
383            quiche::Header::from_slice(buf, quiche::MAX_CONN_ID_LEN).map_err(Error::Quiche)?;
384
385        self.recv_with_connection_id(&header.dcid, buf, info, None)
386            .map(|(_, recv_size)| recv_size)
387    }
388
389    /// Writes a single QUIC packet to be sent to the peer.
390    #[inline]
391    fn send(&self, token: Token, buf: &mut [u8]) -> Result<(usize, SendInfo)> {
392        let mut conn = lock!(self, token, LocKind::Recv);
393
394        if let Some(release_time) =
395            release_time(&conn, Instant::now(), DEFAULT_RELEASE_TIMER_THRESHOLD)
396        {
397            log::trace!(
398                "connection send, scid={:?}, next_release_time={:?}",
399                conn.trace_id(),
400                release_time,
401            );
402            return Err(Error::Retry);
403        }
404
405        // TODO: prevent frequent calls to on_timeout
406        conn.on_timeout();
407
408        match conn.send(buf) {
409            Ok((send_size, send_info)) => {
410                log::trace!(
411                    "connection send, scid={:?}, send_size={}, send_info={:?}",
412                    conn.trace_id(),
413                    send_size,
414                    send_info
415                );
416                return Ok((send_size, send_info));
417            }
418            Err(quiche::Error::Done) => {
419                log::trace!("connection send, scid={:?}, done", conn.trace_id());
420                conn.send_done();
421                return Err(Error::Retry);
422            }
423            Err(err) => {
424                log::error!("connection send, scid={:?}, err={}", conn.trace_id(), err);
425                return Err(Error::Quiche(err));
426            }
427        }
428    }
429}
430
431#[cfg(feature = "server")]
432impl QuicServerTransport for Group {
433    fn recv_with_acceptor(
434        &self,
435        acceptor: &mut Acceptor,
436        buf: &mut [u8],
437        recv_size: usize,
438        recv_info: RecvInfo,
439        unparker: Option<&Unparker>,
440    ) -> Result<(usize, SendInfo)> {
441        let header = quiche::Header::from_slice(&mut buf[..recv_size], quiche::MAX_CONN_ID_LEN)
442            .map_err(Error::Quiche)?;
443
444        match self.recv_with_connection_id(&header.dcid, &mut buf[..recv_size], recv_info, unparker)
445        {
446            Ok((token, _)) => match self.send(token, buf) {
447                Err(Error::Busy) | Err(Error::Retry) => Ok((
448                    0,
449                    SendInfo {
450                        at: Instant::now(),
451                        from: recv_info.to,
452                        to: recv_info.from,
453                    },
454                )),
455                r => r,
456            },
457            Err(Error::NotFound) => match acceptor.handshake(&header, buf, recv_size, recv_info) {
458                Ok(Handshake::Accept(conn)) => {
459                    let token = self.register(conn)?;
460
461                    // Newly registered connections should be idle.
462                    match self.recv_with_connection_id(
463                        &header.dcid,
464                        &mut buf[..recv_size],
465                        recv_info,
466                        None,
467                    ) {
468                        Ok(_) => {}
469                        Err(Error::Busy) | Err(Error::Retry) => {
470                            unreachable!("Newly registered connections should be idle");
471                        }
472                        Err(err) => return Err(err),
473                    }
474
475                    match self.send(token, buf) {
476                        Err(Error::Busy) | Err(Error::Retry) => Ok((
477                            0,
478                            SendInfo {
479                                at: Instant::now(),
480                                from: recv_info.to,
481                                to: recv_info.from,
482                            },
483                        )),
484                        r => r,
485                    }
486                }
487                Ok(Handshake::Handshake(send_size)) => Ok((
488                    send_size,
489                    SendInfo {
490                        at: Instant::now(),
491                        from: recv_info.to,
492                        to: recv_info.from,
493                    },
494                )),
495                Err(err) => Err(err),
496            },
497            Err(err) => Err(err),
498        }
499    }
500}
501
502#[cfg(feature = "server")]
503impl QuicClient for Group {
504    fn connect(
505        &self,
506        server_name: Option<&str>,
507        local: std::net::SocketAddr,
508        peer: std::net::SocketAddr,
509        config: &mut quiche::Config,
510    ) -> Result<Token> {
511        let conn = quiche::connect(server_name, &random_conn_id(), local, peer, config)?;
512
513        let token = self.register(conn)?;
514
515        Ok(token)
516    }
517}