1#![allow(deprecated)] use aes_gcm::aead::{Aead, KeyInit, Payload};
11use aes_gcm::{Aes256Gcm, Nonce};
12use aes_kw::Kek;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use base64::Engine;
15use p256::PublicKey;
16use rand::RngCore;
17use rsa::{Oaep, RsaPublicKey};
18use serde_json::Value;
19use sha1::Sha1;
20
21use crate::errors::WebexError;
22
23pub fn encrypt_rsa_oaep_a256gcm(
25 plaintext: &[u8],
26 rsa_jwk: &Value,
27) -> Result<String, WebexError> {
28 let rsa_key = parse_rsa_public_key(rsa_jwk)?;
30
31 let mut cek = [0u8; 32];
33 rand::thread_rng().fill_bytes(&mut cek);
34
35 let mut iv = [0u8; 12];
37 rand::thread_rng().fill_bytes(&mut iv);
38
39 let mut header = serde_json::json!({"alg": "RSA-OAEP", "enc": "A256GCM"});
41 if let Some(kid) = rsa_jwk.get("kid").and_then(|v| v.as_str()) {
42 header["kid"] = Value::String(kid.to_string());
43 }
44 let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
45
46 let padding = Oaep::new::<Sha1>();
48 let encrypted_key = rsa_key
49 .encrypt(&mut rand::thread_rng(), padding, &cek)
50 .map_err(|e| WebexError::kms(format!("RSA-OAEP encryption failed: {e}")))?;
51
52 let cipher = Aes256Gcm::new_from_slice(&cek)
54 .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
55 let nonce = Nonce::from_slice(&iv);
56 let aad = header_b64.as_bytes();
57 let ciphertext_with_tag = cipher
58 .encrypt(nonce, Payload { msg: plaintext, aad })
59 .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
60
61 let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
63
64 Ok(format!(
66 "{}.{}.{}.{}.{}",
67 header_b64,
68 URL_SAFE_NO_PAD.encode(&encrypted_key),
69 URL_SAFE_NO_PAD.encode(&iv),
70 URL_SAFE_NO_PAD.encode(ciphertext),
71 URL_SAFE_NO_PAD.encode(tag),
72 ))
73}
74
75pub fn encrypt_dir_a256gcm(
77 plaintext: &[u8],
78 cek: &[u8; 32],
79 kid: &str,
80) -> Result<String, WebexError> {
81 let mut iv = [0u8; 12];
83 rand::thread_rng().fill_bytes(&mut iv);
84
85 let mut header = serde_json::json!({"alg": "dir", "enc": "A256GCM"});
87 if !kid.is_empty() {
88 header["kid"] = Value::String(kid.to_string());
89 }
90 let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
91
92 let cipher = Aes256Gcm::new_from_slice(cek)
94 .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
95 let nonce = Nonce::from_slice(&iv);
96 let aad = header_b64.as_bytes();
97 let ciphertext_with_tag = cipher
98 .encrypt(nonce, Payload { msg: plaintext, aad })
99 .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
100
101 let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
102
103 Ok(format!(
105 "{}.{}.{}.{}.{}",
106 header_b64,
107 "", URL_SAFE_NO_PAD.encode(&iv),
109 URL_SAFE_NO_PAD.encode(ciphertext),
110 URL_SAFE_NO_PAD.encode(tag),
111 ))
112}
113
114pub fn encrypt_a256kw_a256gcm(
116 plaintext: &[u8],
117 wrapping_key: &[u8; 32],
118) -> Result<String, WebexError> {
119 let mut cek = [0u8; 32];
121 rand::thread_rng().fill_bytes(&mut cek);
122
123 let mut iv = [0u8; 12];
125 rand::thread_rng().fill_bytes(&mut iv);
126
127 let header = serde_json::json!({"alg": "A256KW", "enc": "A256GCM"});
129 let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
130
131 let kek = Kek::from(*wrapping_key);
133 let mut wrapped_key = vec![0u8; cek.len() + 8]; kek.wrap(&cek, &mut wrapped_key)
135 .map_err(|e| WebexError::kms(format!("AES key wrap failed: {e}")))?;
136
137 let cipher = Aes256Gcm::new_from_slice(&cek)
139 .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
140 let nonce = Nonce::from_slice(&iv);
141 let aad = header_b64.as_bytes();
142 let ciphertext_with_tag = cipher
143 .encrypt(nonce, Payload { msg: plaintext, aad })
144 .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
145
146 let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
147
148 Ok(format!(
149 "{}.{}.{}.{}.{}",
150 header_b64,
151 URL_SAFE_NO_PAD.encode(&wrapped_key),
152 URL_SAFE_NO_PAD.encode(&iv),
153 URL_SAFE_NO_PAD.encode(ciphertext),
154 URL_SAFE_NO_PAD.encode(tag),
155 ))
156}
157
158pub fn decrypt_a256kw_a256gcm(
160 token: &str,
161 wrapping_key: &[u8; 32],
162) -> Result<Vec<u8>, WebexError> {
163 let parts = parse_jwe_compact(token)?;
164
165 let kek = Kek::from(*wrapping_key);
167 let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
168 kek.unwrap(&parts.encrypted_key, &mut cek)
169 .map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
170
171 decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
173}
174
175pub fn decrypt_dir_a256gcm(
177 token: &str,
178 cek: &[u8; 32],
179) -> Result<Vec<u8>, WebexError> {
180 let parts = parse_jwe_compact(token)?;
181
182 decrypt_a256gcm(cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
184}
185
186pub fn decrypt_message_jwe(
188 token: &str,
189 key: &[u8; 32],
190) -> Result<Vec<u8>, WebexError> {
191 let parts = parse_jwe_compact(token)?;
192
193 let header_json: Value = serde_json::from_slice(&parts.header_bytes)
195 .map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
196 let alg = header_json
197 .get("alg")
198 .and_then(|v| v.as_str())
199 .unwrap_or("");
200
201 match alg {
202 "dir" => {
203 decrypt_a256gcm(key, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
205 }
206 "A256KW" => {
207 let kek = Kek::from(*key);
209 let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
210 kek.unwrap(&parts.encrypted_key, &mut cek)
211 .map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
212 decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
213 }
214 _ => Err(WebexError::kms(format!(
215 "Unsupported message JWE algorithm: {alg}"
216 ))),
217 }
218}
219
220pub fn decrypt_ecdh_es(
225 token: &str,
226 local_private_key: &p256::SecretKey,
227) -> Result<Vec<u8>, WebexError> {
228 let parts = parse_jwe_compact(token)?;
229
230 let header_json: Value = serde_json::from_slice(&parts.header_bytes)
232 .map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
233
234 let alg = header_json
235 .get("alg")
236 .and_then(|v| v.as_str())
237 .unwrap_or("");
238 let enc = header_json
239 .get("enc")
240 .and_then(|v| v.as_str())
241 .unwrap_or("A256GCM");
242
243 let epk = header_json
245 .get("epk")
246 .ok_or_else(|| WebexError::kms("No epk in ECDH-ES JWE header"))?;
247
248 let server_public = parse_ec_public_key(epk)?;
249
250 let shared_secret = p256::ecdh::diffie_hellman(
252 local_private_key.to_nonzero_scalar(),
253 server_public.as_affine(),
254 );
255
256 let apu = header_json
258 .get("apu")
259 .and_then(|v| v.as_str())
260 .map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
261 .unwrap_or_default();
262 let apv = header_json
263 .get("apv")
264 .and_then(|v| v.as_str())
265 .map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
266 .unwrap_or_default();
267
268 match alg {
269 "ECDH-ES" => {
270 let key_len = enc_key_length(enc);
272 let cek = concat_kdf(
273 shared_secret.raw_secret_bytes(),
274 enc, &apu,
276 &apv,
277 (key_len * 8) as u32,
278 )?;
279
280 decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
281 }
282 "ECDH-ES+A256KW" => {
283 let kek_bytes = concat_kdf(
285 shared_secret.raw_secret_bytes(),
286 "A256KW",
287 &apu,
288 &apv,
289 256,
290 )?;
291
292 let kek_arr: [u8; 32] = kek_bytes
293 .try_into()
294 .map_err(|_| WebexError::kms("Derived KEK is not 32 bytes"))?;
295 let kek = Kek::from(kek_arr);
296
297 let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
298 kek.unwrap(&parts.encrypted_key, &mut cek)
299 .map_err(|e| WebexError::kms(format!("ECDH-ES+A256KW unwrap failed: {e}")))?;
300
301 decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
302 }
303 _ => Err(WebexError::kms(format!("Unsupported ECDH algorithm: {alg}"))),
304 }
305}
306
307pub fn decrypt_jwe(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
309 match key {
310 JweKey::Symmetric(k) => decrypt_message_jwe(token, k),
311 JweKey::EcdhPrivate(k) => decrypt_ecdh_es(token, k),
312 }
313}
314
315pub fn unwrap_kms_response(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
318 let dot_count = token.chars().filter(|&c| c == '.').count();
319 match dot_count {
320 4 => decrypt_jwe(token, key),
321 2 => {
322 let parts: Vec<&str> = token.split('.').collect();
324 URL_SAFE_NO_PAD
325 .decode(parts[1])
326 .map_err(|e| WebexError::kms(format!("Failed to decode JWS payload: {e}")))
327 }
328 _ => Err(WebexError::kms(format!(
329 "Invalid KMS response format: expected 3 or 5 parts, got {} dots",
330 dot_count
331 ))),
332 }
333}
334
335pub enum JweKey {
337 Symmetric([u8; 32]),
339 EcdhPrivate(p256::SecretKey),
341}
342
343struct JweParts {
346 header_b64: String,
347 header_bytes: Vec<u8>,
348 encrypted_key: Vec<u8>,
349 iv: Vec<u8>,
350 ciphertext: Vec<u8>,
351 tag: Vec<u8>,
352}
353
354fn parse_jwe_compact(token: &str) -> Result<JweParts, WebexError> {
355 let parts: Vec<&str> = token.split('.').collect();
356 if parts.len() != 5 {
357 return Err(WebexError::kms(format!(
358 "Invalid JWE compact: expected 5 parts, got {}",
359 parts.len()
360 )));
361 }
362
363 let header_b64 = parts[0].to_string();
364 let header_bytes = URL_SAFE_NO_PAD
365 .decode(parts[0])
366 .map_err(|e| WebexError::kms(format!("Failed to decode JWE header: {e}")))?;
367 let encrypted_key = URL_SAFE_NO_PAD
368 .decode(parts[1])
369 .map_err(|e| WebexError::kms(format!("Failed to decode encrypted key: {e}")))?;
370 let iv = URL_SAFE_NO_PAD
371 .decode(parts[2])
372 .map_err(|e| WebexError::kms(format!("Failed to decode IV: {e}")))?;
373 let ciphertext = URL_SAFE_NO_PAD
374 .decode(parts[3])
375 .map_err(|e| WebexError::kms(format!("Failed to decode ciphertext: {e}")))?;
376 let tag = URL_SAFE_NO_PAD
377 .decode(parts[4])
378 .map_err(|e| WebexError::kms(format!("Failed to decode tag: {e}")))?;
379
380 Ok(JweParts {
381 header_b64,
382 header_bytes,
383 encrypted_key,
384 iv,
385 ciphertext,
386 tag,
387 })
388}
389
390fn decrypt_a256gcm(
391 cek: &[u8],
392 iv: &[u8],
393 ciphertext: &[u8],
394 tag: &[u8],
395 aad: &str,
396) -> Result<Vec<u8>, WebexError> {
397 let cipher = Aes256Gcm::new_from_slice(cek)
398 .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
399
400 let nonce = Nonce::from_slice(iv);
401
402 let mut ct_with_tag = ciphertext.to_vec();
404 ct_with_tag.extend_from_slice(tag);
405
406 let plaintext = cipher
407 .decrypt(
408 nonce,
409 Payload {
410 msg: &ct_with_tag,
411 aad: aad.as_bytes(),
412 },
413 )
414 .map_err(|e| WebexError::kms(format!("AES-GCM decryption failed: {e}")))?;
415
416 Ok(plaintext)
417}
418
419fn concat_kdf(
421 shared_secret: &[u8],
422 algorithm_id: &str,
423 apu: &[u8],
424 apv: &[u8],
425 key_data_len_bits: u32,
426) -> Result<Vec<u8>, WebexError> {
427 use sha2::{Digest, Sha256};
428
429 let key_data_len = (key_data_len_bits / 8) as usize;
430 let reps = (key_data_len + 31) / 32; let mut derived = Vec::with_capacity(key_data_len);
433
434 for counter in 1..=reps as u32 {
435 let mut hasher = Sha256::new();
436 hasher.update(counter.to_be_bytes());
437 hasher.update(shared_secret);
438
439 hasher.update((algorithm_id.len() as u32).to_be_bytes());
442 hasher.update(algorithm_id.as_bytes());
443
444 hasher.update((apu.len() as u32).to_be_bytes());
446 hasher.update(apu);
447
448 hasher.update((apv.len() as u32).to_be_bytes());
450 hasher.update(apv);
451
452 hasher.update(key_data_len_bits.to_be_bytes());
454
455 derived.extend_from_slice(&hasher.finalize());
456 }
457
458 derived.truncate(key_data_len);
459 Ok(derived)
460}
461
462fn enc_key_length(enc: &str) -> usize {
463 match enc {
464 "A128GCM" => 16,
465 "A192GCM" => 24,
466 "A256GCM" => 32,
467 "A128CBC-HS256" => 32,
468 "A256CBC-HS512" => 64,
469 _ => 32, }
471}
472
473fn parse_rsa_public_key(jwk: &Value) -> Result<RsaPublicKey, WebexError> {
474 let n = jwk
475 .get("n")
476 .and_then(|v| v.as_str())
477 .ok_or_else(|| WebexError::kms("Missing 'n' in RSA JWK"))?;
478 let e = jwk
479 .get("e")
480 .and_then(|v| v.as_str())
481 .ok_or_else(|| WebexError::kms("Missing 'e' in RSA JWK"))?;
482
483 let n_bytes = URL_SAFE_NO_PAD
484 .decode(n)
485 .map_err(|e| WebexError::kms(format!("Failed to decode RSA n: {e}")))?;
486 let e_bytes = URL_SAFE_NO_PAD
487 .decode(e)
488 .map_err(|e| WebexError::kms(format!("Failed to decode RSA e: {e}")))?;
489
490 let n_uint = rsa::BigUint::from_bytes_be(&n_bytes);
491 let e_uint = rsa::BigUint::from_bytes_be(&e_bytes);
492
493 RsaPublicKey::new(n_uint, e_uint)
494 .map_err(|e| WebexError::kms(format!("Invalid RSA public key: {e}")))
495}
496
497fn parse_ec_public_key(jwk: &Value) -> Result<PublicKey, WebexError> {
498 let x = jwk
499 .get("x")
500 .and_then(|v| v.as_str())
501 .ok_or_else(|| WebexError::kms("Missing 'x' in EC JWK"))?;
502 let y = jwk
503 .get("y")
504 .and_then(|v| v.as_str())
505 .ok_or_else(|| WebexError::kms("Missing 'y' in EC JWK"))?;
506
507 let x_bytes = URL_SAFE_NO_PAD
508 .decode(x)
509 .map_err(|e| WebexError::kms(format!("Failed to decode EC x: {e}")))?;
510 let y_bytes = URL_SAFE_NO_PAD
511 .decode(y)
512 .map_err(|e| WebexError::kms(format!("Failed to decode EC y: {e}")))?;
513
514 let mut uncompressed = vec![0x04];
516 uncompressed.extend_from_slice(&x_bytes);
517 uncompressed.extend_from_slice(&y_bytes);
518
519 PublicKey::from_sec1_bytes(&uncompressed)
520 .map_err(|e| WebexError::kms(format!("Invalid EC public key: {e}")))
521}