1use crate::errors::WebexError;
12use crate::jwe;
13use crate::types::{FetchFn, FetchRequest};
14use base64::engine::general_purpose::URL_SAFE_NO_PAD;
15use base64::Engine;
16use p256::elliptic_curve::sec1::ToEncodedPoint;
17use p256::PublicKey;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::{oneshot, Mutex};
23use tracing::{debug, info, warn};
24use uuid::Uuid;
25
26const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
27
28struct PendingRequest {
30 tx: oneshot::Sender<String>,
31}
32
33#[derive(Clone)]
39pub struct KmsResponseHandler {
40 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
41}
42
43impl KmsResponseHandler {
44 pub async fn handle_kms_message(&self, data: &Value) {
46 let kms_messages = data
47 .get("kmsMessages")
48 .and_then(|v| v.as_array())
49 .or_else(|| {
50 data.get("encryption")
51 .and_then(|e| e.get("kmsMessages"))
52 .and_then(|v| v.as_array())
53 });
54
55 let kms_messages = match kms_messages {
56 Some(msgs) => msgs,
57 None => {
58 debug!("Received KMS message without kmsMessages array");
59 return;
60 }
61 };
62
63 let mut pending = self.pending_requests.lock().await;
64
65 for raw_msg in kms_messages {
66 let wrapped = match raw_msg.as_str() {
67 Some(s) => s.to_string(),
68 None => continue,
69 };
70
71 debug!("Received KMS response, pending requests: {}", pending.len());
72
73 if !pending.is_empty() {
75 let (_, req) = pending.remove(0);
76 let _ = req.tx.send(wrapped);
77 } else {
78 warn!("Received KMS response but no pending requests");
79 }
80 }
81 }
82}
83
84pub struct KmsClient {
86 token: String,
87 device_url: String,
88 user_id: String,
89 encryption_service_url: String,
90 http_do: FetchFn,
91
92 kms_cluster: String,
93 ephemeral_key: Option<[u8; 32]>,
95 ephemeral_key_kid: String,
97 context_expiration: Option<Instant>,
98 key_cache: HashMap<String, [u8; 32]>,
100 initialized: bool,
101
102 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
104}
105
106impl KmsClient {
107 pub fn new(
108 http_do: FetchFn,
109 token: &str,
110 device_url: &str,
111 user_id: &str,
112 encryption_service_url: &str,
113 ) -> Self {
114 Self {
115 token: token.to_string(),
116 device_url: device_url.to_string(),
117 user_id: user_id.to_string(),
118 encryption_service_url: encryption_service_url.to_string(),
119 http_do,
120 kms_cluster: String::new(),
121 ephemeral_key: None,
122 ephemeral_key_kid: String::new(),
123 context_expiration: None,
124 key_cache: HashMap::new(),
125 initialized: false,
126 pending_requests: Arc::new(Mutex::new(Vec::new())),
127 }
128 }
129
130 pub fn response_handler(&self) -> KmsResponseHandler {
135 KmsResponseHandler {
136 pending_requests: self.pending_requests.clone(),
137 }
138 }
139
140 pub async fn initialize(&mut self) -> Result<(), WebexError> {
142 info!("Initializing KMS client");
143
144 let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
146
147 let mut headers = HashMap::new();
148 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
149
150 let response = (self.http_do)(FetchRequest {
151 url: kms_details_url,
152 method: "GET".to_string(),
153 headers,
154 body: None,
155 })
156 .await
157 .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
158
159 if !response.ok {
160 return Err(WebexError::kms(format!(
161 "Failed to fetch KMS details: {}",
162 response.status
163 )));
164 }
165
166 let kms_details: Value = serde_json::from_slice(&response.body)
167 .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
168
169 self.kms_cluster = kms_details["kmsCluster"]
170 .as_str()
171 .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
172 .to_string();
173
174 let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
176 Value::String(s) => serde_json::from_str::<Value>(s)
177 .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
178 v @ Value::Object(_) => v.clone(),
179 _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
180 };
181
182 let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
184 let local_public = local_secret.public_key();
185 let local_public_point = local_public.to_encoded_point(false);
186
187 let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
188 let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
189
190 let public_jwk_map = serde_json::json!({
191 "kty": "EC",
192 "crv": "P-256",
193 "x": URL_SAFE_NO_PAD.encode(*x_bytes),
194 "y": URL_SAFE_NO_PAD.encode(*y_bytes),
195 });
196
197 let request_id = Uuid::new_v4().to_string();
199 let ecdh_request_body = serde_json::json!({
200 "client": {
201 "clientId": self.device_url,
202 "credential": {
203 "userId": self.user_id,
204 "bearer": self.token,
205 },
206 },
207 "method": "create",
208 "uri": format!("{}/ecdhe", self.kms_cluster),
209 "requestId": request_id,
210 "jwk": public_jwk_map,
211 });
212
213 let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
215 ecdh_request_body.to_string().as_bytes(),
216 &rsa_jwk_value,
217 )?;
218
219 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
221
222 let response_body = jwe::unwrap_kms_response(
224 &wrapped_response,
225 &jwe::JweKey::EcdhPrivate(local_secret.clone()),
226 )?;
227 let response_data: Value = serde_json::from_slice(&response_body)
228 .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
229
230 let remote_jwk_data = extract_jwk_from_response(&response_data)
232 .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
233
234 let kty = remote_jwk_data.get("kty").and_then(|v| v.as_str()).unwrap_or("");
236 let crv = remote_jwk_data.get("crv").and_then(|v| v.as_str()).unwrap_or("");
237 if kty != "EC" || crv != "P-256" {
238 return Err(WebexError::kms(format!(
239 "Invalid remote key type: kty={}, crv={}", kty, crv
240 )));
241 }
242
243 let remote_x = remote_jwk_data["x"]
245 .as_str()
246 .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
247 let remote_y = remote_jwk_data["y"]
248 .as_str()
249 .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
250
251 let remote_x_bytes = URL_SAFE_NO_PAD
252 .decode(remote_x)
253 .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
254 let remote_y_bytes = URL_SAFE_NO_PAD
255 .decode(remote_y)
256 .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
257
258 let mut uncompressed = vec![0x04];
260 uncompressed.extend_from_slice(&remote_x_bytes);
261 uncompressed.extend_from_slice(&remote_y_bytes);
262
263 let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
264 .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
265
266 let shared_secret = p256::ecdh::diffie_hellman(
268 local_secret.to_nonzero_scalar(),
269 remote_public.as_affine(),
270 );
271
272 let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
274 let mut derived = [0u8; 32];
275 hkdf.expand(&[], &mut derived)
276 .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
277
278 self.ephemeral_key = Some(derived);
279 self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
280 self.initialized = true;
281
282 self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
284
285 info!("KMS client initialized successfully");
286 Ok(())
287 }
288
289 pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
291 if let Some(cached) = self.key_cache.get(key_uri) {
293 debug!("Cache hit for key: {key_uri}");
294 return Ok(*cached);
295 }
296
297 if self.is_context_expired() {
299 info!("Context expired, re-initializing");
300 self.initialize().await?;
301 }
302
303 if !self.initialized {
304 return Err(WebexError::kms("KMS context not initialized"));
305 }
306
307 let ephemeral_key = self
308 .ephemeral_key
309 .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
310
311 let request_id = Uuid::new_v4().to_string();
313 let retrieve_body = serde_json::json!({
314 "client": {
315 "clientId": self.device_url,
316 "credential": {
317 "userId": self.user_id,
318 "bearer": self.token,
319 },
320 },
321 "method": "retrieve",
322 "uri": key_uri,
323 "requestId": request_id,
324 });
325
326 let wrapped = jwe::encrypt_dir_a256gcm(
328 retrieve_body.to_string().as_bytes(),
329 &ephemeral_key,
330 &self.ephemeral_key_kid,
331 )?;
332
333 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
335
336 let response_body = jwe::unwrap_kms_response(
338 &wrapped_response,
339 &jwe::JweKey::Symmetric(ephemeral_key),
340 )?;
341 let response_data: Value = serde_json::from_slice(&response_body)
342 .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
343
344 let key_jwk_data = extract_jwk_from_response(&response_data)
346 .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
347
348 let k_b64 = key_jwk_data["k"]
350 .as_str()
351 .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
352 let k_bytes = URL_SAFE_NO_PAD
353 .decode(k_b64)
354 .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
355
356 let content_key: [u8; 32] = k_bytes
357 .try_into()
358 .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
359
360 self.key_cache.insert(key_uri.to_string(), content_key);
361 info!("Key retrieved and cached: {key_uri}");
362 Ok(content_key)
363 }
364
365 async fn send_kms_request(
367 &self,
368 request_id: &str,
369 wrapped: &str,
370 ) -> Result<String, WebexError> {
371 let (tx, rx) = oneshot::channel();
372
373 {
375 let mut pending = self.pending_requests.lock().await;
376 pending.push((
377 request_id.to_string(),
378 PendingRequest { tx },
379 ));
380 }
381
382 let mut headers = HashMap::new();
384 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
385 headers.insert("Content-Type".to_string(), "application/json".to_string());
386
387 let body = serde_json::to_string(&serde_json::json!({
388 "destination": self.kms_cluster,
389 "kmsMessages": [wrapped],
390 }))
391 .map_err(|e| WebexError::kms(format!("Failed to serialize KMS request: {e}")))?;
392
393 let http_response = (self.http_do)(FetchRequest {
394 url: format!("{}/kms/messages", self.encryption_service_url),
395 method: "POST".to_string(),
396 headers,
397 body: Some(body),
398 })
399 .await;
400
401 match http_response {
402 Ok(resp) if !resp.ok => {
403 let status = resp.status;
404 let body = String::from_utf8_lossy(&resp.body);
405 let mut pending = self.pending_requests.lock().await;
406 pending.retain(|(id, _)| id != request_id);
407 return Err(WebexError::kms(format!(
408 "KMS HTTP request failed: {status} {body}"
409 )));
410 }
411 Err(e) => {
412 let mut pending = self.pending_requests.lock().await;
413 pending.retain(|(id, _)| id != request_id);
414 return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
415 }
416 Ok(resp) => {
417 debug!(
418 "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
419 resp.status
420 );
421 }
422 }
423
424 match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
426 Ok(Ok(response)) => Ok(response),
427 Ok(Err(_)) => Err(WebexError::kms(format!(
428 "KMS request {request_id} channel closed"
429 ))),
430 Err(_) => {
431 let mut pending = self.pending_requests.lock().await;
432 pending.retain(|(id, _)| id != request_id);
433 Err(WebexError::kms(format!(
434 "KMS request {request_id} timed out after {}s",
435 KMS_RESPONSE_TIMEOUT.as_secs()
436 )))
437 }
438 }
439 }
440
441 fn is_context_expired(&self) -> bool {
442 if !self.initialized {
443 return true;
444 }
445 match self.context_expiration {
446 Some(exp) => {
447 let with_buffer = exp - Duration::from_secs(30);
448 Instant::now() > with_buffer
449 }
450 None => true,
451 }
452 }
453
454 pub fn is_initialized(&self) -> bool {
456 self.initialized
457 }
458}
459
460fn extract_jwk_from_response(data: &Value) -> Option<Value> {
462 if let Some(jwk) = data.pointer("/body/key/jwk") {
464 if jwk.is_object() {
465 return Some(jwk.clone());
466 }
467 }
468 if let Some(key) = data.pointer("/body/key") {
470 if key.is_object() {
471 return Some(key.clone());
472 }
473 }
474 if let Some(jwk) = data.pointer("/key/jwk") {
476 if jwk.is_object() {
477 return Some(jwk.clone());
478 }
479 }
480 if let Some(key) = data.get("key") {
482 if key.is_object() {
483 return Some(key.clone());
484 }
485 }
486 None
487}
488
489fn extract_key_uri(data: &Value) -> Option<String> {
491 data.pointer("/body/key/uri")
492 .or_else(|| data.pointer("/key/uri"))
493 .and_then(|v| v.as_str())
494 .map(|s| s.to_string())
495}