1use crate::errors::WebexError;
12use crate::jwe;
13use crate::types::{FetchFn, FetchRequest};
14use crate::url_validation::validate_webex_url;
15use base64::engine::general_purpose::URL_SAFE_NO_PAD;
16use base64::Engine;
17use p256::elliptic_curve::sec1::ToEncodedPoint;
18use p256::PublicKey;
19use serde_json::Value;
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::{oneshot, Mutex};
24use tracing::{debug, info, warn};
25use uuid::Uuid;
26
27const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
28const MAX_PENDING_REQUESTS: usize = 100;
29const MAX_KEY_CACHE_SIZE: usize = 100;
30
31struct PendingRequest {
33 tx: oneshot::Sender<String>,
34}
35
36#[derive(Clone)]
42pub struct KmsResponseHandler {
43 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
44}
45
46impl KmsResponseHandler {
47 pub async fn handle_kms_message(&self, data: &Value) {
49 let kms_messages = data
50 .get("kmsMessages")
51 .and_then(|v| v.as_array())
52 .or_else(|| {
53 data.get("encryption")
54 .and_then(|e| e.get("kmsMessages"))
55 .and_then(|v| v.as_array())
56 });
57
58 let kms_messages = match kms_messages {
59 Some(msgs) => msgs,
60 None => {
61 debug!("Received KMS message without kmsMessages array");
62 return;
63 }
64 };
65
66 let mut pending = self.pending_requests.lock().await;
67
68 for raw_msg in kms_messages {
69 let wrapped = match raw_msg.as_str() {
70 Some(s) => s.to_string(),
71 None => continue,
72 };
73
74 debug!("Received KMS response, pending requests: {}", pending.len());
75
76 if !pending.is_empty() {
78 let (_, req) = pending.remove(0);
79 let _ = req.tx.send(wrapped);
80 } else {
81 warn!("Received KMS response but no pending requests");
82 }
83 }
84 }
85}
86
87pub struct KmsClient {
89 token: String,
90 device_url: String,
91 user_id: String,
92 encryption_service_url: String,
93 http_do: FetchFn,
94
95 kms_cluster: String,
96 ephemeral_key: Option<[u8; 32]>,
98 ephemeral_key_kid: String,
100 context_expiration: Option<Instant>,
101 key_cache: HashMap<String, [u8; 32]>,
103 initialized: bool,
104
105 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
107}
108
109impl KmsClient {
110 pub fn new(
111 http_do: FetchFn,
112 token: &str,
113 device_url: &str,
114 user_id: &str,
115 encryption_service_url: &str,
116 ) -> Self {
117 Self {
118 token: token.to_string(),
119 device_url: device_url.to_string(),
120 user_id: user_id.to_string(),
121 encryption_service_url: encryption_service_url.to_string(),
122 http_do,
123 kms_cluster: String::new(),
124 ephemeral_key: None,
125 ephemeral_key_kid: String::new(),
126 context_expiration: None,
127 key_cache: HashMap::new(),
128 initialized: false,
129 pending_requests: Arc::new(Mutex::new(Vec::new())),
130 }
131 }
132
133 pub fn response_handler(&self) -> KmsResponseHandler {
138 KmsResponseHandler {
139 pending_requests: self.pending_requests.clone(),
140 }
141 }
142
143 pub async fn initialize(&mut self) -> Result<(), WebexError> {
145 info!("Initializing KMS client");
146
147 let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
149
150 let mut headers = HashMap::new();
151 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
152
153 let response = (self.http_do)(FetchRequest {
154 url: kms_details_url,
155 method: "GET".to_string(),
156 headers,
157 body: None,
158 })
159 .await
160 .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
161
162 if !response.ok {
163 return Err(WebexError::kms(format!(
164 "Failed to fetch KMS details: {}",
165 response.status
166 )));
167 }
168
169 let kms_details: Value = serde_json::from_slice(&response.body)
170 .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
171
172 self.kms_cluster = kms_details["kmsCluster"]
173 .as_str()
174 .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
175 .to_string();
176
177 validate_webex_url(&self.kms_cluster, "https")
179 .map_err(|e| WebexError::kms(format!("Invalid kmsCluster URL: {e}")))?;
180
181 let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
183 Value::String(s) => serde_json::from_str::<Value>(s)
184 .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
185 v @ Value::Object(_) => v.clone(),
186 _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
187 };
188
189 let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
191 let local_public = local_secret.public_key();
192 let local_public_point = local_public.to_encoded_point(false);
193
194 let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
195 let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
196
197 let public_jwk_map = serde_json::json!({
198 "kty": "EC",
199 "crv": "P-256",
200 "x": URL_SAFE_NO_PAD.encode(*x_bytes),
201 "y": URL_SAFE_NO_PAD.encode(*y_bytes),
202 });
203
204 let request_id = Uuid::new_v4().to_string();
206 let ecdh_request_body = serde_json::json!({
207 "client": {
208 "clientId": self.device_url,
209 "credential": {
210 "userId": self.user_id,
211 "bearer": self.token,
212 },
213 },
214 "method": "create",
215 "uri": format!("{}/ecdhe", self.kms_cluster),
216 "requestId": request_id,
217 "jwk": public_jwk_map,
218 });
219
220 let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
222 ecdh_request_body.to_string().as_bytes(),
223 &rsa_jwk_value,
224 )?;
225
226 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
228
229 let response_body = jwe::unwrap_kms_response(
231 &wrapped_response,
232 &jwe::JweKey::EcdhPrivate(local_secret.clone()),
233 )?;
234 let response_data: Value = serde_json::from_slice(&response_body)
235 .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
236
237 let remote_jwk_data = extract_jwk_from_response(&response_data)
239 .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
240
241 let kty = remote_jwk_data.get("kty").and_then(|v| v.as_str()).unwrap_or("");
243 let crv = remote_jwk_data.get("crv").and_then(|v| v.as_str()).unwrap_or("");
244 if kty != "EC" || crv != "P-256" {
245 return Err(WebexError::kms(format!(
246 "Invalid remote key type: kty={}, crv={}", kty, crv
247 )));
248 }
249
250 let remote_x = remote_jwk_data["x"]
252 .as_str()
253 .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
254 let remote_y = remote_jwk_data["y"]
255 .as_str()
256 .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
257
258 let remote_x_bytes = URL_SAFE_NO_PAD
259 .decode(remote_x)
260 .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
261 let remote_y_bytes = URL_SAFE_NO_PAD
262 .decode(remote_y)
263 .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
264
265 let mut uncompressed = vec![0x04];
267 uncompressed.extend_from_slice(&remote_x_bytes);
268 uncompressed.extend_from_slice(&remote_y_bytes);
269
270 let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
271 .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
272
273 let shared_secret = p256::ecdh::diffie_hellman(
275 local_secret.to_nonzero_scalar(),
276 remote_public.as_affine(),
277 );
278
279 let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
281 let mut derived = [0u8; 32];
282 hkdf.expand(&[], &mut derived)
283 .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
284
285 self.ephemeral_key = Some(derived);
286 self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
287 self.initialized = true;
288
289 self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
291
292 info!("KMS client initialized successfully");
293 Ok(())
294 }
295
296 pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
298 if let Some(cached) = self.key_cache.get(key_uri) {
300 debug!("Cache hit for key: {key_uri}");
301 return Ok(*cached);
302 }
303
304 if self.key_cache.len() > MAX_KEY_CACHE_SIZE {
306 warn!("Key cache exceeded size limit ({}), clearing cache", MAX_KEY_CACHE_SIZE);
307 self.key_cache.clear();
308 }
309
310 if self.is_context_expired() {
312 info!("Context expired, re-initializing");
313 self.initialize().await?;
314 }
315
316 if !self.initialized {
317 return Err(WebexError::kms("KMS context not initialized"));
318 }
319
320 let ephemeral_key = self
321 .ephemeral_key
322 .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
323
324 let request_id = Uuid::new_v4().to_string();
326 let retrieve_body = serde_json::json!({
327 "client": {
328 "clientId": self.device_url,
329 "credential": {
330 "userId": self.user_id,
331 "bearer": self.token,
332 },
333 },
334 "method": "retrieve",
335 "uri": key_uri,
336 "requestId": request_id,
337 });
338
339 let wrapped = jwe::encrypt_dir_a256gcm(
341 retrieve_body.to_string().as_bytes(),
342 &ephemeral_key,
343 &self.ephemeral_key_kid,
344 )?;
345
346 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
348
349 let response_body = jwe::unwrap_kms_response(
351 &wrapped_response,
352 &jwe::JweKey::Symmetric(ephemeral_key),
353 )?;
354 let response_data: Value = serde_json::from_slice(&response_body)
355 .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
356
357 let key_jwk_data = extract_jwk_from_response(&response_data)
359 .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
360
361 let k_b64 = key_jwk_data["k"]
363 .as_str()
364 .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
365 let k_bytes = URL_SAFE_NO_PAD
366 .decode(k_b64)
367 .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
368
369 let content_key: [u8; 32] = k_bytes
370 .try_into()
371 .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
372
373 self.key_cache.insert(key_uri.to_string(), content_key);
374 info!("Key retrieved and cached: {key_uri}");
375 Ok(content_key)
376 }
377
378 async fn send_kms_request(
380 &self,
381 request_id: &str,
382 wrapped: &str,
383 ) -> Result<String, WebexError> {
384 let (tx, rx) = oneshot::channel();
385
386 {
388 let mut pending = self.pending_requests.lock().await;
389 if pending.len() >= MAX_PENDING_REQUESTS {
390 return Err(WebexError::kms(format!(
391 "Too many pending KMS requests (max: {})",
392 MAX_PENDING_REQUESTS
393 )));
394 }
395 pending.push((
396 request_id.to_string(),
397 PendingRequest { tx },
398 ));
399 }
400
401 let mut headers = HashMap::new();
403 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
404 headers.insert("Content-Type".to_string(), "application/json".to_string());
405
406 let body = serde_json::to_string(&serde_json::json!({
407 "destination": self.kms_cluster,
408 "kmsMessages": [wrapped],
409 }))
410 .map_err(|e| WebexError::kms(format!("Failed to serialize KMS request: {e}")))?;
411
412 let http_response = (self.http_do)(FetchRequest {
413 url: format!("{}/kms/messages", self.encryption_service_url),
414 method: "POST".to_string(),
415 headers,
416 body: Some(body),
417 })
418 .await;
419
420 match http_response {
421 Ok(resp) if !resp.ok => {
422 let status = resp.status;
423 let body = String::from_utf8_lossy(&resp.body);
424 let mut pending = self.pending_requests.lock().await;
425 pending.retain(|(id, _)| id != request_id);
426 return Err(WebexError::kms(format!(
427 "KMS HTTP request failed: {status} {body}"
428 )));
429 }
430 Err(e) => {
431 let mut pending = self.pending_requests.lock().await;
432 pending.retain(|(id, _)| id != request_id);
433 return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
434 }
435 Ok(resp) => {
436 debug!(
437 "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
438 resp.status
439 );
440 }
441 }
442
443 match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
445 Ok(Ok(response)) => Ok(response),
446 Ok(Err(_)) => Err(WebexError::kms(format!(
447 "KMS request {request_id} channel closed"
448 ))),
449 Err(_) => {
450 let mut pending = self.pending_requests.lock().await;
451 pending.retain(|(id, _)| id != request_id);
452 Err(WebexError::kms(format!(
453 "KMS request {request_id} timed out after {}s",
454 KMS_RESPONSE_TIMEOUT.as_secs()
455 )))
456 }
457 }
458 }
459
460 fn is_context_expired(&self) -> bool {
461 if !self.initialized {
462 return true;
463 }
464 match self.context_expiration {
465 Some(exp) => {
466 let with_buffer = exp - Duration::from_secs(30);
467 Instant::now() > with_buffer
468 }
469 None => true,
470 }
471 }
472
473 pub fn is_initialized(&self) -> bool {
475 self.initialized
476 }
477}
478
479fn extract_jwk_from_response(data: &Value) -> Option<Value> {
481 if let Some(jwk) = data.pointer("/body/key/jwk") {
483 if jwk.is_object() {
484 return Some(jwk.clone());
485 }
486 }
487 if let Some(key) = data.pointer("/body/key") {
489 if key.is_object() {
490 return Some(key.clone());
491 }
492 }
493 if let Some(jwk) = data.pointer("/key/jwk") {
495 if jwk.is_object() {
496 return Some(jwk.clone());
497 }
498 }
499 if let Some(key) = data.get("key") {
501 if key.is_object() {
502 return Some(key.clone());
503 }
504 }
505 None
506}
507
508fn extract_key_uri(data: &Value) -> Option<String> {
510 data.pointer("/body/key/uri")
511 .or_else(|| data.pointer("/key/uri"))
512 .and_then(|v| v.as_str())
513 .map(|s| s.to_string())
514}