s2n_quic_core/crypto/tls/
offload.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3use crate::{
4    application,
5    crypto::{
6        tls::{self, ConnectionInfo, NamedGroup, TlsSession},
7        CryptoSuite,
8    },
9    sync::spsc::{channel, Receiver, SendSlice, Sender},
10    transport,
11};
12use alloc::{boxed::Box, collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
13use core::{any::Any, future::Future, task::Poll};
14use std::sync::Mutex;
15
16/// Trait used for spawning async tasks corresponding to TLS operations. Each task will signify TLS work
17/// that needs to be done per QUIC connection.
18pub trait Executor {
19    fn spawn(&self, task: impl Future<Output = ()> + Send + 'static);
20}
21
22/// Allows access to the TlsSession on handshake failure and when the exporter secret is ready.
23pub trait ExporterHandler {
24    fn on_tls_handshake_failed(
25        &self,
26        session: &impl TlsSession,
27        e: &(dyn core::error::Error + Send + Sync + 'static),
28    ) -> Option<Box<dyn Any + Send>>;
29    fn on_tls_exporter_ready(&self, session: &impl TlsSession) -> Option<Box<dyn Any + Send>>;
30}
31
32// Most people don't need the TlsSession so we ignore these callbacks by default
33impl ExporterHandler for () {
34    fn on_tls_handshake_failed(
35        &self,
36        _session: &impl TlsSession,
37        _e: &(dyn core::error::Error + Send + Sync + 'static),
38    ) -> Option<Box<dyn std::any::Any + Send>> {
39        None
40    }
41
42    fn on_tls_exporter_ready(
43        &self,
44        _session: &impl TlsSession,
45    ) -> Option<Box<dyn std::any::Any + Send>> {
46        None
47    }
48}
49
50pub struct OffloadEndpoint<E: tls::Endpoint, X: Executor, H: ExporterHandler> {
51    inner: E,
52    executor: X,
53    exporter: H,
54    channel_capacity: usize,
55}
56
57impl<E: tls::Endpoint, X: Executor, H: ExporterHandler> OffloadEndpoint<E, X, H> {
58    pub fn new(inner: E, executor: X, exporter: H, channel_capacity: usize) -> Self {
59        Self {
60            inner,
61            executor,
62            exporter,
63            channel_capacity,
64        }
65    }
66}
67
68impl<E, X, H> tls::Endpoint for OffloadEndpoint<E, X, H>
69where
70    E: tls::Endpoint,
71    X: Executor + Send + 'static,
72    H: ExporterHandler + Send + 'static + Sync + Clone,
73{
74    type Session = OffloadSession<<E as tls::Endpoint>::Session>;
75
76    fn new_server_session<Params: s2n_codec::EncoderValue>(
77        &mut self,
78        transport_parameters: &Params,
79        connection_info: ConnectionInfo,
80    ) -> Self::Session {
81        OffloadSession::new(
82            self.inner
83                .new_server_session(transport_parameters, connection_info),
84            &self.executor,
85            self.exporter.clone(),
86            self.channel_capacity,
87        )
88    }
89
90    fn new_client_session<Params: s2n_codec::EncoderValue>(
91        &mut self,
92        transport_parameters: &Params,
93        server_name: application::ServerName,
94    ) -> Self::Session {
95        OffloadSession::new(
96            self.inner
97                .new_client_session(transport_parameters, server_name),
98            &self.executor,
99            self.exporter.clone(),
100            self.channel_capacity,
101        )
102    }
103
104    fn max_tag_length(&self) -> usize {
105        self.inner.max_tag_length()
106    }
107}
108
109#[derive(Debug)]
110pub struct OffloadSession<S: CryptoSuite> {
111    recv_from_tls: Receiver<Request<S>>,
112    send_to_tls: Sender<Response>,
113    allowed_to_send: Arc<Mutex<AllowedToSend>>,
114}
115
116impl<S: tls::Session + 'static> OffloadSession<S> {
117    fn new(
118        mut inner: S,
119        executor: &impl Executor,
120        exporter: impl ExporterHandler + Sync + Send + 'static + Clone,
121        channel_capacity: usize,
122    ) -> Self {
123        let (mut send_to_quic, recv_from_tls): (Sender<Request<S>>, Receiver<Request<S>>) =
124            channel(channel_capacity);
125        let (send_to_tls, mut recv_from_quic): (Sender<Response>, Receiver<Response>) =
126            channel(channel_capacity);
127        let allowed_to_send = Arc::new(Mutex::new(AllowedToSend::default()));
128        let clone = allowed_to_send.clone();
129
130        let future = async move {
131            let mut initial_data = VecDeque::default();
132            let mut handshake_data = VecDeque::default();
133            let mut application_data = VecDeque::default();
134
135            core::future::poll_fn(|ctx| {
136                match send_to_quic.poll_slice(ctx) {
137                    Poll::Ready(res) => match res {
138                        Ok(send_slice) => {
139                            let allowed_to_send = *allowed_to_send.lock().unwrap();
140
141                            let mut context = RemoteContext {
142                                send_to_quic: send_slice,
143                                waker: ctx.waker().clone(),
144                                initial_data: &mut initial_data,
145                                handshake_data: &mut handshake_data,
146                                application_data: &mut application_data,
147                                exporter_handler: exporter.clone(),
148                                allowed_to_send,
149                                error: None,
150                            };
151
152                            while let Poll::Ready(res) = recv_from_quic.poll_slice(ctx) {
153                                match res {
154                                    Ok(mut recv_slice) => {
155                                        while let Some(response) = recv_slice.pop() {
156                                            match response {
157                                                Response::Initial(data) => {
158                                                    context.initial_data.push_back(data);
159                                                }
160                                                Response::Handshake(data) => {
161                                                    context.handshake_data.push_back(data);
162                                                }
163                                                Response::Application(data) => {
164                                                    context.application_data.push_back(data)
165                                                }
166                                                Response::SendStatusChanged => (),
167                                            }
168                                        }
169                                    }
170                                    Err(_) => {
171                                        // For whatever reason the QUIC side decided to drop this channel. In this case
172                                        // we complete the future.
173                                        return Poll::Ready(());
174                                    }
175                                }
176                            }
177
178                            let res = inner.poll(&mut context);
179                            // Either there was an error or the handshake has finished if TLS returned Poll::Ready.
180                            // Notify the QUIC side accordingly.
181                            if let Poll::Ready(res) = res {
182                                let request = match res {
183                                    Ok(_) => Request::TlsDone,
184                                    Err(e) => Request::TlsError(e),
185                                };
186                                let _ = context.send_to_quic.push(request);
187                            }
188
189                            // We also need to notify the QUIC side of any stored errors that we have.
190                            if let Some(error) = context.error {
191                                let _ = context.send_to_quic.push(Request::TlsError(error));
192                            }
193
194                            // We've already sent the Result to the QUIC side so we can just map it out here.
195                            res.map(|_| ())
196                        }
197                        Err(_) => {
198                            // For whatever reason the QUIC side decided to drop this channel. In this case
199                            // we complete the future.
200                            Poll::Ready(())
201                        }
202                    },
203                    Poll::Pending => Poll::Pending,
204                }
205            })
206            .await;
207        };
208        executor.spawn(future);
209
210        Self {
211            recv_from_tls,
212            send_to_tls,
213            allowed_to_send: clone,
214        }
215    }
216}
217
218impl<S: tls::Session> tls::Session for OffloadSession<S> {
219    #[inline]
220    fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
221    where
222        W: tls::Context<Self>,
223    {
224        let cloned_waker = &context.waker().clone();
225        let mut ctx = core::task::Context::from_waker(cloned_waker);
226
227        match self.recv_from_tls.poll_slice(&mut ctx) {
228            Poll::Ready(res) => match res {
229                Ok(mut slice) => {
230                    while let Some(request) = slice.pop() {
231                        match request {
232                            Request::HandshakeKeys(key, header_key) => {
233                                context.on_handshake_keys(key, header_key)?;
234                            }
235                            Request::ServerName(server_name) => {
236                                context.on_server_name(server_name)?
237                            }
238                            Request::SendInitial(bytes) => context.send_initial(bytes),
239                            Request::ClientParams(client_params, mut server_params) => context
240                                .on_client_application_params(
241                                    tls::ApplicationParameters {
242                                        transport_parameters: &client_params,
243                                    },
244                                    &mut server_params,
245                                )?,
246                            Request::ApplicationProtocol(bytes) => {
247                                context.on_application_protocol(bytes)?;
248                            }
249                            Request::KeyExchangeGroup(named_group) => {
250                                context.on_key_exchange_group(named_group)?;
251                            }
252                            Request::OneRttKeys(key, header_key, transport_parameters) => context
253                                .on_one_rtt_keys(
254                                key,
255                                header_key,
256                                tls::ApplicationParameters {
257                                    transport_parameters: &transport_parameters,
258                                },
259                            )?,
260                            Request::SendHandshake(bytes) => {
261                                context.send_handshake(bytes);
262                            }
263                            Request::HandshakeComplete => {
264                                context.on_handshake_complete()?;
265                            }
266                            Request::TlsDone => {
267                                return Poll::Ready(Ok(()));
268                            }
269                            Request::ZeroRtt(key, header_key, transport_parameters) => {
270                                context.on_zero_rtt_keys(
271                                    key,
272                                    header_key,
273                                    tls::ApplicationParameters {
274                                        transport_parameters: &transport_parameters,
275                                    },
276                                )?;
277                            }
278                            Request::TlsContext(ctx) => {
279                                context.on_tls_context(ctx);
280                            }
281                            Request::SendApplication(transmission) => {
282                                context.send_application(transmission);
283                            }
284                            Request::TlsError(e) => return Poll::Ready(Err(e)),
285                        }
286                    }
287                }
288                Err(_) => {
289                    // For whatever reason the TLS task was cancelled. We cannot continue the handshake.
290                    return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE)));
291                }
292            },
293            Poll::Pending => (),
294        }
295
296        let mut allowed_to_send = self.allowed_to_send.lock().unwrap();
297        let mut state_change = false;
298        if allowed_to_send.can_send_initial != context.can_send_initial()
299            || allowed_to_send.can_send_handshake != context.can_send_handshake()
300            || allowed_to_send.can_send_application != context.can_send_application()
301        {
302            *allowed_to_send = AllowedToSend {
303                can_send_initial: context.can_send_initial(),
304                can_send_handshake: context.can_send_handshake(),
305                can_send_application: context.can_send_application(),
306            };
307            state_change = true;
308        }
309        // Drop the lock ASAP
310        drop(allowed_to_send);
311
312        match self.send_to_tls.poll_slice(&mut ctx) {
313            Poll::Ready(res) => match res {
314                Ok(mut slice) => {
315                    if let Some(resp) = context.receive_initial(None) {
316                        let _ = slice.push(Response::Initial(resp));
317                    }
318
319                    if let Some(resp) = context.receive_handshake(None) {
320                        let _ = slice.push(Response::Handshake(resp));
321                    }
322
323                    if let Some(resp) = context.receive_application(None) {
324                        let _ = slice.push(Response::Application(resp));
325                    }
326
327                    if state_change {
328                        let _ = slice.push(Response::SendStatusChanged);
329                    }
330                }
331                Err(_) => {
332                    // For whatever reason the TLS task was cancelled. We cannot continue the handshake.
333                    return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE)));
334                }
335            },
336            Poll::Pending => (),
337        }
338
339        Poll::Pending
340    }
341}
342
343impl<S: tls::Session> CryptoSuite for OffloadSession<S> {
344    type HandshakeKey = <S as CryptoSuite>::HandshakeKey;
345    type HandshakeHeaderKey = <S as CryptoSuite>::HandshakeHeaderKey;
346    type InitialKey = <S as CryptoSuite>::InitialKey;
347    type InitialHeaderKey = <S as CryptoSuite>::InitialHeaderKey;
348    type ZeroRttKey = <S as CryptoSuite>::ZeroRttKey;
349    type ZeroRttHeaderKey = <S as CryptoSuite>::ZeroRttHeaderKey;
350    type OneRttKey = <S as CryptoSuite>::OneRttKey;
351    type OneRttHeaderKey = <S as CryptoSuite>::OneRttHeaderKey;
352    type RetryKey = <S as CryptoSuite>::RetryKey;
353}
354
355#[derive(Debug, Default, Copy, Clone)]
356struct AllowedToSend {
357    can_send_initial: bool,
358    can_send_handshake: bool,
359    can_send_application: bool,
360}
361
362const SLICE_ERROR: crate::transport::Error =
363    crate::transport::Error::INTERNAL_ERROR.with_reason("Slice is full");
364
365#[derive(Debug)]
366struct RemoteContext<'a, Request, H> {
367    send_to_quic: SendSlice<'a, Request>,
368    initial_data: &'a mut VecDeque<bytes::Bytes>,
369    handshake_data: &'a mut VecDeque<bytes::Bytes>,
370    application_data: &'a mut VecDeque<bytes::Bytes>,
371    waker: core::task::Waker,
372    allowed_to_send: AllowedToSend,
373    exporter_handler: H,
374    error: Option<crate::transport::Error>,
375}
376
377impl<S: CryptoSuite, H: ExporterHandler> tls::Context<S> for RemoteContext<'_, Request<S>, H> {
378    fn on_client_application_params(
379        &mut self,
380        client_params: tls::ApplicationParameters,
381        server_params: &mut alloc::vec::Vec<u8>,
382    ) -> Result<(), crate::transport::Error> {
383        match self.send_to_quic.push(Request::ClientParams(
384            client_params.transport_parameters.to_vec(),
385            server_params.to_vec(),
386        )) {
387            Ok(_) => return Ok(()),
388            Err(_) => self.error = Some(SLICE_ERROR),
389        }
390        Ok(())
391    }
392
393    fn on_handshake_keys(
394        &mut self,
395        key: <S as CryptoSuite>::HandshakeKey,
396        header_key: <S as CryptoSuite>::HandshakeHeaderKey,
397    ) -> Result<(), crate::transport::Error> {
398        match self
399            .send_to_quic
400            .push(Request::HandshakeKeys(key, header_key))
401        {
402            Ok(_) => return Ok(()),
403            Err(_) => self.error = Some(SLICE_ERROR),
404        }
405        Ok(())
406    }
407
408    fn on_zero_rtt_keys(
409        &mut self,
410        key: <S as CryptoSuite>::ZeroRttKey,
411        header_key: <S as CryptoSuite>::ZeroRttHeaderKey,
412        application_parameters: tls::ApplicationParameters,
413    ) -> Result<(), crate::transport::Error> {
414        match self.send_to_quic.push(Request::ZeroRtt(
415            key,
416            header_key,
417            application_parameters.transport_parameters.to_vec(),
418        )) {
419            Ok(_) => (),
420            Err(_) => self.error = Some(SLICE_ERROR),
421        }
422        Ok(())
423    }
424
425    fn on_one_rtt_keys(
426        &mut self,
427        key: <S as CryptoSuite>::OneRttKey,
428        header_key: <S as CryptoSuite>::OneRttHeaderKey,
429        application_parameters: tls::ApplicationParameters,
430    ) -> Result<(), crate::transport::Error> {
431        match self.send_to_quic.push(Request::OneRttKeys(
432            key,
433            header_key,
434            application_parameters.transport_parameters.to_vec(),
435        )) {
436            Ok(_) => (),
437            Err(_) => self.error = Some(SLICE_ERROR),
438        }
439        Ok(())
440    }
441
442    fn on_server_name(
443        &mut self,
444        server_name: crate::application::ServerName,
445    ) -> Result<(), crate::transport::Error> {
446        match self.send_to_quic.push(Request::ServerName(server_name)) {
447            Ok(_) => (),
448            Err(_) => self.error = Some(SLICE_ERROR),
449        }
450        Ok(())
451    }
452
453    fn on_application_protocol(
454        &mut self,
455        application_protocol: bytes::Bytes,
456    ) -> Result<(), crate::transport::Error> {
457        match self
458            .send_to_quic
459            .push(Request::ApplicationProtocol(application_protocol))
460        {
461            Ok(_) => (),
462            Err(_) => self.error = Some(SLICE_ERROR),
463        }
464        Ok(())
465    }
466
467    fn on_key_exchange_group(
468        &mut self,
469        named_group: tls::NamedGroup,
470    ) -> Result<(), crate::transport::Error> {
471        match self
472            .send_to_quic
473            .push(Request::KeyExchangeGroup(named_group))
474        {
475            Ok(_) => (),
476            Err(_) => self.error = Some(SLICE_ERROR),
477        }
478        Ok(())
479    }
480
481    fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error> {
482        match self.send_to_quic.push(Request::HandshakeComplete) {
483            Ok(_) => (),
484            Err(_) => self.error = Some(SLICE_ERROR),
485        }
486
487        Ok(())
488    }
489
490    fn on_tls_context(&mut self, _context: Box<dyn Any + Send>) {
491        unimplemented!("TLS Context is not supported in Offload implementation");
492    }
493
494    fn on_tls_exporter_ready(
495        &mut self,
496        session: &impl TlsSession,
497    ) -> Result<(), crate::transport::Error> {
498        if let Some(context) = self.exporter_handler.on_tls_exporter_ready(session) {
499            match self.send_to_quic.push(Request::TlsContext(context)) {
500                Ok(_) => (),
501                Err(_) => self.error = Some(SLICE_ERROR),
502            }
503        }
504
505        Ok(())
506    }
507
508    fn receive_initial(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
509        gimme_bytes(max_len, self.initial_data)
510    }
511
512    fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
513        gimme_bytes(max_len, self.handshake_data)
514    }
515
516    fn receive_application(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
517        gimme_bytes(max_len, self.application_data)
518    }
519
520    fn can_send_initial(&self) -> bool {
521        self.allowed_to_send.can_send_initial
522    }
523
524    fn send_initial(&mut self, transmission: bytes::Bytes) {
525        if self
526            .send_to_quic
527            .push(Request::SendInitial(transmission))
528            .is_err()
529        {
530            self.error = Some(SLICE_ERROR);
531        }
532    }
533
534    fn can_send_handshake(&self) -> bool {
535        self.allowed_to_send.can_send_handshake
536    }
537
538    fn send_handshake(&mut self, transmission: bytes::Bytes) {
539        if self
540            .send_to_quic
541            .push(Request::SendHandshake(transmission))
542            .is_err()
543        {
544            self.error = Some(SLICE_ERROR);
545        }
546    }
547
548    fn can_send_application(&self) -> bool {
549        self.allowed_to_send.can_send_application
550    }
551
552    fn send_application(&mut self, transmission: bytes::Bytes) {
553        if self
554            .send_to_quic
555            .push(Request::SendApplication(transmission))
556            .is_err()
557        {
558            self.error = Some(SLICE_ERROR);
559        }
560    }
561
562    fn waker(&self) -> &core::task::Waker {
563        &self.waker
564    }
565
566    fn on_tls_handshake_failed(
567        &mut self,
568        session: &impl tls::TlsSession,
569        e: &(dyn core::error::Error + Send + Sync + 'static),
570    ) -> Result<(), crate::transport::Error> {
571        if let Some(context) = self.exporter_handler.on_tls_handshake_failed(session, e) {
572            match self.send_to_quic.push(Request::TlsContext(context)) {
573                Ok(_) => (),
574                Err(_) => self.error = Some(SLICE_ERROR),
575            }
576        }
577        Ok(())
578    }
579}
580
581fn gimme_bytes(max_len: Option<usize>, vec: &mut VecDeque<bytes::Bytes>) -> Option<bytes::Bytes> {
582    let bytes = vec.pop_front();
583    if let Some(mut bytes) = bytes {
584        if let Some(max_len) = max_len {
585            if bytes.len() > max_len {
586                let remainder = bytes.split_off(max_len);
587                vec.push_front(remainder);
588            }
589        }
590        return Some(bytes);
591    }
592    None
593}
594
595enum Request<S: CryptoSuite> {
596    ZeroRtt(
597        <S as CryptoSuite>::ZeroRttKey,
598        <S as CryptoSuite>::ZeroRttHeaderKey,
599        Vec<u8>,
600    ),
601    ServerName(crate::application::ServerName),
602    SendInitial(bytes::Bytes),
603    ClientParams(Vec<u8>, Vec<u8>),
604    HandshakeKeys(
605        <S as CryptoSuite>::HandshakeKey,
606        <S as CryptoSuite>::HandshakeHeaderKey,
607    ),
608    SendHandshake(bytes::Bytes),
609    ApplicationProtocol(bytes::Bytes),
610    KeyExchangeGroup(NamedGroup),
611    OneRttKeys(
612        <S as CryptoSuite>::OneRttKey,
613        <S as CryptoSuite>::OneRttHeaderKey,
614        Vec<u8>,
615    ),
616    HandshakeComplete,
617    TlsDone,
618    TlsContext(Box<dyn Any + Send>),
619    SendApplication(bytes::Bytes),
620    TlsError(transport::Error),
621}
622
623enum Response {
624    Initial(bytes::Bytes),
625    Handshake(bytes::Bytes),
626    Application(bytes::Bytes),
627    SendStatusChanged,
628}
629
630impl<S: CryptoSuite> alloc::fmt::Debug for Request<S> {
631    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
632        match self {
633            Request::ServerName(_) => write!(f, "ServerName"),
634            Request::SendInitial(_) => write!(f, "SendInitial"),
635            Request::ClientParams(_, _) => write!(f, "ClientParams"),
636            Request::HandshakeKeys(_, _) => write!(f, "HandshakeKeys"),
637            Request::SendHandshake(_) => write!(f, "SendHandshake"),
638            Request::ApplicationProtocol(_) => write!(f, "ApplicationProtocol"),
639            Request::KeyExchangeGroup(_) => write!(f, "KeyExchangeGroup"),
640            Request::OneRttKeys(_, _, _) => write!(f, "OneRttKeys"),
641            Request::HandshakeComplete => write!(f, "HandshakeComplete"),
642            Request::TlsDone => write!(f, "TlsDone"),
643            Request::ZeroRtt(_, _, _) => write!(f, "ZeroRtt"),
644            Request::TlsContext(_) => write!(f, "TlsContext"),
645            Request::SendApplication(_) => write!(f, "SendApplication"),
646            Request::TlsError(_) => write!(f, "TlsError"),
647        }
648    }
649}
650
651impl alloc::fmt::Debug for Response {
652    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
653        match self {
654            Response::Initial(_) => write!(f, "ResponseInitial"),
655            Response::Handshake(_) => write!(f, "ResponseHandshake"),
656            Response::Application(_) => write!(f, "ResponseApplication"),
657            Response::SendStatusChanged => write!(f, "SendStatusChanged"),
658        }
659    }
660}