1#![forbid(unsafe_code)]
19
20use std::panic::AssertUnwindSafe;
21use std::path::Path;
22use std::sync::Arc;
23
24use tracing::{debug, warn};
25
26use crate::types::{
27 AgentMessage, AssistantMessage, Cost, ModelSpec, StopReason, ToolResultMessage, Usage,
28};
29
30#[derive(Debug)]
36pub enum PolicyVerdict {
37 Continue,
39 Stop(String),
41 Inject(Vec<AgentMessage>),
43}
44
45#[derive(Debug)]
50pub enum PreDispatchVerdict {
51 Continue,
53 Stop(String),
55 Inject(Vec<AgentMessage>),
57 Skip(String),
59}
60
61#[derive(Debug)]
65pub struct PolicyContext<'a> {
66 pub turn_index: usize,
68 pub accumulated_usage: &'a Usage,
70 pub accumulated_cost: &'a Cost,
72 pub message_count: usize,
74 pub overflow_signal: bool,
76 pub new_messages: &'a [AgentMessage],
84 pub state: &'a crate::SessionState,
86}
87
88pub struct ToolDispatchContext<'a> {
96 pub tool_name: &'a str,
98 pub tool_call_id: &'a str,
100 pub arguments: &'a mut serde_json::Value,
102 pub execution_root: Option<&'a Path>,
104 pub state: &'a crate::SessionState,
106}
107
108impl std::fmt::Debug for ToolDispatchContext<'_> {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_struct("ToolDispatchContext")
111 .field("tool_name", &self.tool_name)
112 .field("tool_call_id", &self.tool_call_id)
113 .field("execution_root", &self.execution_root)
114 .field("arguments", &"<redacted>")
115 .finish()
116 }
117}
118
119#[derive(Debug)]
121pub struct TurnPolicyContext<'a> {
122 pub assistant_message: &'a AssistantMessage,
124 pub tool_results: &'a [ToolResultMessage],
126 pub stop_reason: StopReason,
128 pub system_prompt: &'a str,
130 pub model_spec: &'a ModelSpec,
132 pub context_messages: &'a [AgentMessage],
138}
139
140pub trait PreTurnPolicy: Send + Sync {
150 fn name(&self) -> &str;
152 fn evaluate(&self, ctx: &PolicyContext<'_>) -> PolicyVerdict;
154}
155
156pub trait PreDispatchPolicy: Send + Sync {
161 fn name(&self) -> &str;
163 fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict;
165}
166
167pub trait PostTurnPolicy: Send + Sync {
171 fn name(&self) -> &str;
173 fn evaluate(&self, ctx: &PolicyContext<'_>, turn: &TurnPolicyContext<'_>) -> PolicyVerdict;
175}
176
177pub trait PostLoopPolicy: Send + Sync {
181 fn name(&self) -> &str;
183 fn evaluate(&self, ctx: &PolicyContext<'_>) -> PolicyVerdict;
185}
186
187pub fn run_policies(policies: &[Arc<dyn PreTurnPolicy>], ctx: &PolicyContext<'_>) -> PolicyVerdict {
195 run_policies_inner(policies.iter().map(std::convert::AsRef::as_ref), ctx)
196}
197
198pub fn run_post_turn_policies(
200 policies: &[Arc<dyn PostTurnPolicy>],
201 ctx: &PolicyContext<'_>,
202 turn: &TurnPolicyContext<'_>,
203) -> PolicyVerdict {
204 let mut injections: Vec<AgentMessage> = Vec::new();
205
206 for policy in policies {
207 let policy_name = policy.name().to_string();
208 let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx, turn)));
209
210 match result {
211 Ok(PolicyVerdict::Continue) => {}
212 Ok(PolicyVerdict::Stop(reason)) => {
213 debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
214 return PolicyVerdict::Stop(reason);
215 }
216 Ok(PolicyVerdict::Inject(msgs)) => {
217 injections.extend(msgs);
218 }
219 Err(_) => {
220 warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
221 }
222 }
223 }
224
225 if injections.is_empty() {
226 PolicyVerdict::Continue
227 } else {
228 PolicyVerdict::Inject(injections)
229 }
230}
231
232pub fn run_post_loop_policies(
234 policies: &[Arc<dyn PostLoopPolicy>],
235 ctx: &PolicyContext<'_>,
236) -> PolicyVerdict {
237 let mut injections: Vec<AgentMessage> = Vec::new();
238
239 for policy in policies {
240 let policy_name = policy.name().to_string();
241 let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
242
243 match result {
244 Ok(PolicyVerdict::Continue) => {}
245 Ok(PolicyVerdict::Stop(reason)) => {
246 debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
247 return PolicyVerdict::Stop(reason);
248 }
249 Ok(PolicyVerdict::Inject(msgs)) => {
250 injections.extend(msgs);
251 }
252 Err(_) => {
253 warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
254 }
255 }
256 }
257
258 if injections.is_empty() {
259 PolicyVerdict::Continue
260 } else {
261 PolicyVerdict::Inject(injections)
262 }
263}
264
265fn run_policies_inner<'a>(
267 policies: impl Iterator<Item = &'a dyn PreTurnPolicy>,
268 ctx: &PolicyContext<'_>,
269) -> PolicyVerdict {
270 let mut injections: Vec<AgentMessage> = Vec::new();
271
272 for policy in policies {
273 let policy_name = policy.name().to_string();
274 let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
275
276 match result {
277 Ok(PolicyVerdict::Continue) => {}
278 Ok(PolicyVerdict::Stop(reason)) => {
279 debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
280 return PolicyVerdict::Stop(reason);
281 }
282 Ok(PolicyVerdict::Inject(msgs)) => {
283 injections.extend(msgs);
284 }
285 Err(_) => {
286 warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
287 }
288 }
289 }
290
291 if injections.is_empty() {
292 PolicyVerdict::Continue
293 } else {
294 PolicyVerdict::Inject(injections)
295 }
296}
297
298pub fn run_pre_dispatch_policies(
305 policies: &[Arc<dyn PreDispatchPolicy>],
306 ctx: &mut ToolDispatchContext<'_>,
307) -> PreDispatchVerdict {
308 let mut injections: Vec<AgentMessage> = Vec::new();
309
310 for policy in policies {
311 let policy_name = policy.name().to_string();
312 let argument_snapshot = ctx.arguments.clone();
313 let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
314
315 match result {
316 Ok(PreDispatchVerdict::Continue) => {}
317 Ok(PreDispatchVerdict::Stop(reason)) => {
318 debug!(policy = %policy_name, reason = %reason, "policy stopped loop (pre-dispatch)");
319 return PreDispatchVerdict::Stop(reason);
320 }
321 Ok(PreDispatchVerdict::Skip(error_text)) => {
322 debug!(policy = %policy_name, "policy skipped tool call");
323 return PreDispatchVerdict::Skip(error_text);
324 }
325 Ok(PreDispatchVerdict::Inject(msgs)) => {
326 injections.extend(msgs);
327 }
328 Err(_) => {
329 *ctx.arguments = argument_snapshot;
330 warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
331 }
332 }
333 }
334
335 if injections.is_empty() {
336 PreDispatchVerdict::Continue
337 } else {
338 PreDispatchVerdict::Inject(injections)
339 }
340}
341
342#[cfg(test)]
345mod tests {
346 use super::*;
347 use std::sync::atomic::{AtomicUsize, Ordering};
348
349 struct TestPolicy {
352 policy_name: String,
353 make_verdict: Box<dyn Fn() -> PolicyVerdict + Send + Sync>,
354 call_count: AtomicUsize,
355 }
356
357 impl TestPolicy {
358 fn new(name: &str, make: impl Fn() -> PolicyVerdict + Send + Sync + 'static) -> Self {
359 Self {
360 policy_name: name.to_string(),
361 make_verdict: Box::new(make),
362 call_count: AtomicUsize::new(0),
363 }
364 }
365
366 fn calls(&self) -> usize {
367 self.call_count.load(Ordering::SeqCst)
368 }
369 }
370
371 impl PreTurnPolicy for TestPolicy {
372 fn name(&self) -> &str {
373 &self.policy_name
374 }
375 fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
376 self.call_count.fetch_add(1, Ordering::SeqCst);
377 (self.make_verdict)()
378 }
379 }
380
381 struct PanickingPolicy;
382 impl PreTurnPolicy for PanickingPolicy {
383 fn name(&self) -> &'static str {
384 "panicker"
385 }
386 fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
387 panic!("policy intentionally panicked");
388 }
389 }
390
391 struct TestPreDispatchPolicy {
392 policy_name: String,
393 make_verdict: Box<dyn Fn() -> PreDispatchVerdict + Send + Sync>,
394 call_count: AtomicUsize,
395 }
396
397 impl TestPreDispatchPolicy {
398 fn new(name: &str, make: impl Fn() -> PreDispatchVerdict + Send + Sync + 'static) -> Self {
399 Self {
400 policy_name: name.to_string(),
401 make_verdict: Box::new(make),
402 call_count: AtomicUsize::new(0),
403 }
404 }
405
406 fn calls(&self) -> usize {
407 self.call_count.load(Ordering::SeqCst)
408 }
409 }
410
411 impl PreDispatchPolicy for TestPreDispatchPolicy {
412 fn name(&self) -> &str {
413 &self.policy_name
414 }
415 fn evaluate(&self, _ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
416 self.call_count.fetch_add(1, Ordering::SeqCst);
417 (self.make_verdict)()
418 }
419 }
420
421 struct PanickingPreDispatchPolicy;
422 impl PreDispatchPolicy for PanickingPreDispatchPolicy {
423 fn name(&self) -> &'static str {
424 "panicker"
425 }
426 fn evaluate(&self, _ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
427 panic!("pre-dispatch policy panicked");
428 }
429 }
430
431 struct MutatingPreDispatchPolicy;
432 impl PreDispatchPolicy for MutatingPreDispatchPolicy {
433 fn name(&self) -> &'static str {
434 "mutator"
435 }
436 fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
437 if let Some(obj) = ctx.arguments.as_object_mut() {
438 obj.insert("injected".to_string(), serde_json::json!("by_policy"));
439 }
440 PreDispatchVerdict::Continue
441 }
442 }
443
444 struct VerifyingPreDispatchPolicy {
445 expected_key: String,
446 }
447 impl PreDispatchPolicy for VerifyingPreDispatchPolicy {
448 fn name(&self) -> &'static str {
449 "verifier"
450 }
451 fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
452 if ctx.arguments.get(&self.expected_key).is_some() {
453 PreDispatchVerdict::Continue
454 } else {
455 PreDispatchVerdict::Skip(format!("missing key: {}", self.expected_key))
456 }
457 }
458 }
459
460 fn test_message() -> AgentMessage {
461 AgentMessage::Llm(crate::types::LlmMessage::User(crate::types::UserMessage {
462 content: vec![],
463 timestamp: 0,
464 cache_hint: None,
465 }))
466 }
467
468 fn test_context() -> (Usage, Cost) {
469 (Usage::default(), Cost::default())
470 }
471
472 fn make_ctx<'a>(
473 usage: &'a Usage,
474 cost: &'a Cost,
475 state: &'a crate::SessionState,
476 ) -> PolicyContext<'a> {
477 PolicyContext {
478 turn_index: 0,
479 accumulated_usage: usage,
480 accumulated_cost: cost,
481 message_count: 5,
482 overflow_signal: false,
483 new_messages: &[],
484 state,
485 }
486 }
487
488 fn make_dispatch_ctx<'a>(
489 args: &'a mut serde_json::Value,
490 state: &'a crate::SessionState,
491 ) -> ToolDispatchContext<'a> {
492 ToolDispatchContext {
493 tool_name: "test_tool",
494 tool_call_id: "id1",
495 arguments: args,
496 execution_root: None,
497 state,
498 }
499 }
500
501 #[test]
504 fn policy_verdict_debug() {
505 let v = PolicyVerdict::Continue;
506 assert!(format!("{v:?}").contains("Continue"));
507
508 let v = PolicyVerdict::Stop("budget exceeded".to_string());
509 assert!(format!("{v:?}").contains("budget exceeded"));
510
511 let v = PolicyVerdict::Inject(vec![]);
512 assert!(format!("{v:?}").contains("Inject"));
513 }
514
515 #[test]
516 fn pre_dispatch_verdict_debug() {
517 let v = PreDispatchVerdict::Skip("denied".to_string());
518 assert!(format!("{v:?}").contains("denied"));
519
520 let v = PreDispatchVerdict::Stop("halt".to_string());
521 assert!(format!("{v:?}").contains("halt"));
522 }
523
524 #[test]
525 fn policy_context_construction() {
526 let (usage, cost) = test_context();
527 let state = crate::SessionState::new();
528 let ctx = make_ctx(&usage, &cost, &state);
529 assert_eq!(ctx.turn_index, 0);
530 assert_eq!(ctx.message_count, 5);
531 assert!(!ctx.overflow_signal);
532 }
533
534 #[test]
537 fn empty_vec_returns_continue() {
538 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![];
539 let (usage, cost) = test_context();
540 let state = crate::SessionState::new();
541 let ctx = make_ctx(&usage, &cost, &state);
542 let result = run_policies(&policies, &ctx);
543 assert!(matches!(result, PolicyVerdict::Continue));
544 }
545
546 #[test]
547 fn single_continue() {
548 let p = Arc::new(TestPolicy::new("a", || PolicyVerdict::Continue));
549 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p.clone()];
550 let (usage, cost) = test_context();
551 let state = crate::SessionState::new();
552 let ctx = make_ctx(&usage, &cost, &state);
553 let result = run_policies(&policies, &ctx);
554 assert!(matches!(result, PolicyVerdict::Continue));
555 assert_eq!(p.calls(), 1);
556 }
557
558 #[test]
559 fn single_stop_short_circuits() {
560 let p1 = Arc::new(TestPolicy::new("stopper", || {
561 PolicyVerdict::Stop("done".into())
562 }));
563 let p2 = Arc::new(TestPolicy::new("never_called", || PolicyVerdict::Continue));
564 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1.clone(), p2.clone()];
565 let (usage, cost) = test_context();
566 let state = crate::SessionState::new();
567 let ctx = make_ctx(&usage, &cost, &state);
568 let result = run_policies(&policies, &ctx);
569 assert!(matches!(result, PolicyVerdict::Stop(ref r) if r == "done"));
570 assert_eq!(p1.calls(), 1);
571 assert_eq!(p2.calls(), 0);
572 }
573
574 #[test]
575 fn inject_accumulates_across_policies() {
576 let p1 = Arc::new(TestPolicy::new("a", || {
577 PolicyVerdict::Inject(vec![test_message()])
578 }));
579 let p2 = Arc::new(TestPolicy::new("b", || {
580 PolicyVerdict::Inject(vec![test_message()])
581 }));
582 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2];
583 let (usage, cost) = test_context();
584 let state = crate::SessionState::new();
585 let ctx = make_ctx(&usage, &cost, &state);
586 let result = run_policies(&policies, &ctx);
587 match result {
588 PolicyVerdict::Inject(msgs) => assert_eq!(msgs.len(), 2),
589 _ => panic!("expected Inject"),
590 }
591 }
592
593 #[test]
594 fn stop_after_inject_returns_stop() {
595 let p1 = Arc::new(TestPolicy::new("injector", || {
596 PolicyVerdict::Inject(vec![test_message()])
597 }));
598 let p2 = Arc::new(TestPolicy::new("stopper", || {
599 PolicyVerdict::Stop("halt".into())
600 }));
601 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2];
602 let (usage, cost) = test_context();
603 let state = crate::SessionState::new();
604 let ctx = make_ctx(&usage, &cost, &state);
605 let result = run_policies(&policies, &ctx);
606 assert!(matches!(result, PolicyVerdict::Stop(ref r) if r == "halt"));
607 }
608
609 #[test]
610 fn panic_caught_returns_continue() {
611 let p1: Arc<dyn PreTurnPolicy> = Arc::new(PanickingPolicy);
612 let p2 = Arc::new(TestPolicy::new("after_panic", || PolicyVerdict::Continue));
613 let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2.clone()];
614 let (usage, cost) = test_context();
615 let state = crate::SessionState::new();
616 let ctx = make_ctx(&usage, &cost, &state);
617 let result = run_policies(&policies, &ctx);
618 assert!(matches!(result, PolicyVerdict::Continue));
619 assert_eq!(p2.calls(), 1); }
621
622 #[test]
625 fn pre_dispatch_empty_vec_returns_continue() {
626 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![];
627 let state = crate::SessionState::new();
628 let mut args = serde_json::json!({});
629 let mut ctx = make_dispatch_ctx(&mut args, &state);
630 let result = run_pre_dispatch_policies(&policies, &mut ctx);
631 assert!(matches!(result, PreDispatchVerdict::Continue));
632 }
633
634 #[test]
635 fn pre_dispatch_skip_short_circuits() {
636 let p1 = Arc::new(TestPreDispatchPolicy::new("skipper", || {
637 PreDispatchVerdict::Skip("denied".into())
638 }));
639 let p2 = Arc::new(TestPreDispatchPolicy::new("never", || {
640 PreDispatchVerdict::Continue
641 }));
642 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1.clone(), p2.clone()];
643 let state = crate::SessionState::new();
644 let mut args = serde_json::json!({});
645 let mut ctx = make_dispatch_ctx(&mut args, &state);
646 let result = run_pre_dispatch_policies(&policies, &mut ctx);
647 assert!(matches!(result, PreDispatchVerdict::Skip(ref e) if e == "denied"));
648 assert_eq!(p1.calls(), 1);
649 assert_eq!(p2.calls(), 0);
650 }
651
652 #[test]
653 fn pre_dispatch_stop_short_circuits() {
654 let p1 = Arc::new(TestPreDispatchPolicy::new("stopper", || {
655 PreDispatchVerdict::Stop("halt".into())
656 }));
657 let p2 = Arc::new(TestPreDispatchPolicy::new("never", || {
658 PreDispatchVerdict::Continue
659 }));
660 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2.clone()];
661 let state = crate::SessionState::new();
662 let mut args = serde_json::json!({});
663 let mut ctx = make_dispatch_ctx(&mut args, &state);
664 let result = run_pre_dispatch_policies(&policies, &mut ctx);
665 assert!(matches!(result, PreDispatchVerdict::Stop(ref r) if r == "halt"));
666 assert_eq!(p2.calls(), 0);
667 }
668
669 #[test]
670 fn pre_dispatch_inject_accumulates() {
671 let p1 = Arc::new(TestPreDispatchPolicy::new("a", || {
672 PreDispatchVerdict::Inject(vec![test_message()])
673 }));
674 let p2 = Arc::new(TestPreDispatchPolicy::new("b", || {
675 PreDispatchVerdict::Inject(vec![test_message()])
676 }));
677 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2];
678 let state = crate::SessionState::new();
679 let mut args = serde_json::json!({});
680 let mut ctx = make_dispatch_ctx(&mut args, &state);
681 let result = run_pre_dispatch_policies(&policies, &mut ctx);
682 match result {
683 PreDispatchVerdict::Inject(msgs) => assert_eq!(msgs.len(), 2),
684 _ => panic!("expected Inject"),
685 }
686 }
687
688 #[test]
689 fn pre_dispatch_panic_caught_returns_continue() {
690 let p1: Arc<dyn PreDispatchPolicy> = Arc::new(PanickingPreDispatchPolicy);
691 let p2 = Arc::new(TestPreDispatchPolicy::new("after", || {
692 PreDispatchVerdict::Continue
693 }));
694 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2.clone()];
695 let state = crate::SessionState::new();
696 let mut args = serde_json::json!({});
697 let mut ctx = make_dispatch_ctx(&mut args, &state);
698 let result = run_pre_dispatch_policies(&policies, &mut ctx);
699 assert!(matches!(result, PreDispatchVerdict::Continue));
700 assert_eq!(p2.calls(), 1);
701 }
702
703 #[test]
704 fn argument_mutation_visible_to_next_policy() {
705 let mutator: Arc<dyn PreDispatchPolicy> = Arc::new(MutatingPreDispatchPolicy);
706 let verifier: Arc<dyn PreDispatchPolicy> = Arc::new(VerifyingPreDispatchPolicy {
707 expected_key: "injected".to_string(),
708 });
709 let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![mutator, verifier];
710 let state = crate::SessionState::new();
711 let mut args = serde_json::json!({"original": "value"});
712 let mut ctx = make_dispatch_ctx(&mut args, &state);
713 let result = run_pre_dispatch_policies(&policies, &mut ctx);
714 assert!(matches!(result, PreDispatchVerdict::Continue));
716 assert_eq!(args["injected"], "by_policy");
718 }
719
720 #[test]
721 fn tool_dispatch_context_contains_only_reliable_fields() {
722 let state = crate::SessionState::new();
726 let mut args = serde_json::json!({"path": "/tmp/file"});
727 let ctx = ToolDispatchContext {
728 tool_name: "write_file",
729 tool_call_id: "call-123",
730 arguments: &mut args,
731 execution_root: None,
732 state: &state,
733 };
734 assert_eq!(ctx.tool_name, "write_file");
735 assert_eq!(ctx.tool_call_id, "call-123");
736 assert_eq!(ctx.arguments["path"], "/tmp/file");
737 let debug_str = format!("{ctx:?}");
739 assert!(debug_str.contains("write_file"));
740 assert!(
741 !debug_str.contains("/tmp/file"),
742 "arguments must be redacted in Debug"
743 );
744 }
745}