1use crate::{
17 convert::from_tpm_object_to_vec,
18 crypto::crypto_digest,
19 device::{Auth, Device, DeviceError, TpmCommandObject},
20 key::{AnyKey, KeyError, TpmKey},
21 session::SessionCache,
22 uri::{Uri, UriError},
23};
24
25use std::{
26 cell::RefCell,
27 cmp,
28 collections::{HashMap, HashSet},
29 fmt, fs,
30 io::Write,
31 num::TryFromIntError,
32 path::{Path, PathBuf},
33 rc::Rc,
34};
35
36use thiserror::Error;
37use tpm2_protocol::{
38 data::{Tpm2bName, TpmAlgId, TpmCc, TpmHt, TpmRcBase, TpmRh, TpmaNv, TpmsContext, TpmtPublic},
39 message::{
40 TpmAuthResponses, TpmEvictControlCommand, TpmFlushContextCommand, TpmNvReadCommand,
41 TpmNvReadPublicCommand, TpmResponseBody,
42 },
43 TpmErrorKind, TpmHandle, TpmParse,
44};
45
46#[derive(Debug, Error)]
47pub enum ContextError {
48 #[error("already tracked: {0}")]
49 AlreadyTracked(TpmHandle),
50 #[error("context not found: {0}")]
51 ContextNotFound(String),
52 #[error("crypto: {0}")]
53 Crypto(#[from] crate::crypto::CryptoError),
54 #[error("device: {0}")]
55 Device(#[from] DeviceError),
56 #[error("invalid handle: {0:08x}")]
57 InvalidHandle(u32),
58 #[error("invalid parent URI: must be a tpm:// or key:// URI")]
59 InvalidParentUri,
60 #[error("invalid URI: {0}")]
61 InvalidUri(UriError),
62 #[error("I/O: {0}")]
63 Io(#[from] std::io::Error),
64 #[error("key: {0}")]
65 Key(#[from] KeyError),
66 #[error("not tracked: {0}")]
67 NotTracked(TpmHandle),
68 #[error("parent not loaded")]
69 ParentNotLoaded,
70 #[error("session: {0}")]
71 Session(#[from] crate::session::SessionError),
72 #[error("unknown handle: {0:08x}")]
73 UnknownHandle(u32),
74 #[error("uri: {0}")]
75 Uri(#[from] UriError),
76}
77
78impl From<TpmErrorKind> for ContextError {
79 fn from(err: TpmErrorKind) -> Self {
80 Self::Device(DeviceError::from(err))
81 }
82}
83
84impl From<TryFromIntError> for ContextError {
85 fn from(err: TryFromIntError) -> Self {
86 Self::Device(err.into())
87 }
88}
89
90pub struct ContextCache<'a> {
91 pub handles: HashMap<u32, TpmHandle>,
92 pub writer: &'a mut dyn Write,
93 pub contexts: HashMap<String, Vec<u8>>,
94 dirty_contexts: HashSet<String>,
95 contexts_dir: PathBuf,
96 pub session_map: SessionCache,
97}
98
99impl std::fmt::Debug for ContextCache<'_> {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
101 let handles: Vec<String> = self
102 .handles
103 .values()
104 .map(|t| Uri::Tpm(t.0).to_string())
105 .collect();
106 f.debug_struct("Context")
107 .field("handles", &handles)
108 .field("contexts", &self.contexts.keys())
109 .field("writer", &"<dyn Write>")
110 .finish()
111 }
112}
113
114pub struct Context {
116 pub device: Rc<RefCell<Device>>,
117 pub handle: TpmHandle,
118 pub grip: String,
119 pub public: TpmtPublic,
120}
121
122impl Drop for Context {
123 fn drop(&mut self) {
124 if let Ok(mut device) = self.device.try_borrow_mut() {
125 if let Err(e) = device.flush_context(self.handle.0) {
126 log::warn!(
127 "Failed to flush context for handle {:08x}: {e}",
128 self.handle
129 );
130 }
131 } else {
132 log::error!(
133 "Could not borrow device to flush context handle {}",
134 self.handle
135 );
136 }
137 }
138}
139
140pub struct ContextIterator<'a> {
142 device: Rc<RefCell<Device>>,
143 context_keys: Vec<String>,
144 contexts_map: &'a HashMap<String, Vec<u8>>,
145}
146
147pub enum ContextItem {
149 Loaded(Box<Context>),
151 Stale(String),
153}
154
155impl Iterator for ContextIterator<'_> {
156 type Item = Result<ContextItem, DeviceError>;
157
158 fn next(&mut self) -> Option<Self::Item> {
159 while let Some(grip) = self.context_keys.pop() {
160 let Some(context_blob) = self.contexts_map.get(&grip) else {
161 continue;
162 };
163
164 let context_struct = match TpmsContext::parse(context_blob) {
165 Ok((cs, _)) => cs,
166 Err(e) => {
167 log::warn!("Failed to parse context for grip {grip}: {e}");
168 continue;
169 }
170 };
171
172 let mut device = self.device.borrow_mut();
173 let live_handle = match device.load_context(context_struct) {
174 Ok(h) => h,
175 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => {
176 return Some(Ok(ContextItem::Stale(grip)));
177 }
178 Err(e) => {
179 log::warn!("Skipping unloadable context {grip}: {e}");
180 continue;
181 }
182 };
183
184 let public = match device.read_public(live_handle.into()) {
185 Ok((p, _)) => p,
186 Err(e) => {
187 log::warn!("Failed to read public area for context {grip}: {e}");
188 if let Err(flush_err) = device.flush_context(live_handle) {
189 log::error!(
190 "Failed to flush context after read_public failed: {flush_err}"
191 );
192 }
193 continue;
194 }
195 };
196
197 return Some(Ok(ContextItem::Loaded(Box::new(Context {
198 device: self.device.clone(),
199 handle: live_handle.into(),
200 grip,
201 public,
202 }))));
203 }
204 None
205 }
206}
207
208impl<'a> ContextCache<'a> {
209 pub fn teardown(&mut self, device: Option<Rc<RefCell<Device>>>) {
211 if let Err(e) = self.save_contexts() {
212 eprintln!("teardown: {e:#}");
213 }
214 if let Err(e) = self.session_map.save() {
215 eprintln!("teardown: {e:#}");
216 }
217 if let Some(device_rc) = device {
218 match device_rc.try_borrow_mut() {
219 Ok(mut device_guard) => {
220 if let Err(e) = self.flush(&mut device_guard) {
221 eprintln!("teardown: {e:#}");
222 }
223 }
224 Err(e) => {
225 eprintln!("teardown: {e:#}");
226 }
227 }
228 }
229 }
230
231 pub fn execute<C: TpmCommandObject>(
241 &mut self,
242 device: &mut Device,
243 command: &C,
244 handles: &[u32],
245 auths: &[Auth],
246 ) -> Result<(TpmResponseBody, TpmAuthResponses), ContextError> {
247 let activated_handles = self.session_map.prepare_sessions(device, auths)?;
248
249 for &handle in &activated_handles {
250 self.track(TpmHandle(handle))?;
251 }
252
253 let (sessions, session_handles) = self
254 .session_map
255 .build_auth_area(device, command, handles, auths)?;
256
257 let (resp, auth_responses) = device.execute(command, &sessions)?;
258
259 self.session_map
260 .teardown_sessions(device, &session_handles, &auth_responses)?;
261
262 for handle in activated_handles {
263 self.untrack(handle);
264 }
265
266 Ok((resp, auth_responses))
267 }
268
269 pub fn new(
275 device: Option<&mut Device>,
276 cache_dir: &Path,
277 writer: &'a mut dyn Write,
278 session_map: SessionCache,
279 ) -> Result<ContextCache<'a>, ContextError> {
280 let contexts_dir = cache_dir.join("contexts");
281 let mut new_context = Self {
282 handles: HashMap::new(),
283 writer,
284 contexts: HashMap::new(),
285 dirty_contexts: HashSet::new(),
286 contexts_dir,
287 session_map,
288 };
289
290 new_context.load_contexts()?;
291
292 if let Some(dev) = device {
293 new_context.refresh_contexts(dev)?;
294 }
295
296 Ok(new_context)
297 }
298
299 pub fn loaded_contexts(&self, device: Rc<RefCell<Device>>) -> ContextIterator<'_> {
301 let keys = self.contexts.keys().cloned().collect();
302 ContextIterator {
303 device,
304 context_keys: keys,
305 contexts_map: &self.contexts,
306 }
307 }
308
309 fn load_contexts(&mut self) -> Result<(), ContextError> {
311 fs::create_dir_all(&self.contexts_dir)?;
312 let entries = match fs::read_dir(&self.contexts_dir) {
313 Ok(entries) => entries.filter_map(Result::ok),
314 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
315 Err(e) => return Err(e.into()),
316 };
317
318 for entry in entries {
319 let path = entry.path();
320 if path.is_file() {
321 if let Some(grip) = path.file_stem().and_then(|s| s.to_str()) {
322 if grip.len() == 16 && grip.chars().all(|c| c.is_ascii_hexdigit()) {
323 let content = fs::read(&path)?;
324 self.contexts.insert(grip.to_string(), content);
325 } else {
326 log::trace!(
327 "Pruning invalid or outdated context file: {}",
328 path.display()
329 );
330 fs::remove_file(path)?;
331 }
332 }
333 }
334 }
335 Ok(())
336 }
337
338 fn save_contexts(&mut self) -> Result<(), ContextError> {
340 if self.dirty_contexts.is_empty() {
341 return Ok(());
342 }
343 fs::create_dir_all(&self.contexts_dir)?;
344 for grip in self.dirty_contexts.drain() {
345 if let Some(data) = self.contexts.get(&grip) {
346 let path = self.contexts_dir.join(&grip);
347 fs::write(path, data)?;
348 }
349 }
350 Ok(())
351 }
352
353 pub fn remove_context(&mut self, grip: &str) -> Result<(), ContextError> {
359 if self.contexts.remove(grip).is_some() {
360 let path = self.contexts_dir.join(grip);
361 if let Err(e) = fs::remove_file(path) {
362 if e.kind() != std::io::ErrorKind::NotFound {
363 return Err(e.into());
364 }
365 }
366 }
367 Ok(())
368 }
369
370 pub fn reset(&mut self) -> Result<(), ContextError> {
376 let paths_to_delete: Vec<_> = self
377 .contexts
378 .keys()
379 .map(|grip| self.contexts_dir.join(grip))
380 .collect();
381 for path in paths_to_delete {
382 if let Err(e) = fs::remove_file(path) {
383 if e.kind() != std::io::ErrorKind::NotFound {
384 return Err(e.into());
385 }
386 }
387 }
388 self.contexts.clear();
389 self.dirty_contexts.clear();
390 Ok(())
391 }
392
393 fn refresh_contexts(&mut self, device: &mut Device) -> Result<(), ContextError> {
395 let grips_to_refresh: Vec<String> = self.contexts.keys().cloned().collect();
396 for grip in grips_to_refresh {
397 let context_blob = match self.contexts.get(&grip) {
398 Some(blob) => blob.clone(),
399 None => continue,
400 };
401
402 let (context_struct, _) = TpmsContext::parse(&context_blob)?;
403
404 match device.load_context(context_struct) {
405 Ok(live_handle) => {
406 let new_context_struct = device.save_context(live_handle)?;
407 device.flush_context(live_handle)?;
408
409 let new_context_blob = from_tpm_object_to_vec(&new_context_struct)?;
410
411 self.contexts.insert(grip.clone(), new_context_blob);
412 self.dirty_contexts.insert(grip);
413 }
414 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => {
415 self.remove_context(&grip)?;
416 }
417 Err(e) => {
418 log::warn!("key://{grip}: {e}");
419 self.remove_context(&grip)?;
420 }
421 }
422 }
423 Ok(())
424 }
425
426 pub fn new_context(
433 &mut self,
434 device: &mut Device,
435 handle: TpmHandle,
436 name: &Tpm2bName,
437 ) -> Result<(), ContextError> {
438 let context_struct = device.save_context(handle.0)?;
439 let context_bytes = from_tpm_object_to_vec(&context_struct)?;
440 let digest = crypto_digest(TpmAlgId::Sha256, &[name.as_ref()])?;
441 let grip = hex::encode(&digest[..8]);
442
443 self.contexts.insert(grip.clone(), context_bytes);
444 self.dirty_contexts.insert(grip.clone());
445
446 writeln!(self.writer, "key://{grip}")?;
447 Ok(())
448 }
449
450 pub fn mark_dirty(&mut self, grip: String) {
452 self.dirty_contexts.insert(grip);
453 }
454
455 #[must_use]
456 pub fn cache_dir(&self) -> &Path {
457 &self.contexts_dir
458 }
459
460 pub fn import_key(
466 &mut self,
467 device: &mut Device,
468 parent_handle: TpmHandle,
469 input_bytes: &[u8],
470 auths: &[Auth],
471 ) -> Result<TpmKey, ContextError> {
472 let external_key = match AnyKey::try_from(input_bytes)? {
473 AnyKey::Tpm(_) => return Err(ContextError::Key(KeyError::InvalidFormat)),
474 AnyKey::External(key) => key,
475 };
476
477 let mut rng = rand::thread_rng();
478 let handles = [parent_handle.0];
479
480 self.session_map.prepare_sessions(device, auths)?;
481 Ok(TpmKey::from_external_key(
482 device,
483 parent_handle,
484 &external_key,
485 &mut rng,
486 &handles,
487 auths,
488 self,
489 )?)
490 }
491
492 pub fn load_context_from_bytes(
498 &mut self,
499 device: &mut Device,
500 blob: &[u8],
501 ) -> Result<(TpmHandle, Tpm2bName), ContextError> {
502 let (context, _) = TpmsContext::parse(blob)?;
503 match device.load_context(context) {
504 Ok(handle) => {
505 let handle = TpmHandle(handle);
506 let (_, name) = device.read_public(handle)?;
507 device.add_name_to_cache(handle.0, name);
508 self.track(handle)?;
509 Ok((handle, name))
510 }
511 Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::Handle => {
512 Err(ContextError::ParentNotLoaded)
513 }
514 Err(e) => Err(e.into()),
515 }
516 }
517
518 pub fn load_parent(
525 &mut self,
526 device: &mut Device,
527 uri: &Uri,
528 ) -> Result<TpmHandle, ContextError> {
529 if matches!(uri, Uri::Path(_) | Uri::Password(_)) {
530 return Err(ContextError::InvalidParentUri);
531 }
532 self.load_context(device, uri)
533 }
534
535 pub fn load_context(
545 &mut self,
546 device: &mut Device,
547 uri: &Uri,
548 ) -> Result<TpmHandle, ContextError> {
549 match uri {
550 Uri::Tpm(handle) => Ok(TpmHandle(*handle)),
551 Uri::Context(grip) => {
552 let context_blob = self
553 .contexts
554 .get(grip)
555 .ok_or_else(|| ContextError::ContextNotFound(grip.clone()))?
556 .clone();
557 self.load_context_from_bytes(device, &context_blob)
558 .map(|(handle, _)| handle)
559 }
560 Uri::Path(_) => {
561 let context_blob = uri.to_bytes()?;
562 self.load_context_from_bytes(device, &context_blob)
563 .map(|(handle, _)| handle)
564 }
565 Uri::Password(_) | Uri::Session(_) => {
566 Err(ContextError::InvalidUri(UriError::InvalidUriType))
567 }
568 }
569 }
570
571 pub fn delete(
577 &mut self,
578 device: &mut Device,
579 uri: &Uri,
580 auths: &[Auth],
581 ) -> Result<u32, ContextError> {
582 let handle = self.load_context(device, uri)?.0;
583
584 let mso = (handle >> 24) as u8;
585 let result = match TpmHt::try_from(mso) {
586 Ok(TpmHt::Persistent) => self.delete_persistent(device, TpmHandle(handle), auths),
587 Ok(TpmHt::Transient) => self.delete_transient(device, TpmHandle(handle)),
588 Ok(TpmHt::HmacSession | TpmHt::PolicySession) => {
589 let cmd = TpmFlushContextCommand {
590 flush_handle: handle.into(),
591 };
592 let sessions = vec![];
593 device.execute(&cmd, &sessions)?;
594 self.handles.remove(&handle);
595 Ok(())
596 }
597 _ => return Err(ContextError::InvalidHandle(handle)),
598 };
599
600 match result {
601 Ok(()) => Ok(handle),
602 Err(ContextError::Device(DeviceError::TpmRc(rc))) if rc.base() == TpmRcBase::Handle => {
603 Err(ContextError::UnknownHandle(handle))
604 }
605 Err(e) => Err(e),
606 }
607 }
608
609 pub fn track(&mut self, handle: TpmHandle) -> Result<(), ContextError> {
615 self.non_existence_invariant(handle)?;
616
617 let mso = (handle.0 >> 24) as u8;
618 match TpmHt::try_from(mso) {
619 Ok(TpmHt::Transient | TpmHt::HmacSession | TpmHt::PolicySession) => {
620 self.handles.insert(handle.0, handle);
621 Ok(())
622 }
623 _ => Err(ContextError::InvalidHandle(handle.0)),
624 }
625 }
626
627 pub fn untrack(&mut self, handle: u32) {
629 self.handles.remove(&handle);
630 }
631
632 pub fn flush(&mut self, device: &mut Device) -> Result<(), ContextError> {
639 let handles_to_flush: Vec<TpmHandle> = self.handles.drain().map(|(_, v)| v).collect();
640
641 for handle in handles_to_flush {
642 let cmd = TpmFlushContextCommand {
643 flush_handle: handle,
644 };
645 let sessions = vec![];
646 if let Err(err) = device.execute(&cmd, &sessions) {
647 let uri = Uri::Tpm(handle.0);
648 log::error!("{uri}: {err}");
649 }
650 }
651
652 Ok(())
653 }
654
655 pub fn write_key_data(
661 &mut self,
662 output_uri: Option<&Uri>,
663 key: &TpmKey,
664 ) -> Result<(), ContextError> {
665 let output_is_der = if let Some(Uri::Path(path_str)) = output_uri {
666 Path::new(path_str)
667 .extension()
668 .and_then(std::ffi::OsStr::to_str)
669 == Some("der")
670 } else {
671 false
672 };
673
674 let output_bytes = if output_is_der {
675 key.to_der()?
676 } else {
677 key.to_pem()?.into_bytes()
678 };
679
680 self.write_data(output_uri, &output_bytes)
681 }
682
683 pub fn write_data(
689 &mut self,
690 output_uri: Option<&Uri>,
691 data: &[u8],
692 ) -> Result<(), ContextError> {
693 if let Some(uri) = output_uri {
694 match uri {
695 Uri::Path(path) => {
696 std::fs::write(path, data)?;
697 writeln!(self.writer, "{uri}")?;
698 }
699 _ => return Err(ContextError::InvalidUri(UriError::InvalidUriType)),
700 }
701 } else {
702 self.writer.write_all(data)?;
703 }
704 Ok(())
705 }
706
707 pub fn read_certificate(
713 &mut self,
714 device: &mut Device,
715 auths: &[Auth],
716 handle: u32,
717 max_read_size: usize,
718 ) -> Result<Option<Vec<u8>>, ContextError> {
719 let nv_read_public_cmd = TpmNvReadPublicCommand {
720 nv_index: handle.into(),
721 };
722 let (resp, _) = device.execute(&nv_read_public_cmd, &[])?;
723 let read_public_resp = resp
724 .NvReadPublic()
725 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::NvReadPublic))?;
726 let nv_public = read_public_resp.nv_public;
727 let data_size = nv_public.data_size as usize;
728
729 if data_size == 0 {
730 return Ok(None);
731 }
732
733 let auth_handle = if nv_public.attributes.contains(TpmaNv::AUTHREAD) {
734 handle
735 } else if nv_public.attributes.contains(TpmaNv::PPREAD) {
736 TpmRh::Platform as u32
737 } else if nv_public.attributes.contains(TpmaNv::OWNERREAD) {
738 TpmRh::Owner as u32
739 } else {
740 handle
741 };
742
743 let mut cert_bytes = Vec::with_capacity(data_size);
744 let mut offset = 0;
745 while offset < data_size {
746 let chunk_size = cmp::min(max_read_size, data_size - offset);
747
748 let nv_read_cmd = TpmNvReadCommand {
749 auth_handle: auth_handle.into(),
750 nv_index: handle.into(),
751 size: u16::try_from(chunk_size)?,
752 offset: u16::try_from(offset)?,
753 };
754
755 let (resp, _) = self.execute(device, &nv_read_cmd, &[auth_handle], auths)?;
756
757 let read_resp = resp
758 .NvRead()
759 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::NvRead))?;
760 cert_bytes.extend_from_slice(read_resp.data.as_ref());
761 offset += chunk_size;
762 }
763
764 Ok(Some(cert_bytes))
765 }
766
767 fn delete_persistent(
768 &mut self,
769 device: &mut Device,
770 handle: TpmHandle,
771 auths: &[Auth],
772 ) -> Result<(), ContextError> {
773 let auth_handle = TpmRh::Owner;
774 let cmd = TpmEvictControlCommand {
775 auth: (auth_handle as u32).into(),
776 object_handle: handle.0.into(),
777 persistent_handle: handle,
778 };
779 let handles = [auth_handle as u32, handle.0];
780
781 let (resp, _) = self.execute(device, &cmd, &handles, auths)?;
782
783 resp.EvictControl()
784 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::EvictControl))?;
785 Ok(())
786 }
787
788 fn delete_transient(
789 &mut self,
790 device: &mut Device,
791 handle: TpmHandle,
792 ) -> Result<(), ContextError> {
793 let cmd = TpmFlushContextCommand {
794 flush_handle: handle,
795 };
796 let sessions = vec![];
797 let (_, _) = device.execute(&cmd, &sessions)?;
798 self.handles.remove(&handle.0);
799 Ok(())
800 }
801
802 pub fn evict_key(
809 &mut self,
810 device: &mut Device,
811 transient_handle: TpmHandle,
812 persistent_handle: TpmHandle,
813 auths: &[Auth],
814 ) -> Result<(), ContextError> {
815 self.existence_invariant(transient_handle)?;
816 let auth_handle = TpmRh::Owner;
817 let cmd = TpmEvictControlCommand {
818 auth: (auth_handle as u32).into(),
819 object_handle: transient_handle.0.into(),
820 persistent_handle,
821 };
822 let handles = [auth_handle as u32, transient_handle.0];
823
824 let (resp, _) = self.execute(device, &cmd, &handles, auths)?;
825
826 resp.EvictControl()
827 .map_err(|_| DeviceError::ResponseMismatch(TpmCc::EvictControl))?;
828 self.handles.remove(&transient_handle.0);
829 Ok(())
830 }
831
832 fn existence_invariant(&self, handle: TpmHandle) -> Result<(), ContextError> {
833 if self.handles.contains_key(&handle.0) {
834 Ok(())
835 } else {
836 Err(ContextError::NotTracked(handle))
837 }
838 }
839
840 fn non_existence_invariant(&self, handle: TpmHandle) -> Result<(), ContextError> {
841 if self.handles.contains_key(&handle.0) {
842 Err(ContextError::AlreadyTracked(handle))
843 } else {
844 Ok(())
845 }
846 }
847}