1use crate::{features::gso, message::default as message, socket, syscall};
5use s2n_quic_core::{
6 endpoint::Endpoint,
7 event::{self, EndpointPublisher as _},
8 inet::{self, SocketAddress},
9 io::event_loop::EventLoop,
10 path::{mtu, MaxMtu},
11 task::cooldown::Cooldown,
12 time::Clock as ClockTrait,
13};
14use std::{convert::TryInto, io, io::ErrorKind};
15use tokio::runtime::Handle;
16
17mod builder;
18mod clock;
19pub(crate) mod task;
20#[cfg(test)]
21mod tests;
22
23pub type PathHandle = message::Handle;
24pub use builder::Builder;
25pub(crate) use clock::Clock;
26
27#[derive(Debug, Default)]
28pub struct Io {
29 builder: Builder,
30}
31
32impl Io {
33 pub fn builder() -> Builder {
34 Builder::default()
35 }
36
37 pub fn new<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
38 let address = addr.to_socket_addrs()?.next().expect("missing address");
39 let builder = Builder::default().with_receive_address(address)?;
40 Ok(Self { builder })
41 }
42
43 pub fn start<E: Endpoint<PathHandle = PathHandle>>(
44 self,
45 mut endpoint: E,
46 ) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> {
47 let Builder {
48 handle,
49 rx_socket,
50 tx_socket,
51 recv_addr,
52 send_addr,
53 socket_recv_buffer_size,
54 socket_send_buffer_size,
55 queue_recv_buffer_size,
56 queue_send_buffer_size,
57 mtu_config_builder,
58 max_segments,
59 gro_enabled,
60 reuse_address,
61 reuse_port,
62 only_v6,
63 } = self.builder;
64
65 let clock = Clock::default();
66
67 let mut publisher = event::EndpointPublisherSubscriber::new(
68 event::builder::EndpointMeta {
69 endpoint_type: E::ENDPOINT_TYPE,
70 timestamp: clock.get_time(),
71 },
72 None,
73 endpoint.subscriber(),
74 );
75
76 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
77 configuration: event::builder::PlatformFeatureConfiguration::Gso {
78 max_segments: max_segments.into(),
79 },
80 });
81
82 let handle = if let Some(handle) = handle {
85 handle
86 } else {
87 Handle::try_current().map_err(std::io::Error::other)?
88 };
89
90 let guard = handle.enter();
91
92 let rx_socket = if let Some(rx_socket) = rx_socket {
93 rx_socket
94 } else if let Some(recv_addr) = recv_addr {
95 syscall::bind_udp(recv_addr, reuse_address, reuse_port, only_v6)?
96 } else {
97 return Err(io::Error::new(
98 io::ErrorKind::InvalidInput,
99 "missing bind address",
100 ));
101 };
102
103 let rx_addr = convert_addr_to_std(rx_socket.local_addr()?)?;
104
105 let tx_socket = if let Some(tx_socket) = tx_socket {
106 tx_socket
107 } else if let Some(send_addr) = send_addr {
108 syscall::bind_udp(send_addr, reuse_address, reuse_port, only_v6)?
109 } else {
110 rx_socket.try_clone()?
113 };
114
115 if let Some(size) = socket_send_buffer_size {
116 tx_socket.set_send_buffer_size(size)?;
117 }
118
119 if let Some(size) = socket_recv_buffer_size {
120 rx_socket.set_recv_buffer_size(size)?;
121 }
122
123 let mut mtu_config = mtu_config_builder
124 .build()
125 .map_err(|err| io::Error::new(ErrorKind::InvalidInput, format!("{err}")))?;
126 let original_max_mtu = mtu_config.max_mtu();
127
128 if !syscall::configure_mtu_disc(&tx_socket) {
130 mtu_config = mtu::Config::MIN;
132 }
133
134 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
135 configuration: event::builder::PlatformFeatureConfiguration::BaseMtu {
136 mtu: mtu_config.base_mtu().into(),
137 },
138 });
139
140 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
141 configuration: event::builder::PlatformFeatureConfiguration::InitialMtu {
142 mtu: mtu_config.initial_mtu().into(),
143 },
144 });
145
146 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
147 configuration: event::builder::PlatformFeatureConfiguration::MaxMtu {
148 mtu: mtu_config.max_mtu().into(),
149 },
150 });
151
152 let gro_enabled = gro_enabled.unwrap_or(true) && syscall::configure_gro(&rx_socket);
154
155 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
156 configuration: event::builder::PlatformFeatureConfiguration::Gro {
157 enabled: gro_enabled,
158 },
159 });
160
161 syscall::configure_pktinfo(&rx_socket);
163
164 let tos_enabled = syscall::configure_tos(&rx_socket);
166
167 publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
168 configuration: event::builder::PlatformFeatureConfiguration::Ecn {
169 enabled: tos_enabled,
170 },
171 });
172
173 let (stats_sender, stats_recv) = crate::socket::stats::channel();
174
175 let rx = {
176 let payload_len = if gro_enabled {
178 u16::MAX
179 } else {
180 original_max_mtu.into()
183 } as u32;
184
185 let rx_buffer_size = queue_recv_buffer_size.unwrap_or(8 * (1 << 20));
186 let entries = rx_buffer_size / payload_len;
187 let entries = if entries.is_power_of_two() {
188 entries
189 } else {
190 entries.next_power_of_two()
192 };
193
194 let mut consumers = vec![];
195
196 let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);
197
198 let rx_cooldown = cooldown("RX");
201
202 for idx in 0usize..rx_socket_count {
203 let (producer, consumer) = socket::ring::pair(entries, payload_len);
204 consumers.push(consumer);
205
206 if idx + 1 == rx_socket_count {
208 handle.spawn(task::rx(
209 rx_socket,
210 producer,
211 rx_cooldown,
212 stats_sender.clone(),
213 ));
214 break;
215 } else {
216 let rx_socket = rx_socket.try_clone()?;
217 handle.spawn(task::rx(
218 rx_socket,
219 producer,
220 rx_cooldown.clone(),
221 stats_sender.clone(),
222 ));
223 }
224 }
225
226 let max_mtu = MaxMtu::try_from(payload_len as u16).unwrap();
228 let addr: inet::SocketAddress = rx_addr.into();
229 socket::io::rx::Rx::new(consumers, max_mtu, addr.into())
230 };
231
232 let tx = {
233 let gso = crate::features::Gso::from(max_segments);
234
235 let payload_len = {
238 let max_mtu: u16 = mtu_config.max_mtu().into();
239 (max_mtu as u32 * gso.max_segments() as u32).min(u16::MAX as u32)
240 };
241
242 let tx_buffer_size = queue_send_buffer_size.unwrap_or(128 * 1024);
243 let entries = tx_buffer_size / payload_len;
244 let entries = if entries.is_power_of_two() {
245 entries
246 } else {
247 entries.next_power_of_two()
249 };
250
251 let mut producers = vec![];
252
253 let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);
254
255 let tx_cooldown = cooldown("TX");
258
259 for idx in 0usize..tx_socket_count {
260 let (producer, consumer) = socket::ring::pair(entries, payload_len);
261 producers.push(producer);
262
263 if idx + 1 == tx_socket_count {
265 handle.spawn(task::tx(
266 tx_socket,
267 consumer,
268 gso.clone(),
269 tx_cooldown,
270 stats_sender.clone(),
271 ));
272 break;
273 } else {
274 let tx_socket = tx_socket.try_clone()?;
275 handle.spawn(task::tx(
276 tx_socket,
277 consumer,
278 gso.clone(),
279 tx_cooldown.clone(),
280 stats_sender.clone(),
281 ));
282 }
283 }
284
285 socket::io::tx::Tx::new(producers, gso, mtu_config.max_mtu())
287 };
288
289 endpoint.set_mtu_config(mtu_config);
291
292 let task = handle.spawn(
293 EventLoop {
294 endpoint,
295 clock,
296 rx,
297 tx,
298 cooldown: cooldown("ENDPOINT"),
299 stats: stats_recv,
300 }
301 .start(rx_addr.into()),
302 );
303
304 drop(guard);
305
306 Ok((task, rx_addr.into()))
307 }
308}
309
310fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAddr> {
311 addr.as_socket()
312 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain for socket"))
313}
314
315fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
316 std::env::var(name).ok().and_then(|v| v.parse().ok())
317}
318
319pub fn cooldown(direction: &str) -> Cooldown {
320 let name = format!("S2N_QUIC_UNSTABLE_COOLDOWN_{direction}");
321 let limit = parse_env(&name).unwrap_or(0);
322 Cooldown::new(limit)
323}