sh_layer2/session_manager/
session.rs1use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use crate::types::{AgentId, AgentState, Message, MessageRole, SessionId, ToolCall, ToolResult};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SessionConfig {
13 pub model: String,
14 pub temperature: f32,
15 pub max_iterations: i32,
16 pub system_prompt: Option<String>,
17 #[serde(default = "default_max_messages")]
19 pub max_messages: usize,
20 #[serde(default = "default_max_tools")]
22 pub max_tools: usize,
23}
24
25fn default_max_messages() -> usize {
26 1000
27}
28fn default_max_tools() -> usize {
29 100
30}
31
32impl Default for SessionConfig {
33 fn default() -> Self {
34 Self {
35 model: "claude-sonnet-4-6".to_string(),
36 temperature: 0.7,
37 max_iterations: 100,
38 system_prompt: None,
39 max_messages: 1000,
40 max_tools: 100,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Session {
48 pub session_id: SessionId,
50 pub agent_id: AgentId,
52 pub state: AgentState,
54 pub iteration: i32,
56 pub max_iterations: i32,
58 pub messages: Vec<Message>,
60 pub tools_registered: Vec<String>,
62 pub tool_calls_pending: Vec<ToolCall>,
64 pub tool_results_cache: Vec<ToolResult>,
66 pub model: String,
68 pub temperature: f32,
70 pub system_prompt: String,
72 pub tokens_total: i64,
74 pub tokens_prompt: i64,
75 pub tokens_completion: i64,
76 pub cost_estimate: f64,
78 pub created_at: DateTime<Utc>,
80 pub last_updated: DateTime<Utc>,
82 pub checkpoint_count: i32,
84 #[serde(default = "default_max_messages")]
86 pub max_messages: usize,
87 #[serde(default = "default_max_tools")]
89 pub max_tools: usize,
90}
91
92impl Session {
93 pub fn new(config: &SessionConfig) -> Self {
95 let now = Utc::now();
96 Self {
97 session_id: SessionId::new(),
98 agent_id: AgentId::new(),
99 state: AgentState::Idle,
100 iteration: 0,
101 max_iterations: config.max_iterations,
102 messages: Vec::new(),
103 tools_registered: Vec::new(),
104 tool_calls_pending: Vec::new(),
105 tool_results_cache: Vec::new(),
106 model: config.model.clone(),
107 temperature: config.temperature,
108 system_prompt: config.system_prompt.clone().unwrap_or_default(),
109 tokens_total: 0,
110 tokens_prompt: 0,
111 tokens_completion: 0,
112 cost_estimate: 0.0,
113 created_at: now,
114 last_updated: now,
115 checkpoint_count: 0,
116 max_messages: config.max_messages,
117 max_tools: config.max_tools,
118 }
119 }
120
121 pub fn add_user_message(&mut self, content: &str) {
123 self.messages.push(Message::user(content));
124 self.trim_messages();
125 self.iteration += 1;
126 self.touch();
127 }
128
129 pub fn add_assistant_message(&mut self, content: &str) {
131 self.messages.push(Message::assistant(content));
132 self.trim_messages();
133 self.touch();
134 }
135
136 pub fn add_system_message(&mut self, content: &str) {
138 self.messages.push(Message::system(content));
139 self.trim_messages();
140 self.touch();
141 }
142
143 fn trim_messages(&mut self) {
145 if self.messages.len() > self.max_messages {
146 let excess = self.messages.len() - self.max_messages;
147 let first_is_system = self
149 .messages
150 .first()
151 .map(|m| m.role == MessageRole::System)
152 .unwrap_or(false);
153
154 if first_is_system && excess > 0 {
155 self.messages.drain(1..=excess.min(self.messages.len() - 1));
157 } else {
158 self.messages.drain(0..excess);
160 }
161 }
162 }
163
164 pub fn register_tool(&mut self, tool_name: &str) {
166 if !self.tools_registered.contains(&tool_name.to_string()) {
167 if self.tools_registered.len() >= self.max_tools {
168 self.tools_registered.remove(0);
170 }
171 self.tools_registered.push(tool_name.to_string());
172 self.touch();
173 }
174 }
175
176 pub fn touch(&mut self) {
178 self.last_updated = Utc::now();
179 }
180
181 pub fn can_continue(&self) -> bool {
183 self.iteration < self.max_iterations
184 && matches!(self.state, AgentState::Running | AgentState::Idle)
185 }
186
187 pub fn to_json(&self) -> serde_json::Result<String> {
189 serde_json::to_string_pretty(self)
190 }
191
192 pub fn from_json(json: &str) -> serde_json::Result<Self> {
194 serde_json::from_str(json)
195 }
196
197 pub fn to_dict(&self) -> serde_json::Value {
199 serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
200 }
201}
202
203impl Default for Session {
204 fn default() -> Self {
205 Self::new(&SessionConfig::default())
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_session_creation() {
215 let config = SessionConfig::default();
216 let session = Session::new(&config);
217
218 assert!(session.messages.is_empty());
219 assert_eq!(session.state, AgentState::Idle);
220 assert_eq!(session.iteration, 0);
221 }
222
223 #[test]
224 fn test_session_messages() {
225 let config = SessionConfig::default();
226 let mut session = Session::new(&config);
227
228 session.add_user_message("Hello");
229 assert_eq!(session.messages.len(), 1);
230 assert_eq!(session.iteration, 1);
231
232 session.add_assistant_message("Hi there!");
233 assert_eq!(session.messages.len(), 2);
234 }
235
236 #[test]
237 fn test_session_can_continue() {
238 let config = SessionConfig {
239 max_iterations: 5,
240 ..Default::default()
241 };
242
243 let mut session = Session::new(&config);
244 assert!(session.can_continue());
245
246 session.state = AgentState::Running;
247 assert!(session.can_continue());
248
249 session.state = AgentState::Stopped;
250 assert!(!session.can_continue());
251 }
252
253 #[test]
254 fn test_session_serialization() {
255 let config = SessionConfig::default();
256 let session = Session::new(&config);
257
258 let json = session.to_json().unwrap();
259 let restored = Session::from_json(&json).unwrap();
260
261 assert_eq!(session.session_id, restored.session_id);
262 assert_eq!(session.state, restored.state);
263 }
264
265 #[test]
266 fn test_session_max_messages_limit() {
267 let config = SessionConfig {
268 max_messages: 5,
269 ..Default::default()
270 };
271 let mut session = Session::new(&config);
272
273 for i in 0..10 {
275 session.add_user_message(&format!("Message {}", i));
276 }
277
278 assert_eq!(session.messages.len(), 5);
280 }
281
282 #[test]
283 fn test_session_preserves_system_message() {
284 let config = SessionConfig {
285 max_messages: 3,
286 system_prompt: Some("System prompt".to_string()),
287 ..Default::default()
288 };
289 let mut session = Session::new(&config);
290
291 session.add_system_message("System prompt");
292 for i in 0..5 {
293 session.add_user_message(&format!("User {}", i));
294 }
295
296 assert_eq!(session.messages.len(), 3);
298 assert!(session
299 .messages
300 .first()
301 .map(|m| m.role == MessageRole::System)
302 .unwrap_or(false));
303 }
304
305 #[test]
306 fn test_session_max_tools_limit() {
307 let config = SessionConfig {
308 max_tools: 3,
309 ..Default::default()
310 };
311 let mut session = Session::new(&config);
312
313 for i in 0..5 {
314 session.register_tool(&format!("tool_{}", i));
315 }
316
317 assert_eq!(session.tools_registered.len(), 3);
319 assert!(!session.tools_registered.contains(&"tool_0".to_string()));
321 assert!(!session.tools_registered.contains(&"tool_1".to_string()));
322 }
323
324 #[test]
325 fn test_session_no_duplicate_tools() {
326 let config = SessionConfig::default();
327 let mut session = Session::new(&config);
328
329 session.register_tool("tool_a");
330 session.register_tool("tool_a");
331 session.register_tool("tool_a");
332
333 assert_eq!(session.tools_registered.len(), 1);
334 }
335}