1use crate::{
8 auth::{Auth, AuthClass, AuthError},
9 crypto::CryptoError,
10 device::{Device, DeviceError},
11 handle::HandleError,
12 key::Tpm2shAlgId,
13};
14use std::{
15 any::Any,
16 collections::{hash_map::Entry, HashMap, HashSet},
17 fs, io,
18 num::TryFromIntError,
19 path::Path,
20 rc::Rc,
21};
22use thiserror::Error;
23use tpm2_protocol::{
24 data::{Tpm2bPublic, TpmHt, TpmRc, TpmsContext, TpmtPublic},
25 message::TpmAuthResponses,
26 TpmError, TpmHandle,
27};
28
29mod key;
30mod session;
31
32pub use key::*;
33pub use session::*;
34
35#[derive(Debug, Error)]
36pub enum VtpmError {
37 #[error("already tracked: {0}")]
38 AlreadyTracked(TpmHandle),
39 #[error("handle not found: {0}{1:08x}")]
40 HandleNotFound(&'static str, u32),
41 #[error("invalid auth")]
42 InvalidAuth,
43 #[error("invalid key bits: {0}")]
44 InvalidKeyBits(String),
45 #[error("invalid parent: {0:08x}")]
46 InvalidParent(u32),
47 #[error("no handles")]
48 NoHandles,
49 #[error("parent not loaded")]
50 ParentNotLoaded,
51 #[error("trailing authorizations")]
52 TrailingAuthorizations,
53 #[error("unsupported name algorithm: {0}")]
54 UnsupportedNameAlgorithm(Tpm2shAlgId),
55 #[error("auth error: {0}")]
56 Auth(#[from] AuthError),
57 #[error("crypto: {0}")]
58 Crypto(#[from] CryptoError),
59 #[error("device: {0}")]
60 Device(#[from] DeviceError),
61 #[error("handle: {0}")]
62 Handle(#[from] HandleError),
63 #[error("int decode: {0}")]
64 IntDecode(#[from] TryFromIntError),
65 #[error("I/O: {0}")]
66 Io(#[from] io::Error),
67 #[error("TPM: {0}")]
68 Tpm(TpmError),
69}
70
71impl From<TpmError> for VtpmError {
72 fn from(err: TpmError) -> Self {
73 Self::Device(DeviceError::from(err))
74 }
75}
76
77impl From<TpmRc> for VtpmError {
78 fn from(rc: TpmRc) -> Self {
79 Self::Device(DeviceError::TpmRc(rc))
80 }
81}
82
83#[derive(Debug)]
86pub enum RefreshAction {
87 Keep,
89 Stale,
91 Updated(Box<TpmsContext>),
94}
95
96pub trait VtpmContext: 'static {
98 fn as_any(&self) -> &dyn Any;
100
101 fn as_any_mut(&mut self) -> &mut dyn Any;
103
104 fn handle(&self) -> u32;
106
107 fn class(&self) -> &'static str;
109
110 fn details(&self) -> String;
112
113 fn save(&self, path: &Path) -> Result<(), VtpmError>;
120
121 fn delete(&self, device: &mut Device, cache_dir: &Path, vhandle: u32) -> Result<(), VtpmError>;
130
131 fn refresh(&mut self, device: &mut Device) -> Result<RefreshAction, VtpmError>;
140}
141
142pub struct VtpmCache<'a> {
143 pub contexts: HashMap<u32, Box<dyn VtpmContext>>,
144 pub handles: HashMap<u32, TpmHandle>,
145 dirty: HashSet<u32>,
146 cache_dir: &'a Path,
147}
148
149impl<'a> VtpmCache<'a> {
150 pub fn new(cache_dir: &'a Path) -> Result<Self, VtpmError> {
156 let mut cache = Self {
157 contexts: HashMap::new(),
158 handles: HashMap::new(),
159 dirty: HashSet::new(),
160 cache_dir,
161 };
162 cache.load()?;
163 Ok(cache)
164 }
165
166 fn cache_dir(&self) -> &Path {
167 self.cache_dir
168 }
169
170 #[must_use]
172 pub fn find_by_public(&self, public: &TpmtPublic) -> Option<&VtpmKey> {
173 self.key_iter()
174 .find(|(_, key)| key.public.inner == *public)
175 .map(|(_, key)| key)
176 }
177
178 pub fn find_by_phandle(
191 &self,
192 device: &mut Device,
193 phandle: u32,
194 ) -> Result<&VtpmKey, VtpmError> {
195 let (public, _) = device.read_public(phandle.into())?;
196 self.find_by_public(&public)
197 .ok_or(VtpmError::HandleNotFound("tpm:", phandle))
198 }
199
200 pub fn find_by_vhandle(&self, vhandle: u32) -> Result<&VtpmKey, VtpmError> {
211 self.key_iter()
212 .find(|(h, _)| **h == vhandle)
213 .map(|(_, key)| key)
214 .ok_or(VtpmError::HandleNotFound("vtpm:", vhandle))
215 }
216
217 fn load(&mut self) -> Result<(), VtpmError> {
219 let entries = match fs::read_dir(self.cache_dir()) {
220 Ok(entries) => entries.filter_map(Result::ok),
221 Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(()),
222 Err(e) => return Err(e.into()),
223 };
224
225 for entry in entries {
226 let path = entry.path();
227 if path.extension().and_then(|s| s.to_str()) != Some("bin") {
228 continue;
229 }
230 let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
231 continue;
232 };
233 let Ok(vhandle) = u32::from_str_radix(stem, 16) else {
234 continue;
235 };
236
237 let ht = (vhandle >> 24) as u8;
238 let context: Box<dyn VtpmContext> = if ht == TpmHt::Transient as u8 {
239 Box::new(VtpmKey::load_from_path(&path)?)
240 } else if ht == TpmHt::HmacSession as u8 || ht == TpmHt::PolicySession as u8 {
241 Box::new(VtpmSession::load_from_path(&path)?)
242 } else {
243 continue;
244 };
245 self.contexts.insert(vhandle, context);
246 }
247 Ok(())
248 }
249
250 pub fn save(&mut self) -> Result<(), VtpmError> {
256 let vhandles_to_save: Vec<u32> = self.dirty.drain().collect();
257 for vhandle in vhandles_to_save {
258 if let Some(context) = self.contexts.get(&vhandle) {
259 let path = self.cache_dir().join(format!("{vhandle:08x}.bin"));
260 context.save(&path)?;
261 }
262 }
263 Ok(())
264 }
265
266 pub fn remove(&mut self, device: &mut Device, vhandle: u32) -> Result<(), VtpmError> {
272 if let Some(context) = self.contexts.remove(&vhandle) {
273 context.delete(device, self.cache_dir(), vhandle)?;
274 }
275 self.dirty.remove(&vhandle);
276 Ok(())
277 }
278
279 pub fn track(&mut self, handle: TpmHandle) -> Result<(), VtpmError> {
285 if self.handles.contains_key(&handle.0) {
286 return Err(VtpmError::AlreadyTracked(handle));
287 }
288 self.handles.insert(handle.0, handle);
289 Ok(())
290 }
291
292 pub fn untrack(&mut self, handle: u32) {
294 self.handles.remove(&handle);
295 }
296
297 fn flush(&mut self, device: &mut Device) {
299 let handles_to_flush: Vec<TpmHandle> = self.handles.drain().map(|(_, v)| v).collect();
300 for handle in handles_to_flush {
301 if let Err(err) = device.flush_context(handle) {
302 log::error!("{handle}: {err}");
303 }
304 }
305 }
306
307 pub fn teardown(&mut self, device: Option<Rc<std::cell::RefCell<Device>>>) {
309 if let Err(e) = self.save() {
310 log::error!("teardown: {e:#}");
311 }
312 if let Some(device_rc) = device {
313 if let Ok(mut dev) = device_rc.try_borrow_mut() {
314 self.flush(&mut dev);
315 }
316 }
317 }
318
319 pub fn save_context(
326 &mut self,
327 device: &mut Device,
328 handle: TpmHandle,
329 public: &Tpm2bPublic,
330 parent_public: &Tpm2bPublic,
331 ) -> Result<u32, VtpmError> {
332 let context = device.save_context(handle)?;
333 for vhandle in 0x8000_0000u32..=0x80FF_FFFF {
334 if let Entry::Vacant(e) = self.contexts.entry(vhandle) {
335 let key = VtpmKey {
336 context,
337 handle: TpmHandle(vhandle),
338 public: public.clone(),
339 parent: parent_public.clone(),
340 };
341 e.insert(Box::new(key));
342 self.dirty.insert(vhandle);
343 return Ok(vhandle);
344 }
345 }
346 Err(VtpmError::NoHandles)
347 }
348
349 pub fn add_session(&mut self, session: VtpmSession) -> u32 {
351 let vhandle = session.handle();
352 self.contexts.insert(vhandle, Box::new(session));
353 self.dirty.insert(vhandle);
354 vhandle
355 }
356
357 pub fn mark_dirty(&mut self, vhandle: u32) {
359 self.dirty.insert(vhandle);
360 }
361
362 #[must_use]
364 pub fn get_session(&self, vhandle: u32) -> Option<&VtpmSession> {
365 self.contexts
366 .get(&vhandle)
367 .and_then(|ctx| ctx.as_any().downcast_ref::<VtpmSession>())
368 }
369
370 pub fn get_mut_session(&mut self, vhandle: u32) -> Option<&mut VtpmSession> {
372 self.dirty.insert(vhandle);
373 self.contexts
374 .get_mut(&vhandle)
375 .and_then(|ctx| ctx.as_any_mut().downcast_mut::<VtpmSession>())
376 }
377
378 pub fn key_iter(&self) -> impl Iterator<Item = (&u32, &VtpmKey)> {
380 self.contexts
381 .iter()
382 .filter_map(|(h, ctx)| ctx.as_any().downcast_ref::<VtpmKey>().map(|key| (h, key)))
383 }
384
385 pub fn prepare_sessions(
391 &mut self,
392 device: &mut Device,
393 auth_list: &[Auth],
394 ) -> Result<Vec<TpmHandle>, VtpmError> {
395 let mut activated_handles = Vec::new();
396 for auth in auth_list {
397 if auth.class() == AuthClass::Session {
398 let vhandle = auth.session()?;
399 let session = self
400 .get_session(vhandle)
401 .ok_or(VtpmError::HandleNotFound("vtpm:", vhandle))?;
402 activated_handles.push(device.load_context(session.context.clone())?);
403 }
404 }
405 Ok(activated_handles)
406 }
407
408 pub fn teardown_sessions(
415 &mut self,
416 device: &mut Device,
417 session_vhandles: &HashSet<u32>,
418 auth_responses: &TpmAuthResponses,
419 ) -> Result<(), VtpmError> {
420 for (i, vhandle) in session_vhandles.iter().enumerate() {
421 let session_handle = self
422 .get_session(*vhandle)
423 .ok_or(VtpmError::HandleNotFound("vtpm:", *vhandle))?
424 .context
425 .saved_handle;
426
427 match device.save_context(session_handle) {
428 Ok(new_context) => {
429 let session = self
430 .get_mut_session(*vhandle)
431 .ok_or(VtpmError::HandleNotFound("vtpm:", *vhandle))?;
432 session.context = new_context;
433 let auth = auth_responses[i];
434 session.nonce_tpm = auth.nonce;
435 session.attributes = auth.session_attributes;
436 }
437 Err(e) => {
438 if let Err(e) = device.flush_context(session_handle) {
439 log::warn!("{session_handle}: {e}");
440 }
441 return Err(e.into());
442 }
443 }
444 }
445 Ok(())
446 }
447}