1use dashmap::DashMap;
19use std::collections::{HashMap, HashSet, VecDeque};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use vellaveto_config::ToolManifest;
23use vellaveto_mcp::memory_tracking::MemoryTracker;
24use vellaveto_mcp::rug_pull::ToolAnnotations;
25use vellaveto_types::AgentIdentity;
26
27pub type ToolAnnotationsCompact = ToolAnnotations;
29
30#[derive(Debug)]
32pub struct SessionState {
33 pub session_id: String,
34 pub created_at: Instant,
35 pub last_activity: Instant,
36 pub protocol_version: Option<String>,
37 pub(crate) known_tools: HashMap<String, ToolAnnotations>,
40 pub request_count: u64,
41 pub tools_list_seen: bool,
44 pub oauth_subject: Option<String>,
47 pub(crate) flagged_tools: HashSet<String>,
52 pub pinned_manifest: Option<ToolManifest>,
55 pub(crate) call_counts: HashMap<String, u64>,
60 pub(crate) action_history: VecDeque<String>,
66 pub memory_tracker: MemoryTracker,
70 pub elicitation_count: u32,
73 pub sampling_count: u32,
77 pub(crate) pending_tool_calls: HashMap<String, String>,
83 pub token_expires_at: Option<u64>,
85 pub current_call_chain: Vec<vellaveto_types::CallChainEntry>,
90 pub agent_identity: Option<AgentIdentity>,
94 pub(crate) backend_sessions: HashMap<String, String>,
99 pub(crate) gateway_tools: HashMap<String, Vec<String>>,
104 pub risk_score: Option<vellaveto_types::RiskScore>,
106 pub(crate) abac_granted_policies: Vec<String>,
110 pub discovered_tools: HashMap<String, DiscoveredToolSession>,
113}
114
115const MAX_DISCOVERED_TOOLS_PER_SESSION: usize = 10_000;
118
119const MAX_BACKEND_SESSIONS: usize = 128;
121const MAX_GATEWAY_TOOLS: usize = 128;
123const MAX_TOOLS_PER_BACKEND: usize = 1000;
125
126const MAX_GRANTED_POLICIES: usize = 1024;
128
129const MAX_KNOWN_TOOLS: usize = 2048;
131
132const MAX_FLAGGED_TOOLS: usize = 2048;
134
135#[derive(Debug, Clone)]
137pub struct DiscoveredToolSession {
138 pub tool_id: String,
140 pub discovered_at: Instant,
142 pub ttl: Duration,
144 pub used: bool,
146}
147
148impl DiscoveredToolSession {
149 pub fn is_expired(&self) -> bool {
151 self.discovered_at.elapsed() > self.ttl
152 }
153}
154
155impl SessionState {
156 pub fn new(session_id: String) -> Self {
157 let now = Instant::now();
158 Self {
159 session_id,
160 created_at: now,
161 last_activity: now,
162 protocol_version: None,
163 known_tools: HashMap::new(),
164 request_count: 0,
165 tools_list_seen: false,
166 oauth_subject: None,
167 flagged_tools: HashSet::new(),
168 pinned_manifest: None,
169 call_counts: HashMap::new(),
170 action_history: VecDeque::new(),
171 memory_tracker: MemoryTracker::new(),
172 elicitation_count: 0,
173 sampling_count: 0,
174 pending_tool_calls: HashMap::new(),
175 token_expires_at: None,
176 current_call_chain: Vec::new(),
177 agent_identity: None,
178 backend_sessions: HashMap::new(),
179 gateway_tools: HashMap::new(),
180 risk_score: None,
181 abac_granted_policies: Vec::new(),
182 discovered_tools: HashMap::new(),
183 }
184 }
185
186 pub fn known_tools(&self) -> &HashMap<String, ToolAnnotations> {
194 &self.known_tools
195 }
196
197 pub fn flagged_tools(&self) -> &HashSet<String> {
199 &self.flagged_tools
200 }
201
202 pub fn backend_sessions(&self) -> &HashMap<String, String> {
204 &self.backend_sessions
205 }
206
207 pub fn gateway_tools(&self) -> &HashMap<String, Vec<String>> {
209 &self.gateway_tools
210 }
211
212 pub fn abac_granted_policies(&self) -> &[String] {
214 &self.abac_granted_policies
215 }
216
217 #[allow(clippy::map_entry)] pub fn insert_backend_session(
221 &mut self,
222 backend_id: String,
223 upstream_session_id: String,
224 ) -> bool {
225 if self.backend_sessions.contains_key(&backend_id) {
226 self.backend_sessions
227 .insert(backend_id, upstream_session_id);
228 return true;
229 }
230 if self.backend_sessions.len() >= MAX_BACKEND_SESSIONS {
231 tracing::warn!(
232 session_id = %self.session_id,
233 capacity = MAX_BACKEND_SESSIONS,
234 "Backend sessions capacity reached; dropping new entry"
235 );
236 return false;
237 }
238 self.backend_sessions
239 .insert(backend_id, upstream_session_id);
240 true
241 }
242
243 pub fn insert_gateway_tools(&mut self, backend_id: String, tools: Vec<String>) -> bool {
246 if !self.gateway_tools.contains_key(&backend_id)
247 && self.gateway_tools.len() >= MAX_GATEWAY_TOOLS
248 {
249 tracing::warn!(
250 session_id = %self.session_id,
251 capacity = MAX_GATEWAY_TOOLS,
252 "Gateway tools capacity reached; dropping new backend entry"
253 );
254 return false;
255 }
256 let bounded_tools: Vec<String> = tools.into_iter().take(MAX_TOOLS_PER_BACKEND).collect();
258 self.gateway_tools.insert(backend_id, bounded_tools);
259 true
260 }
261
262 pub fn insert_granted_policy(&mut self, policy_id: String) {
264 if !self.abac_granted_policies.contains(&policy_id)
265 && self.abac_granted_policies.len() < MAX_GRANTED_POLICIES
266 {
267 self.abac_granted_policies.push(policy_id);
268 }
269 }
270
271 #[allow(clippy::map_entry)] pub fn insert_known_tool(&mut self, name: String, annotations: ToolAnnotationsCompact) -> bool {
275 if self.known_tools.contains_key(&name) {
276 self.known_tools.insert(name, annotations);
277 return true;
278 }
279 if self.known_tools.len() >= MAX_KNOWN_TOOLS {
280 tracing::warn!(
281 session_id = %self.session_id,
282 capacity = MAX_KNOWN_TOOLS,
283 "Known tools capacity reached; dropping new tool"
284 );
285 return false;
286 }
287 self.known_tools.insert(name, annotations);
288 true
289 }
290
291 pub fn insert_flagged_tool(&mut self, name: String) {
293 if self.flagged_tools.len() < MAX_FLAGGED_TOOLS {
294 self.flagged_tools.insert(name);
295 }
296 }
297
298 pub fn record_discovered_tools(&mut self, tool_ids: &[String], ttl: Duration) {
304 let now = Instant::now();
305 for tool_id in tool_ids {
306 if !self.discovered_tools.contains_key(tool_id) {
308 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
309 self.evict_expired_discoveries();
311 }
312 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
313 tracing::warn!(
314 session_id = %self.session_id,
315 capacity = MAX_DISCOVERED_TOOLS_PER_SESSION,
316 "Discovered tools capacity reached; dropping new tool"
317 );
318 continue;
319 }
320 }
321 self.discovered_tools.insert(
322 tool_id.clone(),
323 DiscoveredToolSession {
324 tool_id: tool_id.clone(),
325 discovered_at: now,
326 ttl,
327 used: false,
328 },
329 );
330 }
331 }
332
333 pub fn is_tool_discovery_expired(&self, tool_id: &str) -> Option<bool> {
339 self.discovered_tools.get(tool_id).map(|d| d.is_expired())
340 }
341
342 pub fn mark_tool_used(&mut self, tool_id: &str) -> bool {
346 if let Some(entry) = self.discovered_tools.get_mut(tool_id) {
347 entry.used = true;
348 true
349 } else {
350 false
351 }
352 }
353
354 pub fn evict_expired_discoveries(&mut self) -> usize {
358 let before = self.discovered_tools.len();
359 self.discovered_tools.retain(|_, d| !d.is_expired());
360 before - self.discovered_tools.len()
361 }
362
363 pub fn touch(&mut self) {
365 self.last_activity = Instant::now();
366 self.request_count = self.request_count.saturating_add(1);
368 }
369
370 pub fn is_expired(&self, timeout: Duration, max_lifetime: Option<Duration>) -> bool {
376 if self.last_activity.elapsed() > timeout {
377 return true;
378 }
379 if let Some(max) = max_lifetime {
380 if self.created_at.elapsed() > max {
381 return true;
382 }
383 }
384 if let Some(exp) = self.token_expires_at {
385 let now = std::time::SystemTime::now()
386 .duration_since(std::time::UNIX_EPOCH)
387 .unwrap_or_default()
388 .as_secs();
389 if now >= exp {
390 return true;
391 }
392 }
393 false
394 }
395}
396
397use vellaveto_types::identity::RequestContext;
402
403pub struct StatefulContext<'a> {
418 session: &'a SessionState,
419 previous_actions_cache: std::sync::OnceLock<Vec<String>>,
422}
423
424impl<'a> StatefulContext<'a> {
425 pub fn new(session: &'a SessionState) -> Self {
427 Self {
428 session,
429 previous_actions_cache: std::sync::OnceLock::new(),
430 }
431 }
432}
433
434impl RequestContext for StatefulContext<'_> {
435 fn call_counts(&self) -> &HashMap<String, u64> {
436 &self.session.call_counts
437 }
438
439 fn previous_actions(&self) -> &[String] {
440 self.previous_actions_cache
441 .get_or_init(|| self.session.action_history.iter().cloned().collect())
442 }
443
444 fn call_chain(&self) -> &[vellaveto_types::CallChainEntry] {
445 &self.session.current_call_chain
446 }
447
448 fn agent_identity(&self) -> Option<&AgentIdentity> {
449 self.session.agent_identity.as_ref()
450 }
451
452 fn session_guard_state(&self) -> Option<&str> {
453 None }
455
456 fn risk_score(&self) -> Option<&vellaveto_types::RiskScore> {
457 self.session.risk_score.as_ref()
458 }
459
460 fn to_evaluation_context(&self) -> vellaveto_types::EvaluationContext {
461 vellaveto_types::EvaluationContext {
462 agent_id: self.session.oauth_subject.clone(),
463 agent_identity: self.session.agent_identity.clone(),
464 call_counts: self.session.call_counts.clone(),
465 previous_actions: self.session.action_history.iter().cloned().collect(),
466 call_chain: self.session.current_call_chain.clone(),
467 session_state: None,
468 ..Default::default()
469 }
470 }
471}
472
473const MAX_SESSION_ID_LEN: usize = 128;
477
478pub struct SessionStore {
480 sessions: Arc<DashMap<String, SessionState>>,
481 session_timeout: Duration,
482 max_sessions: usize,
483 max_lifetime: Option<Duration>,
487}
488
489impl SessionStore {
490 pub fn new(session_timeout: Duration, max_sessions: usize) -> Self {
491 Self {
492 sessions: Arc::new(DashMap::new()),
493 session_timeout,
494 max_sessions,
495 max_lifetime: None,
496 }
497 }
498
499 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
502 self.max_lifetime = Some(lifetime);
503 self
504 }
505
506 pub fn get_or_create(&self, client_session_id: Option<&str>) -> String {
512 let client_session_id = client_session_id.filter(|id| id.len() <= MAX_SESSION_ID_LEN);
515
516 if let Some(id) = client_session_id {
518 if let Some(mut session) = self.sessions.get_mut(id) {
519 if !session.is_expired(self.session_timeout, self.max_lifetime) {
520 session.touch();
521 return id.to_string();
522 }
523 drop(session);
525 self.sessions.remove(id);
526 }
527 }
528
529 if self.sessions.len() >= self.max_sessions {
536 self.evict_expired();
537 if self.sessions.len() >= self.max_sessions {
539 self.evict_oldest();
540 }
541 }
542
543 let session_id = uuid::Uuid::new_v4().to_string();
545 self.sessions
546 .insert(session_id.clone(), SessionState::new(session_id.clone()));
547 session_id
548 }
549
550 pub fn get_mut(
552 &self,
553 session_id: &str,
554 ) -> Option<dashmap::mapref::one::RefMut<'_, String, SessionState>> {
555 self.sessions.get_mut(session_id)
556 }
557
558 pub fn evict_expired(&self) {
560 self.sessions
561 .retain(|_, session| !session.is_expired(self.session_timeout, self.max_lifetime));
562 }
563
564 fn evict_oldest(&self) {
566 let oldest = self
567 .sessions
568 .iter()
569 .min_by_key(|entry| entry.value().last_activity)
570 .map(|entry| entry.key().clone());
571
572 if let Some(id) = oldest {
573 self.sessions.remove(&id);
574 }
575 }
576
577 pub fn len(&self) -> usize {
579 self.sessions.len()
580 }
581
582 pub fn is_empty(&self) -> bool {
584 self.sessions.is_empty()
585 }
586
587 pub fn remove(&self, session_id: &str) -> bool {
589 self.sessions.remove(session_id).is_some()
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn test_session_creation() {
599 let store = SessionStore::new(Duration::from_secs(300), 100);
600 let id = store.get_or_create(None);
601 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
603 }
604
605 #[test]
606 fn test_session_reuse() {
607 let store = SessionStore::new(Duration::from_secs(300), 100);
608 let id1 = store.get_or_create(None);
609 let id2 = store.get_or_create(Some(&id1));
610 assert_eq!(id1, id2);
611 assert_eq!(store.len(), 1);
612 }
613
614 #[test]
615 fn test_session_unknown_id_creates_new() {
616 let store = SessionStore::new(Duration::from_secs(300), 100);
617 let id = store.get_or_create(Some("nonexistent-id"));
618 assert_ne!(id, "nonexistent-id");
619 assert_eq!(store.len(), 1);
620 }
621
622 #[test]
623 fn test_max_sessions_enforced() {
624 let store = SessionStore::new(Duration::from_secs(300), 3);
625 store.get_or_create(None);
626 store.get_or_create(None);
627 store.get_or_create(None);
628 assert_eq!(store.len(), 3);
629 store.get_or_create(None);
631 assert_eq!(store.len(), 3);
632 }
633
634 #[test]
635 fn test_session_remove() {
636 let store = SessionStore::new(Duration::from_secs(300), 100);
637 let id = store.get_or_create(None);
638 assert!(store.remove(&id));
639 assert_eq!(store.len(), 0);
640 assert!(!store.remove(&id));
641 }
642
643 #[test]
644 fn test_session_touch_increments_count() {
645 let store = SessionStore::new(Duration::from_secs(300), 100);
646 let id = store.get_or_create(None);
647 store.get_or_create(Some(&id));
650 let session = store.get_mut(&id).unwrap();
651 assert_eq!(session.request_count, 1);
652 }
653
654 #[test]
655 fn test_flagged_tools_insert_and_contains() {
656 let store = SessionStore::new(Duration::from_secs(300), 100);
657 let id = store.get_or_create(None);
658
659 {
661 let mut session = store.get_mut(&id).unwrap();
662 session.flagged_tools.insert("evil_tool".to_string());
663 session.flagged_tools.insert("suspicious_tool".to_string());
664 }
665
666 let session = store.get_mut(&id).unwrap();
668 assert!(session.flagged_tools.contains("evil_tool"));
669 assert!(session.flagged_tools.contains("suspicious_tool"));
670 assert!(!session.flagged_tools.contains("safe_tool"));
671 assert_eq!(session.flagged_tools.len(), 2);
672 }
673
674 #[test]
675 fn test_flagged_tools_empty_by_default() {
676 let state = SessionState::new("test-session".to_string());
677 assert!(state.flagged_tools.is_empty());
678 assert!(state.pending_tool_calls.is_empty());
679 }
680
681 #[test]
682 fn test_oauth_subject_storage() {
683 let store = SessionStore::new(Duration::from_secs(300), 100);
684 let id = store.get_or_create(None);
685
686 {
688 let session = store.get_mut(&id).unwrap();
689 assert!(session.oauth_subject.is_none());
690 }
691
692 {
694 let mut session = store.get_mut(&id).unwrap();
695 session.oauth_subject = Some("user-42".to_string());
696 }
697
698 let session = store.get_mut(&id).unwrap();
700 assert_eq!(session.oauth_subject.as_deref(), Some("user-42"));
701 }
702
703 #[test]
704 fn test_protocol_version_tracking() {
705 let store = SessionStore::new(Duration::from_secs(300), 100);
706 let id = store.get_or_create(None);
707
708 {
709 let session = store.get_mut(&id).unwrap();
710 assert!(session.protocol_version.is_none());
711 }
712
713 {
714 let mut session = store.get_mut(&id).unwrap();
715 session.protocol_version = Some("2025-11-25".to_string());
716 }
717
718 let session = store.get_mut(&id).unwrap();
719 assert_eq!(session.protocol_version.as_deref(), Some("2025-11-25"));
720 }
721
722 #[test]
723 fn test_known_tools_mutations() {
724 let store = SessionStore::new(Duration::from_secs(300), 100);
725 let id = store.get_or_create(None);
726
727 {
728 let mut session = store.get_mut(&id).unwrap();
729 session.known_tools.insert(
730 "read_file".to_string(),
731 ToolAnnotations {
732 read_only_hint: true,
733 destructive_hint: false,
734 idempotent_hint: true,
735 open_world_hint: false,
736 input_schema_hash: None,
737 },
738 );
739 }
740
741 let session = store.get_mut(&id).unwrap();
742 assert_eq!(session.known_tools.len(), 1);
743 let ann = session.known_tools.get("read_file").unwrap();
744 assert!(ann.read_only_hint);
745 assert!(!ann.destructive_hint);
746 }
747
748 #[test]
749 fn test_tool_annotations_default() {
750 let ann = ToolAnnotations::default();
751 assert!(!ann.read_only_hint);
752 assert!(ann.destructive_hint);
753 assert!(!ann.idempotent_hint);
754 assert!(ann.open_world_hint);
755 }
756
757 #[test]
758 fn test_tool_annotations_equality() {
759 let a = ToolAnnotations {
760 read_only_hint: true,
761 destructive_hint: false,
762 idempotent_hint: true,
763 open_world_hint: false,
764 input_schema_hash: None,
765 };
766 let b = ToolAnnotations {
767 read_only_hint: true,
768 destructive_hint: false,
769 idempotent_hint: true,
770 open_world_hint: false,
771 input_schema_hash: None,
772 };
773 let c = ToolAnnotations::default();
774 assert_eq!(a, b);
775 assert_ne!(a, c);
776 }
777
778 #[test]
779 fn test_tools_list_seen_flag() {
780 let state = SessionState::new("test".to_string());
781 assert!(!state.tools_list_seen);
782 }
783
784 #[test]
787 fn test_inactivity_expiry_preserved() {
788 let state = SessionState::new("test-inactivity".to_string());
789 assert!(!state.is_expired(Duration::from_secs(300), None));
791 assert!(state.is_expired(Duration::from_nanos(0), None));
793 }
794
795 #[test]
796 fn test_absolute_lifetime_enforced() {
797 let state = SessionState::new("test-lifetime".to_string());
798 assert!(state.is_expired(Duration::from_secs(300), Some(Duration::from_nanos(0))));
800 assert!(!state.is_expired(Duration::from_secs(300), Some(Duration::from_secs(86400))));
802 }
803
804 #[test]
805 fn test_none_max_lifetime_no_absolute_limit() {
806 let state = SessionState::new("test-no-limit".to_string());
807 assert!(!state.is_expired(Duration::from_secs(300), None));
809 }
810
811 #[test]
812 fn test_eviction_checks_both_timeouts() {
813 let store = SessionStore::new(Duration::from_secs(300), 100)
815 .with_max_lifetime(Duration::from_nanos(0));
816
817 let _id = store.get_or_create(None);
818 assert_eq!(store.len(), 1);
819
820 store.evict_expired();
822 assert_eq!(store.len(), 0);
823 }
824
825 #[test]
826 fn test_with_max_lifetime_builder() {
827 let store = SessionStore::new(Duration::from_secs(300), 100)
828 .with_max_lifetime(Duration::from_secs(86400));
829 let id = store.get_or_create(None);
831 assert_eq!(store.len(), 1);
832 let id2 = store.get_or_create(Some(&id));
834 assert_eq!(id, id2);
835 }
836
837 #[test]
840 fn test_session_id_at_max_length_accepted() {
841 let store = SessionStore::new(Duration::from_secs(300), 100);
842 let long_id = "a".repeat(MAX_SESSION_ID_LEN);
844 let id = store.get_or_create(Some(&long_id));
846 assert_ne!(id, long_id); assert_eq!(store.len(), 1);
848
849 store
851 .sessions
852 .insert(long_id.clone(), SessionState::new(long_id.clone()));
853 let reused = store.get_or_create(Some(&long_id));
854 assert_eq!(reused, long_id);
855 }
856
857 #[test]
858 fn test_session_id_exceeding_max_length_rejected() {
859 let store = SessionStore::new(Duration::from_secs(300), 100);
860 let too_long = "b".repeat(MAX_SESSION_ID_LEN + 1);
862 store
863 .sessions
864 .insert(too_long.clone(), SessionState::new(too_long.clone()));
865
866 let id = store.get_or_create(Some(&too_long));
869 assert_ne!(id, too_long, "Oversized session ID must not be reused");
870 assert_eq!(id.len(), 36, "Should return a UUID-format session ID");
871 }
872
873 #[test]
874 fn test_session_id_empty_string_accepted() {
875 let store = SessionStore::new(Duration::from_secs(300), 100);
876 let id = store.get_or_create(Some(""));
878 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
880 }
881
882 #[test]
883 fn test_session_id_exactly_128_chars_boundary() {
884 let store = SessionStore::new(Duration::from_secs(300), 100);
885 let exact = "x".repeat(128);
886 let id = store.get_or_create(Some(&exact));
888 assert_eq!(id.len(), 36);
891
892 let one_over = "x".repeat(129);
893 let id2 = store.get_or_create(Some(&one_over));
894 assert_eq!(id2.len(), 36);
895 assert_eq!(store.len(), 2);
897 }
898
899 #[test]
905 fn test_stateful_context_implements_trait() {
906 let session = SessionState::new("test-ctx".to_string());
907 let ctx = StatefulContext::new(&session);
908
909 let _: &dyn RequestContext = &ctx;
911 assert!(ctx.call_counts().is_empty());
912 assert!(ctx.previous_actions().is_empty());
913 assert!(ctx.call_chain().is_empty());
914 assert!(ctx.agent_identity().is_none());
915 assert!(ctx.session_guard_state().is_none());
916 assert!(ctx.risk_score().is_none());
917 }
918
919 #[test]
921 fn test_stateful_context_call_counts() {
922 let mut session = SessionState::new("test-counts".to_string());
923 session.call_counts.insert("read_file".to_string(), 5);
924 session.call_counts.insert("write_file".to_string(), 3);
925
926 let ctx = StatefulContext::new(&session);
927 assert_eq!(ctx.call_counts().len(), 2);
928 assert_eq!(ctx.call_counts()["read_file"], 5);
929 assert_eq!(ctx.call_counts()["write_file"], 3);
930 }
931
932 #[test]
934 fn test_stateful_context_previous_actions() {
935 let mut session = SessionState::new("test-actions".to_string());
936 session.action_history.push_back("read_file".to_string());
937 session.action_history.push_back("write_file".to_string());
938 session.action_history.push_back("execute".to_string());
939
940 let ctx = StatefulContext::new(&session);
941 let actions = ctx.previous_actions();
942 assert_eq!(actions.len(), 3);
943 assert_eq!(actions[0], "read_file");
944 assert_eq!(actions[1], "write_file");
945 assert_eq!(actions[2], "execute");
946 }
947
948 #[test]
953 fn test_discovered_tools_empty_by_default() {
954 let state = SessionState::new("test".to_string());
955 assert!(state.discovered_tools.is_empty());
956 }
957
958 #[test]
959 fn test_record_discovered_tools() {
960 let mut state = SessionState::new("test".to_string());
961 let tools = vec![
962 "server:read_file".to_string(),
963 "server:write_file".to_string(),
964 ];
965 state.record_discovered_tools(&tools, Duration::from_secs(300));
966
967 assert_eq!(state.discovered_tools.len(), 2);
968 assert!(state.discovered_tools.contains_key("server:read_file"));
969 assert!(state.discovered_tools.contains_key("server:write_file"));
970 }
971
972 #[test]
973 fn test_record_discovered_tools_sets_ttl() {
974 let mut state = SessionState::new("test".to_string());
975 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
976
977 let entry = state.discovered_tools.get("server:tool1").unwrap();
978 assert_eq!(entry.ttl, Duration::from_secs(60));
979 assert!(!entry.used);
980 }
981
982 #[test]
983 fn test_record_discovered_tools_rediscovery_resets_ttl() {
984 let mut state = SessionState::new("test".to_string());
985 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
986
987 state.mark_tool_used("server:tool1");
989 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
990
991 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(120));
993
994 let entry = state.discovered_tools.get("server:tool1").unwrap();
995 assert_eq!(entry.ttl, Duration::from_secs(120));
996 assert!(!entry.used); }
998
999 #[test]
1000 fn test_is_tool_discovery_expired_unknown_tool() {
1001 let state = SessionState::new("test".to_string());
1002 assert_eq!(state.is_tool_discovery_expired("unknown:tool"), None);
1003 }
1004
1005 #[test]
1006 fn test_is_tool_discovery_expired_fresh_tool() {
1007 let mut state = SessionState::new("test".to_string());
1008 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1009 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(false));
1010 }
1011
1012 #[test]
1013 fn test_is_tool_discovery_expired_zero_ttl() {
1014 let mut state = SessionState::new("test".to_string());
1015 state.discovered_tools.insert(
1017 "server:tool1".to_string(),
1018 DiscoveredToolSession {
1019 tool_id: "server:tool1".to_string(),
1020 discovered_at: Instant::now() - Duration::from_secs(1),
1021 ttl: Duration::from_nanos(0),
1022 used: false,
1023 },
1024 );
1025 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(true));
1026 }
1027
1028 #[test]
1029 fn test_mark_tool_used_existing() {
1030 let mut state = SessionState::new("test".to_string());
1031 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1032 assert!(!state.discovered_tools.get("server:tool1").unwrap().used);
1033
1034 assert!(state.mark_tool_used("server:tool1"));
1035 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1036 }
1037
1038 #[test]
1039 fn test_mark_tool_used_nonexistent() {
1040 let mut state = SessionState::new("test".to_string());
1041 assert!(!state.mark_tool_used("unknown:tool"));
1042 }
1043
1044 #[test]
1045 fn test_evict_expired_discoveries_none_expired() {
1046 let mut state = SessionState::new("test".to_string());
1047 state.record_discovered_tools(
1048 &["server:tool1".to_string(), "server:tool2".to_string()],
1049 Duration::from_secs(300),
1050 );
1051 assert_eq!(state.evict_expired_discoveries(), 0);
1052 assert_eq!(state.discovered_tools.len(), 2);
1053 }
1054
1055 #[test]
1056 fn test_evict_expired_discoveries_some_expired() {
1057 let mut state = SessionState::new("test".to_string());
1058
1059 state.record_discovered_tools(&["server:fresh".to_string()], Duration::from_secs(300));
1061
1062 state.discovered_tools.insert(
1064 "server:stale".to_string(),
1065 DiscoveredToolSession {
1066 tool_id: "server:stale".to_string(),
1067 discovered_at: Instant::now() - Duration::from_secs(10),
1068 ttl: Duration::from_secs(1),
1069 used: true,
1070 },
1071 );
1072
1073 assert_eq!(state.evict_expired_discoveries(), 1);
1074 assert_eq!(state.discovered_tools.len(), 1);
1075 assert!(state.discovered_tools.contains_key("server:fresh"));
1076 assert!(!state.discovered_tools.contains_key("server:stale"));
1077 }
1078
1079 #[test]
1080 fn test_evict_expired_discoveries_all_expired() {
1081 let mut state = SessionState::new("test".to_string());
1082 let past = Instant::now() - Duration::from_secs(10);
1083 for i in 0..5 {
1084 state.discovered_tools.insert(
1085 format!("server:tool{}", i),
1086 DiscoveredToolSession {
1087 tool_id: format!("server:tool{}", i),
1088 discovered_at: past,
1089 ttl: Duration::from_secs(1),
1090 used: false,
1091 },
1092 );
1093 }
1094
1095 assert_eq!(state.evict_expired_discoveries(), 5);
1096 assert!(state.discovered_tools.is_empty());
1097 }
1098
1099 #[test]
1100 fn test_discovered_tool_session_is_expired() {
1101 let fresh = DiscoveredToolSession {
1102 tool_id: "t".to_string(),
1103 discovered_at: Instant::now(),
1104 ttl: Duration::from_secs(300),
1105 used: false,
1106 };
1107 assert!(!fresh.is_expired());
1108
1109 let stale = DiscoveredToolSession {
1110 tool_id: "t".to_string(),
1111 discovered_at: Instant::now() - Duration::from_secs(10),
1112 ttl: Duration::from_secs(1),
1113 used: false,
1114 };
1115 assert!(stale.is_expired());
1116 }
1117
1118 #[test]
1119 fn test_discovered_tools_survive_session_touch() {
1120 let store = SessionStore::new(Duration::from_secs(300), 100);
1121 let id = store.get_or_create(None);
1122
1123 {
1125 let mut session = store.get_mut(&id).unwrap();
1126 session
1127 .record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1128 }
1129
1130 store.get_or_create(Some(&id));
1132
1133 let session = store.get_mut(&id).unwrap();
1135 assert_eq!(session.discovered_tools.len(), 1);
1136 assert!(session.discovered_tools.contains_key("server:tool1"));
1137 }
1138
1139 #[test]
1140 fn test_multiple_tools_independent_ttl() {
1141 let mut state = SessionState::new("test".to_string());
1142
1143 state.discovered_tools.insert(
1145 "server:short".to_string(),
1146 DiscoveredToolSession {
1147 tool_id: "server:short".to_string(),
1148 discovered_at: Instant::now() - Duration::from_secs(5),
1149 ttl: Duration::from_secs(1),
1150 used: false,
1151 },
1152 );
1153
1154 state.record_discovered_tools(&["server:long".to_string()], Duration::from_secs(3600));
1156
1157 assert_eq!(state.is_tool_discovery_expired("server:short"), Some(true));
1158 assert_eq!(state.is_tool_discovery_expired("server:long"), Some(false));
1159 }
1160
1161 #[test]
1163 fn test_evaluation_context_from_stateful() {
1164 let mut session = SessionState::new("test-eval".to_string());
1165 session.oauth_subject = Some("user-42".to_string());
1166 session.call_counts.insert("tool_a".to_string(), 7);
1167 session.action_history.push_back("tool_a".to_string());
1168 session.agent_identity = Some(AgentIdentity {
1169 issuer: Some("test-issuer".to_string()),
1170 subject: Some("agent-sub".to_string()),
1171 ..Default::default()
1172 });
1173
1174 let ctx = StatefulContext::new(&session);
1175 let eval = ctx.to_evaluation_context();
1176
1177 assert_eq!(eval.agent_id.as_deref(), Some("user-42"));
1178 assert_eq!(eval.call_counts["tool_a"], 7);
1179 assert_eq!(eval.previous_actions, vec!["tool_a".to_string()]);
1180 assert_eq!(
1181 eval.agent_identity.as_ref().unwrap().issuer.as_deref(),
1182 Some("test-issuer")
1183 );
1184 }
1185}