Skip to main content

sgr_agent/
agent_loop.rs

1//! Generic agent loop — drives agent + tools until completion or limit.
2//!
3//! Includes 3-tier loop detection (exact signature, tool name frequency, output stagnation).
4//! Features from Claude Code (HitCC analysis):
5//! - Tool result pairing repair — ensures every tool_use has a matching tool_result
6//! - Context modifiers — tools can adjust runtime behavior via ToolOutput::modifier
7//! - Max output tokens recovery — auto-continuation when response is truncated
8
9use crate::agent::{Agent, AgentError, Decision};
10use crate::context::{AgentContext, AgentState};
11use crate::registry::ToolRegistry;
12use crate::retry::{RetryConfig, delay_for_attempt, is_retryable};
13use crate::types::{Message, Role, SgrError};
14use futures::future::join_all;
15use std::collections::HashMap;
16
17/// Max consecutive parsing errors before aborting the loop.
18const MAX_PARSE_RETRIES: usize = 3;
19
20/// Max retries for transient LLM errors (rate limit, timeout, 5xx).
21const MAX_TRANSIENT_RETRIES: usize = 3;
22
23/// Max auto-continuation attempts when response is truncated (max_output_tokens).
24const MAX_OUTPUT_TOKENS_RECOVERIES: usize = 3;
25
26/// Check if an agent error is recoverable (parsing/empty response).
27fn is_recoverable_error(e: &AgentError) -> bool {
28    matches!(
29        e,
30        AgentError::Llm(SgrError::Json(_))
31            | AgentError::Llm(SgrError::EmptyResponse)
32            | AgentError::Llm(SgrError::Schema(_))
33    )
34}
35
36/// Wrap `agent.decide_stateful()` with retry for transient LLM errors (rate limit, timeout, 5xx).
37/// Parse errors and tool errors are NOT retried here (handled by the caller).
38async fn decide_with_retry(
39    agent: &dyn Agent,
40    messages: &[Message],
41    tools: &ToolRegistry,
42    previous_response_id: Option<&str>,
43) -> Result<(Decision, Option<String>), AgentError> {
44    let retry_config = RetryConfig {
45        max_retries: MAX_TRANSIENT_RETRIES,
46        base_delay_ms: 500,
47        max_delay_ms: 30_000,
48    };
49
50    for attempt in 0..=retry_config.max_retries {
51        match agent
52            .decide_stateful(messages, tools, previous_response_id)
53            .await
54        {
55            Ok(d) => return Ok(d),
56            Err(AgentError::Llm(sgr_err))
57                if is_retryable(&sgr_err) && attempt < retry_config.max_retries =>
58            {
59                let delay = delay_for_attempt(attempt, &retry_config, &sgr_err);
60                tracing::warn!(
61                    attempt = attempt + 1,
62                    max = retry_config.max_retries,
63                    delay_ms = delay.as_millis() as u64,
64                    "Retrying agent.decide(): {}",
65                    sgr_err
66                );
67                tokio::time::sleep(delay).await;
68                // Loop continues — on last attempt, fall through to return the error
69            }
70            Err(e) => return Err(e),
71        }
72    }
73    // If we exhausted all retries, do one final attempt and return its result directly
74    agent
75        .decide_stateful(messages, tools, previous_response_id)
76        .await
77}
78
79/// Ensure every tool_use in messages has a matching tool_result, and vice versa.
80///
81/// Repairs the transcript before sending to the API — prevents crashes from:
82/// - Tool panics/timeouts leaving orphaned tool_use without tool_result
83/// - Duplicate tool_results for the same tool_use_id
84/// - Orphaned tool_results without a preceding tool_use
85///
86/// This is called before each `agent.decide()` call to keep the transcript valid.
87pub fn ensure_tool_result_pairing(messages: &mut Vec<Message>) {
88    // Collect all tool_use IDs from assistant messages
89    let mut expected_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
90    for msg in messages.iter() {
91        if msg.role == Role::Assistant {
92            for tc in &msg.tool_calls {
93                expected_ids.insert(tc.id.clone());
94            }
95        }
96    }
97
98    // Collect all tool_result IDs already present
99    let mut seen_result_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
100    // Track indices of duplicate tool_results to remove
101    let mut to_remove: Vec<usize> = Vec::new();
102
103    for (i, msg) in messages.iter().enumerate() {
104        if msg.role == Role::Tool
105            && let Some(ref id) = msg.tool_call_id
106        {
107            if !seen_result_ids.insert(id.clone()) {
108                // Duplicate tool_result — mark for removal
109                to_remove.push(i);
110            } else if !expected_ids.contains(id) {
111                // Orphaned tool_result (no matching tool_use) — mark for removal
112                to_remove.push(i);
113            }
114        }
115    }
116
117    // Remove duplicates/orphans in reverse order to preserve indices
118    for i in to_remove.into_iter().rev() {
119        tracing::debug!(
120            tool_call_id = messages[i].tool_call_id.as_deref().unwrap_or("?"),
121            "Removing orphaned/duplicate tool_result"
122        );
123        messages.remove(i);
124    }
125
126    // Add synthetic tool_results for missing pairs
127    // Walk through messages and insert after each assistant+tool_calls block
128    let mut i = 0;
129    while i < messages.len() {
130        if messages[i].role == Role::Assistant && !messages[i].tool_calls.is_empty() {
131            let tool_call_ids: Vec<String> = messages[i]
132                .tool_calls
133                .iter()
134                .map(|tc| tc.id.clone())
135                .collect();
136
137            // Check which IDs have results in the subsequent Tool messages
138            let mut insert_pos = i + 1;
139            let mut found_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
140            while insert_pos < messages.len() && messages[insert_pos].role == Role::Tool {
141                if let Some(ref id) = messages[insert_pos].tool_call_id {
142                    found_ids.insert(id.clone());
143                }
144                insert_pos += 1;
145            }
146
147            // Insert synthetic results for missing IDs
148            for id in &tool_call_ids {
149                if !found_ids.contains(id) {
150                    tracing::debug!(
151                        tool_call_id = id.as_str(),
152                        "Inserting synthetic tool_result for orphaned tool_use"
153                    );
154                    messages.insert(
155                        insert_pos,
156                        Message::tool(id, "[Tool result missing due to internal error]"),
157                    );
158                    insert_pos += 1;
159                }
160            }
161            i = insert_pos;
162        } else {
163            i += 1;
164        }
165    }
166}
167
168/// Apply a context modifier from a tool output to the agent context and loop config.
169fn apply_context_modifier(
170    modifier: &crate::agent_tool::ContextModifier,
171    ctx: &mut AgentContext,
172    messages: &mut Vec<Message>,
173    effective_max_steps: &mut usize,
174) {
175    if let Some(ref injection) = modifier.system_injection {
176        // Use Role::User, not Role::System — mid-conversation system messages
177        // are unsupported by Gemini and silently dropped by some providers.
178        messages.push(Message::user(format!("[Context update]: {injection}")));
179    }
180    for (key, value) in &modifier.custom_context {
181        ctx.set(key.clone(), value.clone());
182    }
183    if let Some(delta) = modifier.max_steps_delta {
184        if delta > 0 {
185            *effective_max_steps = effective_max_steps.saturating_add(delta as usize);
186        } else {
187            *effective_max_steps =
188                effective_max_steps.saturating_sub(delta.unsigned_abs() as usize);
189        }
190    }
191    if let Some(tokens) = modifier.max_tokens_override {
192        ctx.set(
193            crate::agent_tool::MAX_TOKENS_OVERRIDE_KEY.to_string(),
194            serde_json::Value::Number(tokens.into()),
195        );
196    }
197}
198
199/// Loop configuration.
200#[derive(Debug, Clone)]
201pub struct LoopConfig {
202    /// Maximum steps before aborting.
203    pub max_steps: usize,
204    /// Consecutive repeated tool calls before loop detection triggers.
205    pub loop_abort_threshold: usize,
206    /// Max messages to keep in context (0 = unlimited).
207    /// Keeps first 2 (system + user prompt) + last N messages.
208    pub max_messages: usize,
209    /// Auto-complete if agent returns same situation text N times.
210    pub auto_complete_threshold: usize,
211}
212
213impl Default for LoopConfig {
214    fn default() -> Self {
215        Self {
216            max_steps: 50,
217            loop_abort_threshold: 6,
218            max_messages: 80,
219            auto_complete_threshold: 3,
220        }
221    }
222}
223
224/// Events emitted during the agent loop.
225#[derive(Debug)]
226pub enum LoopEvent {
227    StepStart {
228        step: usize,
229    },
230    Decision(Decision),
231    ToolResult {
232        name: String,
233        output: String,
234    },
235    Completed {
236        steps: usize,
237    },
238    LoopDetected {
239        count: usize,
240    },
241    Error(AgentError),
242    /// Agent needs user input. Content is the question.
243    WaitingForInput {
244        question: String,
245        tool_call_id: String,
246    },
247    /// Response was truncated, requesting auto-continuation.
248    MaxOutputTokensRecovery {
249        attempt: usize,
250    },
251    /// Prompt exceeded model's context limit.
252    PromptTooLong {
253        message: String,
254    },
255    /// A tool returned a context modifier that was applied.
256    ContextModified {
257        tool_name: String,
258    },
259}
260
261/// Run the agent loop: decide → execute tools → feed results → repeat.
262///
263/// Returns the number of steps taken.
264/// Non-interactive: when a tool returns `ToolOutput::waiting`, emits a
265/// `WaitingForInput` event and uses `"[waiting for user input]"` as placeholder.
266/// For interactive use (actual user input), use `run_loop_interactive`.
267pub async fn run_loop(
268    agent: &dyn Agent,
269    tools: &ToolRegistry,
270    ctx: &mut AgentContext,
271    messages: &mut Vec<Message>,
272    config: &LoopConfig,
273    on_event: impl FnMut(LoopEvent),
274) -> Result<usize, AgentError> {
275    // Delegate to the unified interactive loop with a passive input handler.
276    // When a tool needs input, it gets the placeholder string instead of blocking.
277    run_loop_interactive(
278        agent,
279        tools,
280        ctx,
281        messages,
282        config,
283        on_event,
284        |_question: String| async { "[waiting for user input]".to_string() },
285    )
286    .await
287}
288
289/// Core agent loop — single implementation for both interactive and non-interactive modes.
290///
291/// When a tool returns `ToolOutput::waiting`, calls `on_input` with the question.
292/// `run_loop` delegates here with a passive handler that returns a placeholder.
293pub async fn run_loop_interactive<F, Fut>(
294    agent: &dyn Agent,
295    tools: &ToolRegistry,
296    ctx: &mut AgentContext,
297    messages: &mut Vec<Message>,
298    config: &LoopConfig,
299    mut on_event: impl FnMut(LoopEvent),
300    mut on_input: F,
301) -> Result<usize, AgentError>
302where
303    F: FnMut(String) -> Fut,
304    Fut: std::future::Future<Output = String>,
305{
306    let mut detector = LoopDetector::new(config.loop_abort_threshold);
307    let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
308    let mut parse_retries: usize = 0;
309    let mut response_id: Option<String> = None;
310    let mut max_output_tokens_recoveries: usize = 0;
311    let mut effective_max_steps = config.max_steps;
312
313    let mut step = 0;
314    while {
315        step += 1;
316        step <= effective_max_steps
317    } {
318        if config.max_messages > 0 && messages.len() > config.max_messages {
319            trim_messages(messages, config.max_messages);
320        }
321
322        // Tool result pairing repair
323        ensure_tool_result_pairing(messages);
324
325        ctx.iteration = step;
326        on_event(LoopEvent::StepStart { step });
327
328        agent.prepare_context(ctx, messages);
329
330        let active_tool_names = agent.prepare_tools(ctx, tools);
331        let filtered_tools = if active_tool_names.len() == tools.list().len() {
332            None
333        } else {
334            Some(active_tool_names)
335        };
336        let effective_tools = if let Some(ref names) = filtered_tools {
337            &tools.filter(names)
338        } else {
339            tools
340        };
341
342        let decision = match decide_with_retry(
343            agent,
344            messages,
345            effective_tools,
346            response_id.as_deref(),
347        )
348        .await
349        {
350            Ok((d, new_rid)) => {
351                parse_retries = 0;
352                max_output_tokens_recoveries = 0;
353                response_id = new_rid;
354                d
355            }
356            Err(AgentError::Llm(SgrError::MaxOutputTokens { partial_content })) => {
357                max_output_tokens_recoveries += 1;
358                if max_output_tokens_recoveries > MAX_OUTPUT_TOKENS_RECOVERIES {
359                    return Err(AgentError::Llm(SgrError::MaxOutputTokens {
360                        partial_content,
361                    }));
362                }
363                if !partial_content.is_empty() {
364                    messages.push(Message::assistant(&partial_content));
365                }
366                messages.push(Message::user(
367                    "Your response was cut off. Resume directly from where you stopped. \
368                     No apology, no recap — pick up mid-thought.",
369                ));
370                on_event(LoopEvent::MaxOutputTokensRecovery {
371                    attempt: max_output_tokens_recoveries,
372                });
373                continue;
374            }
375            Err(AgentError::Llm(SgrError::PromptTooLong(msg))) => {
376                on_event(LoopEvent::PromptTooLong {
377                    message: msg.clone(),
378                });
379                return Err(AgentError::Llm(SgrError::PromptTooLong(msg)));
380            }
381            Err(e) if is_recoverable_error(&e) => {
382                parse_retries += 1;
383                if parse_retries > MAX_PARSE_RETRIES {
384                    return Err(e);
385                }
386                let err_msg = format!(
387                    "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
388                    parse_retries, MAX_PARSE_RETRIES, e
389                );
390                on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
391                    err_msg.clone(),
392                ))));
393                messages.push(Message::user(&err_msg));
394                continue;
395            }
396            Err(e) => return Err(e),
397        };
398        on_event(LoopEvent::Decision(decision.clone()));
399
400        if completion_detector.check(&decision) {
401            ctx.state = AgentState::Completed;
402            if !decision.situation.is_empty() {
403                messages.push(Message::assistant(&decision.situation));
404            }
405            on_event(LoopEvent::Completed { steps: step });
406            return Ok(step);
407        }
408
409        if decision.completed || decision.tool_calls.is_empty() {
410            ctx.state = AgentState::Completed;
411            if !decision.situation.is_empty() {
412                messages.push(Message::assistant(&decision.situation));
413            }
414            on_event(LoopEvent::Completed { steps: step });
415            return Ok(step);
416        }
417
418        let sig: Vec<String> = decision
419            .tool_calls
420            .iter()
421            .map(|tc| tc.name.clone())
422            .collect();
423        match detector.check(&sig) {
424            LoopCheckResult::Abort => {
425                ctx.state = AgentState::Failed;
426                on_event(LoopEvent::LoopDetected {
427                    count: detector.consecutive,
428                });
429                return Err(AgentError::LoopDetected(detector.consecutive));
430            }
431            LoopCheckResult::Tier2Warning(dominant_tool) => {
432                let hint = format!(
433                    "LOOP WARNING: You are repeatedly using '{}' without making progress. \
434                     Try a different approach: re-read the file with read_file to see current contents, \
435                     use write_file instead of edit_file, or break the problem into smaller steps.",
436                    dominant_tool
437                );
438                messages.push(Message::system(&hint));
439            }
440            LoopCheckResult::Ok => {}
441        }
442
443        // Add assistant message with tool calls (Gemini requires model turn before function responses)
444        messages.push(Message::assistant_with_tool_calls(
445            &decision.situation,
446            decision.tool_calls.clone(),
447        ));
448
449        let mut step_outputs: Vec<String> = Vec::new();
450        let mut early_done = false;
451
452        // Partition into read-only (parallel) and write (sequential) tool calls
453        let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
454            .tool_calls
455            .iter()
456            .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
457
458        // Phase 1: read-only tools in parallel (shared read-only context ref)
459        if !ro_calls.is_empty() {
460            let ctx_snapshot = ctx.clone(); // snapshot for read-only parallel access
461            let futs: Vec<_> = ro_calls
462                .iter()
463                .map(|tc| {
464                    let tool = tools.get(&tc.name).unwrap();
465                    let args = tc.arguments.clone();
466                    let name = tc.name.clone();
467                    let id = tc.id.clone();
468                    let ctx_ref = &ctx_snapshot;
469                    async move { (id, name, tool.execute_readonly(args, ctx_ref).await) }
470                })
471                .collect();
472
473            let mut pending_modifiers: Vec<(String, crate::agent_tool::ContextModifier)> =
474                Vec::new();
475
476            for (id, name, result) in join_all(futs).await {
477                match result {
478                    Ok(output) => {
479                        on_event(LoopEvent::ToolResult {
480                            name: name.clone(),
481                            output: output.content.clone(),
482                        });
483                        step_outputs.push(output.content.clone());
484                        agent.after_action(ctx, &name, &output.content);
485                        if let Some(modifier) = output.modifier.clone()
486                            && !modifier.is_empty()
487                        {
488                            pending_modifiers.push((name.clone(), modifier));
489                        }
490                        if output.waiting {
491                            ctx.state = AgentState::WaitingInput;
492                            on_event(LoopEvent::WaitingForInput {
493                                question: output.content.clone(),
494                                tool_call_id: id.clone(),
495                            });
496                            let response = on_input(output.content).await;
497                            ctx.state = AgentState::Running;
498                            messages.push(Message::tool(&id, &response));
499                        } else {
500                            messages.push(Message::tool(&id, &output.content));
501                        }
502                        if output.done {
503                            early_done = true;
504                        }
505                    }
506                    Err(e) => {
507                        let err_msg = format!("Tool error: {}", e);
508                        step_outputs.push(err_msg.clone());
509                        messages.push(Message::tool(&id, &err_msg));
510                        agent.after_action(ctx, &name, &err_msg);
511                        on_event(LoopEvent::ToolResult {
512                            name,
513                            output: err_msg,
514                        });
515                    }
516                }
517            }
518
519            for (name, modifier) in pending_modifiers {
520                apply_context_modifier(&modifier, ctx, messages, &mut effective_max_steps);
521                on_event(LoopEvent::ContextModified { tool_name: name });
522            }
523
524            if early_done && rw_calls.is_empty() {
525                ctx.state = AgentState::Completed;
526                on_event(LoopEvent::Completed { steps: step });
527                return Ok(step);
528            }
529        }
530
531        // Phase 2: write tools sequentially (need &mut ctx)
532        for tc in &rw_calls {
533            if let Some(tool) = tools.get(&tc.name) {
534                match tool.execute(tc.arguments.clone(), ctx).await {
535                    Ok(output) => {
536                        on_event(LoopEvent::ToolResult {
537                            name: tc.name.clone(),
538                            output: output.content.clone(),
539                        });
540                        step_outputs.push(output.content.clone());
541                        agent.after_action(ctx, &tc.name, &output.content);
542                        if let Some(ref modifier) = output.modifier
543                            && !modifier.is_empty()
544                        {
545                            apply_context_modifier(
546                                modifier,
547                                ctx,
548                                messages,
549                                &mut effective_max_steps,
550                            );
551                            on_event(LoopEvent::ContextModified {
552                                tool_name: tc.name.clone(),
553                            });
554                        }
555                        if output.waiting {
556                            ctx.state = AgentState::WaitingInput;
557                            on_event(LoopEvent::WaitingForInput {
558                                question: output.content.clone(),
559                                tool_call_id: tc.id.clone(),
560                            });
561                            let response = on_input(output.content.clone()).await;
562                            ctx.state = AgentState::Running;
563                            messages.push(Message::tool(&tc.id, &response));
564                        } else {
565                            messages.push(Message::tool(&tc.id, &output.content));
566                        }
567                        if output.done {
568                            ctx.state = AgentState::Completed;
569                            on_event(LoopEvent::Completed { steps: step });
570                            return Ok(step);
571                        }
572                    }
573                    Err(e) => {
574                        let err_msg = format!("Tool error: {}", e);
575                        step_outputs.push(err_msg.clone());
576                        messages.push(Message::tool(&tc.id, &err_msg));
577                        agent.after_action(ctx, &tc.name, &err_msg);
578                        on_event(LoopEvent::ToolResult {
579                            name: tc.name.clone(),
580                            output: err_msg,
581                        });
582                    }
583                }
584            } else {
585                let err_msg = format!("Unknown tool: {}", tc.name);
586                step_outputs.push(err_msg.clone());
587                messages.push(Message::tool(&tc.id, &err_msg));
588                on_event(LoopEvent::ToolResult {
589                    name: tc.name.clone(),
590                    output: err_msg,
591                });
592            }
593        }
594
595        if detector.check_outputs(&step_outputs) {
596            ctx.state = AgentState::Failed;
597            on_event(LoopEvent::LoopDetected {
598                count: detector.output_repeat_count,
599            });
600            return Err(AgentError::LoopDetected(detector.output_repeat_count));
601        }
602    }
603
604    ctx.state = AgentState::Failed;
605    Err(AgentError::MaxSteps(effective_max_steps))
606}
607
608/// Result of loop detection check.
609#[derive(Debug, PartialEq)]
610enum LoopCheckResult {
611    /// No loop detected.
612    Ok,
613    /// Tier 2 warning: a single tool category dominates. Contains the dominant tool name.
614    /// Agent gets one more chance with a hint injected.
615    Tier2Warning(String),
616    /// Hard loop detected (tier 1 exact repeat, or tier 2 after warning).
617    Abort,
618}
619
620/// 3-tier loop detection:
621/// - Tier 1: exact action signature repeats N times consecutively
622/// - Tier 2: single tool dominates >90% of all calls (warns first, aborts on second trigger)
623/// - Tier 3: tool output stagnation — same results repeating
624struct LoopDetector {
625    threshold: usize,
626    consecutive: usize,
627    last_sig: Vec<String>,
628    tool_freq: HashMap<String, usize>,
629    total_calls: usize,
630    /// Tier 3: hash of last tool outputs to detect stagnation
631    last_output_hash: u64,
632    output_repeat_count: usize,
633    /// Whether tier 2 warning has already been issued (next trigger aborts).
634    tier2_warned: bool,
635}
636
637impl LoopDetector {
638    fn new(threshold: usize) -> Self {
639        Self {
640            threshold,
641            consecutive: 0,
642            last_sig: vec![],
643            tool_freq: HashMap::new(),
644            total_calls: 0,
645            last_output_hash: 0,
646            output_repeat_count: 0,
647            tier2_warned: false,
648        }
649    }
650
651    /// Check action signature for loop.
652    /// Returns `Abort` for tier 1 (exact repeat) or tier 2 after warning.
653    /// Returns `Tier2Warning` on first tier 2 trigger (dominant tool detected).
654    fn check(&mut self, sig: &[String]) -> LoopCheckResult {
655        self.total_calls += 1;
656
657        // Tier 1: exact signature match
658        if sig == self.last_sig {
659            self.consecutive += 1;
660        } else {
661            self.consecutive = 1;
662            self.last_sig = sig.to_vec();
663        }
664        if self.consecutive >= self.threshold {
665            return LoopCheckResult::Abort;
666        }
667
668        // Tier 2: tool name frequency (single tool dominates)
669        for name in sig {
670            *self.tool_freq.entry(name.clone()).or_insert(0) += 1;
671        }
672        if self.total_calls >= self.threshold {
673            for (name, count) in &self.tool_freq {
674                if *count >= self.threshold && *count as f64 / self.total_calls as f64 > 0.9 {
675                    if self.tier2_warned {
676                        return LoopCheckResult::Abort;
677                    }
678                    self.tier2_warned = true;
679                    return LoopCheckResult::Tier2Warning(name.clone());
680                }
681            }
682        }
683
684        LoopCheckResult::Ok
685    }
686
687    /// Check tool outputs for stagnation (tier 3). Call after executing tools each step.
688    fn check_outputs(&mut self, outputs: &[String]) -> bool {
689        use std::collections::hash_map::DefaultHasher;
690        use std::hash::{Hash, Hasher};
691
692        let mut hasher = DefaultHasher::new();
693        outputs.hash(&mut hasher);
694        let hash = hasher.finish();
695
696        if hash == self.last_output_hash && self.last_output_hash != 0 {
697            self.output_repeat_count += 1;
698        } else {
699            self.output_repeat_count = 1;
700            self.last_output_hash = hash;
701        }
702
703        self.output_repeat_count >= self.threshold
704    }
705}
706
707/// Auto-completion detector — catches when agent is done but doesn't call finish_task.
708///
709/// Signals completion when:
710/// - Agent returns same situation text N times (stuck describing same state)
711/// - Situation contains completion keywords ("complete", "finished", "done", "no more")
712struct CompletionDetector {
713    threshold: usize,
714    last_situation: String,
715    repeat_count: usize,
716}
717
718/// Keywords in situation text that suggest task is complete.
719const COMPLETION_KEYWORDS: &[&str] = &[
720    "task is complete",
721    "task is done",
722    "task is finished",
723    "all done",
724    "successfully completed",
725    "nothing more",
726    "no further action",
727    "no more steps",
728];
729
730impl CompletionDetector {
731    fn new(threshold: usize) -> Self {
732        Self {
733            threshold: threshold.max(2),
734            last_situation: String::new(),
735            repeat_count: 0,
736        }
737    }
738
739    /// Check if the decision indicates implicit completion.
740    fn check(&mut self, decision: &Decision) -> bool {
741        // Don't interfere with explicit completion
742        if decision.completed || decision.tool_calls.is_empty() {
743            return false;
744        }
745
746        // Check for completion keywords in situation
747        let sit_lower = decision.situation.to_lowercase();
748        for keyword in COMPLETION_KEYWORDS {
749            if sit_lower.contains(keyword) {
750                return true;
751            }
752        }
753
754        // Check for repeated situation text (agent stuck describing same state)
755        if !decision.situation.is_empty() && decision.situation == self.last_situation {
756            self.repeat_count += 1;
757        } else {
758            self.repeat_count = 1;
759            self.last_situation = decision.situation.clone();
760        }
761
762        self.repeat_count >= self.threshold
763    }
764}
765
766/// Trim messages to fit within max_messages limit.
767/// Keeps: first 2 messages (system + initial user) + last (max - 2) messages.
768fn trim_messages(messages: &mut Vec<Message>, max: usize) {
769    if messages.len() <= max || max < 4 {
770        return;
771    }
772    let keep_start = 2; // system + user prompt
773    let remove_count = messages.len() - max + 1;
774    let mut trim_end = keep_start + remove_count;
775
776    // Don't break functionCall → functionResponse pairs.
777    // Gemini requires: model turn (functionCall) → user turn (functionResponse).
778    // If trim_end lands in the middle of such a pair, extend to skip the whole group.
779    //
780    // Case 1: trim_end points at Tool messages — extend past them (they'd be orphaned).
781    while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
782        trim_end += 1;
783    }
784    // Case 2: the first kept message is a Tool — it lost its preceding Assistant.
785    // (Already handled by Case 1, but double-check.)
786    //
787    // Case 3: the last removed message is an Assistant with tool_calls —
788    // the following Tool messages (now first in kept region) would be orphaned.
789    // Extend trim_end to also remove those Tool messages.
790    if trim_end > keep_start && trim_end < messages.len() {
791        let last_removed = trim_end - 1;
792        if messages[last_removed].role == Role::Assistant
793            && !messages[last_removed].tool_calls.is_empty()
794        {
795            // The assistant had tool_calls but we're keeping it... actually we're removing it.
796            // So remove all following Tool messages too.
797            while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
798                trim_end += 1;
799            }
800        }
801    }
802
803    let removed_range = keep_start..trim_end;
804
805    let summary = format!(
806        "[{} messages trimmed from context to stay within {} message limit]",
807        trim_end - keep_start,
808        max
809    );
810
811    messages.drain(removed_range);
812    messages.insert(keep_start, Message::system(&summary));
813}
814
815#[cfg(test)]
816mod tests {
817    use super::*;
818    use crate::agent::{Agent, AgentError, Decision};
819    use crate::agent_tool::{Tool, ToolError, ToolOutput};
820    use crate::context::AgentContext;
821    use crate::registry::ToolRegistry;
822    use crate::types::{Message, SgrError, ToolCall};
823    use serde_json::Value;
824    use std::sync::Arc;
825    use std::sync::atomic::{AtomicUsize, Ordering};
826
827    struct CountingAgent {
828        max_calls: usize,
829        call_count: Arc<AtomicUsize>,
830    }
831
832    #[async_trait::async_trait]
833    impl Agent for CountingAgent {
834        async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
835            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
836            if n >= self.max_calls {
837                Ok(Decision {
838                    situation: "done".into(),
839                    task: vec![],
840                    tool_calls: vec![],
841                    completed: true,
842                })
843            } else {
844                Ok(Decision {
845                    situation: format!("step {}", n),
846                    task: vec![],
847                    tool_calls: vec![ToolCall {
848                        id: format!("call_{}", n),
849                        name: "echo".into(),
850                        arguments: serde_json::json!({"msg": "hi"}),
851                    }],
852                    completed: false,
853                })
854            }
855        }
856    }
857
858    struct EchoTool;
859
860    #[async_trait::async_trait]
861    impl Tool for EchoTool {
862        fn name(&self) -> &str {
863            "echo"
864        }
865        fn description(&self) -> &str {
866            "echo"
867        }
868        fn parameters_schema(&self) -> Value {
869            serde_json::json!({"type": "object"})
870        }
871        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
872            Ok(ToolOutput::text("echoed"))
873        }
874    }
875
876    #[tokio::test]
877    async fn loop_runs_and_completes() {
878        let agent = CountingAgent {
879            max_calls: 3,
880            call_count: Arc::new(AtomicUsize::new(0)),
881        };
882        let tools = ToolRegistry::new().register(EchoTool);
883        let mut ctx = AgentContext::new();
884        let mut messages = vec![Message::user("go")];
885        let config = LoopConfig::default();
886
887        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
888            .await
889            .unwrap();
890        assert_eq!(steps, 4); // 3 tool calls + 1 completion
891        assert_eq!(ctx.state, AgentState::Completed);
892    }
893
894    #[tokio::test]
895    async fn loop_detects_repetition() {
896        // Agent always returns same tool call → loop detection
897        struct LoopingAgent;
898        #[async_trait::async_trait]
899        impl Agent for LoopingAgent {
900            async fn decide(
901                &self,
902                _: &[Message],
903                _: &ToolRegistry,
904            ) -> Result<Decision, AgentError> {
905                Ok(Decision {
906                    situation: "stuck".into(),
907                    task: vec![],
908                    tool_calls: vec![ToolCall {
909                        id: "1".into(),
910                        name: "echo".into(),
911                        arguments: serde_json::json!({}),
912                    }],
913                    completed: false,
914                })
915            }
916        }
917
918        let tools = ToolRegistry::new().register(EchoTool);
919        let mut ctx = AgentContext::new();
920        let mut messages = vec![Message::user("go")];
921        let config = LoopConfig {
922            max_steps: 50,
923            loop_abort_threshold: 3,
924            auto_complete_threshold: 100, // disable auto-complete for this test
925            ..Default::default()
926        };
927
928        let result = run_loop(
929            &LoopingAgent,
930            &tools,
931            &mut ctx,
932            &mut messages,
933            &config,
934            |_| {},
935        )
936        .await;
937        assert!(matches!(result, Err(AgentError::LoopDetected(3))));
938        assert_eq!(ctx.state, AgentState::Failed);
939    }
940
941    #[tokio::test]
942    async fn loop_max_steps() {
943        // Agent never completes
944        struct NeverDoneAgent;
945        #[async_trait::async_trait]
946        impl Agent for NeverDoneAgent {
947            async fn decide(
948                &self,
949                _: &[Message],
950                _: &ToolRegistry,
951            ) -> Result<Decision, AgentError> {
952                // Different tool names to avoid loop detection
953                static COUNTER: AtomicUsize = AtomicUsize::new(0);
954                let n = COUNTER.fetch_add(1, Ordering::SeqCst);
955                Ok(Decision {
956                    situation: String::new(),
957                    task: vec![],
958                    tool_calls: vec![ToolCall {
959                        id: format!("{}", n),
960                        name: format!("tool_{}", n),
961                        arguments: serde_json::json!({}),
962                    }],
963                    completed: false,
964                })
965            }
966        }
967
968        let tools = ToolRegistry::new().register(EchoTool);
969        let mut ctx = AgentContext::new();
970        let mut messages = vec![Message::user("go")];
971        let config = LoopConfig {
972            max_steps: 5,
973            loop_abort_threshold: 100,
974            ..Default::default()
975        };
976
977        let result = run_loop(
978            &NeverDoneAgent,
979            &tools,
980            &mut ctx,
981            &mut messages,
982            &config,
983            |_| {},
984        )
985        .await;
986        assert!(matches!(result, Err(AgentError::MaxSteps(5))));
987    }
988
989    #[test]
990    fn loop_detector_exact_sig() {
991        let mut d = LoopDetector::new(3);
992        let sig = vec!["bash".to_string()];
993        assert_eq!(d.check(&sig), LoopCheckResult::Ok);
994        assert_eq!(d.check(&sig), LoopCheckResult::Ok);
995        assert_eq!(d.check(&sig), LoopCheckResult::Abort); // 3rd consecutive
996    }
997
998    #[test]
999    fn loop_detector_different_sigs_reset() {
1000        let mut d = LoopDetector::new(3);
1001        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1002        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1003        assert_eq!(d.check(&["read".into()]), LoopCheckResult::Ok); // different → resets
1004        assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1005    }
1006
1007    #[test]
1008    fn loop_detector_tier2_warning_then_abort() {
1009        // Tier 2 requires: count >= threshold AND count/total > 0.9
1010        // Use threshold=3. To avoid tier 1 (exact consecutive), alternate sigs.
1011        let mut d = LoopDetector::new(3);
1012        // Calls 1-2: build up frequency, total_calls < threshold so tier 2 not checked
1013        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); // total=1, edit=1, cons=1
1014        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); // total=2, edit=2, cons=2
1015        // Call 3: break consecutive (different sig) but edit_file still in sig
1016        // total=3, edit=3, cons=1 → tier 2: 3/3=1.0 > 0.9 → first warning
1017        assert_eq!(
1018            d.check(&["edit_file".into(), "read_file".into()]),
1019            LoopCheckResult::Tier2Warning("edit_file".into())
1020        );
1021        // Call 4: tier 2 already warned → abort
1022        assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Abort);
1023    }
1024
1025    #[test]
1026    fn loop_config_default() {
1027        let c = LoopConfig::default();
1028        assert_eq!(c.max_steps, 50);
1029        assert_eq!(c.loop_abort_threshold, 6);
1030    }
1031
1032    #[test]
1033    fn loop_detector_output_stagnation() {
1034        let mut d = LoopDetector::new(3);
1035        let outputs = vec!["same result".to_string()];
1036        assert!(!d.check_outputs(&outputs));
1037        assert!(!d.check_outputs(&outputs));
1038        assert!(d.check_outputs(&outputs)); // 3rd repeat
1039    }
1040
1041    #[test]
1042    fn completion_detector_keyword() {
1043        let mut cd = CompletionDetector::new(3);
1044        let d = Decision {
1045            situation: "The task is complete, all files written.".into(),
1046            task: vec![],
1047            tool_calls: vec![ToolCall {
1048                id: "1".into(),
1049                name: "echo".into(),
1050                arguments: serde_json::json!({}),
1051            }],
1052            completed: false,
1053        };
1054        assert!(cd.check(&d));
1055    }
1056
1057    #[test]
1058    fn completion_detector_repeated_situation() {
1059        let mut cd = CompletionDetector::new(3);
1060        let d = Decision {
1061            situation: "working on it".into(),
1062            task: vec![],
1063            tool_calls: vec![ToolCall {
1064                id: "1".into(),
1065                name: "echo".into(),
1066                arguments: serde_json::json!({}),
1067            }],
1068            completed: false,
1069        };
1070        assert!(!cd.check(&d));
1071        assert!(!cd.check(&d));
1072        assert!(cd.check(&d)); // 3rd repeat
1073    }
1074
1075    #[test]
1076    fn completion_detector_ignores_explicit_completion() {
1077        let mut cd = CompletionDetector::new(2);
1078        let d = Decision {
1079            situation: "task is complete".into(),
1080            task: vec![],
1081            tool_calls: vec![],
1082            completed: true,
1083        };
1084        // Should return false — let normal completion handling take over
1085        assert!(!cd.check(&d));
1086    }
1087
1088    #[test]
1089    fn trim_messages_basic() {
1090        let mut msgs: Vec<Message> = (0..10).map(|i| Message::user(format!("msg {i}"))).collect();
1091        trim_messages(&mut msgs, 6);
1092        // first 2 + summary + last 3 = 6
1093        assert_eq!(msgs.len(), 6);
1094        assert!(msgs[2].content.contains("trimmed"));
1095    }
1096
1097    #[test]
1098    fn trim_messages_no_op_when_under_limit() {
1099        let mut msgs = vec![Message::user("a"), Message::user("b")];
1100        trim_messages(&mut msgs, 10);
1101        assert_eq!(msgs.len(), 2);
1102    }
1103
1104    #[test]
1105    fn trim_messages_preserves_assistant_tool_call_pair() {
1106        use crate::types::Role;
1107        // system, user, assistant(tool_calls), tool, tool, user, assistant
1108        let mut msgs = vec![
1109            Message::system("sys"),
1110            Message::user("prompt"),
1111            Message::assistant_with_tool_calls(
1112                "calling",
1113                vec![
1114                    ToolCall {
1115                        id: "c1".into(),
1116                        name: "read".into(),
1117                        arguments: serde_json::json!({}),
1118                    },
1119                    ToolCall {
1120                        id: "c2".into(),
1121                        name: "read".into(),
1122                        arguments: serde_json::json!({}),
1123                    },
1124                ],
1125            ),
1126            Message::tool("c1", "result1"),
1127            Message::tool("c2", "result2"),
1128            Message::user("next"),
1129            Message::assistant("done"),
1130        ];
1131        // Trim to 5 — should remove assistant+tools as a group, not split them
1132        trim_messages(&mut msgs, 5);
1133        // Verify no orphaned Tool messages remain
1134        for (i, msg) in msgs.iter().enumerate() {
1135            if msg.role == Role::Tool {
1136                // The previous message should be an Assistant with tool_calls
1137                assert!(i > 0, "Tool message at start");
1138                assert!(
1139                    msgs[i - 1].role == Role::Assistant && !msgs[i - 1].tool_calls.is_empty()
1140                        || msgs[i - 1].role == Role::Tool,
1141                    "Orphaned Tool at position {i}"
1142                );
1143            }
1144        }
1145    }
1146
1147    #[test]
1148    fn loop_detector_output_stagnation_resets_on_change() {
1149        let mut d = LoopDetector::new(3);
1150        let a = vec!["result A".to_string()];
1151        let b = vec!["result B".to_string()];
1152        assert!(!d.check_outputs(&a));
1153        assert!(!d.check_outputs(&a));
1154        assert!(!d.check_outputs(&b)); // different → resets
1155        assert!(!d.check_outputs(&a));
1156    }
1157
1158    #[tokio::test]
1159    async fn loop_handles_non_recoverable_llm_error() {
1160        struct FailingAgent;
1161        #[async_trait::async_trait]
1162        impl Agent for FailingAgent {
1163            async fn decide(
1164                &self,
1165                _: &[Message],
1166                _: &ToolRegistry,
1167            ) -> Result<Decision, AgentError> {
1168                Err(AgentError::Llm(SgrError::Api {
1169                    status: 500,
1170                    body: "internal server error".into(),
1171                }))
1172            }
1173        }
1174
1175        let tools = ToolRegistry::new().register(EchoTool);
1176        let mut ctx = AgentContext::new();
1177        let mut messages = vec![Message::user("go")];
1178        let config = LoopConfig::default();
1179
1180        let result = run_loop(
1181            &FailingAgent,
1182            &tools,
1183            &mut ctx,
1184            &mut messages,
1185            &config,
1186            |_| {},
1187        )
1188        .await;
1189        // Non-recoverable: should fail immediately, no retries
1190        assert!(result.is_err());
1191        assert_eq!(messages.len(), 1); // no feedback messages added
1192    }
1193
1194    #[tokio::test]
1195    async fn loop_recovers_from_parse_error() {
1196        // Agent fails with parse error on first call, succeeds on retry
1197        struct ParseRetryAgent {
1198            call_count: Arc<AtomicUsize>,
1199        }
1200        #[async_trait::async_trait]
1201        impl Agent for ParseRetryAgent {
1202            async fn decide(
1203                &self,
1204                msgs: &[Message],
1205                _: &ToolRegistry,
1206            ) -> Result<Decision, AgentError> {
1207                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1208                if n == 0 {
1209                    // First call: simulate parse error
1210                    Err(AgentError::Llm(SgrError::Schema(
1211                        "Missing required field: situation".into(),
1212                    )))
1213                } else {
1214                    // Second call: should see error feedback in messages
1215                    let last = msgs.last().unwrap();
1216                    assert!(
1217                        last.content.contains("Parse error"),
1218                        "expected parse error feedback, got: {}",
1219                        last.content
1220                    );
1221                    Ok(Decision {
1222                        situation: "recovered from parse error".into(),
1223                        task: vec![],
1224                        tool_calls: vec![],
1225                        completed: true,
1226                    })
1227                }
1228            }
1229        }
1230
1231        let tools = ToolRegistry::new().register(EchoTool);
1232        let mut ctx = AgentContext::new();
1233        let mut messages = vec![Message::user("go")];
1234        let config = LoopConfig::default();
1235        let agent = ParseRetryAgent {
1236            call_count: Arc::new(AtomicUsize::new(0)),
1237        };
1238
1239        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1240            .await
1241            .unwrap();
1242        assert_eq!(steps, 2); // step 1 failed parse, step 2 succeeded
1243        assert_eq!(ctx.state, AgentState::Completed);
1244    }
1245
1246    #[tokio::test]
1247    async fn loop_aborts_after_max_parse_retries() {
1248        struct AlwaysFailParseAgent;
1249        #[async_trait::async_trait]
1250        impl Agent for AlwaysFailParseAgent {
1251            async fn decide(
1252                &self,
1253                _: &[Message],
1254                _: &ToolRegistry,
1255            ) -> Result<Decision, AgentError> {
1256                Err(AgentError::Llm(SgrError::Schema("bad json".into())))
1257            }
1258        }
1259
1260        let tools = ToolRegistry::new().register(EchoTool);
1261        let mut ctx = AgentContext::new();
1262        let mut messages = vec![Message::user("go")];
1263        let config = LoopConfig::default();
1264
1265        let result = run_loop(
1266            &AlwaysFailParseAgent,
1267            &tools,
1268            &mut ctx,
1269            &mut messages,
1270            &config,
1271            |_| {},
1272        )
1273        .await;
1274        assert!(result.is_err());
1275        // Should have added MAX_PARSE_RETRIES feedback messages
1276        let feedback_count = messages
1277            .iter()
1278            .filter(|m| m.content.contains("Parse error"))
1279            .count();
1280        assert_eq!(feedback_count, MAX_PARSE_RETRIES);
1281    }
1282
1283    #[tokio::test]
1284    async fn loop_feeds_tool_errors_back() {
1285        // Agent calls unknown tool → error fed back → agent completes
1286        struct ErrorRecoveryAgent {
1287            call_count: Arc<AtomicUsize>,
1288        }
1289        #[async_trait::async_trait]
1290        impl Agent for ErrorRecoveryAgent {
1291            async fn decide(
1292                &self,
1293                msgs: &[Message],
1294                _: &ToolRegistry,
1295            ) -> Result<Decision, AgentError> {
1296                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1297                if n == 0 {
1298                    // First: call unknown tool
1299                    Ok(Decision {
1300                        situation: "trying".into(),
1301                        task: vec![],
1302                        tool_calls: vec![ToolCall {
1303                            id: "1".into(),
1304                            name: "nonexistent_tool".into(),
1305                            arguments: serde_json::json!({}),
1306                        }],
1307                        completed: false,
1308                    })
1309                } else {
1310                    // Second: should see error in messages, complete
1311                    let last = msgs.last().unwrap();
1312                    assert!(last.content.contains("Unknown tool"));
1313                    Ok(Decision {
1314                        situation: "recovered".into(),
1315                        task: vec![],
1316                        tool_calls: vec![],
1317                        completed: true,
1318                    })
1319                }
1320            }
1321        }
1322
1323        let tools = ToolRegistry::new().register(EchoTool);
1324        let mut ctx = AgentContext::new();
1325        let mut messages = vec![Message::user("go")];
1326        let config = LoopConfig::default();
1327        let agent = ErrorRecoveryAgent {
1328            call_count: Arc::new(AtomicUsize::new(0)),
1329        };
1330
1331        let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1332            .await
1333            .unwrap();
1334        assert_eq!(steps, 2);
1335        assert_eq!(ctx.state, AgentState::Completed);
1336    }
1337
1338    #[tokio::test]
1339    async fn parallel_readonly_tools() {
1340        struct ReadOnlyTool {
1341            name: &'static str,
1342        }
1343
1344        #[async_trait::async_trait]
1345        impl Tool for ReadOnlyTool {
1346            fn name(&self) -> &str {
1347                self.name
1348            }
1349            fn description(&self) -> &str {
1350                "read-only tool"
1351            }
1352            fn is_read_only(&self) -> bool {
1353                true
1354            }
1355            fn parameters_schema(&self) -> Value {
1356                serde_json::json!({"type": "object"})
1357            }
1358            async fn execute(
1359                &self,
1360                _: Value,
1361                _: &mut AgentContext,
1362            ) -> Result<ToolOutput, ToolError> {
1363                Ok(ToolOutput::text(format!("{} result", self.name)))
1364            }
1365            async fn execute_readonly(
1366                &self,
1367                _: Value,
1368                _ctx: &crate::context::AgentContext,
1369            ) -> Result<ToolOutput, ToolError> {
1370                Ok(ToolOutput::text(format!("{} result", self.name)))
1371            }
1372        }
1373
1374        struct ParallelAgent;
1375        #[async_trait::async_trait]
1376        impl Agent for ParallelAgent {
1377            async fn decide(
1378                &self,
1379                msgs: &[Message],
1380                _: &ToolRegistry,
1381            ) -> Result<Decision, AgentError> {
1382                if msgs.len() > 3 {
1383                    return Ok(Decision {
1384                        situation: "done".into(),
1385                        task: vec![],
1386                        tool_calls: vec![],
1387                        completed: true,
1388                    });
1389                }
1390                Ok(Decision {
1391                    situation: "reading".into(),
1392                    task: vec![],
1393                    tool_calls: vec![
1394                        ToolCall {
1395                            id: "1".into(),
1396                            name: "reader_a".into(),
1397                            arguments: serde_json::json!({}),
1398                        },
1399                        ToolCall {
1400                            id: "2".into(),
1401                            name: "reader_b".into(),
1402                            arguments: serde_json::json!({}),
1403                        },
1404                    ],
1405                    completed: false,
1406                })
1407            }
1408        }
1409
1410        let tools = ToolRegistry::new()
1411            .register(ReadOnlyTool { name: "reader_a" })
1412            .register(ReadOnlyTool { name: "reader_b" });
1413        let mut ctx = AgentContext::new();
1414        let mut messages = vec![Message::user("read stuff")];
1415        let config = LoopConfig::default();
1416
1417        let steps = run_loop(
1418            &ParallelAgent,
1419            &tools,
1420            &mut ctx,
1421            &mut messages,
1422            &config,
1423            |_| {},
1424        )
1425        .await
1426        .unwrap();
1427        assert!(steps > 0);
1428        assert_eq!(ctx.state, AgentState::Completed);
1429    }
1430
1431    #[tokio::test]
1432    async fn loop_events_are_emitted() {
1433        let agent = CountingAgent {
1434            max_calls: 1,
1435            call_count: Arc::new(AtomicUsize::new(0)),
1436        };
1437        let tools = ToolRegistry::new().register(EchoTool);
1438        let mut ctx = AgentContext::new();
1439        let mut messages = vec![Message::user("go")];
1440        let config = LoopConfig::default();
1441
1442        let mut events = Vec::new();
1443        run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |e| {
1444            events.push(format!("{:?}", std::mem::discriminant(&e)));
1445        })
1446        .await
1447        .unwrap();
1448
1449        // Should have: StepStart, Decision, ToolResult, StepStart, Decision, Completed
1450        assert!(events.len() >= 4);
1451    }
1452
1453    #[tokio::test]
1454    async fn tool_output_done_stops_loop() {
1455        // A tool that returns ToolOutput::done() should stop the loop immediately.
1456        struct DoneTool;
1457        #[async_trait::async_trait]
1458        impl Tool for DoneTool {
1459            fn name(&self) -> &str {
1460                "done_tool"
1461            }
1462            fn description(&self) -> &str {
1463                "returns done"
1464            }
1465            fn parameters_schema(&self) -> Value {
1466                serde_json::json!({"type": "object"})
1467            }
1468            async fn execute(
1469                &self,
1470                _: Value,
1471                _: &mut AgentContext,
1472            ) -> Result<ToolOutput, ToolError> {
1473                Ok(ToolOutput::done("final answer"))
1474            }
1475        }
1476
1477        struct OneShotAgent;
1478        #[async_trait::async_trait]
1479        impl Agent for OneShotAgent {
1480            async fn decide(
1481                &self,
1482                _: &[Message],
1483                _: &ToolRegistry,
1484            ) -> Result<Decision, AgentError> {
1485                Ok(Decision {
1486                    situation: "calling done tool".into(),
1487                    task: vec![],
1488                    tool_calls: vec![ToolCall {
1489                        id: "1".into(),
1490                        name: "done_tool".into(),
1491                        arguments: serde_json::json!({}),
1492                    }],
1493                    completed: false,
1494                })
1495            }
1496        }
1497
1498        let tools = ToolRegistry::new().register(DoneTool);
1499        let mut ctx = AgentContext::new();
1500        let mut messages = vec![Message::user("go")];
1501        let config = LoopConfig::default();
1502
1503        let steps = run_loop(
1504            &OneShotAgent,
1505            &tools,
1506            &mut ctx,
1507            &mut messages,
1508            &config,
1509            |_| {},
1510        )
1511        .await
1512        .unwrap();
1513        assert_eq!(
1514            steps, 1,
1515            "Loop should stop on first step when tool returns done"
1516        );
1517        assert_eq!(ctx.state, AgentState::Completed);
1518    }
1519
1520    #[tokio::test]
1521    async fn tool_messages_formatted_correctly() {
1522        // Verify that assistant messages with tool_calls are preserved in the message list,
1523        // followed by tool result messages.
1524        let agent = CountingAgent {
1525            max_calls: 1,
1526            call_count: Arc::new(AtomicUsize::new(0)),
1527        };
1528        let tools = ToolRegistry::new().register(EchoTool);
1529        let mut ctx = AgentContext::new();
1530        let mut messages = vec![Message::user("go")];
1531        let config = LoopConfig::default();
1532
1533        run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1534            .await
1535            .unwrap();
1536
1537        // After 1 tool call + completion, messages should be:
1538        // [user("go"), assistant_with_tool_calls("step 0", [echo]), tool("echoed"), assistant("done")]
1539        assert!(messages.len() >= 4);
1540
1541        // Find the assistant message with tool calls
1542        let assistant_tc = messages
1543            .iter()
1544            .find(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty());
1545        assert!(
1546            assistant_tc.is_some(),
1547            "Should have an assistant message with tool_calls"
1548        );
1549        let atc = assistant_tc.unwrap();
1550        assert_eq!(atc.tool_calls[0].name, "echo");
1551        assert_eq!(atc.tool_calls[0].id, "call_0");
1552
1553        // The next message should be a Tool result
1554        let tc_idx = messages
1555            .iter()
1556            .position(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty())
1557            .unwrap();
1558        let tool_msg = &messages[tc_idx + 1];
1559        assert_eq!(tool_msg.role, crate::types::Role::Tool);
1560        assert_eq!(tool_msg.tool_call_id.as_deref(), Some("call_0"));
1561        assert_eq!(tool_msg.content, "echoed");
1562    }
1563
1564    // --- Tool result pairing repair tests ---
1565
1566    #[test]
1567    fn pairing_adds_missing_tool_result() {
1568        let mut msgs = vec![
1569            Message::user("go"),
1570            Message::assistant_with_tool_calls(
1571                "calling",
1572                vec![
1573                    ToolCall {
1574                        id: "c1".into(),
1575                        name: "bash".into(),
1576                        arguments: serde_json::json!({}),
1577                    },
1578                    ToolCall {
1579                        id: "c2".into(),
1580                        name: "read".into(),
1581                        arguments: serde_json::json!({}),
1582                    },
1583                ],
1584            ),
1585            // Only c1 has a result — c2 is missing
1586            Message::tool("c1", "ok"),
1587        ];
1588        ensure_tool_result_pairing(&mut msgs);
1589
1590        // c2 should now have a synthetic result
1591        let c2_result = msgs
1592            .iter()
1593            .find(|m| m.tool_call_id.as_deref() == Some("c2"));
1594        assert!(c2_result.is_some(), "Should have synthetic result for c2");
1595        assert!(c2_result.unwrap().content.contains("missing"));
1596    }
1597
1598    #[test]
1599    fn pairing_removes_duplicate_tool_result() {
1600        let mut msgs = vec![
1601            Message::user("go"),
1602            Message::assistant_with_tool_calls(
1603                "calling",
1604                vec![ToolCall {
1605                    id: "c1".into(),
1606                    name: "bash".into(),
1607                    arguments: serde_json::json!({}),
1608                }],
1609            ),
1610            Message::tool("c1", "first"),
1611            Message::tool("c1", "duplicate"), // duplicate
1612        ];
1613        ensure_tool_result_pairing(&mut msgs);
1614
1615        let c1_count = msgs
1616            .iter()
1617            .filter(|m| m.tool_call_id.as_deref() == Some("c1"))
1618            .count();
1619        assert_eq!(c1_count, 1, "Should remove duplicate tool_result");
1620    }
1621
1622    #[test]
1623    fn pairing_removes_orphaned_tool_result() {
1624        let mut msgs = vec![
1625            Message::user("go"),
1626            Message::tool("orphan_id", "orphaned result"), // no matching tool_use
1627            Message::assistant("done"),
1628        ];
1629        ensure_tool_result_pairing(&mut msgs);
1630
1631        let orphan = msgs
1632            .iter()
1633            .find(|m| m.tool_call_id.as_deref() == Some("orphan_id"));
1634        assert!(orphan.is_none(), "Should remove orphaned tool_result");
1635    }
1636
1637    #[test]
1638    fn pairing_noop_for_valid_transcript() {
1639        let mut msgs = vec![
1640            Message::user("go"),
1641            Message::assistant_with_tool_calls(
1642                "calling",
1643                vec![ToolCall {
1644                    id: "c1".into(),
1645                    name: "bash".into(),
1646                    arguments: serde_json::json!({}),
1647                }],
1648            ),
1649            Message::tool("c1", "result"),
1650            Message::assistant("done"),
1651        ];
1652        let len_before = msgs.len();
1653        ensure_tool_result_pairing(&mut msgs);
1654        assert_eq!(msgs.len(), len_before, "Valid transcript should not change");
1655    }
1656
1657    // --- Context modifier tests ---
1658
1659    #[test]
1660    fn context_modifier_system_injection() {
1661        use crate::agent_tool::ContextModifier;
1662
1663        let modifier = ContextModifier::system("Extra instructions for next step");
1664        let mut ctx = AgentContext::new();
1665        let mut messages = vec![Message::user("go")];
1666        let mut max_steps = 50;
1667
1668        apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1669
1670        assert_eq!(messages.len(), 2);
1671        assert_eq!(messages[1].role, Role::User); // User, not System — Gemini compat
1672        assert!(messages[1].content.contains("Extra instructions"));
1673    }
1674
1675    #[test]
1676    fn context_modifier_extra_steps() {
1677        use crate::agent_tool::ContextModifier;
1678
1679        let mut ctx = AgentContext::new();
1680        let mut messages = vec![];
1681        let mut max_steps = 50;
1682
1683        let modifier = ContextModifier::extra_steps(20);
1684        apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1685        assert_eq!(max_steps, 70);
1686
1687        let modifier = ContextModifier::extra_steps(-10);
1688        apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1689        assert_eq!(max_steps, 60);
1690    }
1691
1692    #[test]
1693    fn context_modifier_custom_context() {
1694        use crate::agent_tool::ContextModifier;
1695
1696        let modifier = ContextModifier::custom("my_key", serde_json::json!("my_value"));
1697        let mut ctx = AgentContext::new();
1698        let mut messages = vec![];
1699        let mut max_steps = 50;
1700
1701        apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1702
1703        assert_eq!(ctx.get("my_key").unwrap(), "my_value");
1704    }
1705
1706    #[test]
1707    fn context_modifier_is_empty() {
1708        use crate::agent_tool::ContextModifier;
1709
1710        assert!(ContextModifier::default().is_empty());
1711        assert!(!ContextModifier::system("hi").is_empty());
1712        assert!(!ContextModifier::max_tokens(100).is_empty());
1713        assert!(!ContextModifier::extra_steps(5).is_empty());
1714        assert!(!ContextModifier::custom("k", serde_json::json!("v")).is_empty());
1715    }
1716
1717    #[test]
1718    fn context_modifier_max_tokens_stored_in_context() {
1719        use crate::agent_tool::ContextModifier;
1720
1721        let modifier = ContextModifier::max_tokens(4096);
1722        let mut ctx = AgentContext::new();
1723        let mut messages = vec![];
1724        let mut max_steps = 50;
1725
1726        apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1727
1728        assert_eq!(ctx.max_tokens_override(), Some(4096));
1729    }
1730}