1use dashmap::DashMap;
19use std::collections::{HashMap, HashSet, VecDeque};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use vellaveto_config::ToolManifest;
23use vellaveto_mcp::memory_tracking::MemoryTracker;
24use vellaveto_mcp::rug_pull::ToolAnnotations;
25use vellaveto_types::{AgentIdentity, ReplayStatus};
26
27pub type ToolAnnotationsCompact = ToolAnnotations;
29
30#[derive(Debug)]
32pub struct SessionState {
33 pub session_id: String,
34 pub session_scope_binding: String,
39 pub created_at: Instant,
40 pub last_activity: Instant,
41 pub protocol_version: Option<String>,
42 pub(crate) known_tools: HashMap<String, ToolAnnotations>,
45 pub request_count: u64,
46 pub tools_list_seen: bool,
49 pub oauth_subject: Option<String>,
52 pub(crate) flagged_tools: HashSet<String>,
57 pub pinned_manifest: Option<ToolManifest>,
60 pub(crate) call_counts: HashMap<String, u64>,
65 pub(crate) action_history: VecDeque<String>,
71 pub memory_tracker: MemoryTracker,
75 pub elicitation_count: u32,
78 pub sampling_count: u32,
82 pub(crate) verified_request_nonces: HashSet<String>,
86 pub(crate) verified_request_nonce_order: VecDeque<String>,
89 pub(crate) pending_tool_calls: HashMap<String, String>,
95 pub token_expires_at: Option<u64>,
97 pub current_call_chain: Vec<vellaveto_types::CallChainEntry>,
102 pub agent_identity: Option<AgentIdentity>,
106 pub(crate) backend_sessions: HashMap<String, String>,
111 pub(crate) gateway_tools: HashMap<String, Vec<String>>,
116 pub risk_score: Option<vellaveto_types::RiskScore>,
118 pub(crate) abac_granted_policies: Vec<String>,
122 pub discovered_tools: HashMap<String, DiscoveredToolSession>,
125}
126
127const MAX_DISCOVERED_TOOLS_PER_SESSION: usize = 10_000;
130
131const MAX_BACKEND_SESSIONS: usize = 128;
133const MAX_GATEWAY_TOOLS: usize = 128;
135const MAX_TOOLS_PER_BACKEND: usize = 1000;
137
138const MAX_GRANTED_POLICIES: usize = 1024;
140
141const MAX_KNOWN_TOOLS: usize = 2048;
143
144const MAX_FLAGGED_TOOLS: usize = 2048;
146
147const MAX_GLOBAL_FLAGGED_TOOLS: usize = 10_000;
150
151const GLOBAL_FLAGGED_TOOL_TTL: Duration = Duration::from_secs(24 * 60 * 60);
155const MAX_VERIFIED_REQUEST_NONCES: usize = 1024;
158
159#[derive(Debug, Clone)]
162pub struct GlobalFlaggedToolEntry {
163 pub flagged_at: Instant,
165 pub ttl: Duration,
167}
168
169impl GlobalFlaggedToolEntry {
170 fn is_expired(&self) -> bool {
171 self.flagged_at.elapsed() > self.ttl
172 }
173}
174
175#[derive(Debug, Clone)]
177pub struct DiscoveredToolSession {
178 pub tool_id: String,
180 pub discovered_at: Instant,
182 pub ttl: Duration,
184 pub used: bool,
186}
187
188impl DiscoveredToolSession {
189 pub fn is_expired(&self) -> bool {
191 self.discovered_at.elapsed() > self.ttl
192 }
193}
194
195impl SessionState {
196 pub fn new(session_id: String) -> Self {
197 let now = Instant::now();
198 Self {
199 session_id,
200 session_scope_binding: generate_session_scope_binding(),
201 created_at: now,
202 last_activity: now,
203 protocol_version: None,
204 known_tools: HashMap::new(),
205 request_count: 0,
206 tools_list_seen: false,
207 oauth_subject: None,
208 flagged_tools: HashSet::new(),
209 pinned_manifest: None,
210 call_counts: HashMap::new(),
211 action_history: VecDeque::new(),
212 memory_tracker: MemoryTracker::new(),
213 elicitation_count: 0,
214 sampling_count: 0,
215 verified_request_nonces: HashSet::new(),
216 verified_request_nonce_order: VecDeque::new(),
217 pending_tool_calls: HashMap::new(),
218 token_expires_at: None,
219 current_call_chain: Vec::new(),
220 agent_identity: None,
221 backend_sessions: HashMap::new(),
222 gateway_tools: HashMap::new(),
223 risk_score: None,
224 abac_granted_policies: Vec::new(),
225 discovered_tools: HashMap::new(),
226 }
227 }
228
229 pub fn known_tools(&self) -> &HashMap<String, ToolAnnotations> {
237 &self.known_tools
238 }
239
240 pub fn flagged_tools(&self) -> &HashSet<String> {
242 &self.flagged_tools
243 }
244
245 pub fn backend_sessions(&self) -> &HashMap<String, String> {
247 &self.backend_sessions
248 }
249
250 pub fn gateway_tools(&self) -> &HashMap<String, Vec<String>> {
252 &self.gateway_tools
253 }
254
255 pub fn abac_granted_policies(&self) -> &[String] {
257 &self.abac_granted_policies
258 }
259
260 #[allow(clippy::map_entry)] pub fn insert_backend_session(
264 &mut self,
265 backend_id: String,
266 upstream_session_id: String,
267 ) -> bool {
268 if self.backend_sessions.contains_key(&backend_id) {
269 self.backend_sessions
270 .insert(backend_id, upstream_session_id);
271 return true;
272 }
273 if self.backend_sessions.len() >= MAX_BACKEND_SESSIONS {
274 tracing::warn!(
275 session_id = %self.session_id,
276 capacity = MAX_BACKEND_SESSIONS,
277 "Backend sessions capacity reached; dropping new entry"
278 );
279 return false;
280 }
281 self.backend_sessions
282 .insert(backend_id, upstream_session_id);
283 true
284 }
285
286 pub fn insert_gateway_tools(&mut self, backend_id: String, tools: Vec<String>) -> bool {
289 if !self.gateway_tools.contains_key(&backend_id)
290 && self.gateway_tools.len() >= MAX_GATEWAY_TOOLS
291 {
292 tracing::warn!(
293 session_id = %self.session_id,
294 capacity = MAX_GATEWAY_TOOLS,
295 "Gateway tools capacity reached; dropping new backend entry"
296 );
297 return false;
298 }
299 let bounded_tools: Vec<String> = tools.into_iter().take(MAX_TOOLS_PER_BACKEND).collect();
301 self.gateway_tools.insert(backend_id, bounded_tools);
302 true
303 }
304
305 pub fn insert_granted_policy(&mut self, policy_id: String) {
307 if !self.abac_granted_policies.contains(&policy_id)
308 && self.abac_granted_policies.len() < MAX_GRANTED_POLICIES
309 {
310 self.abac_granted_policies.push(policy_id);
311 }
312 }
313
314 #[allow(clippy::map_entry)] pub fn insert_known_tool(&mut self, name: String, annotations: ToolAnnotationsCompact) -> bool {
318 if self.known_tools.contains_key(&name) {
319 self.known_tools.insert(name, annotations);
320 return true;
321 }
322 if self.known_tools.len() >= MAX_KNOWN_TOOLS {
323 tracing::warn!(
324 session_id = %self.session_id,
325 capacity = MAX_KNOWN_TOOLS,
326 "Known tools capacity reached; dropping new tool"
327 );
328 return false;
329 }
330 self.known_tools.insert(name, annotations);
331 true
332 }
333
334 pub fn insert_flagged_tool(&mut self, name: String) {
336 if self.flagged_tools.len() < MAX_FLAGGED_TOOLS {
337 self.flagged_tools.insert(name);
338 }
339 }
340
341 pub fn record_discovered_tools(&mut self, tool_ids: &[String], ttl: Duration) {
347 let now = Instant::now();
348 for tool_id in tool_ids {
349 if !self.discovered_tools.contains_key(tool_id) {
351 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
352 self.evict_expired_discoveries();
354 }
355 if self.discovered_tools.len() >= MAX_DISCOVERED_TOOLS_PER_SESSION {
356 tracing::warn!(
357 session_id = %self.session_id,
358 capacity = MAX_DISCOVERED_TOOLS_PER_SESSION,
359 "Discovered tools capacity reached; dropping new tool"
360 );
361 continue;
362 }
363 }
364 self.discovered_tools.insert(
365 tool_id.clone(),
366 DiscoveredToolSession {
367 tool_id: tool_id.clone(),
368 discovered_at: now,
369 ttl,
370 used: false,
371 },
372 );
373 }
374 }
375
376 pub fn is_tool_discovery_expired(&self, tool_id: &str) -> Option<bool> {
382 self.discovered_tools.get(tool_id).map(|d| d.is_expired())
383 }
384
385 pub fn mark_tool_used(&mut self, tool_id: &str) -> bool {
389 if let Some(entry) = self.discovered_tools.get_mut(tool_id) {
390 entry.used = true;
391 true
392 } else {
393 false
394 }
395 }
396
397 pub fn evict_expired_discoveries(&mut self) -> usize {
401 let before = self.discovered_tools.len();
402 self.discovered_tools.retain(|_, d| !d.is_expired());
403 before - self.discovered_tools.len()
404 }
405
406 pub fn touch(&mut self) {
408 self.last_activity = Instant::now();
409 self.request_count = self.request_count.saturating_add(1);
411 }
412
413 pub fn record_verified_request_nonce(&mut self, nonce: &str) -> ReplayStatus {
415 if nonce.is_empty() || vellaveto_types::has_dangerous_chars(nonce) {
416 return ReplayStatus::NotChecked;
417 }
418 if self.verified_request_nonces.contains(nonce) {
419 return ReplayStatus::ReplayDetected;
420 }
421 if self.verified_request_nonce_order.len() >= MAX_VERIFIED_REQUEST_NONCES {
422 if let Some(evicted) = self.verified_request_nonce_order.pop_front() {
423 self.verified_request_nonces.remove(&evicted);
424 }
425 }
426 let bounded_nonce = nonce.to_string();
427 self.verified_request_nonce_order
428 .push_back(bounded_nonce.clone());
429 self.verified_request_nonces.insert(bounded_nonce);
430 ReplayStatus::Fresh
431 }
432
433 pub fn is_expired(&self, timeout: Duration, max_lifetime: Option<Duration>) -> bool {
439 if self.last_activity.elapsed() > timeout {
440 return true;
441 }
442 if let Some(max) = max_lifetime {
443 if self.created_at.elapsed() > max {
444 return true;
445 }
446 }
447 if let Some(exp) = self.token_expires_at {
448 let now = std::time::SystemTime::now()
449 .duration_since(std::time::UNIX_EPOCH)
450 .unwrap_or_default()
451 .as_secs();
452 if now >= exp {
453 return true;
454 }
455 }
456 false
457 }
458}
459
460fn generate_session_scope_binding() -> String {
461 format!("sidbind:v1:{}", uuid::Uuid::new_v4().simple())
462}
463
464use vellaveto_types::identity::RequestContext;
469
470pub struct StatefulContext<'a> {
485 session: &'a SessionState,
486 previous_actions_cache: std::sync::OnceLock<Vec<String>>,
489}
490
491impl<'a> StatefulContext<'a> {
492 pub fn new(session: &'a SessionState) -> Self {
494 Self {
495 session,
496 previous_actions_cache: std::sync::OnceLock::new(),
497 }
498 }
499}
500
501impl RequestContext for StatefulContext<'_> {
502 fn call_counts(&self) -> &HashMap<String, u64> {
503 &self.session.call_counts
504 }
505
506 fn previous_actions(&self) -> &[String] {
507 self.previous_actions_cache
508 .get_or_init(|| self.session.action_history.iter().cloned().collect())
509 }
510
511 fn call_chain(&self) -> &[vellaveto_types::CallChainEntry] {
512 &self.session.current_call_chain
513 }
514
515 fn agent_identity(&self) -> Option<&AgentIdentity> {
516 self.session.agent_identity.as_ref()
517 }
518
519 fn session_guard_state(&self) -> Option<&str> {
520 None }
522
523 fn risk_score(&self) -> Option<&vellaveto_types::RiskScore> {
524 self.session.risk_score.as_ref()
525 }
526
527 fn to_evaluation_context(&self) -> vellaveto_types::EvaluationContext {
528 vellaveto_types::EvaluationContext {
529 agent_id: self.session.oauth_subject.clone(),
530 agent_identity: self.session.agent_identity.clone(),
531 call_counts: self.session.call_counts.clone(),
532 previous_actions: self.session.action_history.iter().cloned().collect(),
533 call_chain: self.session.current_call_chain.clone(),
534 session_state: None,
535 ..Default::default()
536 }
537 }
538}
539
540const MAX_SESSION_ID_LEN: usize = 128;
544
545pub struct SessionStore {
547 sessions: Arc<DashMap<String, SessionState>>,
548 session_timeout: Duration,
549 max_sessions: usize,
550 max_lifetime: Option<Duration>,
554 global_flagged_tools: Arc<DashMap<String, GlobalFlaggedToolEntry>>,
559}
560
561impl SessionStore {
562 pub fn new(session_timeout: Duration, max_sessions: usize) -> Self {
563 Self {
564 sessions: Arc::new(DashMap::new()),
565 session_timeout,
566 max_sessions,
567 max_lifetime: None,
568 global_flagged_tools: Arc::new(DashMap::new()),
569 }
570 }
571
572 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
575 self.max_lifetime = Some(lifetime);
576 self
577 }
578
579 pub fn get_or_create(&self, client_session_id: Option<&str>) -> String {
585 let client_session_id = client_session_id.filter(|id| {
590 id.len() <= MAX_SESSION_ID_LEN
591 && !id
592 .chars()
593 .any(|c| c.is_control() || vellaveto_types::is_unicode_format_char(c))
594 });
595
596 if let Some(id) = client_session_id {
598 if let Some(mut session) = self.sessions.get_mut(id) {
599 if !session.is_expired(self.session_timeout, self.max_lifetime) {
600 session.touch();
601 return id.to_string();
602 }
603 drop(session);
605 self.sessions.remove(id);
606 }
607 }
608
609 if self.sessions.len() >= self.max_sessions {
616 self.evict_expired();
617 if self.sessions.len() >= self.max_sessions {
619 self.evict_oldest();
620 }
621 }
622
623 let session_id = uuid::Uuid::new_v4().to_string();
625 self.sessions
626 .insert(session_id.clone(), SessionState::new(session_id.clone()));
627 session_id
628 }
629
630 pub fn get(
632 &self,
633 session_id: &str,
634 ) -> Option<dashmap::mapref::one::Ref<'_, String, SessionState>> {
635 self.sessions.get(session_id)
636 }
637
638 pub fn get_mut(
640 &self,
641 session_id: &str,
642 ) -> Option<dashmap::mapref::one::RefMut<'_, String, SessionState>> {
643 self.sessions.get_mut(session_id)
644 }
645
646 pub fn try_get(
650 &self,
651 session_id: &str,
652 ) -> dashmap::try_result::TryResult<dashmap::mapref::one::Ref<'_, String, SessionState>> {
653 self.sessions.try_get(session_id)
654 }
655
656 pub fn try_get_mut(
660 &self,
661 session_id: &str,
662 ) -> dashmap::try_result::TryResult<dashmap::mapref::one::RefMut<'_, String, SessionState>>
663 {
664 self.sessions.try_get_mut(session_id)
665 }
666
667 pub fn evict_expired(&self) {
669 self.sessions
670 .retain(|_, session| !session.is_expired(self.session_timeout, self.max_lifetime));
671 }
672
673 fn evict_oldest(&self) {
675 let oldest = self
676 .sessions
677 .iter()
678 .min_by_key(|entry| entry.value().last_activity)
679 .map(|entry| entry.key().clone());
680
681 if let Some(id) = oldest {
682 self.sessions.remove(&id);
683 }
684 }
685
686 pub fn len(&self) -> usize {
688 self.sessions.len()
689 }
690
691 pub fn is_empty(&self) -> bool {
693 self.sessions.is_empty()
694 }
695
696 pub fn remove(&self, session_id: &str) -> bool {
698 self.sessions.remove(session_id).is_some()
699 }
700
701 pub fn flag_tool_globally(&self, tool_name: String) {
711 if self.global_flagged_tools.len() >= MAX_GLOBAL_FLAGGED_TOOLS {
712 self.evict_expired_global_flags();
714 if self.global_flagged_tools.len() >= MAX_GLOBAL_FLAGGED_TOOLS {
715 tracing::warn!(
716 tool = %tool_name,
717 capacity = MAX_GLOBAL_FLAGGED_TOOLS,
718 "Global flagged-tools registry at capacity; dropping new entry"
719 );
720 return;
721 }
722 }
723 self.global_flagged_tools
725 .entry(tool_name)
726 .or_insert_with(|| GlobalFlaggedToolEntry {
727 flagged_at: Instant::now(),
728 ttl: GLOBAL_FLAGGED_TOOL_TTL,
729 });
730 }
731
732 pub fn is_tool_globally_flagged(&self, tool_name: &str) -> bool {
738 self.global_flagged_tools
739 .get(tool_name)
740 .map(|entry| !entry.is_expired())
741 .unwrap_or(false)
742 }
743
744 pub fn evict_expired_global_flags(&self) -> usize {
746 let before = self.global_flagged_tools.len();
747 self.global_flagged_tools
748 .retain(|_, entry| !entry.is_expired());
749 before.saturating_sub(self.global_flagged_tools.len())
750 }
751
752 pub fn global_flagged_tools_len(&self) -> usize {
754 self.global_flagged_tools.len()
755 }
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 #[test]
763 fn test_session_creation() {
764 let store = SessionStore::new(Duration::from_secs(300), 100);
765 let id = store.get_or_create(None);
766 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
768 }
769
770 #[test]
771 fn test_session_reuse() {
772 let store = SessionStore::new(Duration::from_secs(300), 100);
773 let id1 = store.get_or_create(None);
774 let id2 = store.get_or_create(Some(&id1));
775 assert_eq!(id1, id2);
776 assert_eq!(store.len(), 1);
777 }
778
779 #[test]
780 fn test_session_unknown_id_creates_new() {
781 let store = SessionStore::new(Duration::from_secs(300), 100);
782 let id = store.get_or_create(Some("nonexistent-id"));
783 assert_ne!(id, "nonexistent-id");
784 assert_eq!(store.len(), 1);
785 }
786
787 #[test]
788 fn test_max_sessions_enforced() {
789 let store = SessionStore::new(Duration::from_secs(300), 3);
790 store.get_or_create(None);
791 store.get_or_create(None);
792 store.get_or_create(None);
793 assert_eq!(store.len(), 3);
794 store.get_or_create(None);
796 assert_eq!(store.len(), 3);
797 }
798
799 #[test]
800 fn test_session_remove() {
801 let store = SessionStore::new(Duration::from_secs(300), 100);
802 let id = store.get_or_create(None);
803 assert!(store.remove(&id));
804 assert_eq!(store.len(), 0);
805 assert!(!store.remove(&id));
806 }
807
808 #[test]
809 fn test_session_touch_increments_count() {
810 let store = SessionStore::new(Duration::from_secs(300), 100);
811 let id = store.get_or_create(None);
812 store.get_or_create(Some(&id));
815 let session = store.get_mut(&id).unwrap();
816 assert_eq!(session.request_count, 1);
817 }
818
819 #[test]
820 fn test_flagged_tools_insert_and_contains() {
821 let store = SessionStore::new(Duration::from_secs(300), 100);
822 let id = store.get_or_create(None);
823
824 {
826 let mut session = store.get_mut(&id).unwrap();
827 session.flagged_tools.insert("evil_tool".to_string());
828 session.flagged_tools.insert("suspicious_tool".to_string());
829 }
830
831 let session = store.get_mut(&id).unwrap();
833 assert!(session.flagged_tools.contains("evil_tool"));
834 assert!(session.flagged_tools.contains("suspicious_tool"));
835 assert!(!session.flagged_tools.contains("safe_tool"));
836 assert_eq!(session.flagged_tools.len(), 2);
837 }
838
839 #[test]
840 fn test_flagged_tools_empty_by_default() {
841 let state = SessionState::new("test-session".to_string());
842 assert!(state.flagged_tools.is_empty());
843 assert!(state.pending_tool_calls.is_empty());
844 }
845
846 #[test]
847 fn test_oauth_subject_storage() {
848 let store = SessionStore::new(Duration::from_secs(300), 100);
849 let id = store.get_or_create(None);
850
851 {
853 let session = store.get_mut(&id).unwrap();
854 assert!(session.oauth_subject.is_none());
855 }
856
857 {
859 let mut session = store.get_mut(&id).unwrap();
860 session.oauth_subject = Some("user-42".to_string());
861 }
862
863 let session = store.get_mut(&id).unwrap();
865 assert_eq!(session.oauth_subject.as_deref(), Some("user-42"));
866 }
867
868 #[test]
869 fn test_protocol_version_tracking() {
870 let store = SessionStore::new(Duration::from_secs(300), 100);
871 let id = store.get_or_create(None);
872
873 {
874 let session = store.get_mut(&id).unwrap();
875 assert!(session.protocol_version.is_none());
876 }
877
878 {
879 let mut session = store.get_mut(&id).unwrap();
880 session.protocol_version = Some("2025-11-25".to_string());
881 }
882
883 let session = store.get_mut(&id).unwrap();
884 assert_eq!(session.protocol_version.as_deref(), Some("2025-11-25"));
885 }
886
887 #[test]
888 fn test_known_tools_mutations() {
889 let store = SessionStore::new(Duration::from_secs(300), 100);
890 let id = store.get_or_create(None);
891
892 {
893 let mut session = store.get_mut(&id).unwrap();
894 session.known_tools.insert(
895 "read_file".to_string(),
896 ToolAnnotations {
897 read_only_hint: true,
898 destructive_hint: false,
899 idempotent_hint: true,
900 open_world_hint: false,
901 input_schema_hash: None,
902 },
903 );
904 }
905
906 let session = store.get_mut(&id).unwrap();
907 assert_eq!(session.known_tools.len(), 1);
908 let ann = session.known_tools.get("read_file").unwrap();
909 assert!(ann.read_only_hint);
910 assert!(!ann.destructive_hint);
911 }
912
913 #[test]
914 fn test_tool_annotations_default() {
915 let ann = ToolAnnotations::default();
916 assert!(!ann.read_only_hint);
917 assert!(ann.destructive_hint);
918 assert!(!ann.idempotent_hint);
919 assert!(ann.open_world_hint);
920 }
921
922 #[test]
923 fn test_tool_annotations_equality() {
924 let a = ToolAnnotations {
925 read_only_hint: true,
926 destructive_hint: false,
927 idempotent_hint: true,
928 open_world_hint: false,
929 input_schema_hash: None,
930 };
931 let b = ToolAnnotations {
932 read_only_hint: true,
933 destructive_hint: false,
934 idempotent_hint: true,
935 open_world_hint: false,
936 input_schema_hash: None,
937 };
938 let c = ToolAnnotations::default();
939 assert_eq!(a, b);
940 assert_ne!(a, c);
941 }
942
943 #[test]
944 fn test_tools_list_seen_flag() {
945 let state = SessionState::new("test".to_string());
946 assert!(!state.tools_list_seen);
947 }
948
949 #[test]
952 fn test_inactivity_expiry_preserved() {
953 let state = SessionState::new("test-inactivity".to_string());
954 assert!(!state.is_expired(Duration::from_secs(300), None));
956 assert!(state.is_expired(Duration::from_nanos(0), None));
958 }
959
960 #[test]
961 fn test_absolute_lifetime_enforced() {
962 let state = SessionState::new("test-lifetime".to_string());
963 assert!(state.is_expired(Duration::from_secs(300), Some(Duration::from_nanos(0))));
965 assert!(!state.is_expired(Duration::from_secs(300), Some(Duration::from_secs(86400))));
967 }
968
969 #[test]
970 fn test_none_max_lifetime_no_absolute_limit() {
971 let state = SessionState::new("test-no-limit".to_string());
972 assert!(!state.is_expired(Duration::from_secs(300), None));
974 }
975
976 #[test]
977 fn test_eviction_checks_both_timeouts() {
978 let store = SessionStore::new(Duration::from_secs(300), 100)
980 .with_max_lifetime(Duration::from_nanos(0));
981
982 let _id = store.get_or_create(None);
983 assert_eq!(store.len(), 1);
984
985 store.evict_expired();
987 assert_eq!(store.len(), 0);
988 }
989
990 #[test]
991 fn test_with_max_lifetime_builder() {
992 let store = SessionStore::new(Duration::from_secs(300), 100)
993 .with_max_lifetime(Duration::from_secs(86400));
994 let id = store.get_or_create(None);
996 assert_eq!(store.len(), 1);
997 let id2 = store.get_or_create(Some(&id));
999 assert_eq!(id, id2);
1000 }
1001
1002 #[test]
1005 fn test_session_id_at_max_length_accepted() {
1006 let store = SessionStore::new(Duration::from_secs(300), 100);
1007 let long_id = "a".repeat(MAX_SESSION_ID_LEN);
1009 let id = store.get_or_create(Some(&long_id));
1011 assert_ne!(id, long_id); assert_eq!(store.len(), 1);
1013
1014 store
1016 .sessions
1017 .insert(long_id.clone(), SessionState::new(long_id.clone()));
1018 let reused = store.get_or_create(Some(&long_id));
1019 assert_eq!(reused, long_id);
1020 }
1021
1022 #[test]
1023 fn test_session_id_exceeding_max_length_rejected() {
1024 let store = SessionStore::new(Duration::from_secs(300), 100);
1025 let too_long = "b".repeat(MAX_SESSION_ID_LEN + 1);
1027 store
1028 .sessions
1029 .insert(too_long.clone(), SessionState::new(too_long.clone()));
1030
1031 let id = store.get_or_create(Some(&too_long));
1034 assert_ne!(id, too_long, "Oversized session ID must not be reused");
1035 assert_eq!(id.len(), 36, "Should return a UUID-format session ID");
1036 }
1037
1038 #[test]
1039 fn test_session_id_empty_string_accepted() {
1040 let store = SessionStore::new(Duration::from_secs(300), 100);
1041 let id = store.get_or_create(Some(""));
1043 assert_eq!(id.len(), 36); assert_eq!(store.len(), 1);
1045 }
1046
1047 #[test]
1048 fn test_session_id_exactly_128_chars_boundary() {
1049 let store = SessionStore::new(Duration::from_secs(300), 100);
1050 let exact = "x".repeat(128);
1051 let id = store.get_or_create(Some(&exact));
1053 assert_eq!(id.len(), 36);
1056
1057 let one_over = "x".repeat(129);
1058 let id2 = store.get_or_create(Some(&one_over));
1059 assert_eq!(id2.len(), 36);
1060 assert_eq!(store.len(), 2);
1062 }
1063
1064 #[test]
1070 fn test_stateful_context_implements_trait() {
1071 let session = SessionState::new("test-ctx".to_string());
1072 let ctx = StatefulContext::new(&session);
1073
1074 let _: &dyn RequestContext = &ctx;
1076 assert!(ctx.call_counts().is_empty());
1077 assert!(ctx.previous_actions().is_empty());
1078 assert!(ctx.call_chain().is_empty());
1079 assert!(ctx.agent_identity().is_none());
1080 assert!(ctx.session_guard_state().is_none());
1081 assert!(ctx.risk_score().is_none());
1082 }
1083
1084 #[test]
1086 fn test_stateful_context_call_counts() {
1087 let mut session = SessionState::new("test-counts".to_string());
1088 session.call_counts.insert("read_file".to_string(), 5);
1089 session.call_counts.insert("write_file".to_string(), 3);
1090
1091 let ctx = StatefulContext::new(&session);
1092 assert_eq!(ctx.call_counts().len(), 2);
1093 assert_eq!(ctx.call_counts()["read_file"], 5);
1094 assert_eq!(ctx.call_counts()["write_file"], 3);
1095 }
1096
1097 #[test]
1099 fn test_stateful_context_previous_actions() {
1100 let mut session = SessionState::new("test-actions".to_string());
1101 session.action_history.push_back("read_file".to_string());
1102 session.action_history.push_back("write_file".to_string());
1103 session.action_history.push_back("execute".to_string());
1104
1105 let ctx = StatefulContext::new(&session);
1106 let actions = ctx.previous_actions();
1107 assert_eq!(actions.len(), 3);
1108 assert_eq!(actions[0], "read_file");
1109 assert_eq!(actions[1], "write_file");
1110 assert_eq!(actions[2], "execute");
1111 }
1112
1113 #[test]
1118 fn test_discovered_tools_empty_by_default() {
1119 let state = SessionState::new("test".to_string());
1120 assert!(state.discovered_tools.is_empty());
1121 }
1122
1123 #[test]
1124 fn test_record_discovered_tools() {
1125 let mut state = SessionState::new("test".to_string());
1126 let tools = vec![
1127 "server:read_file".to_string(),
1128 "server:write_file".to_string(),
1129 ];
1130 state.record_discovered_tools(&tools, Duration::from_secs(300));
1131
1132 assert_eq!(state.discovered_tools.len(), 2);
1133 assert!(state.discovered_tools.contains_key("server:read_file"));
1134 assert!(state.discovered_tools.contains_key("server:write_file"));
1135 }
1136
1137 #[test]
1138 fn test_record_discovered_tools_sets_ttl() {
1139 let mut state = SessionState::new("test".to_string());
1140 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
1141
1142 let entry = state.discovered_tools.get("server:tool1").unwrap();
1143 assert_eq!(entry.ttl, Duration::from_secs(60));
1144 assert!(!entry.used);
1145 }
1146
1147 #[test]
1148 fn test_record_discovered_tools_rediscovery_resets_ttl() {
1149 let mut state = SessionState::new("test".to_string());
1150 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(60));
1151
1152 state.mark_tool_used("server:tool1");
1154 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1155
1156 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(120));
1158
1159 let entry = state.discovered_tools.get("server:tool1").unwrap();
1160 assert_eq!(entry.ttl, Duration::from_secs(120));
1161 assert!(!entry.used); }
1163
1164 #[test]
1165 fn test_is_tool_discovery_expired_unknown_tool() {
1166 let state = SessionState::new("test".to_string());
1167 assert_eq!(state.is_tool_discovery_expired("unknown:tool"), None);
1168 }
1169
1170 #[test]
1171 fn test_is_tool_discovery_expired_fresh_tool() {
1172 let mut state = SessionState::new("test".to_string());
1173 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1174 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(false));
1175 }
1176
1177 #[test]
1178 fn test_is_tool_discovery_expired_zero_ttl() {
1179 let mut state = SessionState::new("test".to_string());
1180 state.discovered_tools.insert(
1182 "server:tool1".to_string(),
1183 DiscoveredToolSession {
1184 tool_id: "server:tool1".to_string(),
1185 discovered_at: Instant::now() - Duration::from_secs(1),
1186 ttl: Duration::from_nanos(0),
1187 used: false,
1188 },
1189 );
1190 assert_eq!(state.is_tool_discovery_expired("server:tool1"), Some(true));
1191 }
1192
1193 #[test]
1194 fn test_mark_tool_used_existing() {
1195 let mut state = SessionState::new("test".to_string());
1196 state.record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1197 assert!(!state.discovered_tools.get("server:tool1").unwrap().used);
1198
1199 assert!(state.mark_tool_used("server:tool1"));
1200 assert!(state.discovered_tools.get("server:tool1").unwrap().used);
1201 }
1202
1203 #[test]
1204 fn test_mark_tool_used_nonexistent() {
1205 let mut state = SessionState::new("test".to_string());
1206 assert!(!state.mark_tool_used("unknown:tool"));
1207 }
1208
1209 #[test]
1210 fn test_evict_expired_discoveries_none_expired() {
1211 let mut state = SessionState::new("test".to_string());
1212 state.record_discovered_tools(
1213 &["server:tool1".to_string(), "server:tool2".to_string()],
1214 Duration::from_secs(300),
1215 );
1216 assert_eq!(state.evict_expired_discoveries(), 0);
1217 assert_eq!(state.discovered_tools.len(), 2);
1218 }
1219
1220 #[test]
1221 fn test_evict_expired_discoveries_some_expired() {
1222 let mut state = SessionState::new("test".to_string());
1223
1224 state.record_discovered_tools(&["server:fresh".to_string()], Duration::from_secs(300));
1226
1227 state.discovered_tools.insert(
1229 "server:stale".to_string(),
1230 DiscoveredToolSession {
1231 tool_id: "server:stale".to_string(),
1232 discovered_at: Instant::now() - Duration::from_secs(10),
1233 ttl: Duration::from_secs(1),
1234 used: true,
1235 },
1236 );
1237
1238 assert_eq!(state.evict_expired_discoveries(), 1);
1239 assert_eq!(state.discovered_tools.len(), 1);
1240 assert!(state.discovered_tools.contains_key("server:fresh"));
1241 assert!(!state.discovered_tools.contains_key("server:stale"));
1242 }
1243
1244 #[test]
1245 fn test_evict_expired_discoveries_all_expired() {
1246 let mut state = SessionState::new("test".to_string());
1247 let past = Instant::now() - Duration::from_secs(10);
1248 for i in 0..5 {
1249 state.discovered_tools.insert(
1250 format!("server:tool{i}"),
1251 DiscoveredToolSession {
1252 tool_id: format!("server:tool{i}"),
1253 discovered_at: past,
1254 ttl: Duration::from_secs(1),
1255 used: false,
1256 },
1257 );
1258 }
1259
1260 assert_eq!(state.evict_expired_discoveries(), 5);
1261 assert!(state.discovered_tools.is_empty());
1262 }
1263
1264 #[test]
1265 fn test_discovered_tool_session_is_expired() {
1266 let fresh = DiscoveredToolSession {
1267 tool_id: "t".to_string(),
1268 discovered_at: Instant::now(),
1269 ttl: Duration::from_secs(300),
1270 used: false,
1271 };
1272 assert!(!fresh.is_expired());
1273
1274 let stale = DiscoveredToolSession {
1275 tool_id: "t".to_string(),
1276 discovered_at: Instant::now() - Duration::from_secs(10),
1277 ttl: Duration::from_secs(1),
1278 used: false,
1279 };
1280 assert!(stale.is_expired());
1281 }
1282
1283 #[test]
1284 fn test_discovered_tools_survive_session_touch() {
1285 let store = SessionStore::new(Duration::from_secs(300), 100);
1286 let id = store.get_or_create(None);
1287
1288 {
1290 let mut session = store.get_mut(&id).unwrap();
1291 session
1292 .record_discovered_tools(&["server:tool1".to_string()], Duration::from_secs(300));
1293 }
1294
1295 store.get_or_create(Some(&id));
1297
1298 let session = store.get_mut(&id).unwrap();
1300 assert_eq!(session.discovered_tools.len(), 1);
1301 assert!(session.discovered_tools.contains_key("server:tool1"));
1302 }
1303
1304 #[test]
1305 fn test_multiple_tools_independent_ttl() {
1306 let mut state = SessionState::new("test".to_string());
1307
1308 state.discovered_tools.insert(
1310 "server:short".to_string(),
1311 DiscoveredToolSession {
1312 tool_id: "server:short".to_string(),
1313 discovered_at: Instant::now() - Duration::from_secs(5),
1314 ttl: Duration::from_secs(1),
1315 used: false,
1316 },
1317 );
1318
1319 state.record_discovered_tools(&["server:long".to_string()], Duration::from_secs(3600));
1321
1322 assert_eq!(state.is_tool_discovery_expired("server:short"), Some(true));
1323 assert_eq!(state.is_tool_discovery_expired("server:long"), Some(false));
1324 }
1325
1326 #[test]
1328 fn test_evaluation_context_from_stateful() {
1329 let mut session = SessionState::new("test-eval".to_string());
1330 session.oauth_subject = Some("user-42".to_string());
1331 session.call_counts.insert("tool_a".to_string(), 7);
1332 session.action_history.push_back("tool_a".to_string());
1333 session.agent_identity = Some(AgentIdentity {
1334 issuer: Some("test-issuer".to_string()),
1335 subject: Some("agent-sub".to_string()),
1336 ..Default::default()
1337 });
1338
1339 let ctx = StatefulContext::new(&session);
1340 let eval = ctx.to_evaluation_context();
1341
1342 assert_eq!(eval.agent_id.as_deref(), Some("user-42"));
1343 assert_eq!(eval.call_counts["tool_a"], 7);
1344 assert_eq!(eval.previous_actions, vec!["tool_a".to_string()]);
1345 assert_eq!(
1346 eval.agent_identity.as_ref().unwrap().issuer.as_deref(),
1347 Some("test-issuer")
1348 );
1349 }
1350
1351 #[test]
1356 fn test_global_flagged_tool_basic() {
1357 let store = SessionStore::new(Duration::from_secs(300), 100);
1358 assert!(!store.is_tool_globally_flagged("evil_tool"));
1359 assert_eq!(store.global_flagged_tools_len(), 0);
1360
1361 store.flag_tool_globally("evil_tool".to_string());
1362 assert!(store.is_tool_globally_flagged("evil_tool"));
1363 assert!(!store.is_tool_globally_flagged("safe_tool"));
1364 assert_eq!(store.global_flagged_tools_len(), 1);
1365 }
1366
1367 #[test]
1368 fn test_global_flagged_tool_survives_session_eviction() {
1369 let store = SessionStore::new(Duration::from_secs(300), 2);
1372 let id1 = store.get_or_create(None);
1373
1374 if let Some(mut s) = store.get_mut(&id1) {
1376 s.insert_flagged_tool("rug_pulled_tool".to_string());
1377 }
1378 store.flag_tool_globally("rug_pulled_tool".to_string());
1380
1381 let is_flagged = store
1383 .get_mut(&id1)
1384 .map(|s| s.flagged_tools.contains("rug_pulled_tool"))
1385 .unwrap_or(false);
1386 assert!(is_flagged);
1387
1388 store.get_or_create(None);
1390 store.get_or_create(None); let session_gone = store.get_mut(&id1).is_none();
1394 assert!(session_gone, "session should have been evicted");
1395
1396 assert!(store.is_tool_globally_flagged("rug_pulled_tool"));
1398 }
1399
1400 #[test]
1401 fn test_global_flagged_tool_expiry() {
1402 let store = SessionStore::new(Duration::from_secs(300), 100);
1403
1404 store.global_flagged_tools.insert(
1406 "expired_tool".to_string(),
1407 GlobalFlaggedToolEntry {
1408 flagged_at: Instant::now() - Duration::from_secs(25 * 60 * 60), ttl: GLOBAL_FLAGGED_TOOL_TTL, },
1411 );
1412
1413 assert!(!store.is_tool_globally_flagged("expired_tool"));
1415
1416 let evicted = store.evict_expired_global_flags();
1418 assert_eq!(evicted, 1);
1419 assert_eq!(store.global_flagged_tools_len(), 0);
1420 }
1421
1422 #[test]
1423 fn test_global_flagged_tool_capacity_bound() {
1424 let store = SessionStore::new(Duration::from_secs(300), 100);
1425
1426 for i in 0..MAX_GLOBAL_FLAGGED_TOOLS {
1428 store.flag_tool_globally(format!("tool_{i}"));
1429 }
1430 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1431
1432 store.flag_tool_globally("overflow_tool".to_string());
1434 assert!(!store.is_tool_globally_flagged("overflow_tool"));
1435 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1436 }
1437
1438 #[test]
1439 fn test_global_flagged_tool_capacity_evicts_expired_first() {
1440 let store = SessionStore::new(Duration::from_secs(300), 100);
1441
1442 for i in 0..MAX_GLOBAL_FLAGGED_TOOLS {
1444 store.global_flagged_tools.insert(
1445 format!("old_tool_{i}"),
1446 GlobalFlaggedToolEntry {
1447 flagged_at: Instant::now() - Duration::from_secs(25 * 60 * 60),
1448 ttl: GLOBAL_FLAGGED_TOOL_TTL,
1449 },
1450 );
1451 }
1452 assert_eq!(store.global_flagged_tools_len(), MAX_GLOBAL_FLAGGED_TOOLS);
1453
1454 store.flag_tool_globally("fresh_tool".to_string());
1456 assert!(store.is_tool_globally_flagged("fresh_tool"));
1457 }
1458
1459 #[test]
1460 fn test_global_flagged_tool_no_ttl_reset_on_reflag() {
1461 let store = SessionStore::new(Duration::from_secs(300), 100);
1462
1463 let old_time = Instant::now() - Duration::from_secs(60 * 60); store.global_flagged_tools.insert(
1466 "tool_a".to_string(),
1467 GlobalFlaggedToolEntry {
1468 flagged_at: old_time,
1469 ttl: GLOBAL_FLAGGED_TOOL_TTL,
1470 },
1471 );
1472
1473 store.flag_tool_globally("tool_a".to_string());
1475 let entry = store.global_flagged_tools.get("tool_a").unwrap();
1476 assert_eq!(entry.flagged_at, old_time);
1477 }
1478
1479 #[test]
1480 fn test_global_flagged_tool_unwrap_or_else_fallback() {
1481 let store = SessionStore::new(Duration::from_secs(300), 100);
1484 store.flag_tool_globally("globally_flagged".to_string());
1485
1486 let is_flagged = store
1488 .get_mut("nonexistent-session")
1489 .map(|s| s.flagged_tools.contains("globally_flagged"))
1490 .unwrap_or_else(|| store.is_tool_globally_flagged("globally_flagged"));
1491
1492 assert!(is_flagged, "global fallback should catch flagged tool");
1493 }
1494
1495 #[test]
1496 fn test_r253_get_or_create_rejects_control_chars_in_session_id() {
1497 let store = SessionStore::new(Duration::from_secs(300), 100);
1498
1499 let id = store.get_or_create(None);
1501
1502 let reused = store.get_or_create(Some(&id));
1504 assert_eq!(id, reused);
1505
1506 let new_id = store.get_or_create(Some("session\x00id"));
1509 assert_ne!(new_id, "session\x00id");
1510
1511 let new_id2 = store.get_or_create(Some("session\u{200B}id"));
1513 assert_ne!(new_id2, "session\u{200B}id");
1514 }
1515}