1use crate::core::agent::error_recovery::ErrorRecoveryState;
4use crate::core::agent::task::{TaskOutcome, TaskResults};
5use crate::exec::events::Usage;
6use crate::llm::provider::{Message, ResponsesContinuationState, responses_continuation_key};
7use crate::llm::providers::gemini::wire::{Content, FunctionResponse, Part};
8use hashbrown::HashMap;
9use parking_lot::Mutex;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use vtcode_exec_events::ThreadEvent;
13
14pub struct AgentSessionState {
17 pub session_id: String,
19
20 pub conversation: Vec<Content>,
22
23 pub messages: Vec<Message>,
25
26 pub stats: SessionStats,
28
29 pub constraints: SessionConstraints,
31
32 pub outcome: TaskOutcome,
34 pub stop_reason: Option<String>,
36 pub total_cost_usd: Option<f64>,
38
39 pub is_completed: bool,
41
42 pub current_stage: Option<String>,
44
45 pub created_contexts: Vec<String>,
47 pub modified_files: Vec<String>,
48 pub executed_commands: Vec<String>,
49 pub warnings: Vec<String>,
50 pub last_file_path: Option<String>,
51 pub last_dir_path: Option<String>,
52
53 pub consecutive_tool_loops: usize,
55 pub tool_loop_limit_hit: bool,
56 pub last_processed_message_idx: usize,
57 pub previous_response_chains: HashMap<(String, String), ResponsesContinuationState>,
59 pub error_recovery: Arc<Mutex<ErrorRecoveryState>>,
61
62 pub consecutive_idle_turns: usize,
64 pub max_tool_loop_streak: usize,
65 pub turn_count: usize,
66 pub turn_total_ms: u128,
67 pub turn_max_ms: u128,
68 pub turn_durations_ms: Vec<u128>,
69}
70
71#[derive(Debug, Default, Clone)]
73pub struct SessionStats {
74 pub turns_executed: usize,
75 pub total_duration: Duration,
76 pub turn_durations: Vec<Duration>,
77 pub total_usage: Usage,
78}
79
80impl SessionStats {
81 pub fn merge_usage(&mut self, usage: crate::llm::provider::Usage) {
82 self.total_usage.input_tokens = self
83 .total_usage
84 .input_tokens
85 .saturating_add(usage.prompt_tokens as u64);
86 self.total_usage.output_tokens = self
87 .total_usage
88 .output_tokens
89 .saturating_add(usage.completion_tokens as u64);
90 let cached = usage.cache_read_tokens_or_fallback();
91 if cached > 0 {
92 self.total_usage.cached_input_tokens = self
93 .total_usage
94 .cached_input_tokens
95 .saturating_add(cached as u64);
96 }
97 let cache_creation = usage.cache_creation_tokens_or_zero();
98 if cache_creation > 0 {
99 self.total_usage.cache_creation_tokens = self
100 .total_usage
101 .cache_creation_tokens
102 .saturating_add(cache_creation as u64);
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct SessionConstraints {
110 pub max_turns: usize,
111 pub max_tool_loops: usize,
112 pub max_context_tokens: usize,
113}
114
115impl AgentSessionState {
116 pub fn new(
117 session_id: String,
118 max_turns: usize,
119 max_tool_loops: usize,
120 max_context_tokens: usize,
121 ) -> Self {
122 Self {
123 session_id,
124 conversation: Vec::new(),
125 messages: Vec::new(),
126 stats: SessionStats::default(),
127 constraints: SessionConstraints {
128 max_turns,
129 max_tool_loops,
130 max_context_tokens,
131 },
132 outcome: TaskOutcome::Unknown,
133 stop_reason: None,
134 total_cost_usd: None,
135 is_completed: false,
136 current_stage: None,
137 created_contexts: Vec::with_capacity(16),
138 modified_files: Vec::with_capacity(32),
139 executed_commands: Vec::with_capacity(64),
140 warnings: Vec::with_capacity(16),
141 last_file_path: None,
142 last_dir_path: None,
143 consecutive_tool_loops: 0,
144 tool_loop_limit_hit: false,
145 last_processed_message_idx: 0,
146 previous_response_chains: HashMap::new(),
147 error_recovery: Arc::new(Mutex::new(ErrorRecoveryState::default())),
148 consecutive_idle_turns: 0,
149 max_tool_loop_streak: 0,
150 turn_count: 0,
151 turn_total_ms: 0,
152 turn_max_ms: 0,
153 turn_durations_ms: Vec::with_capacity(max_turns),
154 }
155 }
156
157 pub fn record_turn(&mut self, start: &Instant, recorded: &mut bool) {
159 if *recorded {
160 return;
161 }
162 let duration = start.elapsed();
163 let ms = duration.as_millis() as u64;
164
165 self.stats.turns_executed += 1;
166 self.stats.total_duration += duration;
167 self.stats.turn_durations.push(duration);
168
169 self.turn_count += 1;
171 self.turn_total_ms += ms as u128;
172 self.turn_max_ms = self.turn_max_ms.max(ms as u128);
173 self.turn_durations_ms.push(ms as u128);
174
175 *recorded = true;
176 }
177
178 pub fn finalize_outcome(&mut self, max_turns: usize) {
179 if self.outcome != TaskOutcome::Unknown {
180 return;
181 }
182 if self.tool_loop_limit_hit {
184 self.outcome = TaskOutcome::tool_loop_limit_reached(
185 self.constraints.max_tool_loops,
186 self.consecutive_tool_loops,
187 );
188 } else if self.is_completed {
189 self.outcome = TaskOutcome::Success;
190 } else if self.stats.turns_executed >= max_turns {
191 self.outcome = TaskOutcome::turn_limit_reached(max_turns, self.stats.turns_executed);
192 }
193 }
194
195 pub fn register_tool_loop(&mut self) -> usize {
196 self.consecutive_tool_loops += 1;
197 self.max_tool_loop_streak = self.max_tool_loop_streak.max(self.consecutive_tool_loops);
198 self.consecutive_tool_loops
199 }
200
201 pub fn reset_tool_loop_guard(&mut self) {
202 self.consecutive_tool_loops = 0;
203 }
204
205 pub fn previous_response_id_for(&self, provider: &str, model: &str) -> Option<String> {
206 self.previous_response_chain_for(provider, model)
207 .map(|chain| chain.response_id.clone())
208 }
209
210 pub fn previous_response_chain_for(
211 &self,
212 provider: &str,
213 model: &str,
214 ) -> Option<&ResponsesContinuationState> {
215 responses_continuation_key(provider, model)
216 .and_then(|key| self.previous_response_chains.get(&key))
217 }
218
219 pub fn set_previous_response_chain(
220 &mut self,
221 provider: &str,
222 model: &str,
223 response_id: Option<&str>,
224 messages: Vec<Message>,
225 ) {
226 let Some(key) = responses_continuation_key(provider, model) else {
227 return;
228 };
229 let Some(response_id) = response_id.map(str::trim).filter(|value| !value.is_empty()) else {
230 self.previous_response_chains.remove(&key);
231 return;
232 };
233
234 self.previous_response_chains.insert(
235 key,
236 ResponsesContinuationState {
237 response_id: response_id.to_string(),
238 messages,
239 },
240 );
241 }
242
243 pub fn clear_previous_response_chain_for(&mut self, provider: &str, model: &str) {
244 if let Some(key) = responses_continuation_key(provider, model) {
245 self.previous_response_chains.remove(&key);
246 }
247 }
248
249 pub fn clear_previous_response_chain(&mut self) {
250 self.previous_response_chains.clear();
251 }
252
253 pub fn mark_tool_loop_limit_hit(&mut self) {
254 if self.tool_loop_limit_hit {
255 return;
256 }
257 self.tool_loop_limit_hit = true;
258 self.outcome = TaskOutcome::tool_loop_limit_reached(
259 self.constraints.max_tool_loops,
260 self.consecutive_tool_loops,
261 );
262 }
263
264 pub fn add_user_message(&mut self, text: String) {
266 self.conversation.push(Content::user_text(text.as_str()));
267 self.messages.push(Message::user(text));
268 }
269
270 pub fn utilization(&self) -> f64 {
272 if self.constraints.max_context_tokens == 0 {
273 return 0.0;
274 }
275 self.total_tokens() as f64 / self.constraints.max_context_tokens as f64
276 }
277
278 pub fn total_tokens(&self) -> usize {
280 self.messages.iter().map(|m| m.estimate_tokens()).sum()
281 }
282
283 pub fn find_safe_split_point(&self, preferred_split_at: usize) -> usize {
285 crate::core::agent::state::safe_history_split_point(
286 &self.messages,
287 self.conversation.len(),
288 preferred_split_at,
289 )
290 }
291
292 pub fn normalize(&mut self) {
294 crate::core::agent::state::normalize_history(&mut self.messages);
295 }
296
297 pub fn into_results(
298 self,
299 summary: String,
300 thread_events: Vec<ThreadEvent>,
301 total_duration_ms: u128,
302 ) -> TaskResults {
303 let average_turn_duration_ms = if self.turn_count > 0 {
304 Some(self.turn_total_ms as f64 / self.turn_count as f64)
305 } else {
306 None
307 };
308 let max_turn_duration_ms = if self.turn_count > 0 {
309 Some(self.turn_max_ms)
310 } else {
311 None
312 };
313
314 TaskResults {
315 created_contexts: self.created_contexts,
316 modified_files: self.modified_files,
317 executed_commands: self.executed_commands,
318 summary,
319 stop_reason: self.stop_reason,
320 total_cost_usd: self.total_cost_usd,
321 warnings: self.warnings,
322 thread_events,
323 outcome: self.outcome,
324 turns_executed: self.stats.turns_executed,
325 total_duration_ms,
326 average_turn_duration_ms,
327 max_turn_duration_ms,
328 turn_durations_ms: self.turn_durations_ms,
329 }
330 }
331
332 pub fn push_tool_result(
338 &mut self,
339 call_id: String,
340 tool_name: &str,
341 result: &serde_json::Value,
342 is_gemini: bool,
343 ) {
344 if is_gemini {
345 self.conversation.push(Content {
346 role: "function".to_string(),
347 parts: vec![Part::FunctionResponse {
348 function_response: FunctionResponse {
349 name: tool_name.to_string(),
350 response: result.clone(),
351 id: Some(call_id.clone()),
352 },
353 thought_signature: None,
354 }],
355 });
356 }
357 let serialized = serde_json::to_string(result).expect("Value serialization is infallible");
358 self.messages
359 .push(Message::tool_response(call_id, serialized));
360 self.executed_commands.push(tool_name.to_owned());
361 }
362
363 pub fn push_tool_error(
368 &mut self,
369 call_id: String,
370 tool_name: &str,
371 error_payload: &serde_json::Value,
372 is_gemini: bool,
373 ) {
374 if is_gemini {
375 self.conversation.push(Content {
376 role: "function".to_string(),
377 parts: vec![Part::FunctionResponse {
378 function_response: FunctionResponse {
379 name: tool_name.to_string(),
380 response: error_payload.clone(),
381 id: Some(call_id.clone()),
382 },
383 thought_signature: None,
384 }],
385 });
386 }
387 let serialized =
388 serde_json::to_string(error_payload).expect("Value serialization is infallible");
389 self.messages
390 .push(Message::tool_response(call_id, serialized));
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::AgentSessionState;
397 use crate::llm::provider::Message;
398 use crate::llm::providers::gemini::wire::Part;
399
400 #[test]
401 fn previous_response_chain_is_scoped_to_provider_and_model() {
402 let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
403 let messages_52 = vec![Message::user("hello".to_string())];
404 let messages_54 = vec![Message::user("continue".to_string())];
405
406 state.set_previous_response_chain(
407 "openai",
408 "gpt-5.2",
409 Some("resp_123"),
410 messages_52.clone(),
411 );
412 state.set_previous_response_chain(
413 "openai",
414 "gpt-5.4",
415 Some("resp_456"),
416 messages_54.clone(),
417 );
418
419 assert_eq!(
420 state.previous_response_id_for("openai", "gpt-5.2"),
421 Some("resp_123".to_string())
422 );
423 assert_eq!(
424 state.previous_response_id_for("openai", "gpt-5.4"),
425 Some("resp_456".to_string())
426 );
427 assert_eq!(state.previous_response_id_for("gemini", "gpt-5.2"), None);
428
429 state.clear_previous_response_chain_for("openai", "gpt-5.2");
430
431 assert_eq!(state.previous_response_id_for("openai", "gpt-5.2"), None);
432 assert_eq!(state.previous_response_chain_for("openai", "gpt-5.2"), None);
433 assert_eq!(
434 state.previous_response_id_for("openai", "gpt-5.4"),
435 Some("resp_456".to_string())
436 );
437 assert_eq!(
438 state
439 .previous_response_chain_for("openai", "gpt-5.4")
440 .map(|chain| chain.messages.as_slice()),
441 Some(messages_54.as_slice())
442 );
443
444 state.clear_previous_response_chain();
445 assert_eq!(state.previous_response_id_for("openai", "gpt-5.4"), None);
446 assert_eq!(state.previous_response_chain_for("openai", "gpt-5.4"), None);
447 }
448
449 #[test]
450 fn register_tool_loop_tracks_current_and_max_streak() {
451 let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
452
453 assert_eq!(state.register_tool_loop(), 1);
454 assert_eq!(state.register_tool_loop(), 2);
455 assert_eq!(state.consecutive_tool_loops, 2);
456 assert_eq!(state.max_tool_loop_streak, 2);
457
458 state.reset_tool_loop_guard();
459 assert_eq!(state.register_tool_loop(), 1);
460 assert_eq!(state.max_tool_loop_streak, 2);
461 }
462
463 #[test]
464 fn push_tool_error_preserves_structured_json_for_gemini() {
465 let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
466 let payload = serde_json::json!({
467 "error": {
468 "tool_name": "read_file",
469 "message": "missing file",
470 "category": "ResourceNotFound"
471 }
472 });
473
474 state.push_tool_error("call_1".to_string(), "read_file", &payload, true);
475
476 match &state.conversation[0].parts[0] {
477 Part::FunctionResponse {
478 function_response, ..
479 } => {
480 assert_eq!(
481 function_response.response["error"]["message"],
482 "missing file"
483 );
484 }
485 other => panic!("expected function response, got {other:?}"),
486 }
487 let expected_serialized = serde_json::to_string(&payload).unwrap();
488 assert_eq!(
489 state.messages[0],
490 Message::tool_response("call_1".to_string(), expected_serialized)
491 );
492 }
493}