s2n_netbench_driver_s2n_quic/
lib.rs1use bytes::Bytes;
5use core::{
6 future::Future,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use netbench::{client, connection::Owner, helper::IdPrefixReader, scenario, Driver, Result};
11use s2n_quic::{
12 connection,
13 stream::{LocalStream, PeerStream, SplittableStream},
14};
15use s2n_quic_core::stream::testing::Data;
16use std::task::ready;
17use std::{
18 collections::{hash_map::Entry, HashMap},
19 ops,
20 sync::Arc,
21};
22
23fn stream_error(err: s2n_quic::stream::Error) -> Result<()> {
24 if let s2n_quic::stream::Error::StreamReset { error, .. } = err {
25 if *error == 0 {
26 return Ok(());
27 }
28 }
29
30 if let s2n_quic::stream::Error::ConnectionError { error, .. } = err {
31 return conn_error(error);
32 }
33
34 Err(err.into())
35}
36
37fn conn_error(err: s2n_quic::connection::Error) -> Result<()> {
38 if let s2n_quic::connection::Error::Application { error, .. } = err {
39 if *error == 0 {
40 return Ok(());
41 }
42 }
43
44 Err(err.into())
45}
46
47pub struct Client(pub s2n_quic::Client);
48
49impl ops::Deref for Client {
50 type Target = s2n_quic::Client;
51
52 #[inline]
53 fn deref(&self) -> &Self::Target {
54 &self.0
55 }
56}
57
58impl ops::DerefMut for Client {
59 #[inline]
60 fn deref_mut(&mut self) -> &mut Self::Target {
61 &mut self.0
62 }
63}
64
65impl<'a> client::Client<'a> for Client {
66 type Connect = Connect<'a>;
67 type Connection = Driver<'a, Connection>;
68
69 fn connect(
70 &mut self,
71 addr: std::net::SocketAddr,
72 server_name: &str,
73 _server_conn_id: u64,
74 scenario: &'a Arc<scenario::Connection>,
75 ) -> Self::Connect {
76 let connect = s2n_quic::client::Connect::new(addr).with_server_name(server_name);
77 let attempt = s2n_quic::Client::connect(self, connect);
78 Connect { attempt, scenario }
79 }
80}
81
82pub struct Connect<'a> {
83 attempt: s2n_quic::client::ConnectionAttempt,
84 scenario: &'a scenario::Connection,
85}
86
87impl<'a> Future for Connect<'a> {
88 type Output = Result<crate::Driver<'a, Connection>>;
89
90 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91 let conn = ready!(Pin::new(&mut self.attempt).poll(cx))?;
92 let conn = Connection::new(conn);
93 let conn = crate::Driver::new(self.scenario, conn);
94 Ok(conn).into()
95 }
96}
97
98pub struct Connection {
99 conn: s2n_quic::Connection,
100 streams: [HashMap<u64, Stream>; 2],
101 opened_streams: HashMap<u64, (Bytes, LocalStream)>,
102 unidentified_peer_stream: Option<(IdPrefixReader, PeerStream)>,
103}
104
105impl From<s2n_quic::Connection> for Connection {
106 fn from(conn: s2n_quic::Connection) -> Self {
107 Self::new(conn)
108 }
109}
110
111impl Connection {
112 pub fn new(connection: s2n_quic::Connection) -> Self {
113 Self {
114 conn: connection,
115 streams: [HashMap::new(), HashMap::new()],
116 opened_streams: HashMap::new(),
117 unidentified_peer_stream: Default::default(),
118 }
119 }
120
121 pub fn into_inner(self) -> s2n_quic::Connection {
122 self.conn
123 }
124
125 fn open_local_stream<
126 F: FnOnce(&mut s2n_quic::Connection, &mut Context) -> Poll<Result<S, connection::Error>>,
127 S: Into<LocalStream>,
128 >(
129 &mut self,
130 id: u64,
131 open: F,
132 cx: &mut Context,
133 ) -> Poll<Result<()>> {
134 if let Entry::Occupied(mut entry) = self.opened_streams.entry(id) {
136 let (prefix, stream) = entry.get_mut();
137 return match stream.poll_send(prefix, cx) {
138 Poll::Ready(Ok(_)) => {
139 let (_, stream) = entry.remove();
140 let stream = Stream::new(stream);
141 self.streams[Owner::Local].insert(id, stream);
142 Poll::Ready(Ok(()))
143 }
144 Poll::Ready(Err(err)) => {
145 entry.remove();
146 Poll::Ready(stream_error(err))
147 }
148 Poll::Pending => Poll::Pending,
149 };
150 }
151
152 let mut stream = ready!(open(&mut self.conn, cx))?.into();
153
154 let mut prefix = Bytes::copy_from_slice(&id.to_be_bytes());
155
156 match stream.poll_send(&mut prefix, cx) {
157 Poll::Ready(Ok(_)) => {
158 let stream = Stream::new(stream);
159 self.streams[Owner::Local].insert(id, stream);
160 Poll::Ready(Ok(()))
161 }
162 Poll::Ready(Err(err)) => Poll::Ready(stream_error(err)),
163 Poll::Pending => {
164 self.opened_streams.insert(id, (prefix, stream));
165 Poll::Pending
166 }
167 }
168 }
169}
170
171impl netbench::Connection for Connection {
172 fn id(&self) -> u64 {
173 self.conn.id()
174 }
175
176 fn poll_open_bidirectional_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> {
177 self.open_local_stream(id, |conn, cx| conn.poll_open_bidirectional_stream(cx), cx)
178 }
179
180 fn poll_open_send_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> {
181 self.open_local_stream(id, |conn, cx| conn.poll_open_send_stream(cx), cx)
182 }
183
184 fn poll_accept_stream(&mut self, cx: &mut Context) -> Poll<Result<Option<u64>>> {
185 loop {
186 if let Some((id, stream)) = self.unidentified_peer_stream.as_mut() {
187 let len = ready!(futures::io::AsyncRead::poll_read(
188 Pin::new(stream),
189 cx,
190 id.remaining()
191 ))?;
192 let id = ready!(id.on_read(len));
193
194 let (_, stream) = self.unidentified_peer_stream.take().unwrap();
195 let stream = Stream::new(stream);
196 self.streams[Owner::Remote].insert(id, stream);
197 return Poll::Ready(Ok(Some(id)));
198 }
199
200 let stream = ready!(self.conn.poll_accept(cx));
201
202 if let Ok(Some(stream)) = stream {
203 self.unidentified_peer_stream = Some((Default::default(), stream));
204 } else {
205 return Poll::Ready(Ok(None));
206 };
207 }
208 }
209
210 fn poll_send(
211 &mut self,
212 owner: Owner,
213 id: u64,
214 bytes: u64,
215 cx: &mut Context,
216 ) -> Poll<Result<u64>> {
217 self.streams[owner]
218 .get_mut(&id)
219 .unwrap()
220 .tx
221 .as_mut()
222 .unwrap()
223 .poll_send(bytes, cx)
224 }
225
226 fn poll_receive(
227 &mut self,
228 owner: Owner,
229 id: u64,
230 bytes: u64,
231 cx: &mut Context,
232 ) -> Poll<Result<u64>> {
233 self.streams[owner]
234 .get_mut(&id)
235 .unwrap()
236 .rx
237 .as_mut()
238 .unwrap()
239 .poll_receive(bytes, cx)
240 }
241
242 fn poll_send_finish(&mut self, owner: Owner, id: u64, _cx: &mut Context) -> Poll<Result<()>> {
243 if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) {
244 let stream = entry.get_mut();
245 if let Some(mut stream) = stream.tx.take() {
246 stream.inner.finish().or_else(stream_error)?;
247 }
248
249 if stream.rx.is_none() {
250 entry.remove();
251 }
252 }
253
254 Poll::Ready(Ok(()))
255 }
256
257 fn poll_receive_finish(
258 &mut self,
259 owner: Owner,
260 id: u64,
261 _cx: &mut Context,
262 ) -> Poll<Result<()>> {
263 if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) {
264 let stream = entry.get_mut();
265 if let Some(mut stream) = stream.rx.take() {
266 let _ = stream.inner.stop_sending(0u8.into());
267 }
268
269 if stream.tx.is_none() {
270 entry.remove();
271 }
272 }
273
274 Poll::Ready(Ok(()))
275 }
276}
277
278macro_rules! chunks {
279 () => {
280 [
281 Bytes::new(),
282 Bytes::new(),
283 Bytes::new(),
284 Bytes::new(),
285 Bytes::new(),
286 Bytes::new(),
287 Bytes::new(),
288 Bytes::new(),
289 Bytes::new(),
290 Bytes::new(),
291 Bytes::new(),
292 Bytes::new(),
293 Bytes::new(),
294 Bytes::new(),
295 Bytes::new(),
296 Bytes::new(),
297 Bytes::new(),
298 Bytes::new(),
299 Bytes::new(),
300 Bytes::new(),
301 Bytes::new(),
302 Bytes::new(),
303 Bytes::new(),
304 Bytes::new(),
305 Bytes::new(),
306 Bytes::new(),
307 Bytes::new(),
308 Bytes::new(),
309 Bytes::new(),
310 Bytes::new(),
311 Bytes::new(),
312 Bytes::new(),
313 ]
314 };
315}
316
317struct Stream {
318 rx: Option<ReceiveStream>,
319 tx: Option<SendStream>,
320}
321
322impl Stream {
323 fn new(stream: impl SplittableStream) -> Self {
324 let (rx, tx) = stream.split();
325 let rx = rx.map(ReceiveStream::new);
326 let tx = tx.map(SendStream::new);
327 Self { rx, tx }
328 }
329}
330
331struct ReceiveStream {
332 inner: s2n_quic::stream::ReceiveStream,
333 buffered: u64,
334 is_open: bool,
335}
336
337impl ReceiveStream {
338 fn new(inner: s2n_quic::stream::ReceiveStream) -> Self {
339 Self {
340 inner,
341 buffered: 0,
342 is_open: true,
343 }
344 }
345
346 fn poll_receive(&mut self, bytes: u64, cx: &mut Context) -> Poll<Result<u64>> {
347 if !self.is_open && self.buffered == 0 {
348 return Ok(0).into();
349 }
350
351 while self.buffered <= bytes && self.is_open {
352 let mut chunks = chunks!();
353
354 if let Poll::Ready(res) = self.inner.poll_receive_vectored(&mut chunks, cx) {
355 let (count, is_open) = res?;
356 self.is_open &= is_open;
357
358 for chunk in &chunks[..count] {
359 self.buffered += chunk.len() as u64;
360 }
361 } else {
362 break;
363 }
364 }
365
366 let received_len = bytes.min(self.buffered);
367 self.buffered -= received_len;
368
369 if !self.is_open && received_len == 0 {
370 return Ok(0).into();
371 }
372
373 if received_len == 0 {
374 Poll::Pending
375 } else {
376 Ok(received_len).into()
377 }
378 }
379}
380
381struct SendStream {
382 inner: s2n_quic::stream::SendStream,
383 data: Data,
384}
385
386impl SendStream {
387 fn new(inner: s2n_quic::stream::SendStream) -> Self {
388 Self {
389 inner,
390 data: Data::new(u64::MAX),
391 }
392 }
393
394 fn poll_send(&mut self, mut bytes: u64, cx: &mut Context) -> Poll<Result<u64>> {
395 if bytes == 0 {
396 return Ok(0).into();
397 }
398
399 let mut len = 0;
400 let mut data = self.data;
401
402 while bytes > 0 {
403 let mut chunks = chunks!();
404
405 let count = data.send(bytes as usize, &mut chunks).unwrap();
406 let initial_len: u64 = chunks.iter().map(|chunk| chunk.len() as u64).sum();
407
408 let count = if let Poll::Ready(count) =
409 self.inner.poll_send_vectored(&mut chunks[..count], cx)?
410 {
411 count
412 } else {
413 break;
414 };
415
416 if count == chunks.len() {
417 len += initial_len;
418 bytes -= initial_len;
419 continue;
420 }
421
422 let remaining_len: u64 = chunks[count..].iter().map(|chunk| chunk.len() as u64).sum();
423
424 len += initial_len - remaining_len;
425
426 break;
427 }
428
429 if len == 0 {
430 return Poll::Pending;
431 }
432
433 self.data.seek_forward(len as usize);
434
435 Poll::Ready(Ok(len))
436 }
437}