s2n_netbench_driver_s2n_quic/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use bytes::Bytes;
5use core::{
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use netbench::{client, connection::Owner, helper::IdPrefixReader, scenario, Driver, Result};
11use s2n_quic::{
12    connection,
13    stream::{LocalStream, PeerStream, SplittableStream},
14};
15use s2n_quic_core::stream::testing::Data;
16use std::task::ready;
17use std::{
18    collections::{hash_map::Entry, HashMap},
19    ops,
20    sync::Arc,
21};
22
23fn stream_error(err: s2n_quic::stream::Error) -> Result<()> {
24    if let s2n_quic::stream::Error::StreamReset { error, .. } = err {
25        if *error == 0 {
26            return Ok(());
27        }
28    }
29
30    if let s2n_quic::stream::Error::ConnectionError { error, .. } = err {
31        return conn_error(error);
32    }
33
34    Err(err.into())
35}
36
37fn conn_error(err: s2n_quic::connection::Error) -> Result<()> {
38    if let s2n_quic::connection::Error::Application { error, .. } = err {
39        if *error == 0 {
40            return Ok(());
41        }
42    }
43
44    Err(err.into())
45}
46
47pub struct Client(pub s2n_quic::Client);
48
49impl ops::Deref for Client {
50    type Target = s2n_quic::Client;
51
52    #[inline]
53    fn deref(&self) -> &Self::Target {
54        &self.0
55    }
56}
57
58impl ops::DerefMut for Client {
59    #[inline]
60    fn deref_mut(&mut self) -> &mut Self::Target {
61        &mut self.0
62    }
63}
64
65impl<'a> client::Client<'a> for Client {
66    type Connect = Connect<'a>;
67    type Connection = Driver<'a, Connection>;
68
69    fn connect(
70        &mut self,
71        addr: std::net::SocketAddr,
72        server_name: &str,
73        _server_conn_id: u64,
74        scenario: &'a Arc<scenario::Connection>,
75    ) -> Self::Connect {
76        let connect = s2n_quic::client::Connect::new(addr).with_server_name(server_name);
77        let attempt = s2n_quic::Client::connect(self, connect);
78        Connect { attempt, scenario }
79    }
80}
81
82pub struct Connect<'a> {
83    attempt: s2n_quic::client::ConnectionAttempt,
84    scenario: &'a scenario::Connection,
85}
86
87impl<'a> Future for Connect<'a> {
88    type Output = Result<crate::Driver<'a, Connection>>;
89
90    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91        let conn = ready!(Pin::new(&mut self.attempt).poll(cx))?;
92        let conn = Connection::new(conn);
93        let conn = crate::Driver::new(self.scenario, conn);
94        Ok(conn).into()
95    }
96}
97
98pub struct Connection {
99    conn: s2n_quic::Connection,
100    streams: [HashMap<u64, Stream>; 2],
101    opened_streams: HashMap<u64, (Bytes, LocalStream)>,
102    unidentified_peer_stream: Option<(IdPrefixReader, PeerStream)>,
103}
104
105impl From<s2n_quic::Connection> for Connection {
106    fn from(conn: s2n_quic::Connection) -> Self {
107        Self::new(conn)
108    }
109}
110
111impl Connection {
112    pub fn new(connection: s2n_quic::Connection) -> Self {
113        Self {
114            conn: connection,
115            streams: [HashMap::new(), HashMap::new()],
116            opened_streams: HashMap::new(),
117            unidentified_peer_stream: Default::default(),
118        }
119    }
120
121    pub fn into_inner(self) -> s2n_quic::Connection {
122        self.conn
123    }
124
125    fn open_local_stream<
126        F: FnOnce(&mut s2n_quic::Connection, &mut Context) -> Poll<Result<S, connection::Error>>,
127        S: Into<LocalStream>,
128    >(
129        &mut self,
130        id: u64,
131        open: F,
132        cx: &mut Context,
133    ) -> Poll<Result<()>> {
134        // the stream has already been opened and is waiting to send the prefix
135        if let Entry::Occupied(mut entry) = self.opened_streams.entry(id) {
136            let (prefix, stream) = entry.get_mut();
137            return match stream.poll_send(prefix, cx) {
138                Poll::Ready(Ok(_)) => {
139                    let (_, stream) = entry.remove();
140                    let stream = Stream::new(stream);
141                    self.streams[Owner::Local].insert(id, stream);
142                    Poll::Ready(Ok(()))
143                }
144                Poll::Ready(Err(err)) => {
145                    entry.remove();
146                    Poll::Ready(stream_error(err))
147                }
148                Poll::Pending => Poll::Pending,
149            };
150        }
151
152        let mut stream = ready!(open(&mut self.conn, cx))?.into();
153
154        let mut prefix = Bytes::copy_from_slice(&id.to_be_bytes());
155
156        match stream.poll_send(&mut prefix, cx) {
157            Poll::Ready(Ok(_)) => {
158                let stream = Stream::new(stream);
159                self.streams[Owner::Local].insert(id, stream);
160                Poll::Ready(Ok(()))
161            }
162            Poll::Ready(Err(err)) => Poll::Ready(stream_error(err)),
163            Poll::Pending => {
164                self.opened_streams.insert(id, (prefix, stream));
165                Poll::Pending
166            }
167        }
168    }
169}
170
171impl netbench::Connection for Connection {
172    fn id(&self) -> u64 {
173        self.conn.id()
174    }
175
176    fn poll_open_bidirectional_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> {
177        self.open_local_stream(id, |conn, cx| conn.poll_open_bidirectional_stream(cx), cx)
178    }
179
180    fn poll_open_send_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> {
181        self.open_local_stream(id, |conn, cx| conn.poll_open_send_stream(cx), cx)
182    }
183
184    fn poll_accept_stream(&mut self, cx: &mut Context) -> Poll<Result<Option<u64>>> {
185        loop {
186            if let Some((id, stream)) = self.unidentified_peer_stream.as_mut() {
187                let len = ready!(futures::io::AsyncRead::poll_read(
188                    Pin::new(stream),
189                    cx,
190                    id.remaining()
191                ))?;
192                let id = ready!(id.on_read(len));
193
194                let (_, stream) = self.unidentified_peer_stream.take().unwrap();
195                let stream = Stream::new(stream);
196                self.streams[Owner::Remote].insert(id, stream);
197                return Poll::Ready(Ok(Some(id)));
198            }
199
200            let stream = ready!(self.conn.poll_accept(cx));
201
202            if let Ok(Some(stream)) = stream {
203                self.unidentified_peer_stream = Some((Default::default(), stream));
204            } else {
205                return Poll::Ready(Ok(None));
206            };
207        }
208    }
209
210    fn poll_send(
211        &mut self,
212        owner: Owner,
213        id: u64,
214        bytes: u64,
215        cx: &mut Context,
216    ) -> Poll<Result<u64>> {
217        self.streams[owner]
218            .get_mut(&id)
219            .unwrap()
220            .tx
221            .as_mut()
222            .unwrap()
223            .poll_send(bytes, cx)
224    }
225
226    fn poll_receive(
227        &mut self,
228        owner: Owner,
229        id: u64,
230        bytes: u64,
231        cx: &mut Context,
232    ) -> Poll<Result<u64>> {
233        self.streams[owner]
234            .get_mut(&id)
235            .unwrap()
236            .rx
237            .as_mut()
238            .unwrap()
239            .poll_receive(bytes, cx)
240    }
241
242    fn poll_send_finish(&mut self, owner: Owner, id: u64, _cx: &mut Context) -> Poll<Result<()>> {
243        if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) {
244            let stream = entry.get_mut();
245            if let Some(mut stream) = stream.tx.take() {
246                stream.inner.finish().or_else(stream_error)?;
247            }
248
249            if stream.rx.is_none() {
250                entry.remove();
251            }
252        }
253
254        Poll::Ready(Ok(()))
255    }
256
257    fn poll_receive_finish(
258        &mut self,
259        owner: Owner,
260        id: u64,
261        _cx: &mut Context,
262    ) -> Poll<Result<()>> {
263        if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) {
264            let stream = entry.get_mut();
265            if let Some(mut stream) = stream.rx.take() {
266                let _ = stream.inner.stop_sending(0u8.into());
267            }
268
269            if stream.tx.is_none() {
270                entry.remove();
271            }
272        }
273
274        Poll::Ready(Ok(()))
275    }
276}
277
278macro_rules! chunks {
279    () => {
280        [
281            Bytes::new(),
282            Bytes::new(),
283            Bytes::new(),
284            Bytes::new(),
285            Bytes::new(),
286            Bytes::new(),
287            Bytes::new(),
288            Bytes::new(),
289            Bytes::new(),
290            Bytes::new(),
291            Bytes::new(),
292            Bytes::new(),
293            Bytes::new(),
294            Bytes::new(),
295            Bytes::new(),
296            Bytes::new(),
297            Bytes::new(),
298            Bytes::new(),
299            Bytes::new(),
300            Bytes::new(),
301            Bytes::new(),
302            Bytes::new(),
303            Bytes::new(),
304            Bytes::new(),
305            Bytes::new(),
306            Bytes::new(),
307            Bytes::new(),
308            Bytes::new(),
309            Bytes::new(),
310            Bytes::new(),
311            Bytes::new(),
312            Bytes::new(),
313        ]
314    };
315}
316
317struct Stream {
318    rx: Option<ReceiveStream>,
319    tx: Option<SendStream>,
320}
321
322impl Stream {
323    fn new(stream: impl SplittableStream) -> Self {
324        let (rx, tx) = stream.split();
325        let rx = rx.map(ReceiveStream::new);
326        let tx = tx.map(SendStream::new);
327        Self { rx, tx }
328    }
329}
330
331struct ReceiveStream {
332    inner: s2n_quic::stream::ReceiveStream,
333    buffered: u64,
334    is_open: bool,
335}
336
337impl ReceiveStream {
338    fn new(inner: s2n_quic::stream::ReceiveStream) -> Self {
339        Self {
340            inner,
341            buffered: 0,
342            is_open: true,
343        }
344    }
345
346    fn poll_receive(&mut self, bytes: u64, cx: &mut Context) -> Poll<Result<u64>> {
347        if !self.is_open && self.buffered == 0 {
348            return Ok(0).into();
349        }
350
351        while self.buffered <= bytes && self.is_open {
352            let mut chunks = chunks!();
353
354            if let Poll::Ready(res) = self.inner.poll_receive_vectored(&mut chunks, cx) {
355                let (count, is_open) = res?;
356                self.is_open &= is_open;
357
358                for chunk in &chunks[..count] {
359                    self.buffered += chunk.len() as u64;
360                }
361            } else {
362                break;
363            }
364        }
365
366        let received_len = bytes.min(self.buffered);
367        self.buffered -= received_len;
368
369        if !self.is_open && received_len == 0 {
370            return Ok(0).into();
371        }
372
373        if received_len == 0 {
374            Poll::Pending
375        } else {
376            Ok(received_len).into()
377        }
378    }
379}
380
381struct SendStream {
382    inner: s2n_quic::stream::SendStream,
383    data: Data,
384}
385
386impl SendStream {
387    fn new(inner: s2n_quic::stream::SendStream) -> Self {
388        Self {
389            inner,
390            data: Data::new(u64::MAX),
391        }
392    }
393
394    fn poll_send(&mut self, mut bytes: u64, cx: &mut Context) -> Poll<Result<u64>> {
395        if bytes == 0 {
396            return Ok(0).into();
397        }
398
399        let mut len = 0;
400        let mut data = self.data;
401
402        while bytes > 0 {
403            let mut chunks = chunks!();
404
405            let count = data.send(bytes as usize, &mut chunks).unwrap();
406            let initial_len: u64 = chunks.iter().map(|chunk| chunk.len() as u64).sum();
407
408            let count = if let Poll::Ready(count) =
409                self.inner.poll_send_vectored(&mut chunks[..count], cx)?
410            {
411                count
412            } else {
413                break;
414            };
415
416            if count == chunks.len() {
417                len += initial_len;
418                bytes -= initial_len;
419                continue;
420            }
421
422            let remaining_len: u64 = chunks[count..].iter().map(|chunk| chunk.len() as u64).sum();
423
424            len += initial_len - remaining_len;
425
426            break;
427        }
428
429        if len == 0 {
430            return Poll::Pending;
431        }
432
433        self.data.seek_forward(len as usize);
434
435        Poll::Ready(Ok(len))
436    }
437}