1use std::collections::HashMap;
37use std::fmt;
38use std::sync::{Arc, RwLock};
39use std::time::{Duration, Instant, SystemTime};
40
41pub type SessionId = String;
43
44#[derive(Debug, Clone)]
46pub struct AgentContext {
47 pub session_id: SessionId,
49 pub working_dir: String,
51 pub variables: HashMap<String, ContextValue>,
53 pub permissions: AgentPermissions,
55 pub started_at: SystemTime,
57 pub last_activity: Instant,
59 pub audit: Vec<AuditEntry>,
61 pub transaction: Option<TransactionScope>,
63 pub budget: OperationBudget,
65 pub tool_registry: Vec<ToolDefinition>,
67 pub tool_calls: Vec<ToolCallRecord>,
69}
70
71#[derive(Debug, Clone)]
73pub struct ToolDefinition {
74 pub name: String,
76 pub description: String,
78 pub parameters_schema: Option<String>,
80 pub requires_confirmation: bool,
82}
83
84#[derive(Debug, Clone)]
86pub struct ToolCallRecord {
87 pub call_id: String,
89 pub tool_name: String,
91 pub arguments: String,
93 pub result: Option<String>,
95 pub error: Option<String>,
97 pub timestamp: SystemTime,
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub enum ContextValue {
104 String(String),
105 Number(f64),
106 Bool(bool),
107 List(Vec<ContextValue>),
108 Object(HashMap<String, ContextValue>),
109 Null,
110}
111
112impl fmt::Display for ContextValue {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match self {
115 ContextValue::String(s) => write!(f, "\"{}\"", s),
116 ContextValue::Number(n) => write!(f, "{}", n),
117 ContextValue::Bool(b) => write!(f, "{}", b),
118 ContextValue::List(l) => {
119 write!(f, "[")?;
120 for (i, v) in l.iter().enumerate() {
121 if i > 0 {
122 write!(f, ", ")?;
123 }
124 write!(f, "{}", v)?;
125 }
126 write!(f, "]")
127 }
128 ContextValue::Object(o) => {
129 write!(f, "{{")?;
130 for (i, (k, v)) in o.iter().enumerate() {
131 if i > 0 {
132 write!(f, ", ")?;
133 }
134 write!(f, "\"{}\": {}", k, v)?;
135 }
136 write!(f, "}}")
137 }
138 ContextValue::Null => write!(f, "null"),
139 }
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct AgentPermissions {
146 pub filesystem: FsPermissions,
148 pub database: DbPermissions,
150 pub calculator: bool,
152 pub network: NetworkPermissions,
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct FsPermissions {
159 pub read: bool,
161 pub write: bool,
163 pub mkdir: bool,
165 pub delete: bool,
167 pub allowed_paths: Vec<String>,
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct DbPermissions {
174 pub read: bool,
176 pub write: bool,
178 pub create: bool,
180 pub drop: bool,
182 pub allowed_tables: Vec<String>,
184}
185
186#[derive(Debug, Clone, Default)]
188pub struct NetworkPermissions {
189 pub http: bool,
191 pub allowed_domains: Vec<String>,
193}
194
195#[derive(Debug, Clone)]
197pub struct AuditEntry {
198 pub timestamp: SystemTime,
200 pub operation: AuditOperation,
202 pub resource: String,
204 pub result: AuditResult,
206 pub metadata: HashMap<String, String>,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
212pub enum AuditOperation {
213 FsRead,
214 FsWrite,
215 FsMkdir,
216 FsDelete,
217 FsList,
218 DbQuery,
219 DbInsert,
220 DbUpdate,
221 DbDelete,
222 Calculate,
223 VarSet,
224 VarGet,
225 TxBegin,
226 TxCommit,
227 TxRollback,
228}
229
230impl fmt::Display for AuditOperation {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 match self {
233 AuditOperation::FsRead => write!(f, "fs.read"),
234 AuditOperation::FsWrite => write!(f, "fs.write"),
235 AuditOperation::FsMkdir => write!(f, "fs.mkdir"),
236 AuditOperation::FsDelete => write!(f, "fs.delete"),
237 AuditOperation::FsList => write!(f, "fs.list"),
238 AuditOperation::DbQuery => write!(f, "db.query"),
239 AuditOperation::DbInsert => write!(f, "db.insert"),
240 AuditOperation::DbUpdate => write!(f, "db.update"),
241 AuditOperation::DbDelete => write!(f, "db.delete"),
242 AuditOperation::Calculate => write!(f, "calc"),
243 AuditOperation::VarSet => write!(f, "var.set"),
244 AuditOperation::VarGet => write!(f, "var.get"),
245 AuditOperation::TxBegin => write!(f, "tx.begin"),
246 AuditOperation::TxCommit => write!(f, "tx.commit"),
247 AuditOperation::TxRollback => write!(f, "tx.rollback"),
248 }
249 }
250}
251
252#[derive(Debug, Clone)]
254pub enum AuditResult {
255 Success,
256 Error(String),
257 Denied(String),
258}
259
260#[derive(Debug, Clone)]
262pub struct TransactionScope {
263 pub tx_id: u64,
265 pub started_at: Instant,
267 pub savepoints: Vec<String>,
269 pub pending_writes: Vec<PendingWrite>,
271}
272
273#[derive(Debug, Clone)]
275pub struct PendingWrite {
276 pub resource_type: ResourceType,
278 pub resource_key: String,
280 pub original_value: Option<Vec<u8>>,
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
286pub enum ResourceType {
287 File,
288 Directory,
289 Table,
290 Variable,
291}
292
293#[derive(Debug, Clone)]
295pub struct OperationBudget {
296 pub max_tokens: Option<u64>,
298 pub tokens_used: u64,
300 pub max_cost: Option<u64>,
302 pub cost_used: u64,
304 pub max_operations: Option<u64>,
306 pub operations_used: u64,
308}
309
310impl Default for OperationBudget {
311 fn default() -> Self {
312 Self {
313 max_tokens: None,
314 max_cost: None,
315 max_operations: Some(10000),
316 tokens_used: 0,
317 cost_used: 0,
318 operations_used: 0,
319 }
320 }
321}
322
323#[derive(Debug, Clone)]
325pub enum ContextError {
326 PermissionDenied(String),
327 VariableNotFound(String),
328 BudgetExceeded(String),
329 TransactionError(String),
330 InvalidPath(String),
331 SessionExpired,
332}
333
334impl fmt::Display for ContextError {
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 match self {
337 ContextError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
338 ContextError::VariableNotFound(name) => write!(f, "Variable not found: {}", name),
339 ContextError::BudgetExceeded(msg) => write!(f, "Budget exceeded: {}", msg),
340 ContextError::TransactionError(msg) => write!(f, "Transaction error: {}", msg),
341 ContextError::InvalidPath(path) => write!(f, "Invalid path: {}", path),
342 ContextError::SessionExpired => write!(f, "Session expired"),
343 }
344 }
345}
346
347impl std::error::Error for ContextError {}
348
349impl AgentContext {
350 pub fn new(session_id: SessionId) -> Self {
352 let now = Instant::now();
353 Self {
354 session_id: session_id.clone(),
355 working_dir: format!("/agents/{}", session_id),
356 variables: HashMap::new(),
357 permissions: AgentPermissions::default(),
358 started_at: SystemTime::now(),
359 last_activity: now,
360 audit: Vec::new(),
361 transaction: None,
362 budget: OperationBudget::default(),
363 tool_registry: Vec::new(),
364 tool_calls: Vec::new(),
365 }
366 }
367
368 pub fn with_working_dir(session_id: SessionId, working_dir: String) -> Self {
370 let mut ctx = Self::new(session_id);
371 ctx.working_dir = working_dir;
372 ctx
373 }
374
375 pub fn with_full_permissions(session_id: SessionId) -> Self {
377 let mut ctx = Self::new(session_id);
378 ctx.permissions = AgentPermissions {
379 filesystem: FsPermissions {
380 read: true,
381 write: true,
382 mkdir: true,
383 delete: true,
384 allowed_paths: vec!["/".into()],
385 },
386 database: DbPermissions {
387 read: true,
388 write: true,
389 create: true,
390 drop: true,
391 allowed_tables: vec!["*".into()],
392 },
393 calculator: true,
394 network: NetworkPermissions::default(),
395 };
396 ctx
397 }
398
399 pub fn register_tool(&mut self, tool: ToolDefinition) {
401 self.tool_registry.push(tool);
402 }
403
404 pub fn record_tool_call(&mut self, call: ToolCallRecord) {
406 self.tool_calls.push(call);
407 }
408
409 pub fn set_var(&mut self, name: &str, value: ContextValue) {
411 self.variables.insert(name.to_string(), value.clone());
412 self.touch();
413 self.audit(AuditOperation::VarSet, name, AuditResult::Success);
414 }
415
416 pub fn get_var(&mut self, name: &str) -> Option<ContextValue> {
418 self.touch();
419 let result = self.variables.get(name).cloned();
420 if result.is_some() {
421 self.audit(AuditOperation::VarGet, name, AuditResult::Success);
422 } else {
423 self.audit(
424 AuditOperation::VarGet,
425 name,
426 AuditResult::Error("not found".into()),
427 );
428 }
429 result
430 }
431
432 pub fn peek_var(&self, name: &str) -> Option<&ContextValue> {
434 self.variables.get(name)
435 }
436
437 fn touch(&mut self) {
439 self.last_activity = Instant::now();
440 }
441
442 fn audit(&mut self, operation: AuditOperation, resource: &str, result: AuditResult) {
444 self.audit.push(AuditEntry {
445 timestamp: SystemTime::now(),
446 operation,
447 resource: resource.to_string(),
448 result,
449 metadata: HashMap::new(),
450 });
451 }
452
453 pub fn check_fs_permission(&self, path: &str, op: AuditOperation) -> Result<(), ContextError> {
455 let perm = match op {
456 AuditOperation::FsRead | AuditOperation::FsList => self.permissions.filesystem.read,
457 AuditOperation::FsWrite => self.permissions.filesystem.write,
458 AuditOperation::FsMkdir => self.permissions.filesystem.mkdir,
459 AuditOperation::FsDelete => self.permissions.filesystem.delete,
460 _ => {
461 return Err(ContextError::PermissionDenied(
462 "invalid fs operation".into(),
463 ));
464 }
465 };
466
467 if !perm {
468 return Err(ContextError::PermissionDenied(format!(
469 "{} not allowed",
470 op
471 )));
472 }
473
474 if !self.permissions.filesystem.allowed_paths.is_empty() {
476 let allowed = self
477 .permissions
478 .filesystem
479 .allowed_paths
480 .iter()
481 .any(|p| path.starts_with(p) || p == "*");
482 if !allowed {
483 return Err(ContextError::PermissionDenied(format!(
484 "path {} not in allowed paths",
485 path
486 )));
487 }
488 }
489
490 Ok(())
491 }
492
493 pub fn check_db_permission(&self, table: &str, op: AuditOperation) -> Result<(), ContextError> {
495 let perm = match op {
496 AuditOperation::DbQuery => self.permissions.database.read,
497 AuditOperation::DbInsert | AuditOperation::DbUpdate => self.permissions.database.write,
498 AuditOperation::DbDelete => self.permissions.database.drop,
499 _ => {
500 return Err(ContextError::PermissionDenied(
501 "invalid db operation".into(),
502 ));
503 }
504 };
505
506 if !perm {
507 return Err(ContextError::PermissionDenied(format!(
508 "{} not allowed",
509 op
510 )));
511 }
512
513 if !self.permissions.database.allowed_tables.is_empty() {
515 let allowed = self.permissions.database.allowed_tables.iter().any(|t| {
516 t == "*" || t == table || (t.ends_with('*') && table.starts_with(&t[..t.len() - 1]))
517 });
518 if !allowed {
519 return Err(ContextError::PermissionDenied(format!(
520 "table {} not in allowed tables",
521 table
522 )));
523 }
524 }
525
526 Ok(())
527 }
528
529 pub fn consume_budget(&mut self, tokens: u64, cost: u64) -> Result<(), ContextError> {
531 self.budget.operations_used += 1;
532 self.budget.tokens_used += tokens;
533 self.budget.cost_used += cost;
534
535 if let Some(max) = self.budget.max_operations
536 && self.budget.operations_used > max
537 {
538 return Err(ContextError::BudgetExceeded("max operations".into()));
539 }
540 if let Some(max) = self.budget.max_tokens
541 && self.budget.tokens_used > max
542 {
543 return Err(ContextError::BudgetExceeded("max tokens".into()));
544 }
545 if let Some(max) = self.budget.max_cost
546 && self.budget.cost_used > max
547 {
548 return Err(ContextError::BudgetExceeded("max cost".into()));
549 }
550
551 Ok(())
552 }
553
554 pub fn begin_transaction(&mut self, tx_id: u64) -> Result<(), ContextError> {
556 if self.transaction.is_some() {
557 return Err(ContextError::TransactionError(
558 "already in transaction".into(),
559 ));
560 }
561
562 self.transaction = Some(TransactionScope {
563 tx_id,
564 started_at: Instant::now(),
565 savepoints: Vec::new(),
566 pending_writes: Vec::new(),
567 });
568
569 self.audit(
570 AuditOperation::TxBegin,
571 &format!("tx:{}", tx_id),
572 AuditResult::Success,
573 );
574 Ok(())
575 }
576
577 pub fn commit_transaction(&mut self) -> Result<(), ContextError> {
579 let tx = self
580 .transaction
581 .take()
582 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
583
584 self.audit(
585 AuditOperation::TxCommit,
586 &format!("tx:{}", tx.tx_id),
587 AuditResult::Success,
588 );
589 Ok(())
590 }
591
592 pub fn rollback_transaction(&mut self) -> Result<Vec<PendingWrite>, ContextError> {
594 let tx = self
595 .transaction
596 .take()
597 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
598
599 self.audit(
600 AuditOperation::TxRollback,
601 &format!("tx:{}", tx.tx_id),
602 AuditResult::Success,
603 );
604
605 Ok(tx.pending_writes)
606 }
607
608 pub fn savepoint(&mut self, name: &str) -> Result<(), ContextError> {
610 let tx = self
611 .transaction
612 .as_mut()
613 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
614
615 tx.savepoints.push(name.to_string());
616 Ok(())
617 }
618
619 pub fn record_pending_write(
621 &mut self,
622 resource_type: ResourceType,
623 resource_key: String,
624 original_value: Option<Vec<u8>>,
625 ) -> Result<(), ContextError> {
626 let tx = self
627 .transaction
628 .as_mut()
629 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
630
631 tx.pending_writes.push(PendingWrite {
632 resource_type,
633 resource_key,
634 original_value,
635 });
636 Ok(())
637 }
638
639 pub fn resolve_path(&self, path: &str) -> String {
641 if path.starts_with('/') {
642 path.to_string()
643 } else {
644 format!("{}/{}", self.working_dir, path)
645 }
646 }
647
648 pub fn substitute_vars(&self, input: &str) -> String {
650 let mut result = input.to_string();
651
652 for (name, value) in &self.variables {
653 let pattern = format!("${}", name);
654 let replacement = match value {
655 ContextValue::String(s) => s.clone(),
656 ContextValue::Number(n) => n.to_string(),
657 ContextValue::Bool(b) => b.to_string(),
658 _ => value.to_string(),
659 };
660 result = result.replace(&pattern, &replacement);
661 }
662
663 result
664 }
665
666 pub fn age(&self) -> Duration {
668 SystemTime::now()
669 .duration_since(self.started_at)
670 .unwrap_or_default()
671 }
672
673 pub fn idle_time(&self) -> Duration {
675 self.last_activity.elapsed()
676 }
677
678 pub fn is_expired(&self, idle_timeout: Duration) -> bool {
680 self.idle_time() > idle_timeout
681 }
682
683 pub fn export_audit(&self) -> Vec<HashMap<String, String>> {
685 self.audit
686 .iter()
687 .map(|entry| {
688 let mut m = HashMap::new();
689 m.insert(
690 "timestamp".into(),
691 entry
692 .timestamp
693 .duration_since(SystemTime::UNIX_EPOCH)
694 .map(|d| d.as_secs().to_string())
695 .unwrap_or_default(),
696 );
697 m.insert("operation".into(), entry.operation.to_string());
698 m.insert("resource".into(), entry.resource.clone());
699 m.insert(
700 "result".into(),
701 match &entry.result {
702 AuditResult::Success => "success".into(),
703 AuditResult::Error(e) => format!("error:{}", e),
704 AuditResult::Denied(r) => format!("denied:{}", r),
705 },
706 );
707 for (k, v) in &entry.metadata {
708 m.insert(k.clone(), v.clone());
709 }
710 m
711 })
712 .collect()
713 }
714}
715
716pub struct SessionManager {
718 sessions: RwLock<HashMap<SessionId, Arc<RwLock<AgentContext>>>>,
719 idle_timeout: Duration,
720}
721
722impl SessionManager {
723 pub fn new(idle_timeout: Duration) -> Self {
725 Self {
726 sessions: RwLock::new(HashMap::new()),
727 idle_timeout,
728 }
729 }
730
731 pub fn create_session(&self, session_id: SessionId) -> Arc<RwLock<AgentContext>> {
733 let ctx = Arc::new(RwLock::new(AgentContext::new(session_id.clone())));
734 self.sessions
735 .write()
736 .unwrap()
737 .insert(session_id, ctx.clone());
738 ctx
739 }
740
741 pub fn get_session(&self, session_id: &str) -> Option<Arc<RwLock<AgentContext>>> {
743 let sessions = self.sessions.read().unwrap();
744 sessions.get(session_id).cloned()
745 }
746
747 pub fn get_or_create(&self, session_id: SessionId) -> Arc<RwLock<AgentContext>> {
749 if let Some(ctx) = self.get_session(&session_id) {
750 return ctx;
751 }
752 self.create_session(session_id)
753 }
754
755 pub fn remove_session(&self, session_id: &str) -> Option<Arc<RwLock<AgentContext>>> {
757 self.sessions.write().unwrap().remove(session_id)
758 }
759
760 pub fn cleanup_expired(&self) -> usize {
762 let mut sessions = self.sessions.write().unwrap();
763 let initial_count = sessions.len();
764
765 sessions.retain(|_, ctx| {
766 let ctx = ctx.read().unwrap();
767 !ctx.is_expired(self.idle_timeout)
768 });
769
770 initial_count - sessions.len()
771 }
772
773 pub fn session_count(&self) -> usize {
775 self.sessions.read().unwrap().len()
776 }
777}
778
779impl Default for SessionManager {
780 fn default() -> Self {
781 Self::new(Duration::from_secs(3600)) }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 #[test]
790 fn test_context_creation() {
791 let ctx = AgentContext::new("test-session".into());
792 assert_eq!(ctx.session_id, "test-session");
793 assert_eq!(ctx.working_dir, "/agents/test-session");
794 }
795
796 #[test]
797 fn test_variables() {
798 let mut ctx = AgentContext::new("test".into());
799 ctx.set_var("model", ContextValue::String("gpt-4".into()));
800 ctx.set_var("budget", ContextValue::Number(1000.0));
801
802 assert_eq!(
803 ctx.get_var("model"),
804 Some(ContextValue::String("gpt-4".into()))
805 );
806 assert_eq!(ctx.get_var("budget"), Some(ContextValue::Number(1000.0)));
807 }
808
809 #[test]
810 fn test_variable_substitution() {
811 let mut ctx = AgentContext::new("test".into());
812 ctx.set_var("name", ContextValue::String("Alice".into()));
813 ctx.set_var("count", ContextValue::Number(42.0));
814
815 let result = ctx.substitute_vars("Hello $name, you have $count items");
816 assert_eq!(result, "Hello Alice, you have 42 items");
817 }
818
819 #[test]
820 fn test_path_resolution() {
821 let ctx = AgentContext::with_working_dir("test".into(), "/home/agent".into());
822
823 assert_eq!(ctx.resolve_path("data.json"), "/home/agent/data.json");
824 assert_eq!(ctx.resolve_path("/absolute/path"), "/absolute/path");
825 }
826
827 #[test]
828 fn test_permissions() {
829 let mut ctx = AgentContext::new("test".into());
830 ctx.permissions.filesystem.read = true;
831 ctx.permissions.filesystem.allowed_paths = vec!["/allowed".into()];
832
833 assert!(
834 ctx.check_fs_permission("/allowed/file", AuditOperation::FsRead)
835 .is_ok()
836 );
837 assert!(
838 ctx.check_fs_permission("/forbidden/file", AuditOperation::FsRead)
839 .is_err()
840 );
841 assert!(
842 ctx.check_fs_permission("/allowed/file", AuditOperation::FsWrite)
843 .is_err()
844 );
845 }
846
847 #[test]
848 fn test_budget() {
849 let mut ctx = AgentContext::new("test".into());
850 ctx.budget.max_operations = Some(3);
851
852 assert!(ctx.consume_budget(100, 10).is_ok());
853 assert!(ctx.consume_budget(100, 10).is_ok());
854 assert!(ctx.consume_budget(100, 10).is_ok());
855 assert!(ctx.consume_budget(100, 10).is_err());
856 }
857
858 #[test]
859 fn test_transaction() {
860 let mut ctx = AgentContext::new("test".into());
861
862 assert!(ctx.begin_transaction(1).is_ok());
863 assert!(ctx.begin_transaction(2).is_err()); ctx.record_pending_write(
866 ResourceType::File,
867 "/test/file".into(),
868 Some(b"original".to_vec()),
869 )
870 .unwrap();
871
872 let pending = ctx.rollback_transaction().unwrap();
873 assert_eq!(pending.len(), 1);
874 }
875
876 #[test]
877 fn test_session_manager() {
878 let mgr = SessionManager::default();
879
880 let _s1 = mgr.create_session("s1".into());
881 let _s2 = mgr.create_session("s2".into());
882
883 assert_eq!(mgr.session_count(), 2);
884 assert!(mgr.get_session("s1").is_some());
885 assert!(mgr.get_session("s3").is_none());
886
887 mgr.remove_session("s1");
888 assert_eq!(mgr.session_count(), 1);
889 }
890}