1use std::collections::HashMap;
54use std::fmt;
55use std::sync::Arc;
56use std::time::Instant;
57
58use chrono::{DateTime, Utc};
59use serde::{Deserialize, Serialize};
60use tokio_util::sync::CancellationToken;
61use uuid::Uuid;
62
63use crate::types::Timestamp;
64
65#[derive(Debug, Clone)]
67pub struct RequestContext {
68 pub request_id: String,
70
71 pub user_id: Option<String>,
73
74 pub session_id: Option<String>,
76
77 pub client_id: Option<String>,
79
80 pub timestamp: Timestamp,
82
83 pub start_time: Instant,
85
86 pub metadata: Arc<HashMap<String, serde_json::Value>>,
88
89 #[cfg(feature = "tracing")]
91 pub span: Option<tracing::Span>,
92
93 pub cancellation_token: Option<Arc<CancellationToken>>,
95}
96
97#[derive(Debug, Clone)]
99pub struct ResponseContext {
100 pub request_id: String,
102
103 pub timestamp: Timestamp,
105
106 pub duration: std::time::Duration,
108
109 pub status: ResponseStatus,
111
112 pub metadata: Arc<HashMap<String, serde_json::Value>>,
114}
115
116#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
118pub enum ResponseStatus {
119 Success,
121 Error {
123 code: i32,
125 message: String,
127 },
128 Partial,
130 Cancelled,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ElicitationContext {
141 pub elicitation_id: String,
143 pub message: String,
145 pub schema: serde_json::Value,
147 #[deprecated(note = "Use message field instead")]
149 pub prompt: Option<String>,
150 pub constraints: Option<serde_json::Value>,
152 pub defaults: Option<HashMap<String, serde_json::Value>>,
154 pub required: bool,
156 pub timeout_ms: Option<u64>,
158 pub cancellable: bool,
160 pub client_session: Option<ClientSession>,
162 pub requested_at: Timestamp,
164 pub state: ElicitationState,
166 pub metadata: HashMap<String, serde_json::Value>,
168}
169
170#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub enum ElicitationState {
173 Pending,
175 Accepted,
177 Declined,
179 Cancelled,
181 TimedOut,
183}
184
185impl ElicitationContext {
186 pub fn new(message: String, schema: serde_json::Value) -> Self {
188 Self {
189 elicitation_id: Uuid::new_v4().to_string(),
190 message,
191 schema,
192 #[allow(deprecated)]
193 prompt: None,
194 constraints: None,
195 defaults: None,
196 required: true,
197 timeout_ms: Some(30000),
198 cancellable: true,
199 client_session: None,
200 requested_at: Timestamp::now(),
201 state: ElicitationState::Pending,
202 metadata: HashMap::new(),
203 }
204 }
205
206 pub fn with_client_session(mut self, session: ClientSession) -> Self {
208 self.client_session = Some(session);
209 self
210 }
211
212 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
214 self.timeout_ms = Some(timeout_ms);
215 self
216 }
217
218 pub fn set_state(&mut self, state: ElicitationState) {
220 self.state = state;
221 }
222
223 pub fn is_complete(&self) -> bool {
225 !matches!(self.state, ElicitationState::Pending)
226 }
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct CompletionContext {
232 pub completion_id: String,
234 pub completion_ref: CompletionReference,
236 pub argument_name: Option<String>,
238 pub partial_value: Option<String>,
240 pub resolved_arguments: HashMap<String, String>,
242 pub completions: Vec<CompletionOption>,
244 pub cursor_position: Option<usize>,
246 pub max_completions: Option<usize>,
248 pub has_more: bool,
250 pub total_completions: Option<usize>,
252 pub client_capabilities: Option<CompletionCapabilities>,
254 pub metadata: HashMap<String, serde_json::Value>,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct CompletionCapabilities {
261 pub supports_pagination: bool,
263 pub supports_fuzzy: bool,
265 pub max_batch_size: usize,
267 pub supports_descriptions: bool,
269}
270
271impl CompletionContext {
272 pub fn new(completion_ref: CompletionReference) -> Self {
274 Self {
275 completion_id: Uuid::new_v4().to_string(),
276 completion_ref,
277 argument_name: None,
278 partial_value: None,
279 resolved_arguments: HashMap::new(),
280 completions: Vec::new(),
281 cursor_position: None,
282 max_completions: Some(100),
283 has_more: false,
284 total_completions: None,
285 client_capabilities: None,
286 metadata: HashMap::new(),
287 }
288 }
289
290 pub fn add_completion(&mut self, option: CompletionOption) {
292 self.completions.push(option);
293 }
294
295 pub fn with_resolved_arguments(mut self, args: HashMap<String, String>) -> Self {
297 self.resolved_arguments = args;
298 self
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub enum CompletionReference {
305 Prompt {
307 name: String,
309 argument: String,
311 },
312 ResourceTemplate {
314 name: String,
316 parameter: String,
318 },
319 Tool {
321 name: String,
323 argument: String,
325 },
326 Custom {
328 ref_type: String,
330 metadata: HashMap<String, serde_json::Value>,
332 },
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct CompletionOption {
338 pub value: String,
340 pub label: Option<String>,
342 pub completion_type: Option<String>,
344 pub documentation: Option<String>,
346 pub sort_priority: Option<i32>,
348 pub insert_text: Option<String>,
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct ResourceTemplateContext {
355 pub template_name: String,
357 pub uri_template: String,
359 pub parameters: HashMap<String, TemplateParameter>,
361 pub description: Option<String>,
363 pub preset_type: Option<String>,
365 pub metadata: HashMap<String, serde_json::Value>,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct TemplateParameter {
372 pub name: String,
374 pub param_type: String,
376 pub required: bool,
378 pub default: Option<serde_json::Value>,
380 pub description: Option<String>,
382 pub pattern: Option<String>,
384 pub enum_values: Option<Vec<serde_json::Value>>,
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct PingContext {
391 pub origin: PingOrigin,
393 pub response_threshold_ms: Option<u64>,
395 pub payload: Option<serde_json::Value>,
397 pub health_metadata: HashMap<String, serde_json::Value>,
399 pub connection_metrics: Option<ConnectionMetrics>,
401}
402
403#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
405pub enum PingOrigin {
406 Client,
408 Server,
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct ConnectionMetrics {
415 pub rtt_ms: Option<f64>,
417 pub packet_loss: Option<f64>,
419 pub uptime_seconds: Option<u64>,
421 pub bytes_sent: Option<u64>,
423 pub bytes_received: Option<u64>,
425 pub last_success: Option<DateTime<Utc>>,
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct BidirectionalContext {
432 pub direction: CommunicationDirection,
434 pub initiator: CommunicationInitiator,
436 pub expects_response: bool,
438 pub parent_request_id: Option<String>,
440 pub request_type: Option<String>,
442 pub server_id: Option<String>,
444 pub correlation_id: String,
446 pub metadata: HashMap<String, serde_json::Value>,
448}
449
450impl BidirectionalContext {
451 pub fn new(direction: CommunicationDirection, initiator: CommunicationInitiator) -> Self {
453 Self {
454 direction,
455 initiator,
456 expects_response: true,
457 parent_request_id: None,
458 request_type: None,
459 server_id: None,
460 correlation_id: Uuid::new_v4().to_string(),
461 metadata: HashMap::new(),
462 }
463 }
464
465 pub fn with_direction(mut self, direction: CommunicationDirection) -> Self {
467 self.direction = direction;
468 self
469 }
470
471 pub fn with_request_type(mut self, request_type: String) -> Self {
473 self.request_type = Some(request_type);
474 self
475 }
476
477 pub fn validate_direction(&self) -> Result<(), String> {
479 if let Some(ref request_type) = self.request_type {
480 match (request_type.as_str(), &self.direction) {
481 ("sampling/createMessage", CommunicationDirection::ServerToClient) => Ok(()),
483 ("roots/list", CommunicationDirection::ServerToClient) => Ok(()),
484 ("elicitation/create", CommunicationDirection::ServerToClient) => Ok(()),
485
486 ("completion/complete", CommunicationDirection::ClientToServer) => Ok(()),
488 ("tools/call", CommunicationDirection::ClientToServer) => Ok(()),
489 ("resources/read", CommunicationDirection::ClientToServer) => Ok(()),
490 ("prompts/get", CommunicationDirection::ClientToServer) => Ok(()),
491
492 ("ping", _) => Ok(()), (req, dir) => Err(format!(
497 "Invalid direction {:?} for request type '{}'",
498 dir, req
499 )),
500 }
501 } else {
502 Ok(()) }
504 }
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct ServerInitiatedContext {
510 pub request_type: ServerInitiatedType,
512 pub server_id: String,
514 pub correlation_id: String,
516 pub client_capabilities: Option<ClientCapabilities>,
518 pub initiated_at: Timestamp,
520 pub metadata: HashMap<String, serde_json::Value>,
522}
523
524#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
526pub enum ServerInitiatedType {
527 CreateMessage,
529 ListRoots,
531 Elicitation,
533 Ping,
535}
536
537#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct ClientCapabilities {
540 pub sampling: bool,
542 pub roots: bool,
544 pub elicitation: bool,
546 pub max_concurrent_requests: usize,
548 pub experimental: HashMap<String, bool>,
550}
551
552impl ServerInitiatedContext {
553 pub fn new(request_type: ServerInitiatedType, server_id: String) -> Self {
555 Self {
556 request_type,
557 server_id,
558 correlation_id: Uuid::new_v4().to_string(),
559 client_capabilities: None,
560 initiated_at: Timestamp::now(),
561 metadata: HashMap::new(),
562 }
563 }
564
565 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
567 self.client_capabilities = Some(capabilities);
568 self
569 }
570
571 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
573 self.metadata.insert(key, value);
574 self
575 }
576}
577
578#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
580pub enum CommunicationDirection {
581 ClientToServer,
583 ServerToClient,
585}
586
587#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
589pub enum CommunicationInitiator {
590 Client,
592 Server,
594}
595
596impl RequestContext {
597 #[must_use]
611 pub fn new() -> Self {
612 Self {
613 request_id: Uuid::new_v4().to_string(),
614 user_id: None,
615 session_id: None,
616 client_id: None,
617 timestamp: Timestamp::now(),
618 start_time: Instant::now(),
619 metadata: Arc::new(HashMap::new()),
620 #[cfg(feature = "tracing")]
621 span: None,
622 cancellation_token: None,
623 }
624 }
625 #[must_use]
640 pub fn is_authenticated(&self) -> bool {
641 self.metadata
642 .get("authenticated")
643 .and_then(serde_json::Value::as_bool)
644 .unwrap_or(false)
645 }
646
647 #[must_use]
649 pub fn user(&self) -> Option<&str> {
650 self.user_id.as_deref()
651 }
652
653 #[must_use]
655 pub fn roles(&self) -> Vec<String> {
656 self.metadata
657 .get("auth")
658 .and_then(|v| v.get("roles"))
659 .and_then(|v| v.as_array())
660 .map(|arr| {
661 arr.iter()
662 .filter_map(|v| v.as_str().map(std::string::ToString::to_string))
663 .collect()
664 })
665 .unwrap_or_default()
666 }
667
668 pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
670 if required.is_empty() {
671 return true;
672 }
673 let user_roles = self.roles();
674 if user_roles.is_empty() {
675 return false;
676 }
677 let set: std::collections::HashSet<_> = user_roles.into_iter().collect();
678 required.iter().any(|r| set.contains(r.as_ref()))
679 }
680
681 pub fn with_id(id: impl Into<String>) -> Self {
683 Self {
684 request_id: id.into(),
685 ..Self::new()
686 }
687 }
688
689 #[must_use]
691 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
692 self.user_id = Some(user_id.into());
693 self
694 }
695
696 #[must_use]
698 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
699 self.session_id = Some(session_id.into());
700 self
701 }
702
703 #[must_use]
705 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
706 self.client_id = Some(client_id.into());
707 self
708 }
709
710 #[must_use]
712 pub fn with_metadata(
713 mut self,
714 key: impl Into<String>,
715 value: impl Into<serde_json::Value>,
716 ) -> Self {
717 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
718 self
719 }
720
721 #[must_use]
723 pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
724 self.cancellation_token = Some(token);
725 self
726 }
727
728 #[must_use]
730 pub fn elapsed(&self) -> std::time::Duration {
731 self.start_time.elapsed()
732 }
733
734 #[must_use]
736 pub fn is_cancelled(&self) -> bool {
737 self.cancellation_token
738 .as_ref()
739 .is_some_and(|token| token.is_cancelled())
740 }
741
742 #[must_use]
744 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
745 self.metadata.get(key)
746 }
747
748 #[must_use]
750 pub fn derive(&self) -> Self {
751 Self {
752 request_id: Uuid::new_v4().to_string(),
753 user_id: self.user_id.clone(),
754 session_id: self.session_id.clone(),
755 client_id: self.client_id.clone(),
756 timestamp: Timestamp::now(),
757 start_time: Instant::now(),
758 metadata: self.metadata.clone(),
759 #[cfg(feature = "tracing")]
760 span: None,
761 cancellation_token: self.cancellation_token.clone(),
762 }
763 }
764
765 #[must_use]
771 pub fn with_elicitation_context(mut self, context: ElicitationContext) -> Self {
772 Arc::make_mut(&mut self.metadata).insert(
773 "elicitation_context".to_string(),
774 serde_json::to_value(context).unwrap_or_default(),
775 );
776 self
777 }
778
779 pub fn elicitation_context(&self) -> Option<ElicitationContext> {
781 self.metadata
782 .get("elicitation_context")
783 .and_then(|v| serde_json::from_value(v.clone()).ok())
784 }
785
786 #[must_use]
788 pub fn with_completion_context(mut self, context: CompletionContext) -> Self {
789 Arc::make_mut(&mut self.metadata).insert(
790 "completion_context".to_string(),
791 serde_json::to_value(context).unwrap_or_default(),
792 );
793 self
794 }
795
796 pub fn completion_context(&self) -> Option<CompletionContext> {
798 self.metadata
799 .get("completion_context")
800 .and_then(|v| serde_json::from_value(v.clone()).ok())
801 }
802
803 #[must_use]
805 pub fn with_resource_template_context(mut self, context: ResourceTemplateContext) -> Self {
806 Arc::make_mut(&mut self.metadata).insert(
807 "resource_template_context".to_string(),
808 serde_json::to_value(context).unwrap_or_default(),
809 );
810 self
811 }
812
813 pub fn resource_template_context(&self) -> Option<ResourceTemplateContext> {
815 self.metadata
816 .get("resource_template_context")
817 .and_then(|v| serde_json::from_value(v.clone()).ok())
818 }
819
820 #[must_use]
822 pub fn with_ping_context(mut self, context: PingContext) -> Self {
823 Arc::make_mut(&mut self.metadata).insert(
824 "ping_context".to_string(),
825 serde_json::to_value(context).unwrap_or_default(),
826 );
827 self
828 }
829
830 pub fn ping_context(&self) -> Option<PingContext> {
832 self.metadata
833 .get("ping_context")
834 .and_then(|v| serde_json::from_value(v.clone()).ok())
835 }
836
837 #[must_use]
839 pub fn with_bidirectional_context(mut self, context: BidirectionalContext) -> Self {
840 Arc::make_mut(&mut self.metadata).insert(
841 "bidirectional_context".to_string(),
842 serde_json::to_value(context).unwrap_or_default(),
843 );
844 self
845 }
846
847 pub fn bidirectional_context(&self) -> Option<BidirectionalContext> {
849 self.metadata
850 .get("bidirectional_context")
851 .and_then(|v| serde_json::from_value(v.clone()).ok())
852 }
853
854 pub fn is_server_initiated(&self) -> bool {
856 self.bidirectional_context()
857 .map(|ctx| ctx.direction == CommunicationDirection::ServerToClient)
858 .unwrap_or(false)
859 }
860
861 pub fn is_client_initiated(&self) -> bool {
863 !self.is_server_initiated()
864 }
865
866 pub fn communication_direction(&self) -> CommunicationDirection {
868 self.bidirectional_context()
869 .map(|ctx| ctx.direction)
870 .unwrap_or(CommunicationDirection::ClientToServer)
871 }
872
873 pub fn for_elicitation(schema: serde_json::Value, prompt: Option<String>) -> Self {
875 let message = prompt.unwrap_or_else(|| "Please provide input".to_string());
876 let elicitation_ctx = ElicitationContext::new(message, schema);
877
878 let bidirectional_ctx = BidirectionalContext::new(
879 CommunicationDirection::ServerToClient,
880 CommunicationInitiator::Server,
881 )
882 .with_request_type("elicitation/create".to_string());
883
884 Self::new()
885 .with_elicitation_context(elicitation_ctx)
886 .with_bidirectional_context(bidirectional_ctx)
887 }
888
889 pub fn for_completion(completion_ref: CompletionReference) -> Self {
891 let completion_ctx = CompletionContext::new(completion_ref);
892
893 Self::new().with_completion_context(completion_ctx)
894 }
895
896 pub fn for_resource_template(template_name: String, uri_template: String) -> Self {
898 let template_ctx = ResourceTemplateContext {
899 template_name,
900 uri_template,
901 parameters: HashMap::new(),
902 description: None,
903 preset_type: None,
904 metadata: HashMap::new(),
905 };
906
907 Self::new().with_resource_template_context(template_ctx)
908 }
909
910 pub fn for_ping(origin: PingOrigin) -> Self {
912 let ping_ctx = PingContext {
913 origin: origin.clone(),
914 response_threshold_ms: Some(5_000), payload: None,
916 health_metadata: HashMap::new(),
917 connection_metrics: None,
918 };
919
920 let direction = match origin {
921 PingOrigin::Client => CommunicationDirection::ClientToServer,
922 PingOrigin::Server => CommunicationDirection::ServerToClient,
923 };
924 let initiator = match origin {
925 PingOrigin::Client => CommunicationInitiator::Client,
926 PingOrigin::Server => CommunicationInitiator::Server,
927 };
928
929 let bidirectional_ctx =
930 BidirectionalContext::new(direction, initiator).with_request_type("ping".to_string());
931
932 Self::new()
933 .with_ping_context(ping_ctx)
934 .with_bidirectional_context(bidirectional_ctx)
935 }
936}
937
938impl ResponseContext {
939 pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
941 Self {
942 request_id: request_id.into(),
943 timestamp: Timestamp::now(),
944 duration,
945 status: ResponseStatus::Success,
946 metadata: Arc::new(HashMap::new()),
947 }
948 }
949
950 pub fn error(
952 request_id: impl Into<String>,
953 duration: std::time::Duration,
954 code: i32,
955 message: impl Into<String>,
956 ) -> Self {
957 Self {
958 request_id: request_id.into(),
959 timestamp: Timestamp::now(),
960 duration,
961 status: ResponseStatus::Error {
962 code,
963 message: message.into(),
964 },
965 metadata: Arc::new(HashMap::new()),
966 }
967 }
968
969 pub fn cancelled(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
971 Self {
972 request_id: request_id.into(),
973 timestamp: Timestamp::now(),
974 duration,
975 status: ResponseStatus::Cancelled,
976 metadata: Arc::new(HashMap::new()),
977 }
978 }
979
980 #[must_use]
982 pub fn with_metadata(
983 mut self,
984 key: impl Into<String>,
985 value: impl Into<serde_json::Value>,
986 ) -> Self {
987 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
988 self
989 }
990
991 #[must_use]
993 pub const fn is_success(&self) -> bool {
994 matches!(self.status, ResponseStatus::Success)
995 }
996
997 #[must_use]
999 pub const fn is_error(&self) -> bool {
1000 matches!(self.status, ResponseStatus::Error { .. })
1001 }
1002
1003 #[must_use]
1005 pub fn error_info(&self) -> Option<(i32, &str)> {
1006 match &self.status {
1007 ResponseStatus::Error { code, message } => Some((*code, message)),
1008 _ => None,
1009 }
1010 }
1011}
1012
1013impl Default for RequestContext {
1014 fn default() -> Self {
1015 Self::new()
1016 }
1017}
1018
1019impl fmt::Display for ResponseStatus {
1020 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1021 match self {
1022 Self::Success => write!(f, "Success"),
1023 Self::Error { code, message } => write!(f, "Error({code}: {message})"),
1024 Self::Partial => write!(f, "Partial"),
1025 Self::Cancelled => write!(f, "Cancelled"),
1026 }
1027 }
1028}
1029
1030#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1036pub enum ClientId {
1037 Header(String),
1039 Token(String),
1041 Session(String),
1043 QueryParam(String),
1045 UserAgent(String),
1047 Anonymous,
1049}
1050
1051impl ClientId {
1052 #[must_use]
1054 pub fn as_str(&self) -> &str {
1055 match self {
1056 Self::Header(id)
1057 | Self::Token(id)
1058 | Self::Session(id)
1059 | Self::QueryParam(id)
1060 | Self::UserAgent(id) => id,
1061 Self::Anonymous => "anonymous",
1062 }
1063 }
1064
1065 #[must_use]
1067 pub const fn is_authenticated(&self) -> bool {
1068 matches!(self, Self::Token(_) | Self::Session(_))
1069 }
1070
1071 #[must_use]
1073 pub const fn auth_method(&self) -> &'static str {
1074 match self {
1075 Self::Header(_) => "header",
1076 Self::Token(_) => "bearer_token",
1077 Self::Session(_) => "session_cookie",
1078 Self::QueryParam(_) => "query_param",
1079 Self::UserAgent(_) => "user_agent",
1080 Self::Anonymous => "anonymous",
1081 }
1082 }
1083}
1084
1085#[derive(Debug, Clone, Serialize, Deserialize)]
1087pub struct ClientSession {
1088 pub client_id: String,
1090 pub client_name: Option<String>,
1092 pub connected_at: DateTime<Utc>,
1094 pub last_activity: DateTime<Utc>,
1096 pub request_count: usize,
1098 pub transport_type: String,
1100 pub authenticated: bool,
1102 pub capabilities: Option<serde_json::Value>,
1104 pub metadata: HashMap<String, serde_json::Value>,
1106}
1107
1108impl ClientSession {
1109 #[must_use]
1111 pub fn new(client_id: String, transport_type: String) -> Self {
1112 let now = Utc::now();
1113 Self {
1114 client_id,
1115 client_name: None,
1116 connected_at: now,
1117 last_activity: now,
1118 request_count: 0,
1119 transport_type,
1120 authenticated: false,
1121 capabilities: None,
1122 metadata: HashMap::new(),
1123 }
1124 }
1125
1126 pub fn update_activity(&mut self) {
1128 self.last_activity = Utc::now();
1129 self.request_count += 1;
1130 }
1131
1132 pub fn authenticate(&mut self, client_name: Option<String>) {
1134 self.authenticated = true;
1135 self.client_name = client_name;
1136 }
1137
1138 pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
1140 self.capabilities = Some(capabilities);
1141 }
1142
1143 #[must_use]
1145 pub fn session_duration(&self) -> chrono::Duration {
1146 self.last_activity - self.connected_at
1147 }
1148
1149 #[must_use]
1151 pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
1152 Utc::now() - self.last_activity > idle_threshold
1153 }
1154}
1155
1156#[derive(Debug, Clone, Serialize, Deserialize)]
1158pub struct RequestInfo {
1159 pub timestamp: DateTime<Utc>,
1161 pub client_id: String,
1163 pub method_name: String,
1165 pub parameters: serde_json::Value,
1167 pub response_time_ms: Option<u64>,
1169 pub success: bool,
1171 pub error_message: Option<String>,
1173 pub status_code: Option<u16>,
1175 pub metadata: HashMap<String, serde_json::Value>,
1177}
1178
1179impl RequestInfo {
1180 #[must_use]
1182 pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
1183 Self {
1184 timestamp: Utc::now(),
1185 client_id,
1186 method_name,
1187 parameters,
1188 response_time_ms: None,
1189 success: false,
1190 error_message: None,
1191 status_code: None,
1192 metadata: HashMap::new(),
1193 }
1194 }
1195
1196 #[must_use]
1198 pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
1199 self.response_time_ms = Some(response_time_ms);
1200 self.success = true;
1201 self.status_code = Some(200);
1202 self
1203 }
1204
1205 #[must_use]
1207 pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
1208 self.response_time_ms = Some(response_time_ms);
1209 self.success = false;
1210 self.error_message = Some(error);
1211 self.status_code = Some(500);
1212 self
1213 }
1214
1215 #[must_use]
1217 pub const fn with_status_code(mut self, code: u16) -> Self {
1218 self.status_code = Some(code);
1219 self
1220 }
1221
1222 #[must_use]
1224 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
1225 self.metadata.insert(key, value);
1226 self
1227 }
1228}
1229
1230#[derive(Debug)]
1232pub struct ClientIdExtractor {
1233 auth_tokens: Arc<dashmap::DashMap<String, String>>,
1235}
1236
1237impl ClientIdExtractor {
1238 #[must_use]
1240 pub fn new() -> Self {
1241 Self {
1242 auth_tokens: Arc::new(dashmap::DashMap::new()),
1243 }
1244 }
1245
1246 pub fn register_token(&self, token: String, client_id: String) {
1248 self.auth_tokens.insert(token, client_id);
1249 }
1250
1251 pub fn revoke_token(&self, token: &str) {
1253 self.auth_tokens.remove(token);
1254 }
1255
1256 #[must_use]
1258 pub fn list_tokens(&self) -> Vec<(String, String)> {
1259 self.auth_tokens
1260 .iter()
1261 .map(|entry| (entry.key().clone(), entry.value().clone()))
1262 .collect()
1263 }
1264
1265 #[must_use]
1267 #[allow(clippy::significant_drop_tightening)]
1268 pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
1269 if let Some(client_id) = headers.get("x-client-id") {
1271 return ClientId::Header(client_id.clone());
1272 }
1273
1274 if let Some(auth) = headers.get("authorization")
1276 && let Some(token) = auth.strip_prefix("Bearer ")
1277 {
1278 let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
1280 if let Some(entry) = token_lookup {
1281 let client_id = entry.value().clone();
1282 drop(entry); return ClientId::Token(client_id);
1284 }
1285 return ClientId::Token(token.to_string());
1287 }
1288
1289 if let Some(cookie) = headers.get("cookie") {
1291 for cookie_part in cookie.split(';') {
1292 let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
1293 if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
1294 return ClientId::Session(parts[1].to_string());
1295 }
1296 }
1297 }
1298
1299 if let Some(user_agent) = headers.get("user-agent") {
1301 use std::collections::hash_map::DefaultHasher;
1302 use std::hash::{Hash, Hasher};
1303 let mut hasher = DefaultHasher::new();
1304 user_agent.hash(&mut hasher);
1305 return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
1306 }
1307
1308 ClientId::Anonymous
1309 }
1310
1311 #[must_use]
1313 pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
1314 query_params
1315 .get("client_id")
1316 .map(|client_id| ClientId::QueryParam(client_id.clone()))
1317 }
1318
1319 #[must_use]
1321 pub fn extract_client_id(
1322 &self,
1323 headers: Option<&HashMap<String, String>>,
1324 query_params: Option<&HashMap<String, String>>,
1325 ) -> ClientId {
1326 if let Some(params) = query_params
1328 && let Some(client_id) = self.extract_from_query(params)
1329 {
1330 return client_id;
1331 }
1332
1333 if let Some(headers) = headers {
1335 return self.extract_from_http_headers(headers);
1336 }
1337
1338 ClientId::Anonymous
1339 }
1340}
1341
1342impl Default for ClientIdExtractor {
1343 fn default() -> Self {
1344 Self::new()
1345 }
1346}
1347
1348pub trait RequestContextExt {
1350 #[must_use]
1352 fn with_enhanced_client_id(self, client_id: ClientId) -> Self;
1353
1354 #[must_use]
1356 fn extract_client_id(
1357 self,
1358 extractor: &ClientIdExtractor,
1359 headers: Option<&HashMap<String, String>>,
1360 query_params: Option<&HashMap<String, String>>,
1361 ) -> Self;
1362
1363 fn get_enhanced_client_id(&self) -> Option<ClientId>;
1365}
1366
1367impl RequestContextExt for RequestContext {
1368 fn with_enhanced_client_id(self, client_id: ClientId) -> Self {
1369 self.with_client_id(client_id.as_str())
1370 .with_metadata("client_id_method", client_id.auth_method())
1371 .with_metadata("client_authenticated", client_id.is_authenticated())
1372 }
1373
1374 fn extract_client_id(
1375 self,
1376 extractor: &ClientIdExtractor,
1377 headers: Option<&HashMap<String, String>>,
1378 query_params: Option<&HashMap<String, String>>,
1379 ) -> Self {
1380 let client_id = extractor.extract_client_id(headers, query_params);
1381 self.with_enhanced_client_id(client_id)
1382 }
1383
1384 fn get_enhanced_client_id(&self) -> Option<ClientId> {
1385 self.client_id.as_ref().map(|id| {
1386 let method = self
1387 .get_metadata("client_id_method")
1388 .and_then(|v| v.as_str())
1389 .unwrap_or("header");
1390
1391 match method {
1392 "bearer_token" => ClientId::Token(id.clone()),
1393 "session_cookie" => ClientId::Session(id.clone()),
1394 "query_param" => ClientId::QueryParam(id.clone()),
1395 "user_agent" => ClientId::UserAgent(id.clone()),
1396 "anonymous" => ClientId::Anonymous,
1397 _ => ClientId::Header(id.clone()), }
1399 })
1400 }
1401}
1402
1403#[cfg(test)]
1404mod tests {
1405 use super::*;
1406
1407 #[test]
1408 fn test_request_context_creation() {
1409 let ctx = RequestContext::new();
1410 assert!(!ctx.request_id.is_empty());
1411 assert!(ctx.user_id.is_none());
1412 assert!(ctx.elapsed() < std::time::Duration::from_millis(100));
1413 }
1414
1415 #[test]
1416 fn test_request_context_builder() {
1417 let ctx = RequestContext::new()
1418 .with_user_id("user123")
1419 .with_session_id("session456")
1420 .with_metadata("key", "value");
1421
1422 assert_eq!(ctx.user_id, Some("user123".to_string()));
1423 assert_eq!(ctx.session_id, Some("session456".to_string()));
1424 assert_eq!(
1425 ctx.get_metadata("key"),
1426 Some(&serde_json::Value::String("value".to_string()))
1427 );
1428 }
1429
1430 #[test]
1431 fn test_response_context_creation() {
1432 let duration = std::time::Duration::from_millis(100);
1433
1434 let success_ctx = ResponseContext::success("req1", duration);
1435 assert!(success_ctx.is_success());
1436 assert!(!success_ctx.is_error());
1437
1438 let error_ctx = ResponseContext::error("req2", duration, 500, "Internal error");
1439 assert!(!error_ctx.is_success());
1440 assert!(error_ctx.is_error());
1441 assert_eq!(error_ctx.error_info(), Some((500, "Internal error")));
1442 }
1443
1444 #[test]
1445 fn test_context_derivation() {
1446 let parent_ctx = RequestContext::new()
1447 .with_user_id("user123")
1448 .with_metadata("key", "value");
1449
1450 let child_ctx = parent_ctx.derive();
1451
1452 assert_ne!(parent_ctx.request_id, child_ctx.request_id);
1454
1455 assert_eq!(parent_ctx.user_id, child_ctx.user_id);
1457 assert_eq!(
1458 parent_ctx.get_metadata("key"),
1459 child_ctx.get_metadata("key")
1460 );
1461 }
1462
1463 #[test]
1466 fn test_client_id_extraction() {
1467 let extractor = ClientIdExtractor::new();
1468
1469 let mut headers = HashMap::new();
1471 headers.insert("x-client-id".to_string(), "test-client".to_string());
1472
1473 let client_id = extractor.extract_from_http_headers(&headers);
1474 assert_eq!(client_id, ClientId::Header("test-client".to_string()));
1475 assert_eq!(client_id.as_str(), "test-client");
1476 assert_eq!(client_id.auth_method(), "header");
1477 assert!(!client_id.is_authenticated());
1478 }
1479
1480 #[test]
1481 fn test_bearer_token_extraction() {
1482 let extractor = ClientIdExtractor::new();
1483 extractor.register_token("token123".to_string(), "client-1".to_string());
1484
1485 let mut headers = HashMap::new();
1486 headers.insert("authorization".to_string(), "Bearer token123".to_string());
1487
1488 let client_id = extractor.extract_from_http_headers(&headers);
1489 assert_eq!(client_id, ClientId::Token("client-1".to_string()));
1490 assert!(client_id.is_authenticated());
1491 assert_eq!(client_id.auth_method(), "bearer_token");
1492 }
1493
1494 #[test]
1495 fn test_session_cookie_extraction() {
1496 let extractor = ClientIdExtractor::new();
1497
1498 let mut headers = HashMap::new();
1499 headers.insert(
1500 "cookie".to_string(),
1501 "session_id=sess123; other=value".to_string(),
1502 );
1503
1504 let client_id = extractor.extract_from_http_headers(&headers);
1505 assert_eq!(client_id, ClientId::Session("sess123".to_string()));
1506 assert!(client_id.is_authenticated());
1507 }
1508
1509 #[test]
1510 fn test_user_agent_fallback() {
1511 let extractor = ClientIdExtractor::new();
1512
1513 let mut headers = HashMap::new();
1514 headers.insert("user-agent".to_string(), "TestAgent/1.0".to_string());
1515
1516 let client_id = extractor.extract_from_http_headers(&headers);
1517 if let ClientId::UserAgent(id) = client_id {
1518 assert!(id.starts_with("ua_"));
1519 } else {
1520 assert!(
1522 matches!(client_id, ClientId::UserAgent(_)),
1523 "Expected UserAgent ClientId"
1524 );
1525 }
1526 }
1527
1528 #[test]
1529 fn test_client_session() {
1530 let mut session = ClientSession::new("test-client".to_string(), "http".to_string());
1531 assert!(!session.authenticated);
1532 assert_eq!(session.request_count, 0);
1533
1534 session.update_activity();
1535 assert_eq!(session.request_count, 1);
1536
1537 session.authenticate(Some("Test Client".to_string()));
1538 assert!(session.authenticated);
1539 assert_eq!(session.client_name, Some("Test Client".to_string()));
1540
1541 assert!(!session.is_idle(chrono::Duration::seconds(1)));
1543 }
1544
1545 #[test]
1546 fn test_request_info() {
1547 let params = serde_json::json!({"param": "value"});
1548 let request = RequestInfo::new("client-1".to_string(), "test_method".to_string(), params);
1549
1550 assert!(!request.success);
1551 assert!(request.response_time_ms.is_none());
1552
1553 let completed = request.complete_success(150);
1554 assert!(completed.success);
1555 assert_eq!(completed.response_time_ms, Some(150));
1556 assert_eq!(completed.status_code, Some(200));
1557 }
1558
1559 #[test]
1560 fn test_request_context_ext() {
1561 let extractor = ClientIdExtractor::new();
1562
1563 let mut headers = HashMap::new();
1564 headers.insert("x-client-id".to_string(), "test-client".to_string());
1565
1566 let ctx = RequestContext::new().extract_client_id(&extractor, Some(&headers), None);
1567
1568 assert_eq!(ctx.client_id, Some("test-client".to_string()));
1569 assert_eq!(
1570 ctx.get_metadata("client_id_method"),
1571 Some(&serde_json::Value::String("header".to_string()))
1572 );
1573 assert_eq!(
1574 ctx.get_metadata("client_authenticated"),
1575 Some(&serde_json::Value::Bool(false))
1576 );
1577
1578 let enhanced_id = ctx.get_enhanced_client_id();
1579 assert_eq!(
1580 enhanced_id,
1581 Some(ClientId::Header("test-client".to_string()))
1582 );
1583 }
1584
1585 #[test]
1586 fn test_request_analytics() {
1587 let start = std::time::Instant::now();
1588 let request = RequestInfo::new(
1589 "client-123".to_string(),
1590 "get_data".to_string(),
1591 serde_json::json!({"filter": "active"}),
1592 );
1593
1594 let response_time = start.elapsed().as_millis() as u64;
1595 let completed = request
1596 .complete_success(response_time)
1597 .with_metadata("cache_hit".to_string(), serde_json::json!(true));
1598
1599 assert!(completed.success);
1600 assert!(completed.response_time_ms.is_some());
1601 assert_eq!(
1602 completed.metadata.get("cache_hit"),
1603 Some(&serde_json::json!(true))
1604 );
1605 }
1606
1607 #[test]
1612 fn test_elicitation_context() {
1613 let schema = serde_json::json!({
1614 "type": "object",
1615 "properties": {
1616 "name": {"type": "string"},
1617 "age": {"type": "integer"}
1618 }
1619 });
1620
1621 let ctx = RequestContext::for_elicitation(
1622 schema.clone(),
1623 Some("Please enter your details".to_string()),
1624 );
1625
1626 let elicit_ctx = ctx.elicitation_context().unwrap();
1628 assert_eq!(elicit_ctx.schema, schema);
1629 assert_eq!(elicit_ctx.message, "Please enter your details".to_string());
1630 assert!(elicit_ctx.required);
1631 assert!(elicit_ctx.cancellable);
1632 assert_eq!(elicit_ctx.timeout_ms, Some(30_000));
1633
1634 assert!(ctx.is_server_initiated());
1636 assert!(!ctx.is_client_initiated());
1637 assert_eq!(
1638 ctx.communication_direction(),
1639 CommunicationDirection::ServerToClient
1640 );
1641
1642 let bi_ctx = ctx.bidirectional_context().unwrap();
1643 assert_eq!(bi_ctx.direction, CommunicationDirection::ServerToClient);
1644 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Server);
1645 assert!(bi_ctx.expects_response);
1646 }
1647
1648 #[test]
1649 fn test_completion_context() {
1650 let comp_ref = CompletionReference::Tool {
1651 name: "test_tool".to_string(),
1652 argument: "file_path".to_string(),
1653 };
1654
1655 let ctx = RequestContext::for_completion(comp_ref.clone());
1656 let completion_ctx = ctx.completion_context().unwrap();
1657
1658 assert!(matches!(
1659 completion_ctx.completion_ref,
1660 CompletionReference::Tool { .. }
1661 ));
1662 assert_eq!(completion_ctx.max_completions, Some(100));
1663 assert!(completion_ctx.completions.is_empty());
1664
1665 let completion_option = CompletionOption {
1667 value: "/home/user/document.txt".to_string(),
1668 label: Some("document.txt".to_string()),
1669 completion_type: Some("file".to_string()),
1670 documentation: Some("A text document".to_string()),
1671 sort_priority: Some(1),
1672 insert_text: Some("document.txt".to_string()),
1673 };
1674
1675 let mut completion_ctx_with_options = CompletionContext::new(comp_ref);
1676 completion_ctx_with_options.argument_name = Some("file_path".to_string());
1677 completion_ctx_with_options.partial_value = Some("/home/user/".to_string());
1678 completion_ctx_with_options.completions = vec![completion_option];
1679 completion_ctx_with_options.cursor_position = Some(11);
1680 completion_ctx_with_options.max_completions = Some(10);
1681
1682 let ctx_with_options =
1683 RequestContext::new().with_completion_context(completion_ctx_with_options);
1684 let retrieved_ctx = ctx_with_options.completion_context().unwrap();
1685
1686 assert_eq!(retrieved_ctx.argument_name, Some("file_path".to_string()));
1687 assert_eq!(retrieved_ctx.partial_value, Some("/home/user/".to_string()));
1688 assert_eq!(retrieved_ctx.completions.len(), 1);
1689 assert_eq!(
1690 retrieved_ctx.completions[0].value,
1691 "/home/user/document.txt"
1692 );
1693 assert_eq!(retrieved_ctx.cursor_position, Some(11));
1694 }
1695
1696 #[test]
1697 fn test_resource_template_context() {
1698 let template_name = "file_system".to_string();
1699 let uri_template = "file://{path}".to_string();
1700
1701 let ctx =
1702 RequestContext::for_resource_template(template_name.clone(), uri_template.clone());
1703 let template_ctx = ctx.resource_template_context().unwrap();
1704
1705 assert_eq!(template_ctx.template_name, template_name);
1706 assert_eq!(template_ctx.uri_template, uri_template);
1707 assert!(template_ctx.parameters.is_empty());
1708
1709 let mut parameters = HashMap::new();
1711 parameters.insert(
1712 "path".to_string(),
1713 TemplateParameter {
1714 name: "path".to_string(),
1715 param_type: "string".to_string(),
1716 required: true,
1717 default: None,
1718 description: Some("File system path".to_string()),
1719 pattern: Some(r"^[/\w.-]+$".to_string()),
1720 enum_values: None,
1721 },
1722 );
1723
1724 let template_ctx_with_params = ResourceTemplateContext {
1725 template_name: "file_system_detailed".to_string(),
1726 uri_template: "file://{path}".to_string(),
1727 parameters,
1728 description: Some("Access file system resources".to_string()),
1729 preset_type: Some("file_system".to_string()),
1730 metadata: HashMap::new(),
1731 };
1732
1733 let ctx_with_params =
1734 RequestContext::new().with_resource_template_context(template_ctx_with_params);
1735 let retrieved_ctx = ctx_with_params.resource_template_context().unwrap();
1736
1737 assert_eq!(retrieved_ctx.parameters.len(), 1);
1738 let path_param = retrieved_ctx.parameters.get("path").unwrap();
1739 assert_eq!(path_param.param_type, "string");
1740 assert!(path_param.required);
1741 assert_eq!(path_param.description, Some("File system path".to_string()));
1742 assert_eq!(
1743 retrieved_ctx.description,
1744 Some("Access file system resources".to_string())
1745 );
1746 }
1747
1748 #[test]
1749 fn test_ping_context_client_initiated() {
1750 let ctx = RequestContext::for_ping(PingOrigin::Client);
1751 let ping_ctx = ctx.ping_context().unwrap();
1752
1753 assert_eq!(ping_ctx.origin, PingOrigin::Client);
1754 assert_eq!(ping_ctx.response_threshold_ms, Some(5_000));
1755 assert!(ping_ctx.payload.is_none());
1756
1757 assert!(!ctx.is_server_initiated());
1759 assert!(ctx.is_client_initiated());
1760 assert_eq!(
1761 ctx.communication_direction(),
1762 CommunicationDirection::ClientToServer
1763 );
1764
1765 let bi_ctx = ctx.bidirectional_context().unwrap();
1766 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Client);
1767 }
1768
1769 #[test]
1770 fn test_ping_context_server_initiated() {
1771 let ctx = RequestContext::for_ping(PingOrigin::Server);
1772 let ping_ctx = ctx.ping_context().unwrap();
1773
1774 assert_eq!(ping_ctx.origin, PingOrigin::Server);
1775
1776 assert!(ctx.is_server_initiated());
1778 assert!(!ctx.is_client_initiated());
1779 assert_eq!(
1780 ctx.communication_direction(),
1781 CommunicationDirection::ServerToClient
1782 );
1783
1784 let bi_ctx = ctx.bidirectional_context().unwrap();
1785 assert_eq!(bi_ctx.initiator, CommunicationInitiator::Server);
1786 }
1787
1788 #[test]
1789 fn test_connection_metrics() {
1790 let mut ping_ctx = PingContext {
1791 origin: PingOrigin::Client,
1792 response_threshold_ms: Some(1_000),
1793 payload: Some(serde_json::json!({"test": true})),
1794 health_metadata: HashMap::new(),
1795 connection_metrics: None,
1796 };
1797
1798 let metrics = ConnectionMetrics {
1800 rtt_ms: Some(150.5),
1801 packet_loss: Some(0.1),
1802 uptime_seconds: Some(3600),
1803 bytes_sent: Some(1024),
1804 bytes_received: Some(2048),
1805 last_success: Some(Utc::now()),
1806 };
1807
1808 ping_ctx.connection_metrics = Some(metrics);
1809
1810 let ctx = RequestContext::new().with_ping_context(ping_ctx);
1811 let retrieved_ctx = ctx.ping_context().unwrap();
1812 let conn_metrics = retrieved_ctx.connection_metrics.unwrap();
1813
1814 assert_eq!(conn_metrics.rtt_ms, Some(150.5));
1815 assert_eq!(conn_metrics.packet_loss, Some(0.1));
1816 assert_eq!(conn_metrics.uptime_seconds, Some(3600));
1817 assert_eq!(conn_metrics.bytes_sent, Some(1024));
1818 assert_eq!(conn_metrics.bytes_received, Some(2048));
1819 assert!(conn_metrics.last_success.is_some());
1820 }
1821
1822 #[test]
1823 fn test_bidirectional_context_standalone() {
1824 let mut bi_ctx = BidirectionalContext::new(
1825 CommunicationDirection::ServerToClient,
1826 CommunicationInitiator::Server,
1827 );
1828 bi_ctx.expects_response = true;
1829 bi_ctx.parent_request_id = Some("parent-123".to_string());
1830
1831 let ctx = RequestContext::new().with_bidirectional_context(bi_ctx.clone());
1832
1833 assert!(ctx.is_server_initiated());
1834 assert_eq!(
1835 ctx.communication_direction(),
1836 CommunicationDirection::ServerToClient
1837 );
1838
1839 let retrieved_ctx = ctx.bidirectional_context().unwrap();
1840 assert_eq!(
1841 retrieved_ctx.parent_request_id,
1842 Some("parent-123".to_string())
1843 );
1844 assert_eq!(
1845 retrieved_ctx.direction,
1846 CommunicationDirection::ServerToClient
1847 );
1848 assert_eq!(retrieved_ctx.initiator, CommunicationInitiator::Server);
1849 assert!(retrieved_ctx.expects_response);
1850 }
1851
1852 #[test]
1853 fn test_completion_reference_serialization() {
1854 let prompt_ref = CompletionReference::Prompt {
1855 name: "test_prompt".to_string(),
1856 argument: "user_input".to_string(),
1857 };
1858
1859 let template_ref = CompletionReference::ResourceTemplate {
1860 name: "api_endpoint".to_string(),
1861 parameter: "api_key".to_string(),
1862 };
1863
1864 let tool_ref = CompletionReference::Tool {
1865 name: "file_reader".to_string(),
1866 argument: "path".to_string(),
1867 };
1868
1869 let custom_ref = CompletionReference::Custom {
1870 ref_type: "database_query".to_string(),
1871 metadata: {
1872 let mut map = HashMap::new();
1873 map.insert("table".to_string(), serde_json::json!("users"));
1874 map
1875 },
1876 };
1877
1878 let refs = vec![prompt_ref, template_ref, tool_ref, custom_ref];
1880 for ref_item in refs {
1881 let serialized = serde_json::to_value(&ref_item).unwrap();
1882 let deserialized: CompletionReference = serde_json::from_value(serialized).unwrap();
1883
1884 match (&ref_item, &deserialized) {
1885 (
1886 CompletionReference::Prompt {
1887 name: n1,
1888 argument: a1,
1889 },
1890 CompletionReference::Prompt {
1891 name: n2,
1892 argument: a2,
1893 },
1894 ) => {
1895 assert_eq!(n1, n2);
1896 assert_eq!(a1, a2);
1897 }
1898 (
1899 CompletionReference::ResourceTemplate {
1900 name: n1,
1901 parameter: p1,
1902 },
1903 CompletionReference::ResourceTemplate {
1904 name: n2,
1905 parameter: p2,
1906 },
1907 ) => {
1908 assert_eq!(n1, n2);
1909 assert_eq!(p1, p2);
1910 }
1911 (
1912 CompletionReference::Tool {
1913 name: n1,
1914 argument: a1,
1915 },
1916 CompletionReference::Tool {
1917 name: n2,
1918 argument: a2,
1919 },
1920 ) => {
1921 assert_eq!(n1, n2);
1922 assert_eq!(a1, a2);
1923 }
1924 (
1925 CompletionReference::Custom {
1926 ref_type: t1,
1927 metadata: m1,
1928 },
1929 CompletionReference::Custom {
1930 ref_type: t2,
1931 metadata: m2,
1932 },
1933 ) => {
1934 assert_eq!(t1, t2);
1935 assert_eq!(m1.len(), m2.len());
1936 }
1937 _ => panic!("Serialization round-trip failed for CompletionReference"),
1938 }
1939 }
1940 }
1941
1942 #[test]
1943 fn test_context_metadata_integration() {
1944 let mut elicit_ctx = ElicitationContext::new(
1946 "Enter name".to_string(),
1947 serde_json::json!({"type": "string"}),
1948 );
1949 elicit_ctx.required = true;
1950 elicit_ctx.timeout_ms = Some(30_000);
1951 elicit_ctx.cancellable = true;
1952
1953 let ping_ctx = PingContext {
1954 origin: PingOrigin::Server,
1955 response_threshold_ms: Some(2_000),
1956 payload: None,
1957 health_metadata: HashMap::new(),
1958 connection_metrics: None,
1959 };
1960
1961 let ctx = RequestContext::new()
1962 .with_elicitation_context(elicit_ctx)
1963 .with_ping_context(ping_ctx)
1964 .with_metadata("custom_field", "custom_value");
1965
1966 assert!(ctx.elicitation_context().is_some());
1968 assert!(ctx.ping_context().is_some());
1969 assert_eq!(
1970 ctx.get_metadata("custom_field"),
1971 Some(&serde_json::json!("custom_value"))
1972 );
1973
1974 let elicit = ctx.elicitation_context().unwrap();
1976 assert_eq!(elicit.message, "Enter name".to_string());
1977
1978 let ping = ctx.ping_context().unwrap();
1979 assert_eq!(ping.response_threshold_ms, Some(2_000));
1980 }
1981}