use crate::{features::gso, message::default as message, socket, syscall};
use s2n_quic_core::{
endpoint::Endpoint,
event::{self, EndpointPublisher as _},
inet::{self, SocketAddress},
io::event_loop::EventLoop,
path::{mtu, MaxMtu},
task::cooldown::Cooldown,
time::Clock as ClockTrait,
};
use std::{convert::TryInto, io, io::ErrorKind};
use tokio::runtime::Handle;
mod builder;
mod clock;
pub(crate) mod task;
#[cfg(test)]
mod tests;
pub type PathHandle = message::Handle;
pub use builder::Builder;
pub(crate) use clock::Clock;
#[derive(Debug, Default)]
pub struct Io {
builder: Builder,
}
impl Io {
pub fn builder() -> Builder {
Builder::default()
}
pub fn new<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
let address = addr.to_socket_addrs()?.next().expect("missing address");
let builder = Builder::default().with_receive_address(address)?;
Ok(Self { builder })
}
pub fn start<E: Endpoint<PathHandle = PathHandle>>(
self,
mut endpoint: E,
) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> {
let Builder {
handle,
rx_socket,
tx_socket,
recv_addr,
send_addr,
socket_recv_buffer_size,
socket_send_buffer_size,
queue_recv_buffer_size,
queue_send_buffer_size,
mtu_config_builder,
max_segments,
gro_enabled,
reuse_address,
reuse_port,
} = self.builder;
let clock = Clock::default();
let mut publisher = event::EndpointPublisherSubscriber::new(
event::builder::EndpointMeta {
endpoint_type: E::ENDPOINT_TYPE,
timestamp: clock.get_time(),
},
None,
endpoint.subscriber(),
);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Gso {
max_segments: max_segments.into(),
},
});
let handle = if let Some(handle) = handle {
handle
} else {
Handle::try_current().map_err(|err| std::io::Error::new(io::ErrorKind::Other, err))?
};
let guard = handle.enter();
let rx_socket = if let Some(rx_socket) = rx_socket {
rx_socket
} else if let Some(recv_addr) = recv_addr {
syscall::bind_udp(recv_addr, reuse_address, reuse_port)?
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"missing bind address",
));
};
let rx_addr = convert_addr_to_std(rx_socket.local_addr()?)?;
let tx_socket = if let Some(tx_socket) = tx_socket {
tx_socket
} else if let Some(send_addr) = send_addr {
syscall::bind_udp(send_addr, reuse_address, reuse_port)?
} else {
rx_socket.try_clone()?
};
if let Some(size) = socket_send_buffer_size {
tx_socket.set_send_buffer_size(size)?;
}
if let Some(size) = socket_recv_buffer_size {
rx_socket.set_recv_buffer_size(size)?;
}
let mut mtu_config = mtu_config_builder
.build()
.map_err(|err| io::Error::new(ErrorKind::InvalidInput, format!("{err}")))?;
let original_max_mtu = mtu_config.max_mtu;
if !syscall::configure_mtu_disc(&tx_socket) {
mtu_config = mtu::Config::MIN;
}
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::BaseMtu {
mtu: mtu_config.base_mtu.into(),
},
});
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::InitialMtu {
mtu: mtu_config.initial_mtu.into(),
},
});
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::MaxMtu {
mtu: mtu_config.max_mtu.into(),
},
});
let gro_enabled = gro_enabled.unwrap_or(true) && syscall::configure_gro(&rx_socket);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Gro {
enabled: gro_enabled,
},
});
syscall::configure_pktinfo(&rx_socket);
let tos_enabled = syscall::configure_tos(&rx_socket);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Ecn {
enabled: tos_enabled,
},
});
let rx = {
let payload_len = if gro_enabled {
u16::MAX
} else {
original_max_mtu.into()
} as u32;
let rx_buffer_size = queue_recv_buffer_size.unwrap_or(8 * (1 << 20));
let entries = rx_buffer_size / payload_len;
let entries = if entries.is_power_of_two() {
entries
} else {
entries.next_power_of_two()
};
let mut consumers = vec![];
let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);
let rx_cooldown = cooldown("RX");
for idx in 0usize..rx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
consumers.push(consumer);
if idx + 1 == rx_socket_count {
handle.spawn(task::rx(rx_socket, producer, rx_cooldown));
break;
} else {
let rx_socket = rx_socket.try_clone()?;
handle.spawn(task::rx(rx_socket, producer, rx_cooldown.clone()));
}
}
let max_mtu = MaxMtu::try_from(payload_len as u16).unwrap();
let addr: inet::SocketAddress = rx_addr.into();
socket::io::rx::Rx::new(consumers, max_mtu, addr.into())
};
let tx = {
let gso = crate::features::Gso::from(max_segments);
let payload_len = {
let max_mtu: u16 = mtu_config.max_mtu.into();
(max_mtu as u32 * gso.max_segments() as u32).min(u16::MAX as u32)
};
let tx_buffer_size = queue_send_buffer_size.unwrap_or(128 * 1024);
let entries = tx_buffer_size / payload_len;
let entries = if entries.is_power_of_two() {
entries
} else {
entries.next_power_of_two()
};
let mut producers = vec![];
let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);
let tx_cooldown = cooldown("TX");
for idx in 0usize..tx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
producers.push(producer);
if idx + 1 == tx_socket_count {
handle.spawn(task::tx(tx_socket, consumer, gso.clone(), tx_cooldown));
break;
} else {
let tx_socket = tx_socket.try_clone()?;
handle.spawn(task::tx(
tx_socket,
consumer,
gso.clone(),
tx_cooldown.clone(),
));
}
}
socket::io::tx::Tx::new(producers, gso, mtu_config.max_mtu)
};
endpoint.set_mtu_config(mtu_config);
let task = handle.spawn(
EventLoop {
endpoint,
clock,
rx,
tx,
cooldown: cooldown("ENDPOINT"),
}
.start(),
);
drop(guard);
Ok((task, rx_addr.into()))
}
}
fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAddr> {
addr.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain for socket"))
}
fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
std::env::var(name).ok().and_then(|v| v.parse().ok())
}
pub fn cooldown(direction: &str) -> Cooldown {
let name = format!("S2N_QUIC_UNSTABLE_COOLDOWN_{direction}");
let limit = parse_env(&name).unwrap_or(0);
Cooldown::new(limit)
}