Skip to main content

webex_message_handler/
jwe.rs

1//! Pure-Rust JWE compact serialization (RFC 7516) for the algorithms used by Webex KMS.
2#![allow(deprecated)] // aes-gcm Nonce::from_slice uses deprecated generic-array API
3//!
4//! Supported algorithms:
5//! - RSA-OAEP + A256GCM  (encrypt only — initial ECDH handshake)
6//! - A256KW + A256GCM     (encrypt + decrypt — key retrieval, message decryption)
7//! - ECDH-ES + A256GCM    (decrypt only — ECDH handshake response)
8//! - ECDH-ES+A256KW + A256GCM (decrypt only — ECDH handshake response variant)
9
10use aes_gcm::aead::{Aead, KeyInit, Payload};
11use aes_gcm::{Aes256Gcm, Nonce};
12use aes_kw::Kek;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use base64::Engine;
15use p256::PublicKey;
16use rand::RngCore;
17use rsa::{Oaep, RsaPublicKey};
18use serde_json::Value;
19use sha1::Sha1;
20
21use crate::errors::WebexError;
22
23/// Encrypt plaintext using RSA-OAEP + A256GCM and return JWE compact serialization.
24pub fn encrypt_rsa_oaep_a256gcm(
25    plaintext: &[u8],
26    rsa_jwk: &Value,
27) -> Result<String, WebexError> {
28    // Parse RSA public key from JWK
29    let rsa_key = parse_rsa_public_key(rsa_jwk)?;
30
31    // Generate random CEK (32 bytes for A256GCM)
32    let mut cek = [0u8; 32];
33    rand::thread_rng().fill_bytes(&mut cek);
34
35    // Generate random IV (12 bytes)
36    let mut iv = [0u8; 12];
37    rand::thread_rng().fill_bytes(&mut iv);
38
39    // Build protected header (include kid if present so KMS can identify the decryption key)
40    let mut header = serde_json::json!({"alg": "RSA-OAEP", "enc": "A256GCM"});
41    if let Some(kid) = rsa_jwk.get("kid").and_then(|v| v.as_str()) {
42        header["kid"] = Value::String(kid.to_string());
43    }
44    let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
45
46    // Encrypt CEK with RSA-OAEP (SHA-1 per RFC 7518 "RSA-OAEP")
47    let padding = Oaep::new::<Sha1>();
48    let encrypted_key = rsa_key
49        .encrypt(&mut rand::thread_rng(), padding, &cek)
50        .map_err(|e| WebexError::kms(format!("RSA-OAEP encryption failed: {e}")))?;
51
52    // Encrypt plaintext with A256GCM
53    let cipher = Aes256Gcm::new_from_slice(&cek)
54        .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
55    let nonce = Nonce::from_slice(&iv);
56    let aad = header_b64.as_bytes();
57    let ciphertext_with_tag = cipher
58        .encrypt(nonce, Payload { msg: plaintext, aad })
59        .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
60
61    // Split ciphertext and tag (last 16 bytes are the tag)
62    let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
63
64    // JWE compact: header.encrypted_key.iv.ciphertext.tag
65    Ok(format!(
66        "{}.{}.{}.{}.{}",
67        header_b64,
68        URL_SAFE_NO_PAD.encode(&encrypted_key),
69        URL_SAFE_NO_PAD.encode(&iv),
70        URL_SAFE_NO_PAD.encode(ciphertext),
71        URL_SAFE_NO_PAD.encode(tag),
72    ))
73}
74
75/// Encrypt plaintext using dir + A256GCM (direct encryption — key is CEK).
76pub fn encrypt_dir_a256gcm(
77    plaintext: &[u8],
78    cek: &[u8; 32],
79    kid: &str,
80) -> Result<String, WebexError> {
81    // Generate random IV (12 bytes)
82    let mut iv = [0u8; 12];
83    rand::thread_rng().fill_bytes(&mut iv);
84
85    // Build protected header
86    let mut header = serde_json::json!({"alg": "dir", "enc": "A256GCM"});
87    if !kid.is_empty() {
88        header["kid"] = Value::String(kid.to_string());
89    }
90    let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
91
92    // Encrypt plaintext with A256GCM using the key directly as CEK
93    let cipher = Aes256Gcm::new_from_slice(cek)
94        .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
95    let nonce = Nonce::from_slice(&iv);
96    let aad = header_b64.as_bytes();
97    let ciphertext_with_tag = cipher
98        .encrypt(nonce, Payload { msg: plaintext, aad })
99        .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
100
101    let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
102
103    // JWE compact with dir: header . "" . iv . ciphertext . tag (empty encrypted key)
104    Ok(format!(
105        "{}.{}.{}.{}.{}",
106        header_b64,
107        "",  // empty encrypted key for dir
108        URL_SAFE_NO_PAD.encode(&iv),
109        URL_SAFE_NO_PAD.encode(ciphertext),
110        URL_SAFE_NO_PAD.encode(tag),
111    ))
112}
113
114/// Encrypt plaintext using A256KW + A256GCM and return JWE compact serialization.
115pub fn encrypt_a256kw_a256gcm(
116    plaintext: &[u8],
117    wrapping_key: &[u8; 32],
118) -> Result<String, WebexError> {
119    // Generate random CEK (32 bytes)
120    let mut cek = [0u8; 32];
121    rand::thread_rng().fill_bytes(&mut cek);
122
123    // Generate random IV (12 bytes)
124    let mut iv = [0u8; 12];
125    rand::thread_rng().fill_bytes(&mut iv);
126
127    // Build protected header
128    let header = serde_json::json!({"alg": "A256KW", "enc": "A256GCM"});
129    let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
130
131    // AES Key Wrap the CEK
132    let kek = Kek::from(*wrapping_key);
133    let mut wrapped_key = vec![0u8; cek.len() + 8]; // wrapped key is 8 bytes longer
134    kek.wrap(&cek, &mut wrapped_key)
135        .map_err(|e| WebexError::kms(format!("AES key wrap failed: {e}")))?;
136
137    // Encrypt plaintext with A256GCM
138    let cipher = Aes256Gcm::new_from_slice(&cek)
139        .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
140    let nonce = Nonce::from_slice(&iv);
141    let aad = header_b64.as_bytes();
142    let ciphertext_with_tag = cipher
143        .encrypt(nonce, Payload { msg: plaintext, aad })
144        .map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
145
146    let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
147
148    Ok(format!(
149        "{}.{}.{}.{}.{}",
150        header_b64,
151        URL_SAFE_NO_PAD.encode(&wrapped_key),
152        URL_SAFE_NO_PAD.encode(&iv),
153        URL_SAFE_NO_PAD.encode(ciphertext),
154        URL_SAFE_NO_PAD.encode(tag),
155    ))
156}
157
158/// Decrypt a JWE compact serialization token using A256KW + A256GCM.
159pub fn decrypt_a256kw_a256gcm(
160    token: &str,
161    wrapping_key: &[u8; 32],
162) -> Result<Vec<u8>, WebexError> {
163    let parts = parse_jwe_compact(token)?;
164
165    // AES Key Unwrap to get CEK
166    let kek = Kek::from(*wrapping_key);
167    let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
168    kek.unwrap(&parts.encrypted_key, &mut cek)
169        .map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
170
171    // Decrypt with A256GCM
172    decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
173}
174
175/// Decrypt a JWE compact serialization token using dir + A256GCM (direct key).
176pub fn decrypt_dir_a256gcm(
177    token: &str,
178    cek: &[u8; 32],
179) -> Result<Vec<u8>, WebexError> {
180    let parts = parse_jwe_compact(token)?;
181
182    // With "dir", the encrypted_key part should be empty — the provided key IS the CEK
183    decrypt_a256gcm(cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
184}
185
186/// Decrypt a JWE message, auto-detecting "dir" vs "A256KW" from the header.
187pub fn decrypt_message_jwe(
188    token: &str,
189    key: &[u8; 32],
190) -> Result<Vec<u8>, WebexError> {
191    let parts = parse_jwe_compact(token)?;
192
193    // Parse the header to detect algorithm
194    let header_json: Value = serde_json::from_slice(&parts.header_bytes)
195        .map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
196    let alg = header_json
197        .get("alg")
198        .and_then(|v| v.as_str())
199        .unwrap_or("");
200
201    match alg {
202        "dir" => {
203            // Direct: key IS the CEK
204            decrypt_a256gcm(key, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
205        }
206        "A256KW" => {
207            // Key wrapping: unwrap CEK first
208            let kek = Kek::from(*key);
209            let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
210            kek.unwrap(&parts.encrypted_key, &mut cek)
211                .map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
212            decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
213        }
214        _ => Err(WebexError::kms(format!(
215            "Unsupported message JWE algorithm: {alg}"
216        ))),
217    }
218}
219
220/// Decrypt a JWE compact serialization token encrypted with ECDH-ES (or ECDH-ES+A256KW).
221///
222/// The JWE header contains `epk` (the server's ephemeral public key).
223/// We use our local private key + the server's epk to derive the decryption key.
224pub fn decrypt_ecdh_es(
225    token: &str,
226    local_private_key: &p256::SecretKey,
227) -> Result<Vec<u8>, WebexError> {
228    let parts = parse_jwe_compact(token)?;
229
230    // Parse the protected header to get the algorithm and epk
231    let header_json: Value = serde_json::from_slice(&parts.header_bytes)
232        .map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
233
234    let alg = header_json
235        .get("alg")
236        .and_then(|v| v.as_str())
237        .unwrap_or("");
238    let enc = header_json
239        .get("enc")
240        .and_then(|v| v.as_str())
241        .unwrap_or("A256GCM");
242
243    // Extract the server's ephemeral public key from the header
244    let epk = header_json
245        .get("epk")
246        .ok_or_else(|| WebexError::kms("No epk in ECDH-ES JWE header"))?;
247
248    let server_public = parse_ec_public_key(epk)?;
249
250    // Perform ECDH
251    let shared_secret = p256::ecdh::diffie_hellman(
252        local_private_key.to_nonzero_scalar(),
253        server_public.as_affine(),
254    );
255
256    // Extract apu and apv from header (optional)
257    let apu = header_json
258        .get("apu")
259        .and_then(|v| v.as_str())
260        .map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
261        .unwrap_or_default();
262    let apv = header_json
263        .get("apv")
264        .and_then(|v| v.as_str())
265        .map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
266        .unwrap_or_default();
267
268    match alg {
269        "ECDH-ES" => {
270            // Direct key agreement — derive CEK directly
271            let key_len = enc_key_length(enc);
272            let cek = concat_kdf(
273                shared_secret.raw_secret_bytes(),
274                enc, // for direct, algorithm ID is the enc algorithm
275                &apu,
276                &apv,
277                (key_len * 8) as u32,
278            )?;
279
280            decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
281        }
282        "ECDH-ES+A256KW" => {
283            // Key agreement with key wrapping — derive KEK, then unwrap CEK
284            let kek_bytes = concat_kdf(
285                shared_secret.raw_secret_bytes(),
286                "A256KW",
287                &apu,
288                &apv,
289                256,
290            )?;
291
292            let kek_arr: [u8; 32] = kek_bytes
293                .try_into()
294                .map_err(|_| WebexError::kms("Derived KEK is not 32 bytes"))?;
295            let kek = Kek::from(kek_arr);
296
297            let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
298            kek.unwrap(&parts.encrypted_key, &mut cek)
299                .map_err(|e| WebexError::kms(format!("ECDH-ES+A256KW unwrap failed: {e}")))?;
300
301            decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
302        }
303        _ => Err(WebexError::kms(format!("Unsupported ECDH algorithm: {alg}"))),
304    }
305}
306
307/// Generic JWE decryption that auto-detects the algorithm from the header.
308pub fn decrypt_jwe(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
309    match key {
310        JweKey::Symmetric(k) => decrypt_message_jwe(token, k),
311        JweKey::EcdhPrivate(k) => decrypt_ecdh_es(token, k),
312    }
313}
314
315/// Unwrap a KMS response — may be JWE (encrypted, 5 parts) or JWS (signed, 3 parts).
316/// For JWS, the payload is extracted directly (arrives over authenticated Mercury channel).
317pub fn unwrap_kms_response(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
318    let dot_count = token.chars().filter(|&c| c == '.').count();
319    match dot_count {
320        4 => decrypt_jwe(token, key),
321        2 => {
322            // JWS compact: header.payload.signature — extract payload
323            let parts: Vec<&str> = token.split('.').collect();
324            URL_SAFE_NO_PAD
325                .decode(parts[1])
326                .map_err(|e| WebexError::kms(format!("Failed to decode JWS payload: {e}")))
327        }
328        _ => Err(WebexError::kms(format!(
329            "Invalid KMS response format: expected 3 or 5 parts, got {} dots",
330            dot_count
331        ))),
332    }
333}
334
335/// Key types for JWE decryption.
336pub enum JweKey {
337    /// A 256-bit symmetric key for A256KW.
338    Symmetric([u8; 32]),
339    /// An ECDH P-256 private key for ECDH-ES.
340    EcdhPrivate(p256::SecretKey),
341}
342
343// ──────────────────────────── Internal helpers ────────────────────────────
344
345struct JweParts {
346    header_b64: String,
347    header_bytes: Vec<u8>,
348    encrypted_key: Vec<u8>,
349    iv: Vec<u8>,
350    ciphertext: Vec<u8>,
351    tag: Vec<u8>,
352}
353
354fn parse_jwe_compact(token: &str) -> Result<JweParts, WebexError> {
355    let parts: Vec<&str> = token.split('.').collect();
356    if parts.len() != 5 {
357        return Err(WebexError::kms(format!(
358            "Invalid JWE compact: expected 5 parts, got {}",
359            parts.len()
360        )));
361    }
362
363    let header_b64 = parts[0].to_string();
364    let header_bytes = URL_SAFE_NO_PAD
365        .decode(parts[0])
366        .map_err(|e| WebexError::kms(format!("Failed to decode JWE header: {e}")))?;
367    let encrypted_key = URL_SAFE_NO_PAD
368        .decode(parts[1])
369        .map_err(|e| WebexError::kms(format!("Failed to decode encrypted key: {e}")))?;
370    let iv = URL_SAFE_NO_PAD
371        .decode(parts[2])
372        .map_err(|e| WebexError::kms(format!("Failed to decode IV: {e}")))?;
373    let ciphertext = URL_SAFE_NO_PAD
374        .decode(parts[3])
375        .map_err(|e| WebexError::kms(format!("Failed to decode ciphertext: {e}")))?;
376    let tag = URL_SAFE_NO_PAD
377        .decode(parts[4])
378        .map_err(|e| WebexError::kms(format!("Failed to decode tag: {e}")))?;
379
380    Ok(JweParts {
381        header_b64,
382        header_bytes,
383        encrypted_key,
384        iv,
385        ciphertext,
386        tag,
387    })
388}
389
390fn decrypt_a256gcm(
391    cek: &[u8],
392    iv: &[u8],
393    ciphertext: &[u8],
394    tag: &[u8],
395    aad: &str,
396) -> Result<Vec<u8>, WebexError> {
397    let cipher = Aes256Gcm::new_from_slice(cek)
398        .map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
399
400    let nonce = Nonce::from_slice(iv);
401
402    // Combine ciphertext and tag for aes-gcm
403    let mut ct_with_tag = ciphertext.to_vec();
404    ct_with_tag.extend_from_slice(tag);
405
406    let plaintext = cipher
407        .decrypt(
408            nonce,
409            Payload {
410                msg: &ct_with_tag,
411                aad: aad.as_bytes(),
412            },
413        )
414        .map_err(|e| WebexError::kms(format!("AES-GCM decryption failed: {e}")))?;
415
416    Ok(plaintext)
417}
418
419/// Concat KDF (NIST SP 800-56A, used by JWE ECDH-ES per RFC 7518 §4.6.2).
420fn concat_kdf(
421    shared_secret: &[u8],
422    algorithm_id: &str,
423    apu: &[u8],
424    apv: &[u8],
425    key_data_len_bits: u32,
426) -> Result<Vec<u8>, WebexError> {
427    use sha2::{Digest, Sha256};
428
429    let key_data_len = (key_data_len_bits / 8) as usize;
430    let reps = (key_data_len + 31) / 32; // ceil(keyDataLen / hashLen)
431
432    let mut derived = Vec::with_capacity(key_data_len);
433
434    for counter in 1..=reps as u32 {
435        let mut hasher = Sha256::new();
436        hasher.update(counter.to_be_bytes());
437        hasher.update(shared_secret);
438
439        // OtherInfo = AlgorithmID || PartyUInfo || PartyVInfo || SuppPubInfo
440        // AlgorithmID: length(4 bytes) || value
441        hasher.update((algorithm_id.len() as u32).to_be_bytes());
442        hasher.update(algorithm_id.as_bytes());
443
444        // PartyUInfo: length(4 bytes) || value
445        hasher.update((apu.len() as u32).to_be_bytes());
446        hasher.update(apu);
447
448        // PartyVInfo: length(4 bytes) || value
449        hasher.update((apv.len() as u32).to_be_bytes());
450        hasher.update(apv);
451
452        // SuppPubInfo: key length in bits (4 bytes big-endian)
453        hasher.update(key_data_len_bits.to_be_bytes());
454
455        derived.extend_from_slice(&hasher.finalize());
456    }
457
458    derived.truncate(key_data_len);
459    Ok(derived)
460}
461
462fn enc_key_length(enc: &str) -> usize {
463    match enc {
464        "A128GCM" => 16,
465        "A192GCM" => 24,
466        "A256GCM" => 32,
467        "A128CBC-HS256" => 32,
468        "A256CBC-HS512" => 64,
469        _ => 32, // default to 256-bit
470    }
471}
472
473fn parse_rsa_public_key(jwk: &Value) -> Result<RsaPublicKey, WebexError> {
474    let n = jwk
475        .get("n")
476        .and_then(|v| v.as_str())
477        .ok_or_else(|| WebexError::kms("Missing 'n' in RSA JWK"))?;
478    let e = jwk
479        .get("e")
480        .and_then(|v| v.as_str())
481        .ok_or_else(|| WebexError::kms("Missing 'e' in RSA JWK"))?;
482
483    let n_bytes = URL_SAFE_NO_PAD
484        .decode(n)
485        .map_err(|e| WebexError::kms(format!("Failed to decode RSA n: {e}")))?;
486    let e_bytes = URL_SAFE_NO_PAD
487        .decode(e)
488        .map_err(|e| WebexError::kms(format!("Failed to decode RSA e: {e}")))?;
489
490    let n_uint = rsa::BigUint::from_bytes_be(&n_bytes);
491    let e_uint = rsa::BigUint::from_bytes_be(&e_bytes);
492
493    RsaPublicKey::new(n_uint, e_uint)
494        .map_err(|e| WebexError::kms(format!("Invalid RSA public key: {e}")))
495}
496
497fn parse_ec_public_key(jwk: &Value) -> Result<PublicKey, WebexError> {
498    let x = jwk
499        .get("x")
500        .and_then(|v| v.as_str())
501        .ok_or_else(|| WebexError::kms("Missing 'x' in EC JWK"))?;
502    let y = jwk
503        .get("y")
504        .and_then(|v| v.as_str())
505        .ok_or_else(|| WebexError::kms("Missing 'y' in EC JWK"))?;
506
507    let x_bytes = URL_SAFE_NO_PAD
508        .decode(x)
509        .map_err(|e| WebexError::kms(format!("Failed to decode EC x: {e}")))?;
510    let y_bytes = URL_SAFE_NO_PAD
511        .decode(y)
512        .map_err(|e| WebexError::kms(format!("Failed to decode EC y: {e}")))?;
513
514    // Build uncompressed point: 0x04 || x || y
515    let mut uncompressed = vec![0x04];
516    uncompressed.extend_from_slice(&x_bytes);
517    uncompressed.extend_from_slice(&y_bytes);
518
519    PublicKey::from_sec1_bytes(&uncompressed)
520        .map_err(|e| WebexError::kms(format!("Invalid EC public key: {e}")))
521}