tpm2_protocol/basic/
buffer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5use crate::{
6    basic::TpmUint16, TpmMarshal, TpmProtocolError, TpmResult, TpmSized, TpmUnmarshal, TpmWriter,
7};
8use core::{
9    convert::TryFrom,
10    fmt::Debug,
11    hash::{Hash, Hasher},
12    mem::{size_of, MaybeUninit},
13    ops::Deref,
14    slice,
15};
16
17/// A buffer in the TPM2B wire format.
18///
19/// The `size` field is stored in native endian and converted to big-endian
20/// only during marshaling.
21#[derive(Clone, Copy)]
22pub struct TpmBuffer<const CAPACITY: usize> {
23    size: u16,
24    data: [MaybeUninit<u8>; CAPACITY],
25}
26
27impl<const CAPACITY: usize> TpmBuffer<CAPACITY> {
28    /// Creates a new, empty `TpmBuffer`.
29    #[must_use]
30    pub const fn new() -> Self {
31        Self {
32            size: 0,
33            data: [const { MaybeUninit::uninit() }; CAPACITY],
34        }
35    }
36
37    /// Appends a byte to the buffer.
38    ///
39    /// # Errors
40    ///
41    /// Returns [`OutOfMemory`](crate::TpmProtocolError::OutOfMemory) when the
42    /// buffer is full or the size exceeds `u16::MAX`.
43    pub fn try_push(&mut self, byte: u8) -> TpmResult<()> {
44        if (self.size as usize) >= CAPACITY || self.size == u16::MAX {
45            return Err(TpmProtocolError::BufferOverflow);
46        }
47        self.data[self.size as usize].write(byte);
48        self.size += 1;
49        Ok(())
50    }
51
52    /// Appends a slice of bytes to the buffer.
53    ///
54    /// # Errors
55    ///
56    /// Returns [`OutOfMemory`](crate::TpmProtocolError::OutOfMemory) when the
57    /// resulting size exceeds the buffer capacity or `u16::MAX`.
58    pub fn try_extend_from_slice(&mut self, slice: &[u8]) -> TpmResult<()> {
59        let current_len = self.size as usize;
60        let new_len = current_len
61            .checked_add(slice.len())
62            .ok_or(TpmProtocolError::BufferOverflow)?;
63
64        if new_len > CAPACITY {
65            return Err(TpmProtocolError::BufferOverflow);
66        }
67
68        self.size = u16::try_from(new_len).map_err(|_| TpmProtocolError::BufferOverflow)?;
69
70        for (dest, src) in self.data[current_len..new_len].iter_mut().zip(slice) {
71            dest.write(*src);
72        }
73        Ok(())
74    }
75}
76
77#[allow(unsafe_code)]
78impl<const CAPACITY: usize> Deref for TpmBuffer<CAPACITY> {
79    type Target = [u8];
80
81    /// # Safety
82    ///
83    /// This implementation uses `unsafe` to provide a view into the initialized
84    /// portion of the buffer. The caller can rely on this being safe because:
85    /// 1. The first `self.size` bytes are guaranteed to be initialized by the
86    ///    `try_push` and `try_extend_from_slice` methods.
87    /// 2. `MaybeUninit<u8>` is guaranteed to have the same memory layout as `u8`.
88    fn deref(&self) -> &Self::Target {
89        let size = self.size as usize;
90        unsafe { slice::from_raw_parts(self.data.as_ptr().cast::<u8>(), size) }
91    }
92}
93
94impl<const CAPACITY: usize> Default for TpmBuffer<CAPACITY> {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl<const CAPACITY: usize> PartialEq for TpmBuffer<CAPACITY> {
101    fn eq(&self, other: &Self) -> bool {
102        **self == **other
103    }
104}
105
106impl<const CAPACITY: usize> Eq for TpmBuffer<CAPACITY> {}
107
108impl<const CAPACITY: usize> Hash for TpmBuffer<CAPACITY> {
109    fn hash<H: Hasher>(&self, state: &mut H) {
110        (**self).hash(state);
111    }
112}
113
114impl<const CAPACITY: usize> TpmSized for TpmBuffer<CAPACITY> {
115    const SIZE: usize = size_of::<TpmUint16>() + CAPACITY;
116    fn len(&self) -> usize {
117        size_of::<TpmUint16>() + self.size as usize
118    }
119}
120
121impl<const CAPACITY: usize> TpmMarshal for TpmBuffer<CAPACITY> {
122    fn marshal(&self, writer: &mut TpmWriter) -> TpmResult<()> {
123        TpmUint16::from(self.size).marshal(writer)?;
124        writer.write_bytes(self)
125    }
126}
127
128impl<const CAPACITY: usize> TpmUnmarshal for TpmBuffer<CAPACITY> {
129    fn unmarshal(buf: &[u8]) -> TpmResult<(Self, &[u8])> {
130        let (native_size, remainder) = TpmUint16::unmarshal(buf)?;
131        let size_usize = u16::from(native_size) as usize;
132
133        if size_usize > CAPACITY {
134            return Err(TpmProtocolError::TooManyBytes);
135        }
136
137        if remainder.len() < size_usize {
138            return Err(TpmProtocolError::UnexpectedEnd);
139        }
140
141        let mut buffer = Self::new();
142        buffer.try_extend_from_slice(&remainder[..size_usize])?;
143        Ok((buffer, &remainder[size_usize..]))
144    }
145}
146
147impl<'a, const CAPACITY: usize> TryFrom<&'a [u8]> for TpmBuffer<CAPACITY> {
148    type Error = TpmProtocolError;
149
150    fn try_from(slice: &'a [u8]) -> Result<Self, Self::Error> {
151        if slice.len() > CAPACITY {
152            return Err(TpmProtocolError::TooManyBytes);
153        }
154        let mut buffer = Self::new();
155        buffer.try_extend_from_slice(slice)?;
156        Ok(buffer)
157    }
158}
159
160impl<const CAPACITY: usize> AsRef<[u8]> for TpmBuffer<CAPACITY> {
161    fn as_ref(&self) -> &[u8] {
162        self
163    }
164}
165
166impl<const CAPACITY: usize> Debug for TpmBuffer<CAPACITY> {
167    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168        write!(f, "TpmBuffer(")?;
169        for byte in self.iter() {
170            write!(f, "{byte:02X}")?;
171        }
172        write!(f, ")")
173    }
174}