1use dashmap::DashMap;
19use std::collections::{HashMap, HashSet, VecDeque};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use vellaveto_config::ToolManifest;
23use vellaveto_mcp::memory_tracking::MemoryTracker;
24use vellaveto_mcp::rug_pull::ToolAnnotations;
25use vellaveto_types::AgentIdentity;
26
27pub type ToolAnnotationsCompact = ToolAnnotations;
29
30#[derive(Debug)]
32pub struct SessionState {
33 pub session_id: String,
34 pub created_at: Instant,
35 pub last_activity: Instant,
36 pub protocol_version: Option<String>,
37 pub(crate) known_tools: HashMap<String, ToolAnnotations>,
40 pub request_count: u64,
41 pub tools_list_seen: bool,
44 pub oauth_subject: Option<String>,
47 pub(crate) flagged_tools: HashSet<String>,
52 pub pinned_manifest: Option<ToolManifest>,
55 pub(crate) call_counts: HashMap<String, u64>,
60 pub(crate) action_history: VecDeque<String>,
66 pub memory_tracker: MemoryTracker,
70 pub elicitation_count: u32,
73 pub sampling_count: u32,
77 pub(crate) pending_tool_calls: HashMap<String, String>,
83 pub token_expires_at: Option<u64>,
85 pub current_call_chain: Vec<vellaveto_types::CallChainEntry>,
90 pub agent_identity: Option<AgentIdentity>,
94 pub(crate) backend_sessions: HashMap<String, String>,
99 pub(crate) gateway_tools: HashMap<String, Vec<String>>,
104 pub risk_score: Option<vellaveto_types::RiskScore>,
106 pub(crate) abac_granted_policies: Vec<String>,
110 pub discovered_tools: HashMap<String, DiscoveredToolSession>,
113}
114
115const MAX_DISCOVERED_TOOLS_PER_SESSION: usize = 10_000;
118
119const MAX_BACKEND_SESSIONS: usize = 128;
121const MAX_GATEWAY_TOOLS: usize = 128;
123const MAX_TOOLS_PER_BACKEND: usize = 1000;
125
126const MAX_GRANTED_POLICIES: usize = 1024;
128
129const MAX_KNOWN_TOOLS: usize = 2048;
131
132const MAX_FLAGGED_TOOLS: usize = 2048;
134
135#[derive(Debug, Clone)]
137pub struct DiscoveredToolSession {
138 pub tool_id: String,
140 pub discovered_at: Instant,
142 pub ttl: Duration,
144 pub used: bool,
146}
147
148impl DiscoveredToolSession {
149 pub fn is_expired(&self) -> bool {
151 self.discovered_at.elapsed() > self.ttl
152 }
153}
154
155impl SessionState {
156 pub fn new(session_id: String) -> Self {
157 let now = Instant::now();
158 Self {
159 session_id,
160 created_at: now,
161 last_activity: now,
162 protocol_version: None,
163 known_tools: HashMap::new(),
164 request_count: 0,
165 tools_list_seen: false,
166 oauth_subject: None,
167 flagged_tools: HashSet::new(),
168 pinned_manifest: None,
169 call_counts: HashMap::new(),
170 action_history: VecDeque::new(),
171 memory_tracker: MemoryTracker::new(),
172 elicitation_count: 0,
173 sampling_count: 0,
174 pending_tool_calls: HashMap::new(),
175 token_expires_at: None,
176 current_call_chain: Vec::new(),
177 agent_identity: None,
178 backend_sessions: HashMap::new(),
179 gateway_tools: HashMap::new(),
180 risk_score: None,
181 abac_granted_policies: Vec::new(),
182 discovered_tools: HashMap::new(),
183 }
184 }
185
186 pub fn known_tools(&self) -> &HashMap<String, ToolAnnotations> {
194 &self.known_tools
195 }
196
197 pub fn flagged_tools(&self) -> &HashSet<String> {
199 &self.flagged_tools
200 }
201
202 pub fn backend_sessions(&self) -> &HashMap<String, String> {
204 &self.backend_sessions
205 }
206
207 pub fn gateway_tools(&self) -> &HashMap<String, Vec<String>> {
209 &self.gateway_tools
210 }
211
212 pub fn abac_granted_policies(&self) -> &[String] {
214 &self.abac_granted_policies
215 }
216
217 #[allow(clippy::map_entry)] pub fn insert_backend_session(
221 &mut self,
222 backend_id: String,
223 upstream_session_id: String,
224 ) -> bool {
225 if self.backend_sessions.contains_key(&backend_id) {
226 self.backend_sessions
227 .insert(backend_id, upstream_session_id);
228 return true;
229 }
230 if self.backend_sessions.len() >= MAX_BACKEND_SESSIONS {
231 tracing::warn!(
232 session_id = %self.session_id,
233 capacity = MAX_BACKEND_SESSIONS,
234 "Backend sessions capacity reached; dropping new entry"
235 );
236 return false;
237 }
238 self.backend_sessions
239 .insert(backend_id, upstream_session_id);
240 true
241 }
242
243 pub fn insert_gateway_tools(&mut self, backend_id: String, tools: Vec<String>) -> bool {
246 if !self.gateway_tools.contains_key(&backend_id)
247 && self.gateway_tools.len() >= MAX_GATEWAY_TOOLS
248 {
249 tracing::warn!(
250 session_id = %self.session_id,
251 capacity = MAX_GATEWAY_TOOLS,
252 "Gateway tools capacity reached; dropping new backend entry"
253 );
254 return false;
255 }
256 let bounded_tools: Vec<String> = tools.into_iter().take(MAX_TOOLS_PER_BACKEND).collect();
258 self.gateway_tools.insert(backend_id, bounded_tools);
259 true
260 }
261
262 pub fn insert_granted_policy(&mut self, policy_id: String) {
264 if !self.abac_granted_policies.contains(&policy_id)
265 && self.abac_granted_policies.len() < MAX_GRANTED_POLICIES
266 {
267 self.abac_granted_policies.push(policy_id);
268 }
269 }
270
271 #[allow(clippy::map_entry)] pub fn insert_known_tool(&mut self, name: String, annotations: ToolAnnotationsCompact) -> bool {
275 if self.known_tools.contains_key(&name) {
276 self.known_tools.insert(name, annotations);
277 return true;
278 }
279 if self.known_tools.len() >= MAX_KNOWN_TOOLS {
280 tracing::warn!(
281 session_id = %self.session_id,
282 capacity = MAX_KNOWN_TOOLS,
283 "Known tools capacity reached; dropping new tool"
284 );
285 return false;
286 }
287 self.known_tools.insert(name, annotations);
288 true
289 }
290
291 pub fn insert_flagged_tool(&mut self, name: String) {
293 if self.flagged_tools.len() < MAX_FLAGGED_TOOLS {
294 self.flagged_tools.insert(name);
295 }
296 }
297
298 pub fn record_discovered_tools(&mut self, tool_ids: &[String], ttl: Duration) {
304 let now = Instant::now();
305 for tool_id in tool_ids {
306 if !self.discovered_tools.contains_key(tool_id) {
308 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
309 self.evict_expired_discoveries();
311 }
312 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
313 tracing::warn!(
314 session_id = %self.session_id,
315 capacity = MAX_DISCOVERED_TOOLS_PER_SESSION,
316 "Discovered tools capacity reached; dropping new tool"
317 );
318 continue;
319 }
320 }
321 self.discovered_tools.insert(
322 tool_id.clone(),
323 DiscoveredToolSession {
324 tool_id: tool_id.clone(),
325 discovered_at: now,
326 ttl,
327 used: false,
328 },
329 );
330 }
331 }
332
333 pub fn is_tool_discovery_expired(&self, tool_id: &str) -> Option<bool> {
339 self.discovered_tools.get(tool_id).map(|d| d.is_expired())
340 }
341
342 pub fn mark_tool_used(&mut self, tool_id: &str) -> bool {
346 if let Some(entry) = self.discovered_tools.get_mut(tool_id) {
347 entry.used = true;
348 true
349 } else {
350 false
351 }
352 }
353
354 pub fn evict_expired_discoveries(&mut self) -> usize {
358 let before = self.discovered_tools.len();
359 self.discovered_tools.retain(|_, d| !d.is_expired());
360 before - self.discovered_tools.len()
361 }
362
363 pub fn touch(&mut self) {
365 self.last_activity = Instant::now();
366 self.request_count = self.request_count.saturating_add(1);
368 }
369
370 pub fn is_expired(&self, timeout: Duration, max_lifetime: Option<Duration>) -> bool {
376 if self.last_activity.elapsed() > timeout {
377 return true;
378 }
379 if let Some(max) = max_lifetime {
380 if self.created_at.elapsed() > max {
381 return true;
382 }
383 }
384 if let Some(exp) = self.token_expires_at {
385 let now = std::time::SystemTime::now()
386 .duration_since(std::time::UNIX_EPOCH)
387 .unwrap_or_default()
388 .as_secs();
389 if now >= exp {
390 return true;
391 }
392 }
393 false
394 }
395}
396
397use vellaveto_types::identity::RequestContext;
402
403pub struct StatefulContext<'a> {
418 session: &'a SessionState,
419 previous_actions_cache: std::sync::OnceLock<Vec<String>>,
422}
423
424impl<'a> StatefulContext<'a> {
425 pub fn new(session: &'a SessionState) -> Self {
427 Self {
428 session,
429 previous_actions_cache: std::sync::OnceLock::new(),
430 }
431 }
432}
433
434impl RequestContext for StatefulContext<'_> {
435 fn call_counts(&self) -> &HashMap<String, u64> {
436 &self.session.call_counts
437 }
438
439 fn previous_actions(&self) -> &[String] {
440 self.previous_actions_cache
441 .get_or_init(|| self.session.action_history.iter().cloned().collect())
442 }
443
444 fn call_chain(&self) -> &[vellaveto_types::CallChainEntry] {
445 &self.session.current_call_chain
446 }
447
448 fn agent_identity(&self) -> Option<&AgentIdentity> {
449 self.session.agent_identity.as_ref()
450 }
451
452 fn session_guard_state(&self) -> Option<&str> {
453 None }
455
456 fn risk_score(&self) -> Option<&vellaveto_types::RiskScore> {
457 self.session.risk_score.as_ref()
458 }
459
460 fn to_evaluation_context(&self) -> vellaveto_types::EvaluationContext {
461 vellaveto_types::EvaluationContext {
462 agent_id: self.session.oauth_subject.clone(),
463 agent_identity: self.session.agent_identity.clone(),
464 call_counts: self.session.call_counts.clone(),
465 previous_actions: self.session.action_history.iter().cloned().collect(),
466 call_chain: self.session.current_call_chain.clone(),
467 session_state: None,
468 ..Default::default()
469 }
470 }
471}
472
473const MAX_SESSION_ID_LEN: usize = 128;
477
478pub struct SessionStore {
480 sessions: Arc<DashMap<String, SessionState>>,
481 session_timeout: Duration,
482 max_sessions: usize,
483 max_lifetime: Option<Duration>,
487}
488
489impl SessionStore {
490 pub fn new(session_timeout: Duration, max_sessions: usize) -> Self {
491 Self {
492 sessions: Arc::new(DashMap::new()),
493 session_timeout,
494 max_sessions,
495 max_lifetime: None,
496 }
497 }
498
499 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
502 self.max_lifetime = Some(lifetime);
503 self
504 }
505
506 pub fn get_or_create(&self, client_session_id: Option<&str>) -> String {
512 let client_session_id = client_session_id.filter(|id| id.len() <= MAX_SESSION_ID_LEN);
515
516 if let Some(id) = client_session_id {
518 if let Some(mut session) = self.sessions.get_mut(id) {
519 if !session.is_expired(self.session_timeout, self.max_lifetime) {
520 session.touch();
521 return id.to_string();
522 }
523 drop(session);
525 self.sessions.remove(id);
526 }
527 }
528
529 if self.sessions.len() >= self.max_sessions {
536 self.evict_expired();
537 if self.sessions.len() >= self.max_sessions {
539 self.evict_oldest();
540 }
541 }
542
543 let session_id = uuid::Uuid::new_v4().to_string();
545 self.sessions
546 .insert(session_id.clone(), SessionState::new(session_id.clone()));
547 session_id
548 }
549
550 pub fn get(
552 &self,
553 session_id: &str,
554 ) -> Option<dashmap::mapref::one::Ref<'_, String, SessionState>> {
555 self.sessions.get(session_id)
556 }
557
558 pub fn get_mut(
560 &self,
561 session_id: &str,
562 ) -> Option<dashmap::mapref::one::RefMut<'_, String, SessionState>> {
563 self.sessions.get_mut(session_id)
564 }
565
566 pub fn evict_expired(&self) {
568 self.sessions
569 .retain(|_, session| !session.is_expired(self.session_timeout, self.max_lifetime));
570 }
571
572 fn evict_oldest(&self) {
574 let oldest = self
575 .sessions
576 .iter()
577 .min_by_key(|entry| entry.value().last_activity)
578 .map(|entry| entry.key().clone());
579
580 if let Some(id) = oldest {
581 self.sessions.remove(&id);
582 }
583 }
584
585 pub fn len(&self) -> usize {
587 self.sessions.len()
588 }
589
590 pub fn is_empty(&self) -> bool {
592 self.sessions.is_empty()
593 }
594
595 pub fn remove(&self, session_id: &str) -> bool {
597 self.sessions.remove(session_id).is_some()
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_session_creation() {
607 let store = SessionStore::new(Duration::from_secs(300), 100);
608 let id = store.get_or_create(None);
609 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
611 }
612
613 #[test]
614 fn test_session_reuse() {
615 let store = SessionStore::new(Duration::from_secs(300), 100);
616 let id1 = store.get_or_create(None);
617 let id2 = store.get_or_create(Some(&id1));
618 assert_eq!(id1, id2);
619 assert_eq!(store.len(), 1);
620 }
621
622 #[test]
623 fn test_session_unknown_id_creates_new() {
624 let store = SessionStore::new(Duration::from_secs(300), 100);
625 let id = store.get_or_create(Some("nonexistent-id"));
626 assert_ne!(id, "nonexistent-id");
627 assert_eq!(store.len(), 1);
628 }
629
630 #[test]
631 fn test_max_sessions_enforced() {
632 let store = SessionStore::new(Duration::from_secs(300), 3);
633 store.get_or_create(None);
634 store.get_or_create(None);
635 store.get_or_create(None);
636 assert_eq!(store.len(), 3);
637 store.get_or_create(None);
639 assert_eq!(store.len(), 3);
640 }
641
642 #[test]
643 fn test_session_remove() {
644 let store = SessionStore::new(Duration::from_secs(300), 100);
645 let id = store.get_or_create(None);
646 assert!(store.remove(&id));
647 assert_eq!(store.len(), 0);
648 assert!(!store.remove(&id));
649 }
650
651 #[test]
652 fn test_session_touch_increments_count() {
653 let store = SessionStore::new(Duration::from_secs(300), 100);
654 let id = store.get_or_create(None);
655 store.get_or_create(Some(&id));
658 let session = store.get_mut(&id).unwrap();
659 assert_eq!(session.request_count, 1);
660 }
661
662 #[test]
663 fn test_flagged_tools_insert_and_contains() {
664 let store = SessionStore::new(Duration::from_secs(300), 100);
665 let id = store.get_or_create(None);
666
667 {
669 let mut session = store.get_mut(&id).unwrap();
670 session.flagged_tools.insert("evil_tool".to_string());
671 session.flagged_tools.insert("suspicious_tool".to_string());
672 }
673
674 let session = store.get_mut(&id).unwrap();
676 assert!(session.flagged_tools.contains("evil_tool"));
677 assert!(session.flagged_tools.contains("suspicious_tool"));
678 assert!(!session.flagged_tools.contains("safe_tool"));
679 assert_eq!(session.flagged_tools.len(), 2);
680 }
681
682 #[test]
683 fn test_flagged_tools_empty_by_default() {
684 let state = SessionState::new("test-session".to_string());
685 assert!(state.flagged_tools.is_empty());
686 assert!(state.pending_tool_calls.is_empty());
687 }
688
689 #[test]
690 fn test_oauth_subject_storage() {
691 let store = SessionStore::new(Duration::from_secs(300), 100);
692 let id = store.get_or_create(None);
693
694 {
696 let session = store.get_mut(&id).unwrap();
697 assert!(session.oauth_subject.is_none());
698 }
699
700 {
702 let mut session = store.get_mut(&id).unwrap();
703 session.oauth_subject = Some("user-42".to_string());
704 }
705
706 let session = store.get_mut(&id).unwrap();
708 assert_eq!(session.oauth_subject.as_deref(), Some("user-42"));
709 }
710
711 #[test]
712 fn test_protocol_version_tracking() {
713 let store = SessionStore::new(Duration::from_secs(300), 100);
714 let id = store.get_or_create(None);
715
716 {
717 let session = store.get_mut(&id).unwrap();
718 assert!(session.protocol_version.is_none());
719 }
720
721 {
722 let mut session = store.get_mut(&id).unwrap();
723 session.protocol_version = Some("2025-11-25".to_string());
724 }
725
726 let session = store.get_mut(&id).unwrap();
727 assert_eq!(session.protocol_version.as_deref(), Some("2025-11-25"));
728 }
729
730 #[test]
731 fn test_known_tools_mutations() {
732 let store = SessionStore::new(Duration::from_secs(300), 100);
733 let id = store.get_or_create(None);
734
735 {
736 let mut session = store.get_mut(&id).unwrap();
737 session.known_tools.insert(
738 "read_file".to_string(),
739 ToolAnnotations {
740 read_only_hint: true,
741 destructive_hint: false,
742 idempotent_hint: true,
743 open_world_hint: false,
744 input_schema_hash: None,
745 },
746 );
747 }
748
749 let session = store.get_mut(&id).unwrap();
750 assert_eq!(session.known_tools.len(), 1);
751 let ann = session.known_tools.get("read_file").unwrap();
752 assert!(ann.read_only_hint);
753 assert!(!ann.destructive_hint);
754 }
755
756 #[test]
757 fn test_tool_annotations_default() {
758 let ann = ToolAnnotations::default();
759 assert!(!ann.read_only_hint);
760 assert!(ann.destructive_hint);
761 assert!(!ann.idempotent_hint);
762 assert!(ann.open_world_hint);
763 }
764
765 #[test]
766 fn test_tool_annotations_equality() {
767 let a = ToolAnnotations {
768 read_only_hint: true,
769 destructive_hint: false,
770 idempotent_hint: true,
771 open_world_hint: false,
772 input_schema_hash: None,
773 };
774 let b = ToolAnnotations {
775 read_only_hint: true,
776 destructive_hint: false,
777 idempotent_hint: true,
778 open_world_hint: false,
779 input_schema_hash: None,
780 };
781 let c = ToolAnnotations::default();
782 assert_eq!(a, b);
783 assert_ne!(a, c);
784 }
785
786 #[test]
787 fn test_tools_list_seen_flag() {
788 let state = SessionState::new("test".to_string());
789 assert!(!state.tools_list_seen);
790 }
791
792 #[test]
795 fn test_inactivity_expiry_preserved() {
796 let state = SessionState::new("test-inactivity".to_string());
797 assert!(!state.is_expired(Duration::from_secs(300), None));
799 assert!(state.is_expired(Duration::from_nanos(0), None));
801 }
802
803 #[test]
804 fn test_absolute_lifetime_enforced() {
805 let state = SessionState::new("test-lifetime".to_string());
806 assert!(state.is_expired(Duration::from_secs(300), Some(Duration::from_nanos(0))));
808 assert!(!state.is_expired(Duration::from_secs(300), Some(Duration::from_secs(86400))));
810 }
811
812 #[test]
813 fn test_none_max_lifetime_no_absolute_limit() {
814 let state = SessionState::new("test-no-limit".to_string());
815 assert!(!state.is_expired(Duration::from_secs(300), None));
817 }
818
819 #[test]
820 fn test_eviction_checks_both_timeouts() {
821 let store = SessionStore::new(Duration::from_secs(300), 100)
823 .with_max_lifetime(Duration::from_nanos(0));
824
825 let _id = store.get_or_create(None);
826 assert_eq!(store.len(), 1);
827
828 store.evict_expired();
830 assert_eq!(store.len(), 0);
831 }
832
833 #[test]
834 fn test_with_max_lifetime_builder() {
835 let store = SessionStore::new(Duration::from_secs(300), 100)
836 .with_max_lifetime(Duration::from_secs(86400));
837 let id = store.get_or_create(None);
839 assert_eq!(store.len(), 1);
840 let id2 = store.get_or_create(Some(&id));
842 assert_eq!(id, id2);
843 }
844
845 #[test]
848 fn test_session_id_at_max_length_accepted() {
849 let store = SessionStore::new(Duration::from_secs(300), 100);
850 let long_id = "a".repeat(MAX_SESSION_ID_LEN);
852 let id = store.get_or_create(Some(&long_id));
854 assert_ne!(id, long_id); assert_eq!(store.len(), 1);
856
857 store
859 .sessions
860 .insert(long_id.clone(), SessionState::new(long_id.clone()));
861 let reused = store.get_or_create(Some(&long_id));
862 assert_eq!(reused, long_id);
863 }
864
865 #[test]
866 fn test_session_id_exceeding_max_length_rejected() {
867 let store = SessionStore::new(Duration::from_secs(300), 100);
868 let too_long = "b".repeat(MAX_SESSION_ID_LEN + 1);
870 store
871 .sessions
872 .insert(too_long.clone(), SessionState::new(too_long.clone()));
873
874 let id = store.get_or_create(Some(&too_long));
877 assert_ne!(id, too_long, "Oversized session ID must not be reused");
878 assert_eq!(id.len(), 36, "Should return a UUID-format session ID");
879 }
880
881 #[test]
882 fn test_session_id_empty_string_accepted() {
883 let store = SessionStore::new(Duration::from_secs(300), 100);
884 let id = store.get_or_create(Some(""));
886 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
888 }
889
890 #[test]
891 fn test_session_id_exactly_128_chars_boundary() {
892 let store = SessionStore::new(Duration::from_secs(300), 100);
893 let exact = "x".repeat(128);
894 let id = store.get_or_create(Some(&exact));
896 assert_eq!(id.len(), 36);
899
900 let one_over = "x".repeat(129);
901 let id2 = store.get_or_create(Some(&one_over));
902 assert_eq!(id2.len(), 36);
903 assert_eq!(store.len(), 2);
905 }
906
907 #[test]
913 fn test_stateful_context_implements_trait() {
914 let session = SessionState::new("test-ctx".to_string());
915 let ctx = StatefulContext::new(&session);
916
917 let _: &dyn RequestContext = &ctx;
919 assert!(ctx.call_counts().is_empty());
920 assert!(ctx.previous_actions().is_empty());
921 assert!(ctx.call_chain().is_empty());
922 assert!(ctx.agent_identity().is_none());
923 assert!(ctx.session_guard_state().is_none());
924 assert!(ctx.risk_score().is_none());
925 }
926
927 #[test]
929 fn test_stateful_context_call_counts() {
930 let mut session = SessionState::new("test-counts".to_string());
931 session.call_counts.insert("read_file".to_string(), 5);
932 session.call_counts.insert("write_file".to_string(), 3);
933
934 let ctx = StatefulContext::new(&session);
935 assert_eq!(ctx.call_counts().len(), 2);
936 assert_eq!(ctx.call_counts()["read_file"], 5);
937 assert_eq!(ctx.call_counts()["write_file"], 3);
938 }
939
940 #[test]
942 fn test_stateful_context_previous_actions() {
943 let mut session = SessionState::new("test-actions".to_string());
944 session.action_history.push_back("read_file".to_string());
945 session.action_history.push_back("write_file".to_string());
946 session.action_history.push_back("execute".to_string());
947
948 let ctx = StatefulContext::new(&session);
949 let actions = ctx.previous_actions();
950 assert_eq!(actions.len(), 3);
951 assert_eq!(actions[0], "read_file");
952 assert_eq!(actions[1], "write_file");
953 assert_eq!(actions[2], "execute");
954 }
955
956 #[test]
961 fn test_discovered_tools_empty_by_default() {
962 let state = SessionState::new("test".to_string());
963 assert!(state.discovered_tools.is_empty());
964 }
965
966 #[test]
967 fn test_record_discovered_tools() {
968 let mut state = SessionState::new("test".to_string());
969 let tools = vec![
970 "server:read_file".to_string(),
971 "server:write_file".to_string(),
972 ];
973 state.record_discovered_tools(&tools, Duration::from_secs(300));
974
975 assert_eq!(state.discovered_tools.len(), 2);
976 assert!(state.discovered_tools.contains_key("server:read_file"));
977 assert!(state.discovered_tools.contains_key("server:write_file"));
978 }
979
980 #[test]
981 fn test_record_discovered_tools_sets_ttl() {
982 let mut state = SessionState::new("test".to_string());
983 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
984
985 let entry = state.discovered_tools.get("server:tool1").unwrap();
986 assert_eq!(entry.ttl, Duration::from_secs(60));
987 assert!(!entry.used);
988 }
989
990 #[test]
991 fn test_record_discovered_tools_rediscovery_resets_ttl() {
992 let mut state = SessionState::new("test".to_string());
993 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
994
995 state.mark_tool_used("server:tool1");
997 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
998
999 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(120));
1001
1002 let entry = state.discovered_tools.get("server:tool1").unwrap();
1003 assert_eq!(entry.ttl, Duration::from_secs(120));
1004 assert!(!entry.used); }
1006
1007 #[test]
1008 fn test_is_tool_discovery_expired_unknown_tool() {
1009 let state = SessionState::new("test".to_string());
1010 assert_eq!(state.is_tool_discovery_expired("unknown:tool"), None);
1011 }
1012
1013 #[test]
1014 fn test_is_tool_discovery_expired_fresh_tool() {
1015 let mut state = SessionState::new("test".to_string());
1016 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1017 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(false));
1018 }
1019
1020 #[test]
1021 fn test_is_tool_discovery_expired_zero_ttl() {
1022 let mut state = SessionState::new("test".to_string());
1023 state.discovered_tools.insert(
1025 "server:tool1".to_string(),
1026 DiscoveredToolSession {
1027 tool_id: "server:tool1".to_string(),
1028 discovered_at: Instant::now() - Duration::from_secs(1),
1029 ttl: Duration::from_nanos(0),
1030 used: false,
1031 },
1032 );
1033 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(true));
1034 }
1035
1036 #[test]
1037 fn test_mark_tool_used_existing() {
1038 let mut state = SessionState::new("test".to_string());
1039 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1040 assert!(!state.discovered_tools.get("server:tool1").unwrap().used);
1041
1042 assert!(state.mark_tool_used("server:tool1"));
1043 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1044 }
1045
1046 #[test]
1047 fn test_mark_tool_used_nonexistent() {
1048 let mut state = SessionState::new("test".to_string());
1049 assert!(!state.mark_tool_used("unknown:tool"));
1050 }
1051
1052 #[test]
1053 fn test_evict_expired_discoveries_none_expired() {
1054 let mut state = SessionState::new("test".to_string());
1055 state.record_discovered_tools(
1056 &["server:tool1".to_string(), "server:tool2".to_string()],
1057 Duration::from_secs(300),
1058 );
1059 assert_eq!(state.evict_expired_discoveries(), 0);
1060 assert_eq!(state.discovered_tools.len(), 2);
1061 }
1062
1063 #[test]
1064 fn test_evict_expired_discoveries_some_expired() {
1065 let mut state = SessionState::new("test".to_string());
1066
1067 state.record_discovered_tools(&["server:fresh".to_string()], Duration::from_secs(300));
1069
1070 state.discovered_tools.insert(
1072 "server:stale".to_string(),
1073 DiscoveredToolSession {
1074 tool_id: "server:stale".to_string(),
1075 discovered_at: Instant::now() - Duration::from_secs(10),
1076 ttl: Duration::from_secs(1),
1077 used: true,
1078 },
1079 );
1080
1081 assert_eq!(state.evict_expired_discoveries(), 1);
1082 assert_eq!(state.discovered_tools.len(), 1);
1083 assert!(state.discovered_tools.contains_key("server:fresh"));
1084 assert!(!state.discovered_tools.contains_key("server:stale"));
1085 }
1086
1087 #[test]
1088 fn test_evict_expired_discoveries_all_expired() {
1089 let mut state = SessionState::new("test".to_string());
1090 let past = Instant::now() - Duration::from_secs(10);
1091 for i in 0..5 {
1092 state.discovered_tools.insert(
1093 format!("server:tool{}", i),
1094 DiscoveredToolSession {
1095 tool_id: format!("server:tool{}", i),
1096 discovered_at: past,
1097 ttl: Duration::from_secs(1),
1098 used: false,
1099 },
1100 );
1101 }
1102
1103 assert_eq!(state.evict_expired_discoveries(), 5);
1104 assert!(state.discovered_tools.is_empty());
1105 }
1106
1107 #[test]
1108 fn test_discovered_tool_session_is_expired() {
1109 let fresh = DiscoveredToolSession {
1110 tool_id: "t".to_string(),
1111 discovered_at: Instant::now(),
1112 ttl: Duration::from_secs(300),
1113 used: false,
1114 };
1115 assert!(!fresh.is_expired());
1116
1117 let stale = DiscoveredToolSession {
1118 tool_id: "t".to_string(),
1119 discovered_at: Instant::now() - Duration::from_secs(10),
1120 ttl: Duration::from_secs(1),
1121 used: false,
1122 };
1123 assert!(stale.is_expired());
1124 }
1125
1126 #[test]
1127 fn test_discovered_tools_survive_session_touch() {
1128 let store = SessionStore::new(Duration::from_secs(300), 100);
1129 let id = store.get_or_create(None);
1130
1131 {
1133 let mut session = store.get_mut(&id).unwrap();
1134 session
1135 .record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1136 }
1137
1138 store.get_or_create(Some(&id));
1140
1141 let session = store.get_mut(&id).unwrap();
1143 assert_eq!(session.discovered_tools.len(), 1);
1144 assert!(session.discovered_tools.contains_key("server:tool1"));
1145 }
1146
1147 #[test]
1148 fn test_multiple_tools_independent_ttl() {
1149 let mut state = SessionState::new("test".to_string());
1150
1151 state.discovered_tools.insert(
1153 "server:short".to_string(),
1154 DiscoveredToolSession {
1155 tool_id: "server:short".to_string(),
1156 discovered_at: Instant::now() - Duration::from_secs(5),
1157 ttl: Duration::from_secs(1),
1158 used: false,
1159 },
1160 );
1161
1162 state.record_discovered_tools(&["server:long".to_string()], Duration::from_secs(3600));
1164
1165 assert_eq!(state.is_tool_discovery_expired("server:short"), Some(true));
1166 assert_eq!(state.is_tool_discovery_expired("server:long"), Some(false));
1167 }
1168
1169 #[test]
1171 fn test_evaluation_context_from_stateful() {
1172 let mut session = SessionState::new("test-eval".to_string());
1173 session.oauth_subject = Some("user-42".to_string());
1174 session.call_counts.insert("tool_a".to_string(), 7);
1175 session.action_history.push_back("tool_a".to_string());
1176 session.agent_identity = Some(AgentIdentity {
1177 issuer: Some("test-issuer".to_string()),
1178 subject: Some("agent-sub".to_string()),
1179 ..Default::default()
1180 });
1181
1182 let ctx = StatefulContext::new(&session);
1183 let eval = ctx.to_evaluation_context();
1184
1185 assert_eq!(eval.agent_id.as_deref(), Some("user-42"));
1186 assert_eq!(eval.call_counts["tool_a"], 7);
1187 assert_eq!(eval.previous_actions, vec!["tool_a".to_string()]);
1188 assert_eq!(
1189 eval.agent_identity.as_ref().unwrap().issuer.as_deref(),
1190 Some("test-issuer")
1191 );
1192 }
1193}