cli/policy/
mod.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5//! This module contains the parser and executor for the unified policy language.
6
7pub mod software;
8pub mod tpm;
9
10pub use software::*;
11pub use tpm::*;
12
13use crate::{
14    context::ContextError,
15    crypto::CryptoError,
16    device::{Device, DeviceError},
17    pcr::{self, PcrError},
18    session::SessionError,
19    uri::{Uri, UriError},
20};
21use nom::{
22    branch::alt,
23    bytes::complete::{tag, take_while1},
24    character::complete::{char, hex_digit1, space0},
25    combinator::{map, map_res, opt},
26    error::ErrorKind,
27    multi::separated_list1,
28    sequence::{delimited, preceded, terminated, tuple},
29    Err as NomErr, IResult,
30};
31use std::{collections::HashMap, fmt, path::Path, str::FromStr};
32use thiserror::Error;
33use tpm2_protocol::{
34    data::{Tpm2bDigest, TpmAlgId, TpmlDigest, TpmlPcrSelection, TpmsContext},
35    message::TpmFlushContextCommand,
36    TpmErrorKind, TpmParse,
37};
38
39/// An abstract interface for a session that can have a policy applied to it.
40pub trait PolicySession {
41    /// Returns the device associated with the session.
42    fn device(&mut self) -> &mut Device;
43
44    /// Applies a `TPM2_PolicyPCR` action to the session.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the policy action fails.
49    fn policy_pcr(
50        &mut self,
51        pcr_digest: &Tpm2bDigest,
52        pcrs: TpmlPcrSelection,
53    ) -> Result<(), PolicyError>;
54
55    /// Applies a `TPM2_PolicyOR` action to the session.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the policy action fails.
60    fn policy_or(&mut self, p_hash_list: &TpmlDigest) -> Result<(), PolicyError>;
61
62    /// Applies a `TPM2_PolicySecret` action to the session.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if the policy action fails.
67    fn policy_secret(
68        &mut self,
69        auth_handle: u32,
70        auth_handle_name: &tpm2_protocol::data::Tpm2bName,
71        password: Option<&[u8]>,
72        cp_hash: Option<Tpm2bDigest>,
73    ) -> Result<(), PolicyError>;
74
75    /// Retrieves the final policy digest from the session.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the digest cannot be retrieved.
80    fn get_digest(&mut self) -> Result<Tpm2bDigest, PolicyError>;
81
82    /// Returns the session's hash algorithm.
83    fn hash_alg(&self) -> TpmAlgId;
84}
85
86#[derive(Debug, Error)]
87pub enum PolicyError {
88    #[error("context: {0}")]
89    Context(#[from] ContextError),
90    #[error("device: {0}")]
91    Device(#[from] DeviceError),
92    #[error("invalid algorithm: {0:?}")]
93    InvalidAlgorithm(TpmAlgId),
94    #[error("invalid expression: {0}")]
95    InvalidExpression(String),
96    #[error("invalid value: {0}")]
97    InvalidValue(String),
98    #[error("I/O: {0}")]
99    Io(#[from] std::io::Error),
100    #[error("pcr: {0}")]
101    Pcr(#[from] PcrError),
102    #[error("crypto: {0}")]
103    Crypto(#[from] CryptoError),
104    #[error("PCR value for selection '{0}' not provided")]
105    PcrValueMissing(String),
106    #[error("session: {0}")]
107    Session(#[from] SessionError),
108    #[error("uri: {0}")]
109    Uri(#[from] UriError),
110}
111
112impl From<hex::FromHexError> for PolicyError {
113    fn from(err: hex::FromHexError) -> Self {
114        Self::InvalidValue(err.to_string())
115    }
116}
117
118impl From<base64::DecodeError> for PolicyError {
119    fn from(err: base64::DecodeError) -> Self {
120        Self::InvalidValue(err.to_string())
121    }
122}
123
124impl From<std::num::ParseIntError> for PolicyError {
125    fn from(err: std::num::ParseIntError) -> Self {
126        Self::InvalidValue(err.to_string())
127    }
128}
129
130impl From<std::str::Utf8Error> for PolicyError {
131    fn from(err: std::str::Utf8Error) -> Self {
132        Self::InvalidValue(err.to_string())
133    }
134}
135
136impl From<TpmErrorKind> for PolicyError {
137    fn from(err: TpmErrorKind) -> Self {
138        Self::Device(err.into())
139    }
140}
141
142/// The Abstract Syntax Tree (AST) for the unified policy language.
143#[derive(Debug, PartialEq, Clone)]
144pub enum Expression {
145    Pcr {
146        selection: String,
147        digest: Option<String>,
148        count: Option<u32>,
149    },
150    Secret {
151        auth_handle_uri: Box<Expression>,
152        password: Option<Box<Expression>>,
153        cp_hash: Option<String>,
154    },
155    Or(Vec<Expression>),
156    Uri(Uri),
157}
158
159impl fmt::Display for Expression {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        match self {
162            Expression::Pcr {
163                selection,
164                digest,
165                count,
166            } => {
167                write!(f, "{selection}")?;
168                if let Some(d) = digest {
169                    write!(f, ":{d}")?;
170                }
171                if let Some(c) = count {
172                    write!(f, ", count={c}")?;
173                }
174                Ok(())
175            }
176            Expression::Secret {
177                auth_handle_uri,
178                password,
179                cp_hash,
180            } => {
181                write!(f, "secret({auth_handle_uri}")?;
182                if let Some(p) = password {
183                    write!(f, ", {p}")?;
184                }
185                if let Some(c) = cp_hash {
186                    write!(f, ", {c}")?;
187                }
188                write!(f, ")")
189            }
190            Expression::Or(branches) => {
191                let branches_str: Vec<String> = branches.iter().map(ToString::to_string).collect();
192                write!(f, "or({})", branches_str.join(", "))
193            }
194            Expression::Uri(uri) => write!(f, "{uri}"),
195        }
196    }
197}
198
199impl Expression {
200    /// Resolves a file path expression into bytes.
201    ///
202    /// # Errors
203    ///
204    /// Returns a `PolicyError` if the expression is not a file path or the file
205    /// cannot be read.
206    pub fn to_bytes(&self) -> Result<Vec<u8>, PolicyError> {
207        match self {
208            Self::Uri(Uri::Path(path)) => Ok(std::fs::read(Path::new(path))?),
209            _ => Err(PolicyError::InvalidExpression(format!(
210                "invalid expression: {self:?}"
211            ))),
212        }
213    }
214
215    /// Parses a TPM handle from a `tpm://` expression.
216    ///
217    /// # Errors
218    ///
219    /// Returns a `PolicyError` if the expression is not a `Expression::Uri(Uri::Tpm)`.
220    pub fn to_tpm_handle(&self) -> Result<u32, PolicyError> {
221        match self {
222            Self::Uri(Uri::Tpm(handle)) => Ok(*handle),
223            _ => Err(PolicyError::InvalidExpression(format!(
224                "invalid expression: {self:?}"
225            ))),
226        }
227    }
228}
229
230fn comma_sep<'a, F, O>(f: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
231where
232    F: FnMut(&'a str) -> IResult<&'a str, O>,
233{
234    preceded(terminated(char(','), space0), f)
235}
236
237fn secret_expression(input: &str) -> IResult<&str, Expression> {
238    map(
239        tuple((
240            parse_expression,
241            opt(comma_sep(parse_expression)),
242            opt(comma_sep(map(hex_digit1, |s: &str| s.to_string()))),
243        )),
244        |(uri_expr, password_expr, cp_hash_str)| Expression::Secret {
245            auth_handle_uri: Box::new(uri_expr),
246            password: password_expr.map(Box::new),
247            cp_hash: cp_hash_str,
248        },
249    )(input)
250}
251
252fn or_expression(input: &str) -> IResult<&str, Expression> {
253    map(
254        separated_list1(
255            preceded(space0, terminated(char(','), space0)),
256            parse_expression,
257        ),
258        Expression::Or,
259    )(input)
260}
261
262fn call<'a, F, O>(name: &'static str, f: F) -> impl FnMut(&'a str) -> IResult<&'a str, O>
263where
264    F: FnMut(&'a str) -> IResult<&'a str, O>,
265{
266    delimited(
267        terminated(tag(name), char('(')),
268        delimited(space0, f, space0),
269        char(')'),
270    )
271}
272
273fn uri_expression(input: &str) -> IResult<&str, Expression> {
274    map_res(take_while1(|c: char| c != ',' && c != ')'), |s: &str| {
275        Uri::from_str(s).map(Expression::Uri)
276    })(input)
277}
278
279fn pcr_policy_expression(input: &str) -> IResult<&str, Expression> {
280    let (remainder, pcr_substring) = take_while1(|c: char| !matches!(c, '(' | ')' | ','))(input)?;
281
282    match pcr::parse_pcr_policy_string(pcr_substring) {
283        Ok((selections, digest)) => {
284            let selection_str = selections
285                .into_iter()
286                .map(|s| s.to_string())
287                .collect::<Vec<_>>()
288                .join("+");
289            let expr = Expression::Pcr {
290                selection: selection_str,
291                digest,
292                count: None,
293            };
294            Ok((remainder, expr))
295        }
296        Err(_) => {
297            if pcr_substring.contains(':') && !pcr_substring.contains("://") {
298                Err(NomErr::Failure(nom::error::Error::new(
299                    input,
300                    ErrorKind::Verify,
301                )))
302            } else {
303                Err(NomErr::Error(nom::error::Error::new(
304                    input,
305                    ErrorKind::Verify,
306                )))
307            }
308        }
309    }
310}
311
312/// Parses any valid expression.
313fn parse_expression(input: &str) -> IResult<&str, Expression> {
314    alt((
315        call("secret", secret_expression),
316        call("or", or_expression),
317        pcr_policy_expression,
318        uri_expression,
319    ))(input)
320}
321
322/// Parses an expression string, ensuring the entire input is consumed.
323///
324/// # Errors
325///
326/// Returns a `PolicyError` if the input is not a valid expression or if there
327/// is trailing input left after parsing.
328pub fn parse(input: &str) -> Result<Expression, PolicyError> {
329    let (remaining, expr) =
330        parse_expression(input).map_err(|e| PolicyError::InvalidExpression(e.to_string()))?;
331
332    if !remaining.is_empty() {
333        return Err(PolicyError::InvalidExpression(format!(
334            "unexpected trailing input: '{remaining}'"
335        )));
336    }
337    Ok(expr)
338}
339
340/// Traverses a policy AST and applies the commands to a session object.
341///
342/// # Errors
343///
344/// Returns an error if any policy command fails.
345pub fn execute_policy(
346    ast: &Expression,
347    session: &mut impl PolicySession,
348) -> Result<Tpm2bDigest, PolicyError> {
349    match ast {
350        Expression::Pcr {
351            selection,
352            digest,
353            count: _,
354        } => {
355            let digest_bytes =
356                hex::decode(digest.as_ref().ok_or(PolicyError::InvalidExpression(
357                    "PCR policy requires a digest for execution".to_string(),
358                ))?)?;
359            let pcr_digest = Tpm2bDigest::try_from(digest_bytes.as_slice())?;
360            let selections = pcr::pcr_selection_vec_from_str(selection)?;
361            let banks = pcr::pcr_get_bank_list(session.device())?;
362            let pcrs = pcr::pcr_selection_vec_to_tpml(&selections, &banks)?;
363            session.policy_pcr(&pcr_digest, pcrs)?;
364            session.get_digest()
365        }
366        Expression::Secret {
367            auth_handle_uri,
368            password,
369            cp_hash,
370        } => {
371            let mut flush_handle: Option<u32> = None;
372
373            let handle = match &**auth_handle_uri {
374                Expression::Uri(Uri::Tpm(h)) => *h,
375                Expression::Uri(Uri::Path(_)) => {
376                    let context_bytes = auth_handle_uri.to_bytes()?;
377                    let (context, _) = TpmsContext::parse(&context_bytes)?;
378                    let new_handle = session.device().load_context(context)?;
379                    flush_handle = Some(new_handle);
380                    new_handle
381                }
382                _ => {
383                    return Err(PolicyError::InvalidExpression(
384                        "secret() auth handle must be tpm:// or <path>".to_string(),
385                    ))
386                }
387            };
388
389            let name = if (handle >> 24) as u8 == tpm2_protocol::data::TpmHt::Transient as u8 {
390                session.device().read_public(handle.into())?.1
391            } else {
392                tpm2_protocol::data::Tpm2bName::try_from(handle.to_be_bytes().as_slice())?
393            };
394            let password_bytes = password.as_ref().map(|p| p.to_bytes()).transpose()?;
395            let cp_hash_digest = cp_hash
396                .as_ref()
397                .map(|hex_str| -> Result<Tpm2bDigest, PolicyError> {
398                    let bytes = hex::decode(hex_str)?;
399                    Ok(Tpm2bDigest::try_from(bytes.as_slice())?)
400                })
401                .transpose()?;
402
403            session.policy_secret(handle, &name, password_bytes.as_deref(), cp_hash_digest)?;
404
405            if let Some(h) = flush_handle {
406                let cmd = TpmFlushContextCommand {
407                    flush_handle: h.into(),
408                };
409                let sessions = vec![];
410                session.device().execute(&cmd, &sessions)?;
411            }
412
413            session.get_digest()
414        }
415        Expression::Or(branches) => {
416            let mut branch_digests = TpmlDigest::new();
417            for branch in branches {
418                let mut temp_session =
419                    SoftwarePolicySession::new(session.hash_alg(), session.device())?;
420                let branch_digest = execute_policy(branch, &mut temp_session)?;
421                branch_digests
422                    .try_push(branch_digest)
423                    .map_err(|e| PolicyError::InvalidExpression(e.to_string()))?;
424            }
425            session.policy_or(&branch_digests)?;
426            session.get_digest()
427        }
428        Expression::Uri(uri) => Err(PolicyError::InvalidExpression(uri.to_string())),
429    }
430}
431
432/// Recursively traverses a policy AST and populates any missing PCR digests
433/// from a provided map.
434///
435/// # Errors
436///
437/// Returns a `PolicyError` if a PCR selection is malformed or if its value is
438/// not in the map.
439pub fn populate_pcr_digests<S: std::hash::BuildHasher>(
440    ast: &mut Expression,
441    pcr_map: &HashMap<String, Vec<u8>, S>,
442) -> Result<(), PolicyError> {
443    match ast {
444        Expression::Pcr {
445            selection, digest, ..
446        } => {
447            if digest.is_none() {
448                let digest_bytes = pcr_map
449                    .get(selection)
450                    .ok_or_else(|| PolicyError::PcrValueMissing(selection.clone()))?;
451                *digest = Some(hex::encode(digest_bytes));
452            }
453        }
454        Expression::Or(branches) => {
455            for branch in branches.iter_mut() {
456                populate_pcr_digests(branch, pcr_map)?;
457            }
458        }
459        Expression::Secret {
460            auth_handle_uri, ..
461        } => {
462            populate_pcr_digests(auth_handle_uri, pcr_map)?;
463        }
464        Expression::Uri(_) => {}
465    }
466    Ok(())
467}