1#![deny(clippy::all)]
8#![deny(clippy::pedantic)]
9
10use openssl::{
11 bn::{BigNum, BigNumContext},
12 derive::Deriver,
13 ec::{EcGroup, EcKey, EcPoint, PointConversionForm},
14 hash::{Hasher, MessageDigest},
15 memcmp,
16 nid::Nid,
17 pkey::PKey,
18 sign::Signer,
19};
20use thiserror::Error;
21use tpm2_protocol::{
22 constant::TPM_MAX_COMMAND_SIZE,
23 data::{Tpm2bEccParameter, Tpm2bName, TpmAlgId, TpmEccCurve, TpmsEccPoint, TpmtPublic},
24 TpmMarshal, TpmWriter,
25};
26
27pub const UNCOMPRESSED_POINT_TAG: u8 = 0x04;
28
29pub const KDF_LABEL_DUPLICATE: &str = "DUPLICATE";
30pub const KDF_LABEL_INTEGRITY: &str = "INTEGRITY";
31pub const KDF_LABEL_STORAGE: &str = "STORAGE";
32
33#[derive(Debug, Error)]
34pub enum CryptoError {
35 #[error("big number conversion failed")]
36 BigNumConversion,
37 #[error("HMAC mismatch")]
38 HmacMismatch,
39 #[error("invalid ECC point")]
40 InvalidEccPoint,
41 #[error("invalid hash algorithm")]
42 InvalidHashAlgorithm,
43 #[error("invalid data chunk")]
44 InvalidChunk,
45 #[error("invalid parent ECC point")]
46 InvalidParent,
47 #[error("invalid public area")]
48 InvalidPublicArea,
49 #[error("malformed ECC parameter")]
50 MalformedEccParameter,
51 #[error("malformed ECDH seed")]
52 MalformedEcdhSeed,
53 #[error("malformed HMAC key")]
54 MalformedHmacKey,
55 #[error("malformed name")]
56 MalformedName,
57 #[error("unsupported ECC curve")]
58 UnsupportedEccCurve,
59}
60
61fn map_tpm_alg_to_md(alg: TpmAlgId) -> Result<MessageDigest, CryptoError> {
63 match alg {
64 TpmAlgId::Sha1 => Ok(MessageDigest::sha1()),
65 TpmAlgId::Sha256 | TpmAlgId::Sm3_256 => Ok(MessageDigest::sha256()),
66 TpmAlgId::Sha384 => Ok(MessageDigest::sha384()),
67 TpmAlgId::Sha512 => Ok(MessageDigest::sha512()),
68 _ => Err(CryptoError::InvalidHashAlgorithm),
69 }
70}
71
72fn map_ecc_curve_to_nid(curve_id: TpmEccCurve) -> Result<Nid, CryptoError> {
74 match curve_id {
75 TpmEccCurve::NistP256 => Ok(Nid::X9_62_PRIME256V1),
76 TpmEccCurve::NistP384 => Ok(Nid::SECP384R1),
77 TpmEccCurve::NistP521 => Ok(Nid::SECP521R1),
78 _ => Err(CryptoError::UnsupportedEccCurve),
79 }
80}
81
82fn tpm_ecc_param_to_bignum(param: &Tpm2bEccParameter) -> Result<BigNum, CryptoError> {
84 BigNum::from_slice(param.as_ref()).map_err(|_| CryptoError::BigNumConversion)
85}
86
87pub fn hash_size(alg: TpmAlgId) -> Result<usize, CryptoError> {
94 Ok(map_tpm_alg_to_md(alg)?.size())
95}
96
97pub fn digest(alg: TpmAlgId, data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
104 let md = map_tpm_alg_to_md(alg).map_err(|_| CryptoError::InvalidHashAlgorithm)?;
105 let mut hasher = Hasher::new(md).map_err(|_| CryptoError::InvalidChunk)?;
106 for chunk in data_chunks {
107 hasher
108 .update(chunk)
109 .map_err(|_| CryptoError::InvalidChunk)?;
110 }
111 Ok(hasher
112 .finish()
113 .map_err(|_| CryptoError::InvalidChunk)?
114 .to_vec())
115}
116
117pub fn hmac(alg: TpmAlgId, key: &[u8], data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
126 if key.is_empty() {
127 return Err(CryptoError::MalformedHmacKey);
128 }
129 let md = map_tpm_alg_to_md(alg).map_err(|_| CryptoError::InvalidHashAlgorithm)?;
130 let public_key = PKey::hmac(key).map_err(|_| CryptoError::MalformedHmacKey)?;
131 let mut signer = Signer::new(md, &public_key).map_err(|_| CryptoError::MalformedHmacKey)?;
132 for chunk in data_chunks {
133 signer
134 .update(chunk)
135 .map_err(|_| CryptoError::MalformedHmacKey)?;
136 }
137 signer
138 .sign_to_vec()
139 .map_err(|_| CryptoError::MalformedHmacKey)
140}
141
142pub fn hmac_verify(
153 alg: TpmAlgId,
154 key: &[u8],
155 data_chunks: &[&[u8]],
156 signature: &[u8],
157) -> Result<(), CryptoError> {
158 let expected = hmac(alg, key, data_chunks)?;
159 if memcmp::eq(&expected, signature) {
160 Ok(())
161 } else {
162 Err(CryptoError::HmacMismatch)
163 }
164}
165
166pub fn kdfa(
175 auth_hash: TpmAlgId,
176 hmac_key: &[u8],
177 label: &str,
178 context_a: &[u8],
179 context_b: &[u8],
180 key_bits: u16,
181) -> Result<Vec<u8>, CryptoError> {
182 let mut key_stream = Vec::new();
183 let key_bytes = (key_bits as usize).div_ceil(8);
184 let label_bytes = {
185 let mut bytes = label.as_bytes().to_vec();
186 bytes.push(0);
187 bytes
188 };
189
190 let mut counter: u32 = 1;
191 while key_stream.len() < key_bytes {
192 let counter_bytes = counter.to_be_bytes();
193 let key_bits_bytes = u32::from(key_bits).to_be_bytes();
194 let hmac_payload = [
195 counter_bytes.as_slice(),
196 label_bytes.as_slice(),
197 context_a,
198 context_b,
199 key_bits_bytes.as_slice(),
200 ];
201
202 let result = hmac(auth_hash, hmac_key, &hmac_payload)?;
203 let remaining = key_bytes - key_stream.len();
204 let to_take = remaining.min(result.len());
205 key_stream.extend_from_slice(&result[..to_take]);
206
207 counter += 1;
208 }
209
210 Ok(key_stream)
211}
212
213pub fn kdfe(
222 hash_alg: TpmAlgId,
223 z: &[u8],
224 label: &str,
225 context_u: &[u8],
226 context_v: &[u8],
227 key_bits: u16,
228) -> Result<Vec<u8>, CryptoError> {
229 let mut key_stream = Vec::new();
230 let key_bytes = (key_bits as usize).div_ceil(8);
231 let mut label_bytes = label.as_bytes().to_vec();
232 if label_bytes.last() != Some(&0) {
233 label_bytes.push(0);
234 }
235
236 let other_info = [label_bytes.as_slice(), context_u, context_v].concat();
237
238 let mut counter: u32 = 1;
239 while key_stream.len() < key_bytes {
240 let counter_bytes = counter.to_be_bytes();
241 let digest_payload = [&counter_bytes, z, &other_info];
242
243 let result = digest(hash_alg, &digest_payload)?;
244 let remaining = key_bytes - key_stream.len();
245 let to_take = remaining.min(result.len());
246 key_stream.extend_from_slice(&result[..to_take]);
247
248 counter += 1;
249 }
250
251 Ok(key_stream)
252}
253
254pub fn make_name(public: &TpmtPublic) -> Result<Tpm2bName, CryptoError> {
263 let name_alg = public.name_alg;
264
265 let mut name_buf = Vec::new();
266 name_buf.extend_from_slice(&(name_alg as u16).to_be_bytes());
267
268 let mut public_bytes = vec![0u8; TPM_MAX_COMMAND_SIZE as usize];
269 let len = {
270 let mut writer = TpmWriter::new(&mut public_bytes);
271 public
272 .marshal(&mut writer)
273 .map_err(|_| CryptoError::InvalidPublicArea)?;
274 writer.len()
275 };
276 public_bytes.truncate(len);
277
278 let digest = digest(name_alg, &[&public_bytes])?;
279 name_buf.extend_from_slice(&digest);
280
281 Tpm2bName::try_from(name_buf.as_slice()).map_err(|_| CryptoError::MalformedName)
282}
283
284pub fn ecdh(
301 curve_id: TpmEccCurve,
302 parent_point: &TpmsEccPoint,
303 name_alg: TpmAlgId,
304 priv_bytes: &[u8],
305) -> Result<(Vec<u8>, TpmsEccPoint), CryptoError> {
306 let nid = map_ecc_curve_to_nid(curve_id).map_err(|_| CryptoError::UnsupportedEccCurve)?;
307 let group = EcGroup::from_curve_name(nid).map_err(|_| CryptoError::UnsupportedEccCurve)?;
308 let mut ctx = BigNumContext::new().map_err(|_| CryptoError::MalformedEccParameter)?;
309
310 let parent_x =
311 tpm_ecc_param_to_bignum(&parent_point.x).map_err(|_| CryptoError::InvalidParent)?;
312 let parent_y =
313 tpm_ecc_param_to_bignum(&parent_point.y).map_err(|_| CryptoError::InvalidParent)?;
314 let parent_key = EcKey::from_public_key_affine_coordinates(&group, &parent_x, &parent_y)
315 .map_err(|_| CryptoError::InvalidParent)?;
316 let parent_public_key =
317 PKey::from_ec_key(parent_key).map_err(|_| CryptoError::InvalidParent)?;
318
319 let mut order = BigNum::new().map_err(|_| CryptoError::MalformedEccParameter)?;
320 group
321 .order(&mut order, &mut ctx)
322 .map_err(|_| CryptoError::MalformedEccParameter)?;
323
324 let priv_bn = BigNum::from_slice(priv_bytes).map_err(|_| CryptoError::MalformedEccParameter)?;
325
326 let mut ephemeral_pub_point =
327 EcPoint::new(&group).map_err(|_| CryptoError::MalformedEccParameter)?;
328 ephemeral_pub_point
329 .mul_generator(&group, &priv_bn, &ctx)
330 .map_err(|_| CryptoError::MalformedEccParameter)?;
331 let ephemeral_key = EcKey::from_private_components(&group, &priv_bn, &ephemeral_pub_point)
332 .map_err(|_| CryptoError::MalformedEccParameter)?;
333
334 let ephemeral_public_key =
335 PKey::from_ec_key(ephemeral_key).map_err(|_| CryptoError::MalformedEccParameter)?;
336 let mut deriver =
337 Deriver::new(&ephemeral_public_key).map_err(|_| CryptoError::MalformedEcdhSeed)?;
338 deriver
339 .set_peer(&parent_public_key)
340 .map_err(|_| CryptoError::MalformedEcdhSeed)?;
341 let z = deriver
342 .derive_to_vec()
343 .map_err(|_| CryptoError::MalformedEcdhSeed)?;
344
345 let ephemeral_pub_bytes = ephemeral_pub_point
346 .to_bytes(&group, PointConversionForm::UNCOMPRESSED, &mut ctx)
347 .map_err(|_| CryptoError::MalformedEccParameter)?;
348
349 if ephemeral_pub_bytes.is_empty() || ephemeral_pub_bytes[0] != UNCOMPRESSED_POINT_TAG {
350 return Err(CryptoError::InvalidEccPoint);
351 }
352
353 let coord_len = (ephemeral_pub_bytes.len() - 1) / 2;
354 let ephemeral_x = &ephemeral_pub_bytes[1..=coord_len];
355 let ephemeral_y = &ephemeral_pub_bytes[1 + coord_len..];
356
357 let seed_bits =
358 u16::try_from(hash_size(name_alg)? * 8).map_err(|_| CryptoError::MalformedEcdhSeed)?;
359
360 let context_u = ephemeral_x;
361 let context_v = parent_point.x.as_ref();
362
363 let seed = kdfe(
364 name_alg,
365 &z,
366 KDF_LABEL_DUPLICATE,
367 context_u,
368 context_v,
369 seed_bits,
370 )?;
371
372 let ephemeral_point_tpm = TpmsEccPoint {
373 x: Tpm2bEccParameter::try_from(ephemeral_x)
374 .map_err(|_| CryptoError::MalformedEccParameter)?,
375 y: Tpm2bEccParameter::try_from(ephemeral_y)
376 .map_err(|_| CryptoError::MalformedEccParameter)?,
377 };
378
379 Ok((seed, ephemeral_point_tpm))
380}