Skip to main content

webex_message_handler/
kms_client.rs

1//! KMS client for ECDH key exchange and encryption key retrieval.
2//!
3//! Implements the Webex KMS protocol:
4//! 1. Fetch KMS cluster details (rsaPublicKey, kmsCluster)
5//! 2. ECDH handshake: generate local P-256 keypair, wrap with RSA-OAEP,
6//!    send via HTTP, receive response via Mercury, derive shared key
7//! 3. Key retrieval: wrap request with ECDH-derived key, send via HTTP,
8//!    receive via Mercury, unwrap to get content key
9//! 4. Content keys are JWE A256KW + A256GCM
10
11use crate::errors::WebexError;
12use crate::jwe;
13use crate::types::{FetchFn, FetchRequest};
14use crate::url_validation::validate_webex_url;
15use base64::engine::general_purpose::URL_SAFE_NO_PAD;
16use base64::Engine;
17use p256::elliptic_curve::sec1::ToEncodedPoint;
18use p256::PublicKey;
19use serde_json::Value;
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::{oneshot, Mutex};
24use tracing::{debug, info, warn};
25use uuid::Uuid;
26
27const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
28const MAX_PENDING_REQUESTS: usize = 100;
29const MAX_KEY_CACHE_SIZE: usize = 100;
30
31/// A pending KMS request awaiting a Mercury response.
32struct PendingRequest {
33    tx: oneshot::Sender<String>,
34}
35
36/// Handle for resolving KMS responses from Mercury.
37///
38/// This is separated from KmsClient so the Mercury event loop can resolve
39/// pending requests without holding the KmsClient lock (which would deadlock
40/// during initialize/get_key).
41#[derive(Clone)]
42pub struct KmsResponseHandler {
43    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
44}
45
46impl KmsResponseHandler {
47    /// Handle a KMS response that arrived via Mercury WebSocket.
48    pub async fn handle_kms_message(&self, data: &Value) {
49        let kms_messages = data
50            .get("kmsMessages")
51            .and_then(|v| v.as_array())
52            .or_else(|| {
53                data.get("encryption")
54                    .and_then(|e| e.get("kmsMessages"))
55                    .and_then(|v| v.as_array())
56            });
57
58        let kms_messages = match kms_messages {
59            Some(msgs) => msgs,
60            None => {
61                debug!("Received KMS message without kmsMessages array");
62                return;
63            }
64        };
65
66        let mut pending = self.pending_requests.lock().await;
67
68        for raw_msg in kms_messages {
69            let wrapped = match raw_msg.as_str() {
70                Some(s) => s.to_string(),
71                None => continue,
72            };
73
74            debug!("Received KMS response, pending requests: {}", pending.len());
75
76            // Resolve the first pending request (FIFO)
77            if !pending.is_empty() {
78                let (_, req) = pending.remove(0);
79                let _ = req.tx.send(wrapped);
80            } else {
81                warn!("Received KMS response but no pending requests");
82            }
83        }
84    }
85}
86
87/// KMS client for Webex end-to-end encryption.
88pub struct KmsClient {
89    token: String,
90    device_url: String,
91    user_id: String,
92    encryption_service_url: String,
93    http_do: FetchFn,
94
95    kms_cluster: String,
96    /// The 256-bit symmetric key derived from ECDH, used as CEK for dir+A256GCM.
97    ephemeral_key: Option<[u8; 32]>,
98    /// The kid (key URI) for the ephemeral key, included in JWE headers.
99    ephemeral_key_kid: String,
100    context_expiration: Option<Instant>,
101    /// Cache of content encryption keys (key URI → raw 256-bit key bytes).
102    key_cache: HashMap<String, [u8; 32]>,
103    initialized: bool,
104
105    /// Pending KMS requests waiting for Mercury responses (FIFO).
106    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
107}
108
109impl KmsClient {
110    pub fn new(
111        http_do: FetchFn,
112        token: &str,
113        device_url: &str,
114        user_id: &str,
115        encryption_service_url: &str,
116    ) -> Self {
117        Self {
118            token: token.to_string(),
119            device_url: device_url.to_string(),
120            user_id: user_id.to_string(),
121            encryption_service_url: encryption_service_url.to_string(),
122            http_do,
123            kms_cluster: String::new(),
124            ephemeral_key: None,
125            ephemeral_key_kid: String::new(),
126            context_expiration: None,
127            key_cache: HashMap::new(),
128            initialized: false,
129            pending_requests: Arc::new(Mutex::new(Vec::new())),
130        }
131    }
132
133    /// Get a KmsResponseHandler that can resolve pending requests from Mercury.
134    ///
135    /// The returned handler can be used without holding the KmsClient lock,
136    /// preventing deadlocks during initialize() and get_key().
137    pub fn response_handler(&self) -> KmsResponseHandler {
138        KmsResponseHandler {
139            pending_requests: self.pending_requests.clone(),
140        }
141    }
142
143    /// Initialize KMS context with ECDH handshake.
144    pub async fn initialize(&mut self) -> Result<(), WebexError> {
145        info!("Initializing KMS client");
146
147        // Step 1: Fetch KMS details
148        let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
149
150        let mut headers = HashMap::new();
151        headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
152
153        let response = (self.http_do)(FetchRequest {
154            url: kms_details_url,
155            method: "GET".to_string(),
156            headers,
157            body: None,
158        })
159        .await
160        .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
161
162        if !response.ok {
163            return Err(WebexError::kms(format!(
164                "Failed to fetch KMS details: {}",
165                response.status
166            )));
167        }
168
169        let kms_details: Value = serde_json::from_slice(&response.body)
170            .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
171
172        self.kms_cluster = kms_details["kmsCluster"]
173            .as_str()
174            .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
175            .to_string();
176
177        // Validate KMS cluster URL
178        validate_webex_url(&self.kms_cluster, "https")
179            .map_err(|e| WebexError::kms(format!("Invalid kmsCluster URL: {e}")))?;
180
181        // Parse RSA public key (may be string or object)
182        let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
183            Value::String(s) => serde_json::from_str::<Value>(s)
184                .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
185            v @ Value::Object(_) => v.clone(),
186            _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
187        };
188
189        // Step 2: Generate local ECDH keypair (P-256)
190        let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
191        let local_public = local_secret.public_key();
192        let local_public_point = local_public.to_encoded_point(false);
193
194        let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
195        let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
196
197        let public_jwk_map = serde_json::json!({
198            "kty": "EC",
199            "crv": "P-256",
200            "x": URL_SAFE_NO_PAD.encode(*x_bytes),
201            "y": URL_SAFE_NO_PAD.encode(*y_bytes),
202        });
203
204        // Step 3: Build ECDH request body
205        let request_id = Uuid::new_v4().to_string();
206        let ecdh_request_body = serde_json::json!({
207            "client": {
208                "clientId": self.device_url,
209                "credential": {
210                    "userId": self.user_id,
211                    "bearer": self.token,
212                },
213            },
214            "method": "create",
215            "uri": format!("{}/ecdhe", self.kms_cluster),
216            "requestId": request_id,
217            "jwk": public_jwk_map,
218        });
219
220        // Step 4: Wrap ECDH request with server RSA public key (JWE RSA-OAEP + A256GCM)
221        let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
222            ecdh_request_body.to_string().as_bytes(),
223            &rsa_jwk_value,
224        )?;
225
226        // Step 5: POST ECDH request and wait for Mercury response
227        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
228
229        // Step 6: Unwrap ECDH response (may be JWE encrypted with ECDH-ES or JWS signed)
230        let response_body = jwe::unwrap_kms_response(
231            &wrapped_response,
232            &jwe::JweKey::EcdhPrivate(local_secret.clone()),
233        )?;
234        let response_data: Value = serde_json::from_slice(&response_body)
235            .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
236
237        // Step 7: Extract remote key and derive shared secret
238        let remote_jwk_data = extract_jwk_from_response(&response_data)
239            .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
240
241        // Validate remote key type before attempting to parse
242        let kty = remote_jwk_data.get("kty").and_then(|v| v.as_str()).unwrap_or("");
243        let crv = remote_jwk_data.get("crv").and_then(|v| v.as_str()).unwrap_or("");
244        if kty != "EC" || crv != "P-256" {
245            return Err(WebexError::kms(format!(
246                "Invalid remote key type: kty={}, crv={}", kty, crv
247            )));
248        }
249
250        // Parse remote public key
251        let remote_x = remote_jwk_data["x"]
252            .as_str()
253            .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
254        let remote_y = remote_jwk_data["y"]
255            .as_str()
256            .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
257
258        let remote_x_bytes = URL_SAFE_NO_PAD
259            .decode(remote_x)
260            .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
261        let remote_y_bytes = URL_SAFE_NO_PAD
262            .decode(remote_y)
263            .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
264
265        // Build uncompressed point: 0x04 || x || y
266        let mut uncompressed = vec![0x04];
267        uncompressed.extend_from_slice(&remote_x_bytes);
268        uncompressed.extend_from_slice(&remote_y_bytes);
269
270        let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
271            .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
272
273        // Perform ECDH
274        let shared_secret = p256::ecdh::diffie_hellman(
275            local_secret.to_nonzero_scalar(),
276            remote_public.as_affine(),
277        );
278
279        // HKDF to derive 256-bit key
280        let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
281        let mut derived = [0u8; 32];
282        hkdf.expand(&[], &mut derived)
283            .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
284
285        self.ephemeral_key = Some(derived);
286        self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
287        self.initialized = true;
288
289        // Set context expiration (default 1 hour)
290        self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
291
292        info!("KMS client initialized successfully");
293        Ok(())
294    }
295
296    /// Retrieve an encryption key from KMS.
297    pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
298        // Check cache
299        if let Some(cached) = self.key_cache.get(key_uri) {
300            debug!("Cache hit for key: {key_uri}");
301            return Ok(*cached);
302        }
303
304        // Check and enforce key cache size limit
305        if self.key_cache.len() > MAX_KEY_CACHE_SIZE {
306            warn!("Key cache exceeded size limit ({}), clearing cache", MAX_KEY_CACHE_SIZE);
307            self.key_cache.clear();
308        }
309
310        // Check context expiration
311        if self.is_context_expired() {
312            info!("Context expired, re-initializing");
313            self.initialize().await?;
314        }
315
316        if !self.initialized {
317            return Err(WebexError::kms("KMS context not initialized"));
318        }
319
320        let ephemeral_key = self
321            .ephemeral_key
322            .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
323
324        // Build retrieve request
325        let request_id = Uuid::new_v4().to_string();
326        let retrieve_body = serde_json::json!({
327            "client": {
328                "clientId": self.device_url,
329                "credential": {
330                    "userId": self.user_id,
331                    "bearer": self.token,
332                },
333            },
334            "method": "retrieve",
335            "uri": key_uri,
336            "requestId": request_id,
337        });
338
339        // Wrap with ephemeral key (dir + A256GCM — key is CEK directly)
340        let wrapped = jwe::encrypt_dir_a256gcm(
341            retrieve_body.to_string().as_bytes(),
342            &ephemeral_key,
343            &self.ephemeral_key_kid,
344        )?;
345
346        // POST and wait for Mercury response
347        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
348
349        // Unwrap response with ephemeral key (may be JWE or JWS)
350        let response_body = jwe::unwrap_kms_response(
351            &wrapped_response,
352            &jwe::JweKey::Symmetric(ephemeral_key),
353        )?;
354        let response_data: Value = serde_json::from_slice(&response_body)
355            .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
356
357        // Extract content key
358        let key_jwk_data = extract_jwk_from_response(&response_data)
359            .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
360
361        // Extract the symmetric key bytes from the JWK
362        let k_b64 = key_jwk_data["k"]
363            .as_str()
364            .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
365        let k_bytes = URL_SAFE_NO_PAD
366            .decode(k_b64)
367            .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
368
369        let content_key: [u8; 32] = k_bytes
370            .try_into()
371            .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
372
373        self.key_cache.insert(key_uri.to_string(), content_key);
374        info!("Key retrieved and cached: {key_uri}");
375        Ok(content_key)
376    }
377
378    /// Send a KMS request via HTTP and wait for the response via Mercury.
379    async fn send_kms_request(
380        &self,
381        request_id: &str,
382        wrapped: &str,
383    ) -> Result<String, WebexError> {
384        let (tx, rx) = oneshot::channel();
385
386        // Register pending request with bound checking
387        {
388            let mut pending = self.pending_requests.lock().await;
389            if pending.len() >= MAX_PENDING_REQUESTS {
390                return Err(WebexError::kms(format!(
391                    "Too many pending KMS requests (max: {})",
392                    MAX_PENDING_REQUESTS
393                )));
394            }
395            pending.push((
396                request_id.to_string(),
397                PendingRequest { tx },
398            ));
399        }
400
401        // POST the request
402        let mut headers = HashMap::new();
403        headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
404        headers.insert("Content-Type".to_string(), "application/json".to_string());
405
406        let body = serde_json::to_string(&serde_json::json!({
407            "destination": self.kms_cluster,
408            "kmsMessages": [wrapped],
409        }))
410        .map_err(|e| WebexError::kms(format!("Failed to serialize KMS request: {e}")))?;
411
412        let http_response = (self.http_do)(FetchRequest {
413            url: format!("{}/kms/messages", self.encryption_service_url),
414            method: "POST".to_string(),
415            headers,
416            body: Some(body),
417        })
418        .await;
419
420        match http_response {
421            Ok(resp) if !resp.ok => {
422                let status = resp.status;
423                let body = String::from_utf8_lossy(&resp.body);
424                let mut pending = self.pending_requests.lock().await;
425                pending.retain(|(id, _)| id != request_id);
426                return Err(WebexError::kms(format!(
427                    "KMS HTTP request failed: {status} {body}"
428                )));
429            }
430            Err(e) => {
431                let mut pending = self.pending_requests.lock().await;
432                pending.retain(|(id, _)| id != request_id);
433                return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
434            }
435            Ok(resp) => {
436                debug!(
437                    "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
438                    resp.status
439                );
440            }
441        }
442
443        // Wait for Mercury response with timeout
444        match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
445            Ok(Ok(response)) => Ok(response),
446            Ok(Err(_)) => Err(WebexError::kms(format!(
447                "KMS request {request_id} channel closed"
448            ))),
449            Err(_) => {
450                let mut pending = self.pending_requests.lock().await;
451                pending.retain(|(id, _)| id != request_id);
452                Err(WebexError::kms(format!(
453                    "KMS request {request_id} timed out after {}s",
454                    KMS_RESPONSE_TIMEOUT.as_secs()
455                )))
456            }
457        }
458    }
459
460    fn is_context_expired(&self) -> bool {
461        if !self.initialized {
462            return true;
463        }
464        match self.context_expiration {
465            Some(exp) => {
466                let with_buffer = exp - Duration::from_secs(30);
467                Instant::now() > with_buffer
468            }
469            None => true,
470        }
471    }
472
473    /// Whether the KMS client has been initialized.
474    pub fn is_initialized(&self) -> bool {
475        self.initialized
476    }
477}
478
479/// Extract a JWK from the KMS response JSON.
480fn extract_jwk_from_response(data: &Value) -> Option<Value> {
481    // Try body.key.jwk
482    if let Some(jwk) = data.pointer("/body/key/jwk") {
483        if jwk.is_object() {
484            return Some(jwk.clone());
485        }
486    }
487    // Try body.key (as direct JWK)
488    if let Some(key) = data.pointer("/body/key") {
489        if key.is_object() {
490            return Some(key.clone());
491        }
492    }
493    // Try key.jwk
494    if let Some(jwk) = data.pointer("/key/jwk") {
495        if jwk.is_object() {
496            return Some(jwk.clone());
497        }
498    }
499    // Try key
500    if let Some(key) = data.get("key") {
501        if key.is_object() {
502            return Some(key.clone());
503        }
504    }
505    None
506}
507
508/// Extract the key URI from a KMS response JSON.
509fn extract_key_uri(data: &Value) -> Option<String> {
510    data.pointer("/body/key/uri")
511        .or_else(|| data.pointer("/key/uri"))
512        .and_then(|v| v.as_str())
513        .map(|s| s.to_string())
514}