1use std::collections::HashMap;
54use std::fmt;
55use std::sync::Arc;
56use std::time::Instant;
57
58use chrono::{DateTime, Utc};
59use serde::{Deserialize, Serialize};
60use tokio_util::sync::CancellationToken;
61use uuid::Uuid;
62
63use crate::types::Timestamp;
64
65pub trait ServerCapabilities: Send + Sync + fmt::Debug {
68 fn create_message(
70 &self,
71 request: serde_json::Value,
72 ) -> futures::future::BoxFuture<
73 '_,
74 Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
75 >;
76
77 fn elicit(
79 &self,
80 request: serde_json::Value,
81 ) -> futures::future::BoxFuture<
82 '_,
83 Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
84 >;
85
86 fn list_roots(
88 &self,
89 ) -> futures::future::BoxFuture<
90 '_,
91 Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
92 >;
93}
94
95#[derive(Clone)]
97pub struct RequestContext {
98 pub request_id: String,
100
101 pub user_id: Option<String>,
103
104 pub session_id: Option<String>,
106
107 pub client_id: Option<String>,
109
110 pub timestamp: Timestamp,
112
113 pub start_time: Instant,
115
116 pub metadata: Arc<HashMap<String, serde_json::Value>>,
118
119 #[cfg(feature = "tracing")]
121 pub span: Option<tracing::Span>,
122
123 pub cancellation_token: Option<Arc<CancellationToken>>,
125
126 #[doc(hidden)]
129 pub(crate) server_capabilities: Option<Arc<dyn ServerCapabilities>>,
130}
131
132impl fmt::Debug for RequestContext {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("RequestContext")
135 .field("request_id", &self.request_id)
136 .field("user_id", &self.user_id)
137 .field("session_id", &self.session_id)
138 .field("client_id", &self.client_id)
139 .field("timestamp", &self.timestamp)
140 .field("metadata", &self.metadata)
141 .field("server_capabilities", &self.server_capabilities.is_some())
142 .finish()
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct ResponseContext {
149 pub request_id: String,
151
152 pub timestamp: Timestamp,
154
155 pub duration: std::time::Duration,
157
158 pub status: ResponseStatus,
160
161 pub metadata: Arc<HashMap<String, serde_json::Value>>,
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
167pub enum ResponseStatus {
168 Success,
170 Error {
172 code: i32,
174 message: String,
176 },
177 Partial,
179 Cancelled,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ElicitationContext {
190 pub elicitation_id: String,
192 pub message: String,
194 pub schema: serde_json::Value,
196 #[deprecated(note = "Use message field instead")]
198 pub prompt: Option<String>,
199 pub constraints: Option<serde_json::Value>,
201 pub defaults: Option<HashMap<String, serde_json::Value>>,
203 pub required: bool,
205 pub timeout_ms: Option<u64>,
207 pub cancellable: bool,
209 pub client_session: Option<ClientSession>,
211 pub requested_at: Timestamp,
213 pub state: ElicitationState,
215 pub metadata: HashMap<String, serde_json::Value>,
217}
218
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub enum ElicitationState {
222 Pending,
224 Accepted,
226 Declined,
228 Cancelled,
230 TimedOut,
232}
233
234impl ElicitationContext {
235 pub fn new(message: String, schema: serde_json::Value) -> Self {
237 Self {
238 elicitation_id: Uuid::new_v4().to_string(),
239 message,
240 schema,
241 #[allow(deprecated)]
242 prompt: None,
243 constraints: None,
244 defaults: None,
245 required: true,
246 timeout_ms: Some(30000),
247 cancellable: true,
248 client_session: None,
249 requested_at: Timestamp::now(),
250 state: ElicitationState::Pending,
251 metadata: HashMap::new(),
252 }
253 }
254
255 pub fn with_client_session(mut self, session: ClientSession) -> Self {
257 self.client_session = Some(session);
258 self
259 }
260
261 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
263 self.timeout_ms = Some(timeout_ms);
264 self
265 }
266
267 pub fn set_state(&mut self, state: ElicitationState) {
269 self.state = state;
270 }
271
272 pub fn is_complete(&self) -> bool {
274 !matches!(self.state, ElicitationState::Pending)
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct CompletionContext {
281 pub completion_id: String,
283 pub completion_ref: CompletionReference,
285 pub argument_name: Option<String>,
287 pub partial_value: Option<String>,
289 pub resolved_arguments: HashMap<String, String>,
291 pub completions: Vec<CompletionOption>,
293 pub cursor_position: Option<usize>,
295 pub max_completions: Option<usize>,
297 pub has_more: bool,
299 pub total_completions: Option<usize>,
301 pub client_capabilities: Option<CompletionCapabilities>,
303 pub metadata: HashMap<String, serde_json::Value>,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct CompletionCapabilities {
310 pub supports_pagination: bool,
312 pub supports_fuzzy: bool,
314 pub max_batch_size: usize,
316 pub supports_descriptions: bool,
318}
319
320impl CompletionContext {
321 pub fn new(completion_ref: CompletionReference) -> Self {
323 Self {
324 completion_id: Uuid::new_v4().to_string(),
325 completion_ref,
326 argument_name: None,
327 partial_value: None,
328 resolved_arguments: HashMap::new(),
329 completions: Vec::new(),
330 cursor_position: None,
331 max_completions: Some(100),
332 has_more: false,
333 total_completions: None,
334 client_capabilities: None,
335 metadata: HashMap::new(),
336 }
337 }
338
339 pub fn add_completion(&mut self, option: CompletionOption) {
341 self.completions.push(option);
342 }
343
344 pub fn with_resolved_arguments(mut self, args: HashMap<String, String>) -> Self {
346 self.resolved_arguments = args;
347 self
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub enum CompletionReference {
354 Prompt {
356 name: String,
358 argument: String,
360 },
361 ResourceTemplate {
363 name: String,
365 parameter: String,
367 },
368 Tool {
370 name: String,
372 argument: String,
374 },
375 Custom {
377 ref_type: String,
379 metadata: HashMap<String, serde_json::Value>,
381 },
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct CompletionOption {
387 pub value: String,
389 pub label: Option<String>,
391 pub completion_type: Option<String>,
393 pub documentation: Option<String>,
395 pub sort_priority: Option<i32>,
397 pub insert_text: Option<String>,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct ResourceTemplateContext {
404 pub template_name: String,
406 pub uri_template: String,
408 pub parameters: HashMap<String, TemplateParameter>,
410 pub description: Option<String>,
412 pub preset_type: Option<String>,
414 pub metadata: HashMap<String, serde_json::Value>,
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct TemplateParameter {
421 pub name: String,
423 pub param_type: String,
425 pub required: bool,
427 pub default: Option<serde_json::Value>,
429 pub description: Option<String>,
431 pub pattern: Option<String>,
433 pub enum_values: Option<Vec<serde_json::Value>>,
435}
436
437#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct PingContext {
440 pub origin: PingOrigin,
442 pub response_threshold_ms: Option<u64>,
444 pub payload: Option<serde_json::Value>,
446 pub health_metadata: HashMap<String, serde_json::Value>,
448 pub connection_metrics: Option<ConnectionMetrics>,
450}
451
452#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
454pub enum PingOrigin {
455 Client,
457 Server,
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct ConnectionMetrics {
464 pub rtt_ms: Option<f64>,
466 pub packet_loss: Option<f64>,
468 pub uptime_seconds: Option<u64>,
470 pub bytes_sent: Option<u64>,
472 pub bytes_received: Option<u64>,
474 pub last_success: Option<DateTime<Utc>>,
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct BidirectionalContext {
481 pub direction: CommunicationDirection,
483 pub initiator: CommunicationInitiator,
485 pub expects_response: bool,
487 pub parent_request_id: Option<String>,
489 pub request_type: Option<String>,
491 pub server_id: Option<String>,
493 pub correlation_id: String,
495 pub metadata: HashMap<String, serde_json::Value>,
497}
498
499impl BidirectionalContext {
500 pub fn new(direction: CommunicationDirection, initiator: CommunicationInitiator) -> Self {
502 Self {
503 direction,
504 initiator,
505 expects_response: true,
506 parent_request_id: None,
507 request_type: None,
508 server_id: None,
509 correlation_id: Uuid::new_v4().to_string(),
510 metadata: HashMap::new(),
511 }
512 }
513
514 pub fn with_direction(mut self, direction: CommunicationDirection) -> Self {
516 self.direction = direction;
517 self
518 }
519
520 pub fn with_request_type(mut self, request_type: String) -> Self {
522 self.request_type = Some(request_type);
523 self
524 }
525
526 pub fn validate_direction(&self) -> Result<(), String> {
528 if let Some(ref request_type) = self.request_type {
529 match (request_type.as_str(), &self.direction) {
530 ("sampling/createMessage", CommunicationDirection::ServerToClient) => Ok(()),
532 ("roots/list", CommunicationDirection::ServerToClient) => Ok(()),
533 ("elicitation/create", CommunicationDirection::ServerToClient) => Ok(()),
534
535 ("completion/complete", CommunicationDirection::ClientToServer) => Ok(()),
537 ("tools/call", CommunicationDirection::ClientToServer) => Ok(()),
538 ("resources/read", CommunicationDirection::ClientToServer) => Ok(()),
539 ("prompts/get", CommunicationDirection::ClientToServer) => Ok(()),
540
541 ("ping", _) => Ok(()), (req, dir) => Err(format!(
546 "Invalid direction {:?} for request type '{}'",
547 dir, req
548 )),
549 }
550 } else {
551 Ok(()) }
553 }
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct ServerInitiatedContext {
559 pub request_type: ServerInitiatedType,
561 pub server_id: String,
563 pub correlation_id: String,
565 pub client_capabilities: Option<ClientCapabilities>,
567 pub initiated_at: Timestamp,
569 pub metadata: HashMap<String, serde_json::Value>,
571}
572
573#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
575pub enum ServerInitiatedType {
576 CreateMessage,
578 ListRoots,
580 Elicitation,
582 Ping,
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize)]
588pub struct ClientCapabilities {
589 pub sampling: bool,
591 pub roots: bool,
593 pub elicitation: bool,
595 pub max_concurrent_requests: usize,
597 pub experimental: HashMap<String, bool>,
599}
600
601impl ServerInitiatedContext {
602 pub fn new(request_type: ServerInitiatedType, server_id: String) -> Self {
604 Self {
605 request_type,
606 server_id,
607 correlation_id: Uuid::new_v4().to_string(),
608 client_capabilities: None,
609 initiated_at: Timestamp::now(),
610 metadata: HashMap::new(),
611 }
612 }
613
614 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
616 self.client_capabilities = Some(capabilities);
617 self
618 }
619
620 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
622 self.metadata.insert(key, value);
623 self
624 }
625}
626
627#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
629pub enum CommunicationDirection {
630 ClientToServer,
632 ServerToClient,
634}
635
636#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
638pub enum CommunicationInitiator {
639 Client,
641 Server,
643}
644
645impl RequestContext {
646 #[must_use]
660 pub fn new() -> Self {
661 Self {
662 request_id: Uuid::new_v4().to_string(),
663 user_id: None,
664 session_id: None,
665 client_id: None,
666 timestamp: Timestamp::now(),
667 start_time: Instant::now(),
668 metadata: Arc::new(HashMap::new()),
669 #[cfg(feature = "tracing")]
670 span: None,
671 cancellation_token: None,
672 server_capabilities: None,
673 }
674 }
675 #[must_use]
690 pub fn is_authenticated(&self) -> bool {
691 self.metadata
692 .get("authenticated")
693 .and_then(serde_json::Value::as_bool)
694 .unwrap_or(false)
695 }
696
697 #[must_use]
699 pub fn user(&self) -> Option<&str> {
700 self.user_id.as_deref()
701 }
702
703 #[must_use]
705 pub fn roles(&self) -> Vec<String> {
706 self.metadata
707 .get("auth")
708 .and_then(|v| v.get("roles"))
709 .and_then(|v| v.as_array())
710 .map(|arr| {
711 arr.iter()
712 .filter_map(|v| v.as_str().map(std::string::ToString::to_string))
713 .collect()
714 })
715 .unwrap_or_default()
716 }
717
718 pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
720 if required.is_empty() {
721 return true;
722 }
723 let user_roles = self.roles();
724 if user_roles.is_empty() {
725 return false;
726 }
727 let set: std::collections::HashSet<_> = user_roles.into_iter().collect();
728 required.iter().any(|r| set.contains(r.as_ref()))
729 }
730
731 pub fn with_id(id: impl Into<String>) -> Self {
733 Self {
734 request_id: id.into(),
735 ..Self::new()
736 }
737 }
738
739 #[doc(hidden)]
742 pub fn with_server_capabilities(mut self, capabilities: Arc<dyn ServerCapabilities>) -> Self {
743 self.server_capabilities = Some(capabilities);
744 self
745 }
746
747 #[doc(hidden)]
749 pub fn server_capabilities(&self) -> Option<&Arc<dyn ServerCapabilities>> {
750 self.server_capabilities.as_ref()
751 }
752
753 #[must_use]
755 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
756 self.user_id = Some(user_id.into());
757 self
758 }
759
760 #[must_use]
762 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
763 self.session_id = Some(session_id.into());
764 self
765 }
766
767 #[must_use]
769 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
770 self.client_id = Some(client_id.into());
771 self
772 }
773
774 #[must_use]
776 pub fn with_metadata(
777 mut self,
778 key: impl Into<String>,
779 value: impl Into<serde_json::Value>,
780 ) -> Self {
781 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
782 self
783 }
784
785 #[must_use]
787 pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
788 self.cancellation_token = Some(token);
789 self
790 }
791
792 #[must_use]
794 pub fn elapsed(&self) -> std::time::Duration {
795 self.start_time.elapsed()
796 }
797
798 #[must_use]
800 pub fn is_cancelled(&self) -> bool {
801 self.cancellation_token
802 .as_ref()
803 .is_some_and(|token| token.is_cancelled())
804 }
805
806 #[must_use]
808 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
809 self.metadata.get(key)
810 }
811
812 #[must_use]
814 pub fn derive(&self) -> Self {
815 Self {
816 request_id: Uuid::new_v4().to_string(),
817 user_id: self.user_id.clone(),
818 session_id: self.session_id.clone(),
819 client_id: self.client_id.clone(),
820 timestamp: Timestamp::now(),
821 start_time: Instant::now(),
822 metadata: self.metadata.clone(),
823 #[cfg(feature = "tracing")]
824 span: None,
825 cancellation_token: self.cancellation_token.clone(),
826 server_capabilities: self.server_capabilities.clone(),
827 }
828 }
829
830 #[must_use]
836 pub fn with_elicitation_context(mut self, context: ElicitationContext) -> Self {
837 Arc::make_mut(&mut self.metadata).insert(
838 "elicitation_context".to_string(),
839 serde_json::to_value(context).unwrap_or_default(),
840 );
841 self
842 }
843
844 pub fn elicitation_context(&self) -> Option<ElicitationContext> {
846 self.metadata
847 .get("elicitation_context")
848 .and_then(|v| serde_json::from_value(v.clone()).ok())
849 }
850
851 #[must_use]
853 pub fn with_completion_context(mut self, context: CompletionContext) -> Self {
854 Arc::make_mut(&mut self.metadata).insert(
855 "completion_context".to_string(),
856 serde_json::to_value(context).unwrap_or_default(),
857 );
858 self
859 }
860
861 pub fn completion_context(&self) -> Option<CompletionContext> {
863 self.metadata
864 .get("completion_context")
865 .and_then(|v| serde_json::from_value(v.clone()).ok())
866 }
867
868 #[must_use]
870 pub fn with_resource_template_context(mut self, context: ResourceTemplateContext) -> Self {
871 Arc::make_mut(&mut self.metadata).insert(
872 "resource_template_context".to_string(),
873 serde_json::to_value(context).unwrap_or_default(),
874 );
875 self
876 }
877
878 pub fn resource_template_context(&self) -> Option<ResourceTemplateContext> {
880 self.metadata
881 .get("resource_template_context")
882 .and_then(|v| serde_json::from_value(v.clone()).ok())
883 }
884
885 #[must_use]
887 pub fn with_ping_context(mut self, context: PingContext) -> Self {
888 Arc::make_mut(&mut self.metadata).insert(
889 "ping_context".to_string(),
890 serde_json::to_value(context).unwrap_or_default(),
891 );
892 self
893 }
894
895 pub fn ping_context(&self) -> Option<PingContext> {
897 self.metadata
898 .get("ping_context")
899 .and_then(|v| serde_json::from_value(v.clone()).ok())
900 }
901
902 #[must_use]
904 pub fn with_bidirectional_context(mut self, context: BidirectionalContext) -> Self {
905 Arc::make_mut(&mut self.metadata).insert(
906 "bidirectional_context".to_string(),
907 serde_json::to_value(context).unwrap_or_default(),
908 );
909 self
910 }
911
912 pub fn bidirectional_context(&self) -> Option<BidirectionalContext> {
914 self.metadata
915 .get("bidirectional_context")
916 .and_then(|v| serde_json::from_value(v.clone()).ok())
917 }
918
919 pub fn is_server_initiated(&self) -> bool {
921 self.bidirectional_context()
922 .map(|ctx| ctx.direction == CommunicationDirection::ServerToClient)
923 .unwrap_or(false)
924 }
925
926 pub fn is_client_initiated(&self) -> bool {
928 !self.is_server_initiated()
929 }
930
931 pub fn communication_direction(&self) -> CommunicationDirection {
933 self.bidirectional_context()
934 .map(|ctx| ctx.direction)
935 .unwrap_or(CommunicationDirection::ClientToServer)
936 }
937
938 pub fn for_elicitation(schema: serde_json::Value, prompt: Option<String>) -> Self {
940 let message = prompt.unwrap_or_else(|| "Please provide input".to_string());
941 let elicitation_ctx = ElicitationContext::new(message, schema);
942
943 let bidirectional_ctx = BidirectionalContext::new(
944 CommunicationDirection::ServerToClient,
945 CommunicationInitiator::Server,
946 )
947 .with_request_type("elicitation/create".to_string());
948
949 Self::new()
950 .with_elicitation_context(elicitation_ctx)
951 .with_bidirectional_context(bidirectional_ctx)
952 }
953
954 pub fn for_completion(completion_ref: CompletionReference) -> Self {
956 let completion_ctx = CompletionContext::new(completion_ref);
957
958 Self::new().with_completion_context(completion_ctx)
959 }
960
961 pub fn for_resource_template(template_name: String, uri_template: String) -> Self {
963 let template_ctx = ResourceTemplateContext {
964 template_name,
965 uri_template,
966 parameters: HashMap::new(),
967 description: None,
968 preset_type: None,
969 metadata: HashMap::new(),
970 };
971
972 Self::new().with_resource_template_context(template_ctx)
973 }
974
975 pub fn for_ping(origin: PingOrigin) -> Self {
977 let ping_ctx = PingContext {
978 origin: origin.clone(),
979 response_threshold_ms: Some(5_000), payload: None,
981 health_metadata: HashMap::new(),
982 connection_metrics: None,
983 };
984
985 let direction = match origin {
986 PingOrigin::Client => CommunicationDirection::ClientToServer,
987 PingOrigin::Server => CommunicationDirection::ServerToClient,
988 };
989 let initiator = match origin {
990 PingOrigin::Client => CommunicationInitiator::Client,
991 PingOrigin::Server => CommunicationInitiator::Server,
992 };
993
994 let bidirectional_ctx =
995 BidirectionalContext::new(direction, initiator).with_request_type("ping".to_string());
996
997 Self::new()
998 .with_ping_context(ping_ctx)
999 .with_bidirectional_context(bidirectional_ctx)
1000 }
1001}
1002
1003impl ResponseContext {
1004 pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
1006 Self {
1007 request_id: request_id.into(),
1008 timestamp: Timestamp::now(),
1009 duration,
1010 status: ResponseStatus::Success,
1011 metadata: Arc::new(HashMap::new()),
1012 }
1013 }
1014
1015 pub fn error(
1017 request_id: impl Into<String>,
1018 duration: std::time::Duration,
1019 code: i32,
1020 message: impl Into<String>,
1021 ) -> Self {
1022 Self {
1023 request_id: request_id.into(),
1024 timestamp: Timestamp::now(),
1025 duration,
1026 status: ResponseStatus::Error {
1027 code,
1028 message: message.into(),
1029 },
1030 metadata: Arc::new(HashMap::new()),
1031 }
1032 }
1033
1034 pub fn cancelled(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
1036 Self {
1037 request_id: request_id.into(),
1038 timestamp: Timestamp::now(),
1039 duration,
1040 status: ResponseStatus::Cancelled,
1041 metadata: Arc::new(HashMap::new()),
1042 }
1043 }
1044
1045 #[must_use]
1047 pub fn with_metadata(
1048 mut self,
1049 key: impl Into<String>,
1050 value: impl Into<serde_json::Value>,
1051 ) -> Self {
1052 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
1053 self
1054 }
1055
1056 #[must_use]
1058 pub const fn is_success(&self) -> bool {
1059 matches!(self.status, ResponseStatus::Success)
1060 }
1061
1062 #[must_use]
1064 pub const fn is_error(&self) -> bool {
1065 matches!(self.status, ResponseStatus::Error { .. })
1066 }
1067
1068 #[must_use]
1070 pub fn error_info(&self) -> Option<(i32, &str)> {
1071 match &self.status {
1072 ResponseStatus::Error { code, message } => Some((*code, message)),
1073 _ => None,
1074 }
1075 }
1076}
1077
1078impl Default for RequestContext {
1079 fn default() -> Self {
1080 Self::new()
1081 }
1082}
1083
1084impl fmt::Display for ResponseStatus {
1085 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1086 match self {
1087 Self::Success => write!(f, "Success"),
1088 Self::Error { code, message } => write!(f, "Error({code}: {message})"),
1089 Self::Partial => write!(f, "Partial"),
1090 Self::Cancelled => write!(f, "Cancelled"),
1091 }
1092 }
1093}
1094
1095#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1101pub enum ClientId {
1102 Header(String),
1104 Token(String),
1106 Session(String),
1108 QueryParam(String),
1110 UserAgent(String),
1112 Anonymous,
1114}
1115
1116impl ClientId {
1117 #[must_use]
1119 pub fn as_str(&self) -> &str {
1120 match self {
1121 Self::Header(id)
1122 | Self::Token(id)
1123 | Self::Session(id)
1124 | Self::QueryParam(id)
1125 | Self::UserAgent(id) => id,
1126 Self::Anonymous => "anonymous",
1127 }
1128 }
1129
1130 #[must_use]
1132 pub const fn is_authenticated(&self) -> bool {
1133 matches!(self, Self::Token(_) | Self::Session(_))
1134 }
1135
1136 #[must_use]
1138 pub const fn auth_method(&self) -> &'static str {
1139 match self {
1140 Self::Header(_) => "header",
1141 Self::Token(_) => "bearer_token",
1142 Self::Session(_) => "session_cookie",
1143 Self::QueryParam(_) => "query_param",
1144 Self::UserAgent(_) => "user_agent",
1145 Self::Anonymous => "anonymous",
1146 }
1147 }
1148}
1149
1150#[derive(Debug, Clone, Serialize, Deserialize)]
1152pub struct ClientSession {
1153 pub client_id: String,
1155 pub client_name: Option<String>,
1157 pub connected_at: DateTime<Utc>,
1159 pub last_activity: DateTime<Utc>,
1161 pub request_count: usize,
1163 pub transport_type: String,
1165 pub authenticated: bool,
1167 pub capabilities: Option<serde_json::Value>,
1169 pub metadata: HashMap<String, serde_json::Value>,
1171}
1172
1173impl ClientSession {
1174 #[must_use]
1176 pub fn new(client_id: String, transport_type: String) -> Self {
1177 let now = Utc::now();
1178 Self {
1179 client_id,
1180 client_name: None,
1181 connected_at: now,
1182 last_activity: now,
1183 request_count: 0,
1184 transport_type,
1185 authenticated: false,
1186 capabilities: None,
1187 metadata: HashMap::new(),
1188 }
1189 }
1190
1191 pub fn update_activity(&mut self) {
1193 self.last_activity = Utc::now();
1194 self.request_count += 1;
1195 }
1196
1197 pub fn authenticate(&mut self, client_name: Option<String>) {
1199 self.authenticated = true;
1200 self.client_name = client_name;
1201 }
1202
1203 pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
1205 self.capabilities = Some(capabilities);
1206 }
1207
1208 #[must_use]
1210 pub fn session_duration(&self) -> chrono::Duration {
1211 self.last_activity - self.connected_at
1212 }
1213
1214 #[must_use]
1216 pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
1217 Utc::now() - self.last_activity > idle_threshold
1218 }
1219}
1220
1221#[derive(Debug, Clone, Serialize, Deserialize)]
1223pub struct RequestInfo {
1224 pub timestamp: DateTime<Utc>,
1226 pub client_id: String,
1228 pub method_name: String,
1230 pub parameters: serde_json::Value,
1232 pub response_time_ms: Option<u64>,
1234 pub success: bool,
1236 pub error_message: Option<String>,
1238 pub status_code: Option<u16>,
1240 pub metadata: HashMap<String, serde_json::Value>,
1242}
1243
1244impl RequestInfo {
1245 #[must_use]
1247 pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
1248 Self {
1249 timestamp: Utc::now(),
1250 client_id,
1251 method_name,
1252 parameters,
1253 response_time_ms: None,
1254 success: false,
1255 error_message: None,
1256 status_code: None,
1257 metadata: HashMap::new(),
1258 }
1259 }
1260
1261 #[must_use]
1263 pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
1264 self.response_time_ms = Some(response_time_ms);
1265 self.success = true;
1266 self.status_code = Some(200);
1267 self
1268 }
1269
1270 #[must_use]
1272 pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
1273 self.response_time_ms = Some(response_time_ms);
1274 self.success = false;
1275 self.error_message = Some(error);
1276 self.status_code = Some(500);
1277 self
1278 }
1279
1280 #[must_use]
1282 pub const fn with_status_code(mut self, code: u16) -> Self {
1283 self.status_code = Some(code);
1284 self
1285 }
1286
1287 #[must_use]
1289 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
1290 self.metadata.insert(key, value);
1291 self
1292 }
1293}
1294
1295#[derive(Debug)]
1297pub struct ClientIdExtractor {
1298 auth_tokens: Arc<dashmap::DashMap<String, String>>,
1300}
1301
1302impl ClientIdExtractor {
1303 #[must_use]
1305 pub fn new() -> Self {
1306 Self {
1307 auth_tokens: Arc::new(dashmap::DashMap::new()),
1308 }
1309 }
1310
1311 pub fn register_token(&self, token: String, client_id: String) {
1313 self.auth_tokens.insert(token, client_id);
1314 }
1315
1316 pub fn revoke_token(&self, token: &str) {
1318 self.auth_tokens.remove(token);
1319 }
1320
1321 #[must_use]
1323 pub fn list_tokens(&self) -> Vec<(String, String)> {
1324 self.auth_tokens
1325 .iter()
1326 .map(|entry| (entry.key().clone(), entry.value().clone()))
1327 .collect()
1328 }
1329
1330 #[must_use]
1332 #[allow(clippy::significant_drop_tightening)]
1333 pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
1334 if let Some(client_id) = headers.get("x-client-id") {
1336 return ClientId::Header(client_id.clone());
1337 }
1338
1339 if let Some(auth) = headers.get("authorization")
1341 && let Some(token) = auth.strip_prefix("Bearer ")
1342 {
1343 let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
1345 if let Some(entry) = token_lookup {
1346 let client_id = entry.value().clone();
1347 drop(entry); return ClientId::Token(client_id);
1349 }
1350 return ClientId::Token(token.to_string());
1352 }
1353
1354 if let Some(cookie) = headers.get("cookie") {
1356 for cookie_part in cookie.split(';') {
1357 let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
1358 if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
1359 return ClientId::Session(parts[1].to_string());
1360 }
1361 }
1362 }
1363
1364 if let Some(user_agent) = headers.get("user-agent") {
1366 use std::collections::hash_map::DefaultHasher;
1367 use std::hash::{Hash, Hasher};
1368 let mut hasher = DefaultHasher::new();
1369 user_agent.hash(&mut hasher);
1370 return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
1371 }
1372
1373 ClientId::Anonymous
1374 }
1375
1376 #[must_use]
1378 pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
1379 query_params
1380 .get("client_id")
1381 .map(|client_id| ClientId::QueryParam(client_id.clone()))
1382 }
1383
1384 #[must_use]
1386 pub fn extract_client_id(
1387 &self,
1388 headers: Option<&HashMap<String, String>>,
1389 query_params: Option<&HashMap<String, String>>,
1390 ) -> ClientId {
1391 if let Some(params) = query_params
1393 && let Some(client_id) = self.extract_from_query(params)
1394 {
1395 return client_id;
1396 }
1397
1398 if let Some(headers) = headers {
1400 return self.extract_from_http_headers(headers);
1401 }
1402
1403 ClientId::Anonymous
1404 }
1405}
1406
1407impl Default for ClientIdExtractor {
1408 fn default() -> Self {
1409 Self::new()
1410 }
1411}
1412
1413pub trait RequestContextExt {
1415 #[must_use]
1417 fn with_enhanced_client_id(self, client_id: ClientId) -> Self;
1418
1419 #[must_use]
1421 fn extract_client_id(
1422 self,
1423 extractor: &ClientIdExtractor,
1424 headers: Option<&HashMap<String, String>>,
1425 query_params: Option<&HashMap<String, String>>,
1426 ) -> Self;
1427
1428 fn get_enhanced_client_id(&self) -> Option<ClientId>;
1430}
1431
1432impl RequestContextExt for RequestContext {
1433 fn with_enhanced_client_id(self, client_id: ClientId) -> Self {
1434 self.with_client_id(client_id.as_str())
1435 .with_metadata("client_id_method", client_id.auth_method())
1436 .with_metadata("client_authenticated", client_id.is_authenticated())
1437 }
1438
1439 fn extract_client_id(
1440 self,
1441 extractor: &ClientIdExtractor,
1442 headers: Option<&HashMap<String, String>>,
1443 query_params: Option<&HashMap<String, String>>,
1444 ) -> Self {
1445 let client_id = extractor.extract_client_id(headers, query_params);
1446 self.with_enhanced_client_id(client_id)
1447 }
1448
1449 fn get_enhanced_client_id(&self) -> Option<ClientId> {
1450 self.client_id.as_ref().map(|id| {
1451 let method = self
1452 .get_metadata("client_id_method")
1453 .and_then(|v| v.as_str())
1454 .unwrap_or("header");
1455
1456 match method {
1457 "bearer_token" => ClientId::Token(id.clone()),
1458 "session_cookie" => ClientId::Session(id.clone()),
1459 "query_param" => ClientId::QueryParam(id.clone()),
1460 "user_agent" => ClientId::UserAgent(id.clone()),
1461 "anonymous" => ClientId::Anonymous,
1462 _ => ClientId::Header(id.clone()), }
1464 })
1465 }
1466}
1467
1468#[cfg(test)]
1469mod tests {
1470 use super::*;
1471
1472 #[test]
1473 fn test_request_context_creation() {
1474 let ctx = RequestContext::new();
1475 assert!(!ctx.request_id.is_empty());
1476 assert!(ctx.user_id.is_none());
1477 assert!(ctx.elapsed() < std::time::Duration::from_millis(100));
1478 }
1479
1480 #[test]
1481 fn test_request_context_builder() {
1482 let ctx = RequestContext::new()
1483 .with_user_id("user123")
1484 .with_session_id("session456")
1485 .with_metadata("key", "value");
1486
1487 assert_eq!(ctx.user_id, Some("user123".to_string()));
1488 assert_eq!(ctx.session_id, Some("session456".to_string()));
1489 assert_eq!(
1490 ctx.get_metadata("key"),
1491 Some(&serde_json::Value::String("value".to_string()))
1492 );
1493 }
1494
1495 #[test]
1496 fn test_response_context_creation() {
1497 let duration = std::time::Duration::from_millis(100);
1498
1499 let success_ctx = ResponseContext::success("req1", duration);
1500 assert!(success_ctx.is_success());
1501 assert!(!success_ctx.is_error());
1502
1503 let error_ctx = ResponseContext::error("req2", duration, 500, "Internal error");
1504 assert!(!error_ctx.is_success());
1505 assert!(error_ctx.is_error());
1506 assert_eq!(error_ctx.error_info(), Some((500, "Internal error")));
1507 }
1508
1509 #[test]
1510 fn test_context_derivation() {
1511 let parent_ctx = RequestContext::new()
1512 .with_user_id("user123")
1513 .with_metadata("key", "value");
1514
1515 let child_ctx = parent_ctx.derive();
1516
1517 assert_ne!(parent_ctx.request_id, child_ctx.request_id);
1519
1520 assert_eq!(parent_ctx.user_id, child_ctx.user_id);
1522 assert_eq!(
1523 parent_ctx.get_metadata("key"),
1524 child_ctx.get_metadata("key")
1525 );
1526 }
1527
1528 #[test]
1531 fn test_client_id_extraction() {
1532 let extractor = ClientIdExtractor::new();
1533
1534 let mut headers = HashMap::new();
1536 headers.insert("x-client-id".to_string(), "test-client".to_string());
1537
1538 let client_id = extractor.extract_from_http_headers(&headers);
1539 assert_eq!(client_id, ClientId::Header("test-client".to_string()));
1540 assert_eq!(client_id.as_str(), "test-client");
1541 assert_eq!(client_id.auth_method(), "header");
1542 assert!(!client_id.is_authenticated());
1543 }
1544
1545 #[test]
1546 fn test_bearer_token_extraction() {
1547 let extractor = ClientIdExtractor::new();
1548 extractor.register_token("token123".to_string(), "client-1".to_string());
1549
1550 let mut headers = HashMap::new();
1551 headers.insert("authorization".to_string(), "Bearer token123".to_string());
1552
1553 let client_id = extractor.extract_from_http_headers(&headers);
1554 assert_eq!(client_id, ClientId::Token("client-1".to_string()));
1555 assert!(client_id.is_authenticated());
1556 assert_eq!(client_id.auth_method(), "bearer_token");
1557 }
1558
1559 #[test]
1560 fn test_session_cookie_extraction() {
1561 let extractor = ClientIdExtractor::new();
1562
1563 let mut headers = HashMap::new();
1564 headers.insert(
1565 "cookie".to_string(),
1566 "session_id=sess123; other=value".to_string(),
1567 );
1568
1569 let client_id = extractor.extract_from_http_headers(&headers);
1570 assert_eq!(client_id, ClientId::Session("sess123".to_string()));
1571 assert!(client_id.is_authenticated());
1572 }
1573
1574 #[test]
1575 fn test_user_agent_fallback() {
1576 let extractor = ClientIdExtractor::new();
1577
1578 let mut headers = HashMap::new();
1579 headers.insert("user-agent".to_string(), "TestAgent/1.0".to_string());
1580
1581 let client_id = extractor.extract_from_http_headers(&headers);
1582 if let ClientId::UserAgent(id) = client_id {
1583 assert!(id.starts_with("ua_"));
1584 } else {
1585 assert!(
1587 matches!(client_id, ClientId::UserAgent(_)),
1588 "Expected UserAgent ClientId"
1589 );
1590 }
1591 }
1592
1593 #[test]
1594 fn test_client_session() {
1595 let mut session = ClientSession::new("test-client".to_string(), "http".to_string());
1596 assert!(!session.authenticated);
1597 assert_eq!(session.request_count, 0);
1598
1599 session.update_activity();
1600 assert_eq!(session.request_count, 1);
1601
1602 session.authenticate(Some("Test Client".to_string()));
1603 assert!(session.authenticated);
1604 assert_eq!(session.client_name, Some("Test Client".to_string()));
1605
1606 assert!(!session.is_idle(chrono::Duration::seconds(1)));
1608 }
1609
1610 #[test]
1611 fn test_request_info() {
1612 let params = serde_json::json!({"param": "value"});
1613 let request = RequestInfo::new("client-1".to_string(), "test_method".to_string(), params);
1614
1615 assert!(!request.success);
1616 assert!(request.response_time_ms.is_none());
1617
1618 let completed = request.complete_success(150);
1619 assert!(completed.success);
1620 assert_eq!(completed.response_time_ms, Some(150));
1621 assert_eq!(completed.status_code, Some(200));
1622 }
1623
1624 #[test]
1625 fn test_request_context_ext() {
1626 let extractor = ClientIdExtractor::new();
1627
1628 let mut headers = HashMap::new();
1629 headers.insert("x-client-id".to_string(), "test-client".to_string());
1630
1631 let ctx = RequestContext::new().extract_client_id(&extractor, Some(&headers), None);
1632
1633 assert_eq!(ctx.client_id, Some("test-client".to_string()));
1634 assert_eq!(
1635 ctx.get_metadata("client_id_method"),
1636 Some(&serde_json::Value::String("header".to_string()))
1637 );
1638 assert_eq!(
1639 ctx.get_metadata("client_authenticated"),
1640 Some(&serde_json::Value::Bool(false))
1641 );
1642
1643 let enhanced_id = ctx.get_enhanced_client_id();
1644 assert_eq!(
1645 enhanced_id,
1646 Some(ClientId::Header("test-client".to_string()))
1647 );
1648 }
1649
1650 #[test]
1651 fn test_request_analytics() {
1652 let start = std::time::Instant::now();
1653 let request = RequestInfo::new(
1654 "client-123".to_string(),
1655 "get_data".to_string(),
1656 serde_json::json!({"filter": "active"}),
1657 );
1658
1659 let response_time = start.elapsed().as_millis() as u64;
1660 let completed = request
1661 .complete_success(response_time)
1662 .with_metadata("cache_hit".to_string(), serde_json::json!(true));
1663
1664 assert!(completed.success);
1665 assert!(completed.response_time_ms.is_some());
1666 assert_eq!(
1667 completed.metadata.get("cache_hit"),
1668 Some(&serde_json::json!(true))
1669 );
1670 }
1671
1672 #[test]
1677 fn test_elicitation_context() {
1678 let schema = serde_json::json!({
1679 "type": "object",
1680 "properties": {
1681 "name": {"type": "string"},
1682 "age": {"type": "integer"}
1683 }
1684 });
1685
1686 let ctx = RequestContext::for_elicitation(
1687 schema.clone(),
1688 Some("Please enter your details".to_string()),
1689 );
1690
1691 let elicit_ctx = ctx.elicitation_context().unwrap();
1693 assert_eq!(elicit_ctx.schema, schema);
1694 assert_eq!(elicit_ctx.message, "Please enter your details".to_string());
1695 assert!(elicit_ctx.required);
1696 assert!(elicit_ctx.cancellable);
1697 assert_eq!(elicit_ctx.timeout_ms, Some(30_000));
1698
1699 assert!(ctx.is_server_initiated());
1701 assert!(!ctx.is_client_initiated());
1702 assert_eq!(
1703 ctx.communication_direction(),
1704 CommunicationDirection::ServerToClient
1705 );
1706
1707 let bi_ctx = ctx.bidirectional_context().unwrap();
1708 assert_eq!(bi_ctx.direction, CommunicationDirection::ServerToClient);
1709 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Server);
1710 assert!(bi_ctx.expects_response);
1711 }
1712
1713 #[test]
1714 fn test_completion_context() {
1715 let comp_ref = CompletionReference::Tool {
1716 name: "test_tool".to_string(),
1717 argument: "file_path".to_string(),
1718 };
1719
1720 let ctx = RequestContext::for_completion(comp_ref.clone());
1721 let completion_ctx = ctx.completion_context().unwrap();
1722
1723 assert!(matches!(
1724 completion_ctx.completion_ref,
1725 CompletionReference::Tool { .. }
1726 ));
1727 assert_eq!(completion_ctx.max_completions, Some(100));
1728 assert!(completion_ctx.completions.is_empty());
1729
1730 let completion_option = CompletionOption {
1732 value: "/home/user/document.txt".to_string(),
1733 label: Some("document.txt".to_string()),
1734 completion_type: Some("file".to_string()),
1735 documentation: Some("A text document".to_string()),
1736 sort_priority: Some(1),
1737 insert_text: Some("document.txt".to_string()),
1738 };
1739
1740 let mut completion_ctx_with_options = CompletionContext::new(comp_ref);
1741 completion_ctx_with_options.argument_name = Some("file_path".to_string());
1742 completion_ctx_with_options.partial_value = Some("/home/user/".to_string());
1743 completion_ctx_with_options.completions = vec![completion_option];
1744 completion_ctx_with_options.cursor_position = Some(11);
1745 completion_ctx_with_options.max_completions = Some(10);
1746
1747 let ctx_with_options =
1748 RequestContext::new().with_completion_context(completion_ctx_with_options);
1749 let retrieved_ctx = ctx_with_options.completion_context().unwrap();
1750
1751 assert_eq!(retrieved_ctx.argument_name, Some("file_path".to_string()));
1752 assert_eq!(retrieved_ctx.partial_value, Some("/home/user/".to_string()));
1753 assert_eq!(retrieved_ctx.completions.len(), 1);
1754 assert_eq!(
1755 retrieved_ctx.completions[0].value,
1756 "/home/user/document.txt"
1757 );
1758 assert_eq!(retrieved_ctx.cursor_position, Some(11));
1759 }
1760
1761 #[test]
1762 fn test_resource_template_context() {
1763 let template_name = "file_system".to_string();
1764 let uri_template = "file://{path}".to_string();
1765
1766 let ctx =
1767 RequestContext::for_resource_template(template_name.clone(), uri_template.clone());
1768 let template_ctx = ctx.resource_template_context().unwrap();
1769
1770 assert_eq!(template_ctx.template_name, template_name);
1771 assert_eq!(template_ctx.uri_template, uri_template);
1772 assert!(template_ctx.parameters.is_empty());
1773
1774 let mut parameters = HashMap::new();
1776 parameters.insert(
1777 "path".to_string(),
1778 TemplateParameter {
1779 name: "path".to_string(),
1780 param_type: "string".to_string(),
1781 required: true,
1782 default: None,
1783 description: Some("File system path".to_string()),
1784 pattern: Some(r"^[/\w.-]+$".to_string()),
1785 enum_values: None,
1786 },
1787 );
1788
1789 let template_ctx_with_params = ResourceTemplateContext {
1790 template_name: "file_system_detailed".to_string(),
1791 uri_template: "file://{path}".to_string(),
1792 parameters,
1793 description: Some("Access file system resources".to_string()),
1794 preset_type: Some("file_system".to_string()),
1795 metadata: HashMap::new(),
1796 };
1797
1798 let ctx_with_params =
1799 RequestContext::new().with_resource_template_context(template_ctx_with_params);
1800 let retrieved_ctx = ctx_with_params.resource_template_context().unwrap();
1801
1802 assert_eq!(retrieved_ctx.parameters.len(), 1);
1803 let path_param = retrieved_ctx.parameters.get("path").unwrap();
1804 assert_eq!(path_param.param_type, "string");
1805 assert!(path_param.required);
1806 assert_eq!(path_param.description, Some("File system path".to_string()));
1807 assert_eq!(
1808 retrieved_ctx.description,
1809 Some("Access file system resources".to_string())
1810 );
1811 }
1812
1813 #[test]
1814 fn test_ping_context_client_initiated() {
1815 let ctx = RequestContext::for_ping(PingOrigin::Client);
1816 let ping_ctx = ctx.ping_context().unwrap();
1817
1818 assert_eq!(ping_ctx.origin, PingOrigin::Client);
1819 assert_eq!(ping_ctx.response_threshold_ms, Some(5_000));
1820 assert!(ping_ctx.payload.is_none());
1821
1822 assert!(!ctx.is_server_initiated());
1824 assert!(ctx.is_client_initiated());
1825 assert_eq!(
1826 ctx.communication_direction(),
1827 CommunicationDirection::ClientToServer
1828 );
1829
1830 let bi_ctx = ctx.bidirectional_context().unwrap();
1831 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Client);
1832 }
1833
1834 #[test]
1835 fn test_ping_context_server_initiated() {
1836 let ctx = RequestContext::for_ping(PingOrigin::Server);
1837 let ping_ctx = ctx.ping_context().unwrap();
1838
1839 assert_eq!(ping_ctx.origin, PingOrigin::Server);
1840
1841 assert!(ctx.is_server_initiated());
1843 assert!(!ctx.is_client_initiated());
1844 assert_eq!(
1845 ctx.communication_direction(),
1846 CommunicationDirection::ServerToClient
1847 );
1848
1849 let bi_ctx = ctx.bidirectional_context().unwrap();
1850 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Server);
1851 }
1852
1853 #[test]
1854 fn test_connection_metrics() {
1855 let mut ping_ctx = PingContext {
1856 origin: PingOrigin::Client,
1857 response_threshold_ms: Some(1_000),
1858 payload: Some(serde_json::json!({"test": true})),
1859 health_metadata: HashMap::new(),
1860 connection_metrics: None,
1861 };
1862
1863 let metrics = ConnectionMetrics {
1865 rtt_ms: Some(150.5),
1866 packet_loss: Some(0.1),
1867 uptime_seconds: Some(3600),
1868 bytes_sent: Some(1024),
1869 bytes_received: Some(2048),
1870 last_success: Some(Utc::now()),
1871 };
1872
1873 ping_ctx.connection_metrics = Some(metrics);
1874
1875 let ctx = RequestContext::new().with_ping_context(ping_ctx);
1876 let retrieved_ctx = ctx.ping_context().unwrap();
1877 let conn_metrics = retrieved_ctx.connection_metrics.unwrap();
1878
1879 assert_eq!(conn_metrics.rtt_ms, Some(150.5));
1880 assert_eq!(conn_metrics.packet_loss, Some(0.1));
1881 assert_eq!(conn_metrics.uptime_seconds, Some(3600));
1882 assert_eq!(conn_metrics.bytes_sent, Some(1024));
1883 assert_eq!(conn_metrics.bytes_received, Some(2048));
1884 assert!(conn_metrics.last_success.is_some());
1885 }
1886
1887 #[test]
1888 fn test_bidirectional_context_standalone() {
1889 let mut bi_ctx = BidirectionalContext::new(
1890 CommunicationDirection::ServerToClient,
1891 CommunicationInitiator::Server,
1892 );
1893 bi_ctx.expects_response = true;
1894 bi_ctx.parent_request_id = Some("parent-123".to_string());
1895
1896 let ctx = RequestContext::new().with_bidirectional_context(bi_ctx.clone());
1897
1898 assert!(ctx.is_server_initiated());
1899 assert_eq!(
1900 ctx.communication_direction(),
1901 CommunicationDirection::ServerToClient
1902 );
1903
1904 let retrieved_ctx = ctx.bidirectional_context().unwrap();
1905 assert_eq!(
1906 retrieved_ctx.parent_request_id,
1907 Some("parent-123".to_string())
1908 );
1909 assert_eq!(
1910 retrieved_ctx.direction,
1911 CommunicationDirection::ServerToClient
1912 );
1913 assert_eq!(retrieved_ctx.initiator, CommunicationInitiator::Server);
1914 assert!(retrieved_ctx.expects_response);
1915 }
1916
1917 #[test]
1918 fn test_completion_reference_serialization() {
1919 let prompt_ref = CompletionReference::Prompt {
1920 name: "test_prompt".to_string(),
1921 argument: "user_input".to_string(),
1922 };
1923
1924 let template_ref = CompletionReference::ResourceTemplate {
1925 name: "api_endpoint".to_string(),
1926 parameter: "api_key".to_string(),
1927 };
1928
1929 let tool_ref = CompletionReference::Tool {
1930 name: "file_reader".to_string(),
1931 argument: "path".to_string(),
1932 };
1933
1934 let custom_ref = CompletionReference::Custom {
1935 ref_type: "database_query".to_string(),
1936 metadata: {
1937 let mut map = HashMap::new();
1938 map.insert("table".to_string(), serde_json::json!("users"));
1939 map
1940 },
1941 };
1942
1943 let refs = vec![prompt_ref, template_ref, tool_ref, custom_ref];
1945 for ref_item in refs {
1946 let serialized = serde_json::to_value(&ref_item).unwrap();
1947 let deserialized: CompletionReference = serde_json::from_value(serialized).unwrap();
1948
1949 match (&ref_item, &deserialized) {
1950 (
1951 CompletionReference::Prompt {
1952 name: n1,
1953 argument: a1,
1954 },
1955 CompletionReference::Prompt {
1956 name: n2,
1957 argument: a2,
1958 },
1959 ) => {
1960 assert_eq!(n1, n2);
1961 assert_eq!(a1, a2);
1962 }
1963 (
1964 CompletionReference::ResourceTemplate {
1965 name: n1,
1966 parameter: p1,
1967 },
1968 CompletionReference::ResourceTemplate {
1969 name: n2,
1970 parameter: p2,
1971 },
1972 ) => {
1973 assert_eq!(n1, n2);
1974 assert_eq!(p1, p2);
1975 }
1976 (
1977 CompletionReference::Tool {
1978 name: n1,
1979 argument: a1,
1980 },
1981 CompletionReference::Tool {
1982 name: n2,
1983 argument: a2,
1984 },
1985 ) => {
1986 assert_eq!(n1, n2);
1987 assert_eq!(a1, a2);
1988 }
1989 (
1990 CompletionReference::Custom {
1991 ref_type: t1,
1992 metadata: m1,
1993 },
1994 CompletionReference::Custom {
1995 ref_type: t2,
1996 metadata: m2,
1997 },
1998 ) => {
1999 assert_eq!(t1, t2);
2000 assert_eq!(m1.len(), m2.len());
2001 }
2002 _ => panic!("Serialization round-trip failed for CompletionReference"),
2003 }
2004 }
2005 }
2006
2007 #[test]
2008 fn test_context_metadata_integration() {
2009 let mut elicit_ctx = ElicitationContext::new(
2011 "Enter name".to_string(),
2012 serde_json::json!({"type": "string"}),
2013 );
2014 elicit_ctx.required = true;
2015 elicit_ctx.timeout_ms = Some(30_000);
2016 elicit_ctx.cancellable = true;
2017
2018 let ping_ctx = PingContext {
2019 origin: PingOrigin::Server,
2020 response_threshold_ms: Some(2_000),
2021 payload: None,
2022 health_metadata: HashMap::new(),
2023 connection_metrics: None,
2024 };
2025
2026 let ctx = RequestContext::new()
2027 .with_elicitation_context(elicit_ctx)
2028 .with_ping_context(ping_ctx)
2029 .with_metadata("custom_field", "custom_value");
2030
2031 assert!(ctx.elicitation_context().is_some());
2033 assert!(ctx.ping_context().is_some());
2034 assert_eq!(
2035 ctx.get_metadata("custom_field"),
2036 Some(&serde_json::json!("custom_value"))
2037 );
2038
2039 let elicit = ctx.elicitation_context().unwrap();
2041 assert_eq!(elicit.message, "Enter name".to_string());
2042
2043 let ping = ctx.ping_context().unwrap();
2044 assert_eq!(ping.response_threshold_ms, Some(2_000));
2045 }
2046}