1#![allow(clippy::no_effect_underscore_binding)]
6
7mod external_key;
8mod tpm_key;
9
10pub use external_key::*;
11pub use tpm_key::*;
12
13use crate::{crypto::CryptoError, device::DeviceError};
14use rasn::{
15 types::{Integer, ObjectIdentifier},
16 AsnType, Decode, Decoder, Encode,
17};
18use strum::{Display, EnumString};
19use thiserror::Error;
20use tpm2_protocol::{
21 data::{TpmAlgId, TpmEccCurve, TpmaObject, TpmtPublic, TpmuPublicParms},
22 TpmError,
23};
24
25#[derive(Debug, Error)]
26pub enum KeyError {
27 #[error("unsupported name algorithm: {0}")]
28 InvalidAlgorithm(String),
29 #[error("invalid algorithm format: '{0}'")]
30 InvalidAlgorithmFormat(String),
31 #[error("invalid ECC curve: {0}")]
32 InvalidEccCurve(String),
33 #[error("invalid ECC point: {0}")]
34 InvalidEccPoint(String),
35 #[error("invalid key format")]
36 InvalidFormat,
37 #[error("invalid RSA exponent")]
38 InvalidRsaExponent,
39 #[error("invalid RSA key bits: {0}")]
40 InvalidRsaKeyBits(String),
41 #[error("invalid RSA modulus: {0}")]
42 InvalidRsaModulus(String),
43 #[error("invalid OID")]
44 InvalidOid,
45 #[error("invalid parent: {0:08x}")]
46 InvalidParent(u32),
47 #[error("pem: {0}")]
48 Pem(#[from] pem::PemError),
49 #[error("unsupported file format")]
50 UnsupportedFileFormat,
51 #[error("unsupported key algorithm: {0}")]
52 UnsupportedKeyAlgorithm(Tpm2shAlgId),
53 #[error("unsupported name algorithm: {0}")]
54 UnsupportedNameAlgorithm(Tpm2shAlgId),
55 #[error("unsupported OID: {0}")]
56 UnsupportedOid(String),
57 #[error("invalid PEM tag: {0}")]
58 UnsupportedPemTag(String),
59 #[error("value conversion failed: {0}")]
60 ValueConversionFailed(String),
61 #[error("crypto: {0}")]
62 Crypto(#[from] CryptoError),
63 #[error("device: {0}")]
64 Device(#[from] DeviceError),
65 #[error("hex decode: {0}")]
66 HexDecode(#[from] hex::FromHexError),
67 #[error("rasn decode: {0}")]
68 RasnDecode(#[from] rasn::error::DecodeError),
69 #[error("rasn encode: {0}")]
70 RasnEncode(#[from] rasn::error::EncodeError),
71}
72
73impl From<TpmError> for KeyError {
74 fn from(err: TpmError) -> Self {
75 Self::Device(DeviceError::TpmProtocol(err))
76 }
77}
78
79pub enum AnyKey {
80 Tpm(Box<TpmKey>),
81 External(Box<ExternalKey>),
82}
83
84#[derive(AsnType, Decode, Encode)]
86#[rasn(choice)]
87enum FirstElement {
88 Oid(ObjectIdentifier),
89 Int(Integer),
90}
91
92#[derive(AsnType, Decode, Encode)]
93struct KeyPeek {
94 first: FirstElement,
95}
96
97impl TryFrom<&[u8]> for AnyKey {
98 type Error = KeyError;
99 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
100 if let Ok(pems) = pem::parse_many(bytes) {
101 if let Some(pem) = pems.into_iter().find(|p| {
102 matches!(
103 p.tag(),
104 "TSS2 PRIVATE KEY" | "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY"
105 )
106 }) {
107 let contents = pem.contents();
108 let tag = pem.tag();
109 return match tag {
110 "TSS2 PRIVATE KEY" => {
111 TpmKey::from_der(contents).map(|k| AnyKey::Tpm(Box::new(k)))
112 }
113 "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY" => {
114 ExternalKey::from_der(contents).map(|k| AnyKey::External(Box::new(k)))
115 }
116 _ => Err(KeyError::UnsupportedPemTag(tag.to_string())),
117 };
118 }
119 }
120
121 match rasn::der::decode::<KeyPeek>(bytes)?.first {
122 FirstElement::Oid(..) => TpmKey::from_der(bytes).map(|k| AnyKey::Tpm(Box::new(k))),
123 FirstElement::Int(..) => {
124 ExternalKey::from_der(bytes).map(|k| AnyKey::External(Box::new(k)))
125 }
126 }
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum AlgInfo {
132 Rsa { key_bits: u16 },
133 Ecc { curve_id: TpmEccCurve },
134 KeyedHash,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Alg {
139 pub name: String,
140 pub object_type: TpmAlgId,
141 pub name_alg: TpmAlgId,
142 pub params: AlgInfo,
143}
144
145impl From<Alg> for TpmaObject {
146 fn from(alg: Alg) -> TpmaObject {
147 let mut attributes = TpmaObject::FIXED_TPM | TpmaObject::FIXED_PARENT;
148
149 if alg.object_type != TpmAlgId::KeyedHash {
150 attributes |=
151 TpmaObject::SENSITIVE_DATA_ORIGIN | TpmaObject::DECRYPT | TpmaObject::RESTRICTED;
152 }
153
154 attributes
155 }
156}
157
158impl Alg {
159 pub fn new_keyedhash(hash_alg: &str) -> Result<Self, KeyError> {
165 let name_alg = Tpm2shAlgId::try_from(hash_alg)?.0;
166 Ok(Self {
167 name: format!("keyedhash:{hash_alg}"),
168 object_type: TpmAlgId::KeyedHash,
169 name_alg,
170 params: AlgInfo::KeyedHash,
171 })
172 }
173}
174
175impl std::fmt::Display for Alg {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 write!(f, "{}", self.name)
178 }
179}
180
181impl std::str::FromStr for Alg {
182 type Err = KeyError;
183
184 fn from_str(s: &str) -> Result<Self, Self::Err> {
185 if let Some(rest) = s.strip_prefix("rsa-") {
186 let (bits_str, name_alg_str) = rest
187 .split_once(':')
188 .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
189 let key_bits: u16 = bits_str
190 .parse()
191 .map_err(|_| KeyError::InvalidRsaKeyBits(bits_str.to_string()))?;
192 let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
193 Ok(Self {
194 name: s.to_string(),
195 object_type: TpmAlgId::Rsa,
196 name_alg,
197 params: AlgInfo::Rsa { key_bits },
198 })
199 } else if let Some(rest) = s.strip_prefix("ecc-") {
200 let (curve_str, name_alg_str) = rest
201 .split_once(':')
202 .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
203 let curve_id: TpmEccCurve = Tpm2shEccCurve::from_str(curve_str)
204 .map_err(|e| KeyError::InvalidEccCurve(e.to_string()))?
205 .into();
206 let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
207 Ok(Self {
208 name: s.to_string(),
209 object_type: TpmAlgId::Ecc,
210 name_alg,
211 params: AlgInfo::Ecc { curve_id },
212 })
213 } else if let Some(name_alg_str) = s.strip_prefix("keyedhash:") {
214 let name_alg = Tpm2shAlgId::try_from(name_alg_str)?.0;
215 Ok(Self {
216 name: s.to_string(),
217 object_type: TpmAlgId::KeyedHash,
218 name_alg,
219 params: AlgInfo::KeyedHash,
220 })
221 } else {
222 Err(KeyError::InvalidAlgorithmFormat(s.to_string()))
223 }
224 }
225}
226
227impl std::cmp::Ord for Alg {
228 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
229 self.name.cmp(&other.name)
230 }
231}
232
233impl std::cmp::PartialOrd for Alg {
234 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
235 Some(self.cmp(other))
236 }
237}
238
239#[derive(Debug, Clone, Copy)]
241pub struct Tpm2shAlgId(pub TpmAlgId);
242
243impl TryFrom<&str> for Tpm2shAlgId {
244 type Error = KeyError;
245
246 fn try_from(s: &str) -> Result<Self, Self::Error> {
247 let alg_id = match s {
248 "rsa" => TpmAlgId::Rsa,
249 "sha1" => TpmAlgId::Sha1,
250 "hmac" => TpmAlgId::Hmac,
251 "aes" => TpmAlgId::Aes,
252 "keyedhash" => TpmAlgId::KeyedHash,
253 "xor" => TpmAlgId::Xor,
254 "sha256" => TpmAlgId::Sha256,
255 "sha384" => TpmAlgId::Sha384,
256 "sha512" => TpmAlgId::Sha512,
257 "null" => TpmAlgId::Null,
258 "sm3_256" => TpmAlgId::Sm3_256,
259 "sm4" => TpmAlgId::Sm4,
260 "ecc" => TpmAlgId::Ecc,
261 _ => return Err(KeyError::InvalidAlgorithm(s.to_string())),
262 };
263 Ok(Self(alg_id))
264 }
265}
266
267impl std::fmt::Display for Tpm2shAlgId {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 let s = match self.0 {
270 TpmAlgId::Sha1 => "sha1",
271 TpmAlgId::Sha256 => "sha256",
272 TpmAlgId::Sha384 => "sha384",
273 TpmAlgId::Sha512 => "sha512",
274 TpmAlgId::Rsa => "rsa",
275 TpmAlgId::Hmac => "hmac",
276 TpmAlgId::Aes => "aes",
277 TpmAlgId::KeyedHash => "keyedhash",
278 TpmAlgId::Xor => "xor",
279 TpmAlgId::Null => "null",
280 TpmAlgId::Sm3_256 => "sm3_256",
281 TpmAlgId::Sm4 => "sm4",
282 TpmAlgId::Ecc => "ecc",
283 _ => "unknown",
284 };
285 write!(f, "{s}")
286 }
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
291#[strum(serialize_all = "kebab-case")]
292pub enum Tpm2shEccCurve {
293 NistP192,
294 NistP224,
295 NistP256,
296 NistP384,
297 NistP521,
298 BnP256,
299 BnP638,
300 Sm2P256,
301 #[strum(serialize = "bp-p256-r1")]
302 BpP256R1,
303 #[strum(serialize = "bp-p384-r1")]
304 BpP384R1,
305 #[strum(serialize = "bp-p512-r1")]
306 BpP512R1,
307 Curve25519,
308 Curve448,
309 None,
310}
311
312impl From<TpmEccCurve> for Tpm2shEccCurve {
313 fn from(curve: TpmEccCurve) -> Self {
314 match curve {
315 TpmEccCurve::NistP192 => Self::NistP192,
316 TpmEccCurve::NistP224 => Self::NistP224,
317 TpmEccCurve::NistP256 => Self::NistP256,
318 TpmEccCurve::NistP384 => Self::NistP384,
319 TpmEccCurve::NistP521 => Self::NistP521,
320 TpmEccCurve::BnP256 => Self::BnP256,
321 TpmEccCurve::BnP638 => Self::BnP638,
322 TpmEccCurve::Sm2P256 => Self::Sm2P256,
323 TpmEccCurve::BpP256R1 => Self::BpP256R1,
324 TpmEccCurve::BpP384R1 => Self::BpP384R1,
325 TpmEccCurve::BpP512R1 => Self::BpP512R1,
326 TpmEccCurve::Curve25519 => Self::Curve25519,
327 TpmEccCurve::Curve448 => Self::Curve448,
328 TpmEccCurve::None => Self::None,
329 }
330 }
331}
332
333impl From<Tpm2shEccCurve> for TpmEccCurve {
334 fn from(curve: Tpm2shEccCurve) -> Self {
335 match curve {
336 Tpm2shEccCurve::NistP192 => Self::NistP192,
337 Tpm2shEccCurve::NistP224 => Self::NistP224,
338 Tpm2shEccCurve::NistP256 => Self::NistP256,
339 Tpm2shEccCurve::NistP384 => Self::NistP384,
340 Tpm2shEccCurve::NistP521 => Self::NistP521,
341 Tpm2shEccCurve::BnP256 => Self::BnP256,
342 Tpm2shEccCurve::BnP638 => Self::BnP638,
343 Tpm2shEccCurve::Sm2P256 => Self::Sm2P256,
344 Tpm2shEccCurve::BpP256R1 => Self::BpP256R1,
345 Tpm2shEccCurve::BpP384R1 => Self::BpP384R1,
346 Tpm2shEccCurve::BpP512R1 => Self::BpP512R1,
347 Tpm2shEccCurve::Curve25519 => Self::Curve25519,
348 Tpm2shEccCurve::Curve448 => Self::Curve448,
349 Tpm2shEccCurve::None => Self::None,
350 }
351 }
352}
353
354#[must_use]
356pub fn format_alg_from_public(public: &TpmtPublic) -> String {
357 let name_alg_str = Tpm2shAlgId(public.name_alg).to_string();
358 match public.object_type {
359 TpmAlgId::Rsa => {
360 if let TpmuPublicParms::Rsa(params) = &public.parameters {
361 format!("rsa-{}:{}", params.key_bits, name_alg_str)
362 } else {
363 "rsa".to_string()
364 }
365 }
366 TpmAlgId::Ecc => {
367 if let TpmuPublicParms::Ecc(params) = &public.parameters {
368 let curve_str = Tpm2shEccCurve::from(params.curve_id).to_string();
369 format!("ecc-{curve_str}:{name_alg_str}")
370 } else {
371 "ecc".to_string()
372 }
373 }
374 TpmAlgId::KeyedHash => format!("keyedhash:{name_alg_str}"),
375 _ => Tpm2shAlgId(public.object_type).to_string(),
376 }
377}