Skip to main content

swink_agent/
tool_execution_policy.rs

1//! Tool execution ordering policy.
2//!
3//! By default the agent loop executes all tool calls concurrently via
4//! `tokio::spawn`. This module provides [`ToolExecutionPolicy`] to control
5//! dispatch ordering — sequential, priority-based, or fully custom via the
6//! [`ToolExecutionStrategy`] trait.
7
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use serde_json::Value;
13
14// ─── ToolCallSummary ─────────────────────────────────────────────────────────
15
16/// Lightweight view of a pending tool call, exposed to policy callbacks.
17///
18/// This is intentionally a borrowed view so priority functions do not need
19/// to clone arguments.
20#[derive(Debug)]
21pub struct ToolCallSummary<'a> {
22    /// Unique identifier for this tool call.
23    pub id: &'a str,
24    /// Name of the tool being invoked.
25    pub name: &'a str,
26    /// Arguments passed to the tool.
27    pub arguments: &'a Value,
28}
29
30// ─── PriorityFn ──────────────────────────────────────────────────────────────
31
32/// Callback that assigns an integer priority to a tool call.
33///
34/// Higher values execute first. Tool calls with the same priority execute
35/// concurrently within their group; groups execute sequentially from highest
36/// to lowest priority.
37pub type PriorityFn = dyn Fn(&ToolCallSummary<'_>) -> i32 + Send + Sync;
38
39/// A boxed future returned by a [`ToolExecutionStrategy`].
40pub type ToolExecutionStrategyFuture<'a> =
41    Pin<Box<dyn Future<Output = Vec<Vec<usize>>> + Send + 'a>>;
42
43// ─── ToolExecutionStrategy ───────────────────────────────────────────────────
44
45/// Fully custom tool execution strategy.
46///
47/// Implementations receive the post-preprocessing tool-call slice that is
48/// actually eligible for dispatch. Calls skipped by policies or rejected by
49/// approval are already removed before this hook runs. The strategy returns
50/// execution groups — each group is a `Vec<usize>` of indices into the
51/// provided `tool_calls` slice. Tools within a group execute concurrently;
52/// groups execute sequentially in order.
53pub trait ToolExecutionStrategy: Send + Sync {
54    /// Partition tool calls into sequential execution groups.
55    ///
56    /// Each inner `Vec<usize>` contains indices into the provided
57    /// `tool_calls` slice that should execute concurrently. The outer `Vec`
58    /// is processed sequentially — group 0 completes before group 1 starts,
59    /// etc. Every provided tool call must appear exactly once across all
60    /// groups; out-of-bounds, duplicate, or missing indices are rejected by
61    /// the dispatch layer as deterministic tool errors.
62    fn partition(&self, tool_calls: &[ToolCallSummary<'_>]) -> ToolExecutionStrategyFuture<'_>;
63}
64
65// ─── ToolExecutionPolicy ─────────────────────────────────────────────────────
66
67/// Controls how tool calls within a single turn are dispatched.
68///
69/// The default is [`Concurrent`](ToolExecutionPolicy::Concurrent), which
70/// preserves backward compatibility by spawning all tool calls at once.
71#[derive(Default)]
72pub enum ToolExecutionPolicy {
73    /// Execute all tool calls concurrently via `tokio::spawn` (default).
74    #[default]
75    Concurrent,
76
77    /// Execute tool calls one at a time, in the order the model returned them.
78    Sequential,
79
80    /// Sort tool calls by priority (higher first), then execute groups of
81    /// equal priority concurrently. Groups run sequentially from highest to
82    /// lowest.
83    Priority(Arc<PriorityFn>),
84
85    /// Fully custom execution strategy.
86    Custom(Arc<dyn ToolExecutionStrategy>),
87}
88
89impl Clone for ToolExecutionPolicy {
90    fn clone(&self) -> Self {
91        match self {
92            Self::Concurrent => Self::Concurrent,
93            Self::Sequential => Self::Sequential,
94            Self::Priority(f) => Self::Priority(Arc::clone(f)),
95            Self::Custom(s) => Self::Custom(Arc::clone(s)),
96        }
97    }
98}
99
100impl std::fmt::Debug for ToolExecutionPolicy {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        match self {
103            Self::Concurrent => write!(f, "Concurrent"),
104            Self::Sequential => write!(f, "Sequential"),
105            Self::Priority(_) => write!(f, "Priority(...)"),
106            Self::Custom(_) => write!(f, "Custom(...)"),
107        }
108    }
109}
110
111// ─── Compile-time Send + Sync assertions ─────────────────────────────────────
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn default_is_concurrent() {
119        assert!(matches!(
120            ToolExecutionPolicy::default(),
121            ToolExecutionPolicy::Concurrent
122        ));
123    }
124
125    #[test]
126    fn debug_formatting() {
127        assert_eq!(
128            format!("{:?}", ToolExecutionPolicy::Concurrent),
129            "Concurrent"
130        );
131        assert_eq!(
132            format!("{:?}", ToolExecutionPolicy::Sequential),
133            "Sequential"
134        );
135
136        let pf: Arc<PriorityFn> = Arc::new(|_| 0);
137        assert_eq!(
138            format!("{:?}", ToolExecutionPolicy::Priority(pf)),
139            "Priority(...)"
140        );
141    }
142
143    #[test]
144    fn tool_call_summary_debug() {
145        let args = serde_json::json!({"cmd": "ls"});
146        let summary = ToolCallSummary {
147            id: "call_1",
148            name: "bash",
149            arguments: &args,
150        };
151        let debug = format!("{summary:?}");
152        assert!(debug.contains("bash"));
153        assert!(debug.contains("call_1"));
154    }
155}