s2n_quic_core/io/rx/
pair.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::Rx;
5use crate::{event, inet::datagram};
6use core::task::{Context, Poll};
7
8/// A pair of Rx channels that feed into the same endpoint
9pub struct Channel<A, B>
10where
11    A: Rx,
12    B: Rx<PathHandle = A::PathHandle>,
13{
14    pub(super) a: A,
15    pub(super) b: B,
16}
17
18impl<A, B> Rx for Channel<A, B>
19where
20    A: Rx,
21    B: Rx<PathHandle = A::PathHandle>,
22    A::Queue: 'static,
23    B::Queue: 'static,
24{
25    type PathHandle = A::PathHandle;
26    type Queue = Queue<'static, A::Queue, B::Queue>;
27    type Error = Error<A::Error, B::Error>;
28
29    #[inline]
30    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
31        // assume we aren't ready until one channel returns ready
32        let mut is_ready = false;
33
34        macro_rules! ready {
35            ($value:expr, $var:ident) => {
36                match $value {
37                    Poll::Ready(Ok(())) => is_ready = true,
38                    Poll::Ready(Err(err)) => {
39                        // one of the channels returned an error so shut down both
40                        return Err(Error::$var(err)).into();
41                    }
42                    Poll::Pending => {}
43                }
44            };
45        }
46
47        ready!(self.a.poll_ready(cx), A);
48        ready!(self.b.poll_ready(cx), B);
49
50        if is_ready {
51            Poll::Ready(Ok(()))
52        } else {
53            Poll::Pending
54        }
55    }
56
57    #[inline]
58    fn queue<F: FnOnce(&mut Self::Queue)>(&mut self, f: F) {
59        let a = &mut self.a;
60        let b = &mut self.b;
61        a.queue(|a| {
62            b.queue(|b| {
63                let (a, b): (&'static mut _, &'static mut _) = unsafe {
64                    // Safety: As noted in the [transmute examples](https://doc.rust-lang.org/std/mem/fn.transmute.html#examples)
65                    // it can be used to temporarily extend the lifetime of a reference. In this case, we
66                    // don't want to use GATs until the MSRV is >=1.65.0, which means `Self::Queue` is not
67                    // allowed to take generic lifetimes.
68                    //
69                    // We are left with using a `'static` lifetime here and encapsulating it in a private
70                    // field. The `Self::Queue` struct is then borrowed for the lifetime of the `F`
71                    // function. This will prevent the value from escaping beyond the lifetime of `&mut
72                    // self`.
73                    //
74                    // See https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=9a32abe85c666f36fb2ec86496cc41b4
75                    //
76                    // Once https://github.com/aws/s2n-quic/issues/1742 is resolved this code can go away
77                    (
78                        core::mem::transmute::<&mut <A as Rx>::Queue, &mut <A as Rx>::Queue>(a),
79                        core::mem::transmute::<&mut <B as Rx>::Queue, &mut <B as Rx>::Queue>(b),
80                    )
81                };
82
83                let mut queue = Queue { a, b };
84                f(&mut queue);
85            });
86        });
87    }
88
89    #[inline]
90    fn handle_error<E: event::EndpointPublisher>(self, error: Self::Error, event: &mut E) {
91        // dispatch the error to the appropriate channel
92        match error {
93            Error::A(error) => self.a.handle_error(error, event),
94            Error::B(error) => self.b.handle_error(error, event),
95        }
96    }
97}
98
99/// Tagged error for a pair of channels
100pub enum Error<A, B> {
101    A(A),
102    B(B),
103}
104
105pub struct Queue<'a, A, B>
106where
107    A: super::Queue,
108    B: super::Queue,
109{
110    a: &'a mut A,
111    b: &'a mut B,
112}
113
114impl<A, B> super::Queue for Queue<'_, A, B>
115where
116    A: super::Queue,
117    B: super::Queue<Handle = A::Handle>,
118{
119    type Handle = A::Handle;
120
121    #[inline]
122    fn for_each<F: FnMut(datagram::Header<Self::Handle>, &mut [u8])>(&mut self, mut on_packet: F) {
123        // drain both of the channels
124        self.a.for_each(&mut on_packet);
125        self.b.for_each(&mut on_packet);
126    }
127
128    #[inline]
129    fn is_empty(&self) -> bool {
130        self.a.is_empty() && self.b.is_empty()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::io::{
138        rx::{Queue as _, RxExt as _},
139        testing,
140    };
141    use futures_test::task::noop_waker;
142
143    #[test]
144    fn pair_test() {
145        let channel_a = testing::Channel::default();
146        let channel_b = testing::Channel::default();
147        let mut merged = channel_a.clone().with_pair(channel_b.clone());
148
149        let waker = noop_waker();
150        let mut cx = Context::from_waker(&waker);
151        let cx = &mut cx;
152
153        for push_a in [false, true] {
154            for push_b in [false, true] {
155                assert!(merged.poll_ready(cx).is_pending());
156
157                let mut expected = 0;
158
159                if push_a {
160                    expected += 1;
161                    channel_a.push(Default::default());
162                }
163
164                if push_b {
165                    expected += 1;
166                    channel_b.push(Default::default());
167                }
168
169                assert_eq!(merged.poll_ready(cx).is_ready(), push_a || push_b);
170
171                let mut actual = 0;
172                merged.queue(|queue| {
173                    assert_eq!(queue.is_empty(), !(push_a || push_b));
174
175                    queue.for_each(|_header, _payload| {
176                        actual += 1;
177                    });
178                });
179
180                assert_eq!(expected, actual);
181            }
182        }
183    }
184}