1use crate::contracts::runtime::{StopPolicy, StopPolicyInput, StopPolicyStats};
40use crate::contracts::thread::ToolCall;
41use crate::contracts::RunContext;
42use crate::contracts::StopConditionSpec;
43pub use crate::contracts::StopReason;
44use std::collections::VecDeque;
45use std::sync::Arc;
46use std::time::Duration;
47
48pub struct StopCheckContext<'a> {
58 pub rounds: usize,
60 pub total_input_tokens: usize,
62 pub total_output_tokens: usize,
64 pub consecutive_errors: usize,
66 pub elapsed: Duration,
68 pub last_tool_calls: &'a [ToolCall],
70 pub last_text: &'a str,
72 pub tool_call_history: &'a VecDeque<Vec<String>>,
74 pub run_ctx: &'a RunContext,
79}
80
81impl<'a> StopCheckContext<'a> {
82 pub fn as_policy_input(&'a self) -> StopPolicyInput<'a> {
84 StopPolicyInput {
85 agent_state: self.run_ctx,
86 stats: StopPolicyStats {
87 step: self.rounds,
88 step_tool_call_count: self.last_tool_calls.len(),
89 total_tool_call_count: self.tool_call_history.iter().map(std::vec::Vec::len).sum(),
90 total_input_tokens: self.total_input_tokens,
91 total_output_tokens: self.total_output_tokens,
92 consecutive_errors: self.consecutive_errors,
93 elapsed: self.elapsed,
94 last_tool_calls: self.last_tool_calls,
95 last_text: self.last_text,
96 tool_call_history: self.tool_call_history,
97 },
98 }
99 }
100}
101
102fn stop_check_context_from_policy_input<'a>(
107 input: &'a StopPolicyInput<'a>,
108) -> StopCheckContext<'a> {
109 StopCheckContext {
110 rounds: input.stats.step,
111 total_input_tokens: input.stats.total_input_tokens,
112 total_output_tokens: input.stats.total_output_tokens,
113 consecutive_errors: input.stats.consecutive_errors,
114 elapsed: input.stats.elapsed,
115 last_tool_calls: input.stats.last_tool_calls,
116 last_text: input.stats.last_text,
117 tool_call_history: input.stats.tool_call_history,
118 run_ctx: input.agent_state,
119 }
120}
121
122macro_rules! impl_stop_policy_via_check {
123 ($ty:ty, $id:literal) => {
124 impl StopPolicy for $ty {
125 fn id(&self) -> &str {
126 $id
127 }
128
129 fn evaluate(&self, input: &StopPolicyInput<'_>) -> Option<StopReason> {
130 let ctx = stop_check_context_from_policy_input(input);
131 self.check(&ctx)
132 }
133 }
134 };
135}
136
137pub(crate) fn check_stop_policies(
139 conditions: &[Arc<dyn StopPolicy>],
140 input: &StopPolicyInput<'_>,
141) -> Option<StopReason> {
142 for condition in conditions {
143 if let Some(reason) = StopPolicy::evaluate(condition.as_ref(), input) {
144 return Some(reason);
145 }
146 }
147 None
148}
149
150pub struct MaxRounds(pub usize);
156
157impl MaxRounds {
158 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
159 if ctx.rounds >= self.0 {
160 Some(StopReason::MaxRoundsReached)
161 } else {
162 None
163 }
164 }
165}
166
167impl_stop_policy_via_check!(MaxRounds, "max_rounds");
168
169pub struct Timeout(pub Duration);
171
172impl Timeout {
173 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
174 if ctx.elapsed >= self.0 {
175 Some(StopReason::TimeoutReached)
176 } else {
177 None
178 }
179 }
180}
181
182impl_stop_policy_via_check!(Timeout, "timeout");
183
184pub struct TokenBudget {
186 pub max_total: usize,
188}
189
190impl TokenBudget {
191 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
192 if self.max_total > 0
193 && (ctx.total_input_tokens + ctx.total_output_tokens) >= self.max_total
194 {
195 Some(StopReason::TokenBudgetExceeded)
196 } else {
197 None
198 }
199 }
200}
201
202impl_stop_policy_via_check!(TokenBudget, "token_budget");
203
204pub struct ConsecutiveErrors(pub usize);
206
207impl ConsecutiveErrors {
208 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
209 if self.0 > 0 && ctx.consecutive_errors >= self.0 {
210 Some(StopReason::ConsecutiveErrorsExceeded)
211 } else {
212 None
213 }
214 }
215}
216
217impl_stop_policy_via_check!(ConsecutiveErrors, "consecutive_errors");
218
219pub struct StopOnTool(pub String);
221
222impl StopOnTool {
223 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
224 for call in ctx.last_tool_calls {
225 if call.name == self.0 {
226 return Some(StopReason::ToolCalled(self.0.clone()));
227 }
228 }
229 None
230 }
231}
232
233impl_stop_policy_via_check!(StopOnTool, "stop_on_tool");
234
235pub struct ContentMatch(pub String);
237
238impl ContentMatch {
239 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
240 if !self.0.is_empty() && ctx.last_text.contains(&self.0) {
241 Some(StopReason::ContentMatched(self.0.clone()))
242 } else {
243 None
244 }
245 }
246}
247
248impl_stop_policy_via_check!(ContentMatch, "content_match");
249
250pub struct LoopDetection {
256 pub window: usize,
258}
259
260impl LoopDetection {
261 fn check(&self, ctx: &StopCheckContext<'_>) -> Option<StopReason> {
262 let window = self.window.max(2);
263 let history = ctx.tool_call_history;
264 if history.len() < 2 {
265 return None;
266 }
267
268 let recent: Vec<_> = history.iter().rev().take(window).collect();
270 for pair in recent.windows(2) {
271 if pair[0] == pair[1] {
272 return Some(StopReason::LoopDetected);
273 }
274 }
275 None
276 }
277}
278
279impl_stop_policy_via_check!(LoopDetection, "loop_detection");
280
281pub(crate) fn condition_from_spec(spec: StopConditionSpec) -> Arc<dyn StopPolicy> {
287 match spec {
288 StopConditionSpec::MaxRounds { rounds } => Arc::new(MaxRounds(rounds)),
289 StopConditionSpec::Timeout { seconds } => Arc::new(Timeout(Duration::from_secs(seconds))),
290 StopConditionSpec::TokenBudget { max_total } => Arc::new(TokenBudget { max_total }),
291 StopConditionSpec::ConsecutiveErrors { max } => Arc::new(ConsecutiveErrors(max)),
292 StopConditionSpec::StopOnTool { tool_name } => Arc::new(StopOnTool(tool_name)),
293 StopConditionSpec::ContentMatch { pattern } => Arc::new(ContentMatch(pattern)),
294 StopConditionSpec::LoopDetection { window } => Arc::new(LoopDetection { window }),
295 }
296}
297
298#[cfg(test)]
303mod tests {
304 use super::*;
305 use serde_json::json;
306 use std::sync::LazyLock;
307 use tirea_contract::RunConfig;
308
309 static TEST_RUN_CTX: LazyLock<RunContext> =
310 LazyLock::new(|| RunContext::new("test", json!({}), vec![], RunConfig::default()));
311
312 fn empty_context() -> StopCheckContext<'static> {
313 static EMPTY_TOOL_CALLS: &[ToolCall] = &[];
314 static EMPTY_HISTORY: VecDeque<Vec<String>> = VecDeque::new();
315 StopCheckContext {
316 rounds: 0,
317 total_input_tokens: 0,
318 total_output_tokens: 0,
319 consecutive_errors: 0,
320 elapsed: Duration::ZERO,
321 last_tool_calls: EMPTY_TOOL_CALLS,
322 last_text: "",
323 tool_call_history: &EMPTY_HISTORY,
324 run_ctx: &TEST_RUN_CTX,
325 }
326 }
327
328 fn make_tool_call(name: &str) -> ToolCall {
329 ToolCall {
330 id: "tc-1".to_string(),
331 name: name.to_string(),
332 arguments: json!({}),
333 }
334 }
335
336 #[test]
339 fn max_rounds_none_when_under_limit() {
340 let cond = MaxRounds(3);
341 let mut ctx = empty_context();
342 ctx.rounds = 2;
343 assert!(cond.check(&ctx).is_none());
344 }
345
346 #[test]
347 fn max_rounds_triggers_at_limit() {
348 let cond = MaxRounds(3);
349 let mut ctx = empty_context();
350 ctx.rounds = 3;
351 assert_eq!(cond.check(&ctx), Some(StopReason::MaxRoundsReached));
352 }
353
354 #[test]
355 fn max_rounds_triggers_above_limit() {
356 let cond = MaxRounds(3);
357 let mut ctx = empty_context();
358 ctx.rounds = 5;
359 assert_eq!(cond.check(&ctx), Some(StopReason::MaxRoundsReached));
360 }
361
362 #[test]
365 fn timeout_none_when_under() {
366 let cond = Timeout(Duration::from_secs(10));
367 let mut ctx = empty_context();
368 ctx.elapsed = Duration::from_secs(5);
369 assert!(cond.check(&ctx).is_none());
370 }
371
372 #[test]
373 fn timeout_triggers_at_limit() {
374 let cond = Timeout(Duration::from_secs(10));
375 let mut ctx = empty_context();
376 ctx.elapsed = Duration::from_secs(10);
377 assert_eq!(cond.check(&ctx), Some(StopReason::TimeoutReached));
378 }
379
380 #[test]
383 fn token_budget_none_when_under() {
384 let cond = TokenBudget { max_total: 1000 };
385 let mut ctx = empty_context();
386 ctx.total_input_tokens = 400;
387 ctx.total_output_tokens = 500;
388 assert!(cond.check(&ctx).is_none());
389 }
390
391 #[test]
392 fn token_budget_triggers_at_limit() {
393 let cond = TokenBudget { max_total: 1000 };
394 let mut ctx = empty_context();
395 ctx.total_input_tokens = 600;
396 ctx.total_output_tokens = 400;
397 assert_eq!(cond.check(&ctx), Some(StopReason::TokenBudgetExceeded));
398 }
399
400 #[test]
401 fn token_budget_zero_means_unlimited() {
402 let cond = TokenBudget { max_total: 0 };
403 let mut ctx = empty_context();
404 ctx.total_input_tokens = 999_999;
405 ctx.total_output_tokens = 999_999;
406 assert!(cond.check(&ctx).is_none());
407 }
408
409 #[test]
412 fn consecutive_errors_none_when_under() {
413 let cond = ConsecutiveErrors(3);
414 let mut ctx = empty_context();
415 ctx.consecutive_errors = 2;
416 assert!(cond.check(&ctx).is_none());
417 }
418
419 #[test]
420 fn consecutive_errors_triggers_at_limit() {
421 let cond = ConsecutiveErrors(3);
422 let mut ctx = empty_context();
423 ctx.consecutive_errors = 3;
424 assert_eq!(
425 cond.check(&ctx),
426 Some(StopReason::ConsecutiveErrorsExceeded)
427 );
428 }
429
430 #[test]
431 fn consecutive_errors_zero_means_disabled() {
432 let cond = ConsecutiveErrors(0);
433 let mut ctx = empty_context();
434 ctx.consecutive_errors = 100;
435 assert!(cond.check(&ctx).is_none());
436 }
437
438 #[test]
441 fn stop_on_tool_none_when_not_called() {
442 let cond = StopOnTool("finish".to_string());
443 let calls = vec![make_tool_call("search")];
444 let history = VecDeque::new();
445 let ctx = StopCheckContext {
446 last_tool_calls: &calls,
447 tool_call_history: &history,
448 ..empty_context()
449 };
450 assert!(cond.check(&ctx).is_none());
451 }
452
453 #[test]
454 fn stop_on_tool_triggers_when_called() {
455 let cond = StopOnTool("finish".to_string());
456 let calls = vec![make_tool_call("search"), make_tool_call("finish")];
457 let history = VecDeque::new();
458 let ctx = StopCheckContext {
459 last_tool_calls: &calls,
460 tool_call_history: &history,
461 ..empty_context()
462 };
463 assert_eq!(
464 cond.check(&ctx),
465 Some(StopReason::ToolCalled("finish".to_string()))
466 );
467 }
468
469 #[test]
472 fn content_match_none_when_absent() {
473 let cond = ContentMatch("FINAL_ANSWER".to_string());
474 let history = VecDeque::new();
475 let ctx = StopCheckContext {
476 last_text: "Here is some text",
477 tool_call_history: &history,
478 ..empty_context()
479 };
480 assert!(cond.check(&ctx).is_none());
481 }
482
483 #[test]
484 fn content_match_triggers_when_present() {
485 let cond = ContentMatch("FINAL_ANSWER".to_string());
486 let history = VecDeque::new();
487 let ctx = StopCheckContext {
488 last_text: "The result is: FINAL_ANSWER: 42",
489 tool_call_history: &history,
490 ..empty_context()
491 };
492 assert_eq!(
493 cond.check(&ctx),
494 Some(StopReason::ContentMatched("FINAL_ANSWER".to_string()))
495 );
496 }
497
498 #[test]
499 fn content_match_empty_pattern_never_triggers() {
500 let cond = ContentMatch(String::new());
501 let history = VecDeque::new();
502 let ctx = StopCheckContext {
503 last_text: "anything",
504 tool_call_history: &history,
505 ..empty_context()
506 };
507 assert!(cond.check(&ctx).is_none());
508 }
509
510 #[test]
513 fn loop_detection_none_when_insufficient_history() {
514 let cond = LoopDetection { window: 3 };
515 let mut history = VecDeque::new();
516 history.push_back(vec!["search".to_string()]);
517 let ctx = StopCheckContext {
518 tool_call_history: &history,
519 ..empty_context()
520 };
521 assert!(cond.check(&ctx).is_none());
522 }
523
524 #[test]
525 fn loop_detection_none_when_different_patterns() {
526 let cond = LoopDetection { window: 3 };
527 let mut history = VecDeque::new();
528 history.push_back(vec!["search".to_string()]);
529 history.push_back(vec!["calculate".to_string()]);
530 history.push_back(vec!["write".to_string()]);
531 let ctx = StopCheckContext {
532 tool_call_history: &history,
533 ..empty_context()
534 };
535 assert!(cond.check(&ctx).is_none());
536 }
537
538 #[test]
539 fn loop_detection_triggers_on_consecutive_duplicate() {
540 let cond = LoopDetection { window: 3 };
541 let mut history = VecDeque::new();
542 history.push_back(vec!["search".to_string()]);
543 history.push_back(vec!["calculate".to_string()]);
544 history.push_back(vec!["calculate".to_string()]);
545 let ctx = StopCheckContext {
546 tool_call_history: &history,
547 ..empty_context()
548 };
549 assert_eq!(cond.check(&ctx), Some(StopReason::LoopDetected));
550 }
551
552 #[test]
555 fn check_stop_conditions_returns_first_match() {
556 let conditions: Vec<Arc<dyn StopPolicy>> = vec![
557 Arc::new(MaxRounds(5)),
558 Arc::new(Timeout(Duration::from_secs(10))),
559 ];
560 let mut ctx = empty_context();
561 ctx.rounds = 5;
562 ctx.elapsed = Duration::from_secs(15);
563 assert_eq!(
565 check_stop_policies(&conditions, &ctx.as_policy_input()),
566 Some(StopReason::MaxRoundsReached)
567 );
568 }
569
570 #[test]
571 fn check_stop_conditions_returns_none_when_all_pass() {
572 let conditions: Vec<Arc<dyn StopPolicy>> = vec![
573 Arc::new(MaxRounds(10)),
574 Arc::new(Timeout(Duration::from_secs(60))),
575 ];
576 let mut ctx = empty_context();
577 ctx.rounds = 3;
578 ctx.elapsed = Duration::from_secs(5);
579 assert!(check_stop_policies(&conditions, &ctx.as_policy_input()).is_none());
580 }
581
582 #[test]
583 fn check_stop_conditions_empty_always_none() {
584 let conditions: Vec<Arc<dyn StopPolicy>> = vec![];
585 let ctx = empty_context();
586 assert!(check_stop_policies(&conditions, &ctx.as_policy_input()).is_none());
587 }
588
589 #[test]
592 fn stop_reason_serialization_roundtrip() {
593 let reasons = vec![
594 StopReason::MaxRoundsReached,
595 StopReason::TimeoutReached,
596 StopReason::TokenBudgetExceeded,
597 StopReason::ToolCalled("finish".to_string()),
598 StopReason::ContentMatched("DONE".to_string()),
599 StopReason::ConsecutiveErrorsExceeded,
600 StopReason::LoopDetected,
601 StopReason::Custom("my_reason".to_string()),
602 ];
603 for reason in reasons {
604 let json = serde_json::to_string(&reason).unwrap();
605 let back: StopReason = serde_json::from_str(&json).unwrap();
606 assert_eq!(reason, back);
607 }
608 }
609
610 struct AlwaysStop;
613 impl StopPolicy for AlwaysStop {
614 fn id(&self) -> &str {
615 "always_stop"
616 }
617 fn evaluate(&self, _input: &StopPolicyInput<'_>) -> Option<StopReason> {
618 Some(StopReason::Custom("always".to_string()))
619 }
620 }
621
622 #[test]
623 fn custom_stop_policy_works() {
624 let conditions: Vec<Arc<dyn StopPolicy>> = vec![Arc::new(AlwaysStop)];
625 let ctx = empty_context();
626 assert_eq!(
627 check_stop_policies(&conditions, &ctx.as_policy_input()),
628 Some(StopReason::Custom("always".to_string()))
629 );
630 }
631
632 #[test]
635 fn stop_condition_spec_serialization_roundtrip() {
636 let specs = vec![
637 StopConditionSpec::MaxRounds { rounds: 5 },
638 StopConditionSpec::Timeout { seconds: 30 },
639 StopConditionSpec::TokenBudget { max_total: 1000 },
640 StopConditionSpec::ConsecutiveErrors { max: 3 },
641 StopConditionSpec::StopOnTool {
642 tool_name: "finish".to_string(),
643 },
644 StopConditionSpec::ContentMatch {
645 pattern: "DONE".to_string(),
646 },
647 StopConditionSpec::LoopDetection { window: 4 },
648 ];
649 for spec in specs {
650 let json = serde_json::to_string(&spec).unwrap();
651 let back: StopConditionSpec = serde_json::from_str(&json).unwrap();
652 assert_eq!(spec, back);
653 }
654 }
655
656 #[test]
657 fn stop_condition_spec_json_format() {
658 let spec = StopConditionSpec::MaxRounds { rounds: 5 };
659 let json = serde_json::to_string(&spec).unwrap();
660 assert_eq!(json, r#"{"type":"max_rounds","rounds":5}"#);
661
662 let spec = StopConditionSpec::StopOnTool {
663 tool_name: "done".to_string(),
664 };
665 let json = serde_json::to_string(&spec).unwrap();
666 assert_eq!(json, r#"{"type":"stop_on_tool","tool_name":"done"}"#);
667 }
668
669 #[test]
670 fn stop_condition_spec_into_condition_max_rounds() {
671 let spec = StopConditionSpec::MaxRounds { rounds: 3 };
672 let cond = condition_from_spec(spec);
673 assert_eq!(cond.id(), "max_rounds");
674 let mut ctx = empty_context();
675 ctx.rounds = 3;
676 assert_eq!(
677 cond.evaluate(&ctx.as_policy_input()),
678 Some(StopReason::MaxRoundsReached)
679 );
680 }
681
682 #[test]
683 fn stop_condition_spec_into_condition_timeout() {
684 let spec = StopConditionSpec::Timeout { seconds: 10 };
685 let cond = condition_from_spec(spec);
686 assert_eq!(cond.id(), "timeout");
687 let mut ctx = empty_context();
688 ctx.elapsed = Duration::from_secs(10);
689 assert_eq!(
690 cond.evaluate(&ctx.as_policy_input()),
691 Some(StopReason::TimeoutReached)
692 );
693 }
694
695 #[test]
696 fn stop_condition_spec_into_condition_token_budget() {
697 let spec = StopConditionSpec::TokenBudget { max_total: 100 };
698 let cond = condition_from_spec(spec);
699 let mut ctx = empty_context();
700 ctx.total_input_tokens = 60;
701 ctx.total_output_tokens = 50;
702 assert_eq!(
703 cond.evaluate(&ctx.as_policy_input()),
704 Some(StopReason::TokenBudgetExceeded)
705 );
706 }
707
708 #[test]
709 fn stop_condition_spec_into_condition_stop_on_tool() {
710 let spec = StopConditionSpec::StopOnTool {
711 tool_name: "finish".to_string(),
712 };
713 let cond = condition_from_spec(spec);
714 let calls = vec![make_tool_call("finish")];
715 let history = VecDeque::new();
716 let ctx = StopCheckContext {
717 last_tool_calls: &calls,
718 tool_call_history: &history,
719 ..empty_context()
720 };
721 assert_eq!(
722 cond.evaluate(&ctx.as_policy_input()),
723 Some(StopReason::ToolCalled("finish".to_string()))
724 );
725 }
726}