cli/command/
algorithm.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4use crate::{
5    cli::SubCommand,
6    command::{print_table, CommandError, Tabled},
7    crypto::crypto_hash_size,
8    device::{test_rsa_parms, with_device, Device, DeviceError},
9    job::Job,
10    key::{Tpm2shAlgId, Tpm2shEccCurve},
11};
12use clap::Args;
13use strum::{Display, EnumString};
14use tpm2_protocol::{
15    constant::MAX_HANDLES,
16    data::{TpmAlgId, TpmCap, TpmRcBase, TpmuCapabilities},
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
20#[strum(serialize_all = "kebab-case")]
21pub enum AlgorithmType {
22    Key,
23    Name,
24}
25
26struct AlgorithmRow {
27    algorithm: String,
28    algorithm_type: String,
29}
30
31impl Tabled for AlgorithmRow {
32    fn headers() -> Vec<String> {
33        vec!["ALGORITHM".to_string(), "TYPE".to_string()]
34    }
35
36    fn row(&self) -> Vec<String> {
37        vec![self.algorithm.clone(), self.algorithm_type.clone()]
38    }
39}
40
41/// Lists available algorithms supported by the chip.
42#[derive(Args, Debug)]
43pub struct Algorithm {
44    /// Algorithm type: 'key' or 'name'
45    #[arg(short = 't', long = "type")]
46    pub algorithm_type: Option<AlgorithmType>,
47}
48
49impl Algorithm {
50    fn fetch_hash_algorithms(device: &mut Device) -> Result<Vec<String>, CommandError> {
51        let all_algs = device.fetch_algorithm_properties()?;
52        let hashes: Vec<String> = all_algs
53            .iter()
54            .map(|prop| prop.alg)
55            .filter(|p| crypto_hash_size(*p).is_some())
56            .map(|p| Tpm2shAlgId(p).to_string())
57            .collect();
58        Ok(hashes)
59    }
60
61    fn fetch_algorithms(device: &mut Device) -> Result<Vec<(String, AlgorithmType)>, CommandError> {
62        let mut results: Vec<(String, AlgorithmType)> = Vec::new();
63        let all_alg_props = device.fetch_algorithm_properties()?;
64        let all_algs: std::collections::HashSet<TpmAlgId> =
65            all_alg_props.into_iter().map(|p| p.alg).collect();
66
67        let name_algs: Vec<TpmAlgId> = [TpmAlgId::Sha256, TpmAlgId::Sha384, TpmAlgId::Sha512]
68            .into_iter()
69            .filter(|alg| all_algs.contains(alg))
70            .collect();
71
72        if all_algs.contains(&TpmAlgId::Rsa) {
73            let rsa_key_sizes = [2048, 3072, 4096];
74            for key_bits in rsa_key_sizes {
75                match test_rsa_parms(device, key_bits) {
76                    Ok(()) => {
77                        for &name_alg in &name_algs {
78                            results.push((
79                                format!("rsa-{}:{}", key_bits, Tpm2shAlgId(name_alg)),
80                                AlgorithmType::Key,
81                            ));
82                        }
83                    }
84                    Err(DeviceError::TpmRc(rc)) => {
85                        if rc.base() != TpmRcBase::Value {
86                            return Err(DeviceError::TpmRc(rc).into());
87                        }
88                    }
89                    Err(e) => return Err(e.into()),
90                }
91            }
92        }
93
94        if all_algs.contains(&TpmAlgId::Ecc) {
95            let supported_curves = device.get_capability(
96                TpmCap::EccCurves,
97                0,
98                u32::try_from(MAX_HANDLES)?,
99                |caps| match caps {
100                    TpmuCapabilities::EccCurves(curves) => Ok(curves),
101                    _ => Err(DeviceError::CapabilityMissing(TpmCap::EccCurves)),
102                },
103                |last| *last as u32 + 1,
104            )?;
105            for curve_id in supported_curves {
106                for &name_alg in &name_algs {
107                    results.push((
108                        format!(
109                            "ecc-{}:{}",
110                            Tpm2shEccCurve::from(curve_id),
111                            Tpm2shAlgId(name_alg)
112                        ),
113                        AlgorithmType::Key,
114                    ));
115                }
116            }
117        }
118
119        if all_algs.contains(&TpmAlgId::KeyedHash) {
120            for &name_alg in &name_algs {
121                results.push((
122                    format!("keyedhash:{}", Tpm2shAlgId(name_alg)),
123                    AlgorithmType::Key,
124                ));
125            }
126        }
127        Ok(results)
128    }
129}
130
131impl SubCommand for Algorithm {
132    fn run(&self, job: &mut Job) -> Result<(), CommandError> {
133        with_device(job.device.clone(), |device| {
134            let mut results: Vec<(String, AlgorithmType)> = Vec::new();
135
136            let fetch_keys =
137                self.algorithm_type.is_none() || self.algorithm_type == Some(AlgorithmType::Key);
138            let fetch_names =
139                self.algorithm_type.is_none() || self.algorithm_type == Some(AlgorithmType::Name);
140
141            if fetch_keys {
142                results.extend(Algorithm::fetch_algorithms(device)?);
143            }
144
145            if fetch_names {
146                let hashes = Self::fetch_hash_algorithms(device)?
147                    .into_iter()
148                    .map(|name| (name, AlgorithmType::Name));
149                results.extend(hashes);
150            }
151
152            results.sort_by(|a, b| a.0.cmp(&b.0));
153
154            let rows: Vec<AlgorithmRow> = results
155                .into_iter()
156                .map(|(algorithm, algorithm_type)| AlgorithmRow {
157                    algorithm,
158                    algorithm_type: algorithm_type.to_string(),
159                })
160                .collect();
161            print_table(&mut job.writer, &rows)?;
162            Ok(())
163        })
164    }
165}