1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 io::{Error, Result},
5 net::{SocketAddr, ToSocketAddrs},
6 task::Poll,
7 time::Instant,
8};
9
10use crossbeam_utils::sync::Parker;
11use mio::{Events, Interest, Waker, net::UdpSocket};
12use parking_lot::Mutex;
13
14use zerortt_api::{
15 Acceptor, Event, EventKind, QuicBind, QuicClient, QuicPoll, QuicServerTransport, QuicTransport,
16 StreamKind, Token, WouldBlock,
17 quiche::{self, RecvInfo},
18};
19
20use crate::{
21 buf::QuicBuf,
22 udp::{QuicSocket, QuicSocketError},
23};
24
25struct PollState {
26 acceptor: Option<Acceptor>,
28 poll: mio::Poll,
30 sockets: Vec<QuicSocket>,
32}
33
34pub struct Group {
36 waker: mio::Waker,
38 group: zerortt_poll::Group,
40 laddrs: HashMap<SocketAddr, usize>,
42 state: Mutex<PollState>,
44}
45
46impl Group {
47 fn mio_poll_once(&self, poll_state: &mut PollState, deadline: Option<Instant>) -> Result<()> {
48 let timeout = if let Some(next_release_time) = deadline {
49 next_release_time.checked_duration_since(Instant::now())
50 } else {
51 None
52 };
53
54 let mut events = Events::with_capacity(1024);
55
56 log::trace!("mio poll: timeout({:?})", timeout);
57
58 poll_state
59 .poll
60 .poll(&mut events, timeout)
61 .inspect_err(|err| log::error!("mio poll error: {}", err))?;
62
63 for event in events.iter() {
64 log::trace!("readiness, event={:?}", event);
65
66 let token = event.token();
67
68 if token.0 == poll_state.sockets.len() {
70 continue;
71 }
72
73 if event.is_readable() {
74 self.on_udp_recv(poll_state, token)?;
75 }
76
77 if event.is_writable() {
78 self.on_udp_send(poll_state, token)?;
79 }
80 }
81
82 Ok(())
83 }
84
85 fn on_quic_send(&self, poll_state: &mut PollState, token: Token) -> Result<()> {
86 let mut buf = QuicBuf::new();
87
88 let Poll::Ready(Ok((send_size, send_info))) =
89 self.group.send(token, buf.writable_buf()).would_block()
90 else {
91 return Ok(());
93 };
94
95 assert!(send_size > 0);
96
97 buf.writable_consume(send_size);
98
99 let index = self
100 .laddrs
101 .get(&send_info.from)
102 .cloned()
103 .expect("Quic socket");
104
105 let quic_socket = poll_state.sockets.get_mut(index).expect("Quic socket");
106
107 let len = match quic_socket.send_to(buf, send_info.to) {
108 Ok(len) => len,
109 Err(QuicSocketError::IsFull(_)) => {
110 log::warn!("udp send queue is full, socket=Token({})", index);
111 return Ok(());
112 }
113 Err(err) => return Err(err.into()),
114 };
115
116 log::trace!("quic socket sending fifo, len={}", len);
117
118 Ok(())
119 }
120
121 #[inline]
122 fn on_quic_recv(&self, _: &mut PollState, _: Token) -> Result<()> {
123 Ok(())
124 }
125
126 fn on_udp_recv(&self, poll_state: &mut PollState, token: mio::Token) -> Result<()> {
127 let quic_socket = poll_state.sockets.get_mut(token.0).expect("Quic socket");
128
129 let parker = Parker::new();
130
131 loop {
132 let mut buf = QuicBuf::new();
133
134 let Poll::Ready(from) = quic_socket
135 .recv_from(&mut buf)
136 .map_err(|err| Error::from(err))
137 .would_block()?
138 else {
139 return Ok(());
140 };
141
142 let read_size = buf.readable();
143
144 if let Some(acceptor) = &mut poll_state.acceptor {
145 loop {
147 match self.group.recv_with_acceptor(
148 acceptor,
149 buf.writable_buf(),
150 read_size,
151 RecvInfo {
152 from,
153 to: quic_socket.local_addr(),
154 },
155 Some(parker.unparker()),
156 ) {
157 Ok((send_size, send_info)) => {
158 if send_size == 0 {
159 break;
161 }
162
163 buf.writable_consume(send_size);
164
165 match quic_socket.send_to(buf, send_info.to) {
166 Ok(_) => {}
167 Err(QuicSocketError::IsFull(_)) => {
168 log::warn!(
169 "`QuicSocket` sending queue is full, socket={}",
170 token.0
171 );
172 }
173 Err(err) => return Err(err.into()),
174 }
175 }
176 Err(zerortt_api::Error::Busy) | Err(zerortt_api::Error::Retry) => {
177 parker.park();
178 continue;
180 }
181 Err(_) => {}
182 }
183
184 break;
185 }
186 } else {
187 let header =
188 quiche::Header::from_slice(buf.readable_buf_mut(), quiche::MAX_CONN_ID_LEN)
189 .map_err(zerortt_api::Error::Quiche)?;
190
191 loop {
193 match self.group.recv_with_connection_id(
194 &header.dcid,
195 buf.readable_buf_mut(),
196 RecvInfo {
197 from,
198 to: quic_socket.local_addr(),
199 },
200 Some(parker.unparker()),
201 ) {
202 Ok(_) => {}
203 Err(zerortt_api::Error::Busy) | Err(zerortt_api::Error::Retry) => {
205 parker.park();
206 continue;
208 }
209 Err(_) => {}
210 }
211
212 break;
213 }
214 }
215 }
216 }
217
218 fn on_udp_send(&self, poll_state: &mut PollState, token: mio::Token) -> Result<()> {
219 let socket = poll_state.sockets.get_mut(token.0).expect("Quic socket");
220
221 _ = socket
223 .flush()
224 .map_err(|err| Error::from(err))
225 .would_block()?;
226
227 Ok(())
228 }
229}
230
231impl QuicPoll for Group {
232 type Error = std::io::Error;
233 #[inline]
235 fn len(&self) -> usize {
236 self.group.len()
237 }
238
239 #[inline]
241 fn register(&self, wrapped: quiche::Connection) -> Result<Token> {
242 let token = self.group.register(wrapped);
243
244 self.waker.wake()?;
245
246 Ok(token?)
247 }
248
249 #[inline]
251 fn deregister(&self, token: Token) -> Result<quiche::Connection> {
252 let conn = self.group.deregister(token);
253
254 self.waker.wake()?;
255
256 Ok(conn?)
257 }
258
259 #[inline]
261 fn close(&self, token: Token, app: bool, err: u64, reason: Cow<'static, [u8]>) -> Result<()> {
262 let r = self.group.close(token, app, err, reason);
263
264 self.waker.wake()?;
265
266 Ok(r?)
267 }
268
269 fn stream_open(
271 &self,
272 token: Token,
273 kind: StreamKind,
274 non_blocking: bool,
275 ) -> Result<Option<u64>> {
276 let r = self.group.stream_open(token, kind, non_blocking);
277
278 self.waker.wake()?;
279
280 Ok(r?)
281 }
282
283 #[inline]
285 fn stream_shutdown(&self, token: Token, stream_id: u64, err: u64) -> Result<()> {
286 let r = self.group.stream_shutdown(token, stream_id, err);
287
288 self.waker.wake()?;
289
290 Ok(r?)
291 }
292
293 #[inline]
295 fn stream_send(&self, token: Token, stream_id: u64, buf: &[u8], fin: bool) -> Result<usize> {
296 let send_size = self.group.stream_send(token, stream_id, buf, fin);
297
298 self.waker.wake()?;
299
300 Ok(send_size?)
301 }
302
303 #[inline]
305 fn stream_recv(&self, token: Token, stream_id: u64, buf: &mut [u8]) -> Result<(usize, bool)> {
306 let r = self.group.stream_recv(token, stream_id, buf);
307
308 self.waker.wake()?;
309
310 Ok(r?)
311 }
312
313 fn poll(&self, events: &mut Vec<Event>) -> Result<Option<Instant>> {
315 let mut poll_state = self.state.lock();
316
317 loop {
318 let next_release_time = self.group.poll(events)?;
319
320 for event in events.drain(..).collect::<Vec<_>>() {
322 match event.kind {
323 EventKind::Send => {
324 self.on_quic_send(&mut poll_state, event.token)?;
325 }
326 EventKind::Recv => {
327 self.on_quic_recv(&mut poll_state, event.token)?;
328 }
329 _ => events.push(event),
330 }
331 }
332
333 if !events.is_empty() {
335 return Ok(None);
336 }
337
338 self.mio_poll_once(&mut poll_state, next_release_time)?;
339 }
340 }
341}
342
343impl QuicClient for Group {
344 fn connect(
345 &self,
346 server_name: Option<&str>,
347 local: SocketAddr,
348 peer: SocketAddr,
349 config: &mut quiche::Config,
350 ) -> std::result::Result<Token, Self::Error> {
351 assert!(self.laddrs.contains_key(&local), "invalid local address.");
352
353 let token = self.group.connect(server_name, local, peer, config);
354
355 self.waker.wake()?;
356
357 Ok(token?)
358 }
359}
360
361impl QuicBind for Group {
362 fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
364 self.laddrs.keys()
365 }
366
367 fn bind<S>(laddrs: S, acceptor: Option<Acceptor>) -> Result<Self>
369 where
370 S: ToSocketAddrs,
371 {
372 let poll = mio::Poll::new()?;
373 let group = zerortt_poll::Group::new();
374
375 let mut sockets = vec![];
376 let mut addrs = HashMap::new();
377
378 for laddr in laddrs.to_socket_addrs()? {
379 let mut socket = UdpSocket::bind(laddr)?;
380 addrs.insert(socket.local_addr()?, sockets.len());
381
382 poll.registry().register(
383 &mut socket,
384 mio::Token(sockets.len()),
385 Interest::READABLE | Interest::WRITABLE,
386 )?;
387
388 sockets.push(QuicSocket::new(socket, 1024)?);
389 }
390
391 let waker = Waker::new(poll.registry(), mio::Token(sockets.len()))?;
392
393 Ok(Group {
394 waker,
395 group,
396 laddrs: addrs,
397 state: Mutex::new(PollState {
398 acceptor,
399 poll,
400 sockets,
401 }),
402 })
403 }
404}