s2n_quic_platform/io/
tokio.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{features::gso, message::default as message, socket, syscall};
5use s2n_quic_core::{
6    endpoint::Endpoint,
7    event::{self, EndpointPublisher as _},
8    inet::{self, SocketAddress},
9    io::event_loop::EventLoop,
10    path::{mtu, MaxMtu},
11    task::cooldown::Cooldown,
12    time::Clock as ClockTrait,
13};
14use std::{convert::TryInto, io, io::ErrorKind};
15use tokio::runtime::Handle;
16
17mod builder;
18mod clock;
19pub(crate) mod task;
20#[cfg(test)]
21mod tests;
22
23pub type PathHandle = message::Handle;
24pub use builder::Builder;
25pub(crate) use clock::Clock;
26
27#[derive(Debug, Default)]
28pub struct Io {
29    builder: Builder,
30}
31
32impl Io {
33    pub fn builder() -> Builder {
34        Builder::default()
35    }
36
37    pub fn new<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
38        let address = addr.to_socket_addrs()?.next().expect("missing address");
39        let builder = Builder::default().with_receive_address(address)?;
40        Ok(Self { builder })
41    }
42
43    pub fn start<E: Endpoint<PathHandle = PathHandle>>(
44        self,
45        mut endpoint: E,
46    ) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> {
47        let Builder {
48            handle,
49            rx_socket,
50            tx_socket,
51            recv_addr,
52            send_addr,
53            socket_recv_buffer_size,
54            socket_send_buffer_size,
55            queue_recv_buffer_size,
56            queue_send_buffer_size,
57            mtu_config_builder,
58            max_segments,
59            gro_enabled,
60            reuse_address,
61            reuse_port,
62            only_v6,
63        } = self.builder;
64
65        let clock = Clock::default();
66
67        let mut publisher = event::EndpointPublisherSubscriber::new(
68            event::builder::EndpointMeta {
69                endpoint_type: E::ENDPOINT_TYPE,
70                timestamp: clock.get_time(),
71            },
72            None,
73            endpoint.subscriber(),
74        );
75
76        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
77            configuration: event::builder::PlatformFeatureConfiguration::Gso {
78                max_segments: max_segments.into(),
79            },
80        });
81
82        // try to use the tokio runtime handle if provided, otherwise try to use the implicit tokio
83        // runtime in the current scope of the application.
84        let handle = if let Some(handle) = handle {
85            handle
86        } else {
87            Handle::try_current().map_err(std::io::Error::other)?
88        };
89
90        let guard = handle.enter();
91
92        let rx_socket = if let Some(rx_socket) = rx_socket {
93            rx_socket
94        } else if let Some(recv_addr) = recv_addr {
95            syscall::bind_udp(recv_addr, reuse_address, reuse_port, only_v6)?
96        } else {
97            return Err(io::Error::new(
98                io::ErrorKind::InvalidInput,
99                "missing bind address",
100            ));
101        };
102
103        let rx_addr = convert_addr_to_std(rx_socket.local_addr()?)?;
104
105        let tx_socket = if let Some(tx_socket) = tx_socket {
106            tx_socket
107        } else if let Some(send_addr) = send_addr {
108            syscall::bind_udp(send_addr, reuse_address, reuse_port, only_v6)?
109        } else {
110            // No tx_socket or send address was specified, so the tx socket
111            // will be a handle to the rx socket.
112            rx_socket.try_clone()?
113        };
114
115        if let Some(size) = socket_send_buffer_size {
116            tx_socket.set_send_buffer_size(size)?;
117        }
118
119        if let Some(size) = socket_recv_buffer_size {
120            rx_socket.set_recv_buffer_size(size)?;
121        }
122
123        let mut mtu_config = mtu_config_builder
124            .build()
125            .map_err(|err| io::Error::new(ErrorKind::InvalidInput, format!("{err}")))?;
126        let original_max_mtu = mtu_config.max_mtu();
127
128        // Configure MTU discovery
129        if !syscall::configure_mtu_disc(&tx_socket) {
130            // disable MTU probing if we can't prevent fragmentation
131            mtu_config = mtu::Config::MIN;
132        }
133
134        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
135            configuration: event::builder::PlatformFeatureConfiguration::BaseMtu {
136                mtu: mtu_config.base_mtu().into(),
137            },
138        });
139
140        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
141            configuration: event::builder::PlatformFeatureConfiguration::InitialMtu {
142                mtu: mtu_config.initial_mtu().into(),
143            },
144        });
145
146        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
147            configuration: event::builder::PlatformFeatureConfiguration::MaxMtu {
148                mtu: mtu_config.max_mtu().into(),
149            },
150        });
151
152        // Configure the socket with GRO
153        let gro_enabled = gro_enabled.unwrap_or(true) && syscall::configure_gro(&rx_socket);
154
155        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
156            configuration: event::builder::PlatformFeatureConfiguration::Gro {
157                enabled: gro_enabled,
158            },
159        });
160
161        // Configure packet info CMSG
162        syscall::configure_pktinfo(&rx_socket);
163
164        // Configure TOS/ECN
165        let tos_enabled = syscall::configure_tos(&rx_socket);
166
167        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
168            configuration: event::builder::PlatformFeatureConfiguration::Ecn {
169                enabled: tos_enabled,
170            },
171        });
172
173        let (stats_sender, stats_recv) = crate::socket::stats::channel();
174
175        let rx = {
176            // if GRO is enabled, then we need to provide the syscall with the maximum size buffer
177            let payload_len = if gro_enabled {
178                u16::MAX
179            } else {
180                // Use the originally configured MTU to allow larger packets to be received
181                // even if the tx MTU has been reduced due to configure_mtu_disc failing
182                original_max_mtu.into()
183            } as u32;
184
185            let rx_buffer_size = queue_recv_buffer_size.unwrap_or(8 * (1 << 20));
186            let entries = rx_buffer_size / payload_len;
187            let entries = if entries.is_power_of_two() {
188                entries
189            } else {
190                // round up to the nearest power of two, since the ring buffers require it
191                entries.next_power_of_two()
192            };
193
194            let mut consumers = vec![];
195
196            let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);
197
198            // configure the number of self-wakes before "cooling down" and waiting for epoll to
199            // complete
200            let rx_cooldown = cooldown("RX");
201
202            for idx in 0usize..rx_socket_count {
203                let (producer, consumer) = socket::ring::pair(entries, payload_len);
204                consumers.push(consumer);
205
206                // spawn a task that actually reads from the socket into the ring buffer
207                if idx + 1 == rx_socket_count {
208                    handle.spawn(task::rx(
209                        rx_socket,
210                        producer,
211                        rx_cooldown,
212                        stats_sender.clone(),
213                    ));
214                    break;
215                } else {
216                    let rx_socket = rx_socket.try_clone()?;
217                    handle.spawn(task::rx(
218                        rx_socket,
219                        producer,
220                        rx_cooldown.clone(),
221                        stats_sender.clone(),
222                    ));
223                }
224            }
225
226            // construct the RX side for the endpoint event loop
227            let max_mtu = MaxMtu::try_from(payload_len as u16).unwrap();
228            let addr: inet::SocketAddress = rx_addr.into();
229            socket::io::rx::Rx::new(consumers, max_mtu, addr.into())
230        };
231
232        let tx = {
233            let gso = crate::features::Gso::from(max_segments);
234
235            // compute the payload size for each message from the number of GSO segments we can
236            // fill
237            let payload_len = {
238                let max_mtu: u16 = mtu_config.max_mtu().into();
239                (max_mtu as u32 * gso.max_segments() as u32).min(u16::MAX as u32)
240            };
241
242            let tx_buffer_size = queue_send_buffer_size.unwrap_or(128 * 1024);
243            let entries = tx_buffer_size / payload_len;
244            let entries = if entries.is_power_of_two() {
245                entries
246            } else {
247                // round up to the nearest power of two, since the ring buffers require it
248                entries.next_power_of_two()
249            };
250
251            let mut producers = vec![];
252
253            let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);
254
255            // configure the number of self-wakes before "cooling down" and waiting for epoll to
256            // complete
257            let tx_cooldown = cooldown("TX");
258
259            for idx in 0usize..tx_socket_count {
260                let (producer, consumer) = socket::ring::pair(entries, payload_len);
261                producers.push(producer);
262
263                // spawn a task that actually flushes the ring buffer to the socket
264                if idx + 1 == tx_socket_count {
265                    handle.spawn(task::tx(
266                        tx_socket,
267                        consumer,
268                        gso.clone(),
269                        tx_cooldown,
270                        stats_sender.clone(),
271                    ));
272                    break;
273                } else {
274                    let tx_socket = tx_socket.try_clone()?;
275                    handle.spawn(task::tx(
276                        tx_socket,
277                        consumer,
278                        gso.clone(),
279                        tx_cooldown.clone(),
280                        stats_sender.clone(),
281                    ));
282                }
283            }
284
285            // construct the TX side for the endpoint event loop
286            socket::io::tx::Tx::new(producers, gso, mtu_config.max_mtu())
287        };
288
289        // Notify the endpoint of the MTU that we chose
290        endpoint.set_mtu_config(mtu_config);
291
292        let task = handle.spawn(
293            EventLoop {
294                endpoint,
295                clock,
296                rx,
297                tx,
298                cooldown: cooldown("ENDPOINT"),
299                stats: stats_recv,
300            }
301            .start(rx_addr.into()),
302        );
303
304        drop(guard);
305
306        Ok((task, rx_addr.into()))
307    }
308}
309
310fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAddr> {
311    addr.as_socket()
312        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain for socket"))
313}
314
315fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
316    std::env::var(name).ok().and_then(|v| v.parse().ok())
317}
318
319pub fn cooldown(direction: &str) -> Cooldown {
320    let name = format!("S2N_QUIC_UNSTABLE_COOLDOWN_{direction}");
321    let limit = parse_env(&name).unwrap_or(0);
322    Cooldown::new(limit)
323}