use crate::{
contexts::{OnTransmitError, WriteContext},
stream::{
incoming_connection_flow_controller::IncomingConnectionFlowController,
outgoing_connection_flow_controller::OutgoingConnectionFlowController,
receive_stream::ReceiveStream,
send_stream::SendStream,
stream_events::StreamEvents,
stream_interests::{StreamInterestProvider, StreamInterests},
StreamError,
},
};
use core::{task::Context, time::Duration};
use s2n_quic_core::{
ack, endpoint,
frame::{stream::StreamRef, MaxStreamData, ResetStream, StopSending, StreamDataBlocked},
stream::{ops, StreamId},
time::{timer, Timestamp},
transport,
varint::VarInt,
};
#[derive(Debug)]
pub struct StreamConfig {
pub stream_id: StreamId,
pub local_endpoint_type: endpoint::Type,
pub incoming_connection_flow_controller: IncomingConnectionFlowController,
pub outgoing_connection_flow_controller: OutgoingConnectionFlowController,
pub initial_receive_window: VarInt,
pub desired_flow_control_window: u32,
pub initial_send_window: VarInt,
pub max_send_buffer_size: u32,
}
pub trait StreamTrait: StreamInterestProvider + timer::Provider + core::fmt::Debug {
fn new(config: StreamConfig) -> Self;
fn stream_id(&self) -> StreamId;
fn on_data(
&mut self,
frame: &StreamRef,
events: &mut StreamEvents,
) -> Result<(), transport::Error>;
fn on_stream_data_blocked(
&mut self,
frame: &StreamDataBlocked,
events: &mut StreamEvents,
) -> Result<(), transport::Error>;
fn on_reset(
&mut self,
frame: &ResetStream,
events: &mut StreamEvents,
) -> Result<(), transport::Error>;
fn on_max_stream_data(
&mut self,
frame: &MaxStreamData,
events: &mut StreamEvents,
) -> Result<(), transport::Error>;
fn on_stop_sending(
&mut self,
frame: &StopSending,
events: &mut StreamEvents,
) -> Result<(), transport::Error>;
fn on_packet_ack<A: ack::Set>(&mut self, ack_set: &A, events: &mut StreamEvents);
fn on_packet_loss<A: ack::Set>(&mut self, ack_set: &A, events: &mut StreamEvents);
fn update_blocked_sync_period(&mut self, blocked_sync_period: Duration);
fn on_timeout(&mut self, now: Timestamp);
fn on_internal_reset(&mut self, error: StreamError, events: &mut StreamEvents);
fn on_flush(&mut self, error: StreamError, events: &mut StreamEvents);
fn on_transmit<W: WriteContext>(&mut self, context: &mut W) -> Result<(), OnTransmitError>;
fn on_connection_window_available(&mut self);
fn poll_request(
&mut self,
request: &mut ops::Request,
context: Option<&Context>,
) -> Result<ops::Response, StreamError>;
}
#[derive(Debug)]
pub struct StreamImpl {
pub(super) stream_id: StreamId,
pub(super) receive_stream: ReceiveStream,
has_send: bool,
pub(super) send_stream: SendStream,
}
impl StreamImpl {
fn poll_request_impl(
&mut self,
request: &mut ops::Request,
context: Option<&Context>,
) -> Result<ops::Response, StreamError> {
let mut response = ops::Response::default();
if let Some(rx) = request.rx.as_mut() {
match self.receive_stream.poll_request(rx, context) {
Ok(rx) => response.rx = Some(rx),
Err(err) => {
if response.tx.is_none() {
return Err(err);
} else {
response.rx = Some(ops::rx::Response {
status: ops::Status::Reset(err),
..Default::default()
});
}
}
}
}
if let Some(tx) = request.tx.as_mut() {
match self.send_stream.poll_request(tx, context) {
Ok(tx) => response.tx = Some(tx),
Err(err) => {
if response.rx.is_none() {
return Err(err);
} else {
response.tx = Some(ops::tx::Response {
status: ops::Status::Reset(err),
..Default::default()
});
}
}
}
}
Ok(response)
}
}
impl StreamTrait for StreamImpl {
fn new(config: StreamConfig) -> StreamImpl {
let receive_is_closed = config.stream_id.stream_type().is_unidirectional()
&& config.stream_id.initiator() == config.local_endpoint_type;
let send_is_closed = config.stream_id.stream_type().is_unidirectional()
&& config.stream_id.initiator() != config.local_endpoint_type;
StreamImpl {
stream_id: config.stream_id,
receive_stream: ReceiveStream::new(
receive_is_closed,
config.incoming_connection_flow_controller,
config.initial_receive_window,
config.desired_flow_control_window,
),
has_send: !send_is_closed,
send_stream: SendStream::new(
config.outgoing_connection_flow_controller,
send_is_closed,
config.initial_send_window,
config.max_send_buffer_size,
),
}
}
#[inline]
fn stream_id(&self) -> StreamId {
self.stream_id
}
#[inline]
fn on_data(
&mut self,
frame: &StreamRef,
events: &mut StreamEvents,
) -> Result<(), transport::Error> {
self.receive_stream.on_data(frame, events)
}
#[inline]
fn on_stream_data_blocked(
&mut self,
frame: &StreamDataBlocked,
events: &mut StreamEvents,
) -> Result<(), transport::Error> {
self.receive_stream.on_stream_data_blocked(frame, events)
}
#[inline]
fn on_reset(
&mut self,
frame: &ResetStream,
events: &mut StreamEvents,
) -> Result<(), transport::Error> {
self.receive_stream.on_reset(frame, events)
}
#[inline]
fn on_max_stream_data(
&mut self,
frame: &MaxStreamData,
events: &mut StreamEvents,
) -> Result<(), transport::Error> {
if !self.has_send {
return Err(transport::Error::STREAM_STATE_ERROR
.with_reason("MAX_STREAM_DATA sent on receive-only stream"));
}
self.send_stream.on_max_stream_data(frame, events)
}
#[inline]
fn on_stop_sending(
&mut self,
frame: &StopSending,
events: &mut StreamEvents,
) -> Result<(), transport::Error> {
self.send_stream.on_stop_sending(frame, events)
}
#[inline]
fn on_packet_ack<A: ack::Set>(&mut self, ack_set: &A, events: &mut StreamEvents) {
self.receive_stream.on_packet_ack(ack_set);
self.send_stream.on_packet_ack(ack_set, events);
}
#[inline]
fn on_packet_loss<A: ack::Set>(&mut self, ack_set: &A, _events: &mut StreamEvents) {
self.receive_stream.on_packet_loss(ack_set);
self.send_stream.on_packet_loss(ack_set);
}
#[inline]
fn update_blocked_sync_period(&mut self, blocked_sync_period: Duration) {
self.send_stream
.update_blocked_sync_period(blocked_sync_period);
}
#[inline]
fn on_timeout(&mut self, now: Timestamp) {
self.send_stream.on_timeout(now)
}
#[inline]
fn on_internal_reset(&mut self, error: StreamError, events: &mut StreamEvents) {
self.receive_stream.on_internal_reset(error, events);
self.send_stream.on_internal_reset(error, events);
}
#[inline]
fn on_flush(&mut self, error: StreamError, events: &mut StreamEvents) {
self.receive_stream.on_internal_reset(error, events);
self.send_stream.on_flush(error, events);
}
#[inline]
fn on_transmit<W: WriteContext>(&mut self, context: &mut W) -> Result<(), OnTransmitError> {
self.receive_stream.on_transmit(self.stream_id, context)?;
self.send_stream.on_transmit(self.stream_id, context)
}
#[inline]
fn on_connection_window_available(&mut self) {
self.send_stream.on_connection_window_available()
}
fn poll_request(
&mut self,
request: &mut ops::Request,
context: Option<&Context>,
) -> Result<ops::Response, StreamError> {
#[cfg(debug_assertions)]
let contract: crate::stream::contract::Request = (&*request).into();
let result = self.poll_request_impl(request, context);
#[cfg(debug_assertions)]
contract.validate_response(request, result.as_ref(), context);
result
}
}
impl timer::Provider for StreamImpl {
#[inline]
fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
self.send_stream.timers(query)?;
Ok(())
}
}
impl StreamInterestProvider for StreamImpl {
#[inline]
fn stream_interests(&self, interests: &mut StreamInterests) {
self.send_stream.stream_interests(interests);
self.receive_stream.stream_interests(interests);
}
}