zerortt_mio/
group.rs

1use std::{
2    borrow::Cow,
3    collections::HashMap,
4    io::{Error, Result},
5    net::{SocketAddr, ToSocketAddrs},
6    task::Poll,
7    time::Instant,
8};
9
10use crossbeam_utils::sync::Parker;
11use mio::{Events, Interest, Waker, net::UdpSocket};
12use parking_lot::Mutex;
13
14use zerortt_api::{
15    Acceptor, Event, EventKind, QuicBind, QuicClient, QuicPoll, QuicServerTransport, QuicTransport,
16    StreamKind, Token, WouldBlock,
17    quiche::{self, RecvInfo},
18};
19
20use crate::{
21    buf::QuicBuf,
22    udp::{QuicSocket, QuicSocketError},
23};
24
25struct PollState {
26    /// server-side connection acceptor.
27    acceptor: Option<Acceptor>,
28    /// mio poller.
29    poll: mio::Poll,
30    /// `UDP` sockets bound to this group.
31    sockets: Vec<QuicSocket>,
32}
33
34/// Facade to access `QUIC` group.
35pub struct Group {
36    /// A waker of `mio::Poll`.
37    waker: mio::Waker,
38    /// quic group.
39    group: zerortt_poll::Group,
40    /// local bound addresses
41    laddrs: HashMap<SocketAddr, usize>,
42    /// poll state.
43    state: Mutex<PollState>,
44}
45
46impl Group {
47    fn mio_poll_once(&self, poll_state: &mut PollState, deadline: Option<Instant>) -> Result<()> {
48        let timeout = if let Some(next_release_time) = deadline {
49            next_release_time.checked_duration_since(Instant::now())
50        } else {
51            None
52        };
53
54        let mut events = Events::with_capacity(1024);
55
56        log::trace!("mio poll: timeout({:?})", timeout);
57
58        poll_state
59            .poll
60            .poll(&mut events, timeout)
61            .inspect_err(|err| log::error!("mio poll error: {}", err))?;
62
63        for event in events.iter() {
64            log::trace!("readiness, event={:?}", event);
65
66            let token = event.token();
67
68            // Event to wakeup this `Poll`, skip it!!
69            if token.0 == poll_state.sockets.len() {
70                continue;
71            }
72
73            if event.is_readable() {
74                self.on_udp_recv(poll_state, token)?;
75            }
76
77            if event.is_writable() {
78                self.on_udp_send(poll_state, token)?;
79            }
80        }
81
82        Ok(())
83    }
84
85    fn on_quic_send(&self, poll_state: &mut PollState, token: Token) -> Result<()> {
86        let mut buf = QuicBuf::new();
87
88        let Poll::Ready(Ok((send_size, send_info))) =
89            self.group.send(token, buf.writable_buf()).would_block()
90        else {
91            // skip all other errors.
92            return Ok(());
93        };
94
95        assert!(send_size > 0);
96
97        buf.writable_consume(send_size);
98
99        let index = self
100            .laddrs
101            .get(&send_info.from)
102            .cloned()
103            .expect("Quic socket");
104
105        let quic_socket = poll_state.sockets.get_mut(index).expect("Quic socket");
106
107        let len = match quic_socket.send_to(buf, send_info.to) {
108            Ok(len) => len,
109            Err(QuicSocketError::IsFull(_)) => {
110                log::warn!("udp send queue is full, socket=Token({})", index);
111                return Ok(());
112            }
113            Err(err) => return Err(err.into()),
114        };
115
116        log::trace!("quic socket sending fifo, len={}", len);
117
118        Ok(())
119    }
120
121    #[inline]
122    fn on_quic_recv(&self, _: &mut PollState, _: Token) -> Result<()> {
123        Ok(())
124    }
125
126    fn on_udp_recv(&self, poll_state: &mut PollState, token: mio::Token) -> Result<()> {
127        let quic_socket = poll_state.sockets.get_mut(token.0).expect("Quic socket");
128
129        let parker = Parker::new();
130
131        loop {
132            let mut buf = QuicBuf::new();
133
134            let Poll::Ready(from) = quic_socket
135                .recv_from(&mut buf)
136                .map_err(|err| Error::from(err))
137                .would_block()?
138            else {
139                return Ok(());
140            };
141
142            let read_size = buf.readable();
143
144            if let Some(acceptor) = &mut poll_state.acceptor {
145                // for server-side dispatching.
146                loop {
147                    match self.group.recv_with_acceptor(
148                        acceptor,
149                        buf.writable_buf(),
150                        read_size,
151                        RecvInfo {
152                            from,
153                            to: quic_socket.local_addr(),
154                        },
155                        Some(parker.unparker()),
156                    ) {
157                        Ok((send_size, send_info)) => {
158                            if send_size == 0 {
159                                // handle next udp packet.
160                                break;
161                            }
162
163                            buf.writable_consume(send_size);
164
165                            match quic_socket.send_to(buf, send_info.to) {
166                                Ok(_) => {}
167                                Err(QuicSocketError::IsFull(_)) => {
168                                    log::warn!(
169                                        "`QuicSocket` sending queue is full, socket={}",
170                                        token.0
171                                    );
172                                }
173                                Err(err) => return Err(err.into()),
174                            }
175                        }
176                        Err(zerortt_api::Error::Busy) | Err(zerortt_api::Error::Retry) => {
177                            parker.park();
178                            // try agian.
179                            continue;
180                        }
181                        Err(_) => {}
182                    }
183
184                    break;
185                }
186            } else {
187                let header =
188                    quiche::Header::from_slice(buf.readable_buf_mut(), quiche::MAX_CONN_ID_LEN)
189                        .map_err(zerortt_api::Error::Quiche)?;
190
191                // for client-side dispatching.
192                loop {
193                    match self.group.recv_with_connection_id(
194                        &header.dcid,
195                        buf.readable_buf_mut(),
196                        RecvInfo {
197                            from,
198                            to: quic_socket.local_addr(),
199                        },
200                        Some(parker.unparker()),
201                    ) {
202                        Ok(_) => {}
203                        // Current connection is busy.
204                        Err(zerortt_api::Error::Busy) | Err(zerortt_api::Error::Retry) => {
205                            parker.park();
206                            // try agian.
207                            continue;
208                        }
209                        Err(_) => {}
210                    }
211
212                    break;
213                }
214            }
215        }
216    }
217
218    fn on_udp_send(&self, poll_state: &mut PollState, token: mio::Token) -> Result<()> {
219        let socket = poll_state.sockets.get_mut(token.0).expect("Quic socket");
220
221        // try flush pending packets.
222        _ = socket
223            .flush()
224            .map_err(|err| Error::from(err))
225            .would_block()?;
226
227        Ok(())
228    }
229}
230
231impl QuicPoll for Group {
232    type Error = std::io::Error;
233    /// Returns number of connections in the group.
234    #[inline]
235    fn len(&self) -> usize {
236        self.group.len()
237    }
238
239    /// Wrap and register a new `quiche::Connection`.
240    #[inline]
241    fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
242        let token = self.group.register(wrapped);
243
244        self.waker.wake()?;
245
246        Ok(token?)
247    }
248
249    /// Unwrap a bound `quiche::Connection`
250    #[inline]
251    fn deregister(&self, token: Token) -> Result<quiche::Connection> {
252        let conn = self.group.deregister(token);
253
254        self.waker.wake()?;
255
256        Ok(conn?)
257    }
258
259    /// Close one wrapped `quiche::Connection`
260    #[inline]
261    fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
262        let r = self.group.close(token, app, err, reason);
263
264        self.waker.wake()?;
265
266        Ok(r?)
267    }
268
269    /// Open a new outbound stream.
270    fn stream_open(
271        &self,
272        token: Token,
273        kind: StreamKind,
274        non_blocking: bool,
275    ) -> Result<Option<u64>> {
276        let r = self.group.stream_open(token, kind, non_blocking);
277
278        self.waker.wake()?;
279
280        Ok(r?)
281    }
282
283    /// Shutdown a stream.
284    #[inline]
285    fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
286        let r = self.group.stream_shutdown(token, stream_id, err);
287
288        self.waker.wake()?;
289
290        Ok(r?)
291    }
292
293    /// Writes data to a stream.
294    #[inline]
295    fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
296        let send_size = self.group.stream_send(token, stream_id, buf, fin);
297
298        self.waker.wake()?;
299
300        Ok(send_size?)
301    }
302
303    /// Reads contiguous data from a stream into the provided slice.
304    #[inline]
305    fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
306        let r = self.group.stream_recv(token, stream_id, buf);
307
308        self.waker.wake()?;
309
310        Ok(r?)
311    }
312
313    /// Waits for readiness events.
314    fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
315        let mut poll_state = self.state.lock();
316
317        loop {
318            let next_release_time = self.group.poll(events)?;
319
320            // filter events: `Send` and `Recv`.
321            for event in events.drain(..).collect::<Vec<_>>() {
322                match event.kind {
323                    EventKind::Send => {
324                        self.on_quic_send(&mut poll_state, event.token)?;
325                    }
326                    EventKind::Recv => {
327                        self.on_quic_recv(&mut poll_state, event.token)?;
328                    }
329                    _ => events.push(event),
330                }
331            }
332
333            // Readiness `events` is not empty, returns immediately.
334            if !events.is_empty() {
335                return Ok(None);
336            }
337
338            self.mio_poll_once(&mut poll_state, next_release_time)?;
339        }
340    }
341}
342
343impl QuicClient for Group {
344    fn connect(
345        &self,
346        server_name: Option<&str>,
347        local: SocketAddr,
348        peer: SocketAddr,
349        config: &mut quiche::Config,
350    ) -> std::result::Result<Token, Self::Error> {
351        assert!(self.laddrs.contains_key(&local), "invalid local address.");
352
353        let token = self.group.connect(server_name, local, peer, config);
354
355        self.waker.wake()?;
356
357        Ok(token?)
358    }
359}
360
361impl QuicBind for Group {
362    /// Returns local bound addresses.
363    fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
364        self.laddrs.keys()
365    }
366
367    /// Create a new `Group` and bind to `laddrs`.
368    fn bind<S>(laddrs: S, acceptor: Option<Acceptor>) -> Result<Self>
369    where
370        S: ToSocketAddrs,
371    {
372        let poll = mio::Poll::new()?;
373        let group = zerortt_poll::Group::new();
374
375        let mut sockets = vec![];
376        let mut addrs = HashMap::new();
377
378        for laddr in laddrs.to_socket_addrs()? {
379            let mut socket = UdpSocket::bind(laddr)?;
380            addrs.insert(socket.local_addr()?, sockets.len());
381
382            poll.registry().register(
383                &mut socket,
384                mio::Token(sockets.len()),
385                Interest::READABLE | Interest::WRITABLE,
386            )?;
387
388            sockets.push(QuicSocket::new(socket, 1024)?);
389        }
390
391        let waker = Waker::new(poll.registry(), mio::Token(sockets.len()))?;
392
393        Ok(Group {
394            waker,
395            group,
396            laddrs: addrs,
397            state: Mutex::new(PollState {
398                acceptor,
399                poll,
400                sockets,
401            }),
402        })
403    }
404}