Skip to main content

tpm2_protocol/basic/
buffer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5use crate::{
6    TpmCast, TpmCastMut, TpmMarshal, TpmProtocolError, TpmResult, TpmSized, TpmUnmarshal,
7    TpmWriter, basic::TpmUint16,
8};
9use core::{
10    convert::TryFrom,
11    fmt::Debug,
12    hash::{Hash, Hasher},
13    mem::{MaybeUninit, size_of},
14    ops::Deref,
15    slice,
16};
17
18const TPM2B_SIZE_LEN: usize = size_of::<TpmUint16>();
19
20/// A zero-copy TPM2B wire view over caller-owned bytes.
21#[repr(transparent)]
22pub struct Tpm2b<const CAPACITY: usize>([u8]);
23
24impl<const CAPACITY: usize> Tpm2b<CAPACITY> {
25    /// Casts a byte slice into a TPM2B wire view.
26    ///
27    /// # Errors
28    ///
29    /// Returns [`UnexpectedEnd`](crate::TpmProtocolError::UnexpectedEnd) when
30    /// `buf` is shorter than the TPM2B header or declared payload size.
31    /// Returns [`TrailingData`](crate::TpmProtocolError::TrailingData) when
32    /// `buf` contains bytes after the declared payload.
33    /// Returns [`TooManyBytes`](crate::TpmProtocolError::TooManyBytes) when
34    /// the declared payload exceeds `CAPACITY`.
35    pub fn cast(buf: &[u8]) -> TpmResult<&Self> {
36        Self::validate(buf)?;
37
38        // SAFETY: `validate` checked the complete TPM2B byte range and size
39        // limit for this transparent wire view.
40        Ok(unsafe { Self::cast_unchecked(buf) })
41    }
42
43    /// Casts a byte slice into a TPM2B wire view without validation.
44    ///
45    /// # Safety
46    ///
47    /// The caller must ensure that `buf` contains exactly one complete TPM2B
48    /// value and that its declared payload length does not exceed `CAPACITY`.
49    #[must_use]
50    pub unsafe fn cast_unchecked(buf: &[u8]) -> &Self {
51        let ptr = core::ptr::from_ref(buf) as *const Self;
52
53        // SAFETY: `Tpm2b` is `repr(transparent)` over `[u8]`, so it has the
54        // same layout, metadata, and alignment as the referenced slice.
55        unsafe { &*ptr }
56    }
57
58    /// Casts a mutable byte slice into a mutable TPM2B wire view.
59    ///
60    /// # Errors
61    ///
62    /// Returns [`UnexpectedEnd`](crate::TpmProtocolError::UnexpectedEnd) when
63    /// `buf` is shorter than the TPM2B header or declared payload size.
64    /// Returns [`TrailingData`](crate::TpmProtocolError::TrailingData) when
65    /// `buf` contains bytes after the declared payload.
66    /// Returns [`TooManyBytes`](crate::TpmProtocolError::TooManyBytes) when
67    /// the declared payload exceeds `CAPACITY`.
68    pub fn cast_mut(buf: &mut [u8]) -> TpmResult<&mut Self> {
69        Self::validate(buf)?;
70
71        // SAFETY: `validate` checked the complete TPM2B byte range and size
72        // limit for this transparent wire view. The `&mut` input provides
73        // exclusive access.
74        Ok(unsafe { Self::cast_mut_unchecked(buf) })
75    }
76
77    /// Casts a mutable byte slice into a mutable TPM2B wire view without validation.
78    ///
79    /// # Safety
80    ///
81    /// The caller must ensure that `buf` contains exactly one complete TPM2B
82    /// value and that its declared payload length does not exceed `CAPACITY`.
83    /// The returned reference inherits the exclusive access represented by
84    /// `buf`.
85    #[must_use]
86    pub unsafe fn cast_mut_unchecked(buf: &mut [u8]) -> &mut Self {
87        let ptr = core::ptr::from_mut(buf) as *mut Self;
88
89        // SAFETY: `Tpm2b` is `repr(transparent)` over `[u8]`, so it has the
90        // same layout, metadata, and alignment as the referenced slice.
91        unsafe { &mut *ptr }
92    }
93
94    /// Returns the complete TPM2B byte representation.
95    #[must_use]
96    pub const fn as_bytes(&self) -> &[u8] {
97        &self.0
98    }
99
100    /// Returns the complete mutable TPM2B byte representation.
101    #[must_use]
102    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
103        &mut self.0
104    }
105
106    /// Returns the declared payload size.
107    #[must_use]
108    pub fn size(&self) -> usize {
109        Self::read_size(&self.0)
110    }
111
112    /// Returns the payload bytes.
113    #[must_use]
114    pub fn data(&self) -> &[u8] {
115        &self.0[TPM2B_SIZE_LEN..]
116    }
117
118    /// Returns the mutable payload bytes.
119    #[must_use]
120    pub fn data_mut(&mut self) -> &mut [u8] {
121        &mut self.0[TPM2B_SIZE_LEN..]
122    }
123
124    /// Returns the complete TPM2B wire length.
125    #[must_use]
126    pub const fn len(&self) -> usize {
127        self.0.len()
128    }
129
130    /// Returns `true` when the TPM2B payload is empty.
131    #[must_use]
132    pub fn is_empty(&self) -> bool {
133        self.size() == 0
134    }
135
136    fn validate(buf: &[u8]) -> TpmResult<()> {
137        if buf.len() < TPM2B_SIZE_LEN {
138            return Err(TpmProtocolError::UnexpectedEnd);
139        }
140
141        let payload_len = Self::read_size(buf);
142        if payload_len > CAPACITY {
143            return Err(TpmProtocolError::TooManyBytes);
144        }
145
146        let wire_len = TPM2B_SIZE_LEN
147            .checked_add(payload_len)
148            .ok_or(TpmProtocolError::IntegerTooLarge)?;
149
150        if buf.len() < wire_len {
151            return Err(TpmProtocolError::UnexpectedEnd);
152        }
153
154        if buf.len() > wire_len {
155            return Err(TpmProtocolError::TrailingData);
156        }
157
158        Ok(())
159    }
160
161    fn read_size(buf: &[u8]) -> usize {
162        usize::from(u16::from_be_bytes([buf[0], buf[1]]))
163    }
164}
165
166impl<const CAPACITY: usize> TpmCast for Tpm2b<CAPACITY> {
167    fn cast(buf: &[u8]) -> TpmResult<&Self> {
168        Self::cast(buf)
169    }
170
171    unsafe fn cast_unchecked(buf: &[u8]) -> &Self {
172        // SAFETY: The caller upholds the unchecked cast contract for `Tpm2b`.
173        unsafe { Self::cast_unchecked(buf) }
174    }
175}
176
177impl<const CAPACITY: usize> TpmCastMut for Tpm2b<CAPACITY> {
178    fn cast_mut(buf: &mut [u8]) -> TpmResult<&mut Self> {
179        Self::cast_mut(buf)
180    }
181
182    unsafe fn cast_mut_unchecked(buf: &mut [u8]) -> &mut Self {
183        // SAFETY: The caller upholds the unchecked mutable cast contract for
184        // `Tpm2b`.
185        unsafe { Self::cast_mut_unchecked(buf) }
186    }
187}
188
189impl<const CAPACITY: usize> AsRef<[u8]> for Tpm2b<CAPACITY> {
190    fn as_ref(&self) -> &[u8] {
191        self.as_bytes()
192    }
193}
194
195impl<const CAPACITY: usize> AsMut<[u8]> for Tpm2b<CAPACITY> {
196    fn as_mut(&mut self) -> &mut [u8] {
197        self.as_bytes_mut()
198    }
199}
200
201/// A buffer in the TPM2B wire format.
202///
203/// The `size` field is stored in native endian and converted to big-endian
204/// only during marshaling.
205#[derive(Clone, Copy)]
206pub struct TpmBuffer<const CAPACITY: usize> {
207    size: u16,
208    data: [MaybeUninit<u8>; CAPACITY],
209}
210
211impl<const CAPACITY: usize> TpmBuffer<CAPACITY> {
212    /// Creates a new, empty `TpmBuffer`.
213    #[must_use]
214    pub const fn new() -> Self {
215        Self {
216            size: 0,
217            data: [const { MaybeUninit::uninit() }; CAPACITY],
218        }
219    }
220
221    /// Appends a byte to the buffer.
222    ///
223    /// # Errors
224    ///
225    /// Returns [`BufferOverflow`](crate::TpmProtocolError::BufferOverflow) when the
226    /// buffer is full or the size exceeds `u16::MAX`.
227    pub fn try_push(&mut self, byte: u8) -> TpmResult<()> {
228        if (self.size as usize) >= CAPACITY || self.size == u16::MAX {
229            return Err(TpmProtocolError::BufferOverflow);
230        }
231        self.data[self.size as usize].write(byte);
232        self.size += 1;
233        Ok(())
234    }
235
236    /// Appends a slice of bytes to the buffer.
237    ///
238    /// # Errors
239    ///
240    /// Returns [`BufferOverflow`](crate::TpmProtocolError::BufferOverflow) when the
241    /// resulting size exceeds the buffer capacity or `u16::MAX`.
242    pub fn try_extend_from_slice(&mut self, slice: &[u8]) -> TpmResult<()> {
243        let current_len = self.size as usize;
244        let new_len = current_len
245            .checked_add(slice.len())
246            .ok_or(TpmProtocolError::BufferOverflow)?;
247
248        if new_len > CAPACITY {
249            return Err(TpmProtocolError::BufferOverflow);
250        }
251
252        self.size = u16::try_from(new_len).map_err(|_| TpmProtocolError::BufferOverflow)?;
253
254        for (dest, src) in self.data[current_len..new_len].iter_mut().zip(slice) {
255            dest.write(*src);
256        }
257        Ok(())
258    }
259}
260
261impl<const CAPACITY: usize> Deref for TpmBuffer<CAPACITY> {
262    type Target = [u8];
263
264    fn deref(&self) -> &Self::Target {
265        let size = self.size as usize;
266
267        // SAFETY: The first `size` bytes are initialized by the mutation APIs,
268        // and `MaybeUninit<u8>` has the same layout as `u8`.
269        unsafe { slice::from_raw_parts(self.data.as_ptr().cast::<u8>(), size) }
270    }
271}
272
273impl<const CAPACITY: usize> Default for TpmBuffer<CAPACITY> {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279impl<const CAPACITY: usize> PartialEq for TpmBuffer<CAPACITY> {
280    fn eq(&self, other: &Self) -> bool {
281        **self == **other
282    }
283}
284
285impl<const CAPACITY: usize> Eq for TpmBuffer<CAPACITY> {}
286
287impl<const CAPACITY: usize> Hash for TpmBuffer<CAPACITY> {
288    fn hash<H: Hasher>(&self, state: &mut H) {
289        (**self).hash(state);
290    }
291}
292
293impl<const CAPACITY: usize> TpmSized for TpmBuffer<CAPACITY> {
294    const SIZE: usize = size_of::<TpmUint16>() + CAPACITY;
295    fn len(&self) -> usize {
296        size_of::<TpmUint16>() + self.size as usize
297    }
298}
299
300impl<const CAPACITY: usize> TpmMarshal for TpmBuffer<CAPACITY> {
301    fn marshal(&self, writer: &mut TpmWriter) -> TpmResult<()> {
302        TpmUint16::from(self.size).marshal(writer)?;
303        writer.write_bytes(self)
304    }
305}
306
307impl<const CAPACITY: usize> TpmUnmarshal for TpmBuffer<CAPACITY> {
308    fn unmarshal(buf: &[u8]) -> TpmResult<(Self, &[u8])> {
309        let (native_size, remainder) = TpmUint16::unmarshal(buf)?;
310        let size_usize = u16::from(native_size) as usize;
311
312        if size_usize > CAPACITY {
313            return Err(TpmProtocolError::TooManyBytes);
314        }
315
316        if remainder.len() < size_usize {
317            return Err(TpmProtocolError::UnexpectedEnd);
318        }
319
320        let mut buffer = Self::new();
321        buffer.try_extend_from_slice(&remainder[..size_usize])?;
322        Ok((buffer, &remainder[size_usize..]))
323    }
324}
325
326impl<const CAPACITY: usize> TryFrom<&[u8]> for TpmBuffer<CAPACITY> {
327    type Error = TpmProtocolError;
328
329    fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
330        if slice.len() > CAPACITY {
331            return Err(TpmProtocolError::TooManyBytes);
332        }
333        let mut buffer = Self::new();
334        buffer.try_extend_from_slice(slice)?;
335        Ok(buffer)
336    }
337}
338
339impl<const CAPACITY: usize> AsRef<[u8]> for TpmBuffer<CAPACITY> {
340    fn as_ref(&self) -> &[u8] {
341        self
342    }
343}
344
345impl<const CAPACITY: usize> Debug for TpmBuffer<CAPACITY> {
346    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
347        write!(f, "TpmBuffer(")?;
348        for byte in self.iter() {
349            write!(f, "{byte:02X}")?;
350        }
351        write!(f, ")")
352    }
353}