s2n_quic_platform/socket/io/tx.rs
1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{features::Gso, message::Message, socket::ring::Producer};
5use core::task::{Context, Poll};
6use s2n_quic_core::{
7 event,
8 inet::ExplicitCongestionNotification,
9 io::tx,
10 path::{Handle as _, MaxMtu},
11 task::waker,
12};
13
14/// Structure for sending messages to producer channels
15pub struct Tx<T: Message> {
16 channels: Vec<Producer<T>>,
17 gso: Gso,
18 max_mtu: usize,
19 is_full: bool,
20}
21
22impl<T: Message> Tx<T> {
23 #[inline]
24 pub fn new(channels: Vec<Producer<T>>, gso: Gso, max_mtu: MaxMtu) -> Self {
25 Self {
26 channels,
27 gso,
28 max_mtu: max_mtu.into(),
29 is_full: true,
30 }
31 }
32}
33
34impl<T: Message> tx::Tx for Tx<T> {
35 type PathHandle = T::Handle;
36 type Queue = TxQueue<'static, T>;
37 type Error = ();
38
39 #[inline]
40 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
41 // We only need to poll for capacity if we completely filled up all of the channels.
42 // If we always polled, this would cause the endpoint to spin since most of the time it has
43 // capacity for sending.
44 if !self.is_full {
45 return Poll::Pending;
46 }
47
48 // NOTE: we don't wrap the above check in the contract as we'd technically violate the
49 // contract since we're returning `Pending` without storing a waker
50 waker::debug_assert_contract(cx, |cx| {
51 let mut is_any_ready = false;
52 let mut is_all_closed = true;
53
54 for channel in &mut self.channels {
55 match channel.poll_acquire(1, cx) {
56 Poll::Ready(_) => {
57 is_all_closed = false;
58 is_any_ready = true;
59 }
60 Poll::Pending => {
61 is_all_closed &= !channel.is_open();
62 }
63 }
64 }
65
66 // if all of the channels were closed then shut the task down
67 if is_all_closed {
68 return Err(()).into();
69 }
70
71 // if any of the channels became ready then wake the endpoint up
72 if is_any_ready {
73 Poll::Ready(Ok(()))
74 } else {
75 Poll::Pending
76 }
77 })
78 }
79
80 #[inline]
81 fn queue<F: FnOnce(&mut Self::Queue)>(&mut self, f: F) {
82 let this: &'static mut Self = unsafe {
83 // Safety: As noted in the [transmute examples](https://doc.rust-lang.org/std/mem/fn.transmute.html#examples)
84 // it can be used to temporarily extend the lifetime of a reference. In this case, we
85 // don't want to use GATs until the MSRV is >=1.65.0, which means `Self::Queue` is not
86 // allowed to take generic lifetimes.
87 //
88 // We are left with using a `'static` lifetime here and encapsulating it in a private
89 // field. The `Self::Queue` struct is then borrowed for the lifetime of the `F`
90 // function. This will prevent the value from escaping beyond the lifetime of `&mut
91 // self`.
92 //
93 // See https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=9a32abe85c666f36fb2ec86496cc41b4
94 //
95 // Once https://github.com/aws/s2n-quic/issues/1742 is resolved this code can go away
96 core::mem::transmute(self)
97 };
98
99 let mut capacity = 0;
100 let mut first_with_free_slots = None;
101 for (idx, channel) in this.channels.iter_mut().enumerate() {
102 // try to make one more effort to acquire capacity for sending
103 let count = channel.acquire(u32::MAX) as usize;
104
105 if count > 0 && first_with_free_slots.is_none() {
106 // find the first channel that had capacity
107 first_with_free_slots = Some(idx);
108 }
109
110 capacity += count;
111 }
112
113 // mark that we're still full so we need to poll and wake up next iteration
114 this.is_full = capacity == 0;
115
116 // start with the first queue that has free slots, otherwise set the index to the length,
117 // which will return an AtCapacity error immediately.
118 let channel_index = first_with_free_slots.unwrap_or(this.channels.len());
119
120 // query the maximum number of segments we can fill at this point in time
121 //
122 // NOTE: this value could be lowered in the case the TX task encounters an error with GSO
123 // so we do need to query it each iteration.
124 let max_segments = this.gso.max_segments();
125
126 let mut queue = TxQueue {
127 channels: &mut this.channels,
128 channel_index,
129 message_index: 0,
130 pending_release: 0,
131 gso_segment: None,
132 max_segments,
133 max_mtu: this.max_mtu,
134 capacity,
135 is_full: &mut this.is_full,
136 };
137
138 f(&mut queue);
139 }
140
141 #[inline]
142 fn handle_error<E: event::EndpointPublisher>(self, _error: Self::Error, _events: &mut E) {
143 // The only reason we would be returning an error is if a channel closed. This could either
144 // be because the endpoint is shutting down or one of the tasks panicked. Either way, we
145 // don't know what the cause is here so we don't have any events to emit.
146 }
147}
148
149/// Tracks the current state of a GSO message
150#[derive(Debug, Default)]
151pub struct GsoSegment<Handle> {
152 /// The path handle of the current GSO segment being written
153 ///
154 /// This is used to determine if future messages should be included in this payload or need a
155 /// separate packet.
156 handle: Handle,
157 /// The value of the ecn markings for the current GSO segment being written.
158 ///
159 /// This is used to determine if future messages should be included in this payload or need a
160 /// separate packet.
161 ecn: ExplicitCongestionNotification,
162 /// The number of segments that have been written
163 count: usize,
164 /// The size of each segment.
165 ///
166 /// Note that the last segment can be smaller than the previous ones and will result in a flush
167 size: usize,
168}
169
170pub struct TxQueue<'a, T: Message> {
171 channels: &'a mut [Producer<T>],
172 /// The channel index that we are currently operating on.
173 ///
174 /// This will be incremented after each channel is filled until it exceeds the len of `channels`.
175 channel_index: usize,
176 /// The message index into the current channel that we are operating on.
177 ///
178 /// This is incremented after each message is finished until it exceeds the acquired free
179 /// slots, after which the `channel_index` is incremented (and message_index is reset to zero).
180 message_index: usize,
181 /// The number of messages in the current channel that need to be released to notify the
182 /// consumer.
183 ///
184 /// This is to avoid calling `release` for each message and waking up the socket task too much.
185 pending_release: u32,
186 /// The current GSO segment we are filling, if any
187 gso_segment: Option<GsoSegment<T::Handle>>,
188 /// The maximum number of GSO segments that can be written
189 max_segments: usize,
190 /// The maximum MTU for any given packet
191 max_mtu: usize,
192 /// The maximum number of packets that can be sent in the current iteration
193 capacity: usize,
194 /// Used to track if we have filled up the producer queue and waiting on free slots to be
195 /// released by the consumer.
196 is_full: &'a mut bool,
197}
198
199impl<T: Message> TxQueue<'_, T> {
200 /// Tries to send a message as a GSO segment
201 ///
202 /// Returns the Err(Message) if it was not able to. Otherwise, the index of the GSO'd message is returned.
203 #[inline]
204 fn try_gso<M: tx::Message<Handle = T::Handle>>(
205 &mut self,
206 mut message: M,
207 ) -> Result<Result<tx::Outcome, M>, tx::Error> {
208 // the message doesn't support GSO to return it
209 if !T::SUPPORTS_GSO {
210 return Ok(Err(message));
211 }
212
213 let max_segments = self.max_segments;
214
215 let (prev_message, gso) = if let Some(gso) = self.gso_message() {
216 gso
217 } else {
218 return Ok(Err(message));
219 };
220
221 debug_assert!(
222 max_segments > 1,
223 "gso_segment should only be set when max_gso > 1"
224 );
225
226 // check to make sure the message can be GSO'd and can be included in the same
227 // GSO payload as the previous message
228 let can_gso = message.can_gso(gso.size, gso.count)
229 && message.path_handle().strict_eq(&gso.handle)
230 && message.ecn() == gso.ecn;
231
232 // if we can't use GSO then flush the current message
233 if !can_gso {
234 self.flush_gso();
235 return Ok(Err(message));
236 }
237
238 debug_assert!(
239 gso.count < max_segments,
240 "{} cannot exceed {}",
241 gso.count,
242 max_segments
243 );
244
245 let payload_len = prev_message.payload_len();
246
247 let buffer = unsafe {
248 // Create a slice the `message` can write into. This avoids having to update the
249 // payload length and worrying about panic safety.
250
251 let payload = prev_message.payload_ptr_mut();
252
253 // Safety: all payloads should have enough capacity to extend max_segments *
254 // gso.size
255 let current_payload = payload.add(payload_len);
256 core::slice::from_raw_parts_mut(current_payload, gso.size)
257 };
258 let buffer = tx::PayloadBuffer::new(buffer);
259
260 let size = message.write_payload(buffer, gso.count)?;
261
262 // we don't want to send empty packets
263 if size == 0 {
264 return Err(tx::Error::EmptyPayload);
265 }
266
267 unsafe {
268 debug_assert!(
269 gso.size >= size,
270 "the payload tried to write more than available"
271 );
272 // Set the len to the actual amount written to the payload. In case there is a bug,
273 // take the min anyway so we don't have errors elsewhere.
274 prev_message.set_payload_len(payload_len + size.min(gso.size));
275 }
276 // increment the number of segments that we've written
277 gso.count += 1;
278
279 debug_assert!(
280 gso.count <= max_segments,
281 "{} cannot exceed {}",
282 gso.count,
283 max_segments
284 );
285
286 // the last segment can be smaller but we can't write any more if it is
287 let size_mismatch = gso.size != size;
288
289 // we're bounded by the max_segments amount
290 let at_segment_limit = gso.count >= max_segments;
291
292 // we also can't write more data than u16::MAX
293 let at_payload_limit = gso.size * (gso.count + 1) > u16::MAX as usize;
294
295 // if we've hit any limits, then flush the GSO information to the message
296 if size_mismatch || at_segment_limit || at_payload_limit {
297 self.flush_gso();
298 }
299
300 Ok(Ok(tx::Outcome {
301 len: size,
302 index: 0,
303 }))
304 }
305
306 /// Flushes the current GSO message, if any
307 ///
308 /// In the `gso_segment` field, we track which message is currently being
309 /// built. If there ended up being multiple payloads written to the single message
310 /// we need to set the msg_control values to indicate the GSO size.
311 #[inline]
312 fn flush_gso(&mut self) {
313 // no need to flush if the message type doesn't support GSO
314 if !T::SUPPORTS_GSO {
315 debug_assert!(
316 self.gso_segment.is_none(),
317 "gso_segment should not be set if GSO is unsupported"
318 );
319 return;
320 }
321
322 if let Some((message, gso)) = self.gso_message() {
323 // only need to set the segment size if there was more than one payload written to the message
324 if gso.count > 1 {
325 message.set_segment_size(gso.size);
326 }
327
328 // clear out the current state and release the message
329 self.gso_segment = None;
330 self.release_message();
331 }
332 }
333
334 /// Returns the current GSO message waiting for more segments
335 #[inline]
336 fn gso_message(&mut self) -> Option<(&mut T, &mut GsoSegment<T::Handle>)> {
337 let gso = self.gso_segment.as_mut()?;
338
339 let channel = unsafe {
340 // Safety: the channel_index should always be in-bound if gso_segment is set
341 s2n_quic_core::assume!(self.channels.len() > self.channel_index);
342 &mut self.channels[self.channel_index]
343 };
344
345 let message = unsafe {
346 // Safety: the message_index should always be in-bound if gso_segment is set
347 let data = channel.data();
348 s2n_quic_core::assume!(data.len() > self.message_index);
349 &mut data[self.message_index]
350 };
351
352 Some((message, gso))
353 }
354
355 /// Releases the current message and marks it pending for release
356 #[inline]
357 fn release_message(&mut self) {
358 self.capacity -= 1;
359 *self.is_full = self.capacity == 0;
360
361 let channel = unsafe {
362 // Safety: the channel_index should always be in-bound if gso_segment is set
363 s2n_quic_core::assume!(self.channels.len() > self.channel_index);
364 &mut self.channels[self.channel_index]
365 };
366
367 channel.release_no_wake(1);
368
369 self.pending_release += 1;
370 }
371
372 /// Flushes the current channel and releases any pending messages
373 #[inline]
374 fn flush_channel(&mut self) {
375 if self.pending_release > 0 {
376 if let Some(channel) = self.channels.get_mut(self.channel_index) {
377 channel.wake();
378 self.message_index = 0;
379 self.pending_release = 0;
380 }
381 }
382 }
383}
384
385impl<T: Message> tx::Queue for TxQueue<'_, T> {
386 type Handle = T::Handle;
387
388 const SUPPORTS_ECN: bool = T::SUPPORTS_ECN;
389 const SUPPORTS_FLOW_LABELS: bool = T::SUPPORTS_FLOW_LABELS;
390
391 #[inline]
392 fn push<M>(&mut self, message: M) -> Result<tx::Outcome, tx::Error>
393 where
394 M: tx::Message<Handle = Self::Handle>,
395 {
396 // first try to write a GSO payload, if supported
397 let mut message = match self.try_gso(message)? {
398 Ok(outcome) => return Ok(outcome),
399 Err(message) => message,
400 };
401
402 // find the next free entry, if any
403 let entry = loop {
404 let channel = self
405 .channels
406 .get_mut(self.channel_index)
407 .ok_or(tx::Error::AtCapacity)?;
408
409 if let Some(entry) = channel.data().get_mut(self.message_index) {
410 break entry;
411 } else {
412 // this channel is out of free messages so flush it and move to the next channel
413 self.flush_channel();
414 self.channel_index += 1;
415 };
416 };
417
418 // prepare the entry for writing and reset all of the fields
419 unsafe {
420 // Safety: the entries should have been allocated with the MaxMtu
421 entry.reset(self.max_mtu);
422 }
423
424 // query the values that we use for GSO before we write the message to the entry
425 let handle = *message.path_handle();
426 let ecn = message.ecn();
427 let can_gso = message.can_gso(self.max_mtu, 0);
428
429 // write the message to the entry
430 let payload_len = entry.tx_write(message)?;
431
432 // if GSO is supported and we are allowed to have additional segments, store the GSO state
433 // for another potential message to be written later
434 if T::SUPPORTS_GSO && self.max_segments > 1 && can_gso {
435 self.gso_segment = Some(GsoSegment {
436 handle,
437 ecn,
438 count: 1,
439 size: payload_len,
440 });
441 } else {
442 // otherwise, release the message to the consumer
443 self.release_message();
444 }
445
446 // let the caller know how big the payload was
447 let outcome = tx::Outcome {
448 len: payload_len,
449 index: 0,
450 };
451
452 Ok(outcome)
453 }
454
455 #[inline]
456 fn flush(&mut self) {
457 // flush GSO segments between connections
458 self.flush_gso();
459 }
460
461 #[inline]
462 fn capacity(&self) -> usize {
463 self.capacity
464 }
465}
466
467impl<T: Message> Drop for TxQueue<'_, T> {
468 #[inline]
469 fn drop(&mut self) {
470 // flush the current GSO message, if possible
471 self.flush_gso();
472 // flush the pending messages for the channel
473 self.flush_channel();
474 }
475}