1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::traits::guard::{Guard, GuardResult};
13use crate::traits::tool::ErasedTool;
14use crate::traits::tracker::Tracker;
15use crate::types::action::Action;
16use crate::types::agent_state::AgentState;
17use crate::types::tool_call::ToolCall;
18
19#[derive(Debug, Clone)]
21pub struct PendingToolCall {
22 pub id: String,
24 pub name: String,
26 pub arguments: serde_json::Value,
28}
29
30impl From<&ToolCall> for PendingToolCall {
31 fn from(tc: &ToolCall) -> Self {
32 Self {
33 id: tc.id.clone(),
34 name: tc.name.clone(),
35 arguments: tc.arguments.clone(),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ToolResult {
43 pub id: String,
45 pub output: String,
47}
48
49#[async_trait]
54pub trait ExecutionStrategy: Send + Sync {
55 async fn execute_batch(
57 &self,
58 calls: Vec<PendingToolCall>,
59 tools: &[Arc<dyn ErasedTool>],
60 guards: &[Arc<dyn Guard>],
61 state: &AgentState,
62 ) -> Vec<ToolResult>;
63}
64
65pub struct SequentialStrategy;
71
72#[async_trait]
73impl ExecutionStrategy for SequentialStrategy {
74 async fn execute_batch(
75 &self,
76 calls: Vec<PendingToolCall>,
77 tools: &[Arc<dyn ErasedTool>],
78 guards: &[Arc<dyn Guard>],
79 _state: &AgentState,
80 ) -> Vec<ToolResult> {
81 let mut results = Vec::with_capacity(calls.len());
82 for call in calls {
83 let output = execute_single(&call, tools, guards).await;
84 results.push(ToolResult {
85 id: call.id,
86 output,
87 });
88 }
89 results
90 }
91}
92
93pub struct ParallelStrategy {
97 pub max_concurrency: usize,
99}
100
101impl ParallelStrategy {
102 #[must_use]
104 pub fn new(max_concurrency: usize) -> Self {
105 Self {
106 max_concurrency: max_concurrency.max(1),
107 }
108 }
109}
110
111#[async_trait]
112impl ExecutionStrategy for ParallelStrategy {
113 async fn execute_batch(
114 &self,
115 calls: Vec<PendingToolCall>,
116 tools: &[Arc<dyn ErasedTool>],
117 guards: &[Arc<dyn Guard>],
118 _state: &AgentState,
119 ) -> Vec<ToolResult> {
120 use tokio::sync::Semaphore;
121
122 let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
123 let tools = Arc::new(tools.to_vec());
124 let guards = Arc::new(guards.to_vec());
125
126 let call_ids: Vec<String> = calls.iter().map(|c| c.id.clone()).collect();
129 let mut handles = Vec::with_capacity(calls.len());
130
131 for call in calls {
132 let sem = semaphore.clone();
133 let tools = tools.clone();
134 let guards = guards.clone();
135
136 handles.push(tokio::spawn(async move {
137 let _permit = sem.acquire().await.expect("semaphore closed");
138 let output = execute_single(&call, &tools, &guards).await;
139 ToolResult {
140 id: call.id,
141 output,
142 }
143 }));
144 }
145
146 let mut results = Vec::with_capacity(handles.len());
147 for (i, handle) in handles.into_iter().enumerate() {
148 match handle.await {
149 Ok(result) => results.push(result),
150 Err(e) => results.push(ToolResult {
151 id: call_ids[i].clone(),
152 output: format!("Error: task panicked: {e}"),
153 }),
154 }
155 }
156 results
157 }
158}
159
160pub struct AdaptiveStrategy {
165 tracker: Arc<dyn Tracker>,
166}
167
168impl AdaptiveStrategy {
169 #[must_use]
171 pub fn new(tracker: Arc<dyn Tracker>) -> Self {
172 Self { tracker }
173 }
174}
175
176#[async_trait]
177impl ExecutionStrategy for AdaptiveStrategy {
178 async fn execute_batch(
179 &self,
180 calls: Vec<PendingToolCall>,
181 tools: &[Arc<dyn ErasedTool>],
182 guards: &[Arc<dyn Guard>],
183 state: &AgentState,
184 ) -> Vec<ToolResult> {
185 let concurrency = self.tracker.recommended_concurrency(state);
186 if concurrency <= 1 {
187 SequentialStrategy
188 .execute_batch(calls, tools, guards, state)
189 .await
190 } else {
191 ParallelStrategy::new(concurrency)
192 .execute_batch(calls, tools, guards, state)
193 .await
194 }
195 }
196}
197
198async fn execute_single(
202 call: &PendingToolCall,
203 tools: &[Arc<dyn ErasedTool>],
204 guards: &[Arc<dyn Guard>],
205) -> String {
206 let action = Action::ToolCall {
207 name: call.name.clone(),
208 arguments: call.arguments.clone(),
209 };
210
211 for guard in guards {
213 let guard_name = guard.name().to_string();
214 let action_ref = &action;
215 let result =
216 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| guard.check(action_ref)));
217
218 match result {
219 Ok(GuardResult::Allow) => {}
220 Ok(GuardResult::Deny { reason, .. }) => {
221 return format!("Error: Action blocked by guard: {reason}");
222 }
223 Ok(GuardResult::Sanitize { warning, .. }) => {
224 tracing::info!(guard = guard_name.as_str(), "Guard sanitized: {warning}");
225 }
226 Err(_) => {
227 tracing::warn!(
229 guard = guard_name.as_str(),
230 "Guard panicked — denying action for safety"
231 );
232 return format!("Error: Action blocked — guard '{guard_name}' panicked");
233 }
234 }
235 }
236
237 if let Some(tool) = tools.iter().find(|t| t.name() == call.name) {
239 match tool.execute_json(call.arguments.clone()).await {
240 Ok(output) => serde_json::to_string(&output)
241 .unwrap_or_else(|e| format!("Error serializing output: {e}")),
242 Err(e) => format!("Error executing tool: {e}"),
243 }
244 } else {
245 let available: Vec<_> = tools.iter().map(|t| t.name().to_string()).collect();
246 format!(
247 "Error: Tool '{}' not found. Available: {}",
248 call.name,
249 available.join(", ")
250 )
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::traits::guard::NoopGuard;
258
259 struct AddTool;
260
261 #[async_trait]
262 impl ErasedTool for AddTool {
263 fn name(&self) -> &'static str {
264 "add"
265 }
266 fn description(&self) -> &'static str {
267 "Adds two numbers"
268 }
269 fn schema(&self) -> crate::traits::tool::ToolSchema {
270 crate::traits::tool::ToolSchema {
271 name: "add".into(),
272 description: "add".into(),
273 parameters: serde_json::json!({}),
274 }
275 }
276 async fn execute_json(
277 &self,
278 _args: serde_json::Value,
279 ) -> std::result::Result<serde_json::Value, crate::Error> {
280 Ok(serde_json::json!("result"))
281 }
282 }
283
284 fn make_calls(n: usize) -> Vec<PendingToolCall> {
285 (0..n)
286 .map(|i| PendingToolCall {
287 id: format!("call-{i}"),
288 name: "add".into(),
289 arguments: serde_json::json!({}),
290 })
291 .collect()
292 }
293
294 #[tokio::test]
295 async fn test_sequential_executes_in_order() {
296 let strategy = SequentialStrategy;
297 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
298 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
299
300 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
301
302 let results = strategy
303 .execute_batch(make_calls(3), &tools, &guards, &state)
304 .await;
305
306 assert_eq!(results.len(), 3);
307 assert_eq!(results[0].id, "call-0");
308 assert_eq!(results[1].id, "call-1");
309 assert_eq!(results[2].id, "call-2");
310 for r in &results {
312 assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
313 }
314 }
315
316 #[tokio::test]
317 async fn test_parallel_executes_concurrently() {
318 let strategy = ParallelStrategy::new(4);
319 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
320 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
321
322 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
323
324 let results = strategy
325 .execute_batch(make_calls(5), &tools, &guards, &state)
326 .await;
327
328 assert_eq!(results.len(), 5);
329 for r in &results {
330 assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
331 }
332 }
333
334 #[tokio::test]
335 async fn test_guard_blocks_propagate() {
336 use crate::traits::guard::{Guard, GuardResult};
337
338 struct DenyGuard;
339 impl Guard for DenyGuard {
340 fn name(&self) -> &'static str {
341 "deny"
342 }
343 fn check(&self, _action: &Action) -> GuardResult {
344 GuardResult::Deny {
345 reason: "blocked".into(),
346 severity: crate::traits::guard::GuardSeverity::High,
347 }
348 }
349 }
350
351 let strategy = SequentialStrategy;
352 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
353 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(DenyGuard)];
354
355 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
356
357 let results = strategy
358 .execute_batch(make_calls(1), &tools, &guards, &state)
359 .await;
360
361 assert_eq!(results.len(), 1);
362 assert!(results[0].output.contains("blocked"));
363 }
364
365 #[tokio::test]
366 async fn test_guard_panic_defaults_to_deny() {
367 use crate::traits::guard::{Guard, GuardResult};
368
369 struct PanicGuard;
370 impl Guard for PanicGuard {
371 fn name(&self) -> &'static str {
372 "panic_guard"
373 }
374 fn check(&self, _action: &Action) -> GuardResult {
375 panic!("intentional panic in guard");
376 }
377 }
378
379 let strategy = SequentialStrategy;
380 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
381 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(PanicGuard)];
382
383 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
384
385 let results = strategy
386 .execute_batch(make_calls(1), &tools, &guards, &state)
387 .await;
388
389 assert_eq!(results.len(), 1);
390 assert!(
392 results[0].output.contains("panicked"),
393 "Expected deny on panic, got: {}",
394 results[0].output
395 );
396 }
397
398 #[tokio::test]
399 async fn test_tool_not_found_returns_error() {
400 let strategy = SequentialStrategy;
401 let tools: Vec<Arc<dyn ErasedTool>> = vec![]; let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
403
404 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
405
406 let calls = vec![PendingToolCall {
407 id: "c1".into(),
408 name: "nonexistent".into(),
409 arguments: serde_json::json!({}),
410 }];
411
412 let results = strategy.execute_batch(calls, &tools, &guards, &state).await;
413
414 assert_eq!(results.len(), 1);
415 assert!(
416 results[0].output.contains("not found"),
417 "Expected 'not found', got: {}",
418 results[0].output
419 );
420 }
421}