1#![allow(clippy::no_effect_underscore_binding)]
6
7mod external_key;
8mod tpm_key;
9
10pub use external_key::*;
11pub use tpm_key::*;
12
13use crate::{crypto::CryptoError, device::DeviceError, session::SessionError};
14use rasn::{
15 types::{Integer, ObjectIdentifier},
16 AsnType, Decode, Decoder, Encode,
17};
18use strum::{Display, EnumString};
19use thiserror::Error;
20use tpm2_protocol::{
21 data::{TpmAlgId, TpmEccCurve, TpmtPublic, TpmuPublicParms},
22 TpmErrorKind,
23};
24
25#[derive(Debug, Error)]
26pub enum KeyError {
27 #[error("crypto: {0}")]
28 Crypto(#[from] CryptoError),
29 #[error("device: {0}")]
30 Device(#[from] DeviceError),
31 #[error("unsupported name algorithm: {0}")]
32 InvalidAlgorithm(String),
33 #[error("invalid algorithm format: '{0}'")]
34 InvalidAlgorithmFormat(String),
35 #[error("invalid ECC curve: {0}")]
36 InvalidEccCurve(String),
37 #[error("invalid ECC point: {0}")]
38 InvalidEccPoint(String),
39 #[error("invalid key format")]
40 InvalidFormat,
41 #[error("invalid key bits: {0}")]
42 InvalidKeyBits(String),
43 #[error("invalid modulus: {0}")]
44 InvalidModulus(String),
45 #[error("invalid OID")]
46 InvalidOid,
47 #[error("pem: {0}")]
48 Pem(#[from] pem::PemError),
49 #[error("rasn decode: {0}")]
50 RasnDecode(#[from] rasn::error::DecodeError),
51 #[error("rasn encode: {0}")]
52 RasnEncode(#[from] rasn::error::EncodeError),
53 #[error("unsupported file format")]
54 UnsupportedFileFormat,
55 #[error("unsupported OID: {0}")]
56 UnsupportedOid(String),
57 #[error("invalid PEM tag: {0}")]
58 UnsupportedPemTag(String),
59 #[error("value conversion failed: {0}")]
60 ValueConversionFailed(String),
61 #[error("session: {0}")]
62 Session(#[from] SessionError),
63 #[error("unsupported key algorithm: {0}")]
64 UnsupportedKeyAlgorithm(Tpm2shAlgId),
65 #[error("unsupported name algorithm: {0}")]
66 UnsupportedNameAlgorithm(Tpm2shAlgId),
67}
68
69impl From<TpmErrorKind> for KeyError {
70 fn from(err: TpmErrorKind) -> Self {
71 Self::Device(DeviceError::Tpm(err))
72 }
73}
74
75pub enum AnyKey {
76 Tpm(Box<TpmKey>),
77 External(Box<ExternalKey>),
78}
79
80#[derive(AsnType, Decode, Encode)]
82#[rasn(choice)]
83enum FirstElement {
84 Oid(ObjectIdentifier),
85 Int(Integer),
86}
87
88#[derive(AsnType, Decode, Encode)]
89struct KeyPeek {
90 first: FirstElement,
91}
92
93impl TryFrom<&[u8]> for AnyKey {
94 type Error = KeyError;
95 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
96 if let Ok(pems) = pem::parse_many(bytes) {
97 if let Some(pem) = pems.into_iter().find(|p| {
98 matches!(
99 p.tag(),
100 "TSS2 PRIVATE KEY" | "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY"
101 )
102 }) {
103 let contents = pem.contents();
104 let tag = pem.tag();
105 return match tag {
106 "TSS2 PRIVATE KEY" => {
107 TpmKey::from_der(contents).map(|k| AnyKey::Tpm(Box::new(k)))
108 }
109 "PRIVATE KEY" | "RSA PRIVATE KEY" | "EC PRIVATE KEY" => {
110 ExternalKey::from_der(contents).map(|k| AnyKey::External(Box::new(k)))
111 }
112 _ => Err(KeyError::UnsupportedPemTag(tag.to_string())),
113 };
114 }
115 }
116
117 match rasn::der::decode::<KeyPeek>(bytes)?.first {
118 FirstElement::Oid(..) => TpmKey::from_der(bytes).map(|k| AnyKey::Tpm(Box::new(k))),
119 FirstElement::Int(..) => {
120 ExternalKey::from_der(bytes).map(|k| AnyKey::External(Box::new(k)))
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum AlgInfo {
128 Rsa { key_bits: u16 },
129 Ecc { curve_id: TpmEccCurve },
130 KeyedHash,
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Alg {
135 pub name: String,
136 pub object_type: TpmAlgId,
137 pub name_alg: TpmAlgId,
138 pub params: AlgInfo,
139}
140
141pub fn from_str_to_alg_id(s: &str) -> Result<TpmAlgId, KeyError> {
148 match s {
149 "rsa" => Ok(TpmAlgId::Rsa),
150 "sha1" => Ok(TpmAlgId::Sha1),
151 "hmac" => Ok(TpmAlgId::Hmac),
152 "aes" => Ok(TpmAlgId::Aes),
153 "keyedhash" => Ok(TpmAlgId::KeyedHash),
154 "xor" => Ok(TpmAlgId::Xor),
155 "sha256" => Ok(TpmAlgId::Sha256),
156 "sha384" => Ok(TpmAlgId::Sha384),
157 "sha512" => Ok(TpmAlgId::Sha512),
158 "null" => Ok(TpmAlgId::Null),
159 "sm3_256" => Ok(TpmAlgId::Sm3_256),
160 "sm4" => Ok(TpmAlgId::Sm4),
161 "ecc" => Ok(TpmAlgId::Ecc),
162 _ => Err(KeyError::InvalidAlgorithm(s.to_string())),
163 }
164}
165
166impl Alg {
167 pub fn new_keyedhash(hash_alg: &str) -> Result<Self, KeyError> {
173 let name_alg = from_str_to_alg_id(hash_alg)?;
174 Ok(Self {
175 name: format!("keyedhash:{hash_alg}"),
176 object_type: TpmAlgId::KeyedHash,
177 name_alg,
178 params: AlgInfo::KeyedHash,
179 })
180 }
181}
182
183impl std::fmt::Display for Alg {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 write!(f, "{}", self.name)
186 }
187}
188
189impl std::str::FromStr for Alg {
190 type Err = KeyError;
191
192 fn from_str(s: &str) -> Result<Self, Self::Err> {
193 if let Some(rest) = s.strip_prefix("rsa-") {
194 let (bits_str, name_alg_str) = rest
195 .split_once(':')
196 .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
197 let key_bits: u16 = bits_str
198 .parse()
199 .map_err(|_| KeyError::InvalidKeyBits(bits_str.to_string()))?;
200 let name_alg = from_str_to_alg_id(name_alg_str)?;
201 Ok(Self {
202 name: s.to_string(),
203 object_type: TpmAlgId::Rsa,
204 name_alg,
205 params: AlgInfo::Rsa { key_bits },
206 })
207 } else if let Some(rest) = s.strip_prefix("ecc-") {
208 let (curve_str, name_alg_str) = rest
209 .split_once(':')
210 .ok_or_else(|| KeyError::InvalidAlgorithmFormat(s.to_string()))?;
211 let curve_id: TpmEccCurve = Tpm2shEccCurve::from_str(curve_str)
212 .map_err(|e| KeyError::InvalidEccCurve(e.to_string()))?
213 .into();
214 let name_alg = from_str_to_alg_id(name_alg_str)?;
215 Ok(Self {
216 name: s.to_string(),
217 object_type: TpmAlgId::Ecc,
218 name_alg,
219 params: AlgInfo::Ecc { curve_id },
220 })
221 } else if let Some(name_alg_str) = s.strip_prefix("keyedhash:") {
222 let name_alg = from_str_to_alg_id(name_alg_str)?;
223 Ok(Self {
224 name: s.to_string(),
225 object_type: TpmAlgId::KeyedHash,
226 name_alg,
227 params: AlgInfo::KeyedHash,
228 })
229 } else {
230 Err(KeyError::InvalidAlgorithmFormat(s.to_string()))
231 }
232 }
233}
234
235impl std::cmp::Ord for Alg {
236 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
237 self.name.cmp(&other.name)
238 }
239}
240
241impl std::cmp::PartialOrd for Alg {
242 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
243 Some(self.cmp(other))
244 }
245}
246
247#[derive(Debug, Clone, Copy)]
249pub struct Tpm2shAlgId(pub TpmAlgId);
250
251impl std::fmt::Display for Tpm2shAlgId {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 let s = match self.0 {
254 TpmAlgId::Sha1 => "sha1",
255 TpmAlgId::Sha256 => "sha256",
256 TpmAlgId::Sha384 => "sha384",
257 TpmAlgId::Sha512 => "sha512",
258 TpmAlgId::Rsa => "rsa",
259 TpmAlgId::Hmac => "hmac",
260 TpmAlgId::Aes => "aes",
261 TpmAlgId::KeyedHash => "keyedhash",
262 TpmAlgId::Xor => "xor",
263 TpmAlgId::Null => "null",
264 TpmAlgId::Sm3_256 => "sm3_256",
265 TpmAlgId::Sm4 => "sm4",
266 TpmAlgId::Ecc => "ecc",
267 _ => "unknown",
268 };
269 write!(f, "{s}")
270 }
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
275#[strum(serialize_all = "kebab-case")]
276pub enum Tpm2shEccCurve {
277 NistP192,
278 NistP224,
279 NistP256,
280 NistP384,
281 NistP521,
282 BnP256,
283 BnP638,
284 Sm2P256,
285 #[strum(serialize = "bp-p256-r1")]
286 BpP256R1,
287 #[strum(serialize = "bp-p384-r1")]
288 BpP384R1,
289 #[strum(serialize = "bp-p512-r1")]
290 BpP512R1,
291 Curve25519,
292 Curve448,
293 None,
294}
295
296impl From<TpmEccCurve> for Tpm2shEccCurve {
297 fn from(curve: TpmEccCurve) -> Self {
298 match curve {
299 TpmEccCurve::NistP192 => Self::NistP192,
300 TpmEccCurve::NistP224 => Self::NistP224,
301 TpmEccCurve::NistP256 => Self::NistP256,
302 TpmEccCurve::NistP384 => Self::NistP384,
303 TpmEccCurve::NistP521 => Self::NistP521,
304 TpmEccCurve::BnP256 => Self::BnP256,
305 TpmEccCurve::BnP638 => Self::BnP638,
306 TpmEccCurve::Sm2P256 => Self::Sm2P256,
307 TpmEccCurve::BpP256R1 => Self::BpP256R1,
308 TpmEccCurve::BpP384R1 => Self::BpP384R1,
309 TpmEccCurve::BpP512R1 => Self::BpP512R1,
310 TpmEccCurve::Curve25519 => Self::Curve25519,
311 TpmEccCurve::Curve448 => Self::Curve448,
312 TpmEccCurve::None => Self::None,
313 }
314 }
315}
316
317impl From<Tpm2shEccCurve> for TpmEccCurve {
318 fn from(curve: Tpm2shEccCurve) -> Self {
319 match curve {
320 Tpm2shEccCurve::NistP192 => Self::NistP192,
321 Tpm2shEccCurve::NistP224 => Self::NistP224,
322 Tpm2shEccCurve::NistP256 => Self::NistP256,
323 Tpm2shEccCurve::NistP384 => Self::NistP384,
324 Tpm2shEccCurve::NistP521 => Self::NistP521,
325 Tpm2shEccCurve::BnP256 => Self::BnP256,
326 Tpm2shEccCurve::BnP638 => Self::BnP638,
327 Tpm2shEccCurve::Sm2P256 => Self::Sm2P256,
328 Tpm2shEccCurve::BpP256R1 => Self::BpP256R1,
329 Tpm2shEccCurve::BpP384R1 => Self::BpP384R1,
330 Tpm2shEccCurve::BpP512R1 => Self::BpP512R1,
331 Tpm2shEccCurve::Curve25519 => Self::Curve25519,
332 Tpm2shEccCurve::Curve448 => Self::Curve448,
333 Tpm2shEccCurve::None => Self::None,
334 }
335 }
336}
337
338#[must_use]
340pub fn format_alg_from_public(public: &TpmtPublic) -> String {
341 let name_alg_str = Tpm2shAlgId(public.name_alg).to_string();
342 match public.object_type {
343 TpmAlgId::Rsa => {
344 if let TpmuPublicParms::Rsa(params) = &public.parameters {
345 format!("rsa-{}:{}", params.key_bits, name_alg_str)
346 } else {
347 "rsa".to_string()
348 }
349 }
350 TpmAlgId::Ecc => {
351 if let TpmuPublicParms::Ecc(params) = &public.parameters {
352 let curve_str = Tpm2shEccCurve::from(params.curve_id).to_string();
353 format!("ecc-{curve_str}:{name_alg_str}")
354 } else {
355 "ecc".to_string()
356 }
357 }
358 TpmAlgId::KeyedHash => format!("keyedhash:{name_alg_str}"),
359 _ => Tpm2shAlgId(public.object_type).to_string(),
360 }
361}