Skip to main content

webex_message_handler/
kms_client.rs

1//! KMS client for ECDH key exchange and encryption key retrieval.
2//!
3//! Implements the Webex KMS protocol:
4//! 1. Fetch KMS cluster details (rsaPublicKey, kmsCluster)
5//! 2. ECDH handshake: generate local P-256 keypair, wrap with RSA-OAEP,
6//!    send via HTTP, receive response via Mercury, derive shared key
7//! 3. Key retrieval: wrap request with ECDH-derived key, send via HTTP,
8//!    receive via Mercury, unwrap to get content key
9//! 4. Content keys are JWE A256KW + A256GCM
10
11use crate::errors::WebexError;
12use crate::jwe;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use base64::Engine;
15use p256::elliptic_curve::sec1::ToEncodedPoint;
16use p256::PublicKey;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::{oneshot, Mutex};
22use tracing::{debug, info, warn};
23use uuid::Uuid;
24
25const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
26
27/// A pending KMS request awaiting a Mercury response.
28struct PendingRequest {
29    tx: oneshot::Sender<String>,
30}
31
32/// Handle for resolving KMS responses from Mercury.
33///
34/// This is separated from KmsClient so the Mercury event loop can resolve
35/// pending requests without holding the KmsClient lock (which would deadlock
36/// during initialize/get_key).
37#[derive(Clone)]
38pub struct KmsResponseHandler {
39    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
40}
41
42impl KmsResponseHandler {
43    /// Handle a KMS response that arrived via Mercury WebSocket.
44    pub async fn handle_kms_message(&self, data: &Value) {
45        let kms_messages = data
46            .get("kmsMessages")
47            .and_then(|v| v.as_array())
48            .or_else(|| {
49                data.get("encryption")
50                    .and_then(|e| e.get("kmsMessages"))
51                    .and_then(|v| v.as_array())
52            });
53
54        let kms_messages = match kms_messages {
55            Some(msgs) => msgs,
56            None => {
57                debug!("Received KMS message without kmsMessages array");
58                return;
59            }
60        };
61
62        let mut pending = self.pending_requests.lock().await;
63
64        for raw_msg in kms_messages {
65            let wrapped = match raw_msg.as_str() {
66                Some(s) => s.to_string(),
67                None => continue,
68            };
69
70            debug!("Received KMS response, pending requests: {}", pending.len());
71
72            // Resolve the first pending request (FIFO)
73            if !pending.is_empty() {
74                let (_, req) = pending.remove(0);
75                let _ = req.tx.send(wrapped);
76            } else {
77                warn!("Received KMS response but no pending requests");
78            }
79        }
80    }
81}
82
83/// KMS client for Webex end-to-end encryption.
84pub struct KmsClient {
85    token: String,
86    device_url: String,
87    user_id: String,
88    encryption_service_url: String,
89    client: reqwest::Client,
90
91    kms_cluster: String,
92    /// The 256-bit symmetric key derived from ECDH, used as CEK for dir+A256GCM.
93    ephemeral_key: Option<[u8; 32]>,
94    /// The kid (key URI) for the ephemeral key, included in JWE headers.
95    ephemeral_key_kid: String,
96    context_expiration: Option<Instant>,
97    /// Cache of content encryption keys (key URI → raw 256-bit key bytes).
98    key_cache: HashMap<String, [u8; 32]>,
99    initialized: bool,
100
101    /// Pending KMS requests waiting for Mercury responses (FIFO).
102    pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
103}
104
105impl KmsClient {
106    pub fn new(
107        token: &str,
108        device_url: &str,
109        user_id: &str,
110        encryption_service_url: &str,
111    ) -> Self {
112        Self {
113            token: token.to_string(),
114            device_url: device_url.to_string(),
115            user_id: user_id.to_string(),
116            encryption_service_url: encryption_service_url.to_string(),
117            client: reqwest::Client::new(),
118            kms_cluster: String::new(),
119            ephemeral_key: None,
120            ephemeral_key_kid: String::new(),
121            context_expiration: None,
122            key_cache: HashMap::new(),
123            initialized: false,
124            pending_requests: Arc::new(Mutex::new(Vec::new())),
125        }
126    }
127
128    /// Get a KmsResponseHandler that can resolve pending requests from Mercury.
129    ///
130    /// The returned handler can be used without holding the KmsClient lock,
131    /// preventing deadlocks during initialize() and get_key().
132    pub fn response_handler(&self) -> KmsResponseHandler {
133        KmsResponseHandler {
134            pending_requests: self.pending_requests.clone(),
135        }
136    }
137
138    /// Initialize KMS context with ECDH handshake.
139    pub async fn initialize(&mut self) -> Result<(), WebexError> {
140        info!("Initializing KMS client");
141
142        // Step 1: Fetch KMS details
143        let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
144        let response = self
145            .client
146            .get(&kms_details_url)
147            .header("Authorization", format!("Bearer {}", self.token))
148            .send()
149            .await
150            .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
151
152        if !response.status().is_success() {
153            return Err(WebexError::kms(format!(
154                "Failed to fetch KMS details: {}",
155                response.status()
156            )));
157        }
158
159        let kms_details: Value = response
160            .json()
161            .await
162            .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
163
164        self.kms_cluster = kms_details["kmsCluster"]
165            .as_str()
166            .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
167            .to_string();
168
169        // Parse RSA public key (may be string or object)
170        let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
171            Value::String(s) => serde_json::from_str::<Value>(s)
172                .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
173            v @ Value::Object(_) => v.clone(),
174            _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
175        };
176
177        // Step 2: Generate local ECDH keypair (P-256)
178        let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
179        let local_public = local_secret.public_key();
180        let local_public_point = local_public.to_encoded_point(false);
181
182        let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
183        let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
184
185        let public_jwk_map = serde_json::json!({
186            "kty": "EC",
187            "crv": "P-256",
188            "x": URL_SAFE_NO_PAD.encode(&*x_bytes),
189            "y": URL_SAFE_NO_PAD.encode(&*y_bytes),
190        });
191
192        // Step 3: Build ECDH request body
193        let request_id = Uuid::new_v4().to_string();
194        let ecdh_request_body = serde_json::json!({
195            "client": {
196                "clientId": self.device_url,
197                "credential": {
198                    "userId": self.user_id,
199                    "bearer": self.token,
200                },
201            },
202            "method": "create",
203            "uri": format!("{}/ecdhe", self.kms_cluster),
204            "requestId": request_id,
205            "jwk": public_jwk_map,
206        });
207
208        // Step 4: Wrap ECDH request with server RSA public key (JWE RSA-OAEP + A256GCM)
209        let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
210            ecdh_request_body.to_string().as_bytes(),
211            &rsa_jwk_value,
212        )?;
213
214        // Step 5: POST ECDH request and wait for Mercury response
215        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
216
217        // Step 6: Unwrap ECDH response (may be JWE encrypted with ECDH-ES or JWS signed)
218        let response_body = jwe::unwrap_kms_response(
219            &wrapped_response,
220            &jwe::JweKey::EcdhPrivate(local_secret.clone()),
221        )?;
222        let response_data: Value = serde_json::from_slice(&response_body)
223            .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
224
225        // Step 7: Extract remote key and derive shared secret
226        let remote_jwk_data = extract_jwk_from_response(&response_data)
227            .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
228
229        // Parse remote public key
230        let remote_x = remote_jwk_data["x"]
231            .as_str()
232            .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
233        let remote_y = remote_jwk_data["y"]
234            .as_str()
235            .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
236
237        let remote_x_bytes = URL_SAFE_NO_PAD
238            .decode(remote_x)
239            .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
240        let remote_y_bytes = URL_SAFE_NO_PAD
241            .decode(remote_y)
242            .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
243
244        // Build uncompressed point: 0x04 || x || y
245        let mut uncompressed = vec![0x04];
246        uncompressed.extend_from_slice(&remote_x_bytes);
247        uncompressed.extend_from_slice(&remote_y_bytes);
248
249        let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
250            .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
251
252        // Perform ECDH
253        let shared_secret = p256::ecdh::diffie_hellman(
254            local_secret.to_nonzero_scalar(),
255            remote_public.as_affine(),
256        );
257
258        // HKDF to derive 256-bit key
259        let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
260        let mut derived = [0u8; 32];
261        hkdf.expand(&[], &mut derived)
262            .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
263
264        self.ephemeral_key = Some(derived);
265        self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
266        self.initialized = true;
267
268        // Set context expiration (default 1 hour)
269        self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
270
271        info!("KMS client initialized successfully");
272        Ok(())
273    }
274
275    /// Retrieve an encryption key from KMS.
276    pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
277        // Check cache
278        if let Some(cached) = self.key_cache.get(key_uri) {
279            debug!("Cache hit for key: {key_uri}");
280            return Ok(*cached);
281        }
282
283        // Check context expiration
284        if self.is_context_expired() {
285            info!("Context expired, re-initializing");
286            self.initialize().await?;
287        }
288
289        if !self.initialized {
290            return Err(WebexError::kms("KMS context not initialized"));
291        }
292
293        let ephemeral_key = self
294            .ephemeral_key
295            .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
296
297        // Build retrieve request
298        let request_id = Uuid::new_v4().to_string();
299        let retrieve_body = serde_json::json!({
300            "client": {
301                "clientId": self.device_url,
302                "credential": {
303                    "userId": self.user_id,
304                    "bearer": self.token,
305                },
306            },
307            "method": "retrieve",
308            "uri": key_uri,
309            "requestId": request_id,
310        });
311
312        // Wrap with ephemeral key (dir + A256GCM — key is CEK directly)
313        let wrapped = jwe::encrypt_dir_a256gcm(
314            retrieve_body.to_string().as_bytes(),
315            &ephemeral_key,
316            &self.ephemeral_key_kid,
317        )?;
318
319        // POST and wait for Mercury response
320        let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
321
322        // Unwrap response with ephemeral key (may be JWE or JWS)
323        let response_body = jwe::unwrap_kms_response(
324            &wrapped_response,
325            &jwe::JweKey::Symmetric(ephemeral_key),
326        )?;
327        let response_data: Value = serde_json::from_slice(&response_body)
328            .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
329
330        // Extract content key
331        let key_jwk_data = extract_jwk_from_response(&response_data)
332            .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
333
334        // Extract the symmetric key bytes from the JWK
335        let k_b64 = key_jwk_data["k"]
336            .as_str()
337            .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
338        let k_bytes = URL_SAFE_NO_PAD
339            .decode(k_b64)
340            .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
341
342        let content_key: [u8; 32] = k_bytes
343            .try_into()
344            .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
345
346        self.key_cache.insert(key_uri.to_string(), content_key);
347        info!("Key retrieved and cached: {key_uri}");
348        Ok(content_key)
349    }
350
351    /// Send a KMS request via HTTP and wait for the response via Mercury.
352    async fn send_kms_request(
353        &self,
354        request_id: &str,
355        wrapped: &str,
356    ) -> Result<String, WebexError> {
357        let (tx, rx) = oneshot::channel();
358
359        // Register pending request
360        {
361            let mut pending = self.pending_requests.lock().await;
362            pending.push((
363                request_id.to_string(),
364                PendingRequest { tx },
365            ));
366        }
367
368        // POST the request
369        let http_response = self
370            .client
371            .post(format!("{}/kms/messages", self.encryption_service_url))
372            .header("Authorization", format!("Bearer {}", self.token))
373            .header("Content-Type", "application/json")
374            .json(&serde_json::json!({
375                "destination": self.kms_cluster,
376                "kmsMessages": [wrapped],
377            }))
378            .send()
379            .await;
380
381        match http_response {
382            Ok(resp) if !resp.status().is_success() => {
383                let status = resp.status();
384                let body = resp.text().await.unwrap_or_default();
385                let mut pending = self.pending_requests.lock().await;
386                pending.retain(|(id, _)| id != request_id);
387                return Err(WebexError::kms(format!(
388                    "KMS HTTP request failed: {status} {body}"
389                )));
390            }
391            Err(e) => {
392                let mut pending = self.pending_requests.lock().await;
393                pending.retain(|(id, _)| id != request_id);
394                return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
395            }
396            Ok(resp) => {
397                debug!(
398                    "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
399                    resp.status()
400                );
401            }
402        }
403
404        // Wait for Mercury response with timeout
405        match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
406            Ok(Ok(response)) => Ok(response),
407            Ok(Err(_)) => Err(WebexError::kms(format!(
408                "KMS request {request_id} channel closed"
409            ))),
410            Err(_) => {
411                let mut pending = self.pending_requests.lock().await;
412                pending.retain(|(id, _)| id != request_id);
413                Err(WebexError::kms(format!(
414                    "KMS request {request_id} timed out after {}s",
415                    KMS_RESPONSE_TIMEOUT.as_secs()
416                )))
417            }
418        }
419    }
420
421    fn is_context_expired(&self) -> bool {
422        if !self.initialized {
423            return true;
424        }
425        match self.context_expiration {
426            Some(exp) => {
427                let with_buffer = exp - Duration::from_secs(30);
428                Instant::now() > with_buffer
429            }
430            None => true,
431        }
432    }
433
434    /// Whether the KMS client has been initialized.
435    pub fn is_initialized(&self) -> bool {
436        self.initialized
437    }
438}
439
440/// Extract a JWK from the KMS response JSON.
441fn extract_jwk_from_response(data: &Value) -> Option<Value> {
442    // Try body.key.jwk
443    if let Some(jwk) = data.pointer("/body/key/jwk") {
444        if jwk.is_object() {
445            return Some(jwk.clone());
446        }
447    }
448    // Try body.key (as direct JWK)
449    if let Some(key) = data.pointer("/body/key") {
450        if key.is_object() {
451            return Some(key.clone());
452        }
453    }
454    // Try key.jwk
455    if let Some(jwk) = data.pointer("/key/jwk") {
456        if jwk.is_object() {
457            return Some(jwk.clone());
458        }
459    }
460    // Try key
461    if let Some(key) = data.get("key") {
462        if key.is_object() {
463            return Some(key.clone());
464        }
465    }
466    None
467}
468
469/// Extract the key URI from a KMS response JSON.
470fn extract_key_uri(data: &Value) -> Option<String> {
471    data.pointer("/body/key/uri")
472        .or_else(|| data.pointer("/key/uri"))
473        .and_then(|v| v.as_str())
474        .map(|s| s.to_string())
475}