Skip to main content

vtcode_core/core/agent/
state.rs

1use crate::llm::provider::Message;
2use hashbrown::{HashMap, HashSet};
3use std::time::Duration;
4use vtcode_macros::StringNewtype;
5
6// ============================================================================
7// Context Manager: Call/Output Pairing Invariants (OpenAI Codex pattern)
8// ============================================================================
9
10/// Unique identifier for a tool call.
11#[derive(Debug, Clone, PartialEq, Eq, Hash, StringNewtype)]
12pub struct ToolCallId(String);
13
14/// Status of a tool execution
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OutputStatus {
17    Success,
18    Failed,
19    Canceled,
20    Timeout,
21}
22
23impl OutputStatus {
24    /// Convert to string representation
25    pub fn as_str(&self) -> &'static str {
26        match self {
27            Self::Success => "success",
28            Self::Failed => "failed",
29            Self::Canceled => "canceled",
30            Self::Timeout => "timeout",
31        }
32    }
33}
34
35/// Items that participate in call/output pairing for validation
36#[derive(Debug, Clone)]
37pub enum PairableHistoryItem {
38    /// Tool call without output (yet)
39    ToolCall {
40        call_id: ToolCallId,
41        tool_name: String,
42    },
43    /// Tool output for a previous call
44    ToolOutput {
45        call_id: ToolCallId,
46        status: OutputStatus,
47    },
48}
49
50/// Record of a missing output in conversation history
51#[derive(Debug, Clone)]
52pub struct MissingOutput {
53    pub call_id: ToolCallId,
54    pub tool_name: String,
55}
56
57/// Validation report for conversation history state
58#[derive(Debug, Default, Clone)]
59pub struct HistoryValidationReport {
60    /// Tool calls without corresponding outputs
61    pub missing_outputs: Vec<MissingOutput>,
62    /// Outputs without corresponding calls (orphans)
63    pub orphan_outputs: Vec<ToolCallId>,
64}
65
66impl HistoryValidationReport {
67    /// Check if history is in a valid state
68    pub fn is_valid(&self) -> bool {
69        self.missing_outputs.is_empty() && self.orphan_outputs.is_empty()
70    }
71
72    /// Get a human-readable summary
73    pub fn summary(&self) -> String {
74        if self.is_valid() {
75            "History invariants are valid".to_string()
76        } else {
77            format!(
78                "{} missing outputs, {} orphan outputs",
79                self.missing_outputs.len(),
80                self.orphan_outputs.len()
81            )
82        }
83    }
84}
85
86#[cfg(test)]
87#[inline]
88pub(crate) fn record_turn_duration(
89    turn_durations: &mut Vec<u128>,
90    turn_total_ms: &mut u128,
91    turn_max_ms: &mut u128,
92    turn_count: &mut usize,
93    recorded: &mut bool,
94    start: &std::time::Instant,
95) {
96    if !*recorded {
97        let duration_ms = start.elapsed().as_millis();
98        turn_durations.push(duration_ms);
99        *turn_total_ms += duration_ms;
100        if duration_ms > *turn_max_ms {
101            *turn_max_ms = duration_ms;
102        }
103        *turn_count += 1;
104        *recorded = true;
105    }
106}
107
108/// API failure tracking for exponential backoff
109pub struct ApiFailureTracker {
110    pub consecutive_failures: u32,
111    pub last_failure: Option<std::time::Instant>,
112}
113
114impl Default for ApiFailureTracker {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl ApiFailureTracker {
121    pub fn new() -> Self {
122        Self {
123            consecutive_failures: 0,
124            last_failure: None,
125        }
126    }
127
128    pub fn record_failure(&mut self) {
129        self.consecutive_failures += 1;
130        self.last_failure = Some(std::time::Instant::now());
131    }
132
133    pub fn reset(&mut self) {
134        self.consecutive_failures = 0;
135        self.last_failure = None;
136    }
137
138    pub fn should_circuit_break(&self) -> bool {
139        self.consecutive_failures >= 3
140    }
141
142    pub fn backoff_duration(&self) -> Duration {
143        let base_ms = 1000;
144        let max_ms = 30000;
145        let backoff_ms = base_ms * 2_u64.pow(self.consecutive_failures.saturating_sub(1));
146        Duration::from_millis(backoff_ms.min(max_ms))
147    }
148}
149
150pub fn summarize_list(items: &[String]) -> String {
151    const MAX_ITEMS: usize = 5;
152    if items.is_empty() {
153        return "none".into();
154    }
155    let shown: Vec<&str> = items.iter().take(MAX_ITEMS).map(|s| s.as_str()).collect();
156    if items.len() > MAX_ITEMS {
157        format!("{} [+{} more]", shown.join(", "), items.len() - MAX_ITEMS)
158    } else {
159        shown.join(", ")
160    }
161}
162
163// ============================================================================
164// Standalone History Invariant Functions
165// ============================================================================
166
167/// Validate that conversation history maintains call/output invariants.
168pub fn validate_history_invariants(messages: &[Message]) -> HistoryValidationReport {
169    let mut call_map: HashMap<String, String> = HashMap::new();
170    let mut output_ids: HashSet<String> = HashSet::new();
171
172    // Scan messages to find tool calls and responses
173    for msg in messages {
174        // Tool calls: assistant messages with tool_calls field
175        if let Some(tool_calls) = &msg.tool_calls {
176            for tool_call in tool_calls {
177                call_map.insert(tool_call.id.clone(), msg.role.to_string());
178            }
179        }
180
181        // Tool responses: messages with tool_call_id set
182        if let Some(tool_call_id) = &msg.tool_call_id {
183            output_ids.insert(tool_call_id.clone());
184        }
185    }
186
187    // Find missing outputs (calls without corresponding responses)
188    let missing_outputs: Vec<_> = call_map
189        .keys()
190        .filter(|call_id| !output_ids.contains(*call_id))
191        .map(|call_id| MissingOutput {
192            call_id: ToolCallId::new(call_id.clone()),
193            tool_name: "unknown".to_string(),
194        })
195        .collect();
196
197    // Find orphan outputs (responses without matching calls)
198    let orphan_outputs: Vec<_> = output_ids
199        .iter()
200        .filter(|output_id| !call_map.contains_key(*output_id))
201        .map(|output_id| ToolCallId::new(output_id.clone()))
202        .collect();
203
204    HistoryValidationReport {
205        missing_outputs,
206        orphan_outputs,
207    }
208}
209
210/// Find a split point that keeps tool-call outputs paired with their calls.
211pub fn safe_history_split_point(
212    messages: &[Message],
213    conversation_len: usize,
214    preferred_split_at: usize,
215) -> usize {
216    if preferred_split_at == 0 || preferred_split_at >= conversation_len {
217        return preferred_split_at;
218    }
219
220    let mut call_indices: HashMap<&str, usize> = HashMap::new();
221    for (i, msg) in messages.iter().enumerate() {
222        if let Some(tool_calls) = &msg.tool_calls {
223            for call in tool_calls {
224                call_indices.insert(&call.id, i);
225            }
226        }
227    }
228
229    let mut safe_split_at = preferred_split_at;
230    loop {
231        if safe_split_at == 0 {
232            break;
233        }
234
235        let has_orphan = ((safe_split_at + 1)..messages.len()).any(|i| {
236            messages
237                .get(i)
238                .and_then(|msg| msg.tool_call_id.as_ref())
239                .and_then(|id| call_indices.get(id.as_str()))
240                .is_some_and(|&call_idx| call_idx <= safe_split_at)
241        });
242
243        if !has_orphan {
244            break;
245        }
246
247        safe_split_at -= 1;
248    }
249
250    safe_split_at
251}
252
253/// Ensure all tool calls have corresponding outputs in the message list.
254pub fn ensure_call_outputs_present(messages: &mut Vec<Message>) {
255    let report = validate_history_invariants(messages);
256
257    // Create synthetic outputs for missing calls in reverse order to avoid index shifting
258    for missing in report.missing_outputs.iter().rev() {
259        let synthetic_message = Message::tool_response(
260            missing.call_id.as_str().to_string(),
261            "canceled: Tool execution was interrupted. This synthetic output was created \
262             during history normalization to maintain conversation invariants."
263                .to_string(),
264        );
265
266        tracing::warn!(
267            "Creating synthetic output for call {} due to missing execution result",
268            missing.call_id
269        );
270
271        // Find the position to insert: right after the corresponding call
272        let insert_pos = messages
273            .iter()
274            .position(|msg| {
275                msg.tool_calls.as_ref().is_some_and(|calls| {
276                    calls.iter().any(|call| call.id == missing.call_id.as_str())
277                })
278            })
279            .map(|pos| pos + 1);
280
281        if let Some(pos) = insert_pos {
282            messages.insert(pos, synthetic_message);
283        } else {
284            // If we can't find the call, just append the synthetic output
285            messages.push(synthetic_message);
286        }
287    }
288}
289
290/// Remove outputs without corresponding calls (orphaned outputs) from the message list.
291pub fn remove_orphan_outputs(messages: &mut Vec<Message>) {
292    let report = validate_history_invariants(messages);
293
294    if report.orphan_outputs.is_empty() {
295        return;
296    }
297
298    let orphan_ids: HashSet<String> = report
299        .orphan_outputs
300        .iter()
301        .map(|id| id.as_str().to_string())
302        .collect();
303
304    let initial_len = messages.len();
305
306    // Retain only messages that either:
307    // - Don't have a tool_call_id (not a tool response)
308    // - Have a tool_call_id that matches an existing call
309    messages.retain(|msg| {
310        if let Some(tool_call_id) = msg.tool_call_id.as_ref()
311            && orphan_ids.contains(tool_call_id)
312        {
313            tracing::warn!("Removing orphan output for call {}", tool_call_id);
314            return false;
315        }
316        true
317    });
318
319    if messages.len() != initial_len {
320        tracing::info!("Removed {} orphan outputs", initial_len - messages.len());
321    }
322}
323
324/// Normalize history to enforce call/output pairing invariants.
325pub fn normalize_history(messages: &mut Vec<Message>) {
326    ensure_call_outputs_present(messages);
327    remove_orphan_outputs(messages);
328
329    // Log if issues were found
330    let report = validate_history_invariants(messages);
331    if !report.is_valid() {
332        tracing::warn!("History validation: {}", report.summary());
333    } else {
334        tracing::debug!("History normalized successfully");
335    }
336}
337
338/// Recover from crashed or interrupted session by fixing history invariants.
339pub fn recover_history_from_crash(messages: &mut Vec<Message>) {
340    let report = validate_history_invariants(messages);
341
342    if !report.missing_outputs.is_empty() {
343        tracing::warn!(
344            "Found {} missing outputs during recovery",
345            report.missing_outputs.len()
346        );
347        ensure_call_outputs_present(messages);
348    }
349
350    if !report.orphan_outputs.is_empty() {
351        tracing::warn!(
352            "Found {} orphan outputs during recovery",
353            report.orphan_outputs.len()
354        );
355        remove_orphan_outputs(messages);
356    }
357
358    if report.is_valid() {
359        tracing::debug!("History invariants are valid");
360    }
361}
362
363// ============================================================================
364// Tests: Context Manager - Call/Output Pairing Invariants
365// ============================================================================
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::llm::provider::Message;
371    /// Helper: Create test messages
372    fn make_tool_call(call_id: &str, tool_name: &str) -> Message {
373        Message::assistant_with_tools(
374            "".to_string(),
375            vec![crate::llm::provider::ToolCall::function(
376                call_id.to_string(),
377                tool_name.to_string(),
378                "{}".to_string(),
379            )],
380        )
381    }
382
383    fn make_tool_response(call_id: &str, content: &str) -> Message {
384        Message::tool_response(call_id.to_string(), content.to_string())
385    }
386
387    /// Test: Valid history with matched calls and outputs
388    #[test]
389    fn test_validate_history_valid_matched_pairs() {
390        let mut messages = vec![
391            make_tool_call("call_1", "list_files"),
392            make_tool_response("call_1", "file1.rs\nfile2.rs"),
393        ];
394
395        let report = validate_history_invariants(&messages);
396        assert!(report.is_valid(), "Valid paired call/output should pass");
397        assert!(report.missing_outputs.is_empty());
398        assert!(report.orphan_outputs.is_empty());
399
400        // Normalize should be a no-op
401        normalize_history(&mut messages);
402        assert_eq!(messages.len(), 2);
403    }
404
405    /// Test: Missing output (tool call without response)
406    #[test]
407    fn test_validate_history_missing_output() {
408        let messages = vec![make_tool_call("call_1", "list_files")];
409
410        let report = validate_history_invariants(&messages);
411        assert!(!report.is_valid());
412        assert_eq!(report.missing_outputs.len(), 1);
413        assert_eq!(report.missing_outputs[0].call_id.as_str(), "call_1");
414        assert!(report.orphan_outputs.is_empty());
415    }
416
417    /// Test: Orphan output (response without corresponding call)
418    #[test]
419    fn test_validate_history_orphan_output() {
420        let messages = vec![make_tool_response("orphan_call", "Some result")];
421
422        let report = validate_history_invariants(&messages);
423        assert!(!report.is_valid());
424        assert!(report.missing_outputs.is_empty());
425        assert_eq!(report.orphan_outputs.len(), 1);
426        assert_eq!(report.orphan_outputs[0].as_str(), "orphan_call");
427    }
428
429    /// Test: ensure_call_outputs_present creates synthetic outputs
430    #[test]
431    fn test_ensure_call_outputs_present() {
432        let mut messages = vec![make_tool_call("call_1", "list_files")];
433        let initial_len = messages.len();
434
435        ensure_call_outputs_present(&mut messages);
436
437        assert_eq!(messages.len(), initial_len + 1);
438        let last_msg = &messages[initial_len];
439        assert_eq!(last_msg.tool_call_id, Some("call_1".to_string()));
440        assert!(last_msg.content.as_text().contains("canceled"));
441
442        let report = validate_history_invariants(&messages);
443        assert!(report.is_valid());
444    }
445
446    /// Test: remove_orphan_outputs filters out orphaned responses
447    #[test]
448    fn test_remove_orphan_outputs() {
449        let mut messages = vec![
450            make_tool_call("call_1", "list_files"),
451            make_tool_response("call_1", "valid result"),
452            make_tool_response("orphan_call", "orphan result"),
453        ];
454
455        let initial_len = messages.len();
456        remove_orphan_outputs(&mut messages);
457
458        assert_eq!(messages.len(), initial_len - 1);
459        assert!(
460            messages
461                .iter()
462                .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "call_1"))
463        );
464        assert!(!messages.iter().any(|msg| {
465            msg.tool_call_id
466                .as_ref()
467                .is_some_and(|id| id == "orphan_call")
468        }));
469
470        let report = validate_history_invariants(&messages);
471        assert!(report.is_valid());
472    }
473
474    /// Test: normalize() applies both fixes (synthetic output + orphan removal)
475    #[test]
476    fn test_normalize_combined_fixes() {
477        let mut messages = vec![
478            make_tool_call("call_1", "read_file"),
479            make_tool_call("call_2", "write_file"),
480            make_tool_response("call_2", "written"),
481            make_tool_response("orphan", "orphan result"),
482        ];
483
484        normalize_history(&mut messages);
485
486        let report = validate_history_invariants(&messages);
487        assert!(report.is_valid());
488        assert!(
489            messages
490                .iter()
491                .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "call_1"))
492        );
493        assert!(
494            !messages
495                .iter()
496                .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "orphan"))
497        );
498    }
499
500    /// Test: recover_history_from_crash handles both missing and orphan outputs
501    #[test]
502    fn test_recover_from_crash() {
503        let mut messages = vec![
504            make_tool_call("crashed_call", "dangerous_op"),
505            make_tool_response("old_call", "stale result"),
506        ];
507
508        recover_history_from_crash(&mut messages);
509
510        let report = validate_history_invariants(&messages);
511        assert!(report.is_valid());
512        assert!(messages.iter().any(|msg| {
513            msg.tool_call_id
514                .as_ref()
515                .is_some_and(|id| id == "crashed_call")
516        }));
517        assert!(
518            !messages
519                .iter()
520                .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "old_call"))
521        );
522    }
523
524    /// Test: HistoryValidationReport summary messages
525    #[test]
526    fn test_validation_report_summary() {
527        let valid = HistoryValidationReport::default();
528        assert_eq!(valid.summary(), "History invariants are valid");
529        assert!(valid.is_valid());
530
531        let invalid = HistoryValidationReport {
532            missing_outputs: vec![
533                MissingOutput {
534                    call_id: ToolCallId::new("call_1"),
535                    tool_name: "tool_a".into(),
536                },
537                MissingOutput {
538                    call_id: ToolCallId::new("call_2"),
539                    tool_name: "tool_b".into(),
540                },
541            ],
542            orphan_outputs: vec![ToolCallId::new("orphan_1")],
543        };
544        assert_eq!(invalid.summary(), "2 missing outputs, 1 orphan outputs");
545        assert!(!invalid.is_valid());
546    }
547
548    /// Test: Multiple tool calls with selective missing outputs
549    #[test]
550    fn test_multiple_calls_partial_outputs() {
551        let _messages: Vec<Message> = (1..=3)
552            .flat_map(|i| {
553                vec![
554                    make_tool_call(&format!("call_{i}"), &format!("tool_{i}")),
555                    if i != 2 {
556                        make_tool_response(&format!("call_{i}"), &format!("result_{i}"))
557                    } else {
558                        // Simulate a gap: we don't add a response for call_2 here directly,
559                        // but we need to build messages differently.
560                        // Instead, build manually below.
561                        Message::tool_response("placeholder".into(), "".into())
562                    },
563                ]
564            })
565            .collect();
566        // Redo: explicit construction
567        let mut messages = vec![
568            make_tool_call("call_1", "tool_1"),
569            make_tool_response("call_1", "result_1"),
570            make_tool_call("call_2", "tool_2"),
571            make_tool_call("call_3", "tool_3"),
572            make_tool_response("call_3", "result_3"),
573        ];
574
575        let report = validate_history_invariants(&messages);
576        assert!(!report.is_valid());
577        assert_eq!(report.missing_outputs.len(), 1);
578        assert_eq!(report.missing_outputs[0].call_id.as_str(), "call_2");
579
580        normalize_history(&mut messages);
581        assert!(validate_history_invariants(&messages).is_valid());
582    }
583
584    /// Test: OutputStatus enum conversion
585    #[test]
586    fn test_output_status_as_str() {
587        assert_eq!(OutputStatus::Success.as_str(), "success");
588        assert_eq!(OutputStatus::Failed.as_str(), "failed");
589        assert_eq!(OutputStatus::Canceled.as_str(), "canceled");
590        assert_eq!(OutputStatus::Timeout.as_str(), "timeout");
591    }
592
593    /// Test: find_safe_split_point maintains call/output pairs
594    #[test]
595    fn test_find_safe_split_point() {
596        let messages = vec![
597            Message::user("User 1".into()),           // 0
598            make_tool_call("call_a", "tool_a"),       // 1
599            make_tool_response("call_a", "Result A"), // 2
600            make_tool_call("call_b", "tool_b"),       // 3
601            make_tool_response("call_b", "Result B"), // 4
602        ];
603        let conversation_len = 5;
604
605        // Split at 3 means keeping 3,4. But response at 2 needs call at 1 -> must split at 2.
606        let safe = safe_history_split_point(&messages, conversation_len, 3);
607        assert_eq!(safe, 2, "Should move split to include Call A");
608
609        // Split at 4 is safe: call_b (3) and response_b (4) are both kept.
610        let safe2 = safe_history_split_point(&messages, conversation_len, 4);
611        assert_eq!(safe2, 4, "Should stay at 4 as it is safe");
612    }
613
614    #[test]
615    fn test_summarize_list_formatting() {
616        assert_eq!(summarize_list(&[]), "none");
617        assert_eq!(summarize_list(&["a".into()]), "a");
618        assert_eq!(summarize_list(&["a".into(), "b".into()]), "a, b");
619        let many: Vec<String> = (1..=7).map(|i| format!("item{i}")).collect();
620        let result = summarize_list(&many);
621        assert!(result.contains("item1, item2, item3, item4, item5"));
622        assert!(result.contains("[+2 more]"));
623    }
624}