1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::traits::guard::{Guard, GuardResult};
13use crate::traits::tool::ErasedTool;
14use crate::traits::tracker::Tracker;
15use crate::types::action::Action;
16use crate::types::agent_state::AgentState;
17use crate::types::tool_call::ToolCall;
18
19#[derive(Debug, Clone)]
21pub struct PendingToolCall {
22 pub id: String,
24 pub name: String,
26 pub arguments: serde_json::Value,
28}
29
30impl From<&ToolCall> for PendingToolCall {
31 fn from(tc: &ToolCall) -> Self {
32 Self {
33 id: tc.id.clone(),
34 name: tc.name.clone(),
35 arguments: tc.arguments.clone(),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ToolResult {
43 pub id: String,
45 pub output: String,
47}
48
49#[async_trait]
54pub trait ExecutionStrategy: Send + Sync {
55 async fn execute_batch(
57 &self,
58 calls: Vec<PendingToolCall>,
59 tools: &[Arc<dyn ErasedTool>],
60 guards: &[Arc<dyn Guard>],
61 state: &AgentState,
62 ) -> Vec<ToolResult>;
63}
64
65pub struct SequentialStrategy;
71
72#[async_trait]
73impl ExecutionStrategy for SequentialStrategy {
74 async fn execute_batch(
75 &self,
76 calls: Vec<PendingToolCall>,
77 tools: &[Arc<dyn ErasedTool>],
78 guards: &[Arc<dyn Guard>],
79 _state: &AgentState,
80 ) -> Vec<ToolResult> {
81 let mut results = Vec::with_capacity(calls.len());
82 for call in calls {
83 let output = execute_single(&call, tools, guards).await;
84 results.push(ToolResult {
85 id: call.id,
86 output,
87 });
88 }
89 results
90 }
91}
92
93pub struct ParallelStrategy {
97 pub max_concurrency: usize,
99}
100
101impl ParallelStrategy {
102 #[must_use]
104 pub fn new(max_concurrency: usize) -> Self {
105 Self {
106 max_concurrency: max_concurrency.max(1),
107 }
108 }
109}
110
111#[async_trait]
112impl ExecutionStrategy for ParallelStrategy {
113 async fn execute_batch(
114 &self,
115 calls: Vec<PendingToolCall>,
116 tools: &[Arc<dyn ErasedTool>],
117 guards: &[Arc<dyn Guard>],
118 _state: &AgentState,
119 ) -> Vec<ToolResult> {
120 use tokio::sync::Semaphore;
121
122 let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
123 let tools = Arc::new(tools.to_vec());
124 let guards = Arc::new(guards.to_vec());
125
126 let call_ids: Vec<String> = calls.iter().map(|c| c.id.clone()).collect();
129 let mut handles = Vec::with_capacity(calls.len());
130
131 for call in calls {
132 let sem = semaphore.clone();
133 let tools = tools.clone();
134 let guards = guards.clone();
135
136 handles.push(tokio::spawn(async move {
137 let _permit = sem.acquire().await.expect("semaphore closed");
138 let output = execute_single(&call, &tools, &guards).await;
139 ToolResult {
140 id: call.id,
141 output,
142 }
143 }));
144 }
145
146 let mut results = Vec::with_capacity(handles.len());
147 for (i, handle) in handles.into_iter().enumerate() {
148 match handle.await {
149 Ok(result) => results.push(result),
150 Err(e) => results.push(ToolResult {
151 id: call_ids[i].clone(),
152 output: format!("Error: task panicked: {e}"),
153 }),
154 }
155 }
156 results
157 }
158}
159
160pub struct AdaptiveStrategy {
165 tracker: Arc<dyn Tracker>,
166}
167
168impl AdaptiveStrategy {
169 #[must_use]
171 pub fn new(tracker: Arc<dyn Tracker>) -> Self {
172 Self { tracker }
173 }
174}
175
176#[async_trait]
177impl ExecutionStrategy for AdaptiveStrategy {
178 async fn execute_batch(
179 &self,
180 calls: Vec<PendingToolCall>,
181 tools: &[Arc<dyn ErasedTool>],
182 guards: &[Arc<dyn Guard>],
183 state: &AgentState,
184 ) -> Vec<ToolResult> {
185 let concurrency = self.tracker.recommended_concurrency(state);
186 if concurrency <= 1 {
187 SequentialStrategy
188 .execute_batch(calls, tools, guards, state)
189 .await
190 } else {
191 ParallelStrategy::new(concurrency)
192 .execute_batch(calls, tools, guards, state)
193 .await
194 }
195 }
196}
197
198async fn execute_single(
202 call: &PendingToolCall,
203 tools: &[Arc<dyn ErasedTool>],
204 guards: &[Arc<dyn Guard>],
205) -> String {
206 let action = Action::ToolCall {
207 name: call.name.clone(),
208 arguments: call.arguments.clone(),
209 };
210
211 for guard in guards {
213 let guard_name = guard.name().to_string();
214 let action_ref = &action;
215
216 let guard_span = tracing::info_span!(
217 target: "traitclaw::guard",
218 "guard.check",
219 guard.name = guard_name.as_str(),
220 guard.result = tracing::field::Empty,
221 );
222 let _g = guard_span.enter();
223
224 let result =
225 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| guard.check(action_ref)));
226
227 match result {
228 Ok(GuardResult::Allow) => {
229 guard_span.record("guard.result", "allow");
230 }
231 Ok(GuardResult::Deny { reason, .. }) => {
232 guard_span.record("guard.result", "deny");
233 return format!("Error: Action blocked by guard: {reason}");
234 }
235 Ok(GuardResult::Sanitize { warning, .. }) => {
236 guard_span.record("guard.result", "sanitize");
237 tracing::info!(
238 target: "traitclaw::guard",
239 guard = guard_name.as_str(),
240 "Guard sanitized: {warning}"
241 );
242 }
243 Err(_) => {
244 guard_span.record("guard.result", "panic");
245 tracing::warn!(
247 target: "traitclaw::guard",
248 guard = guard_name.as_str(),
249 "Guard panicked — denying action for safety"
250 );
251 return format!("Error: Action blocked — guard '{guard_name}' panicked");
252 }
253 }
254 }
255
256 let tool_span = tracing::info_span!(
258 target: "traitclaw::tool",
259 "tool.call",
260 tool.name = call.name.as_str(),
261 tool.success = tracing::field::Empty,
262 );
263 let _t = tool_span.enter();
264
265 if let Some(tool) = tools.iter().find(|t| t.name() == call.name) {
266 match tool.execute_json(call.arguments.clone()).await {
267 Ok(output) => {
268 tool_span.record("tool.success", true);
269 serde_json::to_string(&output)
270 .unwrap_or_else(|e| format!("Error serializing output: {e}"))
271 }
272 Err(e) => {
273 tool_span.record("tool.success", false);
274 format!("Error executing tool: {e}")
275 }
276 }
277 } else {
278 tool_span.record("tool.success", false);
279 let available: Vec<_> = tools.iter().map(|t| t.name().to_string()).collect();
280 format!(
281 "Error: Tool '{}' not found. Available: {}",
282 call.name,
283 available.join(", ")
284 )
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::traits::guard::NoopGuard;
292
293 struct AddTool;
294
295 #[async_trait]
296 impl ErasedTool for AddTool {
297 fn name(&self) -> &'static str {
298 "add"
299 }
300 fn description(&self) -> &'static str {
301 "Adds two numbers"
302 }
303 fn schema(&self) -> crate::traits::tool::ToolSchema {
304 crate::traits::tool::ToolSchema {
305 name: "add".into(),
306 description: "add".into(),
307 parameters: serde_json::json!({}),
308 }
309 }
310 async fn execute_json(
311 &self,
312 _args: serde_json::Value,
313 ) -> std::result::Result<serde_json::Value, crate::Error> {
314 Ok(serde_json::json!("result"))
315 }
316 }
317
318 fn make_calls(n: usize) -> Vec<PendingToolCall> {
319 (0..n)
320 .map(|i| PendingToolCall {
321 id: format!("call-{i}"),
322 name: "add".into(),
323 arguments: serde_json::json!({}),
324 })
325 .collect()
326 }
327
328 #[tokio::test]
329 async fn test_sequential_executes_in_order() {
330 let strategy = SequentialStrategy;
331 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
332 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
333
334 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
335
336 let results = strategy
337 .execute_batch(make_calls(3), &tools, &guards, &state)
338 .await;
339
340 assert_eq!(results.len(), 3);
341 assert_eq!(results[0].id, "call-0");
342 assert_eq!(results[1].id, "call-1");
343 assert_eq!(results[2].id, "call-2");
344 for r in &results {
346 assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
347 }
348 }
349
350 #[tokio::test]
351 async fn test_parallel_executes_concurrently() {
352 let strategy = ParallelStrategy::new(4);
353 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
354 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
355
356 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
357
358 let results = strategy
359 .execute_batch(make_calls(5), &tools, &guards, &state)
360 .await;
361
362 assert_eq!(results.len(), 5);
363 for r in &results {
364 assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
365 }
366 }
367
368 #[tokio::test]
369 async fn test_guard_blocks_propagate() {
370 use crate::traits::guard::{Guard, GuardResult};
371
372 struct DenyGuard;
373 impl Guard for DenyGuard {
374 fn name(&self) -> &'static str {
375 "deny"
376 }
377 fn check(&self, _action: &Action) -> GuardResult {
378 GuardResult::Deny {
379 reason: "blocked".into(),
380 severity: crate::traits::guard::GuardSeverity::High,
381 }
382 }
383 }
384
385 let strategy = SequentialStrategy;
386 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
387 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(DenyGuard)];
388
389 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
390
391 let results = strategy
392 .execute_batch(make_calls(1), &tools, &guards, &state)
393 .await;
394
395 assert_eq!(results.len(), 1);
396 assert!(results[0].output.contains("blocked"));
397 }
398
399 #[tokio::test]
400 async fn test_guard_panic_defaults_to_deny() {
401 use crate::traits::guard::{Guard, GuardResult};
402
403 struct PanicGuard;
404 impl Guard for PanicGuard {
405 fn name(&self) -> &'static str {
406 "panic_guard"
407 }
408 fn check(&self, _action: &Action) -> GuardResult {
409 panic!("intentional panic in guard");
410 }
411 }
412
413 let strategy = SequentialStrategy;
414 let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
415 let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(PanicGuard)];
416
417 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
418
419 let results = strategy
420 .execute_batch(make_calls(1), &tools, &guards, &state)
421 .await;
422
423 assert_eq!(results.len(), 1);
424 assert!(
426 results[0].output.contains("panicked"),
427 "Expected deny on panic, got: {}",
428 results[0].output
429 );
430 }
431
432 #[tokio::test]
433 async fn test_tool_not_found_returns_error() {
434 let strategy = SequentialStrategy;
435 let tools: Vec<Arc<dyn ErasedTool>> = vec![]; let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
437
438 let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
439
440 let calls = vec![PendingToolCall {
441 id: "c1".into(),
442 name: "nonexistent".into(),
443 arguments: serde_json::json!({}),
444 }];
445
446 let results = strategy.execute_batch(calls, &tools, &guards, &state).await;
447
448 assert_eq!(results.len(), 1);
449 assert!(
450 results[0].output.contains("not found"),
451 "Expected 'not found', got: {}",
452 results[0].output
453 );
454 }
455}