1use crate::{
8 crypto::{crypto_digest, CryptoError},
9 device::{Device, DeviceError},
10 key::from_str_to_alg_id,
11};
12use std::{convert::TryFrom, fmt, str::FromStr};
13use thiserror::Error;
14use tpm2_protocol::{
15 constant::TPM_PCR_SELECT_MAX,
16 data::{TpmAlgId, TpmCap, TpmCc, TpmlPcrSelection, TpmsPcrSelection, TpmuCapabilities},
17 message::TpmPcrReadCommand,
18 tpm_hash_size, TpmBuffer, TpmErrorKind,
19};
20
21#[derive(Debug, Error)]
22pub enum PcrError {
23 #[error("device: {0}")]
24 Device(#[from] DeviceError),
25 #[error("invalid algorithm: {0:?}")]
26 InvalidAlgorithm(TpmAlgId),
27 #[error("invalid PCR selection: {0}")]
28 InvalidPcrSelection(String),
29 #[error("TPM: {0}")]
30 Tpm(TpmErrorKind),
31 #[error("crypto: {0}")]
32 Crypto(#[from] CryptoError),
33}
34
35impl From<TpmErrorKind> for PcrError {
36 fn from(err: TpmErrorKind) -> Self {
37 Self::Tpm(err)
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct Pcr {
44 pub bank: TpmAlgId,
45 pub index: u32,
46 pub value: Vec<u8>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct PcrBank {
52 pub alg: TpmAlgId,
53 pub count: usize,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct PcrSelection {
59 pub alg: TpmAlgId,
60 pub indices: Vec<u32>,
61}
62
63impl fmt::Display for PcrSelection {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 let indices_str = self
66 .indices
67 .iter()
68 .map(ToString::to_string)
69 .collect::<Vec<_>>()
70 .join(",");
71 write!(f, "{}:{}", crate::key::Tpm2shAlgId(self.alg), indices_str)
72 }
73}
74
75impl FromStr for PcrSelection {
76 type Err = PcrError;
77
78 fn from_str(s: &str) -> Result<Self, Self::Err> {
79 let (alg_str, indices_str) = s
80 .split_once(':')
81 .ok_or_else(|| PcrError::InvalidPcrSelection(format!("invalid bank format: '{s}'")))?;
82 let alg = from_str_to_alg_id(alg_str)
83 .map_err(|e| PcrError::InvalidPcrSelection(e.to_string()))?;
84 let indices: Vec<u32> = indices_str
85 .split(',')
86 .map(str::parse)
87 .collect::<Result<_, _>>()
88 .map_err(|e: std::num::ParseIntError| PcrError::InvalidPcrSelection(e.to_string()))?;
89 Ok(PcrSelection { alg, indices })
90 }
91}
92
93pub fn pcr_get_bank_list(device: &mut Device) -> Result<Vec<PcrBank>, PcrError> {
100 let (_, cap_data) = device.get_capability(TpmCap::Pcrs, 0, 1)?;
101 let mut banks = Vec::new();
102 if let TpmuCapabilities::Pcrs(pcrs) = cap_data.data {
103 for bank in pcrs.iter() {
104 banks.push(PcrBank {
105 alg: bank.hash,
106 count: bank.pcr_select.len() * 8,
107 });
108 }
109 }
110 if banks.is_empty() {
111 return Err(PcrError::InvalidPcrSelection(
112 "TPM reported no active PCR banks.".to_string(),
113 ));
114 }
115 banks.sort_by_key(|b| b.alg);
116 Ok(banks)
117}
118
119pub fn pcr_selection_vec_from_str(selection_str: &str) -> Result<Vec<PcrSelection>, PcrError> {
127 selection_str
128 .split('+')
129 .map(PcrSelection::from_str)
130 .collect()
131}
132
133pub fn parse_pcr_policy_string(
139 policy_str: &str,
140) -> Result<(Vec<PcrSelection>, Option<String>), PcrError> {
141 let (selection_part, digest_part) =
142 if let Some((selection, digest)) = policy_str.rsplit_once(':') {
143 let is_digest = !digest.is_empty()
144 && digest.len() >= tpm_hash_size(&TpmAlgId::Sha1).unwrap_or(20) * 2
145 && digest.chars().all(|c| c.is_ascii_hexdigit());
146
147 if is_digest {
148 (selection, Some(digest.to_string()))
149 } else {
150 (policy_str, None)
151 }
152 } else {
153 (policy_str, None)
154 };
155
156 let selections = pcr_selection_vec_from_str(selection_part)?;
157 Ok((selections, digest_part))
158}
159
160pub fn pcr_selection_vec_to_tpml(
168 selections: &[PcrSelection],
169 banks: &[PcrBank],
170) -> Result<TpmlPcrSelection, PcrError> {
171 let mut list = TpmlPcrSelection::new();
172 for selection in selections {
173 let bank = banks
174 .iter()
175 .find(|b| b.alg == selection.alg)
176 .ok_or_else(|| {
177 PcrError::InvalidPcrSelection(format!(
178 "PCR bank for algorithm {:?} not found or supported by TPM",
179 selection.alg
180 ))
181 })?;
182 let pcr_select_size = bank.count.div_ceil(8);
183 if pcr_select_size > TPM_PCR_SELECT_MAX {
184 return Err(PcrError::InvalidPcrSelection(format!(
185 "invalid select size {pcr_select_size} (> {TPM_PCR_SELECT_MAX})"
186 )));
187 }
188 let mut pcr_select_bytes = vec![0u8; pcr_select_size];
189 for &pcr_index in &selection.indices {
190 let pcr_index = pcr_index as usize;
191 if pcr_index >= bank.count {
192 return Err(PcrError::InvalidPcrSelection(format!(
193 "invalid index {pcr_index} for {:?} bank (max is {})",
194 bank.alg,
195 bank.count - 1
196 )));
197 }
198 pcr_select_bytes[pcr_index / 8] |= 1 << (pcr_index % 8);
199 }
200 list.try_push(TpmsPcrSelection {
201 hash: selection.alg,
202 pcr_select: TpmBuffer::try_from(pcr_select_bytes.as_slice())?,
203 })?;
204 }
205 Ok(list)
206}
207
208pub fn pcr_read(
215 device: &mut Device,
216 pcr_selection_in: &TpmlPcrSelection,
217) -> Result<(Vec<Pcr>, u32), PcrError> {
218 let cmd = TpmPcrReadCommand {
219 pcr_selection_in: *pcr_selection_in,
220 };
221 let (resp, _) = device.execute(&cmd, &[])?;
222 let pcr_read_resp = resp
223 .PcrRead()
224 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::PcrRead))?;
225 let mut pcrs = Vec::new();
226 let mut digest_iter = pcr_read_resp.pcr_values.iter();
227 for selection in pcr_read_resp.pcr_selection_out.iter() {
228 for (byte_idx, &byte) in selection.pcr_select.iter().enumerate() {
229 if byte == 0 {
230 continue;
231 }
232 for bit_idx in 0..8 {
233 if (byte >> bit_idx) & 1 == 1 {
234 let pcr_index = u32::try_from(byte_idx * 8 + bit_idx)
235 .map_err(|_| PcrError::InvalidPcrSelection("PCR index overflow".into()))?;
236 let value = digest_iter.next().ok_or_else(|| {
237 PcrError::InvalidPcrSelection("PCR selection mismatch".to_string())
238 })?;
239 pcrs.push(Pcr {
240 bank: selection.hash,
241 index: pcr_index,
242 value: value.to_vec(),
243 });
244 }
245 }
246 }
247 }
248 Ok((pcrs, pcr_read_resp.pcr_update_counter))
249}
250
251pub fn pcr_composite_digest(pcrs: &[Pcr], alg: TpmAlgId) -> Result<Vec<u8>, PcrError> {
258 let digests: Vec<&[u8]> = pcrs.iter().map(|p| p.value.as_slice()).collect();
259 Ok(crypto_digest(alg, &digests)?)
260}