s2n_quic_platform/socket/io/
tx.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{features::Gso, message::Message, socket::ring::Producer};
5use core::task::{Context, Poll};
6use s2n_quic_core::{
7    event,
8    inet::ExplicitCongestionNotification,
9    io::tx,
10    path::{Handle as _, MaxMtu},
11    task::waker,
12};
13
14/// Structure for sending messages to producer channels
15pub struct Tx<T: Message> {
16    channels: Vec<Producer<T>>,
17    gso: Gso,
18    max_mtu: usize,
19    is_full: bool,
20}
21
22impl<T: Message> Tx<T> {
23    #[inline]
24    pub fn new(channels: Vec<Producer<T>>, gso: Gso, max_mtu: MaxMtu) -> Self {
25        Self {
26            channels,
27            gso,
28            max_mtu: max_mtu.into(),
29            is_full: true,
30        }
31    }
32}
33
34impl<T: Message> tx::Tx for Tx<T> {
35    type PathHandle = T::Handle;
36    type Queue = TxQueue<'static, T>;
37    type Error = ();
38
39    #[inline]
40    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
41        // We only need to poll for capacity if we completely filled up all of the channels.
42        // If we always polled, this would cause the endpoint to spin since most of the time it has
43        // capacity for sending.
44        if !self.is_full {
45            return Poll::Pending;
46        }
47
48        // NOTE: we don't wrap the above check in the contract as we'd technically violate the
49        // contract since we're returning `Pending` without storing a waker
50        waker::debug_assert_contract(cx, |cx| {
51            let mut is_any_ready = false;
52            let mut is_all_closed = true;
53
54            for channel in &mut self.channels {
55                match channel.poll_acquire(1, cx) {
56                    Poll::Ready(_) => {
57                        is_all_closed = false;
58                        is_any_ready = true;
59                    }
60                    Poll::Pending => {
61                        is_all_closed &= !channel.is_open();
62                    }
63                }
64            }
65
66            // if all of the channels were closed then shut the task down
67            if is_all_closed {
68                return Err(()).into();
69            }
70
71            // if any of the channels became ready then wake the endpoint up
72            if is_any_ready {
73                Poll::Ready(Ok(()))
74            } else {
75                Poll::Pending
76            }
77        })
78    }
79
80    #[inline]
81    fn queue<F: FnOnce(&mut Self::Queue)>(&mut self, f: F) {
82        let this: &'static mut Self = unsafe {
83            // Safety: As noted in the [transmute examples](https://doc.rust-lang.org/std/mem/fn.transmute.html#examples)
84            // it can be used to temporarily extend the lifetime of a reference. In this case, we
85            // don't want to use GATs until the MSRV is >=1.65.0, which means `Self::Queue` is not
86            // allowed to take generic lifetimes.
87            //
88            // We are left with using a `'static` lifetime here and encapsulating it in a private
89            // field. The `Self::Queue` struct is then borrowed for the lifetime of the `F`
90            // function. This will prevent the value from escaping beyond the lifetime of `&mut
91            // self`.
92            //
93            // See https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=9a32abe85c666f36fb2ec86496cc41b4
94            //
95            // Once https://github.com/aws/s2n-quic/issues/1742 is resolved this code can go away
96            core::mem::transmute(self)
97        };
98
99        let mut capacity = 0;
100        let mut first_with_free_slots = None;
101        for (idx, channel) in this.channels.iter_mut().enumerate() {
102            // try to make one more effort to acquire capacity for sending
103            let count = channel.acquire(u32::MAX) as usize;
104
105            if count > 0 && first_with_free_slots.is_none() {
106                // find the first channel that had capacity
107                first_with_free_slots = Some(idx);
108            }
109
110            capacity += count;
111        }
112
113        // mark that we're still full so we need to poll and wake up next iteration
114        this.is_full = capacity == 0;
115
116        // start with the first queue that has free slots, otherwise set the index to the length,
117        // which will return an AtCapacity error immediately.
118        let channel_index = first_with_free_slots.unwrap_or(this.channels.len());
119
120        // query the maximum number of segments we can fill at this point in time
121        //
122        // NOTE: this value could be lowered in the case the TX task encounters an error with GSO
123        //       so we do need to query it each iteration.
124        let max_segments = this.gso.max_segments();
125
126        let mut queue = TxQueue {
127            channels: &mut this.channels,
128            channel_index,
129            message_index: 0,
130            pending_release: 0,
131            gso_segment: None,
132            max_segments,
133            max_mtu: this.max_mtu,
134            capacity,
135            is_full: &mut this.is_full,
136        };
137
138        f(&mut queue);
139    }
140
141    #[inline]
142    fn handle_error<E: event::EndpointPublisher>(self, _error: Self::Error, _events: &mut E) {
143        // The only reason we would be returning an error is if a channel closed. This could either
144        // be because the endpoint is shutting down or one of the tasks panicked. Either way, we
145        // don't know what the cause is here so we don't have any events to emit.
146    }
147}
148
149/// Tracks the current state of a GSO message
150#[derive(Debug, Default)]
151pub struct GsoSegment<Handle> {
152    /// The path handle of the current GSO segment being written
153    ///
154    /// This is used to determine if future messages should be included in this payload or need a
155    /// separate packet.
156    handle: Handle,
157    /// The value of the ecn markings for the current GSO segment being written.
158    ///
159    /// This is used to determine if future messages should be included in this payload or need a
160    /// separate packet.
161    ecn: ExplicitCongestionNotification,
162    /// The number of segments that have been written
163    count: usize,
164    /// The size of each segment.
165    ///
166    /// Note that the last segment can be smaller than the previous ones and will result in a flush
167    size: usize,
168}
169
170pub struct TxQueue<'a, T: Message> {
171    channels: &'a mut [Producer<T>],
172    /// The channel index that we are currently operating on.
173    ///
174    /// This will be incremented after each channel is filled until it exceeds the len of `channels`.
175    channel_index: usize,
176    /// The message index into the current channel that we are operating on.
177    ///
178    /// This is incremented after each message is finished until it exceeds the acquired free
179    /// slots, after which the `channel_index` is incremented (and message_index is reset to zero).
180    message_index: usize,
181    /// The number of messages in the current channel that need to be released to notify the
182    /// consumer.
183    ///
184    /// This is to avoid calling `release` for each message and waking up the socket task too much.
185    pending_release: u32,
186    /// The current GSO segment we are filling, if any
187    gso_segment: Option<GsoSegment<T::Handle>>,
188    /// The maximum number of GSO segments that can be written
189    max_segments: usize,
190    /// The maximum MTU for any given packet
191    max_mtu: usize,
192    /// The maximum number of packets that can be sent in the current iteration
193    capacity: usize,
194    /// Used to track if we have filled up the producer queue and waiting on free slots to be
195    /// released by the consumer.
196    is_full: &'a mut bool,
197}
198
199impl<T: Message> TxQueue<'_, T> {
200    /// Tries to send a message as a GSO segment
201    ///
202    /// Returns the Err(Message) if it was not able to. Otherwise, the index of the GSO'd message is returned.
203    #[inline]
204    fn try_gso<M: tx::Message<Handle = T::Handle>>(
205        &mut self,
206        mut message: M,
207    ) -> Result<Result<tx::Outcome, M>, tx::Error> {
208        // the message doesn't support GSO to return it
209        if !T::SUPPORTS_GSO {
210            return Ok(Err(message));
211        }
212
213        let max_segments = self.max_segments;
214
215        let (prev_message, gso) = if let Some(gso) = self.gso_message() {
216            gso
217        } else {
218            return Ok(Err(message));
219        };
220
221        debug_assert!(
222            max_segments > 1,
223            "gso_segment should only be set when max_gso > 1"
224        );
225
226        // check to make sure the message can be GSO'd and can be included in the same
227        // GSO payload as the previous message
228        let can_gso = message.can_gso(gso.size, gso.count)
229            && message.path_handle().strict_eq(&gso.handle)
230            && message.ecn() == gso.ecn;
231
232        // if we can't use GSO then flush the current message
233        if !can_gso {
234            self.flush_gso();
235            return Ok(Err(message));
236        }
237
238        debug_assert!(
239            gso.count < max_segments,
240            "{} cannot exceed {}",
241            gso.count,
242            max_segments
243        );
244
245        let payload_len = prev_message.payload_len();
246
247        let buffer = unsafe {
248            // Create a slice the `message` can write into. This avoids having to update the
249            // payload length and worrying about panic safety.
250
251            let payload = prev_message.payload_ptr_mut();
252
253            // Safety: all payloads should have enough capacity to extend max_segments *
254            // gso.size
255            let current_payload = payload.add(payload_len);
256            core::slice::from_raw_parts_mut(current_payload, gso.size)
257        };
258        let buffer = tx::PayloadBuffer::new(buffer);
259
260        let size = message.write_payload(buffer, gso.count)?;
261
262        // we don't want to send empty packets
263        if size == 0 {
264            return Err(tx::Error::EmptyPayload);
265        }
266
267        unsafe {
268            debug_assert!(
269                gso.size >= size,
270                "the payload tried to write more than available"
271            );
272            // Set the len to the actual amount written to the payload. In case there is a bug,
273            // take the min anyway so we don't have errors elsewhere.
274            prev_message.set_payload_len(payload_len + size.min(gso.size));
275        }
276        // increment the number of segments that we've written
277        gso.count += 1;
278
279        debug_assert!(
280            gso.count <= max_segments,
281            "{} cannot exceed {}",
282            gso.count,
283            max_segments
284        );
285
286        // the last segment can be smaller but we can't write any more if it is
287        let size_mismatch = gso.size != size;
288
289        // we're bounded by the max_segments amount
290        let at_segment_limit = gso.count >= max_segments;
291
292        // we also can't write more data than u16::MAX
293        let at_payload_limit = gso.size * (gso.count + 1) > u16::MAX as usize;
294
295        // if we've hit any limits, then flush the GSO information to the message
296        if size_mismatch || at_segment_limit || at_payload_limit {
297            self.flush_gso();
298        }
299
300        Ok(Ok(tx::Outcome {
301            len: size,
302            index: 0,
303        }))
304    }
305
306    /// Flushes the current GSO message, if any
307    ///
308    /// In the `gso_segment` field, we track which message is currently being
309    /// built. If there ended up being multiple payloads written to the single message
310    /// we need to set the msg_control values to indicate the GSO size.
311    #[inline]
312    fn flush_gso(&mut self) {
313        // no need to flush if the message type doesn't support GSO
314        if !T::SUPPORTS_GSO {
315            debug_assert!(
316                self.gso_segment.is_none(),
317                "gso_segment should not be set if GSO is unsupported"
318            );
319            return;
320        }
321
322        if let Some((message, gso)) = self.gso_message() {
323            // only need to set the segment size if there was more than one payload written to the message
324            if gso.count > 1 {
325                message.set_segment_size(gso.size);
326            }
327
328            // clear out the current state and release the message
329            self.gso_segment = None;
330            self.release_message();
331        }
332    }
333
334    /// Returns the current GSO message waiting for more segments
335    #[inline]
336    fn gso_message(&mut self) -> Option<(&mut T, &mut GsoSegment<T::Handle>)> {
337        let gso = self.gso_segment.as_mut()?;
338
339        let channel = unsafe {
340            // Safety: the channel_index should always be in-bound if gso_segment is set
341            s2n_quic_core::assume!(self.channels.len() > self.channel_index);
342            &mut self.channels[self.channel_index]
343        };
344
345        let message = unsafe {
346            // Safety: the message_index should always be in-bound if gso_segment is set
347            let data = channel.data();
348            s2n_quic_core::assume!(data.len() > self.message_index);
349            &mut data[self.message_index]
350        };
351
352        Some((message, gso))
353    }
354
355    /// Releases the current message and marks it pending for release
356    #[inline]
357    fn release_message(&mut self) {
358        self.capacity -= 1;
359        *self.is_full = self.capacity == 0;
360
361        let channel = unsafe {
362            // Safety: the channel_index should always be in-bound if gso_segment is set
363            s2n_quic_core::assume!(self.channels.len() > self.channel_index);
364            &mut self.channels[self.channel_index]
365        };
366
367        channel.release_no_wake(1);
368
369        self.pending_release += 1;
370    }
371
372    /// Flushes the current channel and releases any pending messages
373    #[inline]
374    fn flush_channel(&mut self) {
375        if self.pending_release > 0 {
376            if let Some(channel) = self.channels.get_mut(self.channel_index) {
377                channel.wake();
378                self.message_index = 0;
379                self.pending_release = 0;
380            }
381        }
382    }
383}
384
385impl<T: Message> tx::Queue for TxQueue<'_, T> {
386    type Handle = T::Handle;
387
388    const SUPPORTS_ECN: bool = T::SUPPORTS_ECN;
389    const SUPPORTS_FLOW_LABELS: bool = T::SUPPORTS_FLOW_LABELS;
390
391    #[inline]
392    fn push<M>(&mut self, message: M) -> Result<tx::Outcome, tx::Error>
393    where
394        M: tx::Message<Handle = Self::Handle>,
395    {
396        // first try to write a GSO payload, if supported
397        let mut message = match self.try_gso(message)? {
398            Ok(outcome) => return Ok(outcome),
399            Err(message) => message,
400        };
401
402        // find the next free entry, if any
403        let entry = loop {
404            let channel = self
405                .channels
406                .get_mut(self.channel_index)
407                .ok_or(tx::Error::AtCapacity)?;
408
409            if let Some(entry) = channel.data().get_mut(self.message_index) {
410                break entry;
411            } else {
412                // this channel is out of free messages so flush it and move to the next channel
413                self.flush_channel();
414                self.channel_index += 1;
415            };
416        };
417
418        // prepare the entry for writing and reset all of the fields
419        unsafe {
420            // Safety: the entries should have been allocated with the MaxMtu
421            entry.reset(self.max_mtu);
422        }
423
424        // query the values that we use for GSO before we write the message to the entry
425        let handle = *message.path_handle();
426        let ecn = message.ecn();
427        let can_gso = message.can_gso(self.max_mtu, 0);
428
429        // write the message to the entry
430        let payload_len = entry.tx_write(message)?;
431
432        // if GSO is supported and we are allowed to have additional segments, store the GSO state
433        // for another potential message to be written later
434        if T::SUPPORTS_GSO && self.max_segments > 1 && can_gso {
435            self.gso_segment = Some(GsoSegment {
436                handle,
437                ecn,
438                count: 1,
439                size: payload_len,
440            });
441        } else {
442            // otherwise, release the message to the consumer
443            self.release_message();
444        }
445
446        // let the caller know how big the payload was
447        let outcome = tx::Outcome {
448            len: payload_len,
449            index: 0,
450        };
451
452        Ok(outcome)
453    }
454
455    #[inline]
456    fn flush(&mut self) {
457        // flush GSO segments between connections
458        self.flush_gso();
459    }
460
461    #[inline]
462    fn capacity(&self) -> usize {
463        self.capacity
464    }
465}
466
467impl<T: Message> Drop for TxQueue<'_, T> {
468    #[inline]
469    fn drop(&mut self) {
470        // flush the current GSO message, if possible
471        self.flush_gso();
472        // flush the pending messages for the channel
473        self.flush_channel();
474    }
475}