cli/key/
mod.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5#![allow(clippy::no_effect_underscore_binding)]
6
7mod external_key;
8mod tpm_key;
9
10pub use external_key::*;
11pub use tpm_key::*;
12
13use crate::{crypto::CryptoError, device::DeviceError, session::SessionError};
14use rasn::{
15    types::{Integer, ObjectIdentifier},
16    AsnType, Decode, Decoder, Encode,
17};
18use strum::{Display, EnumString};
19use thiserror::Error;
20use tpm2_protocol::{
21    data::{TpmAlgId, TpmEccCurve, TpmtPublic, TpmuPublicParms},
22    TpmErrorKind,
23};
24
25#[derive(Debug, Error)]
26pub enum KeyError {
27    #[error("crypto: {0}")]
28    Crypto(#[from] CryptoError),
29    #[error("device: {0}")]
30    Device(#[from] DeviceError),
31    #[error("unsupported name algorithm: {0}")]
32    InvalidAlgorithm(String),
33    #[error("invalid algorithm format: '{0}'")]
34    InvalidAlgorithmFormat(String),
35    #[error("invalid ECC curve: {0}")]
36    InvalidEccCurve(String),
37    #[error("invalid ECC point: {0}")]
38    InvalidEccPoint(String),
39    #[error("invalid key format")]
40    InvalidFormat,
41    #[error("invalid key bits: {0}")]
42    InvalidKeyBits(String),
43    #[error("invalid modulus: {0}")]
44    InvalidModulus(String),
45    #[error("invalid OID")]
46    InvalidOid,
47    #[error("pem: {0}")]
48    Pem(#[from] pem::PemError),
49    #[error("rasn decode: {0}")]
50    RasnDecode(#[from] rasn::error::DecodeError),
51    #[error("rasn encode: {0}")]
52    RasnEncode(#[from] rasn::error::EncodeError),
53    #[error("unsupported file format")]
54    UnsupportedFileFormat,
55    #[error("unsupported OID: {0}")]
56    UnsupportedOid(String),
57    #[error("invalid PEM tag: {0}")]
58    UnsupportedPemTag(String),
59    #[error("value conversion failed: {0}")]
60    ValueConversionFailed(String),
61    #[error("session: {0}")]
62    Session(#[from] SessionError),
63    #[error("unsupported key algorithm: {0}")]
64    UnsupportedKeyAlgorithm(Tpm2shAlgId),
65    #[error("unsupported name algorithm: {0}")]
66    UnsupportedNameAlgorithm(Tpm2shAlgId),
67}
68
69impl From<TpmErrorKind> for KeyError {
70    fn from(err: TpmErrorKind) -> Self {
71        Self::Device(DeviceError::Tpm(err))
72    }
73}
74
75pub enum AnyKey {
76    Tpm(Box<TpmKey>),
77    External(Box<ExternalKey>),
78}
79
80/// Helper types for peeking at the DER structure to determine the key type.
81#[derive(AsnType, Decode, Encode)]
82#[rasn(choice)]
83enum FirstElement {
84    Oid(ObjectIdentifier),
85    Int(Integer),
86}
87
88#[derive(AsnType, Decode, Encode)]
89struct KeyPeek {
90    first: FirstElement,
91}
92
93impl TryFrom<&[u8]> for AnyKey {
94    type Error = KeyError;
95    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
96        if let Ok(pems) = pem::parse_many(bytes) {
97            if let Some(pem) = pems.into_iter().find(|p| {
98                matches!(
99                    p.tag(),
100                    "TSS2 PRIVATE KEY" | "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY"
101                )
102            }) {
103                let contents = pem.contents();
104                let tag = pem.tag();
105                return match tag {
106                    "TSS2 PRIVATE KEY" => {
107                        TpmKey::from_der(contents).map(|k| AnyKey::Tpm(Box::new(k)))
108                    }
109                    "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY" => {
110                        ExternalKey::from_der(contents).map(|k| AnyKey::External(Box::new(k)))
111                    }
112                    _ => Err(KeyError::UnsupportedPemTag(tag.to_string())),
113                };
114            }
115        }
116
117        match rasn::der::decode::<KeyPeek>(bytes)?.first {
118            FirstElement::Oid(..) => TpmKey::from_der(bytes).map(|k| AnyKey::Tpm(Box::new(k))),
119            FirstElement::Int(..) => {
120                ExternalKey::from_der(bytes).map(|k| AnyKey::External(Box::new(k)))
121            }
122        }
123    }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum AlgInfo {
128    Rsa { key_bits: u16 },
129    Ecc { curve_id: TpmEccCurve },
130    KeyedHash,
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Alg {
135    pub name: String,
136    pub object_type: TpmAlgId,
137    pub name_alg: TpmAlgId,
138    pub params: AlgInfo,
139}
140
141/// Converts string to `TpmAlgId`.
142///
143/// # Errors
144///
145/// `KeyError::InvalidNameAlgorithm` returned when string is not match any
146/// algorithm.
147pub fn from_str_to_alg_id(s: &str) -> Result<TpmAlgId, KeyError> {
148    match s {
149        "rsa" => Ok(TpmAlgId::Rsa),
150        "sha1" => Ok(TpmAlgId::Sha1),
151        "hmac" => Ok(TpmAlgId::Hmac),
152        "aes" => Ok(TpmAlgId::Aes),
153        "keyedhash" => Ok(TpmAlgId::KeyedHash),
154        "xor" => Ok(TpmAlgId::Xor),
155        "sha256" => Ok(TpmAlgId::Sha256),
156        "sha384" => Ok(TpmAlgId::Sha384),
157        "sha512" => Ok(TpmAlgId::Sha512),
158        "null" => Ok(TpmAlgId::Null),
159        "sm3_256" => Ok(TpmAlgId::Sm3_256),
160        "sm4" => Ok(TpmAlgId::Sm4),
161        "ecc" => Ok(TpmAlgId::Ecc),
162        _ => Err(KeyError::InvalidAlgorithm(s.to_string())),
163    }
164}
165
166impl Alg {
167    /// Creates a new `Alg` struct for a `KeyedHash` object.
168    ///
169    /// # Errors
170    ///
171    /// Returns an `KeyError` if the provided hash algorithm string is invalid.
172    pub fn new_keyedhash(hash_alg: &str) -> Result<Self, KeyError> {
173        let name_alg = from_str_to_alg_id(hash_alg)?;
174        Ok(Self {
175            name: format!("keyedhash:{hash_alg}"),
176            object_type: TpmAlgId::KeyedHash,
177            name_alg,
178            params: AlgInfo::KeyedHash,
179        })
180    }
181}
182
183impl std::fmt::Display for Alg {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        write!(f, "{}", self.name)
186    }
187}
188
189impl std::str::FromStr for Alg {
190    type Err = KeyError;
191
192    fn from_str(s: &str) -> Result<Self, Self::Err> {
193        if let Some(rest) = s.strip_prefix("rsa-") {
194            let (bits_str, name_alg_str) = rest
195                .split_once(':')
196                .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
197            let key_bits: u16 = bits_str
198                .parse()
199                .map_err(|_| KeyError::InvalidKeyBits(bits_str.to_string()))?;
200            let name_alg = from_str_to_alg_id(name_alg_str)?;
201            Ok(Self {
202                name: s.to_string(),
203                object_type: TpmAlgId::Rsa,
204                name_alg,
205                params: AlgInfo::Rsa { key_bits },
206            })
207        } else if let Some(rest) = s.strip_prefix("ecc-") {
208            let (curve_str, name_alg_str) = rest
209                .split_once(':')
210                .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
211            let curve_id: TpmEccCurve = Tpm2shEccCurve::from_str(curve_str)
212                .map_err(|e| KeyError::InvalidEccCurve(e.to_string()))?
213                .into();
214            let name_alg = from_str_to_alg_id(name_alg_str)?;
215            Ok(Self {
216                name: s.to_string(),
217                object_type: TpmAlgId::Ecc,
218                name_alg,
219                params: AlgInfo::Ecc { curve_id },
220            })
221        } else if let Some(name_alg_str) = s.strip_prefix("keyedhash:") {
222            let name_alg = from_str_to_alg_id(name_alg_str)?;
223            Ok(Self {
224                name: s.to_string(),
225                object_type: TpmAlgId::KeyedHash,
226                name_alg,
227                params: AlgInfo::KeyedHash,
228            })
229        } else {
230            Err(KeyError::InvalidAlgorithmFormat(s.to_string()))
231        }
232    }
233}
234
235impl std::cmp::Ord for Alg {
236    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
237        self.name.cmp(&other.name)
238    }
239}
240
241impl std::cmp::PartialOrd for Alg {
242    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
243        Some(self.cmp(other))
244    }
245}
246
247/// A newtype wrapper to provide a project-specific `Display` implementation for `TpmAlgId`.
248#[derive(Debug, Clone, Copy)]
249pub struct Tpm2shAlgId(pub TpmAlgId);
250
251impl std::fmt::Display for Tpm2shAlgId {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        let s = match self.0 {
254            TpmAlgId::Sha1 => "sha1",
255            TpmAlgId::Sha256 => "sha256",
256            TpmAlgId::Sha384 => "sha384",
257            TpmAlgId::Sha512 => "sha512",
258            TpmAlgId::Rsa => "rsa",
259            TpmAlgId::Hmac => "hmac",
260            TpmAlgId::Aes => "aes",
261            TpmAlgId::KeyedHash => "keyedhash",
262            TpmAlgId::Xor => "xor",
263            TpmAlgId::Null => "null",
264            TpmAlgId::Sm3_256 => "sm3_256",
265            TpmAlgId::Sm4 => "sm4",
266            TpmAlgId::Ecc => "ecc",
267            _ => "unknown",
268        };
269        write!(f, "{s}")
270    }
271}
272
273/// A local wrapper enum for `TpmEccCurve` to allow `strum` derives.
274#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
275#[strum(serialize_all = "kebab-case")]
276pub enum Tpm2shEccCurve {
277    NistP192,
278    NistP224,
279    NistP256,
280    NistP384,
281    NistP521,
282    BnP256,
283    BnP638,
284    Sm2P256,
285    #[strum(serialize = "bp-p256-r1")]
286    BpP256R1,
287    #[strum(serialize = "bp-p384-r1")]
288    BpP384R1,
289    #[strum(serialize = "bp-p512-r1")]
290    BpP512R1,
291    Curve25519,
292    Curve448,
293    None,
294}
295
296impl From<TpmEccCurve> for Tpm2shEccCurve {
297    fn from(curve: TpmEccCurve) -> Self {
298        match curve {
299            TpmEccCurve::NistP192 => Self::NistP192,
300            TpmEccCurve::NistP224 => Self::NistP224,
301            TpmEccCurve::NistP256 => Self::NistP256,
302            TpmEccCurve::NistP384 => Self::NistP384,
303            TpmEccCurve::NistP521 => Self::NistP521,
304            TpmEccCurve::BnP256 => Self::BnP256,
305            TpmEccCurve::BnP638 => Self::BnP638,
306            TpmEccCurve::Sm2P256 => Self::Sm2P256,
307            TpmEccCurve::BpP256R1 => Self::BpP256R1,
308            TpmEccCurve::BpP384R1 => Self::BpP384R1,
309            TpmEccCurve::BpP512R1 => Self::BpP512R1,
310            TpmEccCurve::Curve25519 => Self::Curve25519,
311            TpmEccCurve::Curve448 => Self::Curve448,
312            TpmEccCurve::None => Self::None,
313        }
314    }
315}
316
317impl From<Tpm2shEccCurve> for TpmEccCurve {
318    fn from(curve: Tpm2shEccCurve) -> Self {
319        match curve {
320            Tpm2shEccCurve::NistP192 => Self::NistP192,
321            Tpm2shEccCurve::NistP224 => Self::NistP224,
322            Tpm2shEccCurve::NistP256 => Self::NistP256,
323            Tpm2shEccCurve::NistP384 => Self::NistP384,
324            Tpm2shEccCurve::NistP521 => Self::NistP521,
325            Tpm2shEccCurve::BnP256 => Self::BnP256,
326            Tpm2shEccCurve::BnP638 => Self::BnP638,
327            Tpm2shEccCurve::Sm2P256 => Self::Sm2P256,
328            Tpm2shEccCurve::BpP256R1 => Self::BpP256R1,
329            Tpm2shEccCurve::BpP384R1 => Self::BpP384R1,
330            Tpm2shEccCurve::BpP512R1 => Self::BpP512R1,
331            Tpm2shEccCurve::Curve25519 => Self::Curve25519,
332            Tpm2shEccCurve::Curve448 => Self::Curve448,
333            Tpm2shEccCurve::None => Self::None,
334        }
335    }
336}
337
338/// Formats a human-readable algorithm string from a `TpmtPublic` structure.
339#[must_use]
340pub fn format_alg_from_public(public: &TpmtPublic) -> String {
341    let name_alg_str = Tpm2shAlgId(public.name_alg).to_string();
342    match public.object_type {
343        TpmAlgId::Rsa => {
344            if let TpmuPublicParms::Rsa(params) = &public.parameters {
345                format!("rsa-{}:{}", params.key_bits, name_alg_str)
346            } else {
347                "rsa".to_string()
348            }
349        }
350        TpmAlgId::Ecc => {
351            if let TpmuPublicParms::Ecc(params) = &public.parameters {
352                let curve_str = Tpm2shEccCurve::from(params.curve_id).to_string();
353                format!("ecc-{curve_str}:{name_alg_str}")
354            } else {
355                "ecc".to_string()
356            }
357        }
358        TpmAlgId::KeyedHash => format!("keyedhash:{name_alg_str}"),
359        _ => Tpm2shAlgId(public.object_type).to_string(),
360    }
361}