1use crate::agent::{Agent, AgentError, Decision};
6use crate::context::{AgentContext, AgentState};
7use crate::registry::ToolRegistry;
8use crate::retry::{RetryConfig, delay_for_attempt, is_retryable};
9use crate::types::{Message, SgrError};
10use futures::future::join_all;
11use std::collections::HashMap;
12
13const MAX_PARSE_RETRIES: usize = 3;
15
16const MAX_TRANSIENT_RETRIES: usize = 3;
18
19fn is_recoverable_error(e: &AgentError) -> bool {
21 matches!(
22 e,
23 AgentError::Llm(SgrError::Json(_))
24 | AgentError::Llm(SgrError::EmptyResponse)
25 | AgentError::Llm(SgrError::Schema(_))
26 )
27}
28
29async fn decide_with_retry(
32 agent: &dyn Agent,
33 messages: &[Message],
34 tools: &ToolRegistry,
35 previous_response_id: Option<&str>,
36) -> Result<(Decision, Option<String>), AgentError> {
37 let retry_config = RetryConfig {
38 max_retries: MAX_TRANSIENT_RETRIES,
39 base_delay_ms: 500,
40 max_delay_ms: 30_000,
41 };
42
43 for attempt in 0..=retry_config.max_retries {
44 match agent
45 .decide_stateful(messages, tools, previous_response_id)
46 .await
47 {
48 Ok(d) => return Ok(d),
49 Err(AgentError::Llm(sgr_err))
50 if is_retryable(&sgr_err) && attempt < retry_config.max_retries =>
51 {
52 let delay = delay_for_attempt(attempt, &retry_config, &sgr_err);
53 tracing::warn!(
54 attempt = attempt + 1,
55 max = retry_config.max_retries,
56 delay_ms = delay.as_millis() as u64,
57 "Retrying agent.decide(): {}",
58 sgr_err
59 );
60 tokio::time::sleep(delay).await;
61 }
63 Err(e) => return Err(e),
64 }
65 }
66 agent
68 .decide_stateful(messages, tools, previous_response_id)
69 .await
70}
71
72#[derive(Debug, Clone)]
74pub struct LoopConfig {
75 pub max_steps: usize,
77 pub loop_abort_threshold: usize,
79 pub max_messages: usize,
82 pub auto_complete_threshold: usize,
84}
85
86impl Default for LoopConfig {
87 fn default() -> Self {
88 Self {
89 max_steps: 50,
90 loop_abort_threshold: 6,
91 max_messages: 80,
92 auto_complete_threshold: 3,
93 }
94 }
95}
96
97#[derive(Debug)]
99pub enum LoopEvent {
100 StepStart {
101 step: usize,
102 },
103 Decision(Decision),
104 ToolResult {
105 name: String,
106 output: String,
107 },
108 Completed {
109 steps: usize,
110 },
111 LoopDetected {
112 count: usize,
113 },
114 Error(AgentError),
115 WaitingForInput {
117 question: String,
118 tool_call_id: String,
119 },
120}
121
122pub async fn run_loop(
126 agent: &dyn Agent,
127 tools: &ToolRegistry,
128 ctx: &mut AgentContext,
129 messages: &mut Vec<Message>,
130 config: &LoopConfig,
131 mut on_event: impl FnMut(LoopEvent),
132) -> Result<usize, AgentError> {
133 let mut detector = LoopDetector::new(config.loop_abort_threshold);
134 let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
135 let mut parse_retries: usize = 0;
136 let mut response_id: Option<String> = None;
137
138 for step in 1..=config.max_steps {
139 if config.max_messages > 0 && messages.len() > config.max_messages {
141 trim_messages(messages, config.max_messages);
142 }
143 ctx.iteration = step;
144 on_event(LoopEvent::StepStart { step });
145
146 agent.prepare_context(ctx, messages);
148
149 let active_tool_names = agent.prepare_tools(ctx, tools);
151 let filtered_tools = if active_tool_names.len() == tools.list().len() {
152 None } else {
154 Some(active_tool_names)
155 };
156
157 let effective_tools = if let Some(ref names) = filtered_tools {
159 &tools.filter(names)
160 } else {
161 tools
162 };
163
164 let decision = match decide_with_retry(
165 agent,
166 messages,
167 effective_tools,
168 response_id.as_deref(),
169 )
170 .await
171 {
172 Ok((d, new_rid)) => {
173 parse_retries = 0;
174 response_id = new_rid;
175 d
176 }
177 Err(e) if is_recoverable_error(&e) => {
178 parse_retries += 1;
179 if parse_retries > MAX_PARSE_RETRIES {
180 return Err(e);
181 }
182 let err_msg = format!(
183 "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
184 parse_retries, MAX_PARSE_RETRIES, e
185 );
186 on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
187 err_msg.clone(),
188 ))));
189 messages.push(Message::user(&err_msg));
190 continue;
191 }
192 Err(e) => return Err(e),
193 };
194 on_event(LoopEvent::Decision(decision.clone()));
195
196 if completion_detector.check(&decision) {
198 ctx.state = AgentState::Completed;
199 if !decision.situation.is_empty() {
200 messages.push(Message::assistant(&decision.situation));
201 }
202 on_event(LoopEvent::Completed { steps: step });
203 return Ok(step);
204 }
205
206 if decision.completed || decision.tool_calls.is_empty() {
207 ctx.state = AgentState::Completed;
208 if !decision.situation.is_empty() {
210 messages.push(Message::assistant(&decision.situation));
211 }
212 on_event(LoopEvent::Completed { steps: step });
213 return Ok(step);
214 }
215
216 let sig: Vec<String> = decision
218 .tool_calls
219 .iter()
220 .map(|tc| tc.name.clone())
221 .collect();
222 match detector.check(&sig) {
223 LoopCheckResult::Abort => {
224 ctx.state = AgentState::Failed;
225 on_event(LoopEvent::LoopDetected {
226 count: detector.consecutive,
227 });
228 return Err(AgentError::LoopDetected(detector.consecutive));
229 }
230 LoopCheckResult::Tier2Warning(dominant_tool) => {
231 let hint = format!(
233 "LOOP WARNING: You are repeatedly using '{}' without making progress. \
234 Try a different approach: re-read the file with read_file to see current contents, \
235 use write_file instead of edit_file, or break the problem into smaller steps.",
236 dominant_tool
237 );
238 messages.push(Message::system(&hint));
239 }
240 LoopCheckResult::Ok => {}
241 }
242
243 messages.push(Message::assistant_with_tool_calls(
245 &decision.situation,
246 decision.tool_calls.clone(),
247 ));
248
249 let mut step_outputs: Vec<String> = Vec::new();
251 let mut early_done = false;
252
253 let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
255 .tool_calls
256 .iter()
257 .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
258
259 if !ro_calls.is_empty() {
261 let futs: Vec<_> = ro_calls
262 .iter()
263 .map(|tc| {
264 let tool = tools.get(&tc.name).unwrap();
265 let args = tc.arguments.clone();
266 let name = tc.name.clone();
267 let id = tc.id.clone();
268 async move { (id, name, tool.execute_readonly(args).await) }
269 })
270 .collect();
271
272 for (id, name, result) in join_all(futs).await {
273 match result {
274 Ok(output) => {
275 on_event(LoopEvent::ToolResult {
276 name: name.clone(),
277 output: output.content.clone(),
278 });
279 step_outputs.push(output.content.clone());
280 agent.after_action(ctx, &name, &output.content);
281 if output.waiting {
282 ctx.state = AgentState::WaitingInput;
283 on_event(LoopEvent::WaitingForInput {
284 question: output.content.clone(),
285 tool_call_id: id.clone(),
286 });
287 messages.push(Message::tool(&id, "[waiting for user input]"));
288 ctx.state = AgentState::Running;
289 } else {
290 messages.push(Message::tool(&id, &output.content));
291 }
292 if output.done {
293 early_done = true;
294 }
295 }
296 Err(e) => {
297 let err_msg = format!("Tool error: {}", e);
298 step_outputs.push(err_msg.clone());
299 messages.push(Message::tool(&id, &err_msg));
300 agent.after_action(ctx, &name, &err_msg);
301 on_event(LoopEvent::ToolResult {
302 name,
303 output: err_msg,
304 });
305 }
306 }
307 }
308 if early_done && rw_calls.is_empty() {
309 ctx.state = AgentState::Completed;
311 on_event(LoopEvent::Completed { steps: step });
312 return Ok(step);
313 }
314 }
315
316 for tc in &rw_calls {
318 if let Some(tool) = tools.get(&tc.name) {
319 match tool.execute(tc.arguments.clone(), ctx).await {
320 Ok(output) => {
321 on_event(LoopEvent::ToolResult {
322 name: tc.name.clone(),
323 output: output.content.clone(),
324 });
325 step_outputs.push(output.content.clone());
326 agent.after_action(ctx, &tc.name, &output.content);
327 if output.waiting {
328 ctx.state = AgentState::WaitingInput;
329 on_event(LoopEvent::WaitingForInput {
330 question: output.content.clone(),
331 tool_call_id: tc.id.clone(),
332 });
333 messages.push(Message::tool(&tc.id, "[waiting for user input]"));
334 ctx.state = AgentState::Running;
335 } else {
336 messages.push(Message::tool(&tc.id, &output.content));
337 }
338 if output.done {
339 ctx.state = AgentState::Completed;
340 on_event(LoopEvent::Completed { steps: step });
341 return Ok(step);
342 }
343 }
344 Err(e) => {
345 let err_msg = format!("Tool error: {}", e);
346 step_outputs.push(err_msg.clone());
347 messages.push(Message::tool(&tc.id, &err_msg));
348 agent.after_action(ctx, &tc.name, &err_msg);
349 on_event(LoopEvent::ToolResult {
350 name: tc.name.clone(),
351 output: err_msg,
352 });
353 }
354 }
355 } else {
356 let err_msg = format!("Unknown tool: {}", tc.name);
357 step_outputs.push(err_msg.clone());
358 messages.push(Message::tool(&tc.id, &err_msg));
359 on_event(LoopEvent::ToolResult {
360 name: tc.name.clone(),
361 output: err_msg,
362 });
363 }
364 }
365
366 if detector.check_outputs(&step_outputs) {
368 ctx.state = AgentState::Failed;
369 on_event(LoopEvent::LoopDetected {
370 count: detector.output_repeat_count,
371 });
372 return Err(AgentError::LoopDetected(detector.output_repeat_count));
373 }
374 }
375
376 ctx.state = AgentState::Failed;
377 Err(AgentError::MaxSteps(config.max_steps))
378}
379
380pub async fn run_loop_interactive<F, Fut>(
388 agent: &dyn Agent,
389 tools: &ToolRegistry,
390 ctx: &mut AgentContext,
391 messages: &mut Vec<Message>,
392 config: &LoopConfig,
393 mut on_event: impl FnMut(LoopEvent),
394 mut on_input: F,
395) -> Result<usize, AgentError>
396where
397 F: FnMut(String) -> Fut,
398 Fut: std::future::Future<Output = String>,
399{
400 let mut detector = LoopDetector::new(config.loop_abort_threshold);
401 let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
402 let mut parse_retries: usize = 0;
403 let mut response_id: Option<String> = None;
404
405 for step in 1..=config.max_steps {
406 if config.max_messages > 0 && messages.len() > config.max_messages {
407 trim_messages(messages, config.max_messages);
408 }
409 ctx.iteration = step;
410 on_event(LoopEvent::StepStart { step });
411
412 agent.prepare_context(ctx, messages);
413
414 let active_tool_names = agent.prepare_tools(ctx, tools);
415 let filtered_tools = if active_tool_names.len() == tools.list().len() {
416 None
417 } else {
418 Some(active_tool_names)
419 };
420 let effective_tools = if let Some(ref names) = filtered_tools {
421 &tools.filter(names)
422 } else {
423 tools
424 };
425
426 let decision = match decide_with_retry(
427 agent,
428 messages,
429 effective_tools,
430 response_id.as_deref(),
431 )
432 .await
433 {
434 Ok((d, new_rid)) => {
435 parse_retries = 0;
436 response_id = new_rid;
437 d
438 }
439 Err(e) if is_recoverable_error(&e) => {
440 parse_retries += 1;
441 if parse_retries > MAX_PARSE_RETRIES {
442 return Err(e);
443 }
444 let err_msg = format!(
445 "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
446 parse_retries, MAX_PARSE_RETRIES, e
447 );
448 on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
449 err_msg.clone(),
450 ))));
451 messages.push(Message::user(&err_msg));
452 continue;
453 }
454 Err(e) => return Err(e),
455 };
456 on_event(LoopEvent::Decision(decision.clone()));
457
458 if completion_detector.check(&decision) {
459 ctx.state = AgentState::Completed;
460 if !decision.situation.is_empty() {
461 messages.push(Message::assistant(&decision.situation));
462 }
463 on_event(LoopEvent::Completed { steps: step });
464 return Ok(step);
465 }
466
467 if decision.completed || decision.tool_calls.is_empty() {
468 ctx.state = AgentState::Completed;
469 if !decision.situation.is_empty() {
470 messages.push(Message::assistant(&decision.situation));
471 }
472 on_event(LoopEvent::Completed { steps: step });
473 return Ok(step);
474 }
475
476 let sig: Vec<String> = decision
477 .tool_calls
478 .iter()
479 .map(|tc| tc.name.clone())
480 .collect();
481 match detector.check(&sig) {
482 LoopCheckResult::Abort => {
483 ctx.state = AgentState::Failed;
484 on_event(LoopEvent::LoopDetected {
485 count: detector.consecutive,
486 });
487 return Err(AgentError::LoopDetected(detector.consecutive));
488 }
489 LoopCheckResult::Tier2Warning(dominant_tool) => {
490 let hint = format!(
491 "LOOP WARNING: You are repeatedly using '{}' without making progress. \
492 Try a different approach: re-read the file with read_file to see current contents, \
493 use write_file instead of edit_file, or break the problem into smaller steps.",
494 dominant_tool
495 );
496 messages.push(Message::system(&hint));
497 }
498 LoopCheckResult::Ok => {}
499 }
500
501 messages.push(Message::assistant_with_tool_calls(
503 &decision.situation,
504 decision.tool_calls.clone(),
505 ));
506
507 let mut step_outputs: Vec<String> = Vec::new();
508 let mut early_done = false;
509
510 let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
512 .tool_calls
513 .iter()
514 .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
515
516 if !ro_calls.is_empty() {
518 let futs: Vec<_> = ro_calls
519 .iter()
520 .map(|tc| {
521 let tool = tools.get(&tc.name).unwrap();
522 let args = tc.arguments.clone();
523 let name = tc.name.clone();
524 let id = tc.id.clone();
525 async move { (id, name, tool.execute_readonly(args).await) }
526 })
527 .collect();
528
529 for (id, name, result) in join_all(futs).await {
530 match result {
531 Ok(output) => {
532 on_event(LoopEvent::ToolResult {
533 name: name.clone(),
534 output: output.content.clone(),
535 });
536 step_outputs.push(output.content.clone());
537 agent.after_action(ctx, &name, &output.content);
538 if output.waiting {
539 ctx.state = AgentState::WaitingInput;
540 on_event(LoopEvent::WaitingForInput {
541 question: output.content.clone(),
542 tool_call_id: id.clone(),
543 });
544 let response = on_input(output.content).await;
545 ctx.state = AgentState::Running;
546 messages.push(Message::tool(&id, &response));
547 } else {
548 messages.push(Message::tool(&id, &output.content));
549 }
550 if output.done {
551 early_done = true;
552 }
553 }
554 Err(e) => {
555 let err_msg = format!("Tool error: {}", e);
556 step_outputs.push(err_msg.clone());
557 messages.push(Message::tool(&id, &err_msg));
558 agent.after_action(ctx, &name, &err_msg);
559 on_event(LoopEvent::ToolResult {
560 name,
561 output: err_msg,
562 });
563 }
564 }
565 }
566 if early_done && rw_calls.is_empty() {
567 ctx.state = AgentState::Completed;
569 on_event(LoopEvent::Completed { steps: step });
570 return Ok(step);
571 }
572 }
573
574 for tc in &rw_calls {
576 if let Some(tool) = tools.get(&tc.name) {
577 match tool.execute(tc.arguments.clone(), ctx).await {
578 Ok(output) => {
579 on_event(LoopEvent::ToolResult {
580 name: tc.name.clone(),
581 output: output.content.clone(),
582 });
583 step_outputs.push(output.content.clone());
584 agent.after_action(ctx, &tc.name, &output.content);
585 if output.waiting {
586 ctx.state = AgentState::WaitingInput;
587 on_event(LoopEvent::WaitingForInput {
588 question: output.content.clone(),
589 tool_call_id: tc.id.clone(),
590 });
591 let response = on_input(output.content.clone()).await;
592 ctx.state = AgentState::Running;
593 messages.push(Message::tool(&tc.id, &response));
594 } else {
595 messages.push(Message::tool(&tc.id, &output.content));
596 }
597 if output.done {
598 ctx.state = AgentState::Completed;
599 on_event(LoopEvent::Completed { steps: step });
600 return Ok(step);
601 }
602 }
603 Err(e) => {
604 let err_msg = format!("Tool error: {}", e);
605 step_outputs.push(err_msg.clone());
606 messages.push(Message::tool(&tc.id, &err_msg));
607 agent.after_action(ctx, &tc.name, &err_msg);
608 on_event(LoopEvent::ToolResult {
609 name: tc.name.clone(),
610 output: err_msg,
611 });
612 }
613 }
614 } else {
615 let err_msg = format!("Unknown tool: {}", tc.name);
616 step_outputs.push(err_msg.clone());
617 messages.push(Message::tool(&tc.id, &err_msg));
618 on_event(LoopEvent::ToolResult {
619 name: tc.name.clone(),
620 output: err_msg,
621 });
622 }
623 }
624
625 if detector.check_outputs(&step_outputs) {
626 ctx.state = AgentState::Failed;
627 on_event(LoopEvent::LoopDetected {
628 count: detector.output_repeat_count,
629 });
630 return Err(AgentError::LoopDetected(detector.output_repeat_count));
631 }
632 }
633
634 ctx.state = AgentState::Failed;
635 Err(AgentError::MaxSteps(config.max_steps))
636}
637
638#[derive(Debug, PartialEq)]
640enum LoopCheckResult {
641 Ok,
643 Tier2Warning(String),
646 Abort,
648}
649
650struct LoopDetector {
655 threshold: usize,
656 consecutive: usize,
657 last_sig: Vec<String>,
658 tool_freq: HashMap<String, usize>,
659 total_calls: usize,
660 last_output_hash: u64,
662 output_repeat_count: usize,
663 tier2_warned: bool,
665}
666
667impl LoopDetector {
668 fn new(threshold: usize) -> Self {
669 Self {
670 threshold,
671 consecutive: 0,
672 last_sig: vec![],
673 tool_freq: HashMap::new(),
674 total_calls: 0,
675 last_output_hash: 0,
676 output_repeat_count: 0,
677 tier2_warned: false,
678 }
679 }
680
681 fn check(&mut self, sig: &[String]) -> LoopCheckResult {
685 self.total_calls += 1;
686
687 if sig == self.last_sig {
689 self.consecutive += 1;
690 } else {
691 self.consecutive = 1;
692 self.last_sig = sig.to_vec();
693 }
694 if self.consecutive >= self.threshold {
695 return LoopCheckResult::Abort;
696 }
697
698 for name in sig {
700 *self.tool_freq.entry(name.clone()).or_insert(0) += 1;
701 }
702 if self.total_calls >= self.threshold {
703 for (name, count) in &self.tool_freq {
704 if *count >= self.threshold && *count as f64 / self.total_calls as f64 > 0.9 {
705 if self.tier2_warned {
706 return LoopCheckResult::Abort;
707 }
708 self.tier2_warned = true;
709 return LoopCheckResult::Tier2Warning(name.clone());
710 }
711 }
712 }
713
714 LoopCheckResult::Ok
715 }
716
717 fn check_outputs(&mut self, outputs: &[String]) -> bool {
719 use std::collections::hash_map::DefaultHasher;
720 use std::hash::{Hash, Hasher};
721
722 let mut hasher = DefaultHasher::new();
723 outputs.hash(&mut hasher);
724 let hash = hasher.finish();
725
726 if hash == self.last_output_hash && self.last_output_hash != 0 {
727 self.output_repeat_count += 1;
728 } else {
729 self.output_repeat_count = 1;
730 self.last_output_hash = hash;
731 }
732
733 self.output_repeat_count >= self.threshold
734 }
735}
736
737struct CompletionDetector {
743 threshold: usize,
744 last_situation: String,
745 repeat_count: usize,
746}
747
748const COMPLETION_KEYWORDS: &[&str] = &[
750 "task is complete",
751 "task is done",
752 "task is finished",
753 "all done",
754 "successfully completed",
755 "nothing more",
756 "no further action",
757 "no more steps",
758];
759
760impl CompletionDetector {
761 fn new(threshold: usize) -> Self {
762 Self {
763 threshold: threshold.max(2),
764 last_situation: String::new(),
765 repeat_count: 0,
766 }
767 }
768
769 fn check(&mut self, decision: &Decision) -> bool {
771 if decision.completed || decision.tool_calls.is_empty() {
773 return false;
774 }
775
776 let sit_lower = decision.situation.to_lowercase();
778 for keyword in COMPLETION_KEYWORDS {
779 if sit_lower.contains(keyword) {
780 return true;
781 }
782 }
783
784 if !decision.situation.is_empty() && decision.situation == self.last_situation {
786 self.repeat_count += 1;
787 } else {
788 self.repeat_count = 1;
789 self.last_situation = decision.situation.clone();
790 }
791
792 self.repeat_count >= self.threshold
793 }
794}
795
796fn trim_messages(messages: &mut Vec<Message>, max: usize) {
799 use crate::types::Role;
800
801 if messages.len() <= max || max < 4 {
802 return;
803 }
804 let keep_start = 2; let remove_count = messages.len() - max + 1;
806 let mut trim_end = keep_start + remove_count;
807
808 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
814 trim_end += 1;
815 }
816 if trim_end > keep_start && trim_end < messages.len() {
823 let last_removed = trim_end - 1;
824 if messages[last_removed].role == Role::Assistant
825 && !messages[last_removed].tool_calls.is_empty()
826 {
827 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
830 trim_end += 1;
831 }
832 }
833 }
834
835 let removed_range = keep_start..trim_end;
836
837 let summary = format!(
838 "[{} messages trimmed from context to stay within {} message limit]",
839 trim_end - keep_start,
840 max
841 );
842
843 messages.drain(removed_range);
844 messages.insert(keep_start, Message::system(&summary));
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use crate::agent::{Agent, AgentError, Decision};
851 use crate::agent_tool::{Tool, ToolError, ToolOutput};
852 use crate::context::AgentContext;
853 use crate::registry::ToolRegistry;
854 use crate::types::{Message, SgrError, ToolCall};
855 use serde_json::Value;
856 use std::sync::Arc;
857 use std::sync::atomic::{AtomicUsize, Ordering};
858
859 struct CountingAgent {
860 max_calls: usize,
861 call_count: Arc<AtomicUsize>,
862 }
863
864 #[async_trait::async_trait]
865 impl Agent for CountingAgent {
866 async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
867 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
868 if n >= self.max_calls {
869 Ok(Decision {
870 situation: "done".into(),
871 task: vec![],
872 tool_calls: vec![],
873 completed: true,
874 })
875 } else {
876 Ok(Decision {
877 situation: format!("step {}", n),
878 task: vec![],
879 tool_calls: vec![ToolCall {
880 id: format!("call_{}", n),
881 name: "echo".into(),
882 arguments: serde_json::json!({"msg": "hi"}),
883 }],
884 completed: false,
885 })
886 }
887 }
888 }
889
890 struct EchoTool;
891
892 #[async_trait::async_trait]
893 impl Tool for EchoTool {
894 fn name(&self) -> &str {
895 "echo"
896 }
897 fn description(&self) -> &str {
898 "echo"
899 }
900 fn parameters_schema(&self) -> Value {
901 serde_json::json!({"type": "object"})
902 }
903 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
904 Ok(ToolOutput::text("echoed"))
905 }
906 }
907
908 #[tokio::test]
909 async fn loop_runs_and_completes() {
910 let agent = CountingAgent {
911 max_calls: 3,
912 call_count: Arc::new(AtomicUsize::new(0)),
913 };
914 let tools = ToolRegistry::new().register(EchoTool);
915 let mut ctx = AgentContext::new();
916 let mut messages = vec![Message::user("go")];
917 let config = LoopConfig::default();
918
919 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
920 .await
921 .unwrap();
922 assert_eq!(steps, 4); assert_eq!(ctx.state, AgentState::Completed);
924 }
925
926 #[tokio::test]
927 async fn loop_detects_repetition() {
928 struct LoopingAgent;
930 #[async_trait::async_trait]
931 impl Agent for LoopingAgent {
932 async fn decide(
933 &self,
934 _: &[Message],
935 _: &ToolRegistry,
936 ) -> Result<Decision, AgentError> {
937 Ok(Decision {
938 situation: "stuck".into(),
939 task: vec![],
940 tool_calls: vec![ToolCall {
941 id: "1".into(),
942 name: "echo".into(),
943 arguments: serde_json::json!({}),
944 }],
945 completed: false,
946 })
947 }
948 }
949
950 let tools = ToolRegistry::new().register(EchoTool);
951 let mut ctx = AgentContext::new();
952 let mut messages = vec![Message::user("go")];
953 let config = LoopConfig {
954 max_steps: 50,
955 loop_abort_threshold: 3,
956 auto_complete_threshold: 100, ..Default::default()
958 };
959
960 let result = run_loop(
961 &LoopingAgent,
962 &tools,
963 &mut ctx,
964 &mut messages,
965 &config,
966 |_| {},
967 )
968 .await;
969 assert!(matches!(result, Err(AgentError::LoopDetected(3))));
970 assert_eq!(ctx.state, AgentState::Failed);
971 }
972
973 #[tokio::test]
974 async fn loop_max_steps() {
975 struct NeverDoneAgent;
977 #[async_trait::async_trait]
978 impl Agent for NeverDoneAgent {
979 async fn decide(
980 &self,
981 _: &[Message],
982 _: &ToolRegistry,
983 ) -> Result<Decision, AgentError> {
984 static COUNTER: AtomicUsize = AtomicUsize::new(0);
986 let n = COUNTER.fetch_add(1, Ordering::SeqCst);
987 Ok(Decision {
988 situation: String::new(),
989 task: vec![],
990 tool_calls: vec![ToolCall {
991 id: format!("{}", n),
992 name: format!("tool_{}", n),
993 arguments: serde_json::json!({}),
994 }],
995 completed: false,
996 })
997 }
998 }
999
1000 let tools = ToolRegistry::new().register(EchoTool);
1001 let mut ctx = AgentContext::new();
1002 let mut messages = vec![Message::user("go")];
1003 let config = LoopConfig {
1004 max_steps: 5,
1005 loop_abort_threshold: 100,
1006 ..Default::default()
1007 };
1008
1009 let result = run_loop(
1010 &NeverDoneAgent,
1011 &tools,
1012 &mut ctx,
1013 &mut messages,
1014 &config,
1015 |_| {},
1016 )
1017 .await;
1018 assert!(matches!(result, Err(AgentError::MaxSteps(5))));
1019 }
1020
1021 #[test]
1022 fn loop_detector_exact_sig() {
1023 let mut d = LoopDetector::new(3);
1024 let sig = vec!["bash".to_string()];
1025 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
1026 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
1027 assert_eq!(d.check(&sig), LoopCheckResult::Abort); }
1029
1030 #[test]
1031 fn loop_detector_different_sigs_reset() {
1032 let mut d = LoopDetector::new(3);
1033 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1034 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1035 assert_eq!(d.check(&["read".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1037 }
1038
1039 #[test]
1040 fn loop_detector_tier2_warning_then_abort() {
1041 let mut d = LoopDetector::new(3);
1044 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(
1050 d.check(&["edit_file".into(), "read_file".into()]),
1051 LoopCheckResult::Tier2Warning("edit_file".into())
1052 );
1053 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Abort);
1055 }
1056
1057 #[test]
1058 fn loop_config_default() {
1059 let c = LoopConfig::default();
1060 assert_eq!(c.max_steps, 50);
1061 assert_eq!(c.loop_abort_threshold, 6);
1062 }
1063
1064 #[test]
1065 fn loop_detector_output_stagnation() {
1066 let mut d = LoopDetector::new(3);
1067 let outputs = vec!["same result".to_string()];
1068 assert!(!d.check_outputs(&outputs));
1069 assert!(!d.check_outputs(&outputs));
1070 assert!(d.check_outputs(&outputs)); }
1072
1073 #[test]
1074 fn completion_detector_keyword() {
1075 let mut cd = CompletionDetector::new(3);
1076 let d = Decision {
1077 situation: "The task is complete, all files written.".into(),
1078 task: vec![],
1079 tool_calls: vec![ToolCall {
1080 id: "1".into(),
1081 name: "echo".into(),
1082 arguments: serde_json::json!({}),
1083 }],
1084 completed: false,
1085 };
1086 assert!(cd.check(&d));
1087 }
1088
1089 #[test]
1090 fn completion_detector_repeated_situation() {
1091 let mut cd = CompletionDetector::new(3);
1092 let d = Decision {
1093 situation: "working on it".into(),
1094 task: vec![],
1095 tool_calls: vec![ToolCall {
1096 id: "1".into(),
1097 name: "echo".into(),
1098 arguments: serde_json::json!({}),
1099 }],
1100 completed: false,
1101 };
1102 assert!(!cd.check(&d));
1103 assert!(!cd.check(&d));
1104 assert!(cd.check(&d)); }
1106
1107 #[test]
1108 fn completion_detector_ignores_explicit_completion() {
1109 let mut cd = CompletionDetector::new(2);
1110 let d = Decision {
1111 situation: "task is complete".into(),
1112 task: vec![],
1113 tool_calls: vec![],
1114 completed: true,
1115 };
1116 assert!(!cd.check(&d));
1118 }
1119
1120 #[test]
1121 fn trim_messages_basic() {
1122 let mut msgs: Vec<Message> = (0..10).map(|i| Message::user(format!("msg {i}"))).collect();
1123 trim_messages(&mut msgs, 6);
1124 assert_eq!(msgs.len(), 6);
1126 assert!(msgs[2].content.contains("trimmed"));
1127 }
1128
1129 #[test]
1130 fn trim_messages_no_op_when_under_limit() {
1131 let mut msgs = vec![Message::user("a"), Message::user("b")];
1132 trim_messages(&mut msgs, 10);
1133 assert_eq!(msgs.len(), 2);
1134 }
1135
1136 #[test]
1137 fn trim_messages_preserves_assistant_tool_call_pair() {
1138 use crate::types::Role;
1139 let mut msgs = vec![
1141 Message::system("sys"),
1142 Message::user("prompt"),
1143 Message::assistant_with_tool_calls(
1144 "calling",
1145 vec![
1146 ToolCall {
1147 id: "c1".into(),
1148 name: "read".into(),
1149 arguments: serde_json::json!({}),
1150 },
1151 ToolCall {
1152 id: "c2".into(),
1153 name: "read".into(),
1154 arguments: serde_json::json!({}),
1155 },
1156 ],
1157 ),
1158 Message::tool("c1", "result1"),
1159 Message::tool("c2", "result2"),
1160 Message::user("next"),
1161 Message::assistant("done"),
1162 ];
1163 trim_messages(&mut msgs, 5);
1165 for (i, msg) in msgs.iter().enumerate() {
1167 if msg.role == Role::Tool {
1168 assert!(i > 0, "Tool message at start");
1170 assert!(
1171 msgs[i - 1].role == Role::Assistant && !msgs[i - 1].tool_calls.is_empty()
1172 || msgs[i - 1].role == Role::Tool,
1173 "Orphaned Tool at position {i}"
1174 );
1175 }
1176 }
1177 }
1178
1179 #[test]
1180 fn loop_detector_output_stagnation_resets_on_change() {
1181 let mut d = LoopDetector::new(3);
1182 let a = vec!["result A".to_string()];
1183 let b = vec!["result B".to_string()];
1184 assert!(!d.check_outputs(&a));
1185 assert!(!d.check_outputs(&a));
1186 assert!(!d.check_outputs(&b)); assert!(!d.check_outputs(&a));
1188 }
1189
1190 #[tokio::test]
1191 async fn loop_handles_non_recoverable_llm_error() {
1192 struct FailingAgent;
1193 #[async_trait::async_trait]
1194 impl Agent for FailingAgent {
1195 async fn decide(
1196 &self,
1197 _: &[Message],
1198 _: &ToolRegistry,
1199 ) -> Result<Decision, AgentError> {
1200 Err(AgentError::Llm(SgrError::Api {
1201 status: 500,
1202 body: "internal server error".into(),
1203 }))
1204 }
1205 }
1206
1207 let tools = ToolRegistry::new().register(EchoTool);
1208 let mut ctx = AgentContext::new();
1209 let mut messages = vec![Message::user("go")];
1210 let config = LoopConfig::default();
1211
1212 let result = run_loop(
1213 &FailingAgent,
1214 &tools,
1215 &mut ctx,
1216 &mut messages,
1217 &config,
1218 |_| {},
1219 )
1220 .await;
1221 assert!(result.is_err());
1223 assert_eq!(messages.len(), 1); }
1225
1226 #[tokio::test]
1227 async fn loop_recovers_from_parse_error() {
1228 struct ParseRetryAgent {
1230 call_count: Arc<AtomicUsize>,
1231 }
1232 #[async_trait::async_trait]
1233 impl Agent for ParseRetryAgent {
1234 async fn decide(
1235 &self,
1236 msgs: &[Message],
1237 _: &ToolRegistry,
1238 ) -> Result<Decision, AgentError> {
1239 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1240 if n == 0 {
1241 Err(AgentError::Llm(SgrError::Schema(
1243 "Missing required field: situation".into(),
1244 )))
1245 } else {
1246 let last = msgs.last().unwrap();
1248 assert!(
1249 last.content.contains("Parse error"),
1250 "expected parse error feedback, got: {}",
1251 last.content
1252 );
1253 Ok(Decision {
1254 situation: "recovered from parse error".into(),
1255 task: vec![],
1256 tool_calls: vec![],
1257 completed: true,
1258 })
1259 }
1260 }
1261 }
1262
1263 let tools = ToolRegistry::new().register(EchoTool);
1264 let mut ctx = AgentContext::new();
1265 let mut messages = vec![Message::user("go")];
1266 let config = LoopConfig::default();
1267 let agent = ParseRetryAgent {
1268 call_count: Arc::new(AtomicUsize::new(0)),
1269 };
1270
1271 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1272 .await
1273 .unwrap();
1274 assert_eq!(steps, 2); assert_eq!(ctx.state, AgentState::Completed);
1276 }
1277
1278 #[tokio::test]
1279 async fn loop_aborts_after_max_parse_retries() {
1280 struct AlwaysFailParseAgent;
1281 #[async_trait::async_trait]
1282 impl Agent for AlwaysFailParseAgent {
1283 async fn decide(
1284 &self,
1285 _: &[Message],
1286 _: &ToolRegistry,
1287 ) -> Result<Decision, AgentError> {
1288 Err(AgentError::Llm(SgrError::Schema("bad json".into())))
1289 }
1290 }
1291
1292 let tools = ToolRegistry::new().register(EchoTool);
1293 let mut ctx = AgentContext::new();
1294 let mut messages = vec![Message::user("go")];
1295 let config = LoopConfig::default();
1296
1297 let result = run_loop(
1298 &AlwaysFailParseAgent,
1299 &tools,
1300 &mut ctx,
1301 &mut messages,
1302 &config,
1303 |_| {},
1304 )
1305 .await;
1306 assert!(result.is_err());
1307 let feedback_count = messages
1309 .iter()
1310 .filter(|m| m.content.contains("Parse error"))
1311 .count();
1312 assert_eq!(feedback_count, MAX_PARSE_RETRIES);
1313 }
1314
1315 #[tokio::test]
1316 async fn loop_feeds_tool_errors_back() {
1317 struct ErrorRecoveryAgent {
1319 call_count: Arc<AtomicUsize>,
1320 }
1321 #[async_trait::async_trait]
1322 impl Agent for ErrorRecoveryAgent {
1323 async fn decide(
1324 &self,
1325 msgs: &[Message],
1326 _: &ToolRegistry,
1327 ) -> Result<Decision, AgentError> {
1328 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1329 if n == 0 {
1330 Ok(Decision {
1332 situation: "trying".into(),
1333 task: vec![],
1334 tool_calls: vec![ToolCall {
1335 id: "1".into(),
1336 name: "nonexistent_tool".into(),
1337 arguments: serde_json::json!({}),
1338 }],
1339 completed: false,
1340 })
1341 } else {
1342 let last = msgs.last().unwrap();
1344 assert!(last.content.contains("Unknown tool"));
1345 Ok(Decision {
1346 situation: "recovered".into(),
1347 task: vec![],
1348 tool_calls: vec![],
1349 completed: true,
1350 })
1351 }
1352 }
1353 }
1354
1355 let tools = ToolRegistry::new().register(EchoTool);
1356 let mut ctx = AgentContext::new();
1357 let mut messages = vec![Message::user("go")];
1358 let config = LoopConfig::default();
1359 let agent = ErrorRecoveryAgent {
1360 call_count: Arc::new(AtomicUsize::new(0)),
1361 };
1362
1363 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1364 .await
1365 .unwrap();
1366 assert_eq!(steps, 2);
1367 assert_eq!(ctx.state, AgentState::Completed);
1368 }
1369
1370 #[tokio::test]
1371 async fn parallel_readonly_tools() {
1372 struct ReadOnlyTool {
1373 name: &'static str,
1374 }
1375
1376 #[async_trait::async_trait]
1377 impl Tool for ReadOnlyTool {
1378 fn name(&self) -> &str {
1379 self.name
1380 }
1381 fn description(&self) -> &str {
1382 "read-only tool"
1383 }
1384 fn is_read_only(&self) -> bool {
1385 true
1386 }
1387 fn parameters_schema(&self) -> Value {
1388 serde_json::json!({"type": "object"})
1389 }
1390 async fn execute(
1391 &self,
1392 _: Value,
1393 _: &mut AgentContext,
1394 ) -> Result<ToolOutput, ToolError> {
1395 Ok(ToolOutput::text(format!("{} result", self.name)))
1396 }
1397 async fn execute_readonly(&self, _: Value) -> Result<ToolOutput, ToolError> {
1398 Ok(ToolOutput::text(format!("{} result", self.name)))
1399 }
1400 }
1401
1402 struct ParallelAgent;
1403 #[async_trait::async_trait]
1404 impl Agent for ParallelAgent {
1405 async fn decide(
1406 &self,
1407 msgs: &[Message],
1408 _: &ToolRegistry,
1409 ) -> Result<Decision, AgentError> {
1410 if msgs.len() > 3 {
1411 return Ok(Decision {
1412 situation: "done".into(),
1413 task: vec![],
1414 tool_calls: vec![],
1415 completed: true,
1416 });
1417 }
1418 Ok(Decision {
1419 situation: "reading".into(),
1420 task: vec![],
1421 tool_calls: vec![
1422 ToolCall {
1423 id: "1".into(),
1424 name: "reader_a".into(),
1425 arguments: serde_json::json!({}),
1426 },
1427 ToolCall {
1428 id: "2".into(),
1429 name: "reader_b".into(),
1430 arguments: serde_json::json!({}),
1431 },
1432 ],
1433 completed: false,
1434 })
1435 }
1436 }
1437
1438 let tools = ToolRegistry::new()
1439 .register(ReadOnlyTool { name: "reader_a" })
1440 .register(ReadOnlyTool { name: "reader_b" });
1441 let mut ctx = AgentContext::new();
1442 let mut messages = vec![Message::user("read stuff")];
1443 let config = LoopConfig::default();
1444
1445 let steps = run_loop(
1446 &ParallelAgent,
1447 &tools,
1448 &mut ctx,
1449 &mut messages,
1450 &config,
1451 |_| {},
1452 )
1453 .await
1454 .unwrap();
1455 assert!(steps > 0);
1456 assert_eq!(ctx.state, AgentState::Completed);
1457 }
1458
1459 #[tokio::test]
1460 async fn loop_events_are_emitted() {
1461 let agent = CountingAgent {
1462 max_calls: 1,
1463 call_count: Arc::new(AtomicUsize::new(0)),
1464 };
1465 let tools = ToolRegistry::new().register(EchoTool);
1466 let mut ctx = AgentContext::new();
1467 let mut messages = vec![Message::user("go")];
1468 let config = LoopConfig::default();
1469
1470 let mut events = Vec::new();
1471 run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |e| {
1472 events.push(format!("{:?}", std::mem::discriminant(&e)));
1473 })
1474 .await
1475 .unwrap();
1476
1477 assert!(events.len() >= 4);
1479 }
1480
1481 #[tokio::test]
1482 async fn tool_output_done_stops_loop() {
1483 struct DoneTool;
1485 #[async_trait::async_trait]
1486 impl Tool for DoneTool {
1487 fn name(&self) -> &str {
1488 "done_tool"
1489 }
1490 fn description(&self) -> &str {
1491 "returns done"
1492 }
1493 fn parameters_schema(&self) -> Value {
1494 serde_json::json!({"type": "object"})
1495 }
1496 async fn execute(
1497 &self,
1498 _: Value,
1499 _: &mut AgentContext,
1500 ) -> Result<ToolOutput, ToolError> {
1501 Ok(ToolOutput::done("final answer"))
1502 }
1503 }
1504
1505 struct OneShotAgent;
1506 #[async_trait::async_trait]
1507 impl Agent for OneShotAgent {
1508 async fn decide(
1509 &self,
1510 _: &[Message],
1511 _: &ToolRegistry,
1512 ) -> Result<Decision, AgentError> {
1513 Ok(Decision {
1514 situation: "calling done tool".into(),
1515 task: vec![],
1516 tool_calls: vec![ToolCall {
1517 id: "1".into(),
1518 name: "done_tool".into(),
1519 arguments: serde_json::json!({}),
1520 }],
1521 completed: false,
1522 })
1523 }
1524 }
1525
1526 let tools = ToolRegistry::new().register(DoneTool);
1527 let mut ctx = AgentContext::new();
1528 let mut messages = vec![Message::user("go")];
1529 let config = LoopConfig::default();
1530
1531 let steps = run_loop(
1532 &OneShotAgent,
1533 &tools,
1534 &mut ctx,
1535 &mut messages,
1536 &config,
1537 |_| {},
1538 )
1539 .await
1540 .unwrap();
1541 assert_eq!(
1542 steps, 1,
1543 "Loop should stop on first step when tool returns done"
1544 );
1545 assert_eq!(ctx.state, AgentState::Completed);
1546 }
1547
1548 #[tokio::test]
1549 async fn tool_messages_formatted_correctly() {
1550 let agent = CountingAgent {
1553 max_calls: 1,
1554 call_count: Arc::new(AtomicUsize::new(0)),
1555 };
1556 let tools = ToolRegistry::new().register(EchoTool);
1557 let mut ctx = AgentContext::new();
1558 let mut messages = vec![Message::user("go")];
1559 let config = LoopConfig::default();
1560
1561 run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1562 .await
1563 .unwrap();
1564
1565 assert!(messages.len() >= 4);
1568
1569 let assistant_tc = messages
1571 .iter()
1572 .find(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty());
1573 assert!(
1574 assistant_tc.is_some(),
1575 "Should have an assistant message with tool_calls"
1576 );
1577 let atc = assistant_tc.unwrap();
1578 assert_eq!(atc.tool_calls[0].name, "echo");
1579 assert_eq!(atc.tool_calls[0].id, "call_0");
1580
1581 let tc_idx = messages
1583 .iter()
1584 .position(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty())
1585 .unwrap();
1586 let tool_msg = &messages[tc_idx + 1];
1587 assert_eq!(tool_msg.role, crate::types::Role::Tool);
1588 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("call_0"));
1589 assert_eq!(tool_msg.content, "echoed");
1590 }
1591}