Skip to main content

webex_message_handler/
kms_client.rs

1//! KMS client for ECDH key exchange and encryption key retrieval.
2//!
3//! Implements the Webex KMS protocol:
4//! 1. Fetch KMS cluster details (rsaPublicKey, kmsCluster)
5//! 2. ECDH handshake: generate local P-256 keypair, wrap with RSA-OAEP,
6//!    send via HTTP, receive response via Mercury, derive shared key
7//! 3. Key retrieval: wrap request with ECDH-derived key, send via HTTP,
8//!    receive via Mercury, unwrap to get content key
9//! 4. Content keys are JWE A256KW + A256GCM
10
11use crate::errors::WebexError;
12use crate::jwe;
13use crate::types::{FetchFn, FetchRequest};
14use base64::engine::general_purpose::URL_SAFE_NO_PAD;
15use base64::Engine;
16use p256::elliptic_curve::sec1::ToEncodedPoint;
17use p256::PublicKey;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::{oneshot, Mutex};
23use tracing::{debug, info, warn};
24use uuid::Uuid;
25
26const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
27
28/// A pending KMS request awaiting a Mercury response.
29struct PendingRequest {
30    tx: oneshot::Sender<String>,
31}
32
33/// Handle for resolving KMS responses from Mercury.
34///
35/// This is separated from KmsClient so the Mercury event loop can resolve
36/// pending requests without holding the KmsClient lock (which would deadlock
37/// during initialize/get_key).
38#[derive(Clone)]
39pub struct KmsResponseHandler {
40    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
41}
42
43impl KmsResponseHandler {
44    /// Handle a KMS response that arrived via Mercury WebSocket.
45    pub async fn handle_kms_message(&self, data: &Value) {
46        let kms_messages = data
47            .get("kmsMessages")
48            .and_then(|v| v.as_array())
49            .or_else(|| {
50                data.get("encryption")
51                    .and_then(|e| e.get("kmsMessages"))
52                    .and_then(|v| v.as_array())
53            });
54
55        let kms_messages = match kms_messages {
56            Some(msgs) => msgs,
57            None => {
58                debug!("Received KMS message without kmsMessages array");
59                return;
60            }
61        };
62
63        let mut pending = self.pending_requests.lock().await;
64
65        for raw_msg in kms_messages {
66            let wrapped = match raw_msg.as_str() {
67                Some(s) => s.to_string(),
68                None => continue,
69            };
70
71            debug!("Received KMS response, pending requests: {}", pending.len());
72
73            // Resolve the first pending request (FIFO)
74            if !pending.is_empty() {
75                let (_, req) = pending.remove(0);
76                let _ = req.tx.send(wrapped);
77            } else {
78                warn!("Received KMS response but no pending requests");
79            }
80        }
81    }
82}
83
84/// KMS client for Webex end-to-end encryption.
85pub struct KmsClient {
86    token: String,
87    device_url: String,
88    user_id: String,
89    encryption_service_url: String,
90    http_do: FetchFn,
91
92    kms_cluster: String,
93    /// The 256-bit symmetric key derived from ECDH, used as CEK for dir+A256GCM.
94    ephemeral_key: Option<[u8; 32]>,
95    /// The kid (key URI) for the ephemeral key, included in JWE headers.
96    ephemeral_key_kid: String,
97    context_expiration: Option<Instant>,
98    /// Cache of content encryption keys (key URI → raw 256-bit key bytes).
99    key_cache: HashMap<String, [u8; 32]>,
100    initialized: bool,
101
102    /// Pending KMS requests waiting for Mercury responses (FIFO).
103    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
104}
105
106impl KmsClient {
107    pub fn new(
108        http_do: FetchFn,
109        token: &str,
110        device_url: &str,
111        user_id: &str,
112        encryption_service_url: &str,
113    ) -> Self {
114        Self {
115            token: token.to_string(),
116            device_url: device_url.to_string(),
117            user_id: user_id.to_string(),
118            encryption_service_url: encryption_service_url.to_string(),
119            http_do,
120            kms_cluster: String::new(),
121            ephemeral_key: None,
122            ephemeral_key_kid: String::new(),
123            context_expiration: None,
124            key_cache: HashMap::new(),
125            initialized: false,
126            pending_requests: Arc::new(Mutex::new(Vec::new())),
127        }
128    }
129
130    /// Get a KmsResponseHandler that can resolve pending requests from Mercury.
131    ///
132    /// The returned handler can be used without holding the KmsClient lock,
133    /// preventing deadlocks during initialize() and get_key().
134    pub fn response_handler(&self) -> KmsResponseHandler {
135        KmsResponseHandler {
136            pending_requests: self.pending_requests.clone(),
137        }
138    }
139
140    /// Initialize KMS context with ECDH handshake.
141    pub async fn initialize(&mut self) -> Result<(), WebexError> {
142        info!("Initializing KMS client");
143
144        // Step 1: Fetch KMS details
145        let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
146
147        let mut headers = HashMap::new();
148        headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
149
150        let response = (self.http_do)(FetchRequest {
151            url: kms_details_url,
152            method: "GET".to_string(),
153            headers,
154            body: None,
155        })
156        .await
157        .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
158
159        if !response.ok {
160            return Err(WebexError::kms(format!(
161                "Failed to fetch KMS details: {}",
162                response.status
163            )));
164        }
165
166        let kms_details: Value = serde_json::from_slice(&response.body)
167            .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
168
169        self.kms_cluster = kms_details["kmsCluster"]
170            .as_str()
171            .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
172            .to_string();
173
174        // Parse RSA public key (may be string or object)
175        let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
176            Value::String(s) => serde_json::from_str::<Value>(s)
177                .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
178            v @ Value::Object(_) => v.clone(),
179            _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
180        };
181
182        // Step 2: Generate local ECDH keypair (P-256)
183        let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
184        let local_public = local_secret.public_key();
185        let local_public_point = local_public.to_encoded_point(false);
186
187        let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
188        let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
189
190        let public_jwk_map = serde_json::json!({
191            "kty": "EC",
192            "crv": "P-256",
193            "x": URL_SAFE_NO_PAD.encode(&*x_bytes),
194            "y": URL_SAFE_NO_PAD.encode(&*y_bytes),
195        });
196
197        // Step 3: Build ECDH request body
198        let request_id = Uuid::new_v4().to_string();
199        let ecdh_request_body = serde_json::json!({
200            "client": {
201                "clientId": self.device_url,
202                "credential": {
203                    "userId": self.user_id,
204                    "bearer": self.token,
205                },
206            },
207            "method": "create",
208            "uri": format!("{}/ecdhe", self.kms_cluster),
209            "requestId": request_id,
210            "jwk": public_jwk_map,
211        });
212
213        // Step 4: Wrap ECDH request with server RSA public key (JWE RSA-OAEP + A256GCM)
214        let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
215            ecdh_request_body.to_string().as_bytes(),
216            &rsa_jwk_value,
217        )?;
218
219        // Step 5: POST ECDH request and wait for Mercury response
220        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
221
222        // Step 6: Unwrap ECDH response (may be JWE encrypted with ECDH-ES or JWS signed)
223        let response_body = jwe::unwrap_kms_response(
224            &wrapped_response,
225            &jwe::JweKey::EcdhPrivate(local_secret.clone()),
226        )?;
227        let response_data: Value = serde_json::from_slice(&response_body)
228            .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
229
230        // Step 7: Extract remote key and derive shared secret
231        let remote_jwk_data = extract_jwk_from_response(&response_data)
232            .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
233
234        // Parse remote public key
235        let remote_x = remote_jwk_data["x"]
236            .as_str()
237            .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
238        let remote_y = remote_jwk_data["y"]
239            .as_str()
240            .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
241
242        let remote_x_bytes = URL_SAFE_NO_PAD
243            .decode(remote_x)
244            .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
245        let remote_y_bytes = URL_SAFE_NO_PAD
246            .decode(remote_y)
247            .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
248
249        // Build uncompressed point: 0x04 || x || y
250        let mut uncompressed = vec![0x04];
251        uncompressed.extend_from_slice(&remote_x_bytes);
252        uncompressed.extend_from_slice(&remote_y_bytes);
253
254        let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
255            .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
256
257        // Perform ECDH
258        let shared_secret = p256::ecdh::diffie_hellman(
259            local_secret.to_nonzero_scalar(),
260            remote_public.as_affine(),
261        );
262
263        // HKDF to derive 256-bit key
264        let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
265        let mut derived = [0u8; 32];
266        hkdf.expand(&[], &mut derived)
267            .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
268
269        self.ephemeral_key = Some(derived);
270        self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
271        self.initialized = true;
272
273        // Set context expiration (default 1 hour)
274        self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
275
276        info!("KMS client initialized successfully");
277        Ok(())
278    }
279
280    /// Retrieve an encryption key from KMS.
281    pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
282        // Check cache
283        if let Some(cached) = self.key_cache.get(key_uri) {
284            debug!("Cache hit for key: {key_uri}");
285            return Ok(*cached);
286        }
287
288        // Check context expiration
289        if self.is_context_expired() {
290            info!("Context expired, re-initializing");
291            self.initialize().await?;
292        }
293
294        if !self.initialized {
295            return Err(WebexError::kms("KMS context not initialized"));
296        }
297
298        let ephemeral_key = self
299            .ephemeral_key
300            .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
301
302        // Build retrieve request
303        let request_id = Uuid::new_v4().to_string();
304        let retrieve_body = serde_json::json!({
305            "client": {
306                "clientId": self.device_url,
307                "credential": {
308                    "userId": self.user_id,
309                    "bearer": self.token,
310                },
311            },
312            "method": "retrieve",
313            "uri": key_uri,
314            "requestId": request_id,
315        });
316
317        // Wrap with ephemeral key (dir + A256GCM — key is CEK directly)
318        let wrapped = jwe::encrypt_dir_a256gcm(
319            retrieve_body.to_string().as_bytes(),
320            &ephemeral_key,
321            &self.ephemeral_key_kid,
322        )?;
323
324        // POST and wait for Mercury response
325        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
326
327        // Unwrap response with ephemeral key (may be JWE or JWS)
328        let response_body = jwe::unwrap_kms_response(
329            &wrapped_response,
330            &jwe::JweKey::Symmetric(ephemeral_key),
331        )?;
332        let response_data: Value = serde_json::from_slice(&response_body)
333            .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
334
335        // Extract content key
336        let key_jwk_data = extract_jwk_from_response(&response_data)
337            .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
338
339        // Extract the symmetric key bytes from the JWK
340        let k_b64 = key_jwk_data["k"]
341            .as_str()
342            .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
343        let k_bytes = URL_SAFE_NO_PAD
344            .decode(k_b64)
345            .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
346
347        let content_key: [u8; 32] = k_bytes
348            .try_into()
349            .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
350
351        self.key_cache.insert(key_uri.to_string(), content_key);
352        info!("Key retrieved and cached: {key_uri}");
353        Ok(content_key)
354    }
355
356    /// Send a KMS request via HTTP and wait for the response via Mercury.
357    async fn send_kms_request(
358        &self,
359        request_id: &str,
360        wrapped: &str,
361    ) -> Result<String, WebexError> {
362        let (tx, rx) = oneshot::channel();
363
364        // Register pending request
365        {
366            let mut pending = self.pending_requests.lock().await;
367            pending.push((
368                request_id.to_string(),
369                PendingRequest { tx },
370            ));
371        }
372
373        // POST the request
374        let mut headers = HashMap::new();
375        headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
376        headers.insert("Content-Type".to_string(), "application/json".to_string());
377
378        let body = serde_json::to_string(&serde_json::json!({
379            "destination": self.kms_cluster,
380            "kmsMessages": [wrapped],
381        }))
382        .map_err(|e| WebexError::kms(format!("Failed to serialize KMS request: {e}")))?;
383
384        let http_response = (self.http_do)(FetchRequest {
385            url: format!("{}/kms/messages", self.encryption_service_url),
386            method: "POST".to_string(),
387            headers,
388            body: Some(body),
389        })
390        .await;
391
392        match http_response {
393            Ok(resp) if !resp.ok => {
394                let status = resp.status;
395                let body = String::from_utf8_lossy(&resp.body);
396                let mut pending = self.pending_requests.lock().await;
397                pending.retain(|(id, _)| id != request_id);
398                return Err(WebexError::kms(format!(
399                    "KMS HTTP request failed: {status} {body}"
400                )));
401            }
402            Err(e) => {
403                let mut pending = self.pending_requests.lock().await;
404                pending.retain(|(id, _)| id != request_id);
405                return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
406            }
407            Ok(resp) => {
408                debug!(
409                    "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
410                    resp.status
411                );
412            }
413        }
414
415        // Wait for Mercury response with timeout
416        match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
417            Ok(Ok(response)) => Ok(response),
418            Ok(Err(_)) => Err(WebexError::kms(format!(
419                "KMS request {request_id} channel closed"
420            ))),
421            Err(_) => {
422                let mut pending = self.pending_requests.lock().await;
423                pending.retain(|(id, _)| id != request_id);
424                Err(WebexError::kms(format!(
425                    "KMS request {request_id} timed out after {}s",
426                    KMS_RESPONSE_TIMEOUT.as_secs()
427                )))
428            }
429        }
430    }
431
432    fn is_context_expired(&self) -> bool {
433        if !self.initialized {
434            return true;
435        }
436        match self.context_expiration {
437            Some(exp) => {
438                let with_buffer = exp - Duration::from_secs(30);
439                Instant::now() > with_buffer
440            }
441            None => true,
442        }
443    }
444
445    /// Whether the KMS client has been initialized.
446    pub fn is_initialized(&self) -> bool {
447        self.initialized
448    }
449}
450
451/// Extract a JWK from the KMS response JSON.
452fn extract_jwk_from_response(data: &Value) -> Option<Value> {
453    // Try body.key.jwk
454    if let Some(jwk) = data.pointer("/body/key/jwk") {
455        if jwk.is_object() {
456            return Some(jwk.clone());
457        }
458    }
459    // Try body.key (as direct JWK)
460    if let Some(key) = data.pointer("/body/key") {
461        if key.is_object() {
462            return Some(key.clone());
463        }
464    }
465    // Try key.jwk
466    if let Some(jwk) = data.pointer("/key/jwk") {
467        if jwk.is_object() {
468            return Some(jwk.clone());
469        }
470    }
471    // Try key
472    if let Some(key) = data.get("key") {
473        if key.is_object() {
474            return Some(key.clone());
475        }
476    }
477    None
478}
479
480/// Extract the key URI from a KMS response JSON.
481fn extract_key_uri(data: &Value) -> Option<String> {
482    data.pointer("/body/key/uri")
483        .or_else(|| data.pointer("/key/uri"))
484        .and_then(|v| v.as_str())
485        .map(|s| s.to_string())
486}