Skip to main content

sgr_agent/
agent_loop.rs

1//! Generic agent loop — drives agent + tools until completion or limit.
2//!
3//! Includes 3-tier loop detection (exact signature, tool name frequency, output stagnation).
4
5use crate::agent::{Agent, AgentError, Decision};
6use crate::context::{AgentContext, AgentState};
7use crate::registry::ToolRegistry;
8use crate::types::{Message, SgrError};
9use futures::future::join_all;
10use std::collections::HashMap;
11
12/// Max consecutive parsing errors before aborting the loop.
13const MAX_PARSE_RETRIES: usize = 3;
14
15/// Check if an agent error is recoverable (parsing/empty response).
16fn is_recoverable_error(e: &AgentError) -> bool {
17    matches!(
18        e,
19        AgentError::Llm(SgrError::Json(_))
20            | AgentError::Llm(SgrError::EmptyResponse)
21            | AgentError::Llm(SgrError::Schema(_))
22    )
23}
24
25/// Loop configuration.
26#[derive(Debug, Clone)]
27pub struct LoopConfig {
28    /// Maximum steps before aborting.
29    pub max_steps: usize,
30    /// Consecutive repeated tool calls before loop detection triggers.
31    pub loop_abort_threshold: usize,
32    /// Max messages to keep in context (0 = unlimited).
33    /// Keeps first 2 (system + user prompt) + last N messages.
34    pub max_messages: usize,
35    /// Auto-complete if agent returns same situation text N times.
36    pub auto_complete_threshold: usize,
37}
38
39impl Default for LoopConfig {
40    fn default() -> Self {
41        Self {
42            max_steps: 50,
43            loop_abort_threshold: 6,
44            max_messages: 80,
45            auto_complete_threshold: 3,
46        }
47    }
48}
49
50/// Events emitted during the agent loop.
51#[derive(Debug)]
52pub enum LoopEvent {
53    StepStart {
54        step: usize,
55    },
56    Decision(Decision),
57    ToolResult {
58        name: String,
59        output: String,
60    },
61    Completed {
62        steps: usize,
63    },
64    LoopDetected {
65        count: usize,
66    },
67    Error(AgentError),
68    /// Agent needs user input. Content is the question.
69    WaitingForInput {
70        question: String,
71        tool_call_id: String,
72    },
73}
74
75/// Run the agent loop: decide → execute tools → feed results → repeat.
76///
77/// Returns the number of steps taken.
78pub async fn run_loop(
79    agent: &dyn Agent,
80    tools: &ToolRegistry,
81    ctx: &mut AgentContext,
82    messages: &mut Vec<Message>,
83    config: &LoopConfig,
84    mut on_event: impl FnMut(LoopEvent),
85) -> Result<usize, AgentError> {
86    let mut detector = LoopDetector::new(config.loop_abort_threshold);
87    let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
88    let mut parse_retries: usize = 0;
89
90    for step in 1..=config.max_steps {
91        // Sliding window: trim messages if over limit
92        if config.max_messages > 0 && messages.len() > config.max_messages {
93            trim_messages(messages, config.max_messages);
94        }
95        ctx.iteration = step;
96        on_event(LoopEvent::StepStart { step });
97
98        // Lifecycle hook: prepare context
99        agent.prepare_context(ctx, messages);
100
101        // Lifecycle hook: prepare tools (filter/reorder)
102        let active_tool_names = agent.prepare_tools(ctx, tools);
103        let filtered_tools = if active_tool_names.len() == tools.list().len() {
104            None // no filtering needed
105        } else {
106            Some(active_tool_names)
107        };
108
109        // Use filtered registry if hooks modified the tool set
110        let effective_tools = if let Some(ref names) = filtered_tools {
111            &tools.filter(names)
112        } else {
113            tools
114        };
115
116        let decision = match agent.decide(messages, effective_tools).await {
117            Ok(d) => {
118                parse_retries = 0;
119                d
120            }
121            Err(e) if is_recoverable_error(&e) => {
122                parse_retries += 1;
123                if parse_retries > MAX_PARSE_RETRIES {
124                    return Err(e);
125                }
126                let err_msg = format!(
127                    "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
128                    parse_retries, MAX_PARSE_RETRIES, e
129                );
130                on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
131                    err_msg.clone(),
132                ))));
133                messages.push(Message::user(&err_msg));
134                continue;
135            }
136            Err(e) => return Err(e),
137        };
138        on_event(LoopEvent::Decision(decision.clone()));
139
140        // Auto-completion: detect when agent is done but forgot to call finish_task
141        if completion_detector.check(&decision) {
142            ctx.state = AgentState::Completed;
143            if !decision.situation.is_empty() {
144                messages.push(Message::assistant(&decision.situation));
145            }
146            on_event(LoopEvent::Completed { steps: step });
147            return Ok(step);
148        }
149
150        if decision.completed || decision.tool_calls.is_empty() {
151            ctx.state = AgentState::Completed;
152            // Add assistant message with situation
153            if !decision.situation.is_empty() {
154                messages.push(Message::assistant(&decision.situation));
155            }
156            on_event(LoopEvent::Completed { steps: step });
157            return Ok(step);
158        }
159
160        // Loop detection
161        let sig: Vec<String> = decision
162            .tool_calls
163            .iter()
164            .map(|tc| tc.name.clone())
165            .collect();
166        match detector.check(&sig) {
167            LoopCheckResult::Abort => {
168                ctx.state = AgentState::Failed;
169                on_event(LoopEvent::LoopDetected {
170                    count: detector.consecutive,
171                });
172                return Err(AgentError::LoopDetected(detector.consecutive));
173            }
174            LoopCheckResult::Tier2Warning(dominant_tool) => {
175                // Inject a system hint: give the agent one more chance to change approach
176                let hint = format!(
177                    "LOOP WARNING: You are repeatedly using '{}' without making progress. \
178                     Try a different approach: re-read the file with read_file to see current contents, \
179                     use write_file instead of edit_file, or break the problem into smaller steps.",
180                    dominant_tool
181                );
182                messages.push(Message::system(&hint));
183            }
184            LoopCheckResult::Ok => {}
185        }
186
187        // Add assistant message with tool calls (Gemini requires model turn before function responses)
188        messages.push(Message::assistant_with_tool_calls(
189            &decision.situation,
190            decision.tool_calls.clone(),
191        ));
192
193        // Execute tool calls: read-only in parallel, write sequentially
194        let mut step_outputs: Vec<String> = Vec::new();
195        let mut early_done = false;
196
197        // Partition into read-only (parallel) and write (sequential) tool calls
198        let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
199            .tool_calls
200            .iter()
201            .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
202
203        // Phase 1: read-only tools in parallel
204        if !ro_calls.is_empty() {
205            let futs: Vec<_> = ro_calls
206                .iter()
207                .map(|tc| {
208                    let tool = tools.get(&tc.name).unwrap();
209                    let args = tc.arguments.clone();
210                    let name = tc.name.clone();
211                    let id = tc.id.clone();
212                    async move { (id, name, tool.execute_readonly(args).await) }
213                })
214                .collect();
215
216            for (id, name, result) in join_all(futs).await {
217                match result {
218                    Ok(output) => {
219                        on_event(LoopEvent::ToolResult {
220                            name: name.clone(),
221                            output: output.content.clone(),
222                        });
223                        step_outputs.push(output.content.clone());
224                        agent.after_action(ctx, &name, &output.content);
225                        if output.waiting {
226                            ctx.state = AgentState::WaitingInput;
227                            on_event(LoopEvent::WaitingForInput {
228                                question: output.content.clone(),
229                                tool_call_id: id.clone(),
230                            });
231                            messages.push(Message::tool(&id, "[waiting for user input]"));
232                            ctx.state = AgentState::Running;
233                        } else {
234                            messages.push(Message::tool(&id, &output.content));
235                        }
236                        if output.done {
237                            early_done = true;
238                        }
239                    }
240                    Err(e) => {
241                        let err_msg = format!("Tool error: {}", e);
242                        step_outputs.push(err_msg.clone());
243                        messages.push(Message::tool(&id, &err_msg));
244                        agent.after_action(ctx, &name, &err_msg);
245                        on_event(LoopEvent::ToolResult {
246                            name,
247                            output: err_msg,
248                        });
249                    }
250                }
251            }
252            if early_done && rw_calls.is_empty() {
253                // Only honor early done from read-only tools if no write tools pending
254                ctx.state = AgentState::Completed;
255                on_event(LoopEvent::Completed { steps: step });
256                return Ok(step);
257            }
258        }
259
260        // Phase 2: write tools sequentially (need &mut ctx)
261        for tc in &rw_calls {
262            if let Some(tool) = tools.get(&tc.name) {
263                match tool.execute(tc.arguments.clone(), ctx).await {
264                    Ok(output) => {
265                        on_event(LoopEvent::ToolResult {
266                            name: tc.name.clone(),
267                            output: output.content.clone(),
268                        });
269                        step_outputs.push(output.content.clone());
270                        agent.after_action(ctx, &tc.name, &output.content);
271                        if output.waiting {
272                            ctx.state = AgentState::WaitingInput;
273                            on_event(LoopEvent::WaitingForInput {
274                                question: output.content.clone(),
275                                tool_call_id: tc.id.clone(),
276                            });
277                            messages.push(Message::tool(&tc.id, "[waiting for user input]"));
278                            ctx.state = AgentState::Running;
279                        } else {
280                            messages.push(Message::tool(&tc.id, &output.content));
281                        }
282                        if output.done {
283                            ctx.state = AgentState::Completed;
284                            on_event(LoopEvent::Completed { steps: step });
285                            return Ok(step);
286                        }
287                    }
288                    Err(e) => {
289                        let err_msg = format!("Tool error: {}", e);
290                        step_outputs.push(err_msg.clone());
291                        messages.push(Message::tool(&tc.id, &err_msg));
292                        agent.after_action(ctx, &tc.name, &err_msg);
293                        on_event(LoopEvent::ToolResult {
294                            name: tc.name.clone(),
295                            output: err_msg,
296                        });
297                    }
298                }
299            } else {
300                let err_msg = format!("Unknown tool: {}", tc.name);
301                step_outputs.push(err_msg.clone());
302                messages.push(Message::tool(&tc.id, &err_msg));
303                on_event(LoopEvent::ToolResult {
304                    name: tc.name.clone(),
305                    output: err_msg,
306                });
307            }
308        }
309
310        // Tier 3: output stagnation
311        if detector.check_outputs(&step_outputs) {
312            ctx.state = AgentState::Failed;
313            on_event(LoopEvent::LoopDetected {
314                count: detector.output_repeat_count,
315            });
316            return Err(AgentError::LoopDetected(detector.output_repeat_count));
317        }
318    }
319
320    ctx.state = AgentState::Failed;
321    Err(AgentError::MaxSteps(config.max_steps))
322}
323
324/// Run the agent loop with interactive input support.
325///
326/// When a tool returns `ToolOutput::waiting`, the loop pauses and calls `on_input`
327/// with the question. The returned string is injected as the tool result, then the loop continues.
328///
329/// This is the interactive version of `run_loop` — use it when the agent may need
330/// to ask the user questions (via ClarificationTool or similar).
331pub async fn run_loop_interactive<F, Fut>(
332    agent: &dyn Agent,
333    tools: &ToolRegistry,
334    ctx: &mut AgentContext,
335    messages: &mut Vec<Message>,
336    config: &LoopConfig,
337    mut on_event: impl FnMut(LoopEvent),
338    mut on_input: F,
339) -> Result<usize, AgentError>
340where
341    F: FnMut(String) -> Fut,
342    Fut: std::future::Future<Output = String>,
343{
344    let mut detector = LoopDetector::new(config.loop_abort_threshold);
345    let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
346    let mut parse_retries: usize = 0;
347
348    for step in 1..=config.max_steps {
349        if config.max_messages > 0 && messages.len() > config.max_messages {
350            trim_messages(messages, config.max_messages);
351        }
352        ctx.iteration = step;
353        on_event(LoopEvent::StepStart { step });
354
355        agent.prepare_context(ctx, messages);
356
357        let active_tool_names = agent.prepare_tools(ctx, tools);
358        let filtered_tools = if active_tool_names.len() == tools.list().len() {
359            None
360        } else {
361            Some(active_tool_names)
362        };
363        let effective_tools = if let Some(ref names) = filtered_tools {
364            &tools.filter(names)
365        } else {
366            tools
367        };
368
369        let decision = match agent.decide(messages, effective_tools).await {
370            Ok(d) => {
371                parse_retries = 0;
372                d
373            }
374            Err(e) if is_recoverable_error(&e) => {
375                parse_retries += 1;
376                if parse_retries > MAX_PARSE_RETRIES {
377                    return Err(e);
378                }
379                let err_msg = format!(
380                    "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
381                    parse_retries, MAX_PARSE_RETRIES, e
382                );
383                on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
384                    err_msg.clone(),
385                ))));
386                messages.push(Message::user(&err_msg));
387                continue;
388            }
389            Err(e) => return Err(e),
390        };
391        on_event(LoopEvent::Decision(decision.clone()));
392
393        if completion_detector.check(&decision) {
394            ctx.state = AgentState::Completed;
395            if !decision.situation.is_empty() {
396                messages.push(Message::assistant(&decision.situation));
397            }
398            on_event(LoopEvent::Completed { steps: step });
399            return Ok(step);
400        }
401
402        if decision.completed || decision.tool_calls.is_empty() {
403            ctx.state = AgentState::Completed;
404            if !decision.situation.is_empty() {
405                messages.push(Message::assistant(&decision.situation));
406            }
407            on_event(LoopEvent::Completed { steps: step });
408            return Ok(step);
409        }
410
411        let sig: Vec<String> = decision
412            .tool_calls
413            .iter()
414            .map(|tc| tc.name.clone())
415            .collect();
416        match detector.check(&sig) {
417            LoopCheckResult::Abort => {
418                ctx.state = AgentState::Failed;
419                on_event(LoopEvent::LoopDetected {
420                    count: detector.consecutive,
421                });
422                return Err(AgentError::LoopDetected(detector.consecutive));
423            }
424            LoopCheckResult::Tier2Warning(dominant_tool) => {
425                let hint = format!(
426                    "LOOP WARNING: You are repeatedly using '{}' without making progress. \
427                     Try a different approach: re-read the file with read_file to see current contents, \
428                     use write_file instead of edit_file, or break the problem into smaller steps.",
429                    dominant_tool
430                );
431                messages.push(Message::system(&hint));
432            }
433            LoopCheckResult::Ok => {}
434        }
435
436        // Add assistant message with tool calls (Gemini requires model turn before function responses)
437        messages.push(Message::assistant_with_tool_calls(
438            &decision.situation,
439            decision.tool_calls.clone(),
440        ));
441
442        let mut step_outputs: Vec<String> = Vec::new();
443        let mut early_done = false;
444
445        // Partition into read-only (parallel) and write (sequential) tool calls
446        let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
447            .tool_calls
448            .iter()
449            .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
450
451        // Phase 1: read-only tools in parallel
452        if !ro_calls.is_empty() {
453            let futs: Vec<_> = ro_calls
454                .iter()
455                .map(|tc| {
456                    let tool = tools.get(&tc.name).unwrap();
457                    let args = tc.arguments.clone();
458                    let name = tc.name.clone();
459                    let id = tc.id.clone();
460                    async move { (id, name, tool.execute_readonly(args).await) }
461                })
462                .collect();
463
464            for (id, name, result) in join_all(futs).await {
465                match result {
466                    Ok(output) => {
467                        on_event(LoopEvent::ToolResult {
468                            name: name.clone(),
469                            output: output.content.clone(),
470                        });
471                        step_outputs.push(output.content.clone());
472                        agent.after_action(ctx, &name, &output.content);
473                        if output.waiting {
474                            ctx.state = AgentState::WaitingInput;
475                            on_event(LoopEvent::WaitingForInput {
476                                question: output.content.clone(),
477                                tool_call_id: id.clone(),
478                            });
479                            let response = on_input(output.content).await;
480                            ctx.state = AgentState::Running;
481                            messages.push(Message::tool(&id, &response));
482                        } else {
483                            messages.push(Message::tool(&id, &output.content));
484                        }
485                        if output.done {
486                            early_done = true;
487                        }
488                    }
489                    Err(e) => {
490                        let err_msg = format!("Tool error: {}", e);
491                        step_outputs.push(err_msg.clone());
492                        messages.push(Message::tool(&id, &err_msg));
493                        agent.after_action(ctx, &name, &err_msg);
494                        on_event(LoopEvent::ToolResult {
495                            name,
496                            output: err_msg,
497                        });
498                    }
499                }
500            }
501            if early_done && rw_calls.is_empty() {
502                // Only honor early done from read-only tools if no write tools pending
503                ctx.state = AgentState::Completed;
504                on_event(LoopEvent::Completed { steps: step });
505                return Ok(step);
506            }
507        }
508
509        // Phase 2: write tools sequentially (need &mut ctx)
510        for tc in &rw_calls {
511            if let Some(tool) = tools.get(&tc.name) {
512                match tool.execute(tc.arguments.clone(), ctx).await {
513                    Ok(output) => {
514                        on_event(LoopEvent::ToolResult {
515                            name: tc.name.clone(),
516                            output: output.content.clone(),
517                        });
518                        step_outputs.push(output.content.clone());
519                        agent.after_action(ctx, &tc.name, &output.content);
520                        if output.waiting {
521                            ctx.state = AgentState::WaitingInput;
522                            on_event(LoopEvent::WaitingForInput {
523                                question: output.content.clone(),
524                                tool_call_id: tc.id.clone(),
525                            });
526                            let response = on_input(output.content.clone()).await;
527                            ctx.state = AgentState::Running;
528                            messages.push(Message::tool(&tc.id, &response));
529                        } else {
530                            messages.push(Message::tool(&tc.id, &output.content));
531                        }
532                        if output.done {
533                            ctx.state = AgentState::Completed;
534                            on_event(LoopEvent::Completed { steps: step });
535                            return Ok(step);
536                        }
537                    }
538                    Err(e) => {
539                        let err_msg = format!("Tool error: {}", e);
540                        step_outputs.push(err_msg.clone());
541                        messages.push(Message::tool(&tc.id, &err_msg));
542                        agent.after_action(ctx, &tc.name, &err_msg);
543                        on_event(LoopEvent::ToolResult {
544                            name: tc.name.clone(),
545                            output: err_msg,
546                        });
547                    }
548                }
549            } else {
550                let err_msg = format!("Unknown tool: {}", tc.name);
551                step_outputs.push(err_msg.clone());
552                messages.push(Message::tool(&tc.id, &err_msg));
553                on_event(LoopEvent::ToolResult {
554                    name: tc.name.clone(),
555                    output: err_msg,
556                });
557            }
558        }
559
560        if detector.check_outputs(&step_outputs) {
561            ctx.state = AgentState::Failed;
562            on_event(LoopEvent::LoopDetected {
563                count: detector.output_repeat_count,
564            });
565            return Err(AgentError::LoopDetected(detector.output_repeat_count));
566        }
567    }
568
569    ctx.state = AgentState::Failed;
570    Err(AgentError::MaxSteps(config.max_steps))
571}
572
573/// Result of loop detection check.
574#[derive(Debug, PartialEq)]
575enum LoopCheckResult {
576    /// No loop detected.
577    Ok,
578    /// Tier 2 warning: a single tool category dominates. Contains the dominant tool name.
579    /// Agent gets one more chance with a hint injected.
580    Tier2Warning(String),
581    /// Hard loop detected (tier 1 exact repeat, or tier 2 after warning).
582    Abort,
583}
584
585/// 3-tier loop detection:
586/// - Tier 1: exact action signature repeats N times consecutively
587/// - Tier 2: single tool dominates >90% of all calls (warns first, aborts on second trigger)
588/// - Tier 3: tool output stagnation — same results repeating
589struct LoopDetector {
590    threshold: usize,
591    consecutive: usize,
592    last_sig: Vec<String>,
593    tool_freq: HashMap<String, usize>,
594    total_calls: usize,
595    /// Tier 3: hash of last tool outputs to detect stagnation
596    last_output_hash: u64,
597    output_repeat_count: usize,
598    /// Whether tier 2 warning has already been issued (next trigger aborts).
599    tier2_warned: bool,
600}
601
602impl LoopDetector {
603    fn new(threshold: usize) -> Self {
604        Self {
605            threshold,
606            consecutive: 0,
607            last_sig: vec![],
608            tool_freq: HashMap::new(),
609            total_calls: 0,
610            last_output_hash: 0,
611            output_repeat_count: 0,
612            tier2_warned: false,
613        }
614    }
615
616    /// Check action signature for loop.
617    /// Returns `Abort` for tier 1 (exact repeat) or tier 2 after warning.
618    /// Returns `Tier2Warning` on first tier 2 trigger (dominant tool detected).
619    fn check(&mut self, sig: &[String]) -> LoopCheckResult {
620        self.total_calls += 1;
621
622        // Tier 1: exact signature match
623        if sig == self.last_sig {
624            self.consecutive += 1;
625        } else {
626            self.consecutive = 1;
627            self.last_sig = sig.to_vec();
628        }
629        if self.consecutive >= self.threshold {
630            return LoopCheckResult::Abort;
631        }
632
633        // Tier 2: tool name frequency (single tool dominates)
634        for name in sig {
635            *self.tool_freq.entry(name.clone()).or_insert(0) += 1;
636        }
637        if self.total_calls >= self.threshold {
638            for (name, count) in &self.tool_freq {
639                if *count >= self.threshold && *count as f64 / self.total_calls as f64 > 0.9 {
640                    if self.tier2_warned {
641                        return LoopCheckResult::Abort;
642                    }
643                    self.tier2_warned = true;
644                    return LoopCheckResult::Tier2Warning(name.clone());
645                }
646            }
647        }
648
649        LoopCheckResult::Ok
650    }
651
652    /// Check tool outputs for stagnation (tier 3). Call after executing tools each step.
653    fn check_outputs(&mut self, outputs: &[String]) -> bool {
654        use std::collections::hash_map::DefaultHasher;
655        use std::hash::{Hash, Hasher};
656
657        let mut hasher = DefaultHasher::new();
658        outputs.hash(&mut hasher);
659        let hash = hasher.finish();
660
661        if hash == self.last_output_hash && self.last_output_hash != 0 {
662            self.output_repeat_count += 1;
663        } else {
664            self.output_repeat_count = 1;
665            self.last_output_hash = hash;
666        }
667
668        self.output_repeat_count >= self.threshold
669    }
670}
671
672/// Auto-completion detector — catches when agent is done but doesn't call finish_task.
673///
674/// Signals completion when:
675/// - Agent returns same situation text N times (stuck describing same state)
676/// - Situation contains completion keywords ("complete", "finished", "done", "no more")
677struct CompletionDetector {
678    threshold: usize,
679    last_situation: String,
680    repeat_count: usize,
681}
682
683/// Keywords in situation text that suggest task is complete.
684const COMPLETION_KEYWORDS: &[&str] = &[
685    "task is complete",
686    "task is done",
687    "task is finished",
688    "all done",
689    "successfully completed",
690    "nothing more",
691    "no further action",
692    "no more steps",
693];
694
695impl CompletionDetector {
696    fn new(threshold: usize) -> Self {
697        Self {
698            threshold: threshold.max(2),
699            last_situation: String::new(),
700            repeat_count: 0,
701        }
702    }
703
704    /// Check if the decision indicates implicit completion.
705    fn check(&mut self, decision: &Decision) -> bool {
706        // Don't interfere with explicit completion
707        if decision.completed || decision.tool_calls.is_empty() {
708            return false;
709        }
710
711        // Check for completion keywords in situation
712        let sit_lower = decision.situation.to_lowercase();
713        for keyword in COMPLETION_KEYWORDS {
714            if sit_lower.contains(keyword) {
715                return true;
716            }
717        }
718
719        // Check for repeated situation text (agent stuck describing same state)
720        if !decision.situation.is_empty() && decision.situation == self.last_situation {
721            self.repeat_count += 1;
722        } else {
723            self.repeat_count = 1;
724            self.last_situation = decision.situation.clone();
725        }
726
727        self.repeat_count >= self.threshold
728    }
729}
730
731/// Trim messages to fit within max_messages limit.
732/// Keeps: first 2 messages (system + initial user) + last (max - 2) messages.
733fn trim_messages(messages: &mut Vec<Message>, max: usize) {
734    use crate::types::Role;
735
736    if messages.len() <= max || max < 4 {
737        return;
738    }
739    let keep_start = 2; // system + user prompt
740    let remove_count = messages.len() - max + 1;
741    let mut trim_end = keep_start + remove_count;
742
743    // Don't break functionCall → functionResponse pairs.
744    // Gemini requires: model turn (functionCall) → user turn (functionResponse).
745    // If trim_end lands in the middle of such a pair, extend to skip the whole group.
746    //
747    // Case 1: trim_end points at Tool messages — extend past them (they'd be orphaned).
748    while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
749        trim_end += 1;
750    }
751    // Case 2: the first kept message is a Tool — it lost its preceding Assistant.
752    // (Already handled by Case 1, but double-check.)
753    //
754    // Case 3: the last removed message is an Assistant with tool_calls —
755    // the following Tool messages (now first in kept region) would be orphaned.
756    // Extend trim_end to also remove those Tool messages.
757    if trim_end > keep_start && trim_end < messages.len() {
758        let last_removed = trim_end - 1;
759        if messages[last_removed].role == Role::Assistant
760            && !messages[last_removed].tool_calls.is_empty()
761        {
762            // The assistant had tool_calls but we're keeping it... actually we're removing it.
763            // So remove all following Tool messages too.
764            while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
765                trim_end += 1;
766            }
767        }
768    }
769
770    let removed_range = keep_start..trim_end;
771
772    let summary = format!(
773        "[{} messages trimmed from context to stay within {} message limit]",
774        trim_end - keep_start,
775        max
776    );
777
778    messages.drain(removed_range);
779    messages.insert(keep_start, Message::system(&summary));
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::agent::{Agent, AgentError, Decision};
786    use crate::agent_tool::{Tool, ToolError, ToolOutput};
787    use crate::context::AgentContext;
788    use crate::registry::ToolRegistry;
789    use crate::types::{Message, SgrError, ToolCall};
790    use serde_json::Value;
791    use std::sync::Arc;
792    use std::sync::atomic::{AtomicUsize, Ordering};
793
794    struct CountingAgent {
795        max_calls: usize,
796        call_count: Arc<AtomicUsize>,
797    }
798
799    #[async_trait::async_trait]
800    impl Agent for CountingAgent {
801        async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
802            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
803            if n >= self.max_calls {
804                Ok(Decision {
805                    situation: "done".into(),
806                    task: vec![],
807                    tool_calls: vec![],
808                    completed: true,
809                })
810            } else {
811                Ok(Decision {
812                    situation: format!("step {}", n),
813                    task: vec![],
814                    tool_calls: vec![ToolCall {
815                        id: format!("call_{}", n),
816                        name: "echo".into(),
817                        arguments: serde_json::json!({"msg": "hi"}),
818                    }],
819                    completed: false,
820                })
821            }
822        }
823    }
824
825    struct EchoTool;
826
827    #[async_trait::async_trait]
828    impl Tool for EchoTool {
829        fn name(&self) -> &str {
830            "echo"
831        }
832        fn description(&self) -> &str {
833            "echo"
834        }
835        fn parameters_schema(&self) -> Value {
836            serde_json::json!({"type": "object"})
837        }
838        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
839            Ok(ToolOutput::text("echoed"))
840        }
841    }
842
843    #[tokio::test]
844    async fn loop_runs_and_completes() {
845        let agent = CountingAgent {
846            max_calls: 3,
847            call_count: Arc::new(AtomicUsize::new(0)),
848        };
849        let tools = ToolRegistry::new().register(EchoTool);
850        let mut ctx = AgentContext::new();
851        let mut messages = vec![Message::user("go")];
852        let config = LoopConfig::default();
853
854        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
855            .await
856            .unwrap();
857        assert_eq!(steps, 4); // 3 tool calls + 1 completion
858        assert_eq!(ctx.state, AgentState::Completed);
859    }
860
861    #[tokio::test]
862    async fn loop_detects_repetition() {
863        // Agent always returns same tool call → loop detection
864        struct LoopingAgent;
865        #[async_trait::async_trait]
866        impl Agent for LoopingAgent {
867            async fn decide(
868                &self,
869                _: &[Message],
870                _: &ToolRegistry,
871            ) -> Result<Decision, AgentError> {
872                Ok(Decision {
873                    situation: "stuck".into(),
874                    task: vec![],
875                    tool_calls: vec![ToolCall {
876                        id: "1".into(),
877                        name: "echo".into(),
878                        arguments: serde_json::json!({}),
879                    }],
880                    completed: false,
881                })
882            }
883        }
884
885        let tools = ToolRegistry::new().register(EchoTool);
886        let mut ctx = AgentContext::new();
887        let mut messages = vec![Message::user("go")];
888        let config = LoopConfig {
889            max_steps: 50,
890            loop_abort_threshold: 3,
891            auto_complete_threshold: 100, // disable auto-complete for this test
892            ..Default::default()
893        };
894
895        let result = run_loop(
896            &LoopingAgent,
897            &tools,
898            &mut ctx,
899            &mut messages,
900            &config,
901            |_| {},
902        )
903        .await;
904        assert!(matches!(result, Err(AgentError::LoopDetected(3))));
905        assert_eq!(ctx.state, AgentState::Failed);
906    }
907
908    #[tokio::test]
909    async fn loop_max_steps() {
910        // Agent never completes
911        struct NeverDoneAgent;
912        #[async_trait::async_trait]
913        impl Agent for NeverDoneAgent {
914            async fn decide(
915                &self,
916                _: &[Message],
917                _: &ToolRegistry,
918            ) -> Result<Decision, AgentError> {
919                // Different tool names to avoid loop detection
920                static COUNTER: AtomicUsize = AtomicUsize::new(0);
921                let n = COUNTER.fetch_add(1, Ordering::SeqCst);
922                Ok(Decision {
923                    situation: String::new(),
924                    task: vec![],
925                    tool_calls: vec![ToolCall {
926                        id: format!("{}", n),
927                        name: format!("tool_{}", n),
928                        arguments: serde_json::json!({}),
929                    }],
930                    completed: false,
931                })
932            }
933        }
934
935        let tools = ToolRegistry::new().register(EchoTool);
936        let mut ctx = AgentContext::new();
937        let mut messages = vec![Message::user("go")];
938        let config = LoopConfig {
939            max_steps: 5,
940            loop_abort_threshold: 100,
941            ..Default::default()
942        };
943
944        let result = run_loop(
945            &NeverDoneAgent,
946            &tools,
947            &mut ctx,
948            &mut messages,
949            &config,
950            |_| {},
951        )
952        .await;
953        assert!(matches!(result, Err(AgentError::MaxSteps(5))));
954    }
955
956    #[test]
957    fn loop_detector_exact_sig() {
958        let mut d = LoopDetector::new(3);
959        let sig = vec!["bash".to_string()];
960        assert_eq!(d.check(&sig), LoopCheckResult::Ok);
961        assert_eq!(d.check(&sig), LoopCheckResult::Ok);
962        assert_eq!(d.check(&sig), LoopCheckResult::Abort); // 3rd consecutive
963    }
964
965    #[test]
966    fn loop_detector_different_sigs_reset() {
967        let mut d = LoopDetector::new(3);
968        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
969        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
970        assert_eq!(d.check(&["read".into()]), LoopCheckResult::Ok); // different → resets
971        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
972    }
973
974    #[test]
975    fn loop_detector_tier2_warning_then_abort() {
976        // Tier 2 requires: count >= threshold AND count/total > 0.9
977        // Use threshold=3. To avoid tier 1 (exact consecutive), alternate sigs.
978        let mut d = LoopDetector::new(3);
979        // Calls 1-2: build up frequency, total_calls < threshold so tier 2 not checked
980        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); // total=1, edit=1, cons=1
981        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); // total=2, edit=2, cons=2
982        // Call 3: break consecutive (different sig) but edit_file still in sig
983        // total=3, edit=3, cons=1 → tier 2: 3/3=1.0 > 0.9 → first warning
984        assert_eq!(
985            d.check(&["edit_file".into(), "read_file".into()]),
986            LoopCheckResult::Tier2Warning("edit_file".into())
987        );
988        // Call 4: tier 2 already warned → abort
989        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Abort);
990    }
991
992    #[test]
993    fn loop_config_default() {
994        let c = LoopConfig::default();
995        assert_eq!(c.max_steps, 50);
996        assert_eq!(c.loop_abort_threshold, 6);
997    }
998
999    #[test]
1000    fn loop_detector_output_stagnation() {
1001        let mut d = LoopDetector::new(3);
1002        let outputs = vec!["same result".to_string()];
1003        assert!(!d.check_outputs(&outputs));
1004        assert!(!d.check_outputs(&outputs));
1005        assert!(d.check_outputs(&outputs)); // 3rd repeat
1006    }
1007
1008    #[test]
1009    fn completion_detector_keyword() {
1010        let mut cd = CompletionDetector::new(3);
1011        let d = Decision {
1012            situation: "The task is complete, all files written.".into(),
1013            task: vec![],
1014            tool_calls: vec![ToolCall {
1015                id: "1".into(),
1016                name: "echo".into(),
1017                arguments: serde_json::json!({}),
1018            }],
1019            completed: false,
1020        };
1021        assert!(cd.check(&d));
1022    }
1023
1024    #[test]
1025    fn completion_detector_repeated_situation() {
1026        let mut cd = CompletionDetector::new(3);
1027        let d = Decision {
1028            situation: "working on it".into(),
1029            task: vec![],
1030            tool_calls: vec![ToolCall {
1031                id: "1".into(),
1032                name: "echo".into(),
1033                arguments: serde_json::json!({}),
1034            }],
1035            completed: false,
1036        };
1037        assert!(!cd.check(&d));
1038        assert!(!cd.check(&d));
1039        assert!(cd.check(&d)); // 3rd repeat
1040    }
1041
1042    #[test]
1043    fn completion_detector_ignores_explicit_completion() {
1044        let mut cd = CompletionDetector::new(2);
1045        let d = Decision {
1046            situation: "task is complete".into(),
1047            task: vec![],
1048            tool_calls: vec![],
1049            completed: true,
1050        };
1051        // Should return false — let normal completion handling take over
1052        assert!(!cd.check(&d));
1053    }
1054
1055    #[test]
1056    fn trim_messages_basic() {
1057        let mut msgs: Vec<Message> = (0..10).map(|i| Message::user(format!("msg {i}"))).collect();
1058        trim_messages(&mut msgs, 6);
1059        // first 2 + summary + last 3 = 6
1060        assert_eq!(msgs.len(), 6);
1061        assert!(msgs[2].content.contains("trimmed"));
1062    }
1063
1064    #[test]
1065    fn trim_messages_no_op_when_under_limit() {
1066        let mut msgs = vec![Message::user("a"), Message::user("b")];
1067        trim_messages(&mut msgs, 10);
1068        assert_eq!(msgs.len(), 2);
1069    }
1070
1071    #[test]
1072    fn trim_messages_preserves_assistant_tool_call_pair() {
1073        use crate::types::Role;
1074        // system, user, assistant(tool_calls), tool, tool, user, assistant
1075        let mut msgs = vec![
1076            Message::system("sys"),
1077            Message::user("prompt"),
1078            Message::assistant_with_tool_calls(
1079                "calling",
1080                vec![
1081                    ToolCall {
1082                        id: "c1".into(),
1083                        name: "read".into(),
1084                        arguments: serde_json::json!({}),
1085                    },
1086                    ToolCall {
1087                        id: "c2".into(),
1088                        name: "read".into(),
1089                        arguments: serde_json::json!({}),
1090                    },
1091                ],
1092            ),
1093            Message::tool("c1", "result1"),
1094            Message::tool("c2", "result2"),
1095            Message::user("next"),
1096            Message::assistant("done"),
1097        ];
1098        // Trim to 5 — should remove assistant+tools as a group, not split them
1099        trim_messages(&mut msgs, 5);
1100        // Verify no orphaned Tool messages remain
1101        for (i, msg) in msgs.iter().enumerate() {
1102            if msg.role == Role::Tool {
1103                // The previous message should be an Assistant with tool_calls
1104                assert!(i > 0, "Tool message at start");
1105                assert!(
1106                    msgs[i - 1].role == Role::Assistant && !msgs[i - 1].tool_calls.is_empty()
1107                        || msgs[i - 1].role == Role::Tool,
1108                    "Orphaned Tool at position {i}"
1109                );
1110            }
1111        }
1112    }
1113
1114    #[test]
1115    fn loop_detector_output_stagnation_resets_on_change() {
1116        let mut d = LoopDetector::new(3);
1117        let a = vec!["result A".to_string()];
1118        let b = vec!["result B".to_string()];
1119        assert!(!d.check_outputs(&a));
1120        assert!(!d.check_outputs(&a));
1121        assert!(!d.check_outputs(&b)); // different → resets
1122        assert!(!d.check_outputs(&a));
1123    }
1124
1125    #[tokio::test]
1126    async fn loop_handles_non_recoverable_llm_error() {
1127        struct FailingAgent;
1128        #[async_trait::async_trait]
1129        impl Agent for FailingAgent {
1130            async fn decide(
1131                &self,
1132                _: &[Message],
1133                _: &ToolRegistry,
1134            ) -> Result<Decision, AgentError> {
1135                Err(AgentError::Llm(SgrError::Api {
1136                    status: 500,
1137                    body: "internal server error".into(),
1138                }))
1139            }
1140        }
1141
1142        let tools = ToolRegistry::new().register(EchoTool);
1143        let mut ctx = AgentContext::new();
1144        let mut messages = vec![Message::user("go")];
1145        let config = LoopConfig::default();
1146
1147        let result = run_loop(
1148            &FailingAgent,
1149            &tools,
1150            &mut ctx,
1151            &mut messages,
1152            &config,
1153            |_| {},
1154        )
1155        .await;
1156        // Non-recoverable: should fail immediately, no retries
1157        assert!(result.is_err());
1158        assert_eq!(messages.len(), 1); // no feedback messages added
1159    }
1160
1161    #[tokio::test]
1162    async fn loop_recovers_from_parse_error() {
1163        // Agent fails with parse error on first call, succeeds on retry
1164        struct ParseRetryAgent {
1165            call_count: Arc<AtomicUsize>,
1166        }
1167        #[async_trait::async_trait]
1168        impl Agent for ParseRetryAgent {
1169            async fn decide(
1170                &self,
1171                msgs: &[Message],
1172                _: &ToolRegistry,
1173            ) -> Result<Decision, AgentError> {
1174                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1175                if n == 0 {
1176                    // First call: simulate parse error
1177                    Err(AgentError::Llm(SgrError::Schema(
1178                        "Missing required field: situation".into(),
1179                    )))
1180                } else {
1181                    // Second call: should see error feedback in messages
1182                    let last = msgs.last().unwrap();
1183                    assert!(
1184                        last.content.contains("Parse error"),
1185                        "expected parse error feedback, got: {}",
1186                        last.content
1187                    );
1188                    Ok(Decision {
1189                        situation: "recovered from parse error".into(),
1190                        task: vec![],
1191                        tool_calls: vec![],
1192                        completed: true,
1193                    })
1194                }
1195            }
1196        }
1197
1198        let tools = ToolRegistry::new().register(EchoTool);
1199        let mut ctx = AgentContext::new();
1200        let mut messages = vec![Message::user("go")];
1201        let config = LoopConfig::default();
1202        let agent = ParseRetryAgent {
1203            call_count: Arc::new(AtomicUsize::new(0)),
1204        };
1205
1206        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1207            .await
1208            .unwrap();
1209        assert_eq!(steps, 2); // step 1 failed parse, step 2 succeeded
1210        assert_eq!(ctx.state, AgentState::Completed);
1211    }
1212
1213    #[tokio::test]
1214    async fn loop_aborts_after_max_parse_retries() {
1215        struct AlwaysFailParseAgent;
1216        #[async_trait::async_trait]
1217        impl Agent for AlwaysFailParseAgent {
1218            async fn decide(
1219                &self,
1220                _: &[Message],
1221                _: &ToolRegistry,
1222            ) -> Result<Decision, AgentError> {
1223                Err(AgentError::Llm(SgrError::Schema("bad json".into())))
1224            }
1225        }
1226
1227        let tools = ToolRegistry::new().register(EchoTool);
1228        let mut ctx = AgentContext::new();
1229        let mut messages = vec![Message::user("go")];
1230        let config = LoopConfig::default();
1231
1232        let result = run_loop(
1233            &AlwaysFailParseAgent,
1234            &tools,
1235            &mut ctx,
1236            &mut messages,
1237            &config,
1238            |_| {},
1239        )
1240        .await;
1241        assert!(result.is_err());
1242        // Should have added MAX_PARSE_RETRIES feedback messages
1243        let feedback_count = messages
1244            .iter()
1245            .filter(|m| m.content.contains("Parse error"))
1246            .count();
1247        assert_eq!(feedback_count, MAX_PARSE_RETRIES);
1248    }
1249
1250    #[tokio::test]
1251    async fn loop_feeds_tool_errors_back() {
1252        // Agent calls unknown tool → error fed back → agent completes
1253        struct ErrorRecoveryAgent {
1254            call_count: Arc<AtomicUsize>,
1255        }
1256        #[async_trait::async_trait]
1257        impl Agent for ErrorRecoveryAgent {
1258            async fn decide(
1259                &self,
1260                msgs: &[Message],
1261                _: &ToolRegistry,
1262            ) -> Result<Decision, AgentError> {
1263                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1264                if n == 0 {
1265                    // First: call unknown tool
1266                    Ok(Decision {
1267                        situation: "trying".into(),
1268                        task: vec![],
1269                        tool_calls: vec![ToolCall {
1270                            id: "1".into(),
1271                            name: "nonexistent_tool".into(),
1272                            arguments: serde_json::json!({}),
1273                        }],
1274                        completed: false,
1275                    })
1276                } else {
1277                    // Second: should see error in messages, complete
1278                    let last = msgs.last().unwrap();
1279                    assert!(last.content.contains("Unknown tool"));
1280                    Ok(Decision {
1281                        situation: "recovered".into(),
1282                        task: vec![],
1283                        tool_calls: vec![],
1284                        completed: true,
1285                    })
1286                }
1287            }
1288        }
1289
1290        let tools = ToolRegistry::new().register(EchoTool);
1291        let mut ctx = AgentContext::new();
1292        let mut messages = vec![Message::user("go")];
1293        let config = LoopConfig::default();
1294        let agent = ErrorRecoveryAgent {
1295            call_count: Arc::new(AtomicUsize::new(0)),
1296        };
1297
1298        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1299            .await
1300            .unwrap();
1301        assert_eq!(steps, 2);
1302        assert_eq!(ctx.state, AgentState::Completed);
1303    }
1304
1305    #[tokio::test]
1306    async fn parallel_readonly_tools() {
1307        struct ReadOnlyTool {
1308            name: &'static str,
1309        }
1310
1311        #[async_trait::async_trait]
1312        impl Tool for ReadOnlyTool {
1313            fn name(&self) -> &str {
1314                self.name
1315            }
1316            fn description(&self) -> &str {
1317                "read-only tool"
1318            }
1319            fn is_read_only(&self) -> bool {
1320                true
1321            }
1322            fn parameters_schema(&self) -> Value {
1323                serde_json::json!({"type": "object"})
1324            }
1325            async fn execute(
1326                &self,
1327                _: Value,
1328                _: &mut AgentContext,
1329            ) -> Result<ToolOutput, ToolError> {
1330                Ok(ToolOutput::text(format!("{} result", self.name)))
1331            }
1332            async fn execute_readonly(&self, _: Value) -> Result<ToolOutput, ToolError> {
1333                Ok(ToolOutput::text(format!("{} result", self.name)))
1334            }
1335        }
1336
1337        struct ParallelAgent;
1338        #[async_trait::async_trait]
1339        impl Agent for ParallelAgent {
1340            async fn decide(
1341                &self,
1342                msgs: &[Message],
1343                _: &ToolRegistry,
1344            ) -> Result<Decision, AgentError> {
1345                if msgs.len() > 3 {
1346                    return Ok(Decision {
1347                        situation: "done".into(),
1348                        task: vec![],
1349                        tool_calls: vec![],
1350                        completed: true,
1351                    });
1352                }
1353                Ok(Decision {
1354                    situation: "reading".into(),
1355                    task: vec![],
1356                    tool_calls: vec![
1357                        ToolCall {
1358                            id: "1".into(),
1359                            name: "reader_a".into(),
1360                            arguments: serde_json::json!({}),
1361                        },
1362                        ToolCall {
1363                            id: "2".into(),
1364                            name: "reader_b".into(),
1365                            arguments: serde_json::json!({}),
1366                        },
1367                    ],
1368                    completed: false,
1369                })
1370            }
1371        }
1372
1373        let tools = ToolRegistry::new()
1374            .register(ReadOnlyTool { name: "reader_a" })
1375            .register(ReadOnlyTool { name: "reader_b" });
1376        let mut ctx = AgentContext::new();
1377        let mut messages = vec![Message::user("read stuff")];
1378        let config = LoopConfig::default();
1379
1380        let steps = run_loop(
1381            &ParallelAgent,
1382            &tools,
1383            &mut ctx,
1384            &mut messages,
1385            &config,
1386            |_| {},
1387        )
1388        .await
1389        .unwrap();
1390        assert!(steps > 0);
1391        assert_eq!(ctx.state, AgentState::Completed);
1392    }
1393
1394    #[tokio::test]
1395    async fn loop_events_are_emitted() {
1396        let agent = CountingAgent {
1397            max_calls: 1,
1398            call_count: Arc::new(AtomicUsize::new(0)),
1399        };
1400        let tools = ToolRegistry::new().register(EchoTool);
1401        let mut ctx = AgentContext::new();
1402        let mut messages = vec![Message::user("go")];
1403        let config = LoopConfig::default();
1404
1405        let mut events = Vec::new();
1406        run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |e| {
1407            events.push(format!("{:?}", std::mem::discriminant(&e)));
1408        })
1409        .await
1410        .unwrap();
1411
1412        // Should have: StepStart, Decision, ToolResult, StepStart, Decision, Completed
1413        assert!(events.len() >= 4);
1414    }
1415}