1#![deny(clippy::all)]
8#![deny(clippy::pedantic)]
9
10use num_bigint::{BigUint, RandBigInt};
11use openssl::{
12 bn::{BigNum, BigNumContext},
13 derive::Deriver,
14 ec::{EcGroup, EcKey, EcPoint, PointConversionForm},
15 hash::{Hasher, MessageDigest},
16 memcmp,
17 nid::Nid,
18 pkey::PKey,
19 sign::Signer,
20};
21use rand::{CryptoRng, RngCore};
22use thiserror::Error;
23use tpm2_protocol::{
24 constant::TPM_MAX_COMMAND_SIZE,
25 data::{Tpm2bEccParameter, Tpm2bName, TpmAlgId, TpmEccCurve, TpmsEccPoint, TpmtPublic},
26 TpmMarshal, TpmWriter,
27};
28
29pub const UNCOMPRESSED_POINT_TAG: u8 = 0x04;
30
31pub const KDF_LABEL_DUPLICATE: &str = "DUPLICATE";
32pub const KDF_LABEL_INTEGRITY: &str = "INTEGRITY";
33pub const KDF_LABEL_STORAGE: &str = "STORAGE";
34
35#[derive(Debug, Error)]
36pub enum CryptoError {
37 #[error("big number conversion failed")]
38 BigNumConversion,
39 #[error("HMAC mismatch")]
40 HmacMismatch,
41 #[error("invalid ECC point")]
42 InvalidEccPoint,
43 #[error("invalid hash algorithm")]
44 InvalidHashAlgorithm,
45 #[error("invalid data chunk")]
46 InvalidChunk,
47 #[error("invalid parent ECC point")]
48 InvalidParent,
49 #[error("invalid public area")]
50 InvalidPublicArea,
51 #[error("malformed ECC parameter")]
52 MalformedEccParameter,
53 #[error("malformed ECDH seed")]
54 MalformedEcdhSeed,
55 #[error("malformed HMAC key")]
56 MalformedHmacKey,
57 #[error("malformed name")]
58 MalformedName,
59 #[error("unsupported ECC curve")]
60 UnsupportedEccCurve,
61}
62
63fn map_tpm_alg_to_md(alg: TpmAlgId) -> Result<MessageDigest, CryptoError> {
65 match alg {
66 TpmAlgId::Sha1 => Ok(MessageDigest::sha1()),
67 TpmAlgId::Sha256 | TpmAlgId::Sm3_256 => Ok(MessageDigest::sha256()),
68 TpmAlgId::Sha384 => Ok(MessageDigest::sha384()),
69 TpmAlgId::Sha512 => Ok(MessageDigest::sha512()),
70 _ => Err(CryptoError::InvalidHashAlgorithm),
71 }
72}
73
74fn map_ecc_curve_to_nid(curve_id: TpmEccCurve) -> Result<Nid, CryptoError> {
76 match curve_id {
77 TpmEccCurve::NistP256 => Ok(Nid::X9_62_PRIME256V1),
78 TpmEccCurve::NistP384 => Ok(Nid::SECP384R1),
79 TpmEccCurve::NistP521 => Ok(Nid::SECP521R1),
80 _ => Err(CryptoError::UnsupportedEccCurve),
81 }
82}
83
84fn tpm_ecc_param_to_bignum(param: &Tpm2bEccParameter) -> Result<BigNum, CryptoError> {
86 BigNum::from_slice(param.as_ref()).map_err(|_| CryptoError::BigNumConversion)
87}
88
89pub fn hash_size(alg: TpmAlgId) -> Result<usize, CryptoError> {
96 Ok(map_tpm_alg_to_md(alg)?.size())
97}
98
99pub fn digest(alg: TpmAlgId, data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
106 let md = map_tpm_alg_to_md(alg).map_err(|_| CryptoError::InvalidHashAlgorithm)?;
107 let mut hasher = Hasher::new(md).map_err(|_| CryptoError::InvalidChunk)?;
108 for chunk in data_chunks {
109 hasher
110 .update(chunk)
111 .map_err(|_| CryptoError::InvalidChunk)?;
112 }
113 Ok(hasher
114 .finish()
115 .map_err(|_| CryptoError::InvalidChunk)?
116 .to_vec())
117}
118
119pub fn hmac(alg: TpmAlgId, key: &[u8], data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
128 if key.is_empty() {
129 return Err(CryptoError::MalformedHmacKey);
130 }
131 let md = map_tpm_alg_to_md(alg).map_err(|_| CryptoError::InvalidHashAlgorithm)?;
132 let public_key = PKey::hmac(key).map_err(|_| CryptoError::MalformedHmacKey)?;
133 let mut signer = Signer::new(md, &public_key).map_err(|_| CryptoError::MalformedHmacKey)?;
134 for chunk in data_chunks {
135 signer
136 .update(chunk)
137 .map_err(|_| CryptoError::MalformedHmacKey)?;
138 }
139 signer
140 .sign_to_vec()
141 .map_err(|_| CryptoError::MalformedHmacKey)
142}
143
144pub fn hmac_verify(
155 alg: TpmAlgId,
156 key: &[u8],
157 data_chunks: &[&[u8]],
158 signature: &[u8],
159) -> Result<(), CryptoError> {
160 let expected = hmac(alg, key, data_chunks)?;
161 if memcmp::eq(&expected, signature) {
162 Ok(())
163 } else {
164 Err(CryptoError::HmacMismatch)
165 }
166}
167
168pub fn kdfa(
177 auth_hash: TpmAlgId,
178 hmac_key: &[u8],
179 label: &str,
180 context_a: &[u8],
181 context_b: &[u8],
182 key_bits: u16,
183) -> Result<Vec<u8>, CryptoError> {
184 let mut key_stream = Vec::new();
185 let key_bytes = (key_bits as usize).div_ceil(8);
186 let label_bytes = {
187 let mut bytes = label.as_bytes().to_vec();
188 bytes.push(0);
189 bytes
190 };
191
192 let mut counter: u32 = 1;
193 while key_stream.len() < key_bytes {
194 let counter_bytes = counter.to_be_bytes();
195 let key_bits_bytes = u32::from(key_bits).to_be_bytes();
196 let hmac_payload = [
197 counter_bytes.as_slice(),
198 label_bytes.as_slice(),
199 context_a,
200 context_b,
201 key_bits_bytes.as_slice(),
202 ];
203
204 let result = hmac(auth_hash, hmac_key, &hmac_payload)?;
205 let remaining = key_bytes - key_stream.len();
206 let to_take = remaining.min(result.len());
207 key_stream.extend_from_slice(&result[..to_take]);
208
209 counter += 1;
210 }
211
212 Ok(key_stream)
213}
214
215pub fn kdfe(
224 hash_alg: TpmAlgId,
225 z: &[u8],
226 label: &str,
227 context_u: &[u8],
228 context_v: &[u8],
229 key_bits: u16,
230) -> Result<Vec<u8>, CryptoError> {
231 let mut key_stream = Vec::new();
232 let key_bytes = (key_bits as usize).div_ceil(8);
233 let mut label_bytes = label.as_bytes().to_vec();
234 if label_bytes.last() != Some(&0) {
235 label_bytes.push(0);
236 }
237
238 let other_info = [label_bytes.as_slice(), context_u, context_v].concat();
239
240 let mut counter: u32 = 1;
241 while key_stream.len() < key_bytes {
242 let counter_bytes = counter.to_be_bytes();
243 let digest_payload = [&counter_bytes, z, &other_info];
244
245 let result = digest(hash_alg, &digest_payload)?;
246 let remaining = key_bytes - key_stream.len();
247 let to_take = remaining.min(result.len());
248 key_stream.extend_from_slice(&result[..to_take]);
249
250 counter += 1;
251 }
252
253 Ok(key_stream)
254}
255
256pub fn make_name(public: &TpmtPublic) -> Result<Tpm2bName, CryptoError> {
265 let name_alg = public.name_alg;
266
267 let mut name_buf = Vec::new();
268 name_buf.extend_from_slice(&(name_alg as u16).to_be_bytes());
269
270 let mut public_bytes = vec![0u8; TPM_MAX_COMMAND_SIZE as usize];
271 let len = {
272 let mut writer = TpmWriter::new(&mut public_bytes);
273 public
274 .marshal(&mut writer)
275 .map_err(|_| CryptoError::InvalidPublicArea)?;
276 writer.len()
277 };
278 public_bytes.truncate(len);
279
280 let digest = digest(name_alg, &[&public_bytes])?;
281 name_buf.extend_from_slice(&digest);
282
283 Tpm2bName::try_from(name_buf.as_slice()).map_err(|_| CryptoError::MalformedName)
284}
285
286pub fn ecdh(
303 curve_id: TpmEccCurve,
304 parent_point: &TpmsEccPoint,
305 name_alg: TpmAlgId,
306 rng: &mut (impl RngCore + CryptoRng),
307) -> Result<(Vec<u8>, TpmsEccPoint), CryptoError> {
308 let nid = map_ecc_curve_to_nid(curve_id).map_err(|_| CryptoError::UnsupportedEccCurve)?;
309 let group = EcGroup::from_curve_name(nid).map_err(|_| CryptoError::UnsupportedEccCurve)?;
310 let mut ctx = BigNumContext::new().map_err(|_| CryptoError::MalformedEccParameter)?;
311
312 let parent_x =
313 tpm_ecc_param_to_bignum(&parent_point.x).map_err(|_| CryptoError::InvalidParent)?;
314 let parent_y =
315 tpm_ecc_param_to_bignum(&parent_point.y).map_err(|_| CryptoError::InvalidParent)?;
316 let parent_key = EcKey::from_public_key_affine_coordinates(&group, &parent_x, &parent_y)
317 .map_err(|_| CryptoError::InvalidParent)?;
318 let parent_public_key =
319 PKey::from_ec_key(parent_key).map_err(|_| CryptoError::InvalidParent)?;
320
321 let mut order = BigNum::new().map_err(|_| CryptoError::MalformedEccParameter)?;
322 group
323 .order(&mut order, &mut ctx)
324 .map_err(|_| CryptoError::MalformedEccParameter)?;
325 let order_uint = BigUint::from_bytes_be(&order.to_vec());
326 let one = BigUint::from(1u8);
327
328 let priv_uint = rng.gen_biguint_range(&one, &order_uint);
329 let priv_bn = BigNum::from_slice(&priv_uint.to_bytes_be())
330 .map_err(|_| CryptoError::MalformedEccParameter)?;
331
332 let mut ephemeral_pub_point =
333 EcPoint::new(&group).map_err(|_| CryptoError::MalformedEccParameter)?;
334 ephemeral_pub_point
335 .mul_generator(&group, &priv_bn, &ctx)
336 .map_err(|_| CryptoError::MalformedEccParameter)?;
337 let ephemeral_key = EcKey::from_private_components(&group, &priv_bn, &ephemeral_pub_point)
338 .map_err(|_| CryptoError::MalformedEccParameter)?;
339
340 let ephemeral_public_key =
341 PKey::from_ec_key(ephemeral_key).map_err(|_| CryptoError::MalformedEccParameter)?;
342 let mut deriver =
343 Deriver::new(&ephemeral_public_key).map_err(|_| CryptoError::MalformedEcdhSeed)?;
344 deriver
345 .set_peer(&parent_public_key)
346 .map_err(|_| CryptoError::MalformedEcdhSeed)?;
347 let z = deriver
348 .derive_to_vec()
349 .map_err(|_| CryptoError::MalformedEcdhSeed)?;
350
351 let ephemeral_pub_bytes = ephemeral_pub_point
352 .to_bytes(&group, PointConversionForm::UNCOMPRESSED, &mut ctx)
353 .map_err(|_| CryptoError::MalformedEccParameter)?;
354
355 if ephemeral_pub_bytes.is_empty() || ephemeral_pub_bytes[0] != UNCOMPRESSED_POINT_TAG {
356 return Err(CryptoError::InvalidEccPoint);
357 }
358
359 let coord_len = (ephemeral_pub_bytes.len() - 1) / 2;
360 let ephemeral_x = &ephemeral_pub_bytes[1..=coord_len];
361 let ephemeral_y = &ephemeral_pub_bytes[1 + coord_len..];
362
363 let seed_bits =
364 u16::try_from(hash_size(name_alg)? * 8).map_err(|_| CryptoError::MalformedEcdhSeed)?;
365
366 let context_u = ephemeral_x;
367 let context_v = parent_point.x.as_ref();
368
369 let seed = kdfe(
370 name_alg,
371 &z,
372 KDF_LABEL_DUPLICATE,
373 context_u,
374 context_v,
375 seed_bits,
376 )?;
377
378 let ephemeral_point_tpm = TpmsEccPoint {
379 x: Tpm2bEccParameter::try_from(ephemeral_x)
380 .map_err(|_| CryptoError::MalformedEccParameter)?,
381 y: Tpm2bEccParameter::try_from(ephemeral_y)
382 .map_err(|_| CryptoError::MalformedEccParameter)?,
383 };
384
385 Ok((seed, ephemeral_point_tpm))
386}