1use crate::errors::WebexError;
12use crate::jwe;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use base64::Engine;
15use p256::elliptic_curve::sec1::ToEncodedPoint;
16use p256::PublicKey;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::{oneshot, Mutex};
22use tracing::{debug, info, warn};
23use uuid::Uuid;
24
25const KMS_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
26
27struct PendingRequest {
29 tx: oneshot::Sender<String>,
30}
31
32#[derive(Clone)]
38pub struct KmsResponseHandler {
39 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
40}
41
42impl KmsResponseHandler {
43 pub async fn handle_kms_message(&self, data: &Value) {
45 let kms_messages = data
46 .get("kmsMessages")
47 .and_then(|v| v.as_array())
48 .or_else(|| {
49 data.get("encryption")
50 .and_then(|e| e.get("kmsMessages"))
51 .and_then(|v| v.as_array())
52 });
53
54 let kms_messages = match kms_messages {
55 Some(msgs) => msgs,
56 None => {
57 debug!("Received KMS message without kmsMessages array");
58 return;
59 }
60 };
61
62 let mut pending = self.pending_requests.lock().await;
63
64 for raw_msg in kms_messages {
65 let wrapped = match raw_msg.as_str() {
66 Some(s) => s.to_string(),
67 None => continue,
68 };
69
70 debug!("Received KMS response, pending requests: {}", pending.len());
71
72 if !pending.is_empty() {
74 let (_, req) = pending.remove(0);
75 let _ = req.tx.send(wrapped);
76 } else {
77 warn!("Received KMS response but no pending requests");
78 }
79 }
80 }
81}
82
83pub struct KmsClient {
85 token: String,
86 device_url: String,
87 user_id: String,
88 encryption_service_url: String,
89 client: reqwest::Client,
90
91 kms_cluster: String,
92 ephemeral_key: Option<[u8; 32]>,
94 ephemeral_key_kid: String,
96 context_expiration: Option<Instant>,
97 key_cache: HashMap<String, [u8; 32]>,
99 initialized: bool,
100
101 pending_requests: Arc<Mutex<Vec<(String, PendingRequest)>>>,
103}
104
105impl KmsClient {
106 pub fn new(
107 token: &str,
108 device_url: &str,
109 user_id: &str,
110 encryption_service_url: &str,
111 ) -> Self {
112 Self {
113 token: token.to_string(),
114 device_url: device_url.to_string(),
115 user_id: user_id.to_string(),
116 encryption_service_url: encryption_service_url.to_string(),
117 client: reqwest::Client::new(),
118 kms_cluster: String::new(),
119 ephemeral_key: None,
120 ephemeral_key_kid: String::new(),
121 context_expiration: None,
122 key_cache: HashMap::new(),
123 initialized: false,
124 pending_requests: Arc::new(Mutex::new(Vec::new())),
125 }
126 }
127
128 pub fn response_handler(&self) -> KmsResponseHandler {
133 KmsResponseHandler {
134 pending_requests: self.pending_requests.clone(),
135 }
136 }
137
138 pub async fn initialize(&mut self) -> Result<(), WebexError> {
140 info!("Initializing KMS client");
141
142 let kms_details_url = format!("{}/kms/{}", self.encryption_service_url, self.user_id);
144 let response = self
145 .client
146 .get(&kms_details_url)
147 .header("Authorization", format!("Bearer {}", self.token))
148 .send()
149 .await
150 .map_err(|e| WebexError::kms(format!("Failed to fetch KMS details: {e}")))?;
151
152 if !response.status().is_success() {
153 return Err(WebexError::kms(format!(
154 "Failed to fetch KMS details: {}",
155 response.status()
156 )));
157 }
158
159 let kms_details: Value = response
160 .json()
161 .await
162 .map_err(|e| WebexError::kms(format!("Failed to parse KMS details: {e}")))?;
163
164 self.kms_cluster = kms_details["kmsCluster"]
165 .as_str()
166 .ok_or_else(|| WebexError::kms("Missing kmsCluster in KMS details"))?
167 .to_string();
168
169 let rsa_jwk_value = match &kms_details["rsaPublicKey"] {
171 Value::String(s) => serde_json::from_str::<Value>(s)
172 .map_err(|e| WebexError::kms(format!("Failed to parse RSA public key string: {e}")))?,
173 v @ Value::Object(_) => v.clone(),
174 _ => return Err(WebexError::kms("Invalid rsaPublicKey format")),
175 };
176
177 let local_secret = p256::SecretKey::random(&mut rand::thread_rng());
179 let local_public = local_secret.public_key();
180 let local_public_point = local_public.to_encoded_point(false);
181
182 let x_bytes = local_public_point.x().ok_or_else(|| WebexError::kms("Missing x coordinate"))?;
183 let y_bytes = local_public_point.y().ok_or_else(|| WebexError::kms("Missing y coordinate"))?;
184
185 let public_jwk_map = serde_json::json!({
186 "kty": "EC",
187 "crv": "P-256",
188 "x": URL_SAFE_NO_PAD.encode(&*x_bytes),
189 "y": URL_SAFE_NO_PAD.encode(&*y_bytes),
190 });
191
192 let request_id = Uuid::new_v4().to_string();
194 let ecdh_request_body = serde_json::json!({
195 "client": {
196 "clientId": self.device_url,
197 "credential": {
198 "userId": self.user_id,
199 "bearer": self.token,
200 },
201 },
202 "method": "create",
203 "uri": format!("{}/ecdhe", self.kms_cluster),
204 "requestId": request_id,
205 "jwk": public_jwk_map,
206 });
207
208 let wrapped = jwe::encrypt_rsa_oaep_a256gcm(
210 ecdh_request_body.to_string().as_bytes(),
211 &rsa_jwk_value,
212 )?;
213
214 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
216
217 let response_body = jwe::unwrap_kms_response(
219 &wrapped_response,
220 &jwe::JweKey::EcdhPrivate(local_secret.clone()),
221 )?;
222 let response_data: Value = serde_json::from_slice(&response_body)
223 .map_err(|e| WebexError::kms(format!("Failed to parse ECDH response: {e}")))?;
224
225 let remote_jwk_data = extract_jwk_from_response(&response_data)
227 .ok_or_else(|| WebexError::kms("No key in ECDH response"))?;
228
229 let remote_x = remote_jwk_data["x"]
231 .as_str()
232 .ok_or_else(|| WebexError::kms("Missing x in remote key"))?;
233 let remote_y = remote_jwk_data["y"]
234 .as_str()
235 .ok_or_else(|| WebexError::kms("Missing y in remote key"))?;
236
237 let remote_x_bytes = URL_SAFE_NO_PAD
238 .decode(remote_x)
239 .map_err(|e| WebexError::kms(format!("Failed to decode remote x: {e}")))?;
240 let remote_y_bytes = URL_SAFE_NO_PAD
241 .decode(remote_y)
242 .map_err(|e| WebexError::kms(format!("Failed to decode remote y: {e}")))?;
243
244 let mut uncompressed = vec![0x04];
246 uncompressed.extend_from_slice(&remote_x_bytes);
247 uncompressed.extend_from_slice(&remote_y_bytes);
248
249 let remote_public = PublicKey::from_sec1_bytes(&uncompressed)
250 .map_err(|e| WebexError::kms(format!("Failed to parse remote public key: {e}")))?;
251
252 let shared_secret = p256::ecdh::diffie_hellman(
254 local_secret.to_nonzero_scalar(),
255 remote_public.as_affine(),
256 );
257
258 let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(None, shared_secret.raw_secret_bytes());
260 let mut derived = [0u8; 32];
261 hkdf.expand(&[], &mut derived)
262 .map_err(|e| WebexError::kms(format!("HKDF derivation failed: {e}")))?;
263
264 self.ephemeral_key = Some(derived);
265 self.ephemeral_key_kid = extract_key_uri(&response_data).unwrap_or_default();
266 self.initialized = true;
267
268 self.context_expiration = Some(Instant::now() + Duration::from_secs(3600));
270
271 info!("KMS client initialized successfully");
272 Ok(())
273 }
274
275 pub async fn get_key(&mut self, key_uri: &str) -> Result<[u8; 32], WebexError> {
277 if let Some(cached) = self.key_cache.get(key_uri) {
279 debug!("Cache hit for key: {key_uri}");
280 return Ok(*cached);
281 }
282
283 if self.is_context_expired() {
285 info!("Context expired, re-initializing");
286 self.initialize().await?;
287 }
288
289 if !self.initialized {
290 return Err(WebexError::kms("KMS context not initialized"));
291 }
292
293 let ephemeral_key = self
294 .ephemeral_key
295 .ok_or_else(|| WebexError::kms("No ephemeral key"))?;
296
297 let request_id = Uuid::new_v4().to_string();
299 let retrieve_body = serde_json::json!({
300 "client": {
301 "clientId": self.device_url,
302 "credential": {
303 "userId": self.user_id,
304 "bearer": self.token,
305 },
306 },
307 "method": "retrieve",
308 "uri": key_uri,
309 "requestId": request_id,
310 });
311
312 let wrapped = jwe::encrypt_dir_a256gcm(
314 retrieve_body.to_string().as_bytes(),
315 &ephemeral_key,
316 &self.ephemeral_key_kid,
317 )?;
318
319 let wrapped_response = self.send_kms_request(&request_id, &wrapped).await?;
321
322 let response_body = jwe::unwrap_kms_response(
324 &wrapped_response,
325 &jwe::JweKey::Symmetric(ephemeral_key),
326 )?;
327 let response_data: Value = serde_json::from_slice(&response_body)
328 .map_err(|e| WebexError::kms(format!("Failed to parse key response: {e}")))?;
329
330 let key_jwk_data = extract_jwk_from_response(&response_data)
332 .ok_or_else(|| WebexError::kms("No key found in KMS response"))?;
333
334 let k_b64 = key_jwk_data["k"]
336 .as_str()
337 .ok_or_else(|| WebexError::kms("Missing 'k' in content key JWK"))?;
338 let k_bytes = URL_SAFE_NO_PAD
339 .decode(k_b64)
340 .map_err(|e| WebexError::kms(format!("Failed to decode content key: {e}")))?;
341
342 let content_key: [u8; 32] = k_bytes
343 .try_into()
344 .map_err(|_| WebexError::kms("Content key is not 32 bytes"))?;
345
346 self.key_cache.insert(key_uri.to_string(), content_key);
347 info!("Key retrieved and cached: {key_uri}");
348 Ok(content_key)
349 }
350
351 async fn send_kms_request(
353 &self,
354 request_id: &str,
355 wrapped: &str,
356 ) -> Result<String, WebexError> {
357 let (tx, rx) = oneshot::channel();
358
359 {
361 let mut pending = self.pending_requests.lock().await;
362 pending.push((
363 request_id.to_string(),
364 PendingRequest { tx },
365 ));
366 }
367
368 let http_response = self
370 .client
371 .post(format!("{}/kms/messages", self.encryption_service_url))
372 .header("Authorization", format!("Bearer {}", self.token))
373 .header("Content-Type", "application/json")
374 .json(&serde_json::json!({
375 "destination": self.kms_cluster,
376 "kmsMessages": [wrapped],
377 }))
378 .send()
379 .await;
380
381 match http_response {
382 Ok(resp) if !resp.status().is_success() => {
383 let status = resp.status();
384 let body = resp.text().await.unwrap_or_default();
385 let mut pending = self.pending_requests.lock().await;
386 pending.retain(|(id, _)| id != request_id);
387 return Err(WebexError::kms(format!(
388 "KMS HTTP request failed: {status} {body}"
389 )));
390 }
391 Err(e) => {
392 let mut pending = self.pending_requests.lock().await;
393 pending.retain(|(id, _)| id != request_id);
394 return Err(WebexError::kms(format!("KMS HTTP request failed: {e}")));
395 }
396 Ok(resp) => {
397 debug!(
398 "KMS request {request_id} sent (HTTP {}), waiting for Mercury response...",
399 resp.status()
400 );
401 }
402 }
403
404 match tokio::time::timeout(KMS_RESPONSE_TIMEOUT, rx).await {
406 Ok(Ok(response)) => Ok(response),
407 Ok(Err(_)) => Err(WebexError::kms(format!(
408 "KMS request {request_id} channel closed"
409 ))),
410 Err(_) => {
411 let mut pending = self.pending_requests.lock().await;
412 pending.retain(|(id, _)| id != request_id);
413 Err(WebexError::kms(format!(
414 "KMS request {request_id} timed out after {}s",
415 KMS_RESPONSE_TIMEOUT.as_secs()
416 )))
417 }
418 }
419 }
420
421 fn is_context_expired(&self) -> bool {
422 if !self.initialized {
423 return true;
424 }
425 match self.context_expiration {
426 Some(exp) => {
427 let with_buffer = exp - Duration::from_secs(30);
428 Instant::now() > with_buffer
429 }
430 None => true,
431 }
432 }
433
434 pub fn is_initialized(&self) -> bool {
436 self.initialized
437 }
438}
439
440fn extract_jwk_from_response(data: &Value) -> Option<Value> {
442 if let Some(jwk) = data.pointer("/body/key/jwk") {
444 if jwk.is_object() {
445 return Some(jwk.clone());
446 }
447 }
448 if let Some(key) = data.pointer("/body/key") {
450 if key.is_object() {
451 return Some(key.clone());
452 }
453 }
454 if let Some(jwk) = data.pointer("/key/jwk") {
456 if jwk.is_object() {
457 return Some(jwk.clone());
458 }
459 }
460 if let Some(key) = data.get("key") {
462 if key.is_object() {
463 return Some(key.clone());
464 }
465 }
466 None
467}
468
469fn extract_key_uri(data: &Value) -> Option<String> {
471 data.pointer("/body/key/uri")
472 .or_else(|| data.pointer("/key/uri"))
473 .and_then(|v| v.as_str())
474 .map(|s| s.to_string())
475}