s2n_quic_core/crypto/
payload.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::packet::number::PacketNumberLen;
5use s2n_codec::{CheckedRange, DecoderBuffer, DecoderBufferMut, DecoderError};
6
7/// Type which restricts access to protected and encrypted payloads.
8///
9/// The `ProtectedPayload` is an `EncryptedPayload` that has had
10/// header protection applied. So to get to the cleartext payload,
11/// first you remove header protection, and then you decrypt the packet
12#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub struct ProtectedPayload<'a> {
14    pub(crate) header_len: usize,
15    pub(crate) buffer: DecoderBufferMut<'a>,
16}
17
18impl core::fmt::Debug for ProtectedPayload<'_> {
19    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
20        // Since the protected payload is not very helpful for debugging purposes,
21        // we just print the length of the protected payload as long as we are not in
22        // pretty-printing mode.
23        // Snapshot tests use the pretty-printing mode, therefore we can't change the Debug behavior
24        // for those.
25        let print_buffer_content = f.alternate();
26
27        let mut debug_struct = f.debug_struct("ProtectedPayload");
28        let mut debug_struct = debug_struct.field("header_len", &self.header_len);
29
30        if !print_buffer_content {
31            debug_struct = debug_struct.field("buffer_len", &(self.buffer.len() - self.header_len))
32        } else {
33            debug_struct = debug_struct.field("buffer", &self.buffer)
34        }
35        debug_struct.finish()
36    }
37}
38
39impl<'a> ProtectedPayload<'a> {
40    /// Creates a new protected payload with a header_len
41    pub fn new(header_len: usize, buffer: &'a mut [u8]) -> Self {
42        debug_assert!(buffer.len() >= header_len, "header_len is too large");
43
44        Self {
45            header_len,
46            buffer: DecoderBufferMut::new(buffer),
47        }
48    }
49
50    /// Reads data from a `CheckedRange`
51    pub fn get_checked_range(&self, range: &CheckedRange) -> DecoderBuffer<'_> {
52        self.buffer.get_checked_range(range)
53    }
54
55    pub(crate) fn header_protection_sample(
56        &self,
57        sample_len: usize,
58    ) -> Result<&[u8], DecoderError> {
59        header_protection_sample(self.buffer.peek(), self.header_len, sample_len)
60    }
61
62    /// Returns the length of the payload, including the header
63    pub fn len(&self) -> usize {
64        self.buffer.len()
65    }
66
67    /// Returns `true` if the payload is empty
68    pub fn is_empty(&self) -> bool {
69        self.buffer.is_empty()
70    }
71}
72
73/// Type which restricts access to encrypted payloads
74#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
75pub struct EncryptedPayload<'a> {
76    pub(crate) header_len: usize,
77    pub(crate) packet_number_len: PacketNumberLen,
78    pub(crate) buffer: DecoderBufferMut<'a>,
79}
80
81impl core::fmt::Debug for EncryptedPayload<'_> {
82    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
83        // Since the protected payload is not very helpful for debugging purposes,
84        // we just print the length of the protected payload as long as we are not in
85        // pretty-printing mode.
86        // Snapshot tests use the pretty-printing mode, therefore we can't change the Debug behavior
87        // for those.
88        let print_buffer_content = f.alternate();
89
90        let mut debug_struct = f.debug_struct("EncryptedPayload");
91        let mut debug_struct = debug_struct
92            .field("header_len", &self.header_len)
93            .field("packet_number_len", &self.packet_number_len);
94
95        if !print_buffer_content {
96            debug_struct = debug_struct.field("buffer_len", &(self.buffer.len() - self.header_len))
97        } else {
98            debug_struct = debug_struct.field("buffer", &self.buffer)
99        }
100        debug_struct.finish()
101    }
102}
103
104impl<'a> EncryptedPayload<'a> {
105    pub(crate) fn new(
106        header_len: usize,
107        packet_number_len: PacketNumberLen,
108        buffer: &'a mut [u8],
109    ) -> Self {
110        debug_assert!(
111            buffer.len() >= header_len + packet_number_len.bytesize(),
112            "header_len is too large"
113        );
114
115        Self {
116            header_len,
117            packet_number_len,
118            buffer: DecoderBufferMut::new(buffer),
119        }
120    }
121
122    /// Reads the packet tag in the payload
123    pub fn get_tag(&self) -> u8 {
124        self.buffer.as_less_safe_slice()[0]
125    }
126
127    /// Reads data from a `CheckedRange`
128    pub fn get_checked_range(&self, range: &CheckedRange) -> DecoderBuffer<'_> {
129        self.buffer.get_checked_range(range)
130    }
131
132    pub(crate) fn split_mut(self) -> (&'a mut [u8], &'a mut [u8]) {
133        let (header, payload) = self
134            .buffer
135            .decode_slice(self.header_len + self.packet_number_len.bytesize())
136            .expect("header_len already checked");
137        (
138            header.into_less_safe_slice(),
139            payload.into_less_safe_slice(),
140        )
141    }
142
143    pub(crate) fn header_protection_sample(
144        &self,
145        sample_len: usize,
146    ) -> Result<&[u8], DecoderError> {
147        header_protection_sample(self.buffer.peek(), self.header_len, sample_len)
148    }
149}
150
151fn header_protection_sample(
152    buffer: DecoderBuffer<'_>,
153    header_len: usize,
154    sample_len: usize,
155) -> Result<&[u8], DecoderError> {
156    let buffer = buffer.skip(header_len)?;
157
158    //= https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
159    //# in sampling packet ciphertext for header protection, the Packet Number field is
160    //# assumed to be 4 bytes long
161    let buffer = buffer.skip(PacketNumberLen::MAX_LEN)?;
162
163    //= https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
164    //# An endpoint MUST discard packets that are not long enough to contain
165    //# a complete sample.
166    let (sample, _) = buffer.decode_slice(sample_len)?;
167
168    Ok(sample.into_less_safe_slice())
169}