1use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9use std::time::Instant;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use super::capabilities::ServerToClientRequests;
17use crate::types::Timestamp;
18
19#[derive(Clone)]
25pub struct RequestContext {
26 pub request_id: String,
28
29 pub user_id: Option<String>,
31
32 pub session_id: Option<String>,
34
35 pub client_id: Option<String>,
37
38 pub timestamp: Timestamp,
40
41 pub start_time: Instant,
43
44 pub metadata: Arc<HashMap<String, serde_json::Value>>,
46
47 #[cfg(feature = "tracing")]
49 pub span: Option<tracing::Span>,
50
51 pub cancellation_token: Option<Arc<CancellationToken>>,
53
54 #[doc(hidden)]
57 pub(crate) server_to_client: Option<Arc<dyn ServerToClientRequests>>,
58}
59
60impl fmt::Debug for RequestContext {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("RequestContext")
63 .field("request_id", &self.request_id)
64 .field("user_id", &self.user_id)
65 .field("session_id", &self.session_id)
66 .field("client_id", &self.client_id)
67 .field("timestamp", &self.timestamp)
68 .field("metadata", &self.metadata)
69 .field("server_to_client", &self.server_to_client.is_some())
70 .finish()
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ResponseContext {
77 pub request_id: String,
79
80 pub timestamp: Timestamp,
82
83 pub duration: std::time::Duration,
85
86 pub status: ResponseStatus,
88
89 pub metadata: Arc<HashMap<String, serde_json::Value>>,
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum ResponseStatus {
96 Success,
98 Error {
100 code: i32,
102 message: String,
104 },
105 Partial,
107 Cancelled,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RequestInfo {
114 pub timestamp: DateTime<Utc>,
116 pub client_id: String,
118 pub method_name: String,
120 pub parameters: serde_json::Value,
122 pub response_time_ms: Option<u64>,
124 pub success: bool,
126 pub error_message: Option<String>,
128 pub status_code: Option<u16>,
130 pub metadata: HashMap<String, serde_json::Value>,
132}
133
134impl RequestContext {
135 #[must_use]
137 pub fn new() -> Self {
138 Self {
139 request_id: Uuid::new_v4().to_string(),
140 user_id: None,
141 session_id: None,
142 client_id: None,
143 timestamp: Timestamp::now(),
144 start_time: Instant::now(),
145 metadata: Arc::new(HashMap::new()),
146 #[cfg(feature = "tracing")]
147 span: None,
148 cancellation_token: None,
149 server_to_client: None,
150 }
151 }
152
153 pub fn with_id(id: impl Into<String>) -> Self {
155 Self {
156 request_id: id.into(),
157 ..Self::new()
158 }
159 }
160
161 #[must_use]
170 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
171 self.user_id = Some(user_id.into());
172 self
173 }
174
175 #[must_use]
177 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
178 self.session_id = Some(session_id.into());
179 self
180 }
181
182 #[must_use]
184 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
185 self.client_id = Some(client_id.into());
186 self
187 }
188
189 #[must_use]
199 pub fn with_metadata(
200 mut self,
201 key: impl Into<String>,
202 value: impl Into<serde_json::Value>,
203 ) -> Self {
204 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
205 self
206 }
207
208 #[must_use]
210 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
211 self.metadata.get(key)
212 }
213
214 #[must_use]
216 pub fn elapsed(&self) -> std::time::Duration {
217 self.start_time.elapsed()
218 }
219
220 #[must_use]
222 pub fn is_cancelled(&self) -> bool {
223 self.cancellation_token
224 .as_ref()
225 .is_some_and(|token| token.is_cancelled())
226 }
227
228 #[must_use]
234 pub fn with_server_to_client(mut self, capabilities: Arc<dyn ServerToClientRequests>) -> Self {
235 self.server_to_client = Some(capabilities);
236 self
237 }
238
239 #[must_use]
242 pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
243 self.cancellation_token = Some(token);
244 self
245 }
246
247 #[must_use]
249 pub fn user(&self) -> Option<&str> {
250 self.user_id.as_deref()
251 }
252
253 #[must_use]
256 pub fn is_authenticated(&self) -> bool {
257 self.get_metadata("client_authenticated")
258 .and_then(|v| v.as_bool())
259 .unwrap_or(false)
260 }
261
262 #[must_use]
265 pub fn roles(&self) -> Vec<String> {
266 self.get_metadata("auth")
267 .and_then(|auth| auth.get("roles"))
268 .and_then(|roles| roles.as_array())
269 .map(|roles| {
270 roles
271 .iter()
272 .filter_map(|role| role.as_str().map(ToString::to_string))
273 .collect()
274 })
275 .unwrap_or_default()
276 }
277
278 pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
281 if required.is_empty() {
282 return true; }
284
285 let user_roles = self.roles();
286 required
287 .iter()
288 .any(|required_role| user_roles.contains(&required_role.as_ref().to_string()))
289 }
290
291 #[doc(hidden)]
296 pub fn server_to_client(&self) -> Option<&Arc<dyn ServerToClientRequests>> {
297 self.server_to_client.as_ref()
298 }
299
300 #[must_use]
317 pub fn headers(&self) -> Option<HashMap<String, String>> {
318 self.get_metadata("http_headers")
319 .and_then(|v| serde_json::from_value(v.clone()).ok())
320 }
321
322 #[must_use]
336 pub fn header(&self, name: &str) -> Option<String> {
337 let headers = self.headers()?;
338 let name_lower = name.to_lowercase();
339
340 headers
342 .iter()
343 .find(|(key, _)| key.to_lowercase() == name_lower)
344 .map(|(_, value)| value.clone())
345 }
346
347 #[must_use]
361 pub fn transport(&self) -> Option<String> {
362 self.get_metadata("transport")
363 .and_then(|v| v.as_str())
364 .map(|s| s.to_string())
365 }
366}
367
368impl Default for RequestContext {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374impl ResponseContext {
375 pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
377 Self {
378 request_id: request_id.into(),
379 timestamp: Timestamp::now(),
380 duration,
381 status: ResponseStatus::Success,
382 metadata: Arc::new(HashMap::new()),
383 }
384 }
385
386 pub fn error(
388 request_id: impl Into<String>,
389 duration: std::time::Duration,
390 code: i32,
391 message: impl Into<String>,
392 ) -> Self {
393 Self {
394 request_id: request_id.into(),
395 timestamp: Timestamp::now(),
396 duration,
397 status: ResponseStatus::Error {
398 code,
399 message: message.into(),
400 },
401 metadata: Arc::new(HashMap::new()),
402 }
403 }
404}
405
406impl RequestInfo {
407 #[must_use]
409 pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
410 Self {
411 timestamp: Utc::now(),
412 client_id,
413 method_name,
414 parameters,
415 response_time_ms: None,
416 success: false,
417 error_message: None,
418 status_code: None,
419 metadata: HashMap::new(),
420 }
421 }
422
423 #[must_use]
425 pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
426 self.response_time_ms = Some(response_time_ms);
427 self.success = true;
428 self.status_code = Some(200);
429 self
430 }
431
432 #[must_use]
434 pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
435 self.response_time_ms = Some(response_time_ms);
436 self.success = false;
437 self.error_message = Some(error);
438 self.status_code = Some(500);
439 self
440 }
441
442 #[must_use]
444 pub const fn with_status_code(mut self, code: u16) -> Self {
445 self.status_code = Some(code);
446 self
447 }
448
449 #[must_use]
451 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
452 self.metadata.insert(key, value);
453 self
454 }
455}
456
457pub trait RequestContextExt {
459 #[must_use]
461 fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self;
462
463 #[must_use]
465 fn extract_client_id(
466 self,
467 extractor: &super::client::ClientIdExtractor,
468 headers: Option<&HashMap<String, String>>,
469 query_params: Option<&HashMap<String, String>>,
470 ) -> Self;
471
472 fn get_enhanced_client_id(&self) -> Option<super::client::ClientId>;
474}
475
476impl RequestContextExt for RequestContext {
477 fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self {
478 self.with_client_id(client_id.as_str())
479 .with_metadata(
480 "client_id_method".to_string(),
481 serde_json::Value::String(client_id.auth_method().to_string()),
482 )
483 .with_metadata(
484 "client_authenticated".to_string(),
485 serde_json::Value::Bool(client_id.is_authenticated()),
486 )
487 }
488
489 fn extract_client_id(
490 self,
491 extractor: &super::client::ClientIdExtractor,
492 headers: Option<&HashMap<String, String>>,
493 query_params: Option<&HashMap<String, String>>,
494 ) -> Self {
495 let client_id = extractor.extract_client_id(headers, query_params);
496 self.with_enhanced_client_id(client_id)
497 }
498
499 fn get_enhanced_client_id(&self) -> Option<super::client::ClientId> {
500 self.client_id.as_ref().map(|id| {
501 let method = self
502 .get_metadata("client_id_method")
503 .and_then(|v| v.as_str())
504 .unwrap_or("header");
505
506 match method {
507 "bearer_token" => super::client::ClientId::Token(id.clone()),
508 "session_cookie" => super::client::ClientId::Session(id.clone()),
509 "query_param" => super::client::ClientId::QueryParam(id.clone()),
510 "user_agent" => super::client::ClientId::UserAgent(id.clone()),
511 "anonymous" => super::client::ClientId::Anonymous,
512 _ => super::client::ClientId::Header(id.clone()), }
514 })
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_headers_returns_none_when_not_set() {
524 let ctx = RequestContext::new();
525 assert!(ctx.headers().is_none());
526 }
527
528 #[test]
529 fn test_headers_returns_headers_when_set() {
530 let mut headers_map = HashMap::new();
531 headers_map.insert("user-agent".to_string(), "Test-Agent/1.0".to_string());
532 headers_map.insert("content-type".to_string(), "application/json".to_string());
533
534 let headers_json = serde_json::to_value(&headers_map).unwrap();
535 let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
536
537 let headers = ctx.headers();
538 assert!(headers.is_some());
539
540 let headers = headers.unwrap();
541 assert_eq!(headers.len(), 2);
542 assert_eq!(
543 headers.get("user-agent"),
544 Some(&"Test-Agent/1.0".to_string())
545 );
546 assert_eq!(
547 headers.get("content-type"),
548 Some(&"application/json".to_string())
549 );
550 }
551
552 #[test]
553 fn test_header_case_insensitive_lookup() {
554 let mut headers_map = HashMap::new();
555 headers_map.insert("User-Agent".to_string(), "Test-Agent/1.0".to_string());
556 headers_map.insert("Content-Type".to_string(), "application/json".to_string());
557
558 let headers_json = serde_json::to_value(&headers_map).unwrap();
559 let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
560
561 assert_eq!(ctx.header("user-agent"), Some("Test-Agent/1.0".to_string()));
563 assert_eq!(ctx.header("USER-AGENT"), Some("Test-Agent/1.0".to_string()));
564 assert_eq!(ctx.header("User-Agent"), Some("Test-Agent/1.0".to_string()));
565 assert_eq!(
566 ctx.header("content-type"),
567 Some("application/json".to_string())
568 );
569 assert_eq!(
570 ctx.header("CONTENT-TYPE"),
571 Some("application/json".to_string())
572 );
573 }
574
575 #[test]
576 fn test_header_returns_none_when_not_found() {
577 let mut headers_map = HashMap::new();
578 headers_map.insert("user-agent".to_string(), "Test-Agent/1.0".to_string());
579
580 let headers_json = serde_json::to_value(&headers_map).unwrap();
581 let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
582
583 assert_eq!(ctx.header("x-custom-header"), None);
584 }
585
586 #[test]
587 fn test_header_returns_none_when_headers_not_set() {
588 let ctx = RequestContext::new();
589 assert_eq!(ctx.header("user-agent"), None);
590 }
591
592 #[test]
593 fn test_transport_returns_none_when_not_set() {
594 let ctx = RequestContext::new();
595 assert!(ctx.transport().is_none());
596 }
597
598 #[test]
599 fn test_transport_returns_transport_type() {
600 let ctx = RequestContext::new().with_metadata("transport", "http");
601
602 assert_eq!(ctx.transport(), Some("http".to_string()));
603 }
604
605 #[test]
606 fn test_multiple_transport_types() {
607 let http_ctx = RequestContext::new().with_metadata("transport", "http");
608 assert_eq!(http_ctx.transport(), Some("http".to_string()));
609
610 let ws_ctx = RequestContext::new().with_metadata("transport", "websocket");
611 assert_eq!(ws_ctx.transport(), Some("websocket".to_string()));
612
613 let stdio_ctx = RequestContext::new().with_metadata("transport", "stdio");
614 assert_eq!(stdio_ctx.transport(), Some("stdio".to_string()));
615 }
616}