s2n_tls/
connection.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::missing_safety_doc)] // TODO add safety docs
5
6#[cfg(feature = "unstable-renegotiate")]
7use crate::renegotiate::RenegotiateState;
8use crate::{
9    callbacks::*,
10    cert_chain::{CertificateChain, CertificateChainHandle},
11    config::Config,
12    enums::*,
13    error::{Error, Fallible, Pollable},
14    psk::Psk,
15    security,
16};
17
18use core::{
19    convert::TryInto,
20    fmt,
21    mem::{self, ManuallyDrop, MaybeUninit},
22    pin::Pin,
23    ptr::NonNull,
24    task::{Poll, Waker},
25    time::Duration,
26};
27use libc::c_void;
28use s2n_tls_sys::*;
29use std::{any::Any, ffi::CStr};
30
31mod builder;
32pub use builder::*;
33
34/// return a &str scoped to the lifetime of the surrounding function
35///
36/// SAFETY: must be called on a null terminated string
37///
38/// SAFETY: the underlying data must live at least as long as the surrounding scope
39// We use a macro instead of a function so that the lifetime of the output is
40// automatically inferred to match the surrounding scope.
41macro_rules! const_str {
42    ($c_chars:expr) => {
43        CStr::from_ptr($c_chars)
44            .to_str()
45            .map_err(|_| Error::INVALID_INPUT)
46    };
47}
48
49#[non_exhaustive]
50#[derive(Debug, PartialEq)]
51/// s2n-tls only tracks up to u8::MAX (255) key updates. If any of the fields show
52/// 255 updates, then more than 255 updates may have occurred.
53pub struct KeyUpdateCount {
54    pub send_key_updates: u8,
55    pub recv_key_updates: u8,
56}
57
58/// Corresponds to [s2n_connection].
59pub struct Connection {
60    connection: NonNull<s2n_connection>,
61}
62
63impl fmt::Debug for Connection {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        let mut debug = f.debug_struct("Connection");
66        if let Ok(handshake) = self.handshake_type() {
67            debug.field("handshake_type", &handshake);
68        }
69        if let Ok(cipher) = self.cipher_suite() {
70            debug.field("cipher_suite", &cipher);
71        }
72        if let Ok(version) = self.actual_protocol_version() {
73            debug.field("actual_protocol_version", &version);
74        }
75        if let Ok(curve) = self.selected_curve() {
76            debug.field("selected_curve", &curve);
77        }
78        debug.finish_non_exhaustive()
79    }
80}
81
82/// # Safety
83///
84/// s2n_connection objects can be sent across threads
85unsafe impl Send for Connection {}
86
87/// # Sync
88///
89/// Although NonNull isn't Sync and allows access to mutable pointers even from
90/// immutable references, the Connection interface enforces that all mutating
91/// methods correctly require &mut self.
92///
93/// Developers and reviewers MUST ensure that new methods correctly use
94/// either &self or &mut self depending on their behavior. No mechanism enforces this.
95///
96/// Note: Although non-mutating methods like getters should be thread-safe by definition,
97/// technically the only thread safety guarantee provided by the underlying C library
98/// is that s2n_send and s2n_recv can be called concurrently.
99///
100unsafe impl Sync for Connection {}
101
102impl Connection {
103    /// # Warning
104    ///
105    /// The newly created connection uses the default security policy.
106    /// Consider changing this depending on your security and compatibility requirements
107    /// by calling [`Connection::set_security_policy`].
108    /// Alternatively, you can use [`crate::config::Builder`], [`crate::config::Builder::set_security_policy`],
109    /// and [`Connection::set_config`] to set the policy on the Config instead of on the Connection.
110    /// See the s2n-tls usage guide:
111    /// <https://aws.github.io/s2n-tls/usage-guide/ch06-security-policies.html>
112    ///
113    /// Corresponds to [s2n_connection_new].
114    pub fn new(mode: Mode) -> Self {
115        crate::init::init();
116
117        let connection = unsafe { s2n_connection_new(mode.into()).into_result() }.unwrap();
118
119        unsafe {
120            debug_assert! {
121                s2n_connection_get_config(connection.as_ptr(), &mut core::ptr::null_mut())
122                    .into_result()
123                    .is_err()
124            }
125        }
126
127        let mut connection = Self { connection };
128        connection.init_context(mode);
129        connection
130    }
131
132    fn init_context(&mut self, mode: Mode) {
133        let context = Box::new(Context::new(mode));
134        let context = Box::into_raw(context) as *mut c_void;
135        // allocate a new context object
136        unsafe {
137            // There should never be an existing context
138            debug_assert!(s2n_connection_get_ctx(self.connection.as_ptr())
139                .into_result()
140                .is_err());
141
142            s2n_connection_set_ctx(self.connection.as_ptr(), context)
143                .into_result()
144                .unwrap();
145        }
146    }
147
148    pub fn new_client() -> Self {
149        Self::new(Mode::Client)
150    }
151
152    pub fn new_server() -> Self {
153        Self::new(Mode::Server)
154    }
155
156    pub(crate) fn as_ptr(&mut self) -> *mut s2n_connection {
157        self.connection.as_ptr()
158    }
159
160    /// # Safety
161    ///
162    /// Caller must ensure s2n_connection is a valid reference to a [`s2n_connection`] object
163    pub(crate) unsafe fn from_raw(connection: NonNull<s2n_connection>) -> Self {
164        Self { connection }
165    }
166
167    pub(crate) fn mode(&self) -> Mode {
168        self.context().mode
169    }
170
171    /// can be used to configure s2n to either use built-in blinding (set blinding
172    /// to Blinding::BuiltIn) or self-service blinding (set blinding to
173    /// Blinding::SelfService).
174    ///
175    /// Corresponds to [s2n_connection_set_blinding].
176    pub fn set_blinding(&mut self, blinding: Blinding) -> Result<&mut Self, Error> {
177        unsafe {
178            s2n_connection_set_blinding(self.connection.as_ptr(), blinding.into()).into_result()
179        }?;
180        Ok(self)
181    }
182
183    /// Reports the remaining nanoseconds before the connection may be gracefully shutdown.
184    ///
185    /// This method is expected to succeed, but could fail if the
186    /// [underlying C call](`s2n_connection_get_delay`) encounters errors.
187    /// Failure indicates that calls to [`Self::poll_shutdown`] will also fail and
188    /// that a graceful two-way shutdown of the connection will not be possible.
189    ///
190    /// Corresponds to [s2n_connection_get_delay].
191    pub fn remaining_blinding_delay(&self) -> Result<Duration, Error> {
192        let nanos = unsafe { s2n_connection_get_delay(self.connection.as_ptr()).into_result() }?;
193        Ok(Duration::from_nanos(nanos))
194    }
195
196    /// Sets whether or not a Client Certificate should be required to complete the TLS Connection.
197    ///
198    /// If this is set to ClientAuthType::Optional the server will request a client certificate
199    /// but allow the client to not provide one. Rejecting a client certificate when using
200    /// ClientAuthType::Optional will terminate the handshake.
201    ///
202    /// Corresponds to [s2n_connection_set_client_auth_type].
203    pub fn set_client_auth_type(
204        &mut self,
205        client_auth_type: ClientAuthType,
206    ) -> Result<&mut Self, Error> {
207        unsafe {
208            s2n_connection_set_client_auth_type(self.connection.as_ptr(), client_auth_type.into())
209                .into_result()
210        }?;
211        Ok(self)
212    }
213
214    /// Attempts to drop the config on the connection.
215    ///
216    /// # Safety
217    ///
218    /// The caller must ensure the config associated with the connection was created
219    /// with a [`config::Builder`].
220    unsafe fn drop_config(&mut self) -> Result<(), Error> {
221        let mut prev_config = core::ptr::null_mut();
222
223        // A valid non-null pointer is returned only if the application previously called
224        // [`Self::set_config()`].
225        if s2n_connection_get_config(self.connection.as_ptr(), &mut prev_config)
226            .into_result()
227            .is_ok()
228        {
229            let prev_config = NonNull::new(prev_config).expect(
230                "config should exist since the call to s2n_connection_get_config was successful",
231            );
232            drop(Config::from_raw(prev_config));
233        }
234
235        Ok(())
236    }
237
238    /// Associates a configuration object with a connection.
239    ///
240    /// Corresponds to [s2n_connection_set_config].
241    pub fn set_config(&mut self, mut config: Config) -> Result<&mut Self, Error> {
242        unsafe {
243            // attempt to drop the currently set config
244            self.drop_config()?;
245
246            s2n_connection_set_config(self.connection.as_ptr(), config.as_mut_ptr())
247                .into_result()?;
248
249            debug_assert! {
250                s2n_connection_get_config(self.connection.as_ptr(), &mut core::ptr::null_mut()).into_result().is_ok(),
251                "s2n_connection_set_config was successful"
252            };
253
254            // Setting the config on the connection creates one additional reference to the config
255            // so do not drop so prevent Rust from calling `drop()` at the end of this function.
256            mem::forget(config);
257        }
258
259        Ok(self)
260    }
261
262    pub(crate) fn config(&self) -> Option<Config> {
263        let mut raw = core::ptr::null_mut();
264        let config = unsafe {
265            s2n_connection_get_config(self.connection.as_ptr(), &mut raw)
266                .into_result()
267                .ok()?;
268            let raw = NonNull::new(raw)?;
269            Config::from_raw(raw)
270        };
271        // Because the config pointer is still set on the connection, this is a copy,
272        // not the original config. This is fine -- Configs are immutable.
273        let _ = ManuallyDrop::new(config.clone());
274        Some(config)
275    }
276
277    /// Corresponds to [s2n_connection_set_cipher_preferences].
278    pub fn set_security_policy(&mut self, policy: &security::Policy) -> Result<&mut Self, Error> {
279        unsafe {
280            s2n_connection_set_cipher_preferences(
281                self.connection.as_ptr(),
282                policy.as_cstr().as_ptr(),
283            )
284            .into_result()
285        }?;
286        Ok(self)
287    }
288
289    /// provides a smooth transition from s2n_connection_prefer_low_latency to s2n_connection_prefer_throughput.
290    ///
291    /// s2n_send uses small TLS records that fit into a single TCP segment for the resize_threshold
292    /// bytes (cap to 8M) of data and reset record size back to a single segment after timeout_threshold
293    /// seconds of inactivity.
294    ///
295    /// Corresponds to [s2n_connection_set_dynamic_record_threshold].
296    pub fn set_dynamic_record_threshold(
297        &mut self,
298        resize_threshold: u32,
299        timeout_threshold: u16,
300    ) -> Result<&mut Self, Error> {
301        unsafe {
302            s2n_connection_set_dynamic_record_threshold(
303                self.connection.as_ptr(),
304                resize_threshold,
305                timeout_threshold,
306            )
307            .into_result()
308        }?;
309        Ok(self)
310    }
311
312    /// Signals the connection to do a key_update at the next possible opportunity.
313    /// Note that the resulting key update message will not be sent until `send` is
314    /// called on the connection.
315    ///
316    /// `peer_request` indicates if a key update should also be requested
317    /// of the peer. When set to `KeyUpdateNotRequested`, then only the sending
318    /// key of the connection will be updated. If set to `KeyUpdateRequested`, then
319    /// the sending key of conn will be updated AND the peer will be requested to
320    /// update their sending key. Note that s2n-tls currently only supports
321    /// `peer_request` being set to `KeyUpdateNotRequested` and will return an error
322    /// if any other value is used.
323    ///
324    /// Corresponds to [s2n_connection_request_key_update].
325    pub fn request_key_update(&mut self, peer_request: PeerKeyUpdate) -> Result<&mut Self, Error> {
326        unsafe {
327            s2n_connection_request_key_update(self.connection.as_ptr(), peer_request.into())
328                .into_result()
329        }?;
330        Ok(self)
331    }
332
333    /// Reports the number of times sending and receiving keys have been updated.
334    ///
335    /// This only applies to TLS1.3. Earlier versions do not support key updates.
336    ///
337    /// Corresponds to [s2n_connection_get_key_update_counts].
338    #[cfg(feature = "unstable-ktls")]
339    pub fn key_update_counts(&self) -> Result<KeyUpdateCount, Error> {
340        let mut send_key_updates = 0;
341        let mut recv_key_updates = 0;
342        unsafe {
343            s2n_connection_get_key_update_counts(
344                self.connection.as_ptr(),
345                &mut send_key_updates,
346                &mut recv_key_updates,
347            )
348            .into_result()?;
349        }
350        Ok(KeyUpdateCount {
351            send_key_updates,
352            recv_key_updates,
353        })
354    }
355
356    /// sets the application protocol preferences on an s2n_connection object.
357    ///
358    /// protocols is a list in order of preference, with most preferred protocol first, and of
359    /// length protocol_count. When acting as a client the protocol list is included in the
360    /// Client Hello message as the ALPN extension. As a server, the list is used to negotiate
361    /// a mutual application protocol with the client. After the negotiation for the connection has
362    /// completed, the agreed upon protocol can be retrieved with s2n_get_application_protocol
363    ///
364    /// Corresponds to [s2n_connection_set_protocol_preferences].
365    pub fn set_application_protocol_preference<P: IntoIterator<Item = I>, I: AsRef<[u8]>>(
366        &mut self,
367        protocols: P,
368    ) -> Result<&mut Self, Error> {
369        // reset the list
370        unsafe {
371            s2n_connection_set_protocol_preferences(self.connection.as_ptr(), core::ptr::null(), 0)
372                .into_result()
373        }?;
374
375        for protocol in protocols {
376            self.append_application_protocol_preference(protocol.as_ref())?;
377        }
378
379        Ok(self)
380    }
381
382    /// Corresponds to [s2n_connection_append_protocol_preference].
383    pub fn append_application_protocol_preference(
384        &mut self,
385        protocol: &[u8],
386    ) -> Result<&mut Self, Error> {
387        unsafe {
388            s2n_connection_append_protocol_preference(
389                self.connection.as_ptr(),
390                protocol.as_ptr(),
391                protocol
392                    .len()
393                    .try_into()
394                    .map_err(|_| Error::INVALID_INPUT)?,
395            )
396            .into_result()
397        }?;
398        Ok(self)
399    }
400
401    /// may be used to receive data with callbacks defined by the user.
402    ///
403    /// Corresponds to [s2n_connection_set_recv_cb].
404    pub fn set_receive_callback(&mut self, callback: s2n_recv_fn) -> Result<&mut Self, Error> {
405        unsafe { s2n_connection_set_recv_cb(self.connection.as_ptr(), callback).into_result() }?;
406        Ok(self)
407    }
408
409    /// # Safety
410    ///
411    /// The `context` pointer must live at least as long as the connection
412    ///
413    /// Corresponds to [s2n_connection_set_recv_ctx].
414    pub unsafe fn set_receive_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
415        s2n_connection_set_recv_ctx(self.connection.as_ptr(), context).into_result()?;
416        Ok(self)
417    }
418
419    /// may be used to receive data with callbacks defined by the user.
420    ///
421    /// Corresponds to [s2n_connection_set_send_cb].
422    pub fn set_send_callback(&mut self, callback: s2n_send_fn) -> Result<&mut Self, Error> {
423        unsafe { s2n_connection_set_send_cb(self.connection.as_ptr(), callback).into_result() }?;
424        Ok(self)
425    }
426
427    /// # Safety
428    ///
429    /// The `context` pointer must live at least as long as the connection
430    ///
431    /// Corresponds to [s2n_connection_set_send_ctx].
432    pub unsafe fn set_send_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
433        s2n_connection_set_send_ctx(self.connection.as_ptr(), context).into_result()?;
434        Ok(self)
435    }
436
437    /// Sets the callback to use for verifying that a hostname from an X.509 certificate is
438    /// trusted.
439    ///
440    /// The callback may be called more than once during certificate validation as each SAN on
441    /// the certificate will be checked.
442    ///
443    /// Corresponds to [s2n_connection_set_verify_host_callback].
444    pub fn set_verify_host_callback<T: 'static + VerifyHostNameCallback>(
445        &mut self,
446        handler: T,
447    ) -> Result<&mut Self, Error> {
448        unsafe extern "C" fn verify_host_cb_fn(
449            host_name: *const ::libc::c_char,
450            host_name_len: usize,
451            context: *mut ::libc::c_void,
452        ) -> u8 {
453            let context = &mut *(context as *mut Context);
454            let handler = context.verify_host_callback.as_mut().unwrap();
455            verify_host(host_name, host_name_len, handler)
456        }
457
458        self.context_mut().verify_host_callback = Some(Box::new(handler));
459        unsafe {
460            s2n_connection_set_verify_host_callback(
461                self.connection.as_ptr(),
462                Some(verify_host_cb_fn),
463                self.context_mut() as *mut Context as *mut c_void,
464            )
465            .into_result()
466        }?;
467        Ok(self)
468    }
469
470    /// Connections preferring low latency will be encrypted using small record sizes that
471    /// can be decrypted sooner by the recipient.
472    ///
473    /// Corresponds to [s2n_connection_prefer_low_latency].
474    pub fn prefer_low_latency(&mut self) -> Result<&mut Self, Error> {
475        unsafe { s2n_connection_prefer_low_latency(self.connection.as_ptr()).into_result() }?;
476        Ok(self)
477    }
478
479    /// Connections preferring throughput will use large record sizes that minimize overhead.
480    ///
481    /// Corresponds to [s2n_connection_prefer_throughput].
482    pub fn prefer_throughput(&mut self) -> Result<&mut Self, Error> {
483        unsafe { s2n_connection_prefer_throughput(self.connection.as_ptr()).into_result() }?;
484        Ok(self)
485    }
486
487    /// Configure the connection to reduce potentially expensive calls to recv.
488    ///
489    /// Corresponds to [s2n_connection_set_recv_buffering].
490    pub fn set_receive_buffering(&mut self, enabled: bool) -> Result<&mut Self, Error> {
491        unsafe {
492            s2n_connection_set_recv_buffering(self.connection.as_ptr(), enabled).into_result()
493        }?;
494        Ok(self)
495    }
496
497    /// wipes and free the in and out buffers associated with a connection.
498    ///
499    /// This function may be called when a connection is in keep-alive or idle state to
500    /// reduce memory overhead of long lived connections.
501    ///
502    /// Corresponds to [s2n_connection_release_buffers].
503    pub fn release_buffers(&mut self) -> Result<&mut Self, Error> {
504        unsafe { s2n_connection_release_buffers(self.connection.as_ptr()).into_result() }?;
505        Ok(self)
506    }
507
508    /// Corresponds to [s2n_connection_use_corked_io].
509    pub fn use_corked_io(&mut self) -> Result<&mut Self, Error> {
510        unsafe { s2n_connection_use_corked_io(self.connection.as_ptr()).into_result() }?;
511        Ok(self)
512    }
513
514    pub(crate) fn wipe_method<F, T>(&mut self, wipe: F) -> Result<(), Error>
515    where
516        F: FnOnce(&mut Self) -> Result<T, Error>,
517    {
518        let mode = self.mode();
519
520        // Safety:
521        // We re-init the context after the wipe
522        unsafe { self.drop_context()? };
523
524        let result = wipe(self);
525        // We must initialize the context again whether or not wipe succeeds.
526        // A connection without a context is invalid and has undefined behavior.
527        self.init_context(mode);
528        result?;
529
530        Ok(())
531    }
532
533    /// wipes an existing connection and allows it to be reused.
534    ///
535    /// This method erases all data associated with a connection including pending reads.
536    /// This function should be called after all I/O is completed and s2n_shutdown has been
537    /// called. Reusing the same connection handle(s) is more performant than repeatedly
538    /// calling s2n_connection_new and s2n_connection_free
539    ///
540    /// Corresponds to [s2n_connection_wipe].
541    pub fn wipe(&mut self) -> Result<&mut Self, Error> {
542        self.wipe_method(|conn| unsafe { s2n_connection_wipe(conn.as_ptr()).into_result() })?;
543        Ok(self)
544    }
545
546    fn trigger_initializer(&mut self) {
547        if !core::mem::replace(&mut self.context_mut().connection_initialized, true) {
548            if let Some(config) = self.config() {
549                if let Some(callback) = config.context().connection_initializer.as_ref() {
550                    let future = callback.initialize_connection(self);
551                    AsyncCallback::trigger(future, self);
552                }
553            }
554        }
555    }
556
557    // Poll the connection future if it exists.
558    //
559    // If the future returns Pending, then re-set it back on the Connection.
560    fn poll_async_task(&mut self) -> Option<Poll<Result<(), Error>>> {
561        self.take_async_callback().map(|mut callback| {
562            let waker = self.waker().ok_or(Error::MISSING_WAKER)?.clone();
563            let mut ctx = core::task::Context::from_waker(&waker);
564            match Pin::new(&mut callback).poll(self, &mut ctx) {
565                Poll::Ready(result) => Poll::Ready(result),
566                Poll::Pending => {
567                    // replace the future if it hasn't completed yet
568                    self.set_async_callback(callback);
569                    Poll::Pending
570                }
571            }
572        })
573    }
574
575    pub(crate) fn poll_negotiate_method<F, T>(
576        &mut self,
577        mut negotiate: F,
578    ) -> Poll<Result<(), Error>>
579    where
580        F: FnMut(&mut Connection) -> Poll<Result<T, Error>>,
581    {
582        self.trigger_initializer();
583
584        loop {
585            // Check whether renegotiate is blocked by any async callbacks
586            match self.poll_async_task().unwrap_or(Poll::Ready(Ok(()))) {
587                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
588                Poll::Pending => return Poll::Pending,
589                Poll::Ready(Ok(_)) => {}
590            };
591
592            match negotiate(self) {
593                Poll::Ready(res) => return Poll::Ready(res.map(|_| ())),
594                Poll::Pending => {
595                    // If `negotiate` returned `Pending` it could be blocked on a connection future
596                    // (i.e. not socket IO) so before we return, we need to make sure we poll
597                    // the associated future at least once. Otherwise, we will violate the waker contract.
598                    //
599                    // See https://github.com/aws/s2n-quic/pull/2248
600                    if self.context_mut().async_callback.is_some() {
601                        // continuing in the loop will poll the task
602                        continue;
603                    }
604
605                    // we don't have anything else to poll so return `Pending`
606                    return Poll::Pending;
607                }
608            }
609        }
610    }
611
612    /// Performs the TLS handshake to completion
613    ///
614    /// Multiple callbacks can be configured for a connection and config, but
615    /// [`Self::poll_negotiate()`] can only execute and block on one callback at a time.
616    /// The handshake is sequential, not concurrent, and stops execution when
617    /// it encounters an async callback.
618    ///
619    /// The handshake does not continue execution (and therefore can't call
620    /// any other callbacks) until the blocking async task reports completion.
621    ///
622    /// Corresponds to [s2n_negotiate].
623    pub fn poll_negotiate(&mut self) -> Poll<Result<&mut Self, Error>> {
624        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
625        self.poll_negotiate_method(|conn| unsafe {
626            s2n_negotiate(conn.as_ptr(), &mut blocked).into_poll()
627        })
628        .map_ok(|_| self)
629    }
630
631    /// Encrypts and sends data on a connection where
632    /// [negotiate](`Self::poll_negotiate`) has succeeded.
633    ///
634    /// Returns the number of bytes written, and may indicate a partial write.
635    ///
636    /// Corresponds to [s2n_send].
637    #[cfg(not(feature = "unstable-renegotiate"))]
638    pub fn poll_send(&mut self, buf: &[u8]) -> Poll<Result<usize, Error>> {
639        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
640        let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
641        let buf_ptr = buf.as_ptr() as *const ::libc::c_void;
642        unsafe { s2n_send(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
643    }
644
645    #[cfg(not(feature = "unstable-renegotiate"))]
646    pub(crate) fn poll_recv_raw(
647        &mut self,
648        buf_ptr: *mut ::libc::c_void,
649        buf_len: isize,
650    ) -> Poll<Result<usize, Error>> {
651        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
652        unsafe { s2n_recv(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
653    }
654
655    /// Reads and decrypts data from a connection where
656    /// [negotiate](`Self::poll_negotiate`) has succeeded.
657    ///
658    /// Returns the number of bytes read, and may indicate a partial read.
659    /// 0 bytes returned indicates EOF due to connection closure.
660    ///
661    /// Corresponds to [s2n_recv].
662    pub fn poll_recv(&mut self, buf: &mut [u8]) -> Poll<Result<usize, Error>> {
663        let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
664        let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
665        self.poll_recv_raw(buf_ptr, buf_len)
666    }
667
668    /// Reads and decrypts data from a connection where
669    /// [negotiate](`Self::poll_negotiate`) has succeeded
670    /// to a uninitialized buffer.
671    ///
672    /// Returns the number of bytes read, and may indicate a partial read.
673    /// 0 bytes returned indicates EOF due to connection closure.
674    ///
675    /// Safety: this function is always safe to call, and additionally:
676    /// 1. It will never uninitialize any bytes in `buf`.
677    /// 2. If it returns `Ok(n)`, then the first `n` bytes of `buf`
678    ///    will have been initialized by this function.
679    ///
680    /// Corresponds to [s2n_recv].
681    pub fn poll_recv_uninitialized(
682        &mut self,
683        buf: &mut [MaybeUninit<u8>],
684    ) -> Poll<Result<usize, Error>> {
685        let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
686        let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
687
688        // Safety:
689        // 1. s2n_recv never writes uninitialized garbage to `buf`.
690        // 2. if s2n_recv returns `+n`, it guarantees that the first
691        // `n` bytes of `buf` have been initialized, which allows this
692        // function to return `Ok(n)`
693        self.poll_recv_raw(buf_ptr, buf_len)
694    }
695
696    /// Attempts to flush any data previously buffered by a call to [send](`Self::poll_send`).
697    ///
698    /// poll_flush can only flush data that s2n-tls has already encrypted and
699    /// buffered for sending. poll_send may need to be called again to fully send
700    /// all data. See the [Usage Guide](https://github.com/aws/s2n-tls/blob/main/docs/usage-guide/topics/ch07-io.md)
701    /// for more details.
702    ///
703    /// Corresponds to [s2n_flush].
704    pub fn poll_flush(&mut self) -> Poll<Result<&mut Self, Error>> {
705        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
706        unsafe {
707            s2n_flush(self.connection.as_ptr(), &mut blocked)
708                .into_poll()
709                .map_ok(|_| self)
710        }
711    }
712
713    /// Gets the number of bytes that are currently available in the buffer to be read.
714    ///
715    /// Corresponds to [s2n_peek].
716    pub fn peek_len(&self) -> usize {
717        unsafe { s2n_peek(self.connection.as_ptr()) as usize }
718    }
719
720    /// Attempts a graceful shutdown of the TLS connection.
721    ///
722    /// The shutdown is not complete until the necessary shutdown messages
723    /// have been successfully sent and received. If the peer does not respond
724    /// correctly, the graceful shutdown may fail.
725    ///
726    /// Corresponds to [s2n_shutdown].
727    pub fn poll_shutdown(&mut self) -> Poll<Result<&mut Self, Error>> {
728        if !self.remaining_blinding_delay()?.is_zero() {
729            return Poll::Pending;
730        }
731        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
732        unsafe {
733            s2n_shutdown(self.connection.as_ptr(), &mut blocked)
734                .into_poll()
735                .map_ok(|_| self)
736        }
737    }
738
739    /// Attempts a graceful shutdown of the write side of a TLS connection.
740    ///
741    /// Unlike Self::poll_shutdown, no response from the peer is necessary.
742    /// If using TLS1.3, the connection can continue to be used for reading afterwards.
743    ///
744    /// Corresponds to [s2n_shutdown_send].
745    pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
746        if !self.remaining_blinding_delay()?.is_zero() {
747            return Poll::Pending;
748        }
749        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
750        unsafe {
751            s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
752                .into_poll()
753                .map_ok(|_| self)
754        }
755    }
756
757    /// Returns the TLS alert code, if any
758    ///
759    /// Corresponds to [s2n_connection_get_alert].
760    pub fn alert(&self) -> Option<u8> {
761        let alert =
762            unsafe { s2n_connection_get_alert(self.connection.as_ptr()).into_result() }.ok()?;
763        Some(alert as u8)
764    }
765
766    /// Sets the server name value for the connection
767    ///
768    /// Corresponds to [s2n_set_server_name].
769    pub fn set_server_name(&mut self, server_name: &str) -> Result<&mut Self, Error> {
770        let server_name = std::ffi::CString::new(server_name).map_err(|_| Error::INVALID_INPUT)?;
771        unsafe {
772            s2n_set_server_name(self.connection.as_ptr(), server_name.as_ptr()).into_result()
773        }?;
774        Ok(self)
775    }
776
777    /// Get the server name associated with the connection client hello.
778    ///
779    /// Corresponds to [s2n_get_server_name].
780    pub fn server_name(&self) -> Option<&str> {
781        unsafe {
782            let server_name = s2n_get_server_name(self.connection.as_ptr());
783            match server_name.into_result() {
784                Ok(server_name) => CStr::from_ptr(server_name).to_str().ok(),
785                Err(_) => None,
786            }
787        }
788    }
789
790    /// Adds a session ticket from a previous TLS connection to create a resumed session
791    ///
792    /// Corresponds to [s2n_connection_set_session].
793    pub fn set_session_ticket(&mut self, session: &[u8]) -> Result<&mut Self, Error> {
794        unsafe {
795            s2n_connection_set_session(self.connection.as_ptr(), session.as_ptr(), session.len())
796                .into_result()
797        }?;
798        Ok(self)
799    }
800
801    /// Retrieves the size of the session ticket.
802    ///
803    /// Corresponds to [s2n_connection_get_session_length].
804    pub fn session_ticket_length(&self) -> Result<usize, Error> {
805        let len =
806            unsafe { s2n_connection_get_session_length(self.connection.as_ptr()).into_result()? };
807        Ok(len.try_into().unwrap())
808    }
809
810    /// Serializes the session state from the connection into `output` and returns
811    /// the length of the session ticket.
812    ///
813    /// If the buffer does not have the size for the session_ticket,
814    /// `Error::INVALID_INPUT` is returned.
815    ///
816    /// Note: This function is not recommended for > TLS1.2 because in TLS1.3
817    /// servers can send multiple session tickets and this will return only
818    /// the most recently received ticket.
819    ///
820    /// Corresponds to [s2n_connection_get_session].
821    pub fn session_ticket(&self, output: &mut [u8]) -> Result<usize, Error> {
822        if output.len() < self.session_ticket_length()? {
823            return Err(Error::INVALID_INPUT);
824        }
825        let written = unsafe {
826            s2n_connection_get_session(self.connection.as_ptr(), output.as_mut_ptr(), output.len())
827                .into_result()?
828        };
829        Ok(written.try_into().unwrap())
830    }
831
832    /// Sets a Waker on the connection context or clears it if `None` is passed.
833    pub fn set_waker(&mut self, waker: Option<&Waker>) -> Result<&mut Self, Error> {
834        let ctx = self.context_mut();
835
836        if let Some(waker) = waker {
837            if let Some(prev_waker) = ctx.waker.as_mut() {
838                // only replace the Waker if they don't reference the same task
839                if !prev_waker.will_wake(waker) {
840                    prev_waker.clone_from(waker);
841                }
842            } else {
843                ctx.waker = Some(waker.clone());
844            }
845        } else {
846            ctx.waker = None;
847        }
848        Ok(self)
849    }
850
851    /// Returns the Waker set on the connection context.
852    pub fn waker(&self) -> Option<&Waker> {
853        let ctx = self.context();
854        ctx.waker.as_ref()
855    }
856
857    /// Takes the [`Option::take`] the connection_future stored on the
858    /// connection context.
859    ///
860    /// If the Future returns `Poll::Pending` and has not completed, then it
861    /// should be re-set using [`Self::set_connection_future()`]
862    fn take_async_callback(&mut self) -> Option<AsyncCallback> {
863        let ctx = self.context_mut();
864        ctx.async_callback.take()
865    }
866
867    /// Sets a `connection_future` on the connection context.
868    pub(crate) fn set_async_callback(&mut self, callback: AsyncCallback) {
869        let ctx = self.context_mut();
870        debug_assert!(ctx.async_callback.is_none());
871        ctx.async_callback = Some(callback);
872    }
873
874    /// Retrieve a mutable reference to the [`Context`] stored on the connection.
875    fn context_mut(&mut self) -> &mut Context {
876        unsafe {
877            let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
878                .into_result()
879                .unwrap();
880            &mut *(ctx.as_ptr() as *mut Context)
881        }
882    }
883
884    /// Retrieve a reference to the [`Context`] stored on the connection.
885    fn context(&self) -> &Context {
886        unsafe {
887            let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
888                .into_result()
889                .unwrap();
890            &*(ctx.as_ptr() as *mut Context)
891        }
892    }
893
894    /// Drop the context
895    ///
896    /// SAFETY:
897    /// A connection without a context is invalid. After calling this method
898    /// from anywhere other than Drop, you must reinitialize the context.
899    unsafe fn drop_context(&mut self) -> Result<(), Error> {
900        let ctx = s2n_connection_get_ctx(self.connection.as_ptr()).into_result();
901        if let Ok(ctx) = ctx {
902            drop(Box::from_raw(ctx.as_ptr() as *mut Context));
903        }
904        // Setting a NULL context is important: if we don't also remove the context
905        // from the connection, then the invalid memory is still accessible and
906        // may even be double-freed.
907        s2n_connection_set_ctx(self.connection.as_ptr(), core::ptr::null_mut()).into_result()?;
908        Ok(())
909    }
910
911    /// Mark that the server_name extension was used to configure the connection.
912    ///
913    /// Corresponds to [s2n_connection_server_name_extension_used].
914    pub fn server_name_extension_used(&mut self) {
915        // TODO: requiring the application to call this method is a pretty sharp edge.
916        // Figure out if its possible to automatically call this from the Rust bindings.
917        unsafe {
918            s2n_connection_server_name_extension_used(self.connection.as_ptr())
919                .into_result()
920                .unwrap();
921        }
922    }
923
924    /// Check if client auth was used for a connection.
925    ///
926    /// This is only relevant if [`ClientAuthType::Optional] was used.
927    ///
928    /// Corresponds to [s2n_connection_client_cert_used].
929    pub fn client_cert_used(&self) -> bool {
930        unsafe { s2n_connection_client_cert_used(self.connection.as_ptr()) == 1 }
931    }
932
933    /// Retrieves the raw bytes of the client cert chain received from the peer, if present.
934    ///
935    /// Corresponds to [s2n_connection_get_client_cert_chain].
936    pub fn client_cert_chain_bytes(&self) -> Result<Option<&[u8]>, Error> {
937        if !self.client_cert_used() {
938            return Ok(None);
939        }
940
941        let mut chain = std::ptr::null_mut();
942        let mut len = 0;
943        unsafe {
944            s2n_connection_get_client_cert_chain(self.connection.as_ptr(), &mut chain, &mut len)
945                .into_result()?;
946        }
947
948        if chain.is_null() || len == 0 {
949            return Ok(None);
950        }
951
952        unsafe { Ok(Some(std::slice::from_raw_parts(chain, len as usize))) }
953    }
954
955    // The memory backing the ClientHello is owned by the Connection, so we
956    // tie the ClientHello to the lifetime of the Connection. This is validated
957    // with a doc test that ensures the ClientHello is invalid once the
958    // connection has gone out of scope.
959    //
960    /// Returns a reference to the ClientHello associated with the connection.
961    /// ```compile_fail
962    /// use s2n_tls::client_hello::ClientHello;
963    /// use s2n_tls::connection::Connection;
964    /// use s2n_tls::enums::Mode;
965    ///
966    /// let mut conn = Connection::new(Mode::Server);
967    /// let mut client_hello: &ClientHello = conn.client_hello().unwrap();
968    /// drop(conn);
969    /// client_hello.raw_message();
970    /// ```
971    ///
972    /// The compilation could be failing for a variety of reasons, so make sure
973    /// that the test case is actually good.
974    /// ```no_run
975    /// use s2n_tls::client_hello::ClientHello;
976    /// use s2n_tls::connection::Connection;
977    /// use s2n_tls::enums::Mode;
978    ///
979    /// let mut conn = Connection::new(Mode::Server);
980    /// let mut client_hello: &ClientHello = conn.client_hello().unwrap();
981    /// client_hello.raw_message();
982    /// drop(conn);
983    /// ```
984    ///
985    /// Corresponds to [s2n_connection_get_client_hello].
986    pub fn client_hello(&self) -> Result<&crate::client_hello::ClientHello, Error> {
987        let mut handle =
988            unsafe { s2n_connection_get_client_hello(self.connection.as_ptr()).into_result()? };
989        Ok(crate::client_hello::ClientHello::from_ptr(unsafe {
990            handle.as_mut()
991        }))
992    }
993
994    /// Corresponds to [s2n_client_hello_cb_done].
995    pub(crate) fn mark_client_hello_cb_done(&mut self) -> Result<(), Error> {
996        unsafe {
997            s2n_client_hello_cb_done(self.connection.as_ptr()).into_result()?;
998        }
999        Ok(())
1000    }
1001
1002    /// Access the protocol version selected for the connection.
1003    ///
1004    /// Corresponds to [s2n_connection_get_actual_protocol_version].
1005    pub fn actual_protocol_version(&self) -> Result<Version, Error> {
1006        let version = unsafe {
1007            s2n_connection_get_actual_protocol_version(self.connection.as_ptr()).into_result()?
1008        };
1009        version.try_into()
1010    }
1011
1012    /// Detects if the client hello is using the SSLv2 format.
1013    ///
1014    /// s2n-tls will not negotiate SSLv2, but will accept SSLv2 ClientHellos
1015    /// advertising a higher protocol version like SSLv3 or TLS1.0.
1016    /// [Connection::actual_protocol_version()] can be used to retrieve the
1017    /// protocol version that is actually used on the connection.
1018    ///
1019    /// Corresponds to [s2n_connection_get_client_hello_version], but only checks
1020    /// for SSLv2.
1021    pub fn client_hello_is_sslv2(&self) -> Result<bool, Error> {
1022        let version = unsafe {
1023            s2n_connection_get_client_hello_version(self.connection.as_ptr()).into_result()?
1024        };
1025        let version: Version = version.try_into()?;
1026        Ok(version == Version::SSLV2)
1027    }
1028
1029    /// Corresponds to [s2n_connection_get_handshake_type_name].
1030    pub fn handshake_type(&self) -> Result<&str, Error> {
1031        let handshake = unsafe {
1032            s2n_connection_get_handshake_type_name(self.connection.as_ptr()).into_result()?
1033        };
1034        unsafe {
1035            // SAFETY: Constructed strings have a null byte appended to them.
1036            // SAFETY: The data has a 'static lifetime, because it resides in a
1037            //         static char array, and is never modified after its initial
1038            //         creation.
1039            const_str!(handshake)
1040        }
1041    }
1042
1043    /// Corresponds to [s2n_connection_get_cipher].
1044    pub fn cipher_suite(&self) -> Result<&str, Error> {
1045        let cipher = unsafe { s2n_connection_get_cipher(self.connection.as_ptr()).into_result()? };
1046        unsafe {
1047            // SAFETY: The data is null terminated because it is declared as a C
1048            //         string literal.
1049            // SAFETY: cipher has a static lifetime because it lives on s2n_cipher_suite,
1050            //         a static struct.
1051            const_str!(cipher)
1052        }
1053    }
1054
1055    /// Corresponds to [s2n_connection_get_kem_name].
1056    #[deprecated = "PQ TLS 1.2 KEM Names are no longer supported. Use kem_group_name() to retrieve PQ TLS 1.3 Group name."]
1057    pub fn kem_name(&self) -> Option<&str> {
1058        let name_bytes = {
1059            let name = unsafe { s2n_connection_get_kem_name(self.connection.as_ptr()) };
1060            if name.is_null() {
1061                return None;
1062            }
1063            name
1064        };
1065
1066        let name_str = unsafe {
1067            // SAFETY: The data is null terminated because it is declared as a C
1068            //         string literal.
1069            // SAFETY: kem_name has a static lifetime because it lives on a const
1070            //         struct s2n_kem with file scope.
1071            const_str!(name_bytes)
1072        };
1073
1074        match name_str {
1075            Ok("NONE") => None,
1076            Ok(name) => Some(name),
1077            Err(_) => {
1078                // Unreachable: This would indicate a non-utf-8 string literal in
1079                // the s2n-tls C codebase.
1080                None
1081            }
1082        }
1083    }
1084
1085    /// Corresponds to [s2n_connection_get_kem_group_name].
1086    pub fn kem_group_name(&self) -> Option<&str> {
1087        let name_bytes = {
1088            let name = unsafe { s2n_connection_get_kem_group_name(self.connection.as_ptr()) };
1089            if name.is_null() {
1090                return None;
1091            }
1092            name
1093        };
1094
1095        let name_str = unsafe {
1096            // SAFETY: The data is null terminated because it is declared as a C
1097            //         string literal.
1098            // SAFETY: kem_name has a static lifetime because it lives on a const
1099            //         struct s2n_kem with file scope.
1100            const_str!(name_bytes)
1101        };
1102
1103        match name_str {
1104            Ok("NONE") => None,
1105            Ok(name) => Some(name),
1106            Err(_) => {
1107                // Unreachable: This would indicate a non-utf-8 string literal in
1108                // the s2n-tls C codebase.
1109                None
1110            }
1111        }
1112    }
1113
1114    /// Corresponds to [s2n_connection_get_curve].
1115    pub fn selected_curve(&self) -> Result<&str, Error> {
1116        let curve = unsafe { s2n_connection_get_curve(self.connection.as_ptr()).into_result()? };
1117        unsafe {
1118            // SAFETY: The data is null terminated because it is declared as a C
1119            //         string literal.
1120            // SAFETY: curve has a static lifetime because it lives on s2n_ecc_named_curve,
1121            //         which is a static const struct.
1122            const_str!(curve)
1123        }
1124    }
1125
1126    /// Corresponds to [s2n_connection_get_selected_signature_algorithm].
1127    pub fn selected_signature_algorithm(&self) -> Result<SignatureAlgorithm, Error> {
1128        let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1129        unsafe {
1130            s2n_connection_get_selected_signature_algorithm(self.connection.as_ptr(), &mut sig_alg)
1131                .into_result()?;
1132        }
1133        sig_alg.try_into()
1134    }
1135
1136    /// Corresponds to [s2n_connection_get_selected_digest_algorithm].
1137    pub fn selected_hash_algorithm(&self) -> Result<HashAlgorithm, Error> {
1138        let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1139        unsafe {
1140            s2n_connection_get_selected_digest_algorithm(self.connection.as_ptr(), &mut hash_alg)
1141                .into_result()?;
1142        }
1143        hash_alg.try_into()
1144    }
1145
1146    /// Corresponds to [s2n_connection_get_selected_client_cert_signature_algorithm].
1147    pub fn selected_client_signature_algorithm(&self) -> Result<Option<SignatureAlgorithm>, Error> {
1148        let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1149        unsafe {
1150            s2n_connection_get_selected_client_cert_signature_algorithm(
1151                self.connection.as_ptr(),
1152                &mut sig_alg,
1153            )
1154            .into_result()?;
1155        }
1156        Ok(match sig_alg {
1157            s2n_tls_signature_algorithm::ANONYMOUS => None,
1158            sig_alg => Some(sig_alg.try_into()?),
1159        })
1160    }
1161
1162    /// Corresponds to [s2n_connection_get_selected_client_cert_digest_algorithm].
1163    pub fn selected_client_hash_algorithm(&self) -> Result<Option<HashAlgorithm>, Error> {
1164        let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1165        unsafe {
1166            s2n_connection_get_selected_client_cert_digest_algorithm(
1167                self.connection.as_ptr(),
1168                &mut hash_alg,
1169            )
1170            .into_result()?;
1171        }
1172        Ok(match hash_alg {
1173            s2n_tls_hash_algorithm::NONE => None,
1174            hash_alg => Some(hash_alg.try_into()?),
1175        })
1176    }
1177
1178    /// Corresponds to [s2n_get_application_protocol].
1179    pub fn application_protocol(&self) -> Option<&[u8]> {
1180        let protocol = unsafe { s2n_get_application_protocol(self.connection.as_ptr()) };
1181        if protocol.is_null() {
1182            return None;
1183        }
1184        Some(unsafe { CStr::from_ptr(protocol).to_bytes() })
1185    }
1186
1187    /// Provides access to the TLS-Exporter functionality.
1188    ///
1189    /// See https://datatracker.ietf.org/doc/html/rfc5705 and https://www.rfc-editor.org/rfc/rfc8446.
1190    ///
1191    /// This is currently only available with TLS 1.3 connections which have finished a handshake.
1192    ///
1193    /// Corresponds to [s2n_connection_tls_exporter].
1194    pub fn tls_exporter(
1195        &self,
1196        label: &[u8],
1197        context: &[u8],
1198        output: &mut [u8],
1199    ) -> Result<(), Error> {
1200        unsafe {
1201            s2n_connection_tls_exporter(
1202                self.connection.as_ptr(),
1203                label.as_ptr(),
1204                label.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1205                context.as_ptr(),
1206                context.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1207                output.as_mut_ptr(),
1208                output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1209            )
1210            .into_result()
1211            .map(|_| ())
1212        }
1213    }
1214
1215    /// Returns the validated peer certificate chain.
1216    // 'static lifetime is because this copies the certificate chain from the connection into a new
1217    // chain, so the lifetime is independent of the connection.
1218    ///
1219    /// Corresponds to [s2n_connection_get_peer_cert_chain].
1220    pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
1221        unsafe {
1222            let chain_handle = CertificateChainHandle::allocate()?;
1223            s2n_connection_get_peer_cert_chain(
1224                self.connection.as_ptr(),
1225                chain_handle.cert.as_ptr(),
1226            )
1227            .into_result()
1228            .map(|_| ())?;
1229            Ok(CertificateChain::from_allocated(chain_handle))
1230        }
1231    }
1232
1233    /// Get the certificate used during the TLS handshake
1234    ///
1235    /// - If `self` is a server connection, the certificate selected will depend on the
1236    ///   ServerName sent by the client and supported ciphers.
1237    /// - If `self` is a client connection, the certificate sent in response to a CertificateRequest
1238    ///   message is returned. Currently s2n-tls supports loading only one certificate in client mode. Note that
1239    ///   not all TLS endpoints will request a certificate.
1240    ///
1241    /// Corresponds to [s2n_connection_get_selected_cert].
1242    pub fn selected_cert(&self) -> Option<CertificateChain<'_>> {
1243        unsafe {
1244            // The API only returns null, no error is actually set.
1245            // Clippy doesn't realize from_ptr_reference is unsafe.
1246            #[allow(clippy::manual_map)]
1247            if let Some(ptr) =
1248                NonNull::new(s2n_connection_get_selected_cert(self.connection.as_ptr()))
1249            {
1250                Some(CertificateChain::from_ptr_reference(ptr))
1251            } else {
1252                None
1253            }
1254        }
1255    }
1256
1257    /// Corresponds to [s2n_connection_get_master_secret].
1258    pub fn master_secret(&self) -> Result<Vec<u8>, Error> {
1259        // TLS1.2 master secrets are always 48 bytes
1260        let mut secret = vec![0; 48];
1261        unsafe {
1262            s2n_connection_get_master_secret(
1263                self.connection.as_ptr(),
1264                secret.as_mut_ptr(),
1265                secret.len(),
1266            )
1267            .into_result()?;
1268        }
1269        Ok(secret)
1270    }
1271
1272    /// Retrieves the size of the serialized connection
1273    ///
1274    /// Corresponds to [s2n_connection_serialization_length].
1275    pub fn serialization_length(&self) -> Result<usize, Error> {
1276        unsafe {
1277            let mut length = 0;
1278            s2n_connection_serialization_length(self.connection.as_ptr(), &mut length)
1279                .into_result()?;
1280            Ok(length.try_into().unwrap())
1281        }
1282    }
1283
1284    /// Serializes the TLS connection into the provided buffer
1285    ///
1286    /// Corresponds to [s2n_connection_serialize].
1287    pub fn serialize(&self, output: &mut [u8]) -> Result<(), Error> {
1288        unsafe {
1289            s2n_connection_serialize(
1290                self.connection.as_ptr(),
1291                output.as_mut_ptr(),
1292                output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1293            )
1294            .into_result()?;
1295            Ok(())
1296        }
1297    }
1298
1299    /// Deserializes the input buffer into a new TLS connection that can send/recv
1300    /// data from the original peer.
1301    ///
1302    /// Corresponds to [s2n_connection_deserialize].
1303    pub fn deserialize(&mut self, input: &[u8]) -> Result<(), Error> {
1304        let size = input.len();
1305        /* This is not ideal, we know that s2n_connection_deserialize will not mutate the
1306         * input value, however, the mut is needed to use the stuffer functions. */
1307        let input = input.as_ptr() as *mut u8;
1308        unsafe {
1309            s2n_connection_deserialize(
1310                self.as_ptr(),
1311                input,
1312                size.try_into().map_err(|_| Error::INVALID_INPUT)?,
1313            )
1314            .into_result()?;
1315            Ok(())
1316        }
1317    }
1318
1319    /// Determines whether the connection was resumed from an earlier handshake.
1320    ///
1321    /// Corresponds to [s2n_connection_is_session_resumed].
1322    pub fn resumed(&self) -> bool {
1323        unsafe { s2n_connection_is_session_resumed(self.connection.as_ptr()) == 1 }
1324    }
1325
1326    /// Append an external psk to a connection.
1327    ///
1328    /// This may be called repeatedly to support multiple PSKs.
1329    ///
1330    /// Corresponds to [s2n_connection_append_psk].
1331    pub fn append_psk(&mut self, psk: &Psk) -> Result<(), Error> {
1332        unsafe {
1333            // SAFETY: retrieving a *mut s2n_psk from &Psk: s2n-tls does not treat
1334            // the pointer as mutable, and only holds the reference to copy the
1335            // PSK onto the connection.
1336            s2n_connection_append_psk(self.as_ptr(), psk.ptr.as_ptr()).into_result()?
1337        };
1338        Ok(())
1339    }
1340
1341    /// Corresponds to [s2n_connection_get_negotiated_psk_identity_length].
1342    pub fn negotiated_psk_identity_length(&self) -> Result<usize, Error> {
1343        let mut length = 0;
1344        unsafe {
1345            s2n_connection_get_negotiated_psk_identity_length(self.connection.as_ptr(), &mut length)
1346                .into_result()?
1347        };
1348        Ok(length as usize)
1349    }
1350
1351    /// Retrieve the negotiated psk identity. Use [Connection::negotiated_psk_identity_length]
1352    /// to retrieve the length of the psk identity.
1353    ///
1354    /// Corresponds to [s2n_connection_get_negotiated_psk_identity].
1355    pub fn negotiated_psk_identity(&self, destination: &mut [u8]) -> Result<(), Error> {
1356        unsafe {
1357            s2n_connection_get_negotiated_psk_identity(
1358                self.connection.as_ptr(),
1359                destination.as_mut_ptr(),
1360                destination.len().min(u16::MAX as usize) as u16,
1361            )
1362            .into_result()?;
1363        }
1364        Ok(())
1365    }
1366
1367    /// Associates an arbitrary application context with the Connection to be later retrieved via
1368    /// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
1369    ///
1370    /// This API will override an existing application context set on the Connection.
1371    ///
1372    /// Corresponds to [s2n_connection_set_ctx].
1373    pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
1374        self.context_mut().app_context = Some(Box::new(app_context));
1375    }
1376
1377    /// Retrieves a reference to the application context associated with the Connection.
1378    ///
1379    /// If an application context hasn't already been set on the Connection, or if the set
1380    /// application context isn't of type T, None will be returned.
1381    ///
1382    /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
1383    /// mutable reference to the context, use [`Self::application_context_mut()`].
1384    ///
1385    /// Corresponds to [s2n_connection_get_ctx].
1386    pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
1387        match self.context().app_context.as_ref() {
1388            None => None,
1389            // The Any trait keeps track of the application context's type. downcast_ref() returns
1390            // Some only if the correct type is provided:
1391            // https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
1392            Some(app_context) => app_context.downcast_ref::<T>(),
1393        }
1394    }
1395
1396    /// Retrieves a mutable reference to the application context associated with the Connection.
1397    ///
1398    /// If an application context hasn't already been set on the Connection, or if the set
1399    /// application context isn't of type T, None will be returned.
1400    ///
1401    /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
1402    /// immutable reference to the context, use [`Self::application_context()`].
1403    ///
1404    /// Corresponds to [s2n_connection_get_ctx].
1405    pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
1406        match self.context_mut().app_context.as_mut() {
1407            None => None,
1408            Some(app_context) => app_context.downcast_mut::<T>(),
1409        }
1410    }
1411
1412    #[cfg(feature = "unstable-renegotiate")]
1413    pub(crate) fn renegotiate_state_mut(&mut self) -> &mut RenegotiateState {
1414        &mut self.context_mut().renegotiate_state
1415    }
1416
1417    #[cfg(feature = "unstable-renegotiate")]
1418    pub(crate) fn renegotiate_state(&self) -> &RenegotiateState {
1419        &self.context().renegotiate_state
1420    }
1421}
1422
1423struct Context {
1424    mode: Mode,
1425    waker: Option<Waker>,
1426    async_callback: Option<AsyncCallback>,
1427    verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
1428    connection_initialized: bool,
1429    app_context: Option<Box<dyn Any + Send + Sync>>,
1430    #[cfg(feature = "unstable-renegotiate")]
1431    pub(crate) renegotiate_state: RenegotiateState,
1432}
1433
1434impl Context {
1435    fn new(mode: Mode) -> Self {
1436        Context {
1437            mode,
1438            waker: None,
1439            async_callback: None,
1440            verify_host_callback: None,
1441            connection_initialized: false,
1442            app_context: None,
1443            #[cfg(feature = "unstable-renegotiate")]
1444            renegotiate_state: RenegotiateState::default(),
1445        }
1446    }
1447}
1448
1449#[cfg(feature = "quic")]
1450impl Connection {
1451    /// Corresponds to [s2n_connection_enable_quic].
1452    pub fn enable_quic(&mut self) -> Result<&mut Self, Error> {
1453        unsafe { s2n_connection_enable_quic(self.connection.as_ptr()).into_result() }?;
1454        Ok(self)
1455    }
1456
1457    /// Corresponds to [s2n_connection_set_quic_transport_parameters].
1458    pub fn set_quic_transport_parameters(&mut self, buffer: &[u8]) -> Result<&mut Self, Error> {
1459        unsafe {
1460            s2n_connection_set_quic_transport_parameters(
1461                self.connection.as_ptr(),
1462                buffer.as_ptr(),
1463                buffer.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1464            )
1465            .into_result()
1466        }?;
1467        Ok(self)
1468    }
1469
1470    /// Corresponds to [s2n_connection_get_quic_transport_parameters].
1471    pub fn quic_transport_parameters(&mut self) -> Result<&[u8], Error> {
1472        let mut ptr = core::ptr::null();
1473        let mut len = 0;
1474        unsafe {
1475            s2n_connection_get_quic_transport_parameters(
1476                self.connection.as_ptr(),
1477                &mut ptr,
1478                &mut len,
1479            )
1480            .into_result()
1481        }?;
1482        let buffer = unsafe { core::slice::from_raw_parts(ptr, len as _) };
1483        Ok(buffer)
1484    }
1485
1486    /// # Safety
1487    ///
1488    /// The `context` pointer must live at least as long as the connection
1489    ///
1490    /// Corresponds to [s2n_connection_set_secret_callback].
1491    pub unsafe fn set_secret_callback(
1492        &mut self,
1493        callback: s2n_secret_cb,
1494        context: *mut c_void,
1495    ) -> Result<&mut Self, Error> {
1496        s2n_connection_set_secret_callback(self.connection.as_ptr(), callback, context)
1497            .into_result()?;
1498        Ok(self)
1499    }
1500
1501    /// Corresponds to [s2n_recv_quic_post_handshake_message].
1502    pub fn quic_process_post_handshake_message(&mut self) -> Result<&mut Self, Error> {
1503        let mut blocked = s2n_blocked_status::NOT_BLOCKED;
1504        unsafe {
1505            s2n_recv_quic_post_handshake_message(self.connection.as_ptr(), &mut blocked)
1506                .into_result()
1507        }?;
1508        Ok(self)
1509    }
1510
1511    /// Allows the quic library to check if session tickets are expected
1512    ///
1513    /// Corresponds to [s2n_connection_are_session_tickets_enabled].
1514    pub fn are_session_tickets_enabled(&self) -> bool {
1515        unsafe { s2n_connection_are_session_tickets_enabled(self.connection.as_ptr()) }
1516    }
1517}
1518
1519impl AsRef<Connection> for Connection {
1520    fn as_ref(&self) -> &Connection {
1521        self
1522    }
1523}
1524
1525impl AsMut<Connection> for Connection {
1526    fn as_mut(&mut self) -> &mut Connection {
1527        self
1528    }
1529}
1530
1531impl Drop for Connection {
1532    /// Corresponds to [s2n_connection_free].
1533    fn drop(&mut self) {
1534        // ignore failures since there's not much we can do about it
1535        unsafe {
1536            // clean up context
1537            let _ = self.drop_context();
1538
1539            // cleanup config
1540            let _ = self.drop_config();
1541
1542            // cleanup connection
1543            let _ = s2n_connection_free(self.connection.as_ptr()).into_result();
1544        }
1545    }
1546}
1547
1548#[cfg(test)]
1549mod tests {
1550    use super::*;
1551
1552    // ensure the connection context is send
1553    #[test]
1554    fn context_send_test() {
1555        fn assert_send<T: 'static + Send>() {}
1556        assert_send::<Context>();
1557    }
1558
1559    // ensure the connection context is sync
1560    #[test]
1561    fn context_sync_test() {
1562        fn assert_sync<T: 'static + Sync>() {}
1563        assert_sync::<Context>();
1564    }
1565
1566    /// Test that an application context can be set and retrieved.
1567    #[test]
1568    fn test_app_context_set_and_retrieve() {
1569        let mut connection = Connection::new_server();
1570
1571        // Before a context is set, None is returned.
1572        assert!(connection.application_context::<u32>().is_none());
1573
1574        let test_value: u32 = 1142;
1575        connection.set_application_context(test_value);
1576
1577        // After a context is set, the application data is returned.
1578        assert_eq!(*connection.application_context::<u32>().unwrap(), 1142);
1579    }
1580
1581    /// Test that an application context can be modified.
1582    #[test]
1583    fn test_app_context_modify() {
1584        let test_value: u64 = 0;
1585
1586        let mut connection = Connection::new_server();
1587        connection.set_application_context(test_value);
1588
1589        let context_value = connection.application_context_mut::<u64>().unwrap();
1590        *context_value += 1;
1591
1592        assert_eq!(*connection.application_context::<u64>().unwrap(), 1);
1593    }
1594
1595    /// Test that an application context can be overridden.
1596    #[test]
1597    fn test_app_context_override() {
1598        let mut connection = Connection::new_server();
1599
1600        let test_value: u16 = 1142;
1601        connection.set_application_context(test_value);
1602
1603        assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1604
1605        // Override the context with a new value.
1606        let test_value: u16 = 10;
1607        connection.set_application_context(test_value);
1608
1609        assert_eq!(*connection.application_context::<u16>().unwrap(), 10);
1610
1611        // Override the context with a new type.
1612        let test_value: i16 = -20;
1613        connection.set_application_context(test_value);
1614
1615        assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
1616    }
1617
1618    /// Test that a context of another type can't be retrieved.
1619    #[test]
1620    fn test_app_context_invalid_type() {
1621        let mut connection = Connection::new_server();
1622
1623        let test_value: u32 = 0;
1624        connection.set_application_context(test_value);
1625
1626        // A context type that wasn't set shouldn't be returned.
1627        assert!(connection.application_context::<i16>().is_none());
1628
1629        // Retrieving the correct type succeeds.
1630        assert!(connection.application_context::<u32>().is_some());
1631    }
1632}