1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::packet::number::PacketNumberLen;
use s2n_codec::{CheckedRange, DecoderBuffer, DecoderBufferMut, DecoderError};

/// Type which restricts access to protected and encrypted payloads.
///
/// The `ProtectedPayload` is an `EncryptedPayload` that has had
/// header protection applied. So to get to the cleartext payload,
/// first you remove header protection, and then you decrypt the packet
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ProtectedPayload<'a> {
    pub(crate) header_len: usize,
    pub(crate) buffer: DecoderBufferMut<'a>,
}

impl<'a> core::fmt::Debug for ProtectedPayload<'a> {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
        // Since the protected payload is not very helpful for debugging purposes,
        // we just print the length of the protected payload as long as we are not in
        // pretty-printing mode.
        // Snapshot tests use the pretty-printing mode, therefore we can't change the Debug behavior
        // for those.
        let print_buffer_content = f.alternate();

        let mut debug_struct = f.debug_struct("ProtectedPayload");
        let mut debug_struct = debug_struct.field("header_len", &self.header_len);

        if !print_buffer_content {
            debug_struct = debug_struct.field("buffer_len", &(self.buffer.len() - self.header_len))
        } else {
            debug_struct = debug_struct.field("buffer", &self.buffer)
        }
        debug_struct.finish()
    }
}

impl<'a> ProtectedPayload<'a> {
    /// Creates a new protected payload with a header_len
    pub fn new(header_len: usize, buffer: &'a mut [u8]) -> Self {
        debug_assert!(buffer.len() >= header_len, "header_len is too large");

        Self {
            header_len,
            buffer: DecoderBufferMut::new(buffer),
        }
    }

    /// Reads data from a `CheckedRange`
    pub fn get_checked_range(&self, range: &CheckedRange) -> DecoderBuffer {
        self.buffer.get_checked_range(range)
    }

    pub(crate) fn header_protection_sample(
        &self,
        sample_len: usize,
    ) -> Result<&[u8], DecoderError> {
        header_protection_sample(self.buffer.peek(), self.header_len, sample_len)
    }

    /// Returns the length of the payload, including the header
    pub fn len(&self) -> usize {
        self.buffer.len()
    }

    /// Returns `true` if the payload is empty
    pub fn is_empty(&self) -> bool {
        self.buffer.is_empty()
    }
}

/// Type which restricts access to encrypted payloads
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EncryptedPayload<'a> {
    pub(crate) header_len: usize,
    pub(crate) packet_number_len: PacketNumberLen,
    pub(crate) buffer: DecoderBufferMut<'a>,
}

impl<'a> core::fmt::Debug for EncryptedPayload<'a> {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
        // Since the protected payload is not very helpful for debugging purposes,
        // we just print the length of the protected payload as long as we are not in
        // pretty-printing mode.
        // Snapshot tests use the pretty-printing mode, therefore we can't change the Debug behavior
        // for those.
        let print_buffer_content = f.alternate();

        let mut debug_struct = f.debug_struct("EncryptedPayload");
        let mut debug_struct = debug_struct
            .field("header_len", &self.header_len)
            .field("packet_number_len", &self.packet_number_len);

        if !print_buffer_content {
            debug_struct = debug_struct.field("buffer_len", &(self.buffer.len() - self.header_len))
        } else {
            debug_struct = debug_struct.field("buffer", &self.buffer)
        }
        debug_struct.finish()
    }
}

impl<'a> EncryptedPayload<'a> {
    pub(crate) fn new(
        header_len: usize,
        packet_number_len: PacketNumberLen,
        buffer: &'a mut [u8],
    ) -> Self {
        debug_assert!(
            buffer.len() >= header_len + packet_number_len.bytesize(),
            "header_len is too large"
        );

        Self {
            header_len,
            packet_number_len,
            buffer: DecoderBufferMut::new(buffer),
        }
    }

    /// Reads the packet tag in the payload
    pub fn get_tag(&self) -> u8 {
        self.buffer.as_less_safe_slice()[0]
    }

    /// Reads data from a `CheckedRange`
    pub fn get_checked_range(&self, range: &CheckedRange) -> DecoderBuffer {
        self.buffer.get_checked_range(range)
    }

    pub(crate) fn split_mut(self) -> (&'a mut [u8], &'a mut [u8]) {
        let (header, payload) = self
            .buffer
            .decode_slice(self.header_len + self.packet_number_len.bytesize())
            .expect("header_len already checked");
        (
            header.into_less_safe_slice(),
            payload.into_less_safe_slice(),
        )
    }

    pub(crate) fn header_protection_sample(
        &self,
        sample_len: usize,
    ) -> Result<&[u8], DecoderError> {
        header_protection_sample(self.buffer.peek(), self.header_len, sample_len)
    }
}

fn header_protection_sample(
    buffer: DecoderBuffer,
    header_len: usize,
    sample_len: usize,
) -> Result<&[u8], DecoderError> {
    let buffer = buffer.skip(header_len)?;

    //= https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
    //# in sampling packet ciphertext for header protection, the Packet Number field is
    //# assumed to be 4 bytes long
    let buffer = buffer.skip(PacketNumberLen::MAX_LEN)?;

    //= https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
    //# An endpoint MUST discard packets that are not long enough to contain
    //# a complete sample.
    let (sample, _) = buffer.decode_slice(sample_len)?;

    Ok(sample.into_less_safe_slice())
}