1use crate::errors::WebexError;
12use crate::jwe;
13use crate::types::{FetchFn, FetchRequest};
14use base64::engine::general_purpose::URL_SAFE_NO_PAD;
15use base64::Engine;
16use p256::elliptic_curve::sec1::ToEncodedPoint;
17use p256::PublicKey;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::{oneshot, Mutex};
23use tracing::{debug, info, warn};
24use uuid::Uuid;
25
26const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
27
28struct PendingRequest {
30 tx: oneshot::Sender<String>,
31}
32
33#[derive(Clone)]
39pub struct KmsResponseHandler {
40 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
41}
42
43impl KmsResponseHandler {
44 pub async fn handle_kms_message(&self, data: &Value) {
46 let kms_messages = data
47 .get("kmsMessages")
48 .and_then(|v| v.as_array())
49 .or_else(|| {
50 data.get("encryption")
51 .and_then(|e| e.get("kmsMessages"))
52 .and_then(|v| v.as_array())
53 });
54
55 let kms_messages = match kms_messages {
56 Some(msgs) => msgs,
57 None => {
58 debug!("Received KMS message without kmsMessages array");
59 return;
60 }
61 };
62
63 let mut pending = self.pending_requests.lock().await;
64
65 for raw_msg in kms_messages {
66 let wrapped = match raw_msg.as_str() {
67 Some(s) => s.to_string(),
68 None => continue,
69 };
70
71 debug!("Received KMS response, pending requests: {}", pending.len());
72
73 if !pending.is_empty() {
75 let (_, req) = pending.remove(0);
76 let _ = req.tx.send(wrapped);
77 } else {
78 warn!("Received KMS response but no pending requests");
79 }
80 }
81 }
82}
83
84pub struct KmsClient {
86 token: String,
87 device_url: String,
88 user_id: String,
89 encryption_service_url: String,
90 http_do: FetchFn,
91
92 kms_cluster: String,
93 ephemeral_key: Option<[u8; 32]>,
95 ephemeral_key_kid: String,
97 context_expiration: Option<Instant>,
98 key_cache: HashMap<String, [u8; 32]>,
100 initialized: bool,
101
102 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
104}
105
106impl KmsClient {
107 pub fn new(
108 http_do: FetchFn,
109 token: &str,
110 device_url: &str,
111 user_id: &str,
112 encryption_service_url: &str,
113 ) -> Self {
114 Self {
115 token: token.to_string(),
116 device_url: device_url.to_string(),
117 user_id: user_id.to_string(),
118 encryption_service_url: encryption_service_url.to_string(),
119 http_do,
120 kms_cluster: String::new(),
121 ephemeral_key: None,
122 ephemeral_key_kid: String::new(),
123 context_expiration: None,
124 key_cache: HashMap::new(),
125 initialized: false,
126 pending_requests: Arc::new(Mutex::new(Vec::new())),
127 }
128 }
129
130 pub fn response_handler(&self) -> KmsResponseHandler {
135 KmsResponseHandler {
136 pending_requests: self.pending_requests.clone(),
137 }
138 }
139
140 pub async fn initialize(&mut self) -> Result<(), WebexError> {
142 info!("Initializing KMS client");
143
144 let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
146
147 let mut headers = HashMap::new();
148 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
149
150 let response = (self.http_do)(FetchRequest {
151 url: kms_details_url,
152 method: "GET".to_string(),
153 headers,
154 body: None,
155 })
156 .await
157 .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
158
159 if !response.ok {
160 return Err(WebexError::kms(format!(
161 "Failed to fetch KMS details: {}",
162 response.status
163 )));
164 }
165
166 let kms_details: Value = serde_json::from_slice(&response.body)
167 .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
168
169 self.kms_cluster = kms_details["kmsCluster"]
170 .as_str()
171 .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
172 .to_string();
173
174 let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
176 Value::String(s) => serde_json::from_str::<Value>(s)
177 .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
178 v @ Value::Object(_) => v.clone(),
179 _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
180 };
181
182 let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
184 let local_public = local_secret.public_key();
185 let local_public_point = local_public.to_encoded_point(false);
186
187 let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
188 let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
189
190 let public_jwk_map = serde_json::json!({
191 "kty": "EC",
192 "crv": "P-256",
193 "x": URL_SAFE_NO_PAD.encode(&*x_bytes),
194 "y": URL_SAFE_NO_PAD.encode(&*y_bytes),
195 });
196
197 let request_id = Uuid::new_v4().to_string();
199 let ecdh_request_body = serde_json::json!({
200 "client": {
201 "clientId": self.device_url,
202 "credential": {
203 "userId": self.user_id,
204 "bearer": self.token,
205 },
206 },
207 "method": "create",
208 "uri": format!("{}/ecdhe", self.kms_cluster),
209 "requestId": request_id,
210 "jwk": public_jwk_map,
211 });
212
213 let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
215 ecdh_request_body.to_string().as_bytes(),
216 &rsa_jwk_value,
217 )?;
218
219 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
221
222 let response_body = jwe::unwrap_kms_response(
224 &wrapped_response,
225 &jwe::JweKey::EcdhPrivate(local_secret.clone()),
226 )?;
227 let response_data: Value = serde_json::from_slice(&response_body)
228 .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
229
230 let remote_jwk_data = extract_jwk_from_response(&response_data)
232 .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
233
234 let remote_x = remote_jwk_data["x"]
236 .as_str()
237 .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
238 let remote_y = remote_jwk_data["y"]
239 .as_str()
240 .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
241
242 let remote_x_bytes = URL_SAFE_NO_PAD
243 .decode(remote_x)
244 .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
245 let remote_y_bytes = URL_SAFE_NO_PAD
246 .decode(remote_y)
247 .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
248
249 let mut uncompressed = vec![0x04];
251 uncompressed.extend_from_slice(&remote_x_bytes);
252 uncompressed.extend_from_slice(&remote_y_bytes);
253
254 let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
255 .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
256
257 let shared_secret = p256::ecdh::diffie_hellman(
259 local_secret.to_nonzero_scalar(),
260 remote_public.as_affine(),
261 );
262
263 let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
265 let mut derived = [0u8; 32];
266 hkdf.expand(&[], &mut derived)
267 .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
268
269 self.ephemeral_key = Some(derived);
270 self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
271 self.initialized = true;
272
273 self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
275
276 info!("KMS client initialized successfully");
277 Ok(())
278 }
279
280 pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
282 if let Some(cached) = self.key_cache.get(key_uri) {
284 debug!("Cache hit for key: {key_uri}");
285 return Ok(*cached);
286 }
287
288 if self.is_context_expired() {
290 info!("Context expired, re-initializing");
291 self.initialize().await?;
292 }
293
294 if !self.initialized {
295 return Err(WebexError::kms("KMS context not initialized"));
296 }
297
298 let ephemeral_key = self
299 .ephemeral_key
300 .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
301
302 let request_id = Uuid::new_v4().to_string();
304 let retrieve_body = serde_json::json!({
305 "client": {
306 "clientId": self.device_url,
307 "credential": {
308 "userId": self.user_id,
309 "bearer": self.token,
310 },
311 },
312 "method": "retrieve",
313 "uri": key_uri,
314 "requestId": request_id,
315 });
316
317 let wrapped = jwe::encrypt_dir_a256gcm(
319 retrieve_body.to_string().as_bytes(),
320 &ephemeral_key,
321 &self.ephemeral_key_kid,
322 )?;
323
324 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
326
327 let response_body = jwe::unwrap_kms_response(
329 &wrapped_response,
330 &jwe::JweKey::Symmetric(ephemeral_key),
331 )?;
332 let response_data: Value = serde_json::from_slice(&response_body)
333 .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
334
335 let key_jwk_data = extract_jwk_from_response(&response_data)
337 .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
338
339 let k_b64 = key_jwk_data["k"]
341 .as_str()
342 .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
343 let k_bytes = URL_SAFE_NO_PAD
344 .decode(k_b64)
345 .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
346
347 let content_key: [u8; 32] = k_bytes
348 .try_into()
349 .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
350
351 self.key_cache.insert(key_uri.to_string(), content_key);
352 info!("Key retrieved and cached: {key_uri}");
353 Ok(content_key)
354 }
355
356 async fn send_kms_request(
358 &self,
359 request_id: &str,
360 wrapped: &str,
361 ) -> Result<String, WebexError> {
362 let (tx, rx) = oneshot::channel();
363
364 {
366 let mut pending = self.pending_requests.lock().await;
367 pending.push((
368 request_id.to_string(),
369 PendingRequest { tx },
370 ));
371 }
372
373 let mut headers = HashMap::new();
375 headers.insert("Authorization".to_string(), format!("Bearer {}", self.token));
376 headers.insert("Content-Type".to_string(), "application/json".to_string());
377
378 let body = serde_json::to_string(&serde_json::json!({
379 "destination": self.kms_cluster,
380 "kmsMessages": [wrapped],
381 }))
382 .map_err(|e| WebexError::kms(format!("Failed to serialize KMS request: {e}")))?;
383
384 let http_response = (self.http_do)(FetchRequest {
385 url: format!("{}/kms/messages", self.encryption_service_url),
386 method: "POST".to_string(),
387 headers,
388 body: Some(body),
389 })
390 .await;
391
392 match http_response {
393 Ok(resp) if !resp.ok => {
394 let status = resp.status;
395 let body = String::from_utf8_lossy(&resp.body);
396 let mut pending = self.pending_requests.lock().await;
397 pending.retain(|(id, _)| id != request_id);
398 return Err(WebexError::kms(format!(
399 "KMS HTTP request failed: {status} {body}"
400 )));
401 }
402 Err(e) => {
403 let mut pending = self.pending_requests.lock().await;
404 pending.retain(|(id, _)| id != request_id);
405 return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
406 }
407 Ok(resp) => {
408 debug!(
409 "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
410 resp.status
411 );
412 }
413 }
414
415 match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
417 Ok(Ok(response)) => Ok(response),
418 Ok(Err(_)) => Err(WebexError::kms(format!(
419 "KMS request {request_id} channel closed"
420 ))),
421 Err(_) => {
422 let mut pending = self.pending_requests.lock().await;
423 pending.retain(|(id, _)| id != request_id);
424 Err(WebexError::kms(format!(
425 "KMS request {request_id} timed out after {}s",
426 KMS_RESPONSE_TIMEOUT.as_secs()
427 )))
428 }
429 }
430 }
431
432 fn is_context_expired(&self) -> bool {
433 if !self.initialized {
434 return true;
435 }
436 match self.context_expiration {
437 Some(exp) => {
438 let with_buffer = exp - Duration::from_secs(30);
439 Instant::now() > with_buffer
440 }
441 None => true,
442 }
443 }
444
445 pub fn is_initialized(&self) -> bool {
447 self.initialized
448 }
449}
450
451fn extract_jwk_from_response(data: &Value) -> Option<Value> {
453 if let Some(jwk) = data.pointer("/body/key/jwk") {
455 if jwk.is_object() {
456 return Some(jwk.clone());
457 }
458 }
459 if let Some(key) = data.pointer("/body/key") {
461 if key.is_object() {
462 return Some(key.clone());
463 }
464 }
465 if let Some(jwk) = data.pointer("/key/jwk") {
467 if jwk.is_object() {
468 return Some(jwk.clone());
469 }
470 }
471 if let Some(key) = data.get("key") {
473 if key.is_object() {
474 return Some(key.clone());
475 }
476 }
477 None
478}
479
480fn extract_key_uri(data: &Value) -> Option<String> {
482 data.pointer("/body/key/uri")
483 .or_else(|| data.pointer("/key/uri"))
484 .and_then(|v| v.as_str())
485 .map(|s| s.to_string())
486}