tpm2_protocol/basic/
buffer.rs1use crate::{
6 basic::TpmUint16, TpmMarshal, TpmProtocolError, TpmResult, TpmSized, TpmUnmarshal, TpmWriter,
7};
8use core::{
9 convert::TryFrom,
10 fmt::Debug,
11 hash::{Hash, Hasher},
12 mem::{size_of, MaybeUninit},
13 ops::Deref,
14 slice,
15};
16
17#[derive(Clone, Copy)]
22pub struct TpmBuffer<const CAPACITY: usize> {
23 size: u16,
24 data: [MaybeUninit<u8>; CAPACITY],
25}
26
27impl<const CAPACITY: usize> TpmBuffer<CAPACITY> {
28 #[must_use]
30 pub const fn new() -> Self {
31 Self {
32 size: 0,
33 data: [const { MaybeUninit::uninit() }; CAPACITY],
34 }
35 }
36
37 pub fn try_push(&mut self, byte: u8) -> TpmResult<()> {
44 if (self.size as usize) >= CAPACITY || self.size == u16::MAX {
45 return Err(TpmProtocolError::BufferOverflow);
46 }
47 self.data[self.size as usize].write(byte);
48 self.size += 1;
49 Ok(())
50 }
51
52 pub fn try_extend_from_slice(&mut self, slice: &[u8]) -> TpmResult<()> {
59 let current_len = self.size as usize;
60 let new_len = current_len
61 .checked_add(slice.len())
62 .ok_or(TpmProtocolError::BufferOverflow)?;
63
64 if new_len > CAPACITY {
65 return Err(TpmProtocolError::BufferOverflow);
66 }
67
68 self.size = u16::try_from(new_len).map_err(|_| TpmProtocolError::BufferOverflow)?;
69
70 for (dest, src) in self.data[current_len..new_len].iter_mut().zip(slice) {
71 dest.write(*src);
72 }
73 Ok(())
74 }
75}
76
77#[allow(unsafe_code)]
78impl<const CAPACITY: usize> Deref for TpmBuffer<CAPACITY> {
79 type Target = [u8];
80
81 fn deref(&self) -> &Self::Target {
89 let size = self.size as usize;
90 unsafe { slice::from_raw_parts(self.data.as_ptr().cast::<u8>(), size) }
91 }
92}
93
94impl<const CAPACITY: usize> Default for TpmBuffer<CAPACITY> {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl<const CAPACITY: usize> PartialEq for TpmBuffer<CAPACITY> {
101 fn eq(&self, other: &Self) -> bool {
102 **self == **other
103 }
104}
105
106impl<const CAPACITY: usize> Eq for TpmBuffer<CAPACITY> {}
107
108impl<const CAPACITY: usize> Hash for TpmBuffer<CAPACITY> {
109 fn hash<H: Hasher>(&self, state: &mut H) {
110 (**self).hash(state);
111 }
112}
113
114impl<const CAPACITY: usize> TpmSized for TpmBuffer<CAPACITY> {
115 const SIZE: usize = size_of::<TpmUint16>() + CAPACITY;
116 fn len(&self) -> usize {
117 size_of::<TpmUint16>() + self.size as usize
118 }
119}
120
121impl<const CAPACITY: usize> TpmMarshal for TpmBuffer<CAPACITY> {
122 fn marshal(&self, writer: &mut TpmWriter) -> TpmResult<()> {
123 TpmUint16::from(self.size).marshal(writer)?;
124 writer.write_bytes(self)
125 }
126}
127
128impl<const CAPACITY: usize> TpmUnmarshal for TpmBuffer<CAPACITY> {
129 fn unmarshal(buf: &[u8]) -> TpmResult<(Self, &[u8])> {
130 let (native_size, remainder) = TpmUint16::unmarshal(buf)?;
131 let size_usize = u16::from(native_size) as usize;
132
133 if size_usize > CAPACITY {
134 return Err(TpmProtocolError::TooManyBytes);
135 }
136
137 if remainder.len() < size_usize {
138 return Err(TpmProtocolError::UnexpectedEnd);
139 }
140
141 let mut buffer = Self::new();
142 buffer.try_extend_from_slice(&remainder[..size_usize])?;
143 Ok((buffer, &remainder[size_usize..]))
144 }
145}
146
147impl<'a, const CAPACITY: usize> TryFrom<&'a [u8]> for TpmBuffer<CAPACITY> {
148 type Error = TpmProtocolError;
149
150 fn try_from(slice: &'a [u8]) -> Result<Self, Self::Error> {
151 if slice.len() > CAPACITY {
152 return Err(TpmProtocolError::TooManyBytes);
153 }
154 let mut buffer = Self::new();
155 buffer.try_extend_from_slice(slice)?;
156 Ok(buffer)
157 }
158}
159
160impl<const CAPACITY: usize> AsRef<[u8]> for TpmBuffer<CAPACITY> {
161 fn as_ref(&self) -> &[u8] {
162 self
163 }
164}
165
166impl<const CAPACITY: usize> Debug for TpmBuffer<CAPACITY> {
167 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168 write!(f, "TpmBuffer(")?;
169 for byte in self.iter() {
170 write!(f, "{byte:02X}")?;
171 }
172 write!(f, ")")
173 }
174}