cli/vtpm/
mod.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5//! Manages caching for TPM keys and sessions.
6
7use crate::{
8    auth::{Auth, AuthClass, AuthError},
9    crypto::CryptoError,
10    device::{Device, DeviceError},
11    handle::HandleError,
12    key::Tpm2shAlgId,
13};
14use std::{
15    any::Any,
16    collections::{hash_map::Entry, HashMap, HashSet},
17    fs, io,
18    num::TryFromIntError,
19    path::Path,
20    rc::Rc,
21};
22use thiserror::Error;
23use tpm2_protocol::{
24    data::{Tpm2bPublic, TpmHt, TpmRc, TpmsContext, TpmtPublic},
25    message::TpmAuthResponses,
26    TpmError, TpmHandle,
27};
28
29mod key;
30mod session;
31
32pub use key::*;
33pub use session::*;
34
35#[derive(Debug, Error)]
36pub enum VtpmError {
37    #[error("already tracked: {0}")]
38    AlreadyTracked(TpmHandle),
39    #[error("handle not found: {0}{1:08x}")]
40    HandleNotFound(&'static str, u32),
41    #[error("invalid auth")]
42    InvalidAuth,
43    #[error("invalid key bits: {0}")]
44    InvalidKeyBits(String),
45    #[error("invalid parent: {0:08x}")]
46    InvalidParent(u32),
47    #[error("no handles")]
48    NoHandles,
49    #[error("parent not loaded")]
50    ParentNotLoaded,
51    #[error("trailing authorizations")]
52    TrailingAuthorizations,
53    #[error("unsupported name algorithm: {0}")]
54    UnsupportedNameAlgorithm(Tpm2shAlgId),
55    #[error("auth error: {0}")]
56    Auth(#[from] AuthError),
57    #[error("crypto: {0}")]
58    Crypto(#[from] CryptoError),
59    #[error("device: {0}")]
60    Device(#[from] DeviceError),
61    #[error("handle: {0}")]
62    Handle(#[from] HandleError),
63    #[error("int decode: {0}")]
64    IntDecode(#[from] TryFromIntError),
65    #[error("I/O: {0}")]
66    Io(#[from] io::Error),
67    #[error("TPM: {0}")]
68    Tpm(TpmError),
69}
70
71impl From<TpmError> for VtpmError {
72    fn from(err: TpmError) -> Self {
73        Self::Device(DeviceError::from(err))
74    }
75}
76
77impl From<TpmRc> for VtpmError {
78    fn from(rc: TpmRc) -> Self {
79        Self::Device(DeviceError::TpmRc(rc))
80    }
81}
82
83/// Outcome of refreshing [`VtpmContext`](crate::vtpm::VtpmContext) against the
84/// TPM.
85#[derive(Debug)]
86pub enum RefreshAction {
87    /// The context is still valid.
88    Keep,
89    /// The context is no longer valid.
90    Stale,
91    /// A new [`TpmsContext`](tpm2_protocol::data::VtpmsContext) substituting
92    /// the old one.
93    Updated(Box<TpmsContext>),
94}
95
96/// A VTPM object.
97pub trait VtpmContext: 'static {
98    /// Immutable cast.
99    fn as_any(&self) -> &dyn Any;
100
101    /// Mutable cast.
102    fn as_any_mut(&mut self) -> &mut dyn Any;
103
104    /// Returns the VTPM handle.
105    fn handle(&self) -> u32;
106
107    /// Returns class string.
108    fn class(&self) -> &'static str;
109
110    /// Returns details string.
111    fn details(&self) -> String;
112
113    /// Saves a context to a file.
114    ///
115    /// # Errors
116    ///
117    /// Returns a [`Io`](crate::vtpm::VtpmError::Io) when an I/O operation
118    /// fails.
119    fn save(&self, path: &Path) -> Result<(), VtpmError>;
120
121    /// Deletes a context.
122    ///
123    /// # Errors
124    ///
125    /// Returns a [`Device`](crate::vtpm::VtpmError::Device) when the TPM
126    /// transmission fails.
127    /// Returns a [`Io`](crate::vtpm::VtpmError::Io) when an I/O operation
128    /// fails.
129    fn delete(&self, device: &mut Device, cache_dir: &Path, vhandle: u32) -> Result<(), VtpmError>;
130
131    /// Refreshes a context.
132    ///
133    /// # Errors
134    ///
135    /// Returns a [`Device`](crate::vtpm::VtpmError::Device) when the TPM
136    /// transmission fails.
137    /// Returns a [`Io`](crate::vtpm::VtpmError::Io) when an I/O operation
138    /// fails.
139    fn refresh(&mut self, device: &mut Device) -> Result<RefreshAction, VtpmError>;
140}
141
142pub struct VtpmCache<'a> {
143    pub contexts: HashMap<u32, Box<dyn VtpmContext>>,
144    pub handles: HashMap<u32, TpmHandle>,
145    dirty: HashSet<u32>,
146    cache_dir: &'a Path,
147}
148
149impl<'a> VtpmCache<'a> {
150    /// Creates a new cache and loads existing contexts from disk.
151    ///
152    /// # Errors
153    ///
154    /// Returns a [`VtpmError`] if loading contexts from the cache directory fails.
155    pub fn new(cache_dir: &'a Path) -> Result<Self, VtpmError> {
156        let mut cache = Self {
157            contexts: HashMap::new(),
158            handles: HashMap::new(),
159            dirty: HashSet::new(),
160            cache_dir,
161        };
162        cache.load()?;
163        Ok(cache)
164    }
165
166    fn cache_dir(&self) -> &Path {
167        self.cache_dir
168    }
169
170    /// Finds a VTPM key corresponding to a `TpmtPublic`,
171    #[must_use]
172    pub fn find_by_public(&self, public: &TpmtPublic) -> Option<&VtpmKey> {
173        self.key_iter()
174            .find(|(_, key)| key.public.inner == *public)
175            .map(|(_, key)| key)
176    }
177
178    /// Finds a VTPM key corresponding to a physical handle.
179    ///
180    /// Reads the public area of a physical TPM handle and searches the cache
181    /// for a loaded key with a matching public area. If found, it returns the
182    /// corresponding virtual handle.
183    ///
184    /// # Errors
185    ///
186    /// Returns [`ContextNotFound`](crate::vtpm::VtpmError::ContextNotFound)
187    /// when context is not found.
188    /// Returns [`Device`](crate::vtpm::VtpmError::Device) when
189    /// `TPM2_ReadPublic` fails.
190    pub fn find_by_phandle(
191        &self,
192        device: &mut Device,
193        phandle: u32,
194    ) -> Result<&VtpmKey, VtpmError> {
195        let (public, _) = device.read_public(phandle.into())?;
196        self.find_by_public(&public)
197            .ok_or(VtpmError::HandleNotFound("tpm:", phandle))
198    }
199
200    /// Finds a VTPM key corresponding to a virtual handle.
201    ///
202    /// Reads the public area of a physical TPM handle and searches the cache
203    /// for a loaded key with a matching public area. If found, it returns the
204    /// corresponding virtual handle.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`ContextNotFound`](crate::vtpm::VtpmError::ContextNotFound)
209    /// when context is not found.
210    pub fn find_by_vhandle(&self, vhandle: u32) -> Result<&VtpmKey, VtpmError> {
211        self.key_iter()
212            .find(|(h, _)| **h == vhandle)
213            .map(|(_, key)| key)
214            .ok_or(VtpmError::HandleNotFound("vtpm:", vhandle))
215    }
216
217    /// Loads all contexts from the cache directory.
218    fn load(&mut self) -> Result<(), VtpmError> {
219        let entries = match fs::read_dir(self.cache_dir()) {
220            Ok(entries) => entries.filter_map(Result::ok),
221            Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(()),
222            Err(e) => return Err(e.into()),
223        };
224
225        for entry in entries {
226            let path = entry.path();
227            if path.extension().and_then(|s| s.to_str()) != Some("bin") {
228                continue;
229            }
230            let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
231                continue;
232            };
233            let Ok(vhandle) = u32::from_str_radix(stem, 16) else {
234                continue;
235            };
236
237            let ht = (vhandle >> 24) as u8;
238            let context: Box<dyn VtpmContext> = if ht == TpmHt::Transient as u8 {
239                Box::new(VtpmKey::load_from_path(&path)?)
240            } else if ht == TpmHt::HmacSession as u8 || ht == TpmHt::PolicySession as u8 {
241                Box::new(VtpmSession::load_from_path(&path)?)
242            } else {
243                continue;
244            };
245            self.contexts.insert(vhandle, context);
246        }
247        Ok(())
248    }
249
250    /// Saves all dirty contexts to disk.
251    ///
252    /// # Errors
253    ///
254    /// Returns a [`VtpmError`] if saving any of the dirty contexts fails.
255    pub fn save(&mut self) -> Result<(), VtpmError> {
256        let vhandles_to_save: Vec<u32> = self.dirty.drain().collect();
257        for vhandle in vhandles_to_save {
258            if let Some(context) = self.contexts.get(&vhandle) {
259                let path = self.cache_dir().join(format!("{vhandle:08x}.bin"));
260                context.save(&path)?;
261            }
262        }
263        Ok(())
264    }
265
266    /// Removes a context from the cache and performs necessary cleanup.
267    ///
268    /// # Errors
269    ///
270    /// Returns a [`VtpmError`] if deleting the context fails.
271    pub fn remove(&mut self, device: &mut Device, vhandle: u32) -> Result<(), VtpmError> {
272        if let Some(context) = self.contexts.remove(&vhandle) {
273            context.delete(device, self.cache_dir(), vhandle)?;
274        }
275        self.dirty.remove(&vhandle);
276        Ok(())
277    }
278
279    /// Tracks a transient handle for automatic cleanup.
280    ///
281    /// # Errors
282    ///
283    /// Returns a [`VtpmError::AlreadyTracked`] if the handle is already being tracked.
284    pub fn track(&mut self, handle: TpmHandle) -> Result<(), VtpmError> {
285        if self.handles.contains_key(&handle.0) {
286            return Err(VtpmError::AlreadyTracked(handle));
287        }
288        self.handles.insert(handle.0, handle);
289        Ok(())
290    }
291
292    /// Removes a handle from the tracking list.
293    pub fn untrack(&mut self, handle: u32) {
294        self.handles.remove(&handle);
295    }
296
297    /// Flushes all tracked transient handles from the TPM.
298    fn flush(&mut self, device: &mut Device) {
299        let handles_to_flush: Vec<TpmHandle> = self.handles.drain().map(|(_, v)| v).collect();
300        for handle in handles_to_flush {
301            if let Err(err) = device.flush_context(handle) {
302                log::error!("{handle}: {err}");
303            }
304        }
305    }
306
307    /// Finalizes the cache, saving dirty contexts and flushing handles.
308    pub fn teardown(&mut self, device: Option<Rc<std::cell::RefCell<Device>>>) {
309        if let Err(e) = self.save() {
310            log::error!("teardown: {e:#}");
311        }
312        if let Some(device_rc) = device {
313            if let Ok(mut dev) = device_rc.try_borrow_mut() {
314                self.flush(&mut dev);
315            }
316        }
317    }
318
319    /// Saves a new key context.
320    ///
321    /// # Errors
322    ///
323    /// Returns a [`VtpmError`] if saving the context to the TPM or writing the
324    /// cache file fails.
325    pub fn save_context(
326        &mut self,
327        device: &mut Device,
328        handle: TpmHandle,
329        public: &Tpm2bPublic,
330        parent_public: &Tpm2bPublic,
331    ) -> Result<u32, VtpmError> {
332        let context = device.save_context(handle)?;
333        for vhandle in 0x8000_0000u32..=0x80FF_FFFF {
334            if let Entry::Vacant(e) = self.contexts.entry(vhandle) {
335                let key = VtpmKey {
336                    context,
337                    handle: TpmHandle(vhandle),
338                    public: public.clone(),
339                    parent: parent_public.clone(),
340                };
341                e.insert(Box::new(key));
342                self.dirty.insert(vhandle);
343                return Ok(vhandle);
344            }
345        }
346        Err(VtpmError::NoHandles)
347    }
348
349    /// Adds a session to the cache.
350    pub fn add_session(&mut self, session: VtpmSession) -> u32 {
351        let vhandle = session.handle();
352        self.contexts.insert(vhandle, Box::new(session));
353        self.dirty.insert(vhandle);
354        vhandle
355    }
356
357    /// Marks a context as dirty.
358    pub fn mark_dirty(&mut self, vhandle: u32) {
359        self.dirty.insert(vhandle);
360    }
361
362    /// Gets an immutable reference to a session.
363    #[must_use]
364    pub fn get_session(&self, vhandle: u32) -> Option<&VtpmSession> {
365        self.contexts
366            .get(&vhandle)
367            .and_then(|ctx| ctx.as_any().downcast_ref::<VtpmSession>())
368    }
369
370    /// Gets a mutable reference to a session.
371    pub fn get_mut_session(&mut self, vhandle: u32) -> Option<&mut VtpmSession> {
372        self.dirty.insert(vhandle);
373        self.contexts
374            .get_mut(&vhandle)
375            .and_then(|ctx| ctx.as_any_mut().downcast_mut::<VtpmSession>())
376    }
377
378    /// Returns an iterator over the key contexts.
379    pub fn key_iter(&self) -> impl Iterator<Item = (&u32, &VtpmKey)> {
380        self.contexts
381            .iter()
382            .filter_map(|(h, ctx)| ctx.as_any().downcast_ref::<VtpmKey>().map(|key| (h, key)))
383    }
384
385    /// Prepares sessions by loading them into the TPM.
386    ///
387    /// # Errors
388    ///
389    /// Returns a [`VtpmError`] if a session is not found or loading its context fails.
390    pub fn prepare_sessions(
391        &mut self,
392        device: &mut Device,
393        auth_list: &[Auth],
394    ) -> Result<Vec<TpmHandle>, VtpmError> {
395        let mut activated_handles = Vec::new();
396        for auth in auth_list {
397            if auth.class() == AuthClass::Session {
398                let vhandle = auth.session()?;
399                let session = self
400                    .get_session(vhandle)
401                    .ok_or(VtpmError::HandleNotFound("vtpm:", vhandle))?;
402                activated_handles.push(device.load_context(session.context.clone())?);
403            }
404        }
405        Ok(activated_handles)
406    }
407
408    /// Tears down sessions by saving their updated contexts.
409    ///
410    /// # Errors
411    ///
412    /// Returns a [`VtpmError`] if a session is not found or if saving/flushing
413    /// the context fails.
414    pub fn teardown_sessions(
415        &mut self,
416        device: &mut Device,
417        session_vhandles: &HashSet<u32>,
418        auth_responses: &TpmAuthResponses,
419    ) -> Result<(), VtpmError> {
420        for (i, vhandle) in session_vhandles.iter().enumerate() {
421            let session_handle = self
422                .get_session(*vhandle)
423                .ok_or(VtpmError::HandleNotFound("vtpm:", *vhandle))?
424                .context
425                .saved_handle;
426
427            match device.save_context(session_handle) {
428                Ok(new_context) => {
429                    let session = self
430                        .get_mut_session(*vhandle)
431                        .ok_or(VtpmError::HandleNotFound("vtpm:", *vhandle))?;
432                    session.context = new_context;
433                    let auth = auth_responses[i];
434                    session.nonce_tpm = auth.nonce;
435                    session.attributes = auth.session_attributes;
436                }
437                Err(e) => {
438                    if let Err(e) = device.flush_context(session_handle) {
439                        log::warn!("{session_handle}: {e}");
440                    }
441                    return Err(e.into());
442                }
443            }
444        }
445        Ok(())
446    }
447}