tpm2_device/
lib.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9    fcntl,
10    poll::{poll, PollFd, PollFlags},
11};
12use std::{
13    cell::RefCell,
14    collections::HashMap,
15    fs::{File, OpenOptions},
16    io::{Read, Write},
17    os::fd::{AsFd, AsRawFd},
18    path::{Path, PathBuf},
19    rc::Rc,
20    time::{Duration, Instant},
21};
22
23use thiserror::Error;
24use tpm2_protocol::{
25    constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
26    data::{
27        Tpm2bName, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
28        TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
29        TpmtPublic, TpmuCapabilities,
30    },
31    frame::{
32        tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
33        TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame, TpmGetCapabilityCommand,
34        TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
35    },
36    TpmHandle, TpmWriter,
37};
38use tracing::trace;
39
40/// Errors that can occur when talking to a TPM device.
41#[derive(Debug, Error)]
42pub enum TpmDeviceError {
43    #[error("device is already borrowed")]
44    AlreadyBorrowed,
45    #[error("capability not found: {0}")]
46    CapabilityMissing(TpmCap),
47    #[error("operation interrupted by user")]
48    Interrupted,
49    #[error("invalid response")]
50    InvalidResponse,
51
52    #[error("I/O: {0}")]
53    Io(#[from] std::io::Error),
54
55    /// Marshaling a TPM protocol encoded object failed.
56    #[error("marshal: {0}")]
57    Marshal(tpm2_protocol::TpmProtocolError),
58
59    #[error("device not available")]
60    NotAvailable,
61    #[error("operation failed")]
62    OperationFailed,
63    #[error("PCR banks not available")]
64    PcrBanksNotAvailable,
65    #[error("PCR bank size mismatch")]
66    PcrBankSizeMismatch,
67
68    /// The TPM response did not match the expected command code.
69    #[error("response mismatch: {0}")]
70    ResponseMismatch(TpmCc),
71
72    #[error("TPM command timed out")]
73    Timeout,
74    #[error("TPM return code: {0}")]
75    TpmRc(TpmRc),
76
77    /// Unmarshaling a TPM protocol encoded object failed.
78    #[error("unmarshal: {0}")]
79    Unmarshal(tpm2_protocol::TpmProtocolError),
80
81    #[error("unexpected EOF")]
82    UnexpectedEof,
83}
84
85impl From<TpmRc> for TpmDeviceError {
86    fn from(rc: TpmRc) -> Self {
87        Self::TpmRc(rc)
88    }
89}
90
91impl From<nix::Error> for TpmDeviceError {
92    fn from(err: nix::Error) -> Self {
93        Self::Io(std::io::Error::from_raw_os_error(err as i32))
94    }
95}
96
97/// Executes a closure with a mutable reference to a `TpmDevice`.
98///
99/// This helper function centralizes the boilerplate for safely acquiring a
100/// mutable borrow of a `TpmDevice` from the shared `Rc<RefCell<...>>`.
101///
102/// # Errors
103///
104/// Returns [`NotAvailable`](crate::TpmDeviceError::NotAvailable) when no device
105/// is present and [`AlreadyBorrowed`](crate::TpmDeviceError::AlreadyBorrowed)
106/// when the device is already mutably borrowed, both converted into the caller's
107/// error type `E`. Propagates any error returned by the closure `f`.
108pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
109where
110    F: FnOnce(&mut TpmDevice) -> Result<T, E>,
111    E: From<TpmDeviceError>,
112{
113    let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
114    let mut device_guard = device_rc
115        .try_borrow_mut()
116        .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
117    f(&mut device_guard)
118}
119
120/// A builder for constructing a `TpmDevice`.
121pub struct TpmDeviceBuilder {
122    path: PathBuf,
123    timeout: Duration,
124    interrupted: Box<dyn Fn() -> bool>,
125}
126
127impl Default for TpmDeviceBuilder {
128    fn default() -> Self {
129        Self {
130            path: PathBuf::from("/dev/tpmrm0"),
131            timeout: Duration::from_secs(120),
132            interrupted: Box::new(|| false),
133        }
134    }
135}
136
137impl TpmDeviceBuilder {
138    /// Sets the device file path.
139    #[must_use]
140    pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
141        self.path = path.as_ref().to_path_buf();
142        self
143    }
144
145    /// Sets the operation timeout.
146    #[must_use]
147    pub fn with_timeout(mut self, timeout: Duration) -> Self {
148        self.timeout = timeout;
149        self
150    }
151
152    /// Sets the interruption check callback.
153    #[must_use]
154    pub fn with_interrupted<F>(mut self, handler: F) -> Self
155    where
156        F: Fn() -> bool + 'static,
157    {
158        self.interrupted = Box::new(handler);
159        self
160    }
161
162    /// Opens the TPM device file and constructs the `TpmDevice`.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`Io`](crate::TpmDeviceError::Io) when the device file cannot be
167    /// opened or when configuring the file descriptor flags fails.
168    pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
169        let file = OpenOptions::new()
170            .read(true)
171            .write(true)
172            .open(&self.path)
173            .map_err(TpmDeviceError::Io)?;
174
175        let fd = file.as_raw_fd();
176        let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
177        let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
178        oflags.insert(fcntl::OFlag::O_NONBLOCK);
179        fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
180
181        Ok(TpmDevice {
182            file,
183            name_cache: HashMap::new(),
184            interrupted: self.interrupted,
185            timeout: self.timeout,
186            command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
187            response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
188        })
189    }
190}
191
192pub struct TpmDevice {
193    file: File,
194    name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
195    interrupted: Box<dyn Fn() -> bool>,
196    timeout: Duration,
197    command: Vec<u8>,
198    response: Vec<u8>,
199}
200
201impl std::fmt::Debug for TpmDevice {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        f.debug_struct("Device")
204            .field("file", &self.file)
205            .field("name_cache", &self.name_cache)
206            .field("timeout", &self.timeout)
207            .finish_non_exhaustive()
208    }
209}
210
211impl TpmDevice {
212    const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
213
214    /// Creates a new builder for `TpmDevice`.
215    #[must_use]
216    pub fn builder() -> TpmDeviceBuilder {
217        TpmDeviceBuilder::default()
218    }
219
220    fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
221        let fd = self.file.as_fd();
222        let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
223
224        let num_events = match poll(&mut fds, 100u16) {
225            Ok(num) => num,
226            Err(nix::Error::EINTR) => return Ok(0),
227            Err(e) => return Err(e.into()),
228        };
229
230        if num_events == 0 {
231            return Ok(0);
232        }
233
234        let revents = fds[0].revents().unwrap_or(PollFlags::empty());
235
236        if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
237            return Err(TpmDeviceError::UnexpectedEof);
238        }
239
240        if revents.contains(PollFlags::POLLIN) {
241            match self.file.read(buf) {
242                Ok(0) => Err(TpmDeviceError::UnexpectedEof),
243                Ok(n) => Ok(n),
244                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
245                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
246                Err(e) => Err(e.into()),
247            }
248        } else if revents.contains(PollFlags::POLLHUP) {
249            Err(TpmDeviceError::UnexpectedEof)
250        } else {
251            Ok(0)
252        }
253    }
254
255    /// Performs the whole TPM command transmission process.
256    ///
257    /// # Errors
258    ///
259    /// Returns [`Interrupted`](crate::TpmDeviceError::Interrupted) when the
260    /// interrupt callback requests cancellation.
261    /// Returns [`Timeout`](crate::TpmDeviceError::Timeout) when the TPM does
262    /// not respond within the configured timeout.
263    /// Returns [`Io`](crate::TpmDeviceError::Io) when a write, flush, or read
264    /// operation on the device file fails, or when polling the device file
265    /// descriptor fails.
266    /// Returns [`InvalidResponse`](crate::TpmDeviceError::InvalidResponse) or
267    /// [`UnexpectedEof`](crate::TpmDeviceError::UnexpectedEof) when the TPM
268    /// reply is malformed, truncated, or longer than the announced size.
269    /// Returns [`Marshal`](crate::TpmDeviceError::Marshal) or
270    /// [`Unmarshal`](crate::TpmDeviceError::Unmarshal) when encoding the
271    /// command or decoding the response fails.
272    /// Returns [`TpmRc`](crate::TpmDeviceError::TpmRc) when the TPM returns an
273    /// error code.
274    pub fn transmit<C: TpmFrame>(
275        &mut self,
276        command: &C,
277        sessions: &[TpmsAuthCommand],
278    ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
279        self.prepare_command(command, sessions)?;
280        let cc = command.cc();
281
282        self.file.write_all(&self.command)?;
283        self.file.flush()?;
284
285        let start_time = Instant::now();
286        self.response.clear();
287        let mut total_size: Option<usize> = None;
288        let mut temp_buf = [0u8; 1024];
289
290        loop {
291            if (self.interrupted)() {
292                return Err(TpmDeviceError::Interrupted);
293            }
294            if start_time.elapsed() > self.timeout {
295                return Err(TpmDeviceError::Timeout);
296            }
297
298            let n = self.receive(&mut temp_buf)?;
299            if n > 0 {
300                self.response.extend_from_slice(&temp_buf[..n]);
301            }
302
303            if total_size.is_none() && self.response.len() >= 10 {
304                let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
305                    return Err(TpmDeviceError::InvalidResponse);
306                };
307                let size = u32::from_be_bytes(size_bytes) as usize;
308                if !(10..=TPM_MAX_COMMAND_SIZE as usize).contains(&size) {
309                    return Err(TpmDeviceError::InvalidResponse);
310                }
311                total_size = Some(size);
312            }
313
314            if let Some(size) = total_size {
315                if self.response.len() == size {
316                    break;
317                }
318                if self.response.len() > size {
319                    return Err(TpmDeviceError::InvalidResponse);
320                }
321            }
322        }
323
324        let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
325        trace!("{} R: {}", cc, hex::encode(&self.response));
326        Ok(result??)
327    }
328
329    fn prepare_command<C: TpmFrame>(
330        &mut self,
331        command: &C,
332        sessions: &[TpmsAuthCommand],
333    ) -> Result<(), TpmDeviceError> {
334        let cc = command.cc();
335        let tag = if sessions.is_empty() {
336            TpmSt::NoSessions
337        } else {
338            TpmSt::Sessions
339        };
340
341        self.command.resize(TPM_MAX_COMMAND_SIZE as usize, 0);
342
343        let len = {
344            let mut writer = TpmWriter::new(&mut self.command);
345            tpm_marshal_command(command, tag, sessions, &mut writer)
346                .map_err(TpmDeviceError::Marshal)?;
347            writer.len()
348        };
349        self.command.truncate(len);
350
351        trace!("{} C: {}", cc, hex::encode(&self.command));
352        Ok(())
353    }
354
355    /// Fetches a complete list of capabilities from the TPM, handling
356    /// pagination.
357    ///
358    /// # Errors
359    ///
360    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) returned by
361    /// [`get_capability_page`](TpmDevice::get_capability_page) or by the
362    /// `extract` closure.
363    fn get_capability<T, F, N>(
364        &mut self,
365        cap: TpmCap,
366        property_start: u32,
367        count: u32,
368        mut extract: F,
369        next_prop: N,
370    ) -> Result<Vec<T>, TpmDeviceError>
371    where
372        T: Copy,
373        F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
374        N: Fn(&T) -> u32,
375    {
376        let mut results = Vec::new();
377        let mut prop = property_start;
378        loop {
379            let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
380            let items: &[T] = extract(&cap_data.data)?;
381            results.extend_from_slice(items);
382
383            if more_data {
384                if let Some(last) = items.last() {
385                    prop = next_prop(last);
386                } else {
387                    break;
388                }
389            } else {
390                break;
391            }
392        }
393        Ok(results)
394    }
395
396    /// Retrieves all algorithm properties supported by the TPM.
397    ///
398    /// # Errors
399    ///
400    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
401    /// the handle count cannot be represented as `u32`. Propagates any
402    /// [`TpmDeviceError`](crate::TpmDeviceError) from
403    /// [`get_capability`](TpmDevice::get_capability), including
404    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
405    /// TPM does not report algorithm properties.
406    pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
407        self.get_capability(
408            TpmCap::Algs,
409            0,
410            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
411            |caps| match caps {
412                TpmuCapabilities::Algs(algs) => Ok(algs),
413                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
414            },
415            |last| last.alg as u32 + 1,
416        )
417    }
418
419    /// Retrieves all handles of a specific type from the TPM.
420    ///
421    /// # Errors
422    ///
423    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
424    /// the handle count cannot be represented as `u32`. Propagates any
425    /// [`TpmDeviceError`](crate::TpmDeviceError) from
426    /// [`get_capability`](TpmDevice::get_capability), including
427    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
428    /// TPM does not report handles of the requested class.
429    pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
430        self.get_capability(
431            TpmCap::Handles,
432            (class as u32) << 24,
433            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
434            |caps| match caps {
435                TpmuCapabilities::Handles(handles) => Ok(handles),
436                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
437            },
438            |last| *last + 1,
439        )
440        .map(|handles| handles.into_iter().map(TpmHandle).collect())
441    }
442
443    /// Retrieves all available ECC curves supported by the TPM.
444    ///
445    /// # Errors
446    ///
447    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
448    /// the handle count cannot be represented as `u32`. Propagates any
449    /// [`TpmDeviceError`](crate::TpmDeviceError) from
450    /// [`get_capability`](TpmDevice::get_capability), including
451    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
452    /// TPM does not report ECC curves.
453    pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
454        self.get_capability(
455            TpmCap::EccCurves,
456            0,
457            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
458            |caps| match caps {
459                TpmuCapabilities::EccCurves(curves) => Ok(curves),
460                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
461            },
462            |last| *last as u32 + 1,
463        )
464    }
465
466    /// Retrieves the list of active PCR banks and the bank size.
467    ///
468    /// # Errors
469    ///
470    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
471    /// the handle count cannot be represented as `u32`. Propagates any
472    /// [`TpmDeviceError`](crate::TpmDeviceError) from
473    /// [`get_capability`](TpmDevice::get_capability), including
474    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
475    /// TPM does not report PCRs.
476    /// Returns [`PcrBanksNotAvailable`](crate::TpmDeviceError::PcrBanksNotAvailable)
477    /// if the list of banks is empty.
478    /// Returns [`PcrBankSizeMismatch`](crate::TpmDeviceError::PcrBankSizeMismatch)
479    /// if bank sizes are inconsistent.
480    pub fn fetch_pcr_bank_list(&mut self) -> Result<(usize, Vec<TpmAlgId>), TpmDeviceError> {
481        let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
482            TpmCap::Pcrs,
483            0,
484            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
485            |caps| match caps {
486                TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
487                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
488            },
489            |last| last.hash as u32 + 1,
490        )?;
491
492        if pcrs.is_empty() {
493            return Err(TpmDeviceError::PcrBanksNotAvailable);
494        }
495
496        let mut count = 0;
497        let mut algs = Vec::with_capacity(pcrs.len());
498
499        for bank in pcrs {
500            let next_count = bank.pcr_select.len();
501            if count == 0 {
502                count = next_count;
503            }
504            if next_count != count {
505                return Err(TpmDeviceError::PcrBankSizeMismatch);
506            }
507            algs.push(bank.hash);
508        }
509
510        algs.sort();
511        Ok((count, algs))
512    }
513
514    /// Fetches and returns one page of capabilities of a certain type from the
515    /// TPM.
516    ///
517    /// # Errors
518    ///
519    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
520    /// [`transmit`](TpmDevice::transmit). Returns
521    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
522    /// TPM response does not contain `TPM2_GetCapability` data.
523    fn get_capability_page(
524        &mut self,
525        cap: TpmCap,
526        property: u32,
527        property_count: u32,
528    ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
529        let cmd = TpmGetCapabilityCommand {
530            cap,
531            property,
532            property_count,
533            handles: [],
534        };
535
536        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
537        let TpmGetCapabilityResponse {
538            more_data,
539            capability_data,
540            handles: [],
541        } = resp
542            .GetCapability()
543            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
544
545        Ok((more_data.into(), capability_data))
546    }
547
548    /// Reads a specific TPM property.
549    ///
550    /// # Errors
551    ///
552    /// Returns [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing)
553    /// when the TPM does not report the requested property. Propagates any
554    /// [`TpmDeviceError`](crate::TpmDeviceError) from
555    /// [`get_capability_page`](TpmDevice::get_capability_page).
556    pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, TpmDeviceError> {
557        let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
558
559        let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
560            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
561        };
562
563        let Some(prop) = props.iter().find(|prop| prop.property == property) else {
564            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
565        };
566
567        Ok(prop.value)
568    }
569
570    /// Reads the public area of a TPM object.
571    ///
572    /// # Errors
573    ///
574    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
575    /// [`transmit`](TpmDevice::transmit). Returns
576    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
577    /// TPM response does not contain `TPM2_ReadPublic` data.
578    pub fn read_public(
579        &mut self,
580        handle: TpmHandle,
581    ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
582        if let Some(cached) = self.name_cache.get(&handle.0) {
583            return Ok(cached.clone());
584        }
585
586        let cmd = TpmReadPublicCommand { handles: [handle] };
587        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
588
589        let read_public_resp = resp
590            .ReadPublic()
591            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
592
593        let public = read_public_resp.out_public.inner;
594        let name = read_public_resp.name;
595
596        self.name_cache.insert(handle.0, (public.clone(), name));
597        Ok((public, name))
598    }
599
600    /// Finds a persistent handle by its `Tpm2bName`.
601    ///
602    /// # Errors
603    ///
604    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
605    /// [`fetch_handles`](TpmDevice::fetch_handles) and
606    /// [`read_public`](TpmDevice::read_public), except for TPM reference and
607    /// handle errors with base
608    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
609    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
610    /// invalid handles and skipped.
611    pub fn find_persistent(
612        &mut self,
613        target_name: &Tpm2bName,
614    ) -> Result<Option<TpmHandle>, TpmDeviceError> {
615        for handle in self.fetch_handles(TpmHt::Persistent)? {
616            match self.read_public(handle) {
617                Ok((_, name)) => {
618                    if name == *target_name {
619                        return Ok(Some(handle));
620                    }
621                }
622                Err(TpmDeviceError::TpmRc(rc)) => {
623                    let base = rc.base();
624                    if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
625                        continue;
626                    }
627                    return Err(TpmDeviceError::TpmRc(rc));
628                }
629                Err(e) => return Err(e),
630            }
631        }
632        Ok(None)
633    }
634
635    /// Saves the context of a transient object or session.
636    ///
637    /// # Errors
638    ///
639    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
640    /// [`transmit`](TpmDevice::transmit). Returns
641    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
642    /// TPM response does not contain `TPM2_ContextSave` data.
643    pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
644        let cmd = TpmContextSaveCommand {
645            handles: [save_handle],
646        };
647        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
648        let save_resp = resp
649            .ContextSave()
650            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
651        Ok(save_resp.context)
652    }
653
654    /// Loads a TPM context and returns the handle.
655    ///
656    /// # Errors
657    ///
658    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
659    /// [`transmit`](TpmDevice::transmit). Returns
660    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
661    /// TPM response does not contain `TPM2_ContextLoad` data.
662    pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
663        let cmd = TpmContextLoadCommand {
664            context,
665            handles: [],
666        };
667        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
668        let resp_inner = resp
669            .ContextLoad()
670            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
671        Ok(resp_inner.handles[0])
672    }
673
674    /// Flushes a transient object or session from the TPM and removes it from
675    /// the cache.
676    ///
677    /// # Errors
678    ///
679    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
680    /// [`transmit`](TpmDevice::transmit).
681    pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
682        self.name_cache.remove(&handle.0);
683        let cmd = TpmFlushContextCommand {
684            flush_handle: handle,
685            handles: [],
686        };
687        self.transmit(&cmd, Self::NO_SESSIONS)?;
688        Ok(())
689    }
690
691    /// Loads a session context and then flushes the resulting handle.
692    ///
693    /// # Errors
694    ///
695    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
696    /// [`load_context`](TpmDevice::load_context) or
697    /// [`flush_context`](TpmDevice::flush_context) except for TPM reference
698    /// errors with base
699    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
700    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
701    /// a successful no-op.
702    pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
703        match self.load_context(context) {
704            Ok(handle) => self.flush_context(handle),
705            Err(TpmDeviceError::TpmRc(rc)) => {
706                let base = rc.base();
707                if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
708                    Ok(())
709                } else {
710                    Err(TpmDeviceError::TpmRc(rc))
711                }
712            }
713            Err(e) => Err(e),
714        }
715    }
716
717    /// Refreshes a key context. Returns `true` if the context is still valid,
718    /// and `false` if it is stale.
719    ///
720    /// # Errors
721    ///
722    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
723    /// [`load_context`](TpmDevice::load_context) or
724    /// [`flush_context`](TpmDevice::flush_context) except for TPM reference
725    /// errors with base
726    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0), which are
727    /// treated as a stale context and reported as `Ok(false)`.
728    pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
729        match self.load_context(context) {
730            Ok(handle) => match self.flush_context(handle) {
731                Ok(()) => Ok(true),
732                Err(e) => Err(e),
733            },
734            Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
735            Err(e) => Err(e),
736        }
737    }
738}