Skip to main content

swink_agent/
policy.rs

1//! Configurable policy slots for the agent loop.
2//!
3//! Provides four policy slots at natural seam points in the agent loop:
4//! - **`PreTurn`** (Slot 1): Before each LLM call — guards and pre-conditions.
5//! - **`PreDispatch`** (Slot 2): Per tool call, before approval — validation and argument mutation.
6//! - **`PostTurn`** (Slot 3): After each completed turn — persistence, steering, stop conditions.
7//! - **`PostLoop`** (Slot 4): After the inner loop exits — cleanup before follow-up polling.
8//!
9//! Each slot accepts a `Vec<Arc<dyn Trait>>` of policy implementations, evaluated in order.
10//! The default is empty vecs — no policies, anything goes.
11//!
12//! Two verdict enums enforce Skip-only-in-PreDispatch at compile time:
13//! - [`PolicyVerdict`]: Used by `PreTurn`, `PostTurn`, and `PostLoop` (no Skip variant).
14//! - [`PreDispatchVerdict`]: Used by `PreDispatch` (includes Skip).
15//!
16//! The slot runner catches panics via `catch_unwind` (using `AssertUnwindSafe`),
17//! so policy traits only require `Send + Sync` — implementors do not need `UnwindSafe`.
18#![forbid(unsafe_code)]
19
20use std::panic::AssertUnwindSafe;
21use std::path::Path;
22use std::sync::Arc;
23
24use tracing::{debug, warn};
25
26use crate::types::{
27    AgentMessage, AssistantMessage, Cost, ModelSpec, StopReason, ToolResultMessage, Usage,
28};
29
30// ─── Verdict Enums ──────────────────────────────────────────────────────────
31
32/// Outcome of a policy evaluation for `PreTurn`, `PostTurn`, and `PostLoop` slots.
33///
34/// Does not include `Skip` — that is only available in [`PreDispatchVerdict`].
35#[derive(Debug)]
36pub enum PolicyVerdict {
37    /// Proceed normally.
38    Continue,
39    /// Stop the loop gracefully with a reason.
40    Stop(String),
41    /// Add messages to the pending queue and continue.
42    Inject(Vec<AgentMessage>),
43}
44
45/// Outcome of a `PreDispatch` policy evaluation.
46///
47/// Includes `Skip` for per-tool-call rejection, in addition to the
48/// verdicts available in [`PolicyVerdict`].
49#[derive(Debug)]
50pub enum PreDispatchVerdict {
51    /// Proceed normally.
52    Continue,
53    /// Abort the entire tool batch and stop the loop.
54    Stop(String),
55    /// Add messages to the pending queue and continue.
56    Inject(Vec<AgentMessage>),
57    /// Skip this tool call, returning the error text to the LLM.
58    Skip(String),
59}
60
61// ─── Context Structs ────────────────────────────────────────────────────────
62
63/// Shared read-only context available to every policy evaluation.
64#[derive(Debug)]
65pub struct PolicyContext<'a> {
66    /// Zero-based index of the current/completed turn.
67    pub turn_index: usize,
68    /// Accumulated token usage across all turns.
69    pub accumulated_usage: &'a Usage,
70    /// Accumulated cost across all turns.
71    pub accumulated_cost: &'a Cost,
72    /// Number of messages in context.
73    pub message_count: usize,
74    /// Whether context overflow was signaled.
75    pub overflow_signal: bool,
76    /// Messages added since the last policy evaluation for this slot.
77    ///
78    /// - **`PreTurn`**: user/pending messages appended since the previous turn.
79    /// - **`PostTurn`** / **`PostLoop`**: empty — current-turn data is in [`TurnPolicyContext`].
80    ///
81    /// Policies should only scan this slice, never the full session history,
82    /// to avoid redundant work on messages that have already been evaluated.
83    pub new_messages: &'a [AgentMessage],
84    /// Read-only access to the session state.
85    pub state: &'a crate::SessionState,
86}
87
88/// Combined context for `PreDispatch` policies.
89///
90/// Contains only the data reliably available during tool dispatch — the per-call
91/// fields and read-only session state. Loop-level metrics (turn index, accumulated
92/// usage/cost, message count, overflow signal) are intentionally excluded: they are
93/// not tracked at the tool dispatch call site, and fabricating placeholder values
94/// would give policies incorrect data to reason from.
95pub struct ToolDispatchContext<'a> {
96    /// Name of the tool being called.
97    pub tool_name: &'a str,
98    /// Unique identifier for this tool call.
99    pub tool_call_id: &'a str,
100    /// Mutable reference to tool call arguments (policies may rewrite them).
101    pub arguments: &'a mut serde_json::Value,
102    /// Working directory the tool will resolve relative paths against, when known.
103    pub execution_root: Option<&'a Path>,
104    /// Read-only access to the session state.
105    pub state: &'a crate::SessionState,
106}
107
108impl std::fmt::Debug for ToolDispatchContext<'_> {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.debug_struct("ToolDispatchContext")
111            .field("tool_name", &self.tool_name)
112            .field("tool_call_id", &self.tool_call_id)
113            .field("execution_root", &self.execution_root)
114            .field("arguments", &"<redacted>")
115            .finish()
116    }
117}
118
119/// Per-turn context for `PostTurn` policies.
120#[derive(Debug)]
121pub struct TurnPolicyContext<'a> {
122    /// The assistant message from the completed turn.
123    pub assistant_message: &'a AssistantMessage,
124    /// Tool results produced during this turn.
125    pub tool_results: &'a [ToolResultMessage],
126    /// Why the turn ended.
127    pub stop_reason: StopReason,
128    /// The system prompt active during this turn.
129    pub system_prompt: &'a str,
130    /// The model specification active during this turn.
131    pub model_spec: &'a ModelSpec,
132    /// The committed conversation history for the completed turn.
133    ///
134    /// This always includes the current turn's assistant message and any tool
135    /// results before `PostTurn` policies run, regardless of whether the turn
136    /// ended with plain text, tool execution, or transfer termination.
137    pub context_messages: &'a [AgentMessage],
138}
139
140// ─── Slot Traits ────────────────────────────────────────────────────────────
141
142/// Slot 1: Evaluated before each LLM call.
143///
144/// Use for guards and pre-conditions (budget limits, turn caps, rate limiting).
145/// Trait bounds are `Send + Sync` only — the slot runner handles `catch_unwind`
146/// via `AssertUnwindSafe`, so implementors do not need `UnwindSafe`.
147///
148/// Stateful policies should use interior mutability (`Mutex`, atomics).
149pub trait PreTurnPolicy: Send + Sync {
150    /// Policy identifier for tracing and debugging.
151    fn name(&self) -> &str;
152    /// Evaluate the policy. Returns [`PolicyVerdict`].
153    fn evaluate(&self, ctx: &PolicyContext<'_>) -> PolicyVerdict;
154}
155
156/// Slot 2: Evaluated per tool call, before approval and execution.
157///
158/// Can inspect and mutate tool call arguments via [`ToolDispatchContext`].
159/// Returns [`PreDispatchVerdict`] which includes `Skip` for per-tool rejection.
160pub trait PreDispatchPolicy: Send + Sync {
161    /// Policy identifier for tracing and debugging.
162    fn name(&self) -> &str;
163    /// Evaluate the policy. Returns [`PreDispatchVerdict`].
164    fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict;
165}
166
167/// Slot 3: Evaluated after each completed turn.
168///
169/// Use for persistence, loop detection, dynamic stop conditions, or steering injection.
170pub trait PostTurnPolicy: Send + Sync {
171    /// Policy identifier for tracing and debugging.
172    fn name(&self) -> &str;
173    /// Evaluate the policy. Returns [`PolicyVerdict`].
174    fn evaluate(&self, ctx: &PolicyContext<'_>, turn: &TurnPolicyContext<'_>) -> PolicyVerdict;
175}
176
177/// Slot 4: Evaluated after the inner loop exits, before follow-up polling.
178///
179/// Use for cleanup, cooldown, or rate limiting between outer loop iterations.
180pub trait PostLoopPolicy: Send + Sync {
181    /// Policy identifier for tracing and debugging.
182    fn name(&self) -> &str;
183    /// Evaluate the policy. Returns [`PolicyVerdict`].
184    fn evaluate(&self, ctx: &PolicyContext<'_>) -> PolicyVerdict;
185}
186
187// ─── Slot Runners ───────────────────────────────────────────────────────────
188
189/// Evaluate `PreTurn`, `PostTurn`, or `PostLoop` policies in order.
190///
191/// - **Stop** short-circuits: first Stop wins, remaining policies don't run.
192/// - **Inject** accumulates: all non-short-circuited policies contribute messages.
193/// - **Panics** are caught via `catch_unwind` and treated as Continue.
194pub fn run_policies(policies: &[Arc<dyn PreTurnPolicy>], ctx: &PolicyContext<'_>) -> PolicyVerdict {
195    run_policies_inner(policies.iter().map(std::convert::AsRef::as_ref), ctx)
196}
197
198/// Evaluate `PostTurn` policies in order.
199pub fn run_post_turn_policies(
200    policies: &[Arc<dyn PostTurnPolicy>],
201    ctx: &PolicyContext<'_>,
202    turn: &TurnPolicyContext<'_>,
203) -> PolicyVerdict {
204    let mut injections: Vec<AgentMessage> = Vec::new();
205
206    for policy in policies {
207        let policy_name = policy.name().to_string();
208        let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx, turn)));
209
210        match result {
211            Ok(PolicyVerdict::Continue) => {}
212            Ok(PolicyVerdict::Stop(reason)) => {
213                debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
214                return PolicyVerdict::Stop(reason);
215            }
216            Ok(PolicyVerdict::Inject(msgs)) => {
217                injections.extend(msgs);
218            }
219            Err(_) => {
220                warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
221            }
222        }
223    }
224
225    if injections.is_empty() {
226        PolicyVerdict::Continue
227    } else {
228        PolicyVerdict::Inject(injections)
229    }
230}
231
232/// Evaluate `PostLoop` policies in order.
233pub fn run_post_loop_policies(
234    policies: &[Arc<dyn PostLoopPolicy>],
235    ctx: &PolicyContext<'_>,
236) -> PolicyVerdict {
237    let mut injections: Vec<AgentMessage> = Vec::new();
238
239    for policy in policies {
240        let policy_name = policy.name().to_string();
241        let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
242
243        match result {
244            Ok(PolicyVerdict::Continue) => {}
245            Ok(PolicyVerdict::Stop(reason)) => {
246                debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
247                return PolicyVerdict::Stop(reason);
248            }
249            Ok(PolicyVerdict::Inject(msgs)) => {
250                injections.extend(msgs);
251            }
252            Err(_) => {
253                warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
254            }
255        }
256    }
257
258    if injections.is_empty() {
259        PolicyVerdict::Continue
260    } else {
261        PolicyVerdict::Inject(injections)
262    }
263}
264
265/// Internal runner for `PreTurn` policies (same signature pattern).
266fn run_policies_inner<'a>(
267    policies: impl Iterator<Item = &'a dyn PreTurnPolicy>,
268    ctx: &PolicyContext<'_>,
269) -> PolicyVerdict {
270    let mut injections: Vec<AgentMessage> = Vec::new();
271
272    for policy in policies {
273        let policy_name = policy.name().to_string();
274        let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
275
276        match result {
277            Ok(PolicyVerdict::Continue) => {}
278            Ok(PolicyVerdict::Stop(reason)) => {
279                debug!(policy = %policy_name, reason = %reason, "policy stopped loop");
280                return PolicyVerdict::Stop(reason);
281            }
282            Ok(PolicyVerdict::Inject(msgs)) => {
283                injections.extend(msgs);
284            }
285            Err(_) => {
286                warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
287            }
288        }
289    }
290
291    if injections.is_empty() {
292        PolicyVerdict::Continue
293    } else {
294        PolicyVerdict::Inject(injections)
295    }
296}
297
298/// Evaluate `PreDispatch` policies for a single tool call.
299///
300/// - **Stop** short-circuits: aborts the entire tool batch.
301/// - **Skip** short-circuits: skips this tool call with error text.
302/// - **Inject** accumulates.
303/// - **Panics** are caught, argument mutations are rolled back, and evaluation continues.
304pub fn run_pre_dispatch_policies(
305    policies: &[Arc<dyn PreDispatchPolicy>],
306    ctx: &mut ToolDispatchContext<'_>,
307) -> PreDispatchVerdict {
308    let mut injections: Vec<AgentMessage> = Vec::new();
309
310    for policy in policies {
311        let policy_name = policy.name().to_string();
312        let argument_snapshot = ctx.arguments.clone();
313        let result = std::panic::catch_unwind(AssertUnwindSafe(|| policy.evaluate(ctx)));
314
315        match result {
316            Ok(PreDispatchVerdict::Continue) => {}
317            Ok(PreDispatchVerdict::Stop(reason)) => {
318                debug!(policy = %policy_name, reason = %reason, "policy stopped loop (pre-dispatch)");
319                return PreDispatchVerdict::Stop(reason);
320            }
321            Ok(PreDispatchVerdict::Skip(error_text)) => {
322                debug!(policy = %policy_name, "policy skipped tool call");
323                return PreDispatchVerdict::Skip(error_text);
324            }
325            Ok(PreDispatchVerdict::Inject(msgs)) => {
326                injections.extend(msgs);
327            }
328            Err(_) => {
329                *ctx.arguments = argument_snapshot;
330                warn!(policy = %policy_name, "policy panicked during evaluation, skipping");
331            }
332        }
333    }
334
335    if injections.is_empty() {
336        PreDispatchVerdict::Continue
337    } else {
338        PreDispatchVerdict::Inject(injections)
339    }
340}
341
342// ─── Tests ──────────────────────────────────────────────────────────────────
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use std::sync::atomic::{AtomicUsize, Ordering};
348
349    // ── Test helpers ──
350
351    struct TestPolicy {
352        policy_name: String,
353        make_verdict: Box<dyn Fn() -> PolicyVerdict + Send + Sync>,
354        call_count: AtomicUsize,
355    }
356
357    impl TestPolicy {
358        fn new(name: &str, make: impl Fn() -> PolicyVerdict + Send + Sync + 'static) -> Self {
359            Self {
360                policy_name: name.to_string(),
361                make_verdict: Box::new(make),
362                call_count: AtomicUsize::new(0),
363            }
364        }
365
366        fn calls(&self) -> usize {
367            self.call_count.load(Ordering::SeqCst)
368        }
369    }
370
371    impl PreTurnPolicy for TestPolicy {
372        fn name(&self) -> &str {
373            &self.policy_name
374        }
375        fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
376            self.call_count.fetch_add(1, Ordering::SeqCst);
377            (self.make_verdict)()
378        }
379    }
380
381    struct PanickingPolicy;
382    impl PreTurnPolicy for PanickingPolicy {
383        fn name(&self) -> &'static str {
384            "panicker"
385        }
386        fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
387            panic!("policy intentionally panicked");
388        }
389    }
390
391    struct TestPreDispatchPolicy {
392        policy_name: String,
393        make_verdict: Box<dyn Fn() -> PreDispatchVerdict + Send + Sync>,
394        call_count: AtomicUsize,
395    }
396
397    impl TestPreDispatchPolicy {
398        fn new(name: &str, make: impl Fn() -> PreDispatchVerdict + Send + Sync + 'static) -> Self {
399            Self {
400                policy_name: name.to_string(),
401                make_verdict: Box::new(make),
402                call_count: AtomicUsize::new(0),
403            }
404        }
405
406        fn calls(&self) -> usize {
407            self.call_count.load(Ordering::SeqCst)
408        }
409    }
410
411    impl PreDispatchPolicy for TestPreDispatchPolicy {
412        fn name(&self) -> &str {
413            &self.policy_name
414        }
415        fn evaluate(&self, _ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
416            self.call_count.fetch_add(1, Ordering::SeqCst);
417            (self.make_verdict)()
418        }
419    }
420
421    struct PanickingPreDispatchPolicy;
422    impl PreDispatchPolicy for PanickingPreDispatchPolicy {
423        fn name(&self) -> &'static str {
424            "panicker"
425        }
426        fn evaluate(&self, _ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
427            panic!("pre-dispatch policy panicked");
428        }
429    }
430
431    struct MutatingPreDispatchPolicy;
432    impl PreDispatchPolicy for MutatingPreDispatchPolicy {
433        fn name(&self) -> &'static str {
434            "mutator"
435        }
436        fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
437            if let Some(obj) = ctx.arguments.as_object_mut() {
438                obj.insert("injected".to_string(), serde_json::json!("by_policy"));
439            }
440            PreDispatchVerdict::Continue
441        }
442    }
443
444    struct VerifyingPreDispatchPolicy {
445        expected_key: String,
446    }
447    impl PreDispatchPolicy for VerifyingPreDispatchPolicy {
448        fn name(&self) -> &'static str {
449            "verifier"
450        }
451        fn evaluate(&self, ctx: &mut ToolDispatchContext<'_>) -> PreDispatchVerdict {
452            if ctx.arguments.get(&self.expected_key).is_some() {
453                PreDispatchVerdict::Continue
454            } else {
455                PreDispatchVerdict::Skip(format!("missing key: {}", self.expected_key))
456            }
457        }
458    }
459
460    fn test_message() -> AgentMessage {
461        AgentMessage::Llm(crate::types::LlmMessage::User(crate::types::UserMessage {
462            content: vec![],
463            timestamp: 0,
464            cache_hint: None,
465        }))
466    }
467
468    fn test_context() -> (Usage, Cost) {
469        (Usage::default(), Cost::default())
470    }
471
472    fn make_ctx<'a>(
473        usage: &'a Usage,
474        cost: &'a Cost,
475        state: &'a crate::SessionState,
476    ) -> PolicyContext<'a> {
477        PolicyContext {
478            turn_index: 0,
479            accumulated_usage: usage,
480            accumulated_cost: cost,
481            message_count: 5,
482            overflow_signal: false,
483            new_messages: &[],
484            state,
485        }
486    }
487
488    fn make_dispatch_ctx<'a>(
489        args: &'a mut serde_json::Value,
490        state: &'a crate::SessionState,
491    ) -> ToolDispatchContext<'a> {
492        ToolDispatchContext {
493            tool_name: "test_tool",
494            tool_call_id: "id1",
495            arguments: args,
496            execution_root: None,
497            state,
498        }
499    }
500
501    // ── T006: PolicyVerdict and PreDispatchVerdict debug + PolicyContext construction ──
502
503    #[test]
504    fn policy_verdict_debug() {
505        let v = PolicyVerdict::Continue;
506        assert!(format!("{v:?}").contains("Continue"));
507
508        let v = PolicyVerdict::Stop("budget exceeded".to_string());
509        assert!(format!("{v:?}").contains("budget exceeded"));
510
511        let v = PolicyVerdict::Inject(vec![]);
512        assert!(format!("{v:?}").contains("Inject"));
513    }
514
515    #[test]
516    fn pre_dispatch_verdict_debug() {
517        let v = PreDispatchVerdict::Skip("denied".to_string());
518        assert!(format!("{v:?}").contains("denied"));
519
520        let v = PreDispatchVerdict::Stop("halt".to_string());
521        assert!(format!("{v:?}").contains("halt"));
522    }
523
524    #[test]
525    fn policy_context_construction() {
526        let (usage, cost) = test_context();
527        let state = crate::SessionState::new();
528        let ctx = make_ctx(&usage, &cost, &state);
529        assert_eq!(ctx.turn_index, 0);
530        assert_eq!(ctx.message_count, 5);
531        assert!(!ctx.overflow_signal);
532    }
533
534    // ── T007: run_policies tests ──
535
536    #[test]
537    fn empty_vec_returns_continue() {
538        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![];
539        let (usage, cost) = test_context();
540        let state = crate::SessionState::new();
541        let ctx = make_ctx(&usage, &cost, &state);
542        let result = run_policies(&policies, &ctx);
543        assert!(matches!(result, PolicyVerdict::Continue));
544    }
545
546    #[test]
547    fn single_continue() {
548        let p = Arc::new(TestPolicy::new("a", || PolicyVerdict::Continue));
549        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p.clone()];
550        let (usage, cost) = test_context();
551        let state = crate::SessionState::new();
552        let ctx = make_ctx(&usage, &cost, &state);
553        let result = run_policies(&policies, &ctx);
554        assert!(matches!(result, PolicyVerdict::Continue));
555        assert_eq!(p.calls(), 1);
556    }
557
558    #[test]
559    fn single_stop_short_circuits() {
560        let p1 = Arc::new(TestPolicy::new("stopper", || {
561            PolicyVerdict::Stop("done".into())
562        }));
563        let p2 = Arc::new(TestPolicy::new("never_called", || PolicyVerdict::Continue));
564        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1.clone(), p2.clone()];
565        let (usage, cost) = test_context();
566        let state = crate::SessionState::new();
567        let ctx = make_ctx(&usage, &cost, &state);
568        let result = run_policies(&policies, &ctx);
569        assert!(matches!(result, PolicyVerdict::Stop(ref r) if r == "done"));
570        assert_eq!(p1.calls(), 1);
571        assert_eq!(p2.calls(), 0);
572    }
573
574    #[test]
575    fn inject_accumulates_across_policies() {
576        let p1 = Arc::new(TestPolicy::new("a", || {
577            PolicyVerdict::Inject(vec![test_message()])
578        }));
579        let p2 = Arc::new(TestPolicy::new("b", || {
580            PolicyVerdict::Inject(vec![test_message()])
581        }));
582        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2];
583        let (usage, cost) = test_context();
584        let state = crate::SessionState::new();
585        let ctx = make_ctx(&usage, &cost, &state);
586        let result = run_policies(&policies, &ctx);
587        match result {
588            PolicyVerdict::Inject(msgs) => assert_eq!(msgs.len(), 2),
589            _ => panic!("expected Inject"),
590        }
591    }
592
593    #[test]
594    fn stop_after_inject_returns_stop() {
595        let p1 = Arc::new(TestPolicy::new("injector", || {
596            PolicyVerdict::Inject(vec![test_message()])
597        }));
598        let p2 = Arc::new(TestPolicy::new("stopper", || {
599            PolicyVerdict::Stop("halt".into())
600        }));
601        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2];
602        let (usage, cost) = test_context();
603        let state = crate::SessionState::new();
604        let ctx = make_ctx(&usage, &cost, &state);
605        let result = run_policies(&policies, &ctx);
606        assert!(matches!(result, PolicyVerdict::Stop(ref r) if r == "halt"));
607    }
608
609    #[test]
610    fn panic_caught_returns_continue() {
611        let p1: Arc<dyn PreTurnPolicy> = Arc::new(PanickingPolicy);
612        let p2 = Arc::new(TestPolicy::new("after_panic", || PolicyVerdict::Continue));
613        let policies: Vec<Arc<dyn PreTurnPolicy>> = vec![p1, p2.clone()];
614        let (usage, cost) = test_context();
615        let state = crate::SessionState::new();
616        let ctx = make_ctx(&usage, &cost, &state);
617        let result = run_policies(&policies, &ctx);
618        assert!(matches!(result, PolicyVerdict::Continue));
619        assert_eq!(p2.calls(), 1); // panicking policy skipped, next one runs
620    }
621
622    // ── T008: run_pre_dispatch_policies tests ──
623
624    #[test]
625    fn pre_dispatch_empty_vec_returns_continue() {
626        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![];
627        let state = crate::SessionState::new();
628        let mut args = serde_json::json!({});
629        let mut ctx = make_dispatch_ctx(&mut args, &state);
630        let result = run_pre_dispatch_policies(&policies, &mut ctx);
631        assert!(matches!(result, PreDispatchVerdict::Continue));
632    }
633
634    #[test]
635    fn pre_dispatch_skip_short_circuits() {
636        let p1 = Arc::new(TestPreDispatchPolicy::new("skipper", || {
637            PreDispatchVerdict::Skip("denied".into())
638        }));
639        let p2 = Arc::new(TestPreDispatchPolicy::new("never", || {
640            PreDispatchVerdict::Continue
641        }));
642        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1.clone(), p2.clone()];
643        let state = crate::SessionState::new();
644        let mut args = serde_json::json!({});
645        let mut ctx = make_dispatch_ctx(&mut args, &state);
646        let result = run_pre_dispatch_policies(&policies, &mut ctx);
647        assert!(matches!(result, PreDispatchVerdict::Skip(ref e) if e == "denied"));
648        assert_eq!(p1.calls(), 1);
649        assert_eq!(p2.calls(), 0);
650    }
651
652    #[test]
653    fn pre_dispatch_stop_short_circuits() {
654        let p1 = Arc::new(TestPreDispatchPolicy::new("stopper", || {
655            PreDispatchVerdict::Stop("halt".into())
656        }));
657        let p2 = Arc::new(TestPreDispatchPolicy::new("never", || {
658            PreDispatchVerdict::Continue
659        }));
660        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2.clone()];
661        let state = crate::SessionState::new();
662        let mut args = serde_json::json!({});
663        let mut ctx = make_dispatch_ctx(&mut args, &state);
664        let result = run_pre_dispatch_policies(&policies, &mut ctx);
665        assert!(matches!(result, PreDispatchVerdict::Stop(ref r) if r == "halt"));
666        assert_eq!(p2.calls(), 0);
667    }
668
669    #[test]
670    fn pre_dispatch_inject_accumulates() {
671        let p1 = Arc::new(TestPreDispatchPolicy::new("a", || {
672            PreDispatchVerdict::Inject(vec![test_message()])
673        }));
674        let p2 = Arc::new(TestPreDispatchPolicy::new("b", || {
675            PreDispatchVerdict::Inject(vec![test_message()])
676        }));
677        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2];
678        let state = crate::SessionState::new();
679        let mut args = serde_json::json!({});
680        let mut ctx = make_dispatch_ctx(&mut args, &state);
681        let result = run_pre_dispatch_policies(&policies, &mut ctx);
682        match result {
683            PreDispatchVerdict::Inject(msgs) => assert_eq!(msgs.len(), 2),
684            _ => panic!("expected Inject"),
685        }
686    }
687
688    #[test]
689    fn pre_dispatch_panic_caught_returns_continue() {
690        let p1: Arc<dyn PreDispatchPolicy> = Arc::new(PanickingPreDispatchPolicy);
691        let p2 = Arc::new(TestPreDispatchPolicy::new("after", || {
692            PreDispatchVerdict::Continue
693        }));
694        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![p1, p2.clone()];
695        let state = crate::SessionState::new();
696        let mut args = serde_json::json!({});
697        let mut ctx = make_dispatch_ctx(&mut args, &state);
698        let result = run_pre_dispatch_policies(&policies, &mut ctx);
699        assert!(matches!(result, PreDispatchVerdict::Continue));
700        assert_eq!(p2.calls(), 1);
701    }
702
703    #[test]
704    fn argument_mutation_visible_to_next_policy() {
705        let mutator: Arc<dyn PreDispatchPolicy> = Arc::new(MutatingPreDispatchPolicy);
706        let verifier: Arc<dyn PreDispatchPolicy> = Arc::new(VerifyingPreDispatchPolicy {
707            expected_key: "injected".to_string(),
708        });
709        let policies: Vec<Arc<dyn PreDispatchPolicy>> = vec![mutator, verifier];
710        let state = crate::SessionState::new();
711        let mut args = serde_json::json!({"original": "value"});
712        let mut ctx = make_dispatch_ctx(&mut args, &state);
713        let result = run_pre_dispatch_policies(&policies, &mut ctx);
714        // If mutator didn't inject "injected" key, verifier would return Skip
715        assert!(matches!(result, PreDispatchVerdict::Continue));
716        // Verify the mutation is visible in the original args after dispatch
717        assert_eq!(args["injected"], "by_policy");
718    }
719
720    #[test]
721    fn tool_dispatch_context_contains_only_reliable_fields() {
722        // Regression: ToolDispatchContext must not include loop-level metrics
723        // (turn_index, usage, cost, message_count, overflow_signal, new_messages)
724        // because those are not tracked at the tool dispatch call site.
725        let state = crate::SessionState::new();
726        let mut args = serde_json::json!({"path": "/tmp/file"});
727        let ctx = ToolDispatchContext {
728            tool_name: "write_file",
729            tool_call_id: "call-123",
730            arguments: &mut args,
731            execution_root: None,
732            state: &state,
733        };
734        assert_eq!(ctx.tool_name, "write_file");
735        assert_eq!(ctx.tool_call_id, "call-123");
736        assert_eq!(ctx.arguments["path"], "/tmp/file");
737        // Debug output does not expose argument values
738        let debug_str = format!("{ctx:?}");
739        assert!(debug_str.contains("write_file"));
740        assert!(
741            !debug_str.contains("/tmp/file"),
742            "arguments must be redacted in Debug"
743        );
744    }
745}