1use crate::{
4 application,
5 crypto::{
6 tls::{self, ConnectionInfo, NamedGroup, TlsSession},
7 CryptoSuite,
8 },
9 sync::spsc::{channel, Receiver, SendSlice, Sender},
10 transport,
11};
12use alloc::{boxed::Box, collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
13use core::{any::Any, future::Future, task::Poll};
14use std::sync::Mutex;
15
16pub trait Executor {
19 fn spawn(&self, task: impl Future<Output = ()> + Send + 'static);
20}
21
22pub trait ExporterHandler {
24 fn on_tls_handshake_failed(
25 &self,
26 session: &impl TlsSession,
27 e: &(dyn core::error::Error + Send + Sync + 'static),
28 ) -> Option<Box<dyn Any + Send>>;
29 fn on_tls_exporter_ready(&self, session: &impl TlsSession) -> Option<Box<dyn Any + Send>>;
30}
31
32impl ExporterHandler for () {
34 fn on_tls_handshake_failed(
35 &self,
36 _session: &impl TlsSession,
37 _e: &(dyn core::error::Error + Send + Sync + 'static),
38 ) -> Option<Box<dyn std::any::Any + Send>> {
39 None
40 }
41
42 fn on_tls_exporter_ready(
43 &self,
44 _session: &impl TlsSession,
45 ) -> Option<Box<dyn std::any::Any + Send>> {
46 None
47 }
48}
49
50pub struct OffloadEndpoint<E: tls::Endpoint, X: Executor, H: ExporterHandler> {
51 inner: E,
52 executor: X,
53 exporter: H,
54 channel_capacity: usize,
55}
56
57impl<E: tls::Endpoint, X: Executor, H: ExporterHandler> OffloadEndpoint<E, X, H> {
58 pub fn new(inner: E, executor: X, exporter: H, channel_capacity: usize) -> Self {
59 Self {
60 inner,
61 executor,
62 exporter,
63 channel_capacity,
64 }
65 }
66}
67
68impl<E, X, H> tls::Endpoint for OffloadEndpoint<E, X, H>
69where
70 E: tls::Endpoint,
71 X: Executor + Send + 'static,
72 H: ExporterHandler + Send + 'static + Sync + Clone,
73{
74 type Session = OffloadSession<<E as tls::Endpoint>::Session>;
75
76 fn new_server_session<Params: s2n_codec::EncoderValue>(
77 &mut self,
78 transport_parameters: &Params,
79 connection_info: ConnectionInfo,
80 ) -> Self::Session {
81 OffloadSession::new(
82 self.inner
83 .new_server_session(transport_parameters, connection_info),
84 &self.executor,
85 self.exporter.clone(),
86 self.channel_capacity,
87 )
88 }
89
90 fn new_client_session<Params: s2n_codec::EncoderValue>(
91 &mut self,
92 transport_parameters: &Params,
93 server_name: application::ServerName,
94 ) -> Self::Session {
95 OffloadSession::new(
96 self.inner
97 .new_client_session(transport_parameters, server_name),
98 &self.executor,
99 self.exporter.clone(),
100 self.channel_capacity,
101 )
102 }
103
104 fn max_tag_length(&self) -> usize {
105 self.inner.max_tag_length()
106 }
107}
108
109#[derive(Debug)]
110pub struct OffloadSession<S: CryptoSuite> {
111 recv_from_tls: Receiver<Request<S>>,
112 send_to_tls: Sender<Response>,
113 allowed_to_send: Arc<Mutex<AllowedToSend>>,
114}
115
116impl<S: tls::Session + 'static> OffloadSession<S> {
117 fn new(
118 mut inner: S,
119 executor: &impl Executor,
120 exporter: impl ExporterHandler + Sync + Send + 'static + Clone,
121 channel_capacity: usize,
122 ) -> Self {
123 let (mut send_to_quic, recv_from_tls): (Sender<Request<S>>, Receiver<Request<S>>) =
124 channel(channel_capacity);
125 let (send_to_tls, mut recv_from_quic): (Sender<Response>, Receiver<Response>) =
126 channel(channel_capacity);
127 let allowed_to_send = Arc::new(Mutex::new(AllowedToSend::default()));
128 let clone = allowed_to_send.clone();
129
130 let future = async move {
131 let mut initial_data = VecDeque::default();
132 let mut handshake_data = VecDeque::default();
133 let mut application_data = VecDeque::default();
134
135 core::future::poll_fn(|ctx| {
136 match send_to_quic.poll_slice(ctx) {
137 Poll::Ready(res) => match res {
138 Ok(send_slice) => {
139 let allowed_to_send = *allowed_to_send.lock().unwrap();
140
141 let mut context = RemoteContext {
142 send_to_quic: send_slice,
143 waker: ctx.waker().clone(),
144 initial_data: &mut initial_data,
145 handshake_data: &mut handshake_data,
146 application_data: &mut application_data,
147 exporter_handler: exporter.clone(),
148 allowed_to_send,
149 error: None,
150 };
151
152 while let Poll::Ready(res) = recv_from_quic.poll_slice(ctx) {
153 match res {
154 Ok(mut recv_slice) => {
155 while let Some(response) = recv_slice.pop() {
156 match response {
157 Response::Initial(data) => {
158 context.initial_data.push_back(data);
159 }
160 Response::Handshake(data) => {
161 context.handshake_data.push_back(data);
162 }
163 Response::Application(data) => {
164 context.application_data.push_back(data)
165 }
166 Response::SendStatusChanged => (),
167 }
168 }
169 }
170 Err(_) => {
171 return Poll::Ready(());
174 }
175 }
176 }
177
178 let res = inner.poll(&mut context);
179 if let Poll::Ready(res) = res {
182 let request = match res {
183 Ok(_) => Request::TlsDone,
184 Err(e) => Request::TlsError(e),
185 };
186 let _ = context.send_to_quic.push(request);
187 }
188
189 if let Some(error) = context.error {
191 let _ = context.send_to_quic.push(Request::TlsError(error));
192 }
193
194 res.map(|_| ())
196 }
197 Err(_) => {
198 Poll::Ready(())
201 }
202 },
203 Poll::Pending => Poll::Pending,
204 }
205 })
206 .await;
207 };
208 executor.spawn(future);
209
210 Self {
211 recv_from_tls,
212 send_to_tls,
213 allowed_to_send: clone,
214 }
215 }
216}
217
218impl<S: tls::Session> tls::Session for OffloadSession<S> {
219 #[inline]
220 fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
221 where
222 W: tls::Context<Self>,
223 {
224 let cloned_waker = &context.waker().clone();
225 let mut ctx = core::task::Context::from_waker(cloned_waker);
226
227 match self.recv_from_tls.poll_slice(&mut ctx) {
228 Poll::Ready(res) => match res {
229 Ok(mut slice) => {
230 while let Some(request) = slice.pop() {
231 match request {
232 Request::HandshakeKeys(key, header_key) => {
233 context.on_handshake_keys(key, header_key)?;
234 }
235 Request::ServerName(server_name) => {
236 context.on_server_name(server_name)?
237 }
238 Request::SendInitial(bytes) => context.send_initial(bytes),
239 Request::ClientParams(client_params, mut server_params) => context
240 .on_client_application_params(
241 tls::ApplicationParameters {
242 transport_parameters: &client_params,
243 },
244 &mut server_params,
245 )?,
246 Request::ApplicationProtocol(bytes) => {
247 context.on_application_protocol(bytes)?;
248 }
249 Request::KeyExchangeGroup(named_group) => {
250 context.on_key_exchange_group(named_group)?;
251 }
252 Request::OneRttKeys(key, header_key, transport_parameters) => context
253 .on_one_rtt_keys(
254 key,
255 header_key,
256 tls::ApplicationParameters {
257 transport_parameters: &transport_parameters,
258 },
259 )?,
260 Request::SendHandshake(bytes) => {
261 context.send_handshake(bytes);
262 }
263 Request::HandshakeComplete => {
264 context.on_handshake_complete()?;
265 }
266 Request::TlsDone => {
267 return Poll::Ready(Ok(()));
268 }
269 Request::ZeroRtt(key, header_key, transport_parameters) => {
270 context.on_zero_rtt_keys(
271 key,
272 header_key,
273 tls::ApplicationParameters {
274 transport_parameters: &transport_parameters,
275 },
276 )?;
277 }
278 Request::TlsContext(ctx) => {
279 context.on_tls_context(ctx);
280 }
281 Request::SendApplication(transmission) => {
282 context.send_application(transmission);
283 }
284 Request::TlsError(e) => return Poll::Ready(Err(e)),
285 }
286 }
287 }
288 Err(_) => {
289 return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE)));
291 }
292 },
293 Poll::Pending => (),
294 }
295
296 let mut allowed_to_send = self.allowed_to_send.lock().unwrap();
297 let mut state_change = false;
298 if allowed_to_send.can_send_initial != context.can_send_initial()
299 || allowed_to_send.can_send_handshake != context.can_send_handshake()
300 || allowed_to_send.can_send_application != context.can_send_application()
301 {
302 *allowed_to_send = AllowedToSend {
303 can_send_initial: context.can_send_initial(),
304 can_send_handshake: context.can_send_handshake(),
305 can_send_application: context.can_send_application(),
306 };
307 state_change = true;
308 }
309 drop(allowed_to_send);
311
312 match self.send_to_tls.poll_slice(&mut ctx) {
313 Poll::Ready(res) => match res {
314 Ok(mut slice) => {
315 if let Some(resp) = context.receive_initial(None) {
316 let _ = slice.push(Response::Initial(resp));
317 }
318
319 if let Some(resp) = context.receive_handshake(None) {
320 let _ = slice.push(Response::Handshake(resp));
321 }
322
323 if let Some(resp) = context.receive_application(None) {
324 let _ = slice.push(Response::Application(resp));
325 }
326
327 if state_change {
328 let _ = slice.push(Response::SendStatusChanged);
329 }
330 }
331 Err(_) => {
332 return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE)));
334 }
335 },
336 Poll::Pending => (),
337 }
338
339 Poll::Pending
340 }
341}
342
343impl<S: tls::Session> CryptoSuite for OffloadSession<S> {
344 type HandshakeKey = <S as CryptoSuite>::HandshakeKey;
345 type HandshakeHeaderKey = <S as CryptoSuite>::HandshakeHeaderKey;
346 type InitialKey = <S as CryptoSuite>::InitialKey;
347 type InitialHeaderKey = <S as CryptoSuite>::InitialHeaderKey;
348 type ZeroRttKey = <S as CryptoSuite>::ZeroRttKey;
349 type ZeroRttHeaderKey = <S as CryptoSuite>::ZeroRttHeaderKey;
350 type OneRttKey = <S as CryptoSuite>::OneRttKey;
351 type OneRttHeaderKey = <S as CryptoSuite>::OneRttHeaderKey;
352 type RetryKey = <S as CryptoSuite>::RetryKey;
353}
354
355#[derive(Debug, Default, Copy, Clone)]
356struct AllowedToSend {
357 can_send_initial: bool,
358 can_send_handshake: bool,
359 can_send_application: bool,
360}
361
362const SLICE_ERROR: crate::transport::Error =
363 crate::transport::Error::INTERNAL_ERROR.with_reason("Slice is full");
364
365#[derive(Debug)]
366struct RemoteContext<'a, Request, H> {
367 send_to_quic: SendSlice<'a, Request>,
368 initial_data: &'a mut VecDeque<bytes::Bytes>,
369 handshake_data: &'a mut VecDeque<bytes::Bytes>,
370 application_data: &'a mut VecDeque<bytes::Bytes>,
371 waker: core::task::Waker,
372 allowed_to_send: AllowedToSend,
373 exporter_handler: H,
374 error: Option<crate::transport::Error>,
375}
376
377impl<S: CryptoSuite, H: ExporterHandler> tls::Context<S> for RemoteContext<'_, Request<S>, H> {
378 fn on_client_application_params(
379 &mut self,
380 client_params: tls::ApplicationParameters,
381 server_params: &mut alloc::vec::Vec<u8>,
382 ) -> Result<(), crate::transport::Error> {
383 match self.send_to_quic.push(Request::ClientParams(
384 client_params.transport_parameters.to_vec(),
385 server_params.to_vec(),
386 )) {
387 Ok(_) => return Ok(()),
388 Err(_) => self.error = Some(SLICE_ERROR),
389 }
390 Ok(())
391 }
392
393 fn on_handshake_keys(
394 &mut self,
395 key: <S as CryptoSuite>::HandshakeKey,
396 header_key: <S as CryptoSuite>::HandshakeHeaderKey,
397 ) -> Result<(), crate::transport::Error> {
398 match self
399 .send_to_quic
400 .push(Request::HandshakeKeys(key, header_key))
401 {
402 Ok(_) => return Ok(()),
403 Err(_) => self.error = Some(SLICE_ERROR),
404 }
405 Ok(())
406 }
407
408 fn on_zero_rtt_keys(
409 &mut self,
410 key: <S as CryptoSuite>::ZeroRttKey,
411 header_key: <S as CryptoSuite>::ZeroRttHeaderKey,
412 application_parameters: tls::ApplicationParameters,
413 ) -> Result<(), crate::transport::Error> {
414 match self.send_to_quic.push(Request::ZeroRtt(
415 key,
416 header_key,
417 application_parameters.transport_parameters.to_vec(),
418 )) {
419 Ok(_) => (),
420 Err(_) => self.error = Some(SLICE_ERROR),
421 }
422 Ok(())
423 }
424
425 fn on_one_rtt_keys(
426 &mut self,
427 key: <S as CryptoSuite>::OneRttKey,
428 header_key: <S as CryptoSuite>::OneRttHeaderKey,
429 application_parameters: tls::ApplicationParameters,
430 ) -> Result<(), crate::transport::Error> {
431 match self.send_to_quic.push(Request::OneRttKeys(
432 key,
433 header_key,
434 application_parameters.transport_parameters.to_vec(),
435 )) {
436 Ok(_) => (),
437 Err(_) => self.error = Some(SLICE_ERROR),
438 }
439 Ok(())
440 }
441
442 fn on_server_name(
443 &mut self,
444 server_name: crate::application::ServerName,
445 ) -> Result<(), crate::transport::Error> {
446 match self.send_to_quic.push(Request::ServerName(server_name)) {
447 Ok(_) => (),
448 Err(_) => self.error = Some(SLICE_ERROR),
449 }
450 Ok(())
451 }
452
453 fn on_application_protocol(
454 &mut self,
455 application_protocol: bytes::Bytes,
456 ) -> Result<(), crate::transport::Error> {
457 match self
458 .send_to_quic
459 .push(Request::ApplicationProtocol(application_protocol))
460 {
461 Ok(_) => (),
462 Err(_) => self.error = Some(SLICE_ERROR),
463 }
464 Ok(())
465 }
466
467 fn on_key_exchange_group(
468 &mut self,
469 named_group: tls::NamedGroup,
470 ) -> Result<(), crate::transport::Error> {
471 match self
472 .send_to_quic
473 .push(Request::KeyExchangeGroup(named_group))
474 {
475 Ok(_) => (),
476 Err(_) => self.error = Some(SLICE_ERROR),
477 }
478 Ok(())
479 }
480
481 fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error> {
482 match self.send_to_quic.push(Request::HandshakeComplete) {
483 Ok(_) => (),
484 Err(_) => self.error = Some(SLICE_ERROR),
485 }
486
487 Ok(())
488 }
489
490 fn on_tls_context(&mut self, _context: Box<dyn Any + Send>) {
491 unimplemented!("TLS Context is not supported in Offload implementation");
492 }
493
494 fn on_tls_exporter_ready(
495 &mut self,
496 session: &impl TlsSession,
497 ) -> Result<(), crate::transport::Error> {
498 if let Some(context) = self.exporter_handler.on_tls_exporter_ready(session) {
499 match self.send_to_quic.push(Request::TlsContext(context)) {
500 Ok(_) => (),
501 Err(_) => self.error = Some(SLICE_ERROR),
502 }
503 }
504
505 Ok(())
506 }
507
508 fn receive_initial(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
509 gimme_bytes(max_len, self.initial_data)
510 }
511
512 fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
513 gimme_bytes(max_len, self.handshake_data)
514 }
515
516 fn receive_application(&mut self, max_len: Option<usize>) -> Option<bytes::Bytes> {
517 gimme_bytes(max_len, self.application_data)
518 }
519
520 fn can_send_initial(&self) -> bool {
521 self.allowed_to_send.can_send_initial
522 }
523
524 fn send_initial(&mut self, transmission: bytes::Bytes) {
525 if self
526 .send_to_quic
527 .push(Request::SendInitial(transmission))
528 .is_err()
529 {
530 self.error = Some(SLICE_ERROR);
531 }
532 }
533
534 fn can_send_handshake(&self) -> bool {
535 self.allowed_to_send.can_send_handshake
536 }
537
538 fn send_handshake(&mut self, transmission: bytes::Bytes) {
539 if self
540 .send_to_quic
541 .push(Request::SendHandshake(transmission))
542 .is_err()
543 {
544 self.error = Some(SLICE_ERROR);
545 }
546 }
547
548 fn can_send_application(&self) -> bool {
549 self.allowed_to_send.can_send_application
550 }
551
552 fn send_application(&mut self, transmission: bytes::Bytes) {
553 if self
554 .send_to_quic
555 .push(Request::SendApplication(transmission))
556 .is_err()
557 {
558 self.error = Some(SLICE_ERROR);
559 }
560 }
561
562 fn waker(&self) -> &core::task::Waker {
563 &self.waker
564 }
565
566 fn on_tls_handshake_failed(
567 &mut self,
568 session: &impl tls::TlsSession,
569 e: &(dyn core::error::Error + Send + Sync + 'static),
570 ) -> Result<(), crate::transport::Error> {
571 if let Some(context) = self.exporter_handler.on_tls_handshake_failed(session, e) {
572 match self.send_to_quic.push(Request::TlsContext(context)) {
573 Ok(_) => (),
574 Err(_) => self.error = Some(SLICE_ERROR),
575 }
576 }
577 Ok(())
578 }
579}
580
581fn gimme_bytes(max_len: Option<usize>, vec: &mut VecDeque<bytes::Bytes>) -> Option<bytes::Bytes> {
582 let bytes = vec.pop_front();
583 if let Some(mut bytes) = bytes {
584 if let Some(max_len) = max_len {
585 if bytes.len() > max_len {
586 let remainder = bytes.split_off(max_len);
587 vec.push_front(remainder);
588 }
589 }
590 return Some(bytes);
591 }
592 None
593}
594
595enum Request<S: CryptoSuite> {
596 ZeroRtt(
597 <S as CryptoSuite>::ZeroRttKey,
598 <S as CryptoSuite>::ZeroRttHeaderKey,
599 Vec<u8>,
600 ),
601 ServerName(crate::application::ServerName),
602 SendInitial(bytes::Bytes),
603 ClientParams(Vec<u8>, Vec<u8>),
604 HandshakeKeys(
605 <S as CryptoSuite>::HandshakeKey,
606 <S as CryptoSuite>::HandshakeHeaderKey,
607 ),
608 SendHandshake(bytes::Bytes),
609 ApplicationProtocol(bytes::Bytes),
610 KeyExchangeGroup(NamedGroup),
611 OneRttKeys(
612 <S as CryptoSuite>::OneRttKey,
613 <S as CryptoSuite>::OneRttHeaderKey,
614 Vec<u8>,
615 ),
616 HandshakeComplete,
617 TlsDone,
618 TlsContext(Box<dyn Any + Send>),
619 SendApplication(bytes::Bytes),
620 TlsError(transport::Error),
621}
622
623enum Response {
624 Initial(bytes::Bytes),
625 Handshake(bytes::Bytes),
626 Application(bytes::Bytes),
627 SendStatusChanged,
628}
629
630impl<S: CryptoSuite> alloc::fmt::Debug for Request<S> {
631 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
632 match self {
633 Request::ServerName(_) => write!(f, "ServerName"),
634 Request::SendInitial(_) => write!(f, "SendInitial"),
635 Request::ClientParams(_, _) => write!(f, "ClientParams"),
636 Request::HandshakeKeys(_, _) => write!(f, "HandshakeKeys"),
637 Request::SendHandshake(_) => write!(f, "SendHandshake"),
638 Request::ApplicationProtocol(_) => write!(f, "ApplicationProtocol"),
639 Request::KeyExchangeGroup(_) => write!(f, "KeyExchangeGroup"),
640 Request::OneRttKeys(_, _, _) => write!(f, "OneRttKeys"),
641 Request::HandshakeComplete => write!(f, "HandshakeComplete"),
642 Request::TlsDone => write!(f, "TlsDone"),
643 Request::ZeroRtt(_, _, _) => write!(f, "ZeroRtt"),
644 Request::TlsContext(_) => write!(f, "TlsContext"),
645 Request::SendApplication(_) => write!(f, "SendApplication"),
646 Request::TlsError(_) => write!(f, "TlsError"),
647 }
648 }
649}
650
651impl alloc::fmt::Debug for Response {
652 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
653 match self {
654 Response::Initial(_) => write!(f, "ResponseInitial"),
655 Response::Handshake(_) => write!(f, "ResponseHandshake"),
656 Response::Application(_) => write!(f, "ResponseApplication"),
657 Response::SendStatusChanged => write!(f, "SendStatusChanged"),
658 }
659 }
660}