1pub mod software;
8pub mod tpm;
9
10pub use software::*;
11pub use tpm::*;
12
13use crate::{
14 context::ContextError,
15 crypto::CryptoError,
16 device::{Device, DeviceError},
17 pcr::{self, PcrError},
18 session::SessionError,
19 uri::{Uri, UriError},
20};
21use nom::{
22 branch::alt,
23 bytes::complete::{tag, take_while1},
24 character::complete::{char, hex_digit1, space0},
25 combinator::{map, map_res, opt},
26 error::ErrorKind,
27 multi::separated_list1,
28 sequence::{delimited, preceded, terminated, tuple},
29 Err as NomErr, IResult,
30};
31use std::{collections::HashMap, fmt, path::Path, str::FromStr};
32use thiserror::Error;
33use tpm2_protocol::{
34 data::{Tpm2bDigest, TpmAlgId, TpmlDigest, TpmlPcrSelection, TpmsContext},
35 message::TpmFlushContextCommand,
36 TpmErrorKind, TpmParse,
37};
38
39pub trait PolicySession {
41 fn device(&mut self) -> &mut Device;
43
44 fn policy_pcr(
50 &mut self,
51 pcr_digest: &Tpm2bDigest,
52 pcrs: TpmlPcrSelection,
53 ) -> Result<(), PolicyError>;
54
55 fn policy_or(&mut self, p_hash_list: &TpmlDigest) -> Result<(), PolicyError>;
61
62 fn policy_secret(
68 &mut self,
69 auth_handle: u32,
70 auth_handle_name: &tpm2_protocol::data::Tpm2bName,
71 password: Option<&[u8]>,
72 cp_hash: Option<Tpm2bDigest>,
73 ) -> Result<(), PolicyError>;
74
75 fn get_digest(&mut self) -> Result<Tpm2bDigest, PolicyError>;
81
82 fn hash_alg(&self) -> TpmAlgId;
84}
85
86#[derive(Debug, Error)]
87pub enum PolicyError {
88 #[error("context: {0}")]
89 Context(#[from] ContextError),
90 #[error("device: {0}")]
91 Device(#[from] DeviceError),
92 #[error("invalid algorithm: {0:?}")]
93 InvalidAlgorithm(TpmAlgId),
94 #[error("invalid expression: {0}")]
95 InvalidExpression(String),
96 #[error("invalid value: {0}")]
97 InvalidValue(String),
98 #[error("I/O: {0}")]
99 Io(#[from] std::io::Error),
100 #[error("pcr: {0}")]
101 Pcr(#[from] PcrError),
102 #[error("crypto: {0}")]
103 Crypto(#[from] CryptoError),
104 #[error("PCR value for selection '{0}' not provided")]
105 PcrValueMissing(String),
106 #[error("session: {0}")]
107 Session(#[from] SessionError),
108 #[error("uri: {0}")]
109 Uri(#[from] UriError),
110}
111
112impl From<hex::FromHexError> for PolicyError {
113 fn from(err: hex::FromHexError) -> Self {
114 Self::InvalidValue(err.to_string())
115 }
116}
117
118impl From<base64::DecodeError> for PolicyError {
119 fn from(err: base64::DecodeError) -> Self {
120 Self::InvalidValue(err.to_string())
121 }
122}
123
124impl From<std::num::ParseIntError> for PolicyError {
125 fn from(err: std::num::ParseIntError) -> Self {
126 Self::InvalidValue(err.to_string())
127 }
128}
129
130impl From<std::str::Utf8Error> for PolicyError {
131 fn from(err: std::str::Utf8Error) -> Self {
132 Self::InvalidValue(err.to_string())
133 }
134}
135
136impl From<TpmErrorKind> for PolicyError {
137 fn from(err: TpmErrorKind) -> Self {
138 Self::Device(err.into())
139 }
140}
141
142#[derive(Debug, PartialEq, Clone)]
144pub enum Expression {
145 Pcr {
146 selection: String,
147 digest: Option<String>,
148 count: Option<u32>,
149 },
150 Secret {
151 auth_handle_uri: Box<Expression>,
152 password: Option<Box<Expression>>,
153 cp_hash: Option<String>,
154 },
155 Or(Vec<Expression>),
156 Uri(Uri),
157}
158
159impl fmt::Display for Expression {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 match self {
162 Expression::Pcr {
163 selection,
164 digest,
165 count,
166 } => {
167 write!(f, "{selection}")?;
168 if let Some(d) = digest {
169 write!(f, ":{d}")?;
170 }
171 if let Some(c) = count {
172 write!(f, ", count={c}")?;
173 }
174 Ok(())
175 }
176 Expression::Secret {
177 auth_handle_uri,
178 password,
179 cp_hash,
180 } => {
181 write!(f, "secret({auth_handle_uri}")?;
182 if let Some(p) = password {
183 write!(f, ", {p}")?;
184 }
185 if let Some(c) = cp_hash {
186 write!(f, ", {c}")?;
187 }
188 write!(f, ")")
189 }
190 Expression::Or(branches) => {
191 let branches_str: Vec<String> = branches.iter().map(ToString::to_string).collect();
192 write!(f, "or({})", branches_str.join(", "))
193 }
194 Expression::Uri(uri) => write!(f, "{uri}"),
195 }
196 }
197}
198
199impl Expression {
200 pub fn to_bytes(&self) -> Result<Vec<u8>, PolicyError> {
207 match self {
208 Self::Uri(Uri::Path(path)) => Ok(std::fs::read(Path::new(path))?),
209 _ => Err(PolicyError::InvalidExpression(format!(
210 "invalid expression: {self:?}"
211 ))),
212 }
213 }
214
215 pub fn to_tpm_handle(&self) -> Result<u32, PolicyError> {
221 match self {
222 Self::Uri(Uri::Tpm(handle)) => Ok(*handle),
223 _ => Err(PolicyError::InvalidExpression(format!(
224 "invalid expression: {self:?}"
225 ))),
226 }
227 }
228}
229
230fn comma_sep<'a, F, O>(f: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
231where
232 F: FnMut(&'a str) -> IResult<&'a str, O>,
233{
234 preceded(terminated(char(','), space0), f)
235}
236
237fn secret_expression(input: &str) -> IResult<&str, Expression> {
238 map(
239 tuple((
240 parse_expression,
241 opt(comma_sep(parse_expression)),
242 opt(comma_sep(map(hex_digit1, |s: &str| s.to_string()))),
243 )),
244 |(uri_expr, password_expr, cp_hash_str)| Expression::Secret {
245 auth_handle_uri: Box::new(uri_expr),
246 password: password_expr.map(Box::new),
247 cp_hash: cp_hash_str,
248 },
249 )(input)
250}
251
252fn or_expression(input: &str) -> IResult<&str, Expression> {
253 map(
254 separated_list1(
255 preceded(space0, terminated(char(','), space0)),
256 parse_expression,
257 ),
258 Expression::Or,
259 )(input)
260}
261
262fn call<'a, F, O>(name: &'static str, f: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
263where
264 F: FnMut(&'a str) -> IResult<&'a str, O>,
265{
266 delimited(
267 terminated(tag(name), char('(')),
268 delimited(space0, f, space0),
269 char(')'),
270 )
271}
272
273fn uri_expression(input: &str) -> IResult<&str, Expression> {
274 map_res(take_while1(|c: char| c != ',' && c != ')'), |s: &str| {
275 Uri::from_str(s).map(Expression::Uri)
276 })(input)
277}
278
279fn pcr_policy_expression(input: &str) -> IResult<&str, Expression> {
280 let (remainder, pcr_substring) = take_while1(|c: char| !matches!(c, '(' | ')' | ','))(input)?;
281
282 match pcr::parse_pcr_policy_string(pcr_substring) {
283 Ok((selections, digest)) => {
284 let selection_str = selections
285 .into_iter()
286 .map(|s| s.to_string())
287 .collect::<Vec<_>>()
288 .join("+");
289 let expr = Expression::Pcr {
290 selection: selection_str,
291 digest,
292 count: None,
293 };
294 Ok((remainder, expr))
295 }
296 Err(_) => {
297 if pcr_substring.contains(':') && !pcr_substring.contains("://") {
298 Err(NomErr::Failure(nom::error::Error::new(
299 input,
300 ErrorKind::Verify,
301 )))
302 } else {
303 Err(NomErr::Error(nom::error::Error::new(
304 input,
305 ErrorKind::Verify,
306 )))
307 }
308 }
309 }
310}
311
312fn parse_expression(input: &str) -> IResult<&str, Expression> {
314 alt((
315 call("secret", secret_expression),
316 call("or", or_expression),
317 pcr_policy_expression,
318 uri_expression,
319 ))(input)
320}
321
322pub fn parse(input: &str) -> Result<Expression, PolicyError> {
329 let (remaining, expr) =
330 parse_expression(input).map_err(|e| PolicyError::InvalidExpression(e.to_string()))?;
331
332 if !remaining.is_empty() {
333 return Err(PolicyError::InvalidExpression(format!(
334 "unexpected trailing input: '{remaining}'"
335 )));
336 }
337 Ok(expr)
338}
339
340pub fn execute_policy(
346 ast: &Expression,
347 session: &mut impl PolicySession,
348) -> Result<Tpm2bDigest, PolicyError> {
349 match ast {
350 Expression::Pcr {
351 selection,
352 digest,
353 count: _,
354 } => {
355 let digest_bytes =
356 hex::decode(digest.as_ref().ok_or(PolicyError::InvalidExpression(
357 "PCR policy requires a digest for execution".to_string(),
358 ))?)?;
359 let pcr_digest = Tpm2bDigest::try_from(digest_bytes.as_slice())?;
360 let selections = pcr::pcr_selection_vec_from_str(selection)?;
361 let banks = pcr::pcr_get_bank_list(session.device())?;
362 let pcrs = pcr::pcr_selection_vec_to_tpml(&selections, &banks)?;
363 session.policy_pcr(&pcr_digest, pcrs)?;
364 session.get_digest()
365 }
366 Expression::Secret {
367 auth_handle_uri,
368 password,
369 cp_hash,
370 } => {
371 let mut flush_handle: Option<u32> = None;
372
373 let handle = match &**auth_handle_uri {
374 Expression::Uri(Uri::Tpm(h)) => *h,
375 Expression::Uri(Uri::Path(_)) => {
376 let context_bytes = auth_handle_uri.to_bytes()?;
377 let (context, _) = TpmsContext::parse(&context_bytes)?;
378 let new_handle = session.device().load_context(context)?;
379 flush_handle = Some(new_handle);
380 new_handle
381 }
382 _ => {
383 return Err(PolicyError::InvalidExpression(
384 "secret() auth handle must be tpm:// or <path>".to_string(),
385 ))
386 }
387 };
388
389 let name = if (handle >> 24) as u8 == tpm2_protocol::data::TpmHt::Transient as u8 {
390 session.device().read_public(handle.into())?.1
391 } else {
392 tpm2_protocol::data::Tpm2bName::try_from(handle.to_be_bytes().as_slice())?
393 };
394 let password_bytes = password.as_ref().map(|p| p.to_bytes()).transpose()?;
395 let cp_hash_digest = cp_hash
396 .as_ref()
397 .map(|hex_str| -> Result<Tpm2bDigest, PolicyError> {
398 let bytes = hex::decode(hex_str)?;
399 Ok(Tpm2bDigest::try_from(bytes.as_slice())?)
400 })
401 .transpose()?;
402
403 session.policy_secret(handle, &name, password_bytes.as_deref(), cp_hash_digest)?;
404
405 if let Some(h) = flush_handle {
406 let cmd = TpmFlushContextCommand {
407 flush_handle: h.into(),
408 };
409 let sessions = vec![];
410 session.device().execute(&cmd, &sessions)?;
411 }
412
413 session.get_digest()
414 }
415 Expression::Or(branches) => {
416 let mut branch_digests = TpmlDigest::new();
417 for branch in branches {
418 let mut temp_session =
419 SoftwarePolicySession::new(session.hash_alg(), session.device())?;
420 let branch_digest = execute_policy(branch, &mut temp_session)?;
421 branch_digests
422 .try_push(branch_digest)
423 .map_err(|e| PolicyError::InvalidExpression(e.to_string()))?;
424 }
425 session.policy_or(&branch_digests)?;
426 session.get_digest()
427 }
428 Expression::Uri(uri) => Err(PolicyError::InvalidExpression(uri.to_string())),
429 }
430}
431
432pub fn populate_pcr_digests<S: std::hash::BuildHasher>(
440 ast: &mut Expression,
441 pcr_map: &HashMap<String, Vec<u8>, S>,
442) -> Result<(), PolicyError> {
443 match ast {
444 Expression::Pcr {
445 selection, digest, ..
446 } => {
447 if digest.is_none() {
448 let digest_bytes = pcr_map
449 .get(selection)
450 .ok_or_else(|| PolicyError::PcrValueMissing(selection.clone()))?;
451 *digest = Some(hex::encode(digest_bytes));
452 }
453 }
454 Expression::Or(branches) => {
455 for branch in branches.iter_mut() {
456 populate_pcr_digests(branch, pcr_map)?;
457 }
458 }
459 Expression::Secret {
460 auth_handle_uri, ..
461 } => {
462 populate_pcr_digests(auth_handle_uri, pcr_map)?;
463 }
464 Expression::Uri(_) => {}
465 }
466 Ok(())
467}