cli/
device.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5use crate::{
6    cli::LogFormat,
7    crypto::crypto_hash_size,
8    handle::{Handle, HandleClass},
9    print::TpmPrint,
10    spinner::Spinner,
11    TEARDOWN,
12};
13use log::trace;
14use polling::{Event, Events, Poller};
15use rand::{thread_rng, RngCore};
16use std::{
17    cell::RefCell,
18    collections::HashMap,
19    fs::File,
20    io::{Read, Write},
21    num::TryFromIntError,
22    rc::Rc,
23    sync::atomic::Ordering,
24    time::{Duration, Instant},
25};
26use thiserror::Error;
27use tpm2_protocol::{
28    constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
29    data::{
30        Tpm2bEncryptedSecret, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCap, TpmCc, TpmHt, TpmPt, TpmRc,
31        TpmRcBase, TpmRh, TpmSe, TpmSt, TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData,
32        TpmsContext, TpmsRsaParms, TpmtPublic, TpmtPublicParms, TpmtSymDefObject, TpmuCapabilities,
33        TpmuPublicParms,
34    },
35    message::{
36        tpm_build_command, tpm_parse_response, TpmAuthResponses, TpmBodyBuild,
37        TpmContextLoadCommand, TpmContextSaveCommand, TpmEvictControlCommand,
38        TpmFlushContextCommand, TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmHeader,
39        TpmReadPublicCommand, TpmResponseBody, TpmStartAuthSessionCommand,
40        TpmStartAuthSessionResponse, TpmTestParmsCommand,
41    },
42    TpmError, TpmHandle, TpmWriter,
43};
44
45/// A type-erased object safe TPM command object
46pub trait TpmCommandObject: TpmPrint + TpmHeader + TpmBodyBuild {}
47impl<T> TpmCommandObject for T where T: TpmHeader + TpmBodyBuild + TpmPrint {}
48
49#[derive(Debug, Error)]
50pub enum DeviceError {
51    #[error("device is already borrowed")]
52    AlreadyBorrowed,
53    #[error("capability not found: {0}")]
54    CapabilityMissing(TpmCap),
55    #[error("operation interrupted by user")]
56    Interrupted,
57    #[error("invalid response")]
58    InvalidResponse,
59    #[error("device not available")]
60    NotAvailable,
61    #[error("response mismatch: {0}")]
62    ResponseMismatch(TpmCc),
63    #[error("TPM command timed out")]
64    Timeout,
65    #[error("int decode: {0}")]
66    IntDecode(#[from] TryFromIntError),
67    #[error("I/O: {0}")]
68    Io(#[from] std::io::Error),
69    #[error("syscall: {0}")]
70    Nix(#[from] nix::Error),
71    #[error("protocol: {0}")]
72    TpmProtocol(TpmError),
73    #[error("TPM return code: {0}")]
74    TpmRc(TpmRc),
75}
76
77impl From<TpmError> for DeviceError {
78    fn from(err: TpmError) -> Self {
79        Self::TpmProtocol(err)
80    }
81}
82
83impl From<TpmRc> for DeviceError {
84    fn from(rc: TpmRc) -> Self {
85        Self::TpmRc(rc)
86    }
87}
88
89/// Executes a closure with a mutable reference to a `Device`.
90///
91/// This helper function centralizes the boilerplate for safely acquiring a
92/// mutable borrow of a `Device` from the shared `Rc<RefCell<...>>`.
93///
94/// # Errors
95///
96/// Returns an error if the device is not available or is already borrowed. The
97/// error is converted into the caller's error type `E`.
98pub fn with_device<F, T, E>(device: Option<Rc<RefCell<Device>>>, f: F) -> Result<T, E>
99where
100    F: FnOnce(&mut Device) -> Result<T, E>,
101    E: From<DeviceError>,
102{
103    let device_rc = device.ok_or(DeviceError::NotAvailable)?;
104    let mut device_guard = device_rc
105        .try_borrow_mut()
106        .map_err(|_| DeviceError::AlreadyBorrowed)?;
107    f(&mut device_guard)
108}
109
110#[derive(Debug)]
111pub struct Device {
112    file: File,
113    poller: Poller,
114    log_format: LogFormat,
115    name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
116}
117
118/// Checks if the TPM supports a given set of RSA parameters.
119pub(crate) fn test_rsa_parms(device: &mut Device, key_bits: u16) -> Result<(), DeviceError> {
120    let cmd = TpmTestParmsCommand {
121        parameters: TpmtPublicParms {
122            object_type: TpmAlgId::Rsa,
123            parameters: TpmuPublicParms::Rsa(TpmsRsaParms {
124                key_bits,
125                ..Default::default()
126            }),
127        },
128    };
129    let sessions = vec![];
130    device.execute(&cmd, &sessions).map(|(_, _)| ())
131}
132
133impl Device {
134    /// Creates a new TPM device from an owned transport.
135    ///
136    /// # Errors
137    ///
138    /// Returns an error if the system poller cannot be created.
139    pub fn new(file: File, log_format: LogFormat) -> Result<Self, DeviceError> {
140        let poller = Poller::new()?;
141        Ok(Self {
142            file,
143            poller,
144            log_format,
145            name_cache: HashMap::new(),
146        })
147    }
148
149    fn receive_from_stream(&mut self) -> Result<Vec<u8>, DeviceError> {
150        let mut header = [0u8; 10];
151        self.file.read_exact(&mut header)?;
152        let Ok(size_bytes): Result<[u8; 4], _> = header[2..6].try_into() else {
153            return Err(DeviceError::InvalidResponse);
154        };
155        let size = u32::from_be_bytes(size_bytes) as usize;
156        if size < header.len() || size > TPM_MAX_COMMAND_SIZE {
157            return Err(DeviceError::InvalidResponse);
158        }
159        let mut resp_buf = header.to_vec();
160        resp_buf.resize(size, 0);
161        self.file.read_exact(&mut resp_buf[header.len()..])?;
162        Ok(resp_buf)
163    }
164
165    /// Performs the whole TPM command transmission process.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`Interrupted`](crate::device::DeviceError::Interrupted) when
170    /// user interrupts the program.
171    /// Returns [`Io`](crate::device::DeviceError::Io) when an I/O operation
172    /// fails.
173    /// Returns [`Timeout`](crate::device::DeviceError::Timeout) when the
174    /// transmission timeouts.
175    /// Returns [`TpmProtocol`](crate::device::DeviceError::TpmProtocol) when
176    /// either built command or parsed response is malformed.
177    /// Returns [`TpmRc`](crate::device::DeviceError::TpmRc) when the chip
178    /// responses with a return code.
179    pub fn execute<C: TpmCommandObject>(
180        &mut self,
181        command: &C,
182        sessions: &[TpmsAuthCommand],
183    ) -> Result<(TpmResponseBody, TpmAuthResponses), DeviceError> {
184        let command_vec = self.build_command_buffer(command, sessions)?;
185        let cc = command.cc();
186
187        let mut spinner = Spinner::new("Waiting for TPM...");
188
189        self.file.write_all(&command_vec)?;
190        self.file.flush()?;
191
192        let mut events = Events::new();
193        unsafe { self.poller.add(&self.file, Event::readable(0))? };
194
195        let start_time = Instant::now();
196        let resp_buf = loop {
197            if TEARDOWN.load(Ordering::Relaxed) {
198                spinner.finish();
199                let _ = self.poller.delete(&self.file);
200                break Err(DeviceError::Interrupted);
201            }
202            if start_time.elapsed() > Duration::from_secs(60) {
203                spinner.finish();
204                let _ = self.poller.delete(&self.file);
205                break Err(DeviceError::Timeout);
206            }
207
208            spinner.tick();
209
210            self.poller
211                .wait(&mut events, Some(Duration::from_millis(100)))?;
212
213            if !events.is_empty() {
214                let _ = self.poller.delete(&self.file);
215                break self.receive_from_stream();
216            }
217        }?;
218
219        let result = tpm_parse_response(cc, &resp_buf);
220        if self.log_format == LogFormat::Pretty {
221            let mut buf = Vec::new();
222            match &result {
223                Ok(Ok((response, _))) => {
224                    response.print(&mut buf, "Response", 1)?;
225                    for line in String::from_utf8_lossy(&buf).lines() {
226                        trace!(target: "cli::device", "{line}");
227                    }
228                }
229                Ok(Err(_)) | Err(_) => {
230                    trace!(
231                        target: "cli::device",
232                        "Response: {}",
233                        hex::encode(&resp_buf)
234                    );
235                }
236            }
237        } else {
238            trace!(
239                target: "cli::device",
240                "Response: {}",
241                hex::encode(&resp_buf)
242            );
243        }
244        Ok(result??)
245    }
246
247    fn build_command_buffer<C: TpmCommandObject>(
248        &self,
249        command: &C,
250        sessions: &[TpmsAuthCommand],
251    ) -> Result<Vec<u8>, DeviceError> {
252        let cc = command.cc();
253        let tag = if sessions.is_empty() {
254            TpmSt::NoSessions
255        } else {
256            TpmSt::Sessions
257        };
258        let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
259        let len = {
260            let mut writer = TpmWriter::new(&mut buf);
261            tpm_build_command(command, tag, sessions, &mut writer)?;
262            writer.len()
263        };
264        buf.truncate(len);
265
266        if self.log_format == LogFormat::Pretty {
267            let mut print_buf = Vec::new();
268            writeln!(&mut print_buf, "{cc}")?;
269            command.print(&mut print_buf, "Command", 1)?;
270            for line in String::from_utf8_lossy(&print_buf).lines() {
271                trace!(target: "cli::device", "{line}");
272            }
273        } else {
274            trace!(
275                target: "cli::device",
276                "Command: {}",
277                hex::encode(&buf)
278            );
279        }
280        Ok(buf)
281    }
282
283    /// Fetches a complete list of capabilities from the TPM, handling pagination.
284    ///
285    /// # Errors
286    ///
287    /// This function will return an error if the underlying `execute` call fails
288    /// or if the TPM returns a response of an unexpected type.
289    pub fn get_capability<T, F, N>(
290        &mut self,
291        cap: TpmCap,
292        property_start: u32,
293        count: u32,
294        mut extract: F,
295        next_prop: N,
296    ) -> Result<Vec<T>, DeviceError>
297    where
298        T: Copy,
299        F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], DeviceError>,
300        N: Fn(&T) -> u32,
301    {
302        let mut results = Vec::new();
303        let mut prop = property_start;
304        loop {
305            let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
306            let items: &[T] = extract(&cap_data.data)?;
307            results.extend_from_slice(items);
308
309            if more_data {
310                if let Some(last) = items.last() {
311                    prop = next_prop(last);
312                } else {
313                    break;
314                }
315            } else {
316                break;
317            }
318        }
319        Ok(results)
320    }
321
322    /// Retrieves all algorithm properties supported by the TPM.
323    pub(crate) fn fetch_algorithm_properties(
324        &mut self,
325    ) -> Result<Vec<TpmsAlgProperty>, DeviceError> {
326        self.get_capability(
327            TpmCap::Algs,
328            0,
329            u32::try_from(MAX_HANDLES)?,
330            |caps| match caps {
331                TpmuCapabilities::Algs(algs) => Ok(algs),
332                _ => Err(DeviceError::CapabilityMissing(TpmCap::Algs)),
333            },
334            |last| last.alg as u32 + 1,
335        )
336    }
337
338    /// Retrieves all handles of a specific type from the TPM.
339    ///
340    /// # Errors
341    ///
342    /// Returns a `DeviceError` if the `get_capability_page` call to the TPM device fails.
343    pub fn fetch_handles(&mut self, class: u32) -> Result<Vec<Handle>, DeviceError> {
344        self.get_capability(
345            TpmCap::Handles,
346            class,
347            u32::try_from(MAX_HANDLES)?,
348            |caps| match caps {
349                TpmuCapabilities::Handles(handles) => Ok(handles),
350                _ => Err(DeviceError::CapabilityMissing(TpmCap::Handles)),
351            },
352            |last| *last + 1,
353        )
354        .map(|handles| {
355            handles
356                .into_iter()
357                .map(|h| Handle((HandleClass::Tpm, h)))
358                .collect()
359        })
360    }
361
362    /// Fetches and returns one page of capabilities of a certain type from the TPM.
363    ///
364    /// # Errors
365    ///
366    /// This function will return an error if the underlying `execute` call fails
367    /// or if the TPM returns a response of an unexpected type.
368    pub fn get_capability_page(
369        &mut self,
370        cap: TpmCap,
371        property: u32,
372        count: u32,
373    ) -> Result<(bool, TpmsCapabilityData), DeviceError> {
374        let cmd = TpmGetCapabilityCommand {
375            cap,
376            property,
377            property_count: count,
378        };
379        let sessions = vec![];
380
381        let (resp, _) = self.execute(&cmd, &sessions)?;
382        let TpmGetCapabilityResponse {
383            more_data,
384            capability_data,
385        } = resp
386            .GetCapability()
387            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::GetCapability))?;
388
389        Ok((more_data.into(), capability_data))
390    }
391
392    /// Reads a specific TPM property.
393    ///
394    /// # Errors
395    ///
396    /// Returns a `DeviceError` if the capability or property is not found, or
397    /// if the `get_capability` call fails.
398    pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, DeviceError> {
399        let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
400
401        let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
402            return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
403        };
404
405        let Some(prop) = props.first() else {
406            return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
407        };
408
409        Ok(prop.value)
410    }
411
412    /// Reads the public area of a TPM object.
413    ///
414    /// # Errors
415    ///
416    /// Returns a `DeviceError` if the underlying `TPM2_ReadPublic` command
417    /// execution fails or if the TPM returns a response of an unexpected type.
418    pub fn read_public(
419        &mut self,
420        handle: TpmHandle,
421    ) -> Result<(TpmtPublic, Tpm2bName), DeviceError> {
422        if let Some(cached) = self.name_cache.get(&handle.0) {
423            return Ok(cached.clone());
424        }
425
426        let cmd = TpmReadPublicCommand {
427            object_handle: handle,
428        };
429        let sessions = vec![];
430        let (resp, _) = self.execute(&cmd, &sessions)?;
431
432        let read_public_resp = resp
433            .ReadPublic()
434            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
435
436        let public = read_public_resp.out_public.inner;
437        let name = read_public_resp.name;
438
439        self.name_cache.insert(handle.0, (public.clone(), name));
440        Ok((public, name))
441    }
442
443    /// Finds a persistent handle by its public area.
444    ///
445    /// # Errors
446    ///
447    /// Returns a `DeviceError` if fetching handles or reading public areas fails.
448    pub fn find_persistent(
449        &mut self,
450        target: &TpmtPublic,
451    ) -> Result<Option<(TpmHandle, Tpm2bName)>, DeviceError> {
452        let handles = self.fetch_handles((TpmHt::Persistent as u32) << 24)?;
453        for handle in handles {
454            if let Ok((public, name)) = self.read_public(handle.value().into()) {
455                if public == *target {
456                    return Ok(Some((handle.value().into(), name)));
457                }
458            }
459        }
460        Ok(None)
461    }
462
463    /// Saves the context of a transient object or session.
464    ///
465    /// # Errors
466    ///
467    /// Returns a `DeviceError` if the underlying `TPM2_ContextSave` command
468    /// execution fails or if the TPM returns a response of an unexpected type.
469    pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, DeviceError> {
470        let cmd = TpmContextSaveCommand { save_handle };
471        let sessions = vec![];
472        let (resp, _) = self.execute(&cmd, &sessions)?;
473        let save_resp = resp
474            .ContextSave()
475            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextSave))?;
476        Ok(save_resp.context)
477    }
478
479    /// Loads a TPM context and returns the handle.
480    ///
481    /// # Errors
482    ///
483    /// Returns a `DeviceError` if the `TPM2_ContextLoad` command fails.
484    pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, DeviceError> {
485        let cmd = TpmContextLoadCommand { context };
486        let sessions = vec![];
487        let (resp, _) = self.execute(&cmd, &sessions)?;
488        let resp_inner = resp
489            .ContextLoad()
490            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
491        Ok(resp_inner.loaded_handle)
492    }
493
494    /// Flushes a transient object or session from the TPM and removes it from the cache.
495    ///
496    /// # Errors
497    ///
498    /// Returns a `DeviceError` if the underlying `TPM2_FlushContext` command
499    /// execution fails.
500    pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), DeviceError> {
501        self.name_cache.remove(&handle.0);
502        let cmd = TpmFlushContextCommand {
503            flush_handle: handle,
504        };
505        let sessions = vec![];
506        self.execute(&cmd, &sessions)?;
507        Ok(())
508    }
509
510    /// Loads a session context and then flushes the resulting handle.
511    ///
512    /// # Errors
513    ///
514    /// Returns `DeviceError` on `ContextLoad` or `FlushContext` failure.
515    pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), DeviceError> {
516        match self.load_context(context) {
517            Ok(handle) => self.flush_context(handle),
518            Err(DeviceError::TpmRc(rc)) => {
519                let base = rc.base();
520                if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
521                    Ok(())
522                } else {
523                    Err(DeviceError::TpmRc(rc))
524                }
525            }
526            Err(e) => Err(e),
527        }
528    }
529
530    /// Starts a new authorization session.
531    ///
532    /// This function sends a `TPM2_StartAuthSession` command to the TPM and
533    /// returns the raw response, which can be used to construct a higher-level
534    /// session object.
535    ///
536    /// # Errors
537    ///
538    /// Returns `DeviceError` on TPM command failure.
539    pub fn start_session(
540        &mut self,
541        session_type: TpmSe,
542        auth_hash: TpmAlgId,
543        bind: TpmHandle,
544    ) -> Result<(TpmStartAuthSessionResponse, Tpm2bNonce), DeviceError> {
545        let digest_len =
546            crypto_hash_size(auth_hash).ok_or(DeviceError::TpmProtocol(TpmError::MalformedData))?;
547        let mut nonce_bytes = vec![0; digest_len];
548        thread_rng().fill_bytes(&mut nonce_bytes);
549        let nonce_caller = Tpm2bNonce::try_from(nonce_bytes.as_slice())?;
550
551        let cmd = TpmStartAuthSessionCommand {
552            tpm_key: (TpmRh::Null as u32).into(),
553            bind,
554            nonce_caller,
555            encrypted_salt: Tpm2bEncryptedSecret::default(),
556            session_type,
557            symmetric: TpmtSymDefObject::default(),
558            auth_hash,
559        };
560        let sessions = vec![];
561
562        let (response_body, _) = self.execute(&cmd, &sessions)?;
563
564        let resp = response_body
565            .StartAuthSession()
566            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::StartAuthSession))?;
567
568        Ok((resp, nonce_caller))
569    }
570
571    /// Evicts a persistent object or makes a transient object persistent.
572    ///
573    /// # Errors
574    ///
575    /// Returns `DeviceError` on TPM command failure.
576    pub fn evict_control(
577        &mut self,
578        auth: TpmHandle,
579        object_handle: TpmHandle,
580        persistent_handle: TpmHandle,
581        sessions: &[TpmsAuthCommand],
582    ) -> Result<(), DeviceError> {
583        let cmd = TpmEvictControlCommand {
584            auth,
585            object_handle: object_handle.0.into(),
586            persistent_handle,
587        };
588        let (resp, _) = self.execute(&cmd, sessions)?;
589
590        resp.EvictControl()
591            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::EvictControl))?;
592        Ok(())
593    }
594}