cli/policy/
mod.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5//! This module contains the parser and executor for the unified policy language.
6
7pub mod software;
8pub mod tpm;
9
10pub use software::*;
11pub use tpm::*;
12
13use crate::{
14    auth::{Auth, AuthClass},
15    crypto::CryptoError,
16    device::{Device, DeviceError},
17    handle::{Handle, HandleClass, HandleError},
18    pcr::{self, PcrError, PcrSelection},
19    vtpm::VtpmError,
20};
21use std::{
22    collections::HashMap, fmt, iter::Peekable, num::ParseIntError, slice::Iter, str::FromStr,
23};
24use thiserror::Error;
25use tpm2_protocol::{
26    data::{Tpm2bDigest, TpmAlgId, TpmHt, TpmlDigest, TpmlPcrSelection},
27    TpmError, TpmHandle,
28};
29
30#[derive(Debug, Error)]
31pub enum PolicyError {
32    #[error("invalid algorithm: {0:?}")]
33    InvalidAlgorithm(TpmAlgId),
34    #[error("invalid expression: {0}")]
35    InvalidExpression(String),
36    #[error("invalid secret: {0}")]
37    InvalidSecret(String),
38    #[error("invalid value: {0}")]
39    InvalidValue(String),
40    #[error("no valid branch found for OR policy")]
41    NoValidPolicyOrBranch,
42    #[error("PCR value for selection '{0}' not provided")]
43    PcrValueMissing(String),
44    #[error("unexpected end of expression")]
45    UnexpectedEndOfExpression,
46    #[error("unexpected token: {0}")]
47    UnexpectedToken(String),
48    #[error("unmatched parenthesis")]
49    UnmatchedParenthesis,
50    #[error("crypto: {0}")]
51    Crypto(#[from] CryptoError),
52    #[error("device: {0}")]
53    Device(#[from] DeviceError),
54    #[error("I/O: {0}")]
55    Io(#[from] std::io::Error),
56    #[error("cache: {0}")]
57    Cache(#[from] VtpmError),
58    #[error("pcr: {0}")]
59    Pcr(#[from] PcrError),
60    #[error("hex decode: {0}")]
61    HexDecode(#[from] hex::FromHexError),
62    #[error("handle decode: {0}")]
63    IntDecode(#[from] ParseIntError),
64    #[error("protocol: {0}")]
65    TpmProtocol(TpmError),
66    #[error("handle: {0}")]
67    Handle(#[from] HandleError),
68}
69
70impl From<TpmError> for PolicyError {
71    fn from(err: TpmError) -> Self {
72        Self::TpmProtocol(err)
73    }
74}
75
76/// An abstract interface for a session that can have a policy applied to it.
77pub trait PolicySession {
78    /// Returns the device associated with the session.
79    fn device(&mut self) -> &mut Device;
80
81    /// Applies a `TPM2_PolicyPCR` action to the session.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the policy action fails.
86    fn policy_pcr(
87        &mut self,
88        pcr_digest: &Tpm2bDigest,
89        pcrs: TpmlPcrSelection,
90    ) -> Result<(), PolicyError>;
91
92    /// Applies a `TPM2_PolicyOR` action to the session.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the policy action fails.
97    fn policy_or(&mut self, p_hash_list: &TpmlDigest) -> Result<(), PolicyError>;
98
99    /// Applies a `TPM2_PolicySecret` action to the session.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if the policy action fails.
104    fn policy_secret(
105        &mut self,
106        auth_handle: u32,
107        auth_handle_name: &tpm2_protocol::data::Tpm2bName,
108        password: Option<&[u8]>,
109        cp_hash: Option<Tpm2bDigest>,
110    ) -> Result<(), PolicyError>;
111
112    /// Applies a `TPM2_PolicyRestart` action to the session.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if the policy action fails.
117    fn policy_restart(&mut self) -> Result<(), PolicyError>;
118
119    /// Retrieves the final policy digest from the session.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the digest cannot be retrieved.
124    fn get_digest(&mut self) -> Result<Tpm2bDigest, PolicyError>;
125
126    /// Returns the session's hash algorithm.
127    fn hash_alg(&self) -> TpmAlgId;
128}
129
130/// The Abstract Syntax Tree (AST) for the unified policy language.
131#[derive(Debug, PartialEq, Clone)]
132pub enum Expression {
133    Auth(Auth),
134    Pcr {
135        selections: Vec<PcrSelection>,
136        digest: Option<String>,
137        count: Option<u32>,
138    },
139    Secret {
140        auth_handle: Box<Expression>,
141        password: Option<Box<Expression>>,
142        cp_hash: Option<String>,
143    },
144    And(Vec<Expression>),
145    Or(Vec<Expression>),
146    Handle(Handle),
147}
148
149impl fmt::Display for Expression {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Expression::Auth(auth) => write!(f, "{auth}"),
153            Expression::Pcr {
154                selections,
155                digest,
156                count,
157            } => {
158                let selection_str = selections
159                    .iter()
160                    .map(ToString::to_string)
161                    .collect::<Vec<_>>()
162                    .join("+");
163                write!(f, "pcr({selection_str}")?;
164                if let Some(d) = digest {
165                    write!(f, ":{d}")?;
166                }
167                if let Some(c) = count {
168                    write!(f, ", count={c}")?;
169                }
170                write!(f, ")")
171            }
172            Expression::Secret {
173                auth_handle,
174                password,
175                cp_hash,
176            } => {
177                write!(f, "secret({auth_handle}")?;
178                if let Some(p) = password {
179                    write!(f, ", {p}")?;
180                }
181                if let Some(c) = cp_hash {
182                    write!(f, ", {c}")?;
183                }
184                write!(f, ")")
185            }
186            Expression::And(expressions) => {
187                let s: Vec<String> = expressions.iter().map(ToString::to_string).collect();
188                write!(f, "({})", s.join(" and "))
189            }
190            Expression::Or(expressions) => {
191                let s: Vec<String> = expressions.iter().map(ToString::to_string).collect();
192                write!(f, "({})", s.join(" or "))
193            }
194            Expression::Handle(handle) => write!(f, "{handle}"),
195        }
196    }
197}
198
199impl Expression {
200    /// Resolves a password expression into bytes.
201    ///
202    /// # Errors
203    ///
204    /// Returns a `PolicyError` if the expression is not a file path or the file
205    /// cannot be read.
206    pub fn to_bytes(&self) -> Result<Vec<u8>, PolicyError> {
207        match self {
208            Self::Auth(auth_instance) if auth_instance.class() == AuthClass::Password => {
209                Ok(auth_instance.value().to_vec())
210            }
211            _ => Err(PolicyError::InvalidSecret(format!(
212                "{self:?}: expected 'password:<hex>'"
213            ))),
214        }
215    }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
219enum Token<'a> {
220    And,
221    Or,
222    LParen,
223    RParen,
224    Comma,
225    Ident(&'a str),
226}
227
228impl fmt::Display for Token<'_> {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        match self {
231            Token::And => write!(f, "'and'"),
232            Token::Or => write!(f, "'or'"),
233            Token::LParen => write!(f, "'('"),
234            Token::RParen => write!(f, "')'"),
235            Token::Comma => write!(f, "','"),
236            Token::Ident(s) => write!(f, "'{s}'"),
237        }
238    }
239}
240
241fn tokenize(input: &str) -> Vec<Token<'_>> {
242    let mut tokens = Vec::new();
243    let mut current_index = 0;
244    let bytes = input.as_bytes();
245
246    while current_index < bytes.len() {
247        let ch = bytes[current_index] as char;
248
249        match ch {
250            '(' => {
251                tokens.push(Token::LParen);
252                current_index += 1;
253            }
254            ')' => {
255                tokens.push(Token::RParen);
256                current_index += 1;
257            }
258            ',' => {
259                tokens.push(Token::Comma);
260                current_index += 1;
261            }
262            c if c.is_whitespace() => {
263                let char_len = input[current_index..].chars().next().unwrap().len_utf8();
264                current_index += char_len;
265            }
266            _ => {
267                let start_index = current_index;
268                let mut end_index = start_index;
269                while end_index < bytes.len() {
270                    let current_char = input[end_index..].chars().next().unwrap();
271                    if current_char.is_whitespace() || "(),".contains(current_char) {
272                        break;
273                    }
274                    end_index += current_char.len_utf8();
275                }
276
277                let ident_slice = &input[start_index..end_index];
278
279                match ident_slice {
280                    "and" => tokens.push(Token::And),
281                    "or" => tokens.push(Token::Or),
282                    _ => tokens.push(Token::Ident(ident_slice)),
283                }
284                current_index = end_index;
285            }
286        }
287    }
288    tokens
289}
290
291struct Parser<'a, 'b> {
292    tokens: &'a mut Peekable<Iter<'b, Token<'b>>>,
293}
294
295impl Parser<'_, '_> {
296    fn parse_or(&mut self) -> Result<Expression, PolicyError> {
297        let mut node = self.parse_and()?;
298        while let Some(Token::Or) = self.tokens.peek() {
299            self.tokens.next();
300            let rhs = self.parse_and()?;
301            node = match node {
302                Expression::Or(mut terms) => {
303                    terms.push(rhs);
304                    Expression::Or(terms)
305                }
306                lhs => Expression::Or(vec![lhs, rhs]),
307            };
308        }
309        Ok(node)
310    }
311
312    fn parse_and(&mut self) -> Result<Expression, PolicyError> {
313        let mut node = self.parse_primary()?;
314        while let Some(Token::And) = self.tokens.peek() {
315            self.tokens.next();
316            let rhs = self.parse_primary()?;
317            node = match node {
318                Expression::And(mut factors) => {
319                    factors.push(rhs);
320                    Expression::And(factors)
321                }
322                lhs => Expression::And(vec![lhs, rhs]),
323            };
324        }
325        Ok(node)
326    }
327
328    fn parse_primary(&mut self) -> Result<Expression, PolicyError> {
329        let token = self
330            .tokens
331            .next()
332            .ok_or(PolicyError::UnexpectedEndOfExpression)?;
333
334        match token {
335            Token::LParen => {
336                let expr = self.parse_or()?;
337                if self.tokens.next() != Some(&Token::RParen) {
338                    return Err(PolicyError::UnmatchedParenthesis);
339                }
340                Ok(expr)
341            }
342            Token::Ident(name) => match *name {
343                "pcr" => self.parse_pcr_call(),
344                "secret" => self.parse_secret_call(),
345                _ => Self::parse_literal(name),
346            },
347            _ => Err(PolicyError::UnexpectedToken(token.to_string())),
348        }
349    }
350
351    fn parse_literal(s: &str) -> Result<Expression, PolicyError> {
352        if let Ok(auth) = Auth::from_str(s) {
353            Ok(Expression::Auth(auth))
354        } else if let Ok(handle) = Handle::from_str(s) {
355            Ok(Expression::Handle(handle))
356        } else {
357            Err(PolicyError::InvalidExpression(format!(
358                "unrecognized literal: {s}"
359            )))
360        }
361    }
362
363    fn parse_call_args(&mut self) -> Result<Vec<Expression>, PolicyError> {
364        match self.tokens.peek() {
365            Some(&&Token::LParen) => {
366                self.tokens.next();
367            }
368            Some(actual_token) => {
369                return Err(PolicyError::UnexpectedToken(format!(
370                    "Expected '(' to start argument list, found {actual_token}"
371                )));
372            }
373            None => {
374                return Err(PolicyError::UnexpectedEndOfExpression);
375            }
376        }
377
378        let mut args = Vec::new();
379        if self.tokens.peek() == Some(&&Token::RParen) {
380            self.tokens.next();
381            return Ok(args);
382        }
383
384        loop {
385            if let Some(Token::Ident(ident)) = self.tokens.peek() {
386                if let Ok(selections) = pcr::pcr_selection_vec_from_str(ident) {
387                    self.tokens.next();
388                    args.push(Expression::Pcr {
389                        selections,
390                        digest: None,
391                        count: None,
392                    });
393                } else {
394                    args.push(self.parse_or()?);
395                }
396            } else {
397                args.push(self.parse_or()?);
398            }
399
400            match self.tokens.peek() {
401                Some(&&Token::RParen) => {
402                    self.tokens.next();
403                    break;
404                }
405                Some(&&Token::Comma) => {
406                    self.tokens.next();
407                }
408                Some(actual_token) => {
409                    return Err(PolicyError::UnexpectedToken(format!(
410                        "Expected ',' or ')' in argument list, found {actual_token}"
411                    )));
412                }
413                None => return Err(PolicyError::UnmatchedParenthesis),
414            }
415        }
416        Ok(args)
417    }
418
419    fn parse_pcr_call(&mut self) -> Result<Expression, PolicyError> {
420        let mut args = self.parse_call_args()?;
421        if args.len() != 1 {
422            return Err(PolicyError::InvalidExpression(
423                "pcr() expects one argument ".to_string(),
424            ));
425        }
426
427        let arg = args.remove(0);
428        if let Expression::Pcr {
429            selections,
430            digest,
431            count,
432        } = arg
433        {
434            Ok(Expression::Pcr {
435                selections,
436                digest,
437                count,
438            })
439        } else if let Expression::Auth(_) | Expression::Handle(_) = arg {
440            let pcr_content = arg.to_string();
441            let (selection_part, digest_part) =
442                if let Some((selection, digest)) = pcr_content.rsplit_once(':') {
443                    if digest.chars().all(|c| c.is_ascii_hexdigit())
444                        && digest.len()
445                            >= crate::crypto::crypto_hash_size(TpmAlgId::Sha1).unwrap_or(20) * 2
446                    {
447                        (selection.to_string(), Some(digest.to_string()))
448                    } else {
449                        (pcr_content, None)
450                    }
451                } else {
452                    (pcr_content, None)
453                };
454
455            let selections = pcr::pcr_selection_vec_from_str(&selection_part)?;
456            Ok(Expression::Pcr {
457                selections,
458                digest: digest_part,
459                count: None,
460            })
461        } else {
462            Err(PolicyError::InvalidExpression(
463                "pcr() argument is not a valid PCR selection ".to_string(),
464            ))
465        }
466    }
467
468    fn parse_secret_call(&mut self) -> Result<Expression, PolicyError> {
469        let args = self.parse_call_args()?;
470        if args.is_empty() || args.len() > 3 {
471            return Err(PolicyError::InvalidExpression(
472                "secret() expects 1 to 3 arguments ".to_string(),
473            ));
474        }
475
476        let mut arg_iter = args.into_iter();
477        let auth_handle = Box::new(arg_iter.next().unwrap());
478        let password = arg_iter.next().map(Box::new);
479        let cp_hash = arg_iter.next().map(|expr| expr.to_string());
480
481        Ok(Expression::Secret {
482            auth_handle,
483            password,
484            cp_hash,
485        })
486    }
487}
488
489/// Parses a policy expression string.
490///
491/// # Errors
492///
493/// Returns a `PolicyError::InvalidExpression` if expression parsing fails.
494/// Returns a `PolicyError::UnexpectedToken` if there is trailing data after the
495/// expression.
496pub fn parse(input: &str) -> Result<Expression, PolicyError> {
497    let tokens = tokenize(input);
498    let mut iter = tokens.iter().peekable();
499    let mut parser = Parser { tokens: &mut iter };
500    let expr = parser.parse_or()?;
501
502    if parser.tokens.peek().is_some() {
503        return Err(PolicyError::UnexpectedToken(
504            "Trailing data after expression ".to_string(),
505        ));
506    }
507
508    Ok(expr)
509}
510
511/// Traverses a policy AST and applies the commands to a session object.
512///
513/// # Errors
514///
515/// Returns an error if any policy command fails.
516pub fn execute_policy(
517    ast: &Expression,
518    session: &mut impl PolicySession,
519) -> Result<Tpm2bDigest, PolicyError> {
520    match ast {
521        Expression::Auth(auth) => Err(PolicyError::InvalidExpression(auth.to_string())),
522        Expression::Pcr {
523            selections,
524            digest,
525            count: _,
526        } => {
527            let digest_bytes =
528                hex::decode(digest.as_ref().ok_or(PolicyError::InvalidExpression(
529                    "expected a hex string for optional digest in pcr()".to_string(),
530                ))?)?;
531            let pcr_digest = Tpm2bDigest::try_from(digest_bytes.as_slice())?;
532            let banks = pcr::pcr_get_bank_list(session.device())?;
533            let pcrs = pcr::pcr_selection_vec_to_tpml(selections, &banks)?;
534            session.policy_pcr(&pcr_digest, pcrs)?;
535            session.get_digest()
536        }
537        Expression::Secret {
538            auth_handle,
539            password,
540            cp_hash,
541        } => {
542            let handle_val =
543                match &**auth_handle {
544                    Expression::Handle(handle) if handle.class() == HandleClass::Tpm => {
545                        let h_val = handle.value();
546                        if (h_val >> 24) as u8 != TpmHt::Persistent as u8 {
547                            return Err(PolicyError::InvalidExpression(
548                                "secret() handle must be a persistent TPM handle ('tpm:81xxxxxx')"
549                                    .to_string(),
550                            ));
551                        }
552                        h_val
553                    }
554                    _ => return Err(PolicyError::InvalidExpression(
555                        "secret() first argument must be a persistent TPM handle ('tpm:81xxxxxx')"
556                            .to_string(),
557                    )),
558                };
559            let handle = TpmHandle(handle_val);
560
561            let (_, name) = session.device().read_public(handle)?;
562
563            let password_bytes = password.as_ref().map(|p| p.to_bytes()).transpose()?;
564            let cp_hash_digest = cp_hash
565                .as_ref()
566                .map(|hex_str| -> Result<Tpm2bDigest, PolicyError> {
567                    let bytes = hex::decode(hex_str)?;
568                    Ok(Tpm2bDigest::try_from(bytes.as_slice())?)
569                })
570                .transpose()?;
571
572            session.policy_secret(handle_val, &name, password_bytes.as_deref(), cp_hash_digest)?;
573
574            session.get_digest()
575        }
576        Expression::And(expressions) => {
577            let (last_expr, other_exprs) = expressions.split_last().ok_or_else(|| {
578                PolicyError::InvalidExpression("'and'-expression must be non-empty".to_string())
579            })?;
580
581            for expr in other_exprs {
582                execute_policy(expr, session)?;
583            }
584            execute_policy(last_expr, session)
585        }
586        Expression::Or(branches) => {
587            let mut branch_digests = TpmlDigest::new();
588            for branch in branches {
589                let mut temp_session =
590                    SoftwarePolicySession::new(session.hash_alg(), session.device())?;
591                let branch_digest = execute_policy(branch, &mut temp_session)?;
592                branch_digests
593                    .try_push(branch_digest)
594                    .map_err(|e| PolicyError::InvalidExpression(e.to_string()))?;
595            }
596
597            let mut last_error: Option<PolicyError> = None;
598            let mut branch_succeeded = false;
599            for branch in branches {
600                session.policy_restart()?;
601                match execute_policy(branch, session) {
602                    Ok(_) => {
603                        branch_succeeded = true;
604                        break;
605                    }
606                    Err(e) => {
607                        last_error = Some(e);
608                    }
609                }
610            }
611
612            if !branch_succeeded {
613                return Err(last_error.unwrap_or(PolicyError::NoValidPolicyOrBranch));
614            }
615
616            session.policy_or(&branch_digests)?;
617            session.get_digest()
618        }
619        Expression::Handle(handle) => Err(PolicyError::InvalidExpression(handle.to_string())),
620    }
621}
622
623/// Recursively traverses a policy AST and populates any missing PCR digests
624/// from a provided map.
625///
626/// # Errors
627///
628/// Returns a `PolicyError` if a PCR selection is malformed or if its value is
629/// not in the map.
630pub fn populate_pcr_digests<S: std::hash::BuildHasher>(
631    ast: &mut Expression,
632    pcr_map: &HashMap<String, Vec<u8>, S>,
633) -> Result<(), PolicyError> {
634    match ast {
635        Expression::Pcr {
636            selections, digest, ..
637        } => {
638            if digest.is_none() {
639                let selection_str = selections
640                    .iter()
641                    .map(ToString::to_string)
642                    .collect::<Vec<_>>()
643                    .join("+");
644                let digest_bytes = pcr_map
645                    .get(&selection_str)
646                    .ok_or(PolicyError::PcrValueMissing(selection_str))?;
647                *digest = Some(hex::encode(digest_bytes));
648            }
649        }
650        Expression::And(expressions) | Expression::Or(expressions) => {
651            for expr in expressions.iter_mut() {
652                populate_pcr_digests(expr, pcr_map)?;
653            }
654        }
655        Expression::Secret {
656            auth_handle,
657            password,
658            ..
659        } => {
660            populate_pcr_digests(auth_handle, pcr_map)?;
661            if let Some(pwd_expr) = password {
662                populate_pcr_digests(pwd_expr, pcr_map)?;
663            }
664        }
665        Expression::Auth(_) | Expression::Handle(_) => {}
666    }
667    Ok(())
668}
669
670/// Traverses the AST, applying a fallible visitor closure to each `Pcr` expression.
671///
672/// # Errors
673///
674/// Returns a `PolicyError` if the provided visitor closure returns an error.
675pub fn visit_pcr_expressions_mut<F>(
676    ast: &mut Expression,
677    visitor: &mut F,
678) -> Result<(), PolicyError>
679where
680    F: FnMut(&mut Expression) -> Result<(), PolicyError>,
681{
682    match ast {
683        Expression::Pcr { .. } => visitor(ast)?,
684        Expression::And(branches) | Expression::Or(branches) => {
685            for branch in branches.iter_mut() {
686                visit_pcr_expressions_mut(branch, visitor)?;
687            }
688        }
689        Expression::Secret { auth_handle, .. } => {
690            visit_pcr_expressions_mut(auth_handle, visitor)?;
691        }
692        Expression::Auth(_) | Expression::Handle(_) => {}
693    }
694    Ok(())
695}