cli/
device.rs

1// SPDX-License-Identifier: GPL-3.0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5use crate::{
6    cli::LogFormat,
7    crypto::CryptoError,
8    key::{Tpm2shAlgId, Tpm2shEccCurve},
9    print::TpmPrint,
10    transport::{receive_from_stream, FileTransport, Transport},
11    TEARDOWN,
12};
13
14use std::{
15    cell::RefCell,
16    collections::HashMap,
17    io::{IsTerminal, Write},
18    num::TryFromIntError,
19    rc::Rc,
20    sync::atomic::Ordering,
21    time::{Duration, Instant},
22};
23
24use indicatif::{ProgressBar, ProgressStyle};
25use log::trace;
26use polling::{Event, Events, Poller};
27use rand::{thread_rng, RngCore};
28use thiserror::Error;
29use tpm2_protocol::{
30    constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
31    data::{
32        Tpm2bEncryptedSecret, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCap, TpmCc, TpmHt, TpmPt, TpmRc,
33        TpmRcBase, TpmRh, TpmSe, TpmSt, TpmsAuthCommand, TpmsCapabilityData, TpmsContext,
34        TpmsRsaParms, TpmtPublic, TpmtPublicParms, TpmtSymDefObject, TpmuCapabilities,
35        TpmuPublicParms,
36    },
37    message::{
38        tpm_build_command, tpm_parse_response, TpmAuthResponses, TpmBodyBuild,
39        TpmContextLoadCommand, TpmContextSaveCommand, TpmFlushContextCommand,
40        TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmHeader, TpmReadPublicCommand,
41        TpmResponseBody, TpmStartAuthSessionCommand, TpmStartAuthSessionResponse,
42        TpmTestParmsCommand,
43    },
44    tpm_hash_size, TpmErrorKind, TpmHandle, TpmWriter,
45};
46
47pub const TPM_CAP_PROPERTY_MAX: u32 = 128;
48
49/// A type-erased object safe TPM command object
50pub trait TpmCommandObject: TpmPrint + TpmHeader + TpmBodyBuild {}
51impl<T> TpmCommandObject for T where T: TpmHeader + TpmBodyBuild + TpmPrint {}
52
53/// Represents an authorization method for a command.
54#[derive(Debug, Clone)]
55pub enum Auth {
56    /// A stateful, tracked session identified by its handle.
57    Tracked(u32),
58    /// A stateless password session.
59    Password(Vec<u8>),
60}
61
62/// A type alias for a list of authentications, to attach methods.
63pub type AuthList = Vec<Auth>;
64
65#[derive(Debug, Error)]
66pub enum DeviceError {
67    #[error("I/O: {0}")]
68    Io(#[from] std::io::Error),
69    #[error("syscall: {0}")]
70    Nix(#[from] nix::Error),
71    #[error("response corrupted")]
72    ResponseCorrupted,
73    #[error("response mismatch: {0}")]
74    ResponseMismatch(TpmCc),
75    #[error("operation interrupted by user")]
76    Interrupted,
77    #[error("crypto: {0}")]
78    Crypto(#[from] CryptoError),
79    #[error("unknown handle name: {0:08x}")]
80    UnknownHandleName(u32),
81    #[error("TPM: {0}")]
82    Tpm(TpmErrorKind),
83    #[error("TPM RC: {0}")]
84    TpmRc(TpmRc),
85    #[error("TPM command timed out")]
86    Timeout,
87    #[error("device not available")]
88    NotAvailable,
89    #[error("device is already borrowed")]
90    AlreadyBorrowed,
91    #[error("capability not found: {0}")]
92    CapabilityMissing(TpmCap),
93}
94
95impl From<TpmErrorKind> for DeviceError {
96    fn from(err: TpmErrorKind) -> Self {
97        Self::Tpm(err)
98    }
99}
100
101impl From<TpmRc> for DeviceError {
102    fn from(rc: TpmRc) -> Self {
103        Self::TpmRc(rc)
104    }
105}
106
107impl From<TryFromIntError> for DeviceError {
108    fn from(_err: TryFromIntError) -> Self {
109        Self::Tpm(TpmErrorKind::InvalidValue)
110    }
111}
112
113/// Executes a closure with a mutable reference to a `Device`.
114///
115/// This helper function centralizes the boilerplate for safely acquiring a
116/// mutable borrow of a `Device` from the shared `Rc<RefCell<...>>`.
117///
118/// # Errors
119///
120/// Returns an error if the device is not available or is already borrowed. The
121/// error is converted into the caller's error type `E`.
122pub fn with_device<F, T, E>(device: Option<Rc<RefCell<Device>>>, f: F) -> Result<T, E>
123where
124    F: FnOnce(&mut Device) -> Result<T, E>,
125    E: From<DeviceError>,
126{
127    let device_rc = device.ok_or(DeviceError::NotAvailable)?;
128    let mut device_guard = device_rc
129        .try_borrow_mut()
130        .map_err(|_| DeviceError::AlreadyBorrowed)?;
131    f(&mut device_guard)
132}
133
134#[derive(Debug)]
135pub struct Device {
136    transport: Box<dyn Transport>,
137    poller: Poller,
138    log_format: LogFormat,
139    name_cache: HashMap<u32, Tpm2bName>,
140}
141
142/// Checks if the TPM supports a given set of RSA parameters.
143fn test_rsa_parms(device: &mut Device, key_bits: u16) -> Result<(), DeviceError> {
144    let cmd = TpmTestParmsCommand {
145        parameters: TpmtPublicParms {
146            object_type: TpmAlgId::Rsa,
147            parameters: TpmuPublicParms::Rsa(TpmsRsaParms {
148                key_bits,
149                ..Default::default()
150            }),
151        },
152    };
153    let sessions = vec![];
154    device.execute(&cmd, &sessions).map(|(_, _)| ())
155}
156
157impl Device {
158    /// Creates a new TPM device from an owned transport.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if the system poller cannot be created.
163    pub fn new(
164        transport: impl Transport + 'static,
165        log_format: LogFormat,
166    ) -> Result<Self, DeviceError> {
167        let poller = Poller::new()?;
168        Ok(Self {
169            transport: Box::new(transport),
170            poller,
171            log_format,
172            name_cache: HashMap::new(),
173        })
174    }
175
176    /// Adds a transient handle's name to the internal cache.
177    pub fn add_name_to_cache(&mut self, handle: u32, name: Tpm2bName) {
178        self.name_cache.insert(handle, name);
179    }
180
181    /// Retrieves the TPM Name for a handle, required for authorization computations.
182    pub(crate) fn get_handle_name(&mut self, handle: u32) -> Result<Tpm2bName, DeviceError> {
183        if let Some(name) = self.name_cache.get(&handle) {
184            return Ok(*name);
185        }
186
187        let mso = (handle >> 24) as u8;
188        if mso == TpmHt::Transient as u8 || mso == TpmHt::Persistent as u8 {
189            let (_, name) = self.read_public(handle.into())?;
190            Ok(name)
191        } else {
192            Tpm2bName::try_from(handle.to_be_bytes().as_slice()).map_err(Into::into)
193        }
194    }
195
196    fn receive_with_progress(&mut self) -> Result<Vec<u8>, DeviceError> {
197        if let Some(ft) = self.transport.as_any_mut().downcast_mut::<FileTransport>() {
198            let spinner = ProgressBar::new_spinner();
199            spinner.enable_steady_tick(Duration::from_millis(100));
200            spinner.set_style(
201                ProgressStyle::with_template("{spinner:.green} {msg}")
202                    .expect("Invalid progress spinner template")
203                    .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏ "),
204            );
205            spinner.set_message("Waiting for TPM...");
206
207            let mut events = Events::new();
208            let file = &mut ft.0;
209            unsafe { self.poller.add(&*file, Event::readable(0))? };
210
211            let start_time = Instant::now();
212            let result = loop {
213                if TEARDOWN.load(Ordering::Relaxed) {
214                    break Err(DeviceError::Interrupted);
215                }
216                if start_time.elapsed() > Duration::from_secs(60) {
217                    break Err(DeviceError::Timeout);
218                }
219
220                self.poller
221                    .wait(&mut events, Some(Duration::from_millis(100)))?;
222                if !events.is_empty() {
223                    break receive_from_stream(file);
224                }
225            };
226
227            spinner.finish_and_clear();
228            self.poller.delete(&*file)?;
229            result
230        } else {
231            self.transport.receive()
232        }
233    }
234
235    /// Sends a command to the TPM and waits for the response.
236    ///
237    /// # Errors
238    ///
239    /// This function will return an error if building the command fails, I/O
240    /// with the device fails, or the TPM itself returns an error.
241    pub fn execute<C: TpmCommandObject>(
242        &mut self,
243        command: &C,
244        sessions: &[TpmsAuthCommand],
245    ) -> Result<(TpmResponseBody, TpmAuthResponses), DeviceError> {
246        let command_vec = self.build_command_buffer(command, sessions)?;
247        let cc = command.cc();
248        self.transport.send(&command_vec)?;
249        let resp_buf = if std::io::stderr().is_terminal() {
250            self.receive_with_progress()?
251        } else {
252            self.transport.receive()?
253        };
254        let result = tpm_parse_response(cc, &resp_buf);
255        if self.log_format == LogFormat::Pretty {
256            let mut buf = Vec::new();
257            match &result {
258                Ok(Ok((response, _))) => {
259                    response.print(&mut buf, "Response", 1)?;
260                    for line in String::from_utf8_lossy(&buf).lines() {
261                        trace!(target: "cli::device", "{line}");
262                    }
263                }
264                Ok(Err(_)) | Err(_) => {
265                    trace!(
266                        target: "cli::device",
267                        "Response: {}",
268                        hex::encode(&resp_buf)
269                    );
270                }
271            }
272        } else {
273            trace!(
274                target: "cli::device",
275                "Response: {}",
276                hex::encode(&resp_buf)
277            );
278        }
279        Ok(result??)
280    }
281
282    fn build_command_buffer<C: TpmCommandObject>(
283        &self,
284        command: &C,
285        sessions: &[TpmsAuthCommand],
286    ) -> Result<Vec<u8>, DeviceError> {
287        let cc = command.cc();
288        let tag = if sessions.is_empty() {
289            TpmSt::NoSessions
290        } else {
291            TpmSt::Sessions
292        };
293        let mut buf = vec![0u8; TPM_MAX_COMMAND_SIZE];
294        let len = {
295            let mut writer = TpmWriter::new(&mut buf);
296            tpm_build_command(command, tag, sessions, &mut writer)?;
297            writer.len()
298        };
299        buf.truncate(len);
300
301        if self.log_format == LogFormat::Pretty {
302            let mut print_buf = Vec::new();
303            writeln!(&mut print_buf, "{cc}")?;
304            command.print(&mut print_buf, "Command", 1)?;
305            for line in String::from_utf8_lossy(&print_buf).lines() {
306                trace!(target: "cli::device", "{line}");
307            }
308        } else {
309            trace!(
310                target: "cli::device",
311                "Command: {}",
312                hex::encode(&buf)
313            );
314        }
315        Ok(buf)
316    }
317
318    /// Retrieves all supported algorithms from the TPM by probing its capabilities.
319    ///
320    /// # Errors
321    ///
322    /// Returns a `DeviceError` if querying the TPM fails.
323    pub fn get_all_algorithms(&mut self) -> Result<Vec<(TpmAlgId, String)>, DeviceError> {
324        let mut supported_algs = Vec::new();
325        let mut all_algs = Vec::new();
326        let mut prop = 0;
327        loop {
328            let (more_data, cap_data) =
329                self.get_capability(TpmCap::Algs, prop, u32::try_from(MAX_HANDLES)?)?;
330
331            if let TpmuCapabilities::Algs(p) = cap_data.data {
332                all_algs.extend(p.iter().map(|prop| prop.alg));
333            } else {
334                return Err(DeviceError::CapabilityMissing(TpmCap::Algs));
335            }
336
337            if more_data {
338                if let TpmuCapabilities::Algs(algs) = cap_data.data {
339                    prop = algs.last().map_or(prop, |p| p.alg as u32 + 1);
340                }
341            } else {
342                break;
343            }
344        }
345        let all_algs: std::collections::HashSet<TpmAlgId> = all_algs.into_iter().collect();
346
347        let name_algs: Vec<TpmAlgId> = [TpmAlgId::Sha256, TpmAlgId::Sha384, TpmAlgId::Sha512]
348            .into_iter()
349            .filter(|alg| all_algs.contains(alg))
350            .collect();
351
352        if all_algs.contains(&TpmAlgId::Rsa) {
353            let rsa_key_sizes = [2048, 3072, 4096];
354            for key_bits in rsa_key_sizes {
355                match test_rsa_parms(self, key_bits) {
356                    Ok(()) => {
357                        for &name_alg in &name_algs {
358                            supported_algs.push((
359                                TpmAlgId::Rsa,
360                                format!("rsa-{}:{}", key_bits, Tpm2shAlgId(name_alg)),
361                            ));
362                        }
363                    }
364                    Err(DeviceError::TpmRc(rc)) => {
365                        if rc.base() != TpmRcBase::Value {
366                            return Err(DeviceError::TpmRc(rc));
367                        }
368                    }
369                    Err(e) => return Err(e),
370                }
371            }
372        }
373
374        if all_algs.contains(&TpmAlgId::Ecc) {
375            let mut supported_curves = Vec::new();
376            let mut prop = 0;
377            loop {
378                let (more_data, cap_data) =
379                    self.get_capability(TpmCap::EccCurves, prop, u32::try_from(MAX_HANDLES)?)?;
380                if let TpmuCapabilities::EccCurves(curves) = &cap_data.data {
381                    supported_curves.extend(curves.iter().copied());
382                } else {
383                    return Err(DeviceError::CapabilityMissing(TpmCap::EccCurves));
384                }
385                if more_data {
386                    if let TpmuCapabilities::EccCurves(curves) = cap_data.data {
387                        prop = curves.last().map_or(prop, |&c| c as u32 + 1);
388                    }
389                } else {
390                    break;
391                }
392            }
393            for curve_id in supported_curves {
394                for &name_alg in &name_algs {
395                    supported_algs.push((
396                        TpmAlgId::Ecc,
397                        format!(
398                            "ecc-{}:{}",
399                            Tpm2shEccCurve::from(curve_id),
400                            Tpm2shAlgId(name_alg)
401                        ),
402                    ));
403                }
404            }
405        }
406
407        if all_algs.contains(&TpmAlgId::KeyedHash) {
408            for &name_alg in &name_algs {
409                supported_algs.push((
410                    TpmAlgId::KeyedHash,
411                    format!("keyedhash:{}", Tpm2shAlgId(name_alg)),
412                ));
413            }
414        }
415
416        Ok(supported_algs)
417    }
418
419    /// Retrieves all supported hash algorithms from the TPM.
420    ///
421    /// # Errors
422    ///
423    /// Returns a `DeviceError` if querying the TPM fails.
424    pub fn get_all_hashes(&mut self) -> Result<Vec<String>, DeviceError> {
425        let mut all_algs = Vec::new();
426        let mut prop = 0;
427
428        loop {
429            let (more_data, cap_data) =
430                self.get_capability(TpmCap::Algs, prop, u32::try_from(MAX_HANDLES)?)?;
431
432            if let TpmuCapabilities::Algs(p) = &cap_data.data {
433                all_algs.extend(p.iter().map(|prop| prop.alg));
434            } else {
435                return Err(DeviceError::CapabilityMissing(TpmCap::Algs));
436            }
437
438            if more_data {
439                if let TpmuCapabilities::Algs(algs) = cap_data.data {
440                    prop = algs.last().map_or(prop, |p| p.alg as u32 + 1);
441                }
442            } else {
443                break;
444            }
445        }
446
447        let hashes: Vec<String> = all_algs
448            .iter()
449            .filter(|p| tpm_hash_size(p).is_some())
450            .map(|p| Tpm2shAlgId(*p).to_string())
451            .collect();
452        Ok(hashes)
453    }
454
455    /// Retrieves all handles of a specific type from the TPM.
456    ///
457    /// # Errors
458    ///
459    /// Returns a `DeviceError` if the `get_capability` call to the TPM device fails.
460    pub fn get_all_handles(&mut self, handle_type: u32) -> Result<Vec<u32>, DeviceError> {
461        let mut all_handles = Vec::new();
462        let mut prop = handle_type;
463
464        loop {
465            let (more_data, cap_data) =
466                self.get_capability(TpmCap::Handles, prop, TPM_CAP_PROPERTY_MAX)?;
467
468            if let TpmuCapabilities::Handles(handles) = cap_data.data {
469                all_handles.extend(handles.iter().copied());
470            } else {
471                return Err(DeviceError::CapabilityMissing(TpmCap::Handles));
472            }
473
474            if more_data {
475                if let TpmuCapabilities::Handles(handles) = cap_data.data {
476                    prop = handles.last().map_or(prop, |&h| h + 1);
477                }
478            } else {
479                break;
480            }
481        }
482
483        Ok(all_handles)
484    }
485
486    /// Fetches and returns one page of capabilities of a certain type from the TPM.
487    ///
488    /// # Errors
489    ///
490    /// This function will return an error if the underlying `execute` call fails
491    /// or if the TPM returns a response of an unexpected type.
492    pub fn get_capability(
493        &mut self,
494        cap: TpmCap,
495        property: u32,
496        count: u32,
497    ) -> Result<(bool, TpmsCapabilityData), DeviceError> {
498        let cmd = TpmGetCapabilityCommand {
499            cap,
500            property,
501            property_count: count,
502        };
503        let sessions = vec![];
504
505        let (resp, _) = self.execute(&cmd, &sessions)?;
506        let TpmGetCapabilityResponse {
507            more_data,
508            capability_data,
509        } = resp
510            .GetCapability()
511            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::GetCapability))?;
512
513        Ok((more_data.into(), capability_data))
514    }
515
516    /// Reads a specific TPM property.
517    ///
518    /// # Errors
519    ///
520    /// Returns a `DeviceError` if the capability or property is not found, or
521    /// if the `get_capability` call fails.
522    pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, DeviceError> {
523        let (_, cap_data) = self.get_capability(TpmCap::TpmProperties, property as u32, 1)?;
524
525        let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
526            return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
527        };
528
529        let Some(prop) = props.first() else {
530            return Err(DeviceError::CapabilityMissing(TpmCap::TpmProperties));
531        };
532
533        Ok(prop.value)
534    }
535
536    /// Reads the public area of a TPM object.
537    ///
538    /// # Errors
539    ///
540    /// Returns a `DeviceError` if the underlying `TPM2_ReadPublic` command
541    /// execution fails or if the TPM returns a response of an unexpected type.
542    pub fn read_public(
543        &mut self,
544        handle: TpmHandle,
545    ) -> Result<(TpmtPublic, Tpm2bName), DeviceError> {
546        let cmd = TpmReadPublicCommand {
547            object_handle: handle,
548        };
549        let sessions = vec![];
550        let (resp, _) = self.execute(&cmd, &sessions)?;
551        let read_public_resp = resp
552            .ReadPublic()
553            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
554        let name = read_public_resp.name;
555        self.add_name_to_cache(handle.0, name);
556        Ok((read_public_resp.out_public.inner, name))
557    }
558
559    /// Saves the context of a transient object or session.
560    ///
561    /// # Errors
562    ///
563    /// Returns a `DeviceError` if the underlying `TPM2_ContextSave` command
564    /// execution fails or if the TPM returns a response of an unexpected type.
565    pub fn save_context(&mut self, handle: u32) -> Result<TpmsContext, DeviceError> {
566        let cmd = TpmContextSaveCommand {
567            save_handle: handle.into(),
568        };
569        let sessions = vec![];
570        let (resp, _) = self.execute(&cmd, &sessions)?;
571        let save_resp = resp
572            .ContextSave()
573            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextSave))?;
574        Ok(save_resp.context)
575    }
576
577    /// Loads a TPM context and returns the handle.
578    ///
579    /// # Errors
580    ///
581    /// Returns a `DeviceError` if the `TPM2_ContextLoad` command fails.
582    pub fn load_context(&mut self, context: TpmsContext) -> Result<u32, DeviceError> {
583        let cmd = TpmContextLoadCommand { context };
584        let sessions = vec![];
585        let (resp, _) = self.execute(&cmd, &sessions)?;
586        let resp_inner = resp
587            .ContextLoad()
588            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
589        Ok(resp_inner.loaded_handle.0)
590    }
591
592    /// Flushes a transient object or session from the TPM.
593    ///
594    /// # Errors
595    ///
596    /// Returns a `DeviceError` if the underlying `TPM2_FlushContext` command
597    /// execution fails.
598    pub fn flush_context(&mut self, handle: u32) -> Result<(), DeviceError> {
599        let cmd = TpmFlushContextCommand {
600            flush_handle: handle.into(),
601        };
602        let sessions = vec![];
603        self.execute(&cmd, &sessions)?;
604        Ok(())
605    }
606
607    /// Loads a session context and then flushes the resulting handle.
608    ///
609    /// # Errors
610    ///
611    /// Returns `DeviceError` on `ContextLoad` or `FlushContext` failure.
612    pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), DeviceError> {
613        match self.load_context(context) {
614            Ok(live_handle) => self.flush_context(live_handle),
615            Err(DeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(()),
616            Err(e) => Err(e),
617        }
618    }
619
620    /// Starts a new authorization session.
621    ///
622    /// This function sends a `TPM2_StartAuthSession` command to the TPM and
623    /// returns the raw response, which can be used to construct a higher-level
624    /// session object.
625    ///
626    /// # Errors
627    ///
628    /// Returns `DeviceError` on TPM command failure.
629    pub fn start_session(
630        &mut self,
631        session_type: TpmSe,
632        auth_hash: TpmAlgId,
633    ) -> Result<(TpmStartAuthSessionResponse, Tpm2bNonce), DeviceError> {
634        let digest_len =
635            tpm_hash_size(&auth_hash).ok_or(DeviceError::Tpm(TpmErrorKind::InvalidValue))?;
636        let mut nonce_bytes = vec![0; digest_len];
637        thread_rng().fill_bytes(&mut nonce_bytes);
638        let nonce_caller = Tpm2bNonce::try_from(nonce_bytes.as_slice())?;
639
640        let cmd = TpmStartAuthSessionCommand {
641            tpm_key: (TpmRh::Null as u32).into(),
642            bind: (TpmRh::Null as u32).into(),
643            nonce_caller,
644            encrypted_salt: Tpm2bEncryptedSecret::default(),
645            session_type,
646            symmetric: TpmtSymDefObject::default(),
647            auth_hash,
648        };
649        let sessions = vec![];
650
651        let (response_body, _) = self.execute(&cmd, &sessions)?;
652
653        let resp = response_body
654            .StartAuthSession()
655            .map_err(|_| DeviceError::ResponseMismatch(TpmCc::StartAuthSession))?;
656
657        Ok((resp, nonce_caller))
658    }
659}