1#![deny(clippy::all)]
8#![deny(clippy::pedantic)]
9
10use num_bigint::{BigUint, RandBigInt};
11use openssl::{
12 bn::{BigNum, BigNumContext},
13 derive::Deriver,
14 ec::{EcGroup, EcKey, EcPoint, PointConversionForm},
15 error::ErrorStack,
16 hash::{Hasher, MessageDigest},
17 memcmp,
18 nid::Nid,
19 pkey::PKey,
20 sign::Signer,
21};
22use rand::{CryptoRng, RngCore};
23use thiserror::Error;
24use tpm2_protocol::{
25 constant::TPM_MAX_COMMAND_SIZE,
26 data::{Tpm2bEccParameter, Tpm2bName, TpmAlgId, TpmEccCurve, TpmsEccPoint, TpmtPublic},
27 TpmBuild, TpmError, TpmWriter,
28};
29
30pub const UNCOMPRESSED_POINT_TAG: u8 = 0x04;
31
32pub const KDF_LABEL_DUPLICATE: &str = "DUPLICATE";
33pub const KDF_LABEL_INTEGRITY: &str = "INTEGRITY";
34pub const KDF_LABEL_STORAGE: &str = "STORAGE";
35
36#[derive(Debug, Error)]
37pub enum CryptoError {
38 #[error("OpenSSL error: {0}")]
39 OpenSslError(#[from] ErrorStack),
40 #[error("big number conversion failed")]
41 BigNumConversion,
42 #[error("invalid ECC point")]
43 InvalidEccPoint,
44 #[error("unsupported ECC curve")]
45 UnsupportedEccCurve,
46 #[error("invalid hash algorithm")]
47 InvalidHashAlgorithm,
48 #[error("HMAC mismatch")]
49 HmacMismatch,
50 #[error("invalid public area")]
51 InvalidPublicArea,
52 #[error("malformed ECC parameter")]
53 MalformedEccParameter,
54 #[error("malformed ECDH seed")]
55 MalformedEcdhSeed,
56 #[error("malformed HMAC key")]
57 MalformedHmacKey,
58 #[error("malformed name")]
59 MalformedName,
60}
61
62impl From<TpmError> for CryptoError {
63 fn from(_: TpmError) -> Self {
64 CryptoError::MalformedEccParameter
65 }
66}
67
68fn map_tpm_alg_to_md(alg: TpmAlgId) -> Result<MessageDigest, CryptoError> {
70 match alg {
71 TpmAlgId::Sha1 => Ok(MessageDigest::sha1()),
72 TpmAlgId::Sha256 | TpmAlgId::Sm3_256 => Ok(MessageDigest::sha256()),
73 TpmAlgId::Sha384 => Ok(MessageDigest::sha384()),
74 TpmAlgId::Sha512 => Ok(MessageDigest::sha512()),
75 _ => Err(CryptoError::InvalidHashAlgorithm),
76 }
77}
78
79fn map_ecc_curve_to_nid(curve_id: TpmEccCurve) -> Result<Nid, CryptoError> {
81 match curve_id {
82 TpmEccCurve::NistP256 => Ok(Nid::X9_62_PRIME256V1),
83 TpmEccCurve::NistP384 => Ok(Nid::SECP384R1),
84 TpmEccCurve::NistP521 => Ok(Nid::SECP521R1),
85 _ => Err(CryptoError::UnsupportedEccCurve),
86 }
87}
88
89fn tpm_ecc_param_to_bignum(param: &Tpm2bEccParameter) -> Result<BigNum, CryptoError> {
91 BigNum::from_slice(param.as_ref()).map_err(|_| CryptoError::BigNumConversion)
92}
93
94pub fn hash_size(alg: TpmAlgId) -> Result<usize, CryptoError> {
101 Ok(map_tpm_alg_to_md(alg)?.size())
102}
103
104pub fn digest(alg: TpmAlgId, data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
111 let md = map_tpm_alg_to_md(alg)?;
112 let mut hasher = Hasher::new(md)?;
113 for chunk in data_chunks {
114 hasher.update(chunk)?;
115 }
116 Ok(hasher.finish()?.to_vec())
117}
118
119pub fn hmac(alg: TpmAlgId, key: &[u8], data_chunks: &[&[u8]]) -> Result<Vec<u8>, CryptoError> {
128 if key.is_empty() {
129 return Err(CryptoError::MalformedHmacKey);
130 }
131 let md = map_tpm_alg_to_md(alg)?;
132 let public_key = PKey::hmac(key)?;
133 let mut signer = Signer::new(md, &public_key)?;
134 for chunk in data_chunks {
135 signer.update(chunk)?;
136 }
137 Ok(signer.sign_to_vec()?)
138}
139
140pub fn hmac_verify(
151 alg: TpmAlgId,
152 key: &[u8],
153 data_chunks: &[&[u8]],
154 signature: &[u8],
155) -> Result<(), CryptoError> {
156 let expected = hmac(alg, key, data_chunks)?;
157 if memcmp::eq(&expected, signature) {
158 Ok(())
159 } else {
160 Err(CryptoError::HmacMismatch)
161 }
162}
163
164pub fn kdfa(
173 auth_hash: TpmAlgId,
174 hmac_key: &[u8],
175 label: &str,
176 context_a: &[u8],
177 context_b: &[u8],
178 key_bits: u16,
179) -> Result<Vec<u8>, CryptoError> {
180 let mut key_stream = Vec::new();
181 let key_bytes = (key_bits as usize).div_ceil(8);
182 let label_bytes = {
183 let mut bytes = label.as_bytes().to_vec();
184 bytes.push(0);
185 bytes
186 };
187
188 let mut counter: u32 = 1;
189 while key_stream.len() < key_bytes {
190 let counter_bytes = counter.to_be_bytes();
191 let key_bits_bytes = u32::from(key_bits).to_be_bytes();
192 let hmac_payload = [
193 counter_bytes.as_slice(),
194 label_bytes.as_slice(),
195 context_a,
196 context_b,
197 key_bits_bytes.as_slice(),
198 ];
199
200 let result = hmac(auth_hash, hmac_key, &hmac_payload)?;
201 let remaining = key_bytes - key_stream.len();
202 let to_take = remaining.min(result.len());
203 key_stream.extend_from_slice(&result[..to_take]);
204
205 counter += 1;
206 }
207
208 Ok(key_stream)
209}
210
211pub fn kdfe(
220 hash_alg: TpmAlgId,
221 z: &[u8],
222 label: &str,
223 context_u: &[u8],
224 context_v: &[u8],
225 key_bits: u16,
226) -> Result<Vec<u8>, CryptoError> {
227 let mut key_stream = Vec::new();
228 let key_bytes = (key_bits as usize).div_ceil(8);
229 let mut label_bytes = label.as_bytes().to_vec();
230 if label_bytes.last() != Some(&0) {
231 label_bytes.push(0);
232 }
233
234 let other_info = [label_bytes.as_slice(), context_u, context_v].concat();
235
236 let mut counter: u32 = 1;
237 while key_stream.len() < key_bytes {
238 let counter_bytes = counter.to_be_bytes();
239 let digest_payload = [&counter_bytes, z, &other_info];
240
241 let result = digest(hash_alg, &digest_payload)?;
242 let remaining = key_bytes - key_stream.len();
243 let to_take = remaining.min(result.len());
244 key_stream.extend_from_slice(&result[..to_take]);
245
246 counter += 1;
247 }
248
249 Ok(key_stream)
250}
251
252pub fn make_name(public: &TpmtPublic) -> Result<Tpm2bName, CryptoError> {
261 let name_alg = public.name_alg;
262
263 let mut name_buf = Vec::new();
264 name_buf.extend_from_slice(&(name_alg as u16).to_be_bytes());
265
266 let mut public_bytes = vec![0u8; TPM_MAX_COMMAND_SIZE];
267 let len = {
268 let mut writer = TpmWriter::new(&mut public_bytes);
269 public
270 .build(&mut writer)
271 .map_err(|_| CryptoError::InvalidPublicArea)?;
272 writer.len()
273 };
274 public_bytes.truncate(len);
275
276 let digest = digest(name_alg, &[&public_bytes])?;
277 name_buf.extend_from_slice(&digest);
278
279 Tpm2bName::try_from(name_buf.as_slice()).map_err(|_| CryptoError::MalformedName)
280}
281
282pub fn ecdh(
301 curve_id: TpmEccCurve,
302 parent_point: &TpmsEccPoint,
303 name_alg: TpmAlgId,
304 rng: &mut (impl RngCore + CryptoRng),
305) -> Result<(Vec<u8>, TpmsEccPoint), CryptoError> {
306 let nid = map_ecc_curve_to_nid(curve_id)?;
307 let group = EcGroup::from_curve_name(nid)?;
308 let mut ctx = BigNumContext::new()?;
309
310 let parent_x = tpm_ecc_param_to_bignum(&parent_point.x)?;
311 let parent_y = tpm_ecc_param_to_bignum(&parent_point.y)?;
312 let parent_key = EcKey::from_public_key_affine_coordinates(&group, &parent_x, &parent_y)?;
313 let parent_public_key = PKey::from_ec_key(parent_key)?;
314
315 let mut order = BigNum::new()?;
316 group.order(&mut order, &mut ctx)?;
317 let order_uint = BigUint::from_bytes_be(&order.to_vec());
318 let one = BigUint::from(1u8);
319
320 let priv_uint = rng.gen_biguint_range(&one, &order_uint);
321 let priv_bn = BigNum::from_slice(&priv_uint.to_bytes_be())?;
322
323 let mut ephemeral_pub_point = EcPoint::new(&group)?;
324 ephemeral_pub_point.mul_generator(&group, &priv_bn, &ctx)?;
325 let ephemeral_key = EcKey::from_private_components(&group, &priv_bn, &ephemeral_pub_point)?;
326
327 let ephemeral_public_key = PKey::from_ec_key(ephemeral_key)?;
328 let mut deriver = Deriver::new(&ephemeral_public_key)?;
329 deriver.set_peer(&parent_public_key)?;
330 let z = deriver.derive_to_vec()?;
331
332 let ephemeral_pub_bytes =
333 ephemeral_pub_point.to_bytes(&group, PointConversionForm::UNCOMPRESSED, &mut ctx)?;
334
335 if ephemeral_pub_bytes.is_empty() || ephemeral_pub_bytes[0] != UNCOMPRESSED_POINT_TAG {
336 return Err(CryptoError::InvalidEccPoint);
337 }
338
339 let coord_len = (ephemeral_pub_bytes.len() - 1) / 2;
340 let ephemeral_x = &ephemeral_pub_bytes[1..=coord_len];
341 let ephemeral_y = &ephemeral_pub_bytes[1 + coord_len..];
342
343 let seed_bits =
344 u16::try_from(hash_size(name_alg)? * 8).map_err(|_| CryptoError::MalformedEcdhSeed)?;
345
346 let context_u = ephemeral_x;
347 let context_v = parent_point.x.as_ref();
348
349 let seed = kdfe(
350 name_alg,
351 &z,
352 KDF_LABEL_DUPLICATE,
353 context_u,
354 context_v,
355 seed_bits,
356 )?;
357
358 let ephemeral_point_tpm = TpmsEccPoint {
359 x: Tpm2bEccParameter::try_from(ephemeral_x)?,
360 y: Tpm2bEccParameter::try_from(ephemeral_y)?,
361 };
362
363 Ok((seed, ephemeral_point_tpm))
364}