Skip to main content

sh_layer2/session_manager/
context.rs

1//! # Execution Context
2//!
3//! Agent 执行上下文,支持序列化和恢复。
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::types::{AgentId, AgentState, Message, SessionId, ToolCall, ToolResult};
10
11/// 执行上下文
12///
13/// Agent 执行的完整上下文,包含所有必要的状态信息。
14/// 这是 Python 版本 ExecutionContext 的 Rust 移植。
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ExecutionContext {
17    /// 会话 ID
18    pub session_id: SessionId,
19    /// Agent ID
20    pub agent_id: AgentId,
21
22    // 状态
23    pub state: AgentState,
24    pub iteration: i32,
25    pub max_iterations: i32,
26
27    /// 消息历史(OpenAI 格式)
28    pub messages: Vec<Message>,
29
30    /// Tool 管理
31    pub tools_registered: Vec<String>,
32    pub tool_calls_pending: Vec<ToolCall>,
33    pub tool_results_cache: Vec<ToolResult>,
34
35    /// 配置快照
36    pub model: String,
37    pub temperature: f32,
38    pub system_prompt: String,
39
40    /// 追踪数据
41    pub tokens_total: i64,
42    pub tokens_prompt: i64,
43    pub tokens_completion: i64,
44    pub cost_estimate: f64,
45
46    /// 元数据
47    pub created_at: DateTime<Utc>,
48    pub last_updated: DateTime<Utc>,
49    pub checkpoint_count: i32,
50
51    /// 扩展数据(用于存储额外信息)
52    #[serde(default)]
53    pub metadata: HashMap<String, serde_json::Value>,
54}
55
56impl ExecutionContext {
57    /// 创建新的执行上下文
58    pub fn new() -> Self {
59        let now = Utc::now();
60        Self {
61            session_id: SessionId::new(),
62            agent_id: AgentId::new(),
63            state: AgentState::Idle,
64            iteration: 0,
65            max_iterations: 100,
66            messages: Vec::new(),
67            tools_registered: Vec::new(),
68            tool_calls_pending: Vec::new(),
69            tool_results_cache: Vec::new(),
70            model: "claude-sonnet-4-6".to_string(),
71            temperature: 0.7,
72            system_prompt: String::new(),
73            tokens_total: 0,
74            tokens_prompt: 0,
75            tokens_completion: 0,
76            cost_estimate: 0.0,
77            created_at: now,
78            last_updated: now,
79            checkpoint_count: 0,
80            metadata: HashMap::new(),
81        }
82    }
83
84    /// 使用配置创建执行上下文
85    pub fn with_config(
86        model: impl Into<String>,
87        temperature: f32,
88        system_prompt: impl Into<String>,
89    ) -> Self {
90        let mut ctx = Self::new();
91        ctx.model = model.into();
92        ctx.temperature = temperature;
93        ctx.system_prompt = system_prompt.into();
94        ctx
95    }
96
97    /// 添加消息
98    pub fn add_message(&mut self, role: &str, content: &str) {
99        use crate::types::MessageRole;
100        let role = match role {
101            "user" => MessageRole::User,
102            "assistant" => MessageRole::Assistant,
103            "system" => MessageRole::System,
104            "tool" => MessageRole::Tool,
105            _ => MessageRole::User,
106        };
107        self.messages.push(Message::new(role, content));
108        self.iteration += 1;
109        self.touch();
110    }
111
112    /// 更新最后修改时间
113    pub fn touch(&mut self) {
114        self.last_updated = Utc::now();
115    }
116
117    /// 增加检查点计数
118    pub fn increment_checkpoint(&mut self) {
119        self.checkpoint_count += 1;
120        self.touch();
121    }
122
123    /// 添加 token 使用
124    pub fn add_tokens(&mut self, prompt: i64, completion: i64) {
125        self.tokens_prompt += prompt;
126        self.tokens_completion += completion;
127        self.tokens_total += prompt + completion;
128        self.touch();
129    }
130
131    /// 序列化为字典
132    pub fn to_dict(&self) -> serde_json::Value {
133        serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
134    }
135
136    /// 从字典恢复
137    pub fn from_dict(data: &serde_json::Value) -> serde_json::Result<Self> {
138        serde_json::from_value(data.clone())
139    }
140
141    /// 转换为 JSON 字符串
142    pub fn to_json(&self) -> serde_json::Result<String> {
143        serde_json::to_string_pretty(self)
144    }
145
146    /// 从 JSON 字符串恢复
147    pub fn from_json(json: &str) -> serde_json::Result<Self> {
148        serde_json::from_str(json)
149    }
150
151    /// 设置状态
152    pub fn set_state(&mut self, state: AgentState) {
153        self.state = state;
154        self.touch();
155    }
156
157    /// 检查是否可以继续
158    pub fn can_continue(&self) -> bool {
159        self.iteration < self.max_iterations
160            && matches!(
161                self.state,
162                AgentState::Running | AgentState::Idle | AgentState::WaitingTool
163            )
164    }
165
166    /// 获取消息数量
167    pub fn message_count(&self) -> usize {
168        self.messages.len()
169    }
170}
171
172impl Default for ExecutionContext {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_execution_context_creation() {
184        let ctx = ExecutionContext::new();
185        assert_eq!(ctx.state, AgentState::Idle);
186        assert_eq!(ctx.iteration, 0);
187        assert!(ctx.messages.is_empty());
188    }
189
190    #[test]
191    fn test_add_message() {
192        let mut ctx = ExecutionContext::new();
193        ctx.add_message("user", "Hello");
194
195        assert_eq!(ctx.messages.len(), 1);
196        assert_eq!(ctx.iteration, 1);
197    }
198
199    #[test]
200    fn test_add_tokens() {
201        let mut ctx = ExecutionContext::new();
202        ctx.add_tokens(100, 50);
203
204        assert_eq!(ctx.tokens_prompt, 100);
205        assert_eq!(ctx.tokens_completion, 50);
206        assert_eq!(ctx.tokens_total, 150);
207    }
208
209    #[test]
210    fn test_serialization() {
211        let ctx = ExecutionContext::new();
212        let json = ctx.to_json().unwrap();
213        let restored = ExecutionContext::from_json(&json).unwrap();
214
215        assert_eq!(ctx.session_id, restored.session_id);
216    }
217}