cli/key/
mod.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5#![allow(clippy::no_effect_underscore_binding)]
6
7mod external_key;
8mod tpm_key;
9
10pub use external_key::*;
11pub use tpm_key::*;
12
13use crate::{crypto::CryptoError, device::DeviceError};
14use rasn::{
15    types::{Integer, ObjectIdentifier},
16    AsnType, Decode, Decoder, Encode,
17};
18use strum::{Display, EnumString};
19use thiserror::Error;
20use tpm2_protocol::{
21    data::{TpmAlgId, TpmEccCurve, TpmaObject, TpmtPublic, TpmuPublicParms},
22    TpmError,
23};
24
25#[derive(Debug, Error)]
26pub enum KeyError {
27    #[error("unsupported name algorithm: {0}")]
28    InvalidAlgorithm(String),
29    #[error("invalid algorithm format: '{0}'")]
30    InvalidAlgorithmFormat(String),
31    #[error("invalid ECC curve: {0}")]
32    InvalidEccCurve(String),
33    #[error("invalid ECC point: {0}")]
34    InvalidEccPoint(String),
35    #[error("invalid key format")]
36    InvalidFormat,
37    #[error("invalid RSA exponent")]
38    InvalidRsaExponent,
39    #[error("invalid RSA key bits: {0}")]
40    InvalidRsaKeyBits(String),
41    #[error("invalid RSA modulus: {0}")]
42    InvalidRsaModulus(String),
43    #[error("invalid OID")]
44    InvalidOid,
45    #[error("invalid parent: {0:08x}")]
46    InvalidParent(u32),
47    #[error("pem: {0}")]
48    Pem(#[from] pem::PemError),
49    #[error("unsupported file format")]
50    UnsupportedFileFormat,
51    #[error("unsupported key algorithm: {0}")]
52    UnsupportedKeyAlgorithm(Tpm2shAlgId),
53    #[error("unsupported name algorithm: {0}")]
54    UnsupportedNameAlgorithm(Tpm2shAlgId),
55    #[error("unsupported OID: {0}")]
56    UnsupportedOid(String),
57    #[error("invalid PEM tag: {0}")]
58    UnsupportedPemTag(String),
59    #[error("value conversion failed: {0}")]
60    ValueConversionFailed(String),
61    #[error("crypto: {0}")]
62    Crypto(#[from] CryptoError),
63    #[error("device: {0}")]
64    Device(#[from] DeviceError),
65    #[error("hex decode: {0}")]
66    HexDecode(#[from] hex::FromHexError),
67    #[error("rasn decode: {0}")]
68    RasnDecode(#[from] rasn::error::DecodeError),
69    #[error("rasn encode: {0}")]
70    RasnEncode(#[from] rasn::error::EncodeError),
71}
72
73impl From<TpmError> for KeyError {
74    fn from(err: TpmError) -> Self {
75        Self::Device(DeviceError::TpmProtocol(err))
76    }
77}
78
79pub enum AnyKey {
80    Tpm(Box<TpmKey>),
81    External(Box<ExternalKey>),
82}
83
84/// Helper types for peeking at the DER structure to determine the key type.
85#[derive(AsnType, Decode, Encode)]
86#[rasn(choice)]
87enum FirstElement {
88    Oid(ObjectIdentifier),
89    Int(Integer),
90}
91
92#[derive(AsnType, Decode, Encode)]
93struct KeyPeek {
94    first: FirstElement,
95}
96
97impl TryFrom<&[u8]> for AnyKey {
98    type Error = KeyError;
99    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
100        if let Ok(pems) = pem::parse_many(bytes) {
101            if let Some(pem) = pems.into_iter().find(|p| {
102                matches!(
103                    p.tag(),
104                    "TSS2 PRIVATE KEY" | "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY"
105                )
106            }) {
107                let contents = pem.contents();
108                let tag = pem.tag();
109                return match tag {
110                    "TSS2 PRIVATE KEY" => {
111                        TpmKey::from_der(contents).map(|k| AnyKey::Tpm(Box::new(k)))
112                    }
113                    "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY" => {
114                        ExternalKey::from_der(contents).map(|k| AnyKey::External(Box::new(k)))
115                    }
116                    _ => Err(KeyError::UnsupportedPemTag(tag.to_string())),
117                };
118            }
119        }
120
121        match rasn::der::decode::<KeyPeek>(bytes)?.first {
122            FirstElement::Oid(..) => TpmKey::from_der(bytes).map(|k| AnyKey::Tpm(Box::new(k))),
123            FirstElement::Int(..) => {
124                ExternalKey::from_der(bytes).map(|k| AnyKey::External(Box::new(k)))
125            }
126        }
127    }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum AlgInfo {
132    Rsa { key_bits: u16 },
133    Ecc { curve_id: TpmEccCurve },
134    KeyedHash,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Alg {
139    pub name: String,
140    pub object_type: TpmAlgId,
141    pub name_alg: TpmAlgId,
142    pub params: AlgInfo,
143}
144
145impl From<Alg> for TpmaObject {
146    fn from(alg: Alg) -> TpmaObject {
147        let mut attributes = TpmaObject::FIXED_TPM | TpmaObject::FIXED_PARENT;
148
149        if alg.object_type != TpmAlgId::KeyedHash {
150            attributes |=
151                TpmaObject::SENSITIVE_DATA_ORIGIN | TpmaObject::DECRYPT | TpmaObject::RESTRICTED;
152        }
153
154        attributes
155    }
156}
157
158impl Alg {
159    /// Creates a new `Alg` struct for a `KeyedHash` object.
160    ///
161    /// # Errors
162    ///
163    /// Returns an `KeyError` if the provided hash algorithm string is invalid.
164    pub fn new_keyedhash(hash_alg: &str) -> Result<Self, KeyError> {
165        let name_alg = Tpm2shAlgId::try_from(hash_alg)?.0;
166        Ok(Self {
167            name: format!("keyedhash:{hash_alg}"),
168            object_type: TpmAlgId::KeyedHash,
169            name_alg,
170            params: AlgInfo::KeyedHash,
171        })
172    }
173}
174
175impl std::fmt::Display for Alg {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        write!(f, "{}", self.name)
178    }
179}
180
181impl std::str::FromStr for Alg {
182    type Err = KeyError;
183
184    fn from_str(s: &str) -> Result<Self, Self::Err> {
185        if let Some(rest) = s.strip_prefix("rsa-") {
186            let (bits_str, name_alg_str) = rest
187                .split_once(':')
188                .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
189            let key_bits: u16 = bits_str
190                .parse()
191                .map_err(|_| KeyError::InvalidRsaKeyBits(bits_str.to_string()))?;
192            let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
193            Ok(Self {
194                name: s.to_string(),
195                object_type: TpmAlgId::Rsa,
196                name_alg,
197                params: AlgInfo::Rsa { key_bits },
198            })
199        } else if let Some(rest) = s.strip_prefix("ecc-") {
200            let (curve_str, name_alg_str) = rest
201                .split_once(':')
202                .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
203            let curve_id: TpmEccCurve = Tpm2shEccCurve::from_str(curve_str)
204                .map_err(|e| KeyError::InvalidEccCurve(e.to_string()))?
205                .into();
206            let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
207            Ok(Self {
208                name: s.to_string(),
209                object_type: TpmAlgId::Ecc,
210                name_alg,
211                params: AlgInfo::Ecc { curve_id },
212            })
213        } else if let Some(name_alg_str) = s.strip_prefix("keyedhash:") {
214            let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
215            Ok(Self {
216                name: s.to_string(),
217                object_type: TpmAlgId::KeyedHash,
218                name_alg,
219                params: AlgInfo::KeyedHash,
220            })
221        } else {
222            Err(KeyError::InvalidAlgorithmFormat(s.to_string()))
223        }
224    }
225}
226
227impl std::cmp::Ord for Alg {
228    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
229        self.name.cmp(&other.name)
230    }
231}
232
233impl std::cmp::PartialOrd for Alg {
234    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
235        Some(self.cmp(other))
236    }
237}
238
239/// A newtype wrapper to provide a project-specific `Display` implementation for `TpmAlgId`.
240#[derive(Debug, Clone, Copy)]
241pub struct Tpm2shAlgId(pub TpmAlgId);
242
243impl TryFrom<&str> for Tpm2shAlgId {
244    type Error = KeyError;
245
246    fn try_from(s: &str) -> Result<Self, Self::Error> {
247        let alg_id = match s {
248            "rsa" => TpmAlgId::Rsa,
249            "sha1" => TpmAlgId::Sha1,
250            "hmac" => TpmAlgId::Hmac,
251            "aes" => TpmAlgId::Aes,
252            "keyedhash" => TpmAlgId::KeyedHash,
253            "xor" => TpmAlgId::Xor,
254            "sha256" => TpmAlgId::Sha256,
255            "sha384" => TpmAlgId::Sha384,
256            "sha512" => TpmAlgId::Sha512,
257            "null" => TpmAlgId::Null,
258            "sm3_256" => TpmAlgId::Sm3_256,
259            "sm4" => TpmAlgId::Sm4,
260            "ecc" => TpmAlgId::Ecc,
261            _ => return Err(KeyError::InvalidAlgorithm(s.to_string())),
262        };
263        Ok(Self(alg_id))
264    }
265}
266
267impl std::fmt::Display for Tpm2shAlgId {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        let s = match self.0 {
270            TpmAlgId::Sha1 => "sha1",
271            TpmAlgId::Sha256 => "sha256",
272            TpmAlgId::Sha384 => "sha384",
273            TpmAlgId::Sha512 => "sha512",
274            TpmAlgId::Rsa => "rsa",
275            TpmAlgId::Hmac => "hmac",
276            TpmAlgId::Aes => "aes",
277            TpmAlgId::KeyedHash => "keyedhash",
278            TpmAlgId::Xor => "xor",
279            TpmAlgId::Null => "null",
280            TpmAlgId::Sm3_256 => "sm3_256",
281            TpmAlgId::Sm4 => "sm4",
282            TpmAlgId::Ecc => "ecc",
283            _ => "unknown",
284        };
285        write!(f, "{s}")
286    }
287}
288
289/// A local wrapper enum for `TpmEccCurve` to allow `strum` derives.
290#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
291#[strum(serialize_all = "kebab-case")]
292pub enum Tpm2shEccCurve {
293    NistP192,
294    NistP224,
295    NistP256,
296    NistP384,
297    NistP521,
298    BnP256,
299    BnP638,
300    Sm2P256,
301    #[strum(serialize = "bp-p256-r1")]
302    BpP256R1,
303    #[strum(serialize = "bp-p384-r1")]
304    BpP384R1,
305    #[strum(serialize = "bp-p512-r1")]
306    BpP512R1,
307    Curve25519,
308    Curve448,
309    None,
310}
311
312impl From<TpmEccCurve> for Tpm2shEccCurve {
313    fn from(curve: TpmEccCurve) -> Self {
314        match curve {
315            TpmEccCurve::NistP192 => Self::NistP192,
316            TpmEccCurve::NistP224 => Self::NistP224,
317            TpmEccCurve::NistP256 => Self::NistP256,
318            TpmEccCurve::NistP384 => Self::NistP384,
319            TpmEccCurve::NistP521 => Self::NistP521,
320            TpmEccCurve::BnP256 => Self::BnP256,
321            TpmEccCurve::BnP638 => Self::BnP638,
322            TpmEccCurve::Sm2P256 => Self::Sm2P256,
323            TpmEccCurve::BpP256R1 => Self::BpP256R1,
324            TpmEccCurve::BpP384R1 => Self::BpP384R1,
325            TpmEccCurve::BpP512R1 => Self::BpP512R1,
326            TpmEccCurve::Curve25519 => Self::Curve25519,
327            TpmEccCurve::Curve448 => Self::Curve448,
328            TpmEccCurve::None => Self::None,
329        }
330    }
331}
332
333impl From<Tpm2shEccCurve> for TpmEccCurve {
334    fn from(curve: Tpm2shEccCurve) -> Self {
335        match curve {
336            Tpm2shEccCurve::NistP192 => Self::NistP192,
337            Tpm2shEccCurve::NistP224 => Self::NistP224,
338            Tpm2shEccCurve::NistP256 => Self::NistP256,
339            Tpm2shEccCurve::NistP384 => Self::NistP384,
340            Tpm2shEccCurve::NistP521 => Self::NistP521,
341            Tpm2shEccCurve::BnP256 => Self::BnP256,
342            Tpm2shEccCurve::BnP638 => Self::BnP638,
343            Tpm2shEccCurve::Sm2P256 => Self::Sm2P256,
344            Tpm2shEccCurve::BpP256R1 => Self::BpP256R1,
345            Tpm2shEccCurve::BpP384R1 => Self::BpP384R1,
346            Tpm2shEccCurve::BpP512R1 => Self::BpP512R1,
347            Tpm2shEccCurve::Curve25519 => Self::Curve25519,
348            Tpm2shEccCurve::Curve448 => Self::Curve448,
349            Tpm2shEccCurve::None => Self::None,
350        }
351    }
352}
353
354/// Formats a human-readable algorithm string from a `TpmtPublic` structure.
355#[must_use]
356pub fn format_alg_from_public(public: &TpmtPublic) -> String {
357    let name_alg_str = Tpm2shAlgId(public.name_alg).to_string();
358    match public.object_type {
359        TpmAlgId::Rsa => {
360            if let TpmuPublicParms::Rsa(params) = &public.parameters {
361                format!("rsa-{}:{}", params.key_bits, name_alg_str)
362            } else {
363                "rsa".to_string()
364            }
365        }
366        TpmAlgId::Ecc => {
367            if let TpmuPublicParms::Ecc(params) = &public.parameters {
368                let curve_str = Tpm2shEccCurve::from(params.curve_id).to_string();
369                format!("ecc-{curve_str}:{name_alg_str}")
370            } else {
371                "ecc".to_string()
372            }
373        }
374        TpmAlgId::KeyedHash => format!("keyedhash:{name_alg_str}"),
375        _ => Tpm2shAlgId(public.object_type).to_string(),
376    }
377}