s2n_quic_core/crypto/tls/
slow_tls.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3use crate::{
4    application,
5    crypto::{tls, CryptoSuite},
6    transport,
7};
8use alloc::{boxed::Box, vec::Vec};
9use core::{any::Any, task::Poll};
10
11const DEFER_COUNT: u8 = 3;
12
13pub struct SlowEndpoint<E: tls::Endpoint> {
14    endpoint: E,
15}
16
17impl<E: tls::Endpoint> SlowEndpoint<E> {
18    pub fn new(endpoint: E) -> Self {
19        SlowEndpoint { endpoint }
20    }
21}
22
23impl<E: tls::Endpoint> tls::Endpoint for SlowEndpoint<E> {
24    type Session = SlowSession<<E as tls::Endpoint>::Session>;
25
26    fn new_server_session<Params: s2n_codec::EncoderValue>(
27        &mut self,
28        transport_parameters: &Params,
29        connection_info: tls::ConnectionInfo,
30    ) -> Self::Session {
31        let inner_session = self
32            .endpoint
33            .new_server_session(transport_parameters, connection_info);
34        SlowSession {
35            defer: DEFER_COUNT,
36            inner_session,
37        }
38    }
39
40    fn new_client_session<Params: s2n_codec::EncoderValue>(
41        &mut self,
42        transport_parameters: &Params,
43        server_name: application::ServerName,
44    ) -> Self::Session {
45        let inner_session = self
46            .endpoint
47            .new_client_session(transport_parameters, server_name);
48        SlowSession {
49            defer: DEFER_COUNT,
50            inner_session,
51        }
52    }
53
54    fn max_tag_length(&self) -> usize {
55        self.endpoint.max_tag_length()
56    }
57}
58
59// SlowSession is a test TLS provider that is slow, namely, for each call to poll,
60// it returns Poll::Pending several times before actually polling the real TLS library.
61// This is used in an integration test to assert that our code is correct in the event
62// of any random pendings/wakeups that might occur when negotiating TLS.
63#[derive(Debug)]
64pub struct SlowSession<S: tls::Session> {
65    defer: u8,
66    inner_session: S,
67}
68
69impl<S: tls::Session> tls::Session for SlowSession<S> {
70    #[inline]
71    fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
72    where
73        W: tls::Context<Self>,
74    {
75        // Self-wake and return Pending if defer is non-zero
76        if let Some(d) = self.defer.checked_sub(1) {
77            self.defer = d;
78            context.waker().wake_by_ref();
79            return Poll::Pending;
80        }
81
82        // Otherwise we'll call the function to actually make progress
83        // in the TLS handshake and set up to defer again the next time
84        // we're here.
85        self.defer = DEFER_COUNT;
86        self.inner_session.poll(&mut SlowContext(context))
87    }
88}
89
90impl<S: tls::Session> CryptoSuite for SlowSession<S> {
91    type HandshakeKey = <S as CryptoSuite>::HandshakeKey;
92    type HandshakeHeaderKey = <S as CryptoSuite>::HandshakeHeaderKey;
93    type InitialKey = <S as CryptoSuite>::InitialKey;
94    type InitialHeaderKey = <S as CryptoSuite>::InitialHeaderKey;
95    type ZeroRttKey = <S as CryptoSuite>::ZeroRttKey;
96    type ZeroRttHeaderKey = <S as CryptoSuite>::ZeroRttHeaderKey;
97    type OneRttKey = <S as CryptoSuite>::OneRttKey;
98    type OneRttHeaderKey = <S as CryptoSuite>::OneRttHeaderKey;
99    type RetryKey = <S as CryptoSuite>::RetryKey;
100}
101
102struct SlowContext<'a, Inner>(&'a mut Inner);
103
104impl<I, S: tls::Session> tls::Context<S> for SlowContext<'_, I>
105where
106    I: tls::Context<SlowSession<S>>,
107{
108    fn on_client_application_params(
109        &mut self,
110        client_params: tls::ApplicationParameters,
111        server_params: &mut Vec<u8>,
112    ) -> Result<(), transport::Error> {
113        self.0
114            .on_client_application_params(client_params, server_params)
115    }
116
117    fn on_handshake_keys(
118        &mut self,
119        key: <S as CryptoSuite>::HandshakeKey,
120        header_key: <S as CryptoSuite>::HandshakeHeaderKey,
121    ) -> Result<(), transport::Error> {
122        self.0.on_handshake_keys(key, header_key)
123    }
124
125    fn on_zero_rtt_keys(
126        &mut self,
127        key: <S>::ZeroRttKey,
128        header_key: <S>::ZeroRttHeaderKey,
129        application_parameters: tls::ApplicationParameters,
130    ) -> Result<(), transport::Error> {
131        self.0
132            .on_zero_rtt_keys(key, header_key, application_parameters)
133    }
134
135    fn on_one_rtt_keys(
136        &mut self,
137        key: <S>::OneRttKey,
138        header_key: <S>::OneRttHeaderKey,
139        application_parameters: tls::ApplicationParameters,
140    ) -> Result<(), transport::Error> {
141        self.0
142            .on_one_rtt_keys(key, header_key, application_parameters)
143    }
144
145    fn on_server_name(
146        &mut self,
147        server_name: application::ServerName,
148    ) -> Result<(), transport::Error> {
149        self.0.on_server_name(server_name)
150    }
151
152    fn on_application_protocol(
153        &mut self,
154        application_protocol: tls::Bytes,
155    ) -> Result<(), transport::Error> {
156        self.0.on_application_protocol(application_protocol)
157    }
158
159    fn on_handshake_complete(&mut self) -> Result<(), transport::Error> {
160        self.0.on_handshake_complete()
161    }
162
163    fn on_tls_exporter_ready(
164        &mut self,
165        session: &impl tls::TlsSession,
166    ) -> Result<(), transport::Error> {
167        self.0.on_tls_exporter_ready(session)
168    }
169
170    fn on_tls_handshake_failed(
171        &mut self,
172        session: &impl tls::TlsSession,
173        e: &(dyn core::error::Error + Send + Sync + 'static),
174    ) -> Result<(), transport::Error> {
175        self.0.on_tls_handshake_failed(session, e)
176    }
177
178    fn receive_initial(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
179        self.0.receive_initial(max_len)
180    }
181
182    fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
183        self.0.receive_handshake(max_len)
184    }
185
186    fn receive_application(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
187        self.0.receive_application(max_len)
188    }
189
190    fn can_send_initial(&self) -> bool {
191        self.0.can_send_initial()
192    }
193
194    fn send_initial(&mut self, transmission: tls::Bytes) {
195        self.0.send_initial(transmission);
196    }
197
198    fn can_send_handshake(&self) -> bool {
199        self.0.can_send_handshake()
200    }
201
202    fn send_handshake(&mut self, transmission: tls::Bytes) {
203        self.0.send_handshake(transmission);
204    }
205
206    fn can_send_application(&self) -> bool {
207        self.0.can_send_application()
208    }
209
210    fn send_application(&mut self, transmission: tls::Bytes) {
211        self.0.send_application(transmission)
212    }
213
214    fn waker(&self) -> &core::task::Waker {
215        self.0.waker()
216    }
217
218    fn on_key_exchange_group(
219        &mut self,
220        named_group: tls::NamedGroup,
221    ) -> Result<(), transport::Error> {
222        self.0.on_key_exchange_group(named_group)
223    }
224
225    fn on_tls_context(&mut self, context: Box<dyn Any + Send>) {
226        self.0.on_tls_context(context)
227    }
228}