Skip to main content

vtcode_core/core/agent/session/
mod.rs

1//! Centralized agent session state management.
2
3use crate::core::agent::error_recovery::ErrorRecoveryState;
4use crate::core::agent::task::{TaskOutcome, TaskResults};
5use crate::exec::events::Usage;
6use crate::llm::provider::{Message, ResponsesContinuationState, responses_continuation_key};
7use crate::llm::providers::gemini::wire::{Content, FunctionResponse, Part};
8use hashbrown::HashMap;
9use parking_lot::Mutex;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use vtcode_exec_events::ThreadEvent;
13
14/// Manages the state of an active agent session, including conversation history,
15/// statistics, and turn-based constraints.
16pub struct AgentSessionState {
17    /// The thread or session ID.
18    pub session_id: String,
19
20    /// Provider-specific conversation history (e.g., Gemini style).
21    pub conversation: Vec<Content>,
22
23    /// Standardized conversation messages (OpenAI/Anthropic style).
24    pub messages: Vec<Message>,
25
26    /// Statistics for the current session.
27    pub stats: SessionStats,
28
29    /// Constraints and limits for the session.
30    pub constraints: SessionConstraints,
31
32    /// Outcome of the session if completed.
33    pub outcome: TaskOutcome,
34    /// Provider stop reason associated with the last model turn, when available.
35    pub stop_reason: Option<String>,
36    /// Estimated total API cost in USD for the session, when available.
37    pub total_cost_usd: Option<f64>,
38
39    /// Whether the session has completed.
40    pub is_completed: bool,
41
42    /// Current reasoning stage.
43    pub current_stage: Option<String>,
44
45    // Tracking for side-effects and progress
46    pub created_contexts: Vec<String>,
47    pub modified_files: Vec<String>,
48    pub executed_commands: Vec<String>,
49    pub warnings: Vec<String>,
50    pub last_file_path: Option<String>,
51    pub last_dir_path: Option<String>,
52
53    // Internal loop state
54    pub consecutive_tool_loops: usize,
55    pub tool_loop_limit_hit: bool,
56    pub last_processed_message_idx: usize,
57    /// Responses-style continuation state keyed by normalized provider/model pairs.
58    pub previous_response_chains: HashMap<(String, String), ResponsesContinuationState>,
59    /// Agent-local recent error diagnostics for interrupted or repeated tool failures.
60    pub error_recovery: Arc<Mutex<ErrorRecoveryState>>,
61
62    // Legacy / Stats fields for compatibility
63    pub consecutive_idle_turns: usize,
64    pub max_tool_loop_streak: usize,
65    pub turn_count: usize,
66    pub turn_total_ms: u128,
67    pub turn_max_ms: u128,
68    pub turn_durations_ms: Vec<u128>,
69}
70
71/// Statistics tracked during an agent session.
72#[derive(Debug, Default, Clone)]
73pub struct SessionStats {
74    pub turns_executed: usize,
75    pub total_duration: Duration,
76    pub turn_durations: Vec<Duration>,
77    pub total_usage: Usage,
78}
79
80impl SessionStats {
81    pub fn merge_usage(&mut self, usage: crate::llm::provider::Usage) {
82        self.total_usage.input_tokens = self
83            .total_usage
84            .input_tokens
85            .saturating_add(usage.prompt_tokens as u64);
86        self.total_usage.output_tokens = self
87            .total_usage
88            .output_tokens
89            .saturating_add(usage.completion_tokens as u64);
90        let cached = usage.cache_read_tokens_or_fallback();
91        if cached > 0 {
92            self.total_usage.cached_input_tokens = self
93                .total_usage
94                .cached_input_tokens
95                .saturating_add(cached as u64);
96        }
97        let cache_creation = usage.cache_creation_tokens_or_zero();
98        if cache_creation > 0 {
99            self.total_usage.cache_creation_tokens = self
100                .total_usage
101                .cache_creation_tokens
102                .saturating_add(cache_creation as u64);
103        }
104    }
105}
106
107/// Constraints applied to an agent session.
108#[derive(Debug, Clone)]
109pub struct SessionConstraints {
110    pub max_turns: usize,
111    pub max_tool_loops: usize,
112    pub max_context_tokens: usize,
113}
114
115impl AgentSessionState {
116    pub fn new(
117        session_id: String,
118        max_turns: usize,
119        max_tool_loops: usize,
120        max_context_tokens: usize,
121    ) -> Self {
122        Self {
123            session_id,
124            conversation: Vec::new(),
125            messages: Vec::new(),
126            stats: SessionStats::default(),
127            constraints: SessionConstraints {
128                max_turns,
129                max_tool_loops,
130                max_context_tokens,
131            },
132            outcome: TaskOutcome::Unknown,
133            stop_reason: None,
134            total_cost_usd: None,
135            is_completed: false,
136            current_stage: None,
137            created_contexts: Vec::with_capacity(16),
138            modified_files: Vec::with_capacity(32),
139            executed_commands: Vec::with_capacity(64),
140            warnings: Vec::with_capacity(16),
141            last_file_path: None,
142            last_dir_path: None,
143            consecutive_tool_loops: 0,
144            tool_loop_limit_hit: false,
145            last_processed_message_idx: 0,
146            previous_response_chains: HashMap::new(),
147            error_recovery: Arc::new(Mutex::new(ErrorRecoveryState::default())),
148            consecutive_idle_turns: 0,
149            max_tool_loop_streak: 0,
150            turn_count: 0,
151            turn_total_ms: 0,
152            turn_max_ms: 0,
153            turn_durations_ms: Vec::with_capacity(max_turns),
154        }
155    }
156
157    /// Record a completed turn.
158    pub fn record_turn(&mut self, start: &Instant, recorded: &mut bool) {
159        if *recorded {
160            return;
161        }
162        let duration = start.elapsed();
163        let ms = duration.as_millis() as u64;
164
165        self.stats.turns_executed += 1;
166        self.stats.total_duration += duration;
167        self.stats.turn_durations.push(duration);
168
169        // Legacy stats
170        self.turn_count += 1;
171        self.turn_total_ms += ms as u128;
172        self.turn_max_ms = self.turn_max_ms.max(ms as u128);
173        self.turn_durations_ms.push(ms as u128);
174
175        *recorded = true;
176    }
177
178    pub fn finalize_outcome(&mut self, max_turns: usize) {
179        if self.outcome != TaskOutcome::Unknown {
180            return;
181        }
182        // Priority order: tool loop limit > completion > turn limit
183        if self.tool_loop_limit_hit {
184            self.outcome = TaskOutcome::tool_loop_limit_reached(
185                self.constraints.max_tool_loops,
186                self.consecutive_tool_loops,
187            );
188        } else if self.is_completed {
189            self.outcome = TaskOutcome::Success;
190        } else if self.stats.turns_executed >= max_turns {
191            self.outcome = TaskOutcome::turn_limit_reached(max_turns, self.stats.turns_executed);
192        }
193    }
194
195    pub fn register_tool_loop(&mut self) -> usize {
196        self.consecutive_tool_loops += 1;
197        self.max_tool_loop_streak = self.max_tool_loop_streak.max(self.consecutive_tool_loops);
198        self.consecutive_tool_loops
199    }
200
201    pub fn reset_tool_loop_guard(&mut self) {
202        self.consecutive_tool_loops = 0;
203    }
204
205    pub fn previous_response_id_for(&self, provider: &str, model: &str) -> Option<String> {
206        self.previous_response_chain_for(provider, model)
207            .map(|chain| chain.response_id.clone())
208    }
209
210    pub fn previous_response_chain_for(
211        &self,
212        provider: &str,
213        model: &str,
214    ) -> Option<&ResponsesContinuationState> {
215        responses_continuation_key(provider, model)
216            .and_then(|key| self.previous_response_chains.get(&key))
217    }
218
219    pub fn set_previous_response_chain(
220        &mut self,
221        provider: &str,
222        model: &str,
223        response_id: Option<&str>,
224        messages: Vec<Message>,
225    ) {
226        let Some(key) = responses_continuation_key(provider, model) else {
227            return;
228        };
229        let Some(response_id) = response_id.map(str::trim).filter(|value| !value.is_empty()) else {
230            self.previous_response_chains.remove(&key);
231            return;
232        };
233
234        self.previous_response_chains.insert(
235            key,
236            ResponsesContinuationState {
237                response_id: response_id.to_string(),
238                messages,
239            },
240        );
241    }
242
243    pub fn clear_previous_response_chain_for(&mut self, provider: &str, model: &str) {
244        if let Some(key) = responses_continuation_key(provider, model) {
245            self.previous_response_chains.remove(&key);
246        }
247    }
248
249    pub fn clear_previous_response_chain(&mut self) {
250        self.previous_response_chains.clear();
251    }
252
253    pub fn mark_tool_loop_limit_hit(&mut self) {
254        if self.tool_loop_limit_hit {
255            return;
256        }
257        self.tool_loop_limit_hit = true;
258        self.outcome = TaskOutcome::tool_loop_limit_reached(
259            self.constraints.max_tool_loops,
260            self.consecutive_tool_loops,
261        );
262    }
263
264    /// Add a user message to the history.
265    pub fn add_user_message(&mut self, text: String) {
266        self.conversation.push(Content::user_text(text.as_str()));
267        self.messages.push(Message::user(text));
268    }
269
270    /// Check if context limits are approaching.
271    pub fn utilization(&self) -> f64 {
272        if self.constraints.max_context_tokens == 0 {
273            return 0.0;
274        }
275        self.total_tokens() as f64 / self.constraints.max_context_tokens as f64
276    }
277
278    /// Calculate total estimated tokens in the conversation.
279    pub fn total_tokens(&self) -> usize {
280        self.messages.iter().map(|m| m.estimate_tokens()).sum()
281    }
282
283    /// Find a safe split point for history trimming that doesn't break tool call/output pairs.
284    pub fn find_safe_split_point(&self, preferred_split_at: usize) -> usize {
285        crate::core::agent::state::safe_history_split_point(
286            &self.messages,
287            self.conversation.len(),
288            preferred_split_at,
289        )
290    }
291
292    /// Normalize history to enforce call/output pairing invariants.
293    pub fn normalize(&mut self) {
294        crate::core::agent::state::normalize_history(&mut self.messages);
295    }
296
297    pub fn into_results(
298        self,
299        summary: String,
300        thread_events: Vec<ThreadEvent>,
301        total_duration_ms: u128,
302    ) -> TaskResults {
303        let average_turn_duration_ms = if self.turn_count > 0 {
304            Some(self.turn_total_ms as f64 / self.turn_count as f64)
305        } else {
306            None
307        };
308        let max_turn_duration_ms = if self.turn_count > 0 {
309            Some(self.turn_max_ms)
310        } else {
311            None
312        };
313
314        TaskResults {
315            created_contexts: self.created_contexts,
316            modified_files: self.modified_files,
317            executed_commands: self.executed_commands,
318            summary,
319            stop_reason: self.stop_reason,
320            total_cost_usd: self.total_cost_usd,
321            warnings: self.warnings,
322            thread_events,
323            outcome: self.outcome,
324            turns_executed: self.stats.turns_executed,
325            total_duration_ms,
326            average_turn_duration_ms,
327            max_turn_duration_ms,
328            turn_durations_ms: self.turn_durations_ms,
329        }
330    }
331
332    /// Push a successful tool result to both conversation (for Gemini) and messages.
333    ///
334    /// Accepts a `&Value` to avoid redundant serialize/deserialize cycles for
335    /// Gemini providers — the value is used directly in `FunctionResponse`
336    /// instead of being serialized to a string and then re-parsed.
337    pub fn push_tool_result(
338        &mut self,
339        call_id: String,
340        tool_name: &str,
341        result: &serde_json::Value,
342        is_gemini: bool,
343    ) {
344        if is_gemini {
345            self.conversation.push(Content {
346                role: "function".to_string(),
347                parts: vec![Part::FunctionResponse {
348                    function_response: FunctionResponse {
349                        name: tool_name.to_string(),
350                        response: result.clone(),
351                        id: Some(call_id.clone()),
352                    },
353                    thought_signature: None,
354                }],
355            });
356        }
357        let serialized = serde_json::to_string(result).expect("Value serialization is infallible");
358        self.messages
359            .push(Message::tool_response(call_id, serialized));
360        self.executed_commands.push(tool_name.to_owned());
361    }
362
363    /// Push a tool error to both conversation (for Gemini) and messages.
364    ///
365    /// Accepts a `&Value` to avoid redundant serialize/deserialize cycles for
366    /// Gemini providers.
367    pub fn push_tool_error(
368        &mut self,
369        call_id: String,
370        tool_name: &str,
371        error_payload: &serde_json::Value,
372        is_gemini: bool,
373    ) {
374        if is_gemini {
375            self.conversation.push(Content {
376                role: "function".to_string(),
377                parts: vec![Part::FunctionResponse {
378                    function_response: FunctionResponse {
379                        name: tool_name.to_string(),
380                        response: error_payload.clone(),
381                        id: Some(call_id.clone()),
382                    },
383                    thought_signature: None,
384                }],
385            });
386        }
387        let serialized =
388            serde_json::to_string(error_payload).expect("Value serialization is infallible");
389        self.messages
390            .push(Message::tool_response(call_id, serialized));
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::AgentSessionState;
397    use crate::llm::provider::Message;
398    use crate::llm::providers::gemini::wire::Part;
399
400    #[test]
401    fn previous_response_chain_is_scoped_to_provider_and_model() {
402        let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
403        let messages_52 = vec![Message::user("hello".to_string())];
404        let messages_54 = vec![Message::user("continue".to_string())];
405
406        state.set_previous_response_chain(
407            "openai",
408            "gpt-5.2",
409            Some("resp_123"),
410            messages_52.clone(),
411        );
412        state.set_previous_response_chain(
413            "openai",
414            "gpt-5.4",
415            Some("resp_456"),
416            messages_54.clone(),
417        );
418
419        assert_eq!(
420            state.previous_response_id_for("openai", "gpt-5.2"),
421            Some("resp_123".to_string())
422        );
423        assert_eq!(
424            state.previous_response_id_for("openai", "gpt-5.4"),
425            Some("resp_456".to_string())
426        );
427        assert_eq!(state.previous_response_id_for("gemini", "gpt-5.2"), None);
428
429        state.clear_previous_response_chain_for("openai", "gpt-5.2");
430
431        assert_eq!(state.previous_response_id_for("openai", "gpt-5.2"), None);
432        assert_eq!(state.previous_response_chain_for("openai", "gpt-5.2"), None);
433        assert_eq!(
434            state.previous_response_id_for("openai", "gpt-5.4"),
435            Some("resp_456".to_string())
436        );
437        assert_eq!(
438            state
439                .previous_response_chain_for("openai", "gpt-5.4")
440                .map(|chain| chain.messages.as_slice()),
441            Some(messages_54.as_slice())
442        );
443
444        state.clear_previous_response_chain();
445        assert_eq!(state.previous_response_id_for("openai", "gpt-5.4"), None);
446        assert_eq!(state.previous_response_chain_for("openai", "gpt-5.4"), None);
447    }
448
449    #[test]
450    fn register_tool_loop_tracks_current_and_max_streak() {
451        let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
452
453        assert_eq!(state.register_tool_loop(), 1);
454        assert_eq!(state.register_tool_loop(), 2);
455        assert_eq!(state.consecutive_tool_loops, 2);
456        assert_eq!(state.max_tool_loop_streak, 2);
457
458        state.reset_tool_loop_guard();
459        assert_eq!(state.register_tool_loop(), 1);
460        assert_eq!(state.max_tool_loop_streak, 2);
461    }
462
463    #[test]
464    fn push_tool_error_preserves_structured_json_for_gemini() {
465        let mut state = AgentSessionState::new("session".to_string(), 4, 4, 16_000);
466        let payload = serde_json::json!({
467            "error": {
468                "tool_name": "read_file",
469                "message": "missing file",
470                "category": "ResourceNotFound"
471            }
472        });
473
474        state.push_tool_error("call_1".to_string(), "read_file", &payload, true);
475
476        match &state.conversation[0].parts[0] {
477            Part::FunctionResponse {
478                function_response, ..
479            } => {
480                assert_eq!(
481                    function_response.response["error"]["message"],
482                    "missing file"
483                );
484            }
485            other => panic!("expected function response, got {other:?}"),
486        }
487        let expected_serialized = serde_json::to_string(&payload).unwrap();
488        assert_eq!(
489            state.messages[0],
490            Message::tool_response("call_1".to_string(), expected_serialized)
491        );
492    }
493}