s2n_quic_transport/stream/
controller.rs1mod local_initiated;
5mod remote_initiated;
6
7use self::{local_initiated::LocalInitiated, remote_initiated::RemoteInitiated};
8use crate::{
9 connection,
10 contexts::OnTransmitError,
11 transmission,
12 transmission::{interest::Provider, WriteContext},
13};
14use core::{
15 task::{ready, Context, Poll},
16 time::Duration,
17};
18use s2n_quic_core::{
19 ack, endpoint,
20 frame::MaxStreams,
21 stream::{self, iter::StreamIter, StreamId, StreamType},
22 time::{timer, Timestamp},
23 transport,
24 transport::parameters::InitialFlowControlLimits,
25 varint::VarInt,
26};
27
28#[cfg(test)]
29pub use remote_initiated::MAX_STREAMS_SYNC_FRACTION;
30
31#[derive(Debug)]
38pub struct Controller {
39 local_endpoint_type: endpoint::Type,
40 local_bidi_controller: LocalInitiated<
41 stream::limits::LocalBidirectional,
42 local_initiated::OpenNotifyBidirectional,
43 >,
44 remote_bidi_controller: RemoteInitiated,
45 local_uni_controller: LocalInitiated<
46 stream::limits::LocalUnidirectional,
47 local_initiated::OpenNotifyUnidirectional,
48 >,
49 remote_uni_controller: RemoteInitiated,
50}
51
52impl Controller {
53 pub fn new(
65 local_endpoint_type: endpoint::Type,
66 initial_peer_limits: InitialFlowControlLimits,
67 initial_local_limits: InitialFlowControlLimits,
68 stream_limits: stream::Limits,
69 min_rtt: Duration,
70 ) -> Self {
71 Self {
72 local_endpoint_type,
73 local_bidi_controller: LocalInitiated::new(
74 initial_peer_limits.max_open_remote_bidirectional_streams,
75 stream_limits.max_open_local_bidirectional_streams,
76 ),
77 remote_bidi_controller: RemoteInitiated::new(
78 initial_local_limits.max_open_remote_bidirectional_streams,
79 min_rtt,
80 ),
81 local_uni_controller: LocalInitiated::new(
82 initial_peer_limits.max_open_remote_unidirectional_streams,
83 stream_limits.max_open_local_unidirectional_streams,
84 ),
85 remote_uni_controller: RemoteInitiated::new(
86 initial_local_limits.max_open_remote_unidirectional_streams,
87 min_rtt,
88 ),
89 }
90 }
91
92 pub fn on_max_streams(&mut self, frame: &MaxStreams) {
95 match frame.stream_type {
96 StreamType::Bidirectional => self.local_bidi_controller.on_max_streams(frame),
97 StreamType::Unidirectional => self.local_uni_controller.on_max_streams(frame),
98 }
99 }
100
101 pub fn poll_open_local_stream(
109 &mut self,
110 stream_type: StreamType,
111 open_tokens: &mut connection::OpenToken,
112 context: &Context,
113 ) -> Poll<()> {
114 let poll_open = match stream_type {
115 StreamType::Bidirectional => self
116 .local_bidi_controller
117 .poll_open_stream(&mut open_tokens.bidirectional, context),
118 StreamType::Unidirectional => self
119 .local_uni_controller
120 .poll_open_stream(&mut open_tokens.unidirectional, context),
121 };
122
123 ready!(poll_open);
125
126 let direction = self.direction(StreamId::initial(self.local_endpoint_type, stream_type));
128 self.on_open_stream(direction);
129 Poll::Ready(())
130 }
131
132 pub fn on_open_remote_stream(
141 &mut self,
142 stream_iter: StreamIter,
143 ) -> Result<(), transport::Error> {
144 debug_assert!(
145 self.direction(stream_iter.max_stream_id()).is_remote(),
146 "should only be called for remote initiated streams"
147 );
148
149 match stream_iter.max_stream_id().stream_type() {
151 StreamType::Bidirectional => self
152 .remote_bidi_controller
153 .on_remote_open_stream(stream_iter.max_stream_id())?,
154 StreamType::Unidirectional => self
155 .remote_uni_controller
156 .on_remote_open_stream(stream_iter.max_stream_id())?,
157 }
158
159 let direction = self.direction(stream_iter.max_stream_id());
160 for _stream_id in stream_iter {
162 self.on_open_stream(direction);
163 }
164 Ok(())
165 }
166
167 fn on_open_stream(&mut self, direction: StreamDirection) {
173 match direction {
174 StreamDirection::LocalInitiatedBidirectional => {
175 self.local_bidi_controller.on_open_stream()
176 }
177 StreamDirection::RemoteInitiatedBidirectional => {
178 self.remote_bidi_controller.on_open_stream()
179 }
180 StreamDirection::LocalInitiatedUnidirectional => {
181 self.local_uni_controller.on_open_stream()
182 }
183 StreamDirection::RemoteInitiatedUnidirectional => {
184 self.remote_uni_controller.on_open_stream()
185 }
186 }
187 }
188
189 pub fn on_close_stream(&mut self, stream_id: StreamId) {
191 match self.direction(stream_id) {
192 StreamDirection::LocalInitiatedBidirectional => {
193 self.local_bidi_controller.on_close_stream()
194 }
195 StreamDirection::RemoteInitiatedBidirectional => {
196 self.remote_bidi_controller.on_close_stream()
197 }
198 StreamDirection::LocalInitiatedUnidirectional => {
199 self.local_uni_controller.on_close_stream()
200 }
201 StreamDirection::RemoteInitiatedUnidirectional => {
202 self.remote_uni_controller.on_close_stream()
203 }
204 }
205 }
206
207 pub fn close(&mut self) {
210 self.local_bidi_controller.close();
211 self.remote_bidi_controller.close();
212 self.local_uni_controller.close();
213 self.remote_uni_controller.close();
214 }
215
216 pub fn on_packet_ack<A: ack::Set>(&mut self, ack_set: &A) {
218 self.local_bidi_controller.on_packet_ack(ack_set);
219 self.remote_bidi_controller.on_packet_ack(ack_set);
220 self.local_uni_controller.on_packet_ack(ack_set);
221 self.remote_uni_controller.on_packet_ack(ack_set);
222 }
223
224 pub fn on_packet_loss<A: ack::Set>(&mut self, ack_set: &A) {
226 self.local_bidi_controller.on_packet_loss(ack_set);
227 self.remote_bidi_controller.on_packet_loss(ack_set);
228 self.local_uni_controller.on_packet_loss(ack_set);
229 self.remote_uni_controller.on_packet_loss(ack_set);
230 }
231
232 pub fn update_blocked_sync_period(&mut self, blocked_sync_period: Duration) {
235 self.local_bidi_controller
236 .update_sync_period(blocked_sync_period);
237 self.local_uni_controller
238 .update_sync_period(blocked_sync_period);
239 }
240
241 pub fn update_min_rtt(&mut self, min_rtt: Duration, now: Timestamp) {
242 self.remote_uni_controller.update_min_rtt(min_rtt, now);
243 self.remote_bidi_controller.update_min_rtt(min_rtt, now);
244 }
245
246 #[inline]
248 pub fn on_transmit<W: WriteContext>(&mut self, context: &mut W) -> Result<(), OnTransmitError> {
249 if !self.has_transmission_interest() {
250 return Ok(());
251 }
252
253 let peer_endpoint_type = self.local_endpoint_type.peer_type();
254
255 macro_rules! on_transmit {
256 ($controller:ident, $endpoint:expr, $ty:expr) => {
257 if let Some(nth) = self
258 .$controller
259 .total_open_stream_count()
260 .checked_sub(VarInt::from_u32(1))
261 {
262 if let Some(stream_id) = StreamId::nth($endpoint, $ty, nth.as_u64()) {
263 self.$controller.on_transmit(stream_id, context)?;
264 }
265 }
266 };
267 }
268
269 on_transmit!(
270 local_bidi_controller,
271 self.local_endpoint_type,
272 StreamType::Bidirectional
273 );
274 on_transmit!(
275 remote_bidi_controller,
276 peer_endpoint_type,
277 StreamType::Bidirectional
278 );
279
280 on_transmit!(
281 local_uni_controller,
282 self.local_endpoint_type,
283 StreamType::Unidirectional
284 );
285 on_transmit!(
286 remote_uni_controller,
287 peer_endpoint_type,
288 StreamType::Unidirectional
289 );
290
291 Ok(())
292 }
293
294 pub fn on_timeout(&mut self, now: Timestamp) {
296 self.local_bidi_controller.on_timeout(now);
297 self.remote_bidi_controller.on_timeout(now);
298 self.local_uni_controller.on_timeout(now);
299 self.remote_uni_controller.on_timeout(now);
300 }
301
302 #[inline]
303 fn direction(&self, stream_id: StreamId) -> StreamDirection {
304 let is_local_initiated = self.local_endpoint_type == stream_id.initiator();
305 match (is_local_initiated, stream_id.stream_type()) {
306 (true, StreamType::Bidirectional) => StreamDirection::LocalInitiatedBidirectional,
307 (true, StreamType::Unidirectional) => StreamDirection::LocalInitiatedUnidirectional,
308 (false, StreamType::Bidirectional) => StreamDirection::RemoteInitiatedBidirectional,
309 (false, StreamType::Unidirectional) => StreamDirection::RemoteInitiatedUnidirectional,
310 }
311 }
312}
313
314impl timer::Provider for Controller {
315 #[inline]
316 fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
317 self.local_bidi_controller.timers(query)?;
318 self.remote_bidi_controller.timers(query)?;
319 self.local_uni_controller.timers(query)?;
320 self.remote_uni_controller.timers(query)?;
321 Ok(())
322 }
323}
324
325impl transmission::interest::Provider for Controller {
327 #[inline]
328 fn transmission_interest<Q: transmission::interest::Query>(
329 &self,
330 query: &mut Q,
331 ) -> transmission::interest::Result {
332 self.local_bidi_controller.transmission_interest(query)?;
333 self.remote_bidi_controller.transmission_interest(query)?;
334 self.local_uni_controller.transmission_interest(query)?;
335 self.remote_uni_controller.transmission_interest(query)?;
336 Ok(())
337 }
338}
339
340#[derive(Debug, Copy, Clone)]
341enum StreamDirection {
342 LocalInitiatedBidirectional,
345
346 RemoteInitiatedBidirectional,
349
350 LocalInitiatedUnidirectional,
353
354 RemoteInitiatedUnidirectional,
356}
357
358impl StreamDirection {
359 fn is_remote(&self) -> bool {
360 match self {
361 StreamDirection::LocalInitiatedBidirectional => false,
362 StreamDirection::RemoteInitiatedBidirectional => true,
363 StreamDirection::LocalInitiatedUnidirectional => false,
364 StreamDirection::RemoteInitiatedUnidirectional => true,
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use s2n_quic_core::varint::VarInt;
373
374 impl Controller {
375 pub fn available_local_initiated_stream_capacity(&self, stream_type: StreamType) -> VarInt {
376 match stream_type {
377 StreamType::Bidirectional => self.local_bidi_controller.available_stream_capacity(),
378 StreamType::Unidirectional => self.local_uni_controller.available_stream_capacity(),
379 }
380 }
381
382 pub fn remote_initiated_max_streams_latest_value(&self, stream_type: StreamType) -> VarInt {
383 match stream_type {
384 StreamType::Bidirectional => self.remote_bidi_controller.latest_limit(),
385 StreamType::Unidirectional => self.remote_uni_controller.latest_limit(),
386 }
387 }
388
389 pub fn available_remote_initiated_stream_capacity(
390 &self,
391 stream_type: StreamType,
392 ) -> VarInt {
393 match stream_type {
394 StreamType::Bidirectional => {
395 self.remote_initiated_max_streams_latest_value(stream_type)
396 - self.remote_bidi_controller.open_stream_count()
397 }
398 StreamType::Unidirectional => {
399 self.remote_initiated_max_streams_latest_value(stream_type)
400 - self.remote_uni_controller.open_stream_count()
401 }
402 }
403 }
404 }
405}
406
407#[cfg(test)]
408mod fuzz_target;