Skip to main content

sh_layer2/session_manager/
session.rs

1//! # Session Definition
2//!
3//! 会话结构定义。
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use crate::types::{AgentId, AgentState, Message, MessageRole, SessionId, ToolCall, ToolResult};
9
10/// 会话配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SessionConfig {
13    pub model: String,
14    pub temperature: f32,
15    pub max_iterations: i32,
16    pub system_prompt: Option<String>,
17    /// 最大消息数量(防止内存无限增长)
18    #[serde(default = "default_max_messages")]
19    pub max_messages: usize,
20    /// 最大工具注册数量
21    #[serde(default = "default_max_tools")]
22    pub max_tools: usize,
23}
24
25fn default_max_messages() -> usize {
26    1000
27}
28fn default_max_tools() -> usize {
29    100
30}
31
32impl Default for SessionConfig {
33    fn default() -> Self {
34        Self {
35            model: "claude-sonnet-4-6".to_string(),
36            temperature: 0.7,
37            max_iterations: 100,
38            system_prompt: None,
39            max_messages: 1000,
40            max_tools: 100,
41        }
42    }
43}
44
45/// 会话状态
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Session {
48    /// 会话 ID
49    pub session_id: SessionId,
50    /// Agent ID
51    pub agent_id: AgentId,
52    /// 当前状态
53    pub state: AgentState,
54    /// 当前迭代次数
55    pub iteration: i32,
56    /// 最大迭代次数
57    pub max_iterations: i32,
58    /// 消息历史
59    pub messages: Vec<Message>,
60    /// 已注册工具
61    pub tools_registered: Vec<String>,
62    /// 待处理的工具调用
63    pub tool_calls_pending: Vec<ToolCall>,
64    /// 工具结果缓存
65    pub tool_results_cache: Vec<ToolResult>,
66    /// 模型名称
67    pub model: String,
68    /// 温度参数
69    pub temperature: f32,
70    /// 系统提示词
71    pub system_prompt: String,
72    /// Token 使用统计
73    pub tokens_total: i64,
74    pub tokens_prompt: i64,
75    pub tokens_completion: i64,
76    /// 成本估算
77    pub cost_estimate: f64,
78    /// 创建时间
79    pub created_at: DateTime<Utc>,
80    /// 最后更新时间
81    pub last_updated: DateTime<Utc>,
82    /// 检查点计数
83    pub checkpoint_count: i32,
84    /// 最大消息数量
85    #[serde(default = "default_max_messages")]
86    pub max_messages: usize,
87    /// 最大工具数量
88    #[serde(default = "default_max_tools")]
89    pub max_tools: usize,
90}
91
92impl Session {
93    /// 创建新会话
94    pub fn new(config: &SessionConfig) -> Self {
95        let now = Utc::now();
96        Self {
97            session_id: SessionId::new(),
98            agent_id: AgentId::new(),
99            state: AgentState::Idle,
100            iteration: 0,
101            max_iterations: config.max_iterations,
102            messages: Vec::new(),
103            tools_registered: Vec::new(),
104            tool_calls_pending: Vec::new(),
105            tool_results_cache: Vec::new(),
106            model: config.model.clone(),
107            temperature: config.temperature,
108            system_prompt: config.system_prompt.clone().unwrap_or_default(),
109            tokens_total: 0,
110            tokens_prompt: 0,
111            tokens_completion: 0,
112            cost_estimate: 0.0,
113            created_at: now,
114            last_updated: now,
115            checkpoint_count: 0,
116            max_messages: config.max_messages,
117            max_tools: config.max_tools,
118        }
119    }
120
121    /// 添加用户消息
122    pub fn add_user_message(&mut self, content: &str) {
123        self.messages.push(Message::user(content));
124        self.trim_messages();
125        self.iteration += 1;
126        self.touch();
127    }
128
129    /// 添加助手消息
130    pub fn add_assistant_message(&mut self, content: &str) {
131        self.messages.push(Message::assistant(content));
132        self.trim_messages();
133        self.touch();
134    }
135
136    /// 添加系统消息
137    pub fn add_system_message(&mut self, content: &str) {
138        self.messages.push(Message::system(content));
139        self.trim_messages();
140        self.touch();
141    }
142
143    /// 当消息超过上限时,删除最旧的消息(保留第一条系统消息)
144    fn trim_messages(&mut self) {
145        if self.messages.len() > self.max_messages {
146            let excess = self.messages.len() - self.max_messages;
147            // 保留第一条消息(通常是系统提示)
148            let first_is_system = self
149                .messages
150                .first()
151                .map(|m| m.role == MessageRole::System)
152                .unwrap_or(false);
153
154            if first_is_system && excess > 0 {
155                // 删除第二条到第excess+1条
156                self.messages.drain(1..=excess.min(self.messages.len() - 1));
157            } else {
158                // 删除最旧的excess条
159                self.messages.drain(0..excess);
160            }
161        }
162    }
163
164    /// 注册工具,如果超过上限则移除最旧的
165    pub fn register_tool(&mut self, tool_name: &str) {
166        if !self.tools_registered.contains(&tool_name.to_string()) {
167            if self.tools_registered.len() >= self.max_tools {
168                // 移除最旧的工具
169                self.tools_registered.remove(0);
170            }
171            self.tools_registered.push(tool_name.to_string());
172            self.touch();
173        }
174    }
175
176    /// 更新最后修改时间
177    pub fn touch(&mut self) {
178        self.last_updated = Utc::now();
179    }
180
181    /// 检查是否可以继续执行
182    pub fn can_continue(&self) -> bool {
183        self.iteration < self.max_iterations
184            && matches!(self.state, AgentState::Running | AgentState::Idle)
185    }
186
187    /// 序列化为 JSON
188    pub fn to_json(&self) -> serde_json::Result<String> {
189        serde_json::to_string_pretty(self)
190    }
191
192    /// 从 JSON 反序列化
193    pub fn from_json(json: &str) -> serde_json::Result<Self> {
194        serde_json::from_str(json)
195    }
196
197    /// 转换为字典(兼容 Python 版本)
198    pub fn to_dict(&self) -> serde_json::Value {
199        serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
200    }
201}
202
203impl Default for Session {
204    fn default() -> Self {
205        Self::new(&SessionConfig::default())
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_session_creation() {
215        let config = SessionConfig::default();
216        let session = Session::new(&config);
217
218        assert!(session.messages.is_empty());
219        assert_eq!(session.state, AgentState::Idle);
220        assert_eq!(session.iteration, 0);
221    }
222
223    #[test]
224    fn test_session_messages() {
225        let config = SessionConfig::default();
226        let mut session = Session::new(&config);
227
228        session.add_user_message("Hello");
229        assert_eq!(session.messages.len(), 1);
230        assert_eq!(session.iteration, 1);
231
232        session.add_assistant_message("Hi there!");
233        assert_eq!(session.messages.len(), 2);
234    }
235
236    #[test]
237    fn test_session_can_continue() {
238        let config = SessionConfig {
239            max_iterations: 5,
240            ..Default::default()
241        };
242
243        let mut session = Session::new(&config);
244        assert!(session.can_continue());
245
246        session.state = AgentState::Running;
247        assert!(session.can_continue());
248
249        session.state = AgentState::Stopped;
250        assert!(!session.can_continue());
251    }
252
253    #[test]
254    fn test_session_serialization() {
255        let config = SessionConfig::default();
256        let session = Session::new(&config);
257
258        let json = session.to_json().unwrap();
259        let restored = Session::from_json(&json).unwrap();
260
261        assert_eq!(session.session_id, restored.session_id);
262        assert_eq!(session.state, restored.state);
263    }
264
265    #[test]
266    fn test_session_max_messages_limit() {
267        let config = SessionConfig {
268            max_messages: 5,
269            ..Default::default()
270        };
271        let mut session = Session::new(&config);
272
273        // 添加超过上限的消息
274        for i in 0..10 {
275            session.add_user_message(&format!("Message {}", i));
276        }
277
278        // 应该被限制在 max_messages
279        assert_eq!(session.messages.len(), 5);
280    }
281
282    #[test]
283    fn test_session_preserves_system_message() {
284        let config = SessionConfig {
285            max_messages: 3,
286            system_prompt: Some("System prompt".to_string()),
287            ..Default::default()
288        };
289        let mut session = Session::new(&config);
290
291        session.add_system_message("System prompt");
292        for i in 0..5 {
293            session.add_user_message(&format!("User {}", i));
294        }
295
296        // 第一条系统消息应该保留
297        assert_eq!(session.messages.len(), 3);
298        assert!(session
299            .messages
300            .first()
301            .map(|m| m.role == MessageRole::System)
302            .unwrap_or(false));
303    }
304
305    #[test]
306    fn test_session_max_tools_limit() {
307        let config = SessionConfig {
308            max_tools: 3,
309            ..Default::default()
310        };
311        let mut session = Session::new(&config);
312
313        for i in 0..5 {
314            session.register_tool(&format!("tool_{}", i));
315        }
316
317        // 应该被限制在 max_tools
318        assert_eq!(session.tools_registered.len(), 3);
319        // 最旧的工具应该被移除
320        assert!(!session.tools_registered.contains(&"tool_0".to_string()));
321        assert!(!session.tools_registered.contains(&"tool_1".to_string()));
322    }
323
324    #[test]
325    fn test_session_no_duplicate_tools() {
326        let config = SessionConfig::default();
327        let mut session = Session::new(&config);
328
329        session.register_tool("tool_a");
330        session.register_tool("tool_a");
331        session.register_tool("tool_a");
332
333        assert_eq!(session.tools_registered.len(), 1);
334    }
335}