s2n_quic_transport/stream/
api.rs1use crate::connection::Connection;
7use bytes::Bytes;
8use core::{
9 fmt,
10 future::Future,
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14pub use s2n_quic_core::{
15 application,
16 stream::{ops, StreamError, StreamId, StreamType},
17};
18
19#[derive(Clone)]
20struct State {
21 connection: Connection,
22 stream_id: StreamId,
23 rx: ops::Status,
24 tx: ops::Status,
25}
26
27impl State {
28 fn new(connection: Connection, stream_id: StreamId) -> Self {
29 Self {
30 connection,
31 stream_id,
32 rx: ops::Status::Open,
33 tx: ops::Status::Open,
34 }
35 }
36
37 fn poll_request(
38 &mut self,
39 request: &mut ops::Request,
40 context: Option<&Context>,
41 ) -> Result<ops::Response, StreamError> {
42 let id = self.stream_id;
43 self.connection.poll_request(id, request, context)
44 }
45
46 fn request(&mut self) -> Request<'_, '_> {
47 Request {
48 state: self,
49 request: ops::Request::default(),
50 }
51 }
52}
53
54impl Drop for State {
55 fn drop(&mut self) {
56 let is_rx_open = !self.rx.is_closed();
57 let is_tx_open = !self.tx.is_closed();
58
59 if is_rx_open || is_tx_open {
60 let mut request = self.request();
61
62 if is_tx_open {
63 request.finish().detach_tx();
67 }
68
69 if is_rx_open {
70 request
74 .stop_sending(application::Error::UNKNOWN)
75 .detach_rx();
76 }
77
78 let _ = request.poll(None);
79 }
80 }
81}
82
83macro_rules! tx_stream_apis {
84 () => {
85 pub fn poll_send(
95 &mut self,
96 chunk: &mut Bytes,
97 cx: &mut Context,
98 ) -> Poll<Result<(), StreamError>> {
99 if chunk.is_empty() {
100 return Poll::Ready(Ok(()));
101 }
102
103 self.tx_request()?
104 .send(core::slice::from_mut(chunk))
105 .poll(Some(cx))?
106 .into()
107 }
108
109 pub fn poll_send_vectored(
121 &mut self,
122 chunks: &mut [Bytes],
123 cx: &mut Context,
124 ) -> Poll<Result<usize, StreamError>> {
125 if chunks.is_empty() {
126 return Poll::Ready(Ok(0));
127 }
128
129 let response = self.tx_request()?.send(chunks).poll(Some(cx))?;
130
131 if response.chunks.consumed == 0 {
132 return Poll::Pending;
133 }
134
135 Ok(response.tx().expect("invalid response").chunks.consumed).into()
136 }
137
138 pub fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<Result<usize, StreamError>> {
148 let response = ready!(self
149 .tx_request()?
150 .send_readiness()
151 .poll(Some(cx))?
152 .into_poll());
153 Ok(response.tx().expect("invalid response").bytes.available).into()
154 }
155
156 pub fn send_data(&mut self, chunk: Bytes) -> Result<(), StreamError> {
166 if chunk.is_empty() {
167 return Ok(());
168 }
169
170 match self.tx_request()?.send(&mut [chunk]).poll(None)? {
171 response if response.tx().expect("invalid response").chunks.consumed == 1 => Ok(()),
172 _ => Err(StreamError::sending_blocked()),
173 }
174 }
175
176 pub fn poll_flush(&mut self, cx: &mut Context) -> Poll<Result<(), StreamError>> {
185 self.tx_request()?.flush().poll(Some(cx))?.into()
186 }
187
188 pub fn finish(&mut self) -> Result<(), StreamError> {
195 self.tx_request()?.finish().poll(None)?;
196 Ok(())
197 }
198
199 pub fn poll_close(&mut self, cx: &mut Context) -> Poll<Result<(), StreamError>> {
208 self.tx_request()?.finish().flush().poll(Some(cx))?.into()
209 }
210
211 pub fn reset(&mut self, error_code: application::Error) -> Result<(), StreamError> {
215 self.tx_request()?.reset(error_code).poll(None)?;
216 Ok(())
217 }
218 };
219}
220
221macro_rules! rx_stream_apis {
222 () => {
223 pub fn poll_receive(
233 &mut self,
234 cx: &mut Context,
235 ) -> Poll<Result<Option<Bytes>, StreamError>> {
236 let mut chunk = Bytes::new();
237 let response =
238 ready!(self.poll_receive_vectored(core::slice::from_mut(&mut chunk), cx))?;
239
240 Ok(match response {
241 (consumed, _) if consumed > 0 => Some(chunk),
243 _ => None,
245 })
246 .into()
247 }
248
249 pub fn poll_receive_vectored(
261 &mut self,
262 chunks: &mut [Bytes],
263 cx: &mut Context,
264 ) -> Poll<Result<(usize, bool), StreamError>> {
265 let response = ready!(self
266 .rx_request()?
267 .receive(chunks)
268 .poll(Some(cx))?
269 .into_poll());
270
271 let rx = response.rx().expect("invalid response");
272 let consumed = rx.chunks.consumed;
273 debug_assert!(
274 consumed <= chunks.len(),
275 "consumed exceeded the number of chunks provided"
276 );
277 let is_open = rx.status.is_open() || rx.status.is_finishing();
279 Poll::Ready(Ok((consumed, is_open)))
280 }
281
282 pub fn stop_sending(&mut self, error_code: application::Error) -> Result<(), StreamError> {
293 self.rx_request()?.stop_sending(error_code).poll(None)?;
294 Ok(())
295 }
296 };
297}
298
299pub struct Stream(State);
301
302impl fmt::Debug for Stream {
303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304 let is_alternate = f.alternate();
305
306 let mut s = f.debug_struct("Stream");
307 s.field("id", &self.id());
308
309 if is_alternate {
311 }
313
314 s.finish()
315 }
316}
317
318impl Stream {
319 pub(crate) fn new(connection: Connection, stream_id: StreamId) -> Self {
323 Self(State::new(connection, stream_id))
324 }
325
326 pub fn id(&self) -> StreamId {
327 self.0.stream_id
328 }
329
330 pub fn connection(&self) -> &Connection {
331 &self.0.connection
332 }
333
334 pub fn request(&mut self) -> Request<'_, '_> {
335 self.0.request()
336 }
337
338 pub fn tx_request(&mut self) -> Result<TxRequest<'_, '_>, StreamError> {
339 Ok(TxRequest {
340 state: &mut self.0,
341 request: ops::Request::default(),
342 })
343 }
344
345 pub fn rx_request(&mut self) -> Result<RxRequest<'_, '_>, StreamError> {
346 Ok(RxRequest {
347 state: &mut self.0,
348 request: ops::Request::default(),
349 })
350 }
351
352 tx_stream_apis!();
353 rx_stream_apis!();
354
355 pub fn split(self) -> (ReceiveStream, SendStream) {
360 let mut rx_state = self.0;
361 let mut tx_state = rx_state.clone();
362
363 rx_state.tx = ops::Status::Finished;
365 tx_state.rx = ops::Status::Finished;
366
367 (ReceiveStream(rx_state), SendStream(tx_state))
368 }
369}
370
371pub struct SendStream(State);
373
374impl fmt::Debug for SendStream {
375 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
376 let is_alternate = f.alternate();
377
378 let mut s = f.debug_struct("SendStream");
379 s.field("id", &self.id());
380
381 if is_alternate {
383 }
385
386 s.finish()
387 }
388}
389
390impl SendStream {
391 pub fn id(&self) -> StreamId {
392 self.0.stream_id
393 }
394
395 pub fn connection(&self) -> &Connection {
396 &self.0.connection
397 }
398
399 pub fn tx_request(&mut self) -> Result<TxRequest<'_, '_>, StreamError> {
400 Ok(TxRequest {
401 state: &mut self.0,
402 request: ops::Request::default(),
403 })
404 }
405
406 tx_stream_apis!();
407}
408
409impl From<Stream> for SendStream {
410 fn from(stream: Stream) -> Self {
411 Self(stream.0)
412 }
413}
414
415pub struct ReceiveStream(State);
417
418impl fmt::Debug for ReceiveStream {
419 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
420 let is_alternate = f.alternate();
421
422 let mut s = f.debug_struct("ReceiveStream");
423 s.field("id", &self.id());
424
425 if is_alternate {
427 }
429
430 s.finish()
431 }
432}
433
434impl ReceiveStream {
435 pub fn id(&self) -> StreamId {
436 self.0.stream_id
437 }
438
439 pub fn connection(&self) -> &Connection {
440 &self.0.connection
441 }
442
443 pub fn rx_request(&mut self) -> Result<RxRequest<'_, '_>, StreamError> {
444 Ok(RxRequest {
445 state: &mut self.0,
446 request: ops::Request::default(),
447 })
448 }
449
450 rx_stream_apis!();
451}
452
453impl From<Stream> for ReceiveStream {
454 fn from(stream: Stream) -> Self {
455 Self(stream.0)
456 }
457}
458
459macro_rules! tx_request_apis {
460 () => {
461 pub fn send(&mut self, chunks: &'chunks mut [Bytes]) -> &mut Self {
462 self.request.send(chunks);
463 self
464 }
465
466 pub fn send_readiness(&mut self) -> &mut Self {
467 if self.request.tx.is_none() {
469 self.request.tx = Some(Default::default());
470 }
471 self
472 }
473
474 pub fn finish(&mut self) -> &mut Self {
475 self.request.finish();
476 self
477 }
478
479 pub fn reset(&mut self, error_code: application::Error) -> &mut Self {
480 self.request.reset(error_code);
481 self
482 }
483
484 pub fn flush(&mut self) -> &mut Self {
485 self.request.flush();
486 self
487 }
488 };
489}
490
491macro_rules! rx_request_apis {
492 () => {
493 pub fn receive(&mut self, chunks: &'chunks mut [Bytes]) -> &mut Self {
494 self.request.receive(chunks);
495 self
496 }
497
498 pub fn with_watermark(&mut self, low: usize, high: usize) -> &mut Self {
499 self.request.with_watermark(low, high);
500 self
501 }
502
503 pub fn with_low_watermark(&mut self, low: usize) -> &mut Self {
504 self.request.with_low_watermark(low);
505 self
506 }
507
508 pub fn with_high_watermark(&mut self, high: usize) -> &mut Self {
509 self.request.with_high_watermark(high);
510 self
511 }
512
513 pub fn stop_sending(&mut self, error_code: application::Error) -> &mut Self {
514 self.request.stop_sending(error_code);
515 self
516 }
517 };
518}
519
520pub struct Request<'state, 'chunks> {
521 state: &'state mut State,
522 request: ops::Request<'chunks>,
523}
524
525impl<'chunks> Request<'_, 'chunks> {
526 tx_request_apis!();
527 rx_request_apis!();
528
529 fn detach_tx(&mut self) -> &mut Self {
530 self.request.detach_tx();
531 self
532 }
533
534 fn detach_rx(&mut self) -> &mut Self {
535 self.request.detach_rx();
536 self
537 }
538
539 pub fn poll(&mut self, context: Option<&Context>) -> Result<ops::Response, StreamError> {
540 if self.state.rx.is_finished() && self.state.tx.is_finished() {
541 return Ok(ops::Response {
544 tx: Some(ops::tx::Response {
545 status: ops::Status::Finished,
546 ..Default::default()
547 }),
548 rx: Some(ops::rx::Response {
549 status: ops::Status::Finished,
550 ..Default::default()
551 }),
552 });
553 }
554
555 let response = self.state.poll_request(&mut self.request, context)?;
556
557 if let Some(rx) = response.rx() {
558 self.state.rx = rx.status;
559 }
560
561 if let Some(tx) = response.tx() {
562 self.state.tx = tx.status;
563 }
564
565 Ok(response)
566 }
567}
568
569impl Future for Request<'_, '_> {
570 type Output = Result<ops::Response, StreamError>;
571
572 fn poll(
573 mut self: Pin<&mut Self>,
574 context: &mut Context,
575 ) -> Poll<Result<ops::Response, StreamError>> {
576 Self::poll(&mut self, Some(context))?.into()
577 }
578}
579
580pub struct TxRequest<'state, 'chunks> {
581 state: &'state mut State,
582 request: ops::Request<'chunks>,
583}
584
585impl<'chunks> TxRequest<'_, 'chunks> {
586 tx_request_apis!();
587
588 pub fn poll(&mut self, context: Option<&Context>) -> Result<ops::tx::Response, StreamError> {
589 if self.state.tx.is_finished() {
590 return Ok(ops::tx::Response {
593 status: ops::Status::Finished,
594 ..Default::default()
595 });
596 }
597
598 let response = self
599 .state
600 .poll_request(&mut self.request, context)?
601 .tx
602 .expect("invalid response");
603
604 self.state.tx = response.status;
605
606 Ok(response)
607 }
608}
609
610impl Future for TxRequest<'_, '_> {
611 type Output = Result<ops::tx::Response, StreamError>;
612
613 fn poll(
614 mut self: Pin<&mut Self>,
615 context: &mut Context,
616 ) -> Poll<Result<ops::tx::Response, StreamError>> {
617 Self::poll(&mut self, Some(context))?.into()
618 }
619}
620
621pub struct RxRequest<'state, 'chunks> {
622 state: &'state mut State,
623 request: ops::Request<'chunks>,
624}
625
626impl<'chunks> RxRequest<'_, 'chunks> {
627 rx_request_apis!();
628
629 pub fn poll(&mut self, context: Option<&Context>) -> Result<ops::rx::Response, StreamError> {
630 if self.state.rx.is_finished() {
631 return Ok(ops::rx::Response {
634 status: ops::Status::Finished,
635 ..Default::default()
636 });
637 }
638
639 let response = self
640 .state
641 .poll_request(&mut self.request, context)?
642 .rx
643 .expect("invalid response");
644
645 self.state.rx = response.status;
646
647 Ok(response)
648 }
649}
650
651impl Future for RxRequest<'_, '_> {
652 type Output = Result<ops::rx::Response, StreamError>;
653
654 fn poll(
655 mut self: Pin<&mut Self>,
656 context: &mut Context,
657 ) -> Poll<Result<ops::rx::Response, StreamError>> {
658 Self::poll(&mut self, Some(context))?.into()
659 }
660}