s2n_quic_platform/io/tokio/task/
simple.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    features::Gso,
6    message::{simple::Message, Message as _},
7    socket::{
8        ring, stats, task,
9        task::{rx, tx},
10    },
11    syscall::SocketEvents,
12};
13use core::task::{Context, Poll};
14use s2n_quic_core::task::cooldown::Cooldown;
15use tokio::{io, net::UdpSocket};
16
17pub async fn rx<S: Into<std::net::UdpSocket>>(
18    socket: S,
19    producer: ring::Producer<Message>,
20    cooldown: Cooldown,
21    stats: stats::Sender,
22) -> io::Result<()> {
23    let socket = socket.into();
24    socket.set_nonblocking(true).unwrap();
25
26    let socket = UdpSocket::from_std(socket).unwrap();
27    let result = task::Receiver::new(producer, socket, cooldown, stats).await;
28    if let Some(err) = result {
29        Err(err)
30    } else {
31        Ok(())
32    }
33}
34
35pub async fn tx<S: Into<std::net::UdpSocket>>(
36    socket: S,
37    consumer: ring::Consumer<Message>,
38    gso: Gso,
39    cooldown: Cooldown,
40    stats: stats::Sender,
41) -> io::Result<()> {
42    let socket = socket.into();
43    socket.set_nonblocking(true).unwrap();
44
45    let socket = UdpSocket::from_std(socket).unwrap();
46    let result = task::Sender::new(consumer, socket, gso, cooldown, stats).await;
47    if let Some(err) = result {
48        Err(err)
49    } else {
50        Ok(())
51    }
52}
53
54impl tx::Socket<Message> for UdpSocket {
55    type Error = io::Error;
56
57    #[inline]
58    fn send(
59        &mut self,
60        cx: &mut Context,
61        entries: &mut [Message],
62        events: &mut tx::Events,
63        stats: &stats::Sender,
64    ) -> io::Result<()> {
65        for entry in entries {
66            let target = (*entry.remote_address()).into();
67            let payload = entry.payload_mut();
68
69            let res = self.poll_send_to(cx, payload, target);
70            stats.send().on_operation(&res, |_len| 1);
71            match res {
72                Poll::Ready(Ok(_)) => {
73                    if events.on_complete(1).is_break() {
74                        return Ok(());
75                    }
76                }
77                Poll::Ready(Err(err)) => {
78                    if events.on_error(err).is_break() {
79                        return Ok(());
80                    }
81                }
82                Poll::Pending => {
83                    events.blocked();
84                    break;
85                }
86            }
87        }
88
89        Ok(())
90    }
91}
92
93impl rx::Socket<Message> for UdpSocket {
94    type Error = io::Error;
95
96    #[inline]
97    fn recv(
98        &mut self,
99        cx: &mut Context,
100        entries: &mut [Message],
101        events: &mut rx::Events,
102        stats: &stats::Sender,
103    ) -> io::Result<()> {
104        for entry in entries {
105            let payload = entry.payload_mut();
106            let mut buf = io::ReadBuf::new(payload);
107
108            let res = self.poll_recv_from(cx, &mut buf);
109            stats.recv().on_operation(&res, |_len| 1);
110            match res {
111                Poll::Ready(Ok(addr)) => {
112                    unsafe {
113                        let len = buf.filled().len();
114                        entry.set_payload_len(len);
115                    }
116                    entry.set_remote_address(&(addr.into()));
117
118                    if events.on_complete(1).is_break() {
119                        return Ok(());
120                    }
121                }
122                Poll::Ready(Err(err)) => {
123                    if events.on_error(err).is_break() {
124                        return Ok(());
125                    }
126                }
127                Poll::Pending => {
128                    events.blocked();
129                    break;
130                }
131            }
132        }
133
134        Ok(())
135    }
136}