1use crate::llm::provider::Message;
2use hashbrown::{HashMap, HashSet};
3use std::time::Duration;
4use vtcode_macros::StringNewtype;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, StringNewtype)]
12pub struct ToolCallId(String);
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OutputStatus {
17 Success,
18 Failed,
19 Canceled,
20 Timeout,
21}
22
23impl OutputStatus {
24 pub fn as_str(&self) -> &'static str {
26 match self {
27 Self::Success => "success",
28 Self::Failed => "failed",
29 Self::Canceled => "canceled",
30 Self::Timeout => "timeout",
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub enum PairableHistoryItem {
38 ToolCall {
40 call_id: ToolCallId,
41 tool_name: String,
42 },
43 ToolOutput {
45 call_id: ToolCallId,
46 status: OutputStatus,
47 },
48}
49
50#[derive(Debug, Clone)]
52pub struct MissingOutput {
53 pub call_id: ToolCallId,
54 pub tool_name: String,
55}
56
57#[derive(Debug, Default, Clone)]
59pub struct HistoryValidationReport {
60 pub missing_outputs: Vec<MissingOutput>,
62 pub orphan_outputs: Vec<ToolCallId>,
64}
65
66impl HistoryValidationReport {
67 pub fn is_valid(&self) -> bool {
69 self.missing_outputs.is_empty() && self.orphan_outputs.is_empty()
70 }
71
72 pub fn summary(&self) -> String {
74 if self.is_valid() {
75 "History invariants are valid".to_string()
76 } else {
77 format!(
78 "{} missing outputs, {} orphan outputs",
79 self.missing_outputs.len(),
80 self.orphan_outputs.len()
81 )
82 }
83 }
84}
85
86#[cfg(test)]
87#[inline]
88pub(crate) fn record_turn_duration(
89 turn_durations: &mut Vec<u128>,
90 turn_total_ms: &mut u128,
91 turn_max_ms: &mut u128,
92 turn_count: &mut usize,
93 recorded: &mut bool,
94 start: &std::time::Instant,
95) {
96 if !*recorded {
97 let duration_ms = start.elapsed().as_millis();
98 turn_durations.push(duration_ms);
99 *turn_total_ms += duration_ms;
100 if duration_ms > *turn_max_ms {
101 *turn_max_ms = duration_ms;
102 }
103 *turn_count += 1;
104 *recorded = true;
105 }
106}
107
108pub struct ApiFailureTracker {
110 pub consecutive_failures: u32,
111 pub last_failure: Option<std::time::Instant>,
112}
113
114impl Default for ApiFailureTracker {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl ApiFailureTracker {
121 pub fn new() -> Self {
122 Self {
123 consecutive_failures: 0,
124 last_failure: None,
125 }
126 }
127
128 pub fn record_failure(&mut self) {
129 self.consecutive_failures += 1;
130 self.last_failure = Some(std::time::Instant::now());
131 }
132
133 pub fn reset(&mut self) {
134 self.consecutive_failures = 0;
135 self.last_failure = None;
136 }
137
138 pub fn should_circuit_break(&self) -> bool {
139 self.consecutive_failures >= 3
140 }
141
142 pub fn backoff_duration(&self) -> Duration {
143 let base_ms = 1000;
144 let max_ms = 30000;
145 let backoff_ms = base_ms * 2_u64.pow(self.consecutive_failures.saturating_sub(1));
146 Duration::from_millis(backoff_ms.min(max_ms))
147 }
148}
149
150pub fn summarize_list(items: &[String]) -> String {
151 const MAX_ITEMS: usize = 5;
152 if items.is_empty() {
153 return "none".into();
154 }
155 let shown: Vec<&str> = items.iter().take(MAX_ITEMS).map(|s| s.as_str()).collect();
156 if items.len() > MAX_ITEMS {
157 format!("{} [+{} more]", shown.join(", "), items.len() - MAX_ITEMS)
158 } else {
159 shown.join(", ")
160 }
161}
162
163pub fn validate_history_invariants(messages: &[Message]) -> HistoryValidationReport {
169 let mut call_map: HashMap<String, String> = HashMap::new();
170 let mut output_ids: HashSet<String> = HashSet::new();
171
172 for msg in messages {
174 if let Some(tool_calls) = &msg.tool_calls {
176 for tool_call in tool_calls {
177 call_map.insert(tool_call.id.clone(), msg.role.to_string());
178 }
179 }
180
181 if let Some(tool_call_id) = &msg.tool_call_id {
183 output_ids.insert(tool_call_id.clone());
184 }
185 }
186
187 let missing_outputs: Vec<_> = call_map
189 .keys()
190 .filter(|call_id| !output_ids.contains(*call_id))
191 .map(|call_id| MissingOutput {
192 call_id: ToolCallId::new(call_id.clone()),
193 tool_name: "unknown".to_string(),
194 })
195 .collect();
196
197 let orphan_outputs: Vec<_> = output_ids
199 .iter()
200 .filter(|output_id| !call_map.contains_key(*output_id))
201 .map(|output_id| ToolCallId::new(output_id.clone()))
202 .collect();
203
204 HistoryValidationReport {
205 missing_outputs,
206 orphan_outputs,
207 }
208}
209
210pub fn safe_history_split_point(
212 messages: &[Message],
213 conversation_len: usize,
214 preferred_split_at: usize,
215) -> usize {
216 if preferred_split_at == 0 || preferred_split_at >= conversation_len {
217 return preferred_split_at;
218 }
219
220 let mut call_indices: HashMap<&str, usize> = HashMap::new();
221 for (i, msg) in messages.iter().enumerate() {
222 if let Some(tool_calls) = &msg.tool_calls {
223 for call in tool_calls {
224 call_indices.insert(&call.id, i);
225 }
226 }
227 }
228
229 let mut safe_split_at = preferred_split_at;
230 loop {
231 if safe_split_at == 0 {
232 break;
233 }
234
235 let has_orphan = ((safe_split_at + 1)..messages.len()).any(|i| {
236 messages
237 .get(i)
238 .and_then(|msg| msg.tool_call_id.as_ref())
239 .and_then(|id| call_indices.get(id.as_str()))
240 .is_some_and(|&call_idx| call_idx <= safe_split_at)
241 });
242
243 if !has_orphan {
244 break;
245 }
246
247 safe_split_at -= 1;
248 }
249
250 safe_split_at
251}
252
253pub fn ensure_call_outputs_present(messages: &mut Vec<Message>) {
255 let report = validate_history_invariants(messages);
256
257 for missing in report.missing_outputs.iter().rev() {
259 let synthetic_message = Message::tool_response(
260 missing.call_id.as_str().to_string(),
261 "canceled: Tool execution was interrupted. This synthetic output was created \
262 during history normalization to maintain conversation invariants."
263 .to_string(),
264 );
265
266 tracing::warn!(
267 "Creating synthetic output for call {} due to missing execution result",
268 missing.call_id
269 );
270
271 let insert_pos = messages
273 .iter()
274 .position(|msg| {
275 msg.tool_calls.as_ref().is_some_and(|calls| {
276 calls.iter().any(|call| call.id == missing.call_id.as_str())
277 })
278 })
279 .map(|pos| pos + 1);
280
281 if let Some(pos) = insert_pos {
282 messages.insert(pos, synthetic_message);
283 } else {
284 messages.push(synthetic_message);
286 }
287 }
288}
289
290pub fn remove_orphan_outputs(messages: &mut Vec<Message>) {
292 let report = validate_history_invariants(messages);
293
294 if report.orphan_outputs.is_empty() {
295 return;
296 }
297
298 let orphan_ids: HashSet<String> = report
299 .orphan_outputs
300 .iter()
301 .map(|id| id.as_str().to_string())
302 .collect();
303
304 let initial_len = messages.len();
305
306 messages.retain(|msg| {
310 if let Some(tool_call_id) = msg.tool_call_id.as_ref()
311 && orphan_ids.contains(tool_call_id)
312 {
313 tracing::warn!("Removing orphan output for call {}", tool_call_id);
314 return false;
315 }
316 true
317 });
318
319 if messages.len() != initial_len {
320 tracing::info!("Removed {} orphan outputs", initial_len - messages.len());
321 }
322}
323
324pub fn normalize_history(messages: &mut Vec<Message>) {
326 ensure_call_outputs_present(messages);
327 remove_orphan_outputs(messages);
328
329 let report = validate_history_invariants(messages);
331 if !report.is_valid() {
332 tracing::warn!("History validation: {}", report.summary());
333 } else {
334 tracing::debug!("History normalized successfully");
335 }
336}
337
338pub fn recover_history_from_crash(messages: &mut Vec<Message>) {
340 let report = validate_history_invariants(messages);
341
342 if !report.missing_outputs.is_empty() {
343 tracing::warn!(
344 "Found {} missing outputs during recovery",
345 report.missing_outputs.len()
346 );
347 ensure_call_outputs_present(messages);
348 }
349
350 if !report.orphan_outputs.is_empty() {
351 tracing::warn!(
352 "Found {} orphan outputs during recovery",
353 report.orphan_outputs.len()
354 );
355 remove_orphan_outputs(messages);
356 }
357
358 if report.is_valid() {
359 tracing::debug!("History invariants are valid");
360 }
361}
362
363#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::llm::provider::Message;
371 fn make_tool_call(call_id: &str, tool_name: &str) -> Message {
373 Message::assistant_with_tools(
374 "".to_string(),
375 vec![crate::llm::provider::ToolCall::function(
376 call_id.to_string(),
377 tool_name.to_string(),
378 "{}".to_string(),
379 )],
380 )
381 }
382
383 fn make_tool_response(call_id: &str, content: &str) -> Message {
384 Message::tool_response(call_id.to_string(), content.to_string())
385 }
386
387 #[test]
389 fn test_validate_history_valid_matched_pairs() {
390 let mut messages = vec![
391 make_tool_call("call_1", "list_files"),
392 make_tool_response("call_1", "file1.rs\nfile2.rs"),
393 ];
394
395 let report = validate_history_invariants(&messages);
396 assert!(report.is_valid(), "Valid paired call/output should pass");
397 assert!(report.missing_outputs.is_empty());
398 assert!(report.orphan_outputs.is_empty());
399
400 normalize_history(&mut messages);
402 assert_eq!(messages.len(), 2);
403 }
404
405 #[test]
407 fn test_validate_history_missing_output() {
408 let messages = vec![make_tool_call("call_1", "list_files")];
409
410 let report = validate_history_invariants(&messages);
411 assert!(!report.is_valid());
412 assert_eq!(report.missing_outputs.len(), 1);
413 assert_eq!(report.missing_outputs[0].call_id.as_str(), "call_1");
414 assert!(report.orphan_outputs.is_empty());
415 }
416
417 #[test]
419 fn test_validate_history_orphan_output() {
420 let messages = vec![make_tool_response("orphan_call", "Some result")];
421
422 let report = validate_history_invariants(&messages);
423 assert!(!report.is_valid());
424 assert!(report.missing_outputs.is_empty());
425 assert_eq!(report.orphan_outputs.len(), 1);
426 assert_eq!(report.orphan_outputs[0].as_str(), "orphan_call");
427 }
428
429 #[test]
431 fn test_ensure_call_outputs_present() {
432 let mut messages = vec![make_tool_call("call_1", "list_files")];
433 let initial_len = messages.len();
434
435 ensure_call_outputs_present(&mut messages);
436
437 assert_eq!(messages.len(), initial_len + 1);
438 let last_msg = &messages[initial_len];
439 assert_eq!(last_msg.tool_call_id, Some("call_1".to_string()));
440 assert!(last_msg.content.as_text().contains("canceled"));
441
442 let report = validate_history_invariants(&messages);
443 assert!(report.is_valid());
444 }
445
446 #[test]
448 fn test_remove_orphan_outputs() {
449 let mut messages = vec![
450 make_tool_call("call_1", "list_files"),
451 make_tool_response("call_1", "valid result"),
452 make_tool_response("orphan_call", "orphan result"),
453 ];
454
455 let initial_len = messages.len();
456 remove_orphan_outputs(&mut messages);
457
458 assert_eq!(messages.len(), initial_len - 1);
459 assert!(
460 messages
461 .iter()
462 .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "call_1"))
463 );
464 assert!(!messages.iter().any(|msg| {
465 msg.tool_call_id
466 .as_ref()
467 .is_some_and(|id| id == "orphan_call")
468 }));
469
470 let report = validate_history_invariants(&messages);
471 assert!(report.is_valid());
472 }
473
474 #[test]
476 fn test_normalize_combined_fixes() {
477 let mut messages = vec![
478 make_tool_call("call_1", "read_file"),
479 make_tool_call("call_2", "write_file"),
480 make_tool_response("call_2", "written"),
481 make_tool_response("orphan", "orphan result"),
482 ];
483
484 normalize_history(&mut messages);
485
486 let report = validate_history_invariants(&messages);
487 assert!(report.is_valid());
488 assert!(
489 messages
490 .iter()
491 .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "call_1"))
492 );
493 assert!(
494 !messages
495 .iter()
496 .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "orphan"))
497 );
498 }
499
500 #[test]
502 fn test_recover_from_crash() {
503 let mut messages = vec![
504 make_tool_call("crashed_call", "dangerous_op"),
505 make_tool_response("old_call", "stale result"),
506 ];
507
508 recover_history_from_crash(&mut messages);
509
510 let report = validate_history_invariants(&messages);
511 assert!(report.is_valid());
512 assert!(messages.iter().any(|msg| {
513 msg.tool_call_id
514 .as_ref()
515 .is_some_and(|id| id == "crashed_call")
516 }));
517 assert!(
518 !messages
519 .iter()
520 .any(|msg| msg.tool_call_id.as_ref().is_some_and(|id| id == "old_call"))
521 );
522 }
523
524 #[test]
526 fn test_validation_report_summary() {
527 let valid = HistoryValidationReport::default();
528 assert_eq!(valid.summary(), "History invariants are valid");
529 assert!(valid.is_valid());
530
531 let invalid = HistoryValidationReport {
532 missing_outputs: vec![
533 MissingOutput {
534 call_id: ToolCallId::new("call_1"),
535 tool_name: "tool_a".into(),
536 },
537 MissingOutput {
538 call_id: ToolCallId::new("call_2"),
539 tool_name: "tool_b".into(),
540 },
541 ],
542 orphan_outputs: vec![ToolCallId::new("orphan_1")],
543 };
544 assert_eq!(invalid.summary(), "2 missing outputs, 1 orphan outputs");
545 assert!(!invalid.is_valid());
546 }
547
548 #[test]
550 fn test_multiple_calls_partial_outputs() {
551 let _messages: Vec<Message> = (1..=3)
552 .flat_map(|i| {
553 vec![
554 make_tool_call(&format!("call_{i}"), &format!("tool_{i}")),
555 if i != 2 {
556 make_tool_response(&format!("call_{i}"), &format!("result_{i}"))
557 } else {
558 Message::tool_response("placeholder".into(), "".into())
562 },
563 ]
564 })
565 .collect();
566 let mut messages = vec![
568 make_tool_call("call_1", "tool_1"),
569 make_tool_response("call_1", "result_1"),
570 make_tool_call("call_2", "tool_2"),
571 make_tool_call("call_3", "tool_3"),
572 make_tool_response("call_3", "result_3"),
573 ];
574
575 let report = validate_history_invariants(&messages);
576 assert!(!report.is_valid());
577 assert_eq!(report.missing_outputs.len(), 1);
578 assert_eq!(report.missing_outputs[0].call_id.as_str(), "call_2");
579
580 normalize_history(&mut messages);
581 assert!(validate_history_invariants(&messages).is_valid());
582 }
583
584 #[test]
586 fn test_output_status_as_str() {
587 assert_eq!(OutputStatus::Success.as_str(), "success");
588 assert_eq!(OutputStatus::Failed.as_str(), "failed");
589 assert_eq!(OutputStatus::Canceled.as_str(), "canceled");
590 assert_eq!(OutputStatus::Timeout.as_str(), "timeout");
591 }
592
593 #[test]
595 fn test_find_safe_split_point() {
596 let messages = vec![
597 Message::user("User 1".into()), make_tool_call("call_a", "tool_a"), make_tool_response("call_a", "Result A"), make_tool_call("call_b", "tool_b"), make_tool_response("call_b", "Result B"), ];
603 let conversation_len = 5;
604
605 let safe = safe_history_split_point(&messages, conversation_len, 3);
607 assert_eq!(safe, 2, "Should move split to include Call A");
608
609 let safe2 = safe_history_split_point(&messages, conversation_len, 4);
611 assert_eq!(safe2, 4, "Should stay at 4 as it is safe");
612 }
613
614 #[test]
615 fn test_summarize_list_formatting() {
616 assert_eq!(summarize_list(&[]), "none");
617 assert_eq!(summarize_list(&["a".into()]), "a");
618 assert_eq!(summarize_list(&["a".into(), "b".into()]), "a, b");
619 let many: Vec<String> = (1..=7).map(|i| format!("item{i}")).collect();
620 let result = summarize_list(&many);
621 assert!(result.contains("item1, item2, item3, item4, item5"));
622 assert!(result.contains("[+2 more]"));
623 }
624}