Skip to main content

tiny_loop/tool/executor/
parallel.rs

1use crate::{
2    tool::{Tool, executor::ToolExecutor},
3    types::ToolCall,
4};
5use async_trait::async_trait;
6use futures::future::join_all;
7use std::collections::HashMap;
8
9/// Executes tools in parallel by grouping calls by tool name and using [`Tool::call_batch`]
10///
11/// # How it works
12///
13/// 1. Groups tool calls by tool name
14/// 2. Executes each group in parallel using [`Tool::call_batch`]
15/// 3. Flattens and returns all results
16///
17/// # Example
18///
19/// Given tool calls:
20/// ```text
21/// [
22///   ToolCall { id: "1", function: { name: "weather", ... } },
23///   ToolCall { id: "2", function: { name: "search", ... } },
24///   ToolCall { id: "3", function: { name: "weather", ... } },
25/// ]
26/// ```
27///
28/// The executor will:
29/// 1. Group by name: `{ "weather": [call1, call3], "search": [call2] }`
30/// 2. Execute in parallel:
31///    - `weather_tool.call_batch([call1, call3])` (runs concurrently)
32///    - `search_tool.call_batch([call2])` (runs concurrently)
33/// 3. Return flattened results: `[result1, result3, result2]`
34pub struct ParallelExecutor {
35    tools: HashMap<String, Box<dyn Tool + Sync>>,
36}
37
38impl ParallelExecutor {
39    /// Create a new parallel executor
40    pub fn new() -> Self {
41        Self {
42            tools: HashMap::new(),
43        }
44    }
45}
46
47#[async_trait]
48impl ToolExecutor for ParallelExecutor {
49    fn add(&mut self, name: String, tool: Box<dyn Tool + Sync>) -> Option<Box<dyn Tool + Sync>> {
50        tracing::trace!("Registering tool: {}", name);
51        self.tools.insert(name, tool)
52    }
53
54    async fn execute(&self, calls: Vec<ToolCall>) -> Vec<crate::types::ToolMessage> {
55        tracing::debug!("Executing {} tool calls in parallel", calls.len());
56        let mut grouped: HashMap<String, Vec<ToolCall>> = HashMap::new();
57        for call in calls {
58            grouped
59                .entry(call.function.name.clone())
60                .or_default()
61                .push(call);
62        }
63
64        tracing::trace!("Grouped into {} unique tools", grouped.len());
65
66        let futures = grouped.into_iter().map(|(name, calls)| async move {
67            tracing::debug!("Executing {} calls for tool '{}'", calls.len(), name);
68            if let Some(tool) = self.tools.get(&name) {
69                tool.call_batch(calls).await
70            } else {
71                tracing::debug!("Tool '{}' not found", name);
72                calls
73                    .into_iter()
74                    .map(|call| crate::types::ToolMessage {
75                        tool_call_id: call.id,
76                        content: format!("Tool '{}' not found", name),
77                    })
78                    .collect::<Vec<_>>()
79            }
80        });
81
82        let results = join_all(futures).await.into_iter().flatten().collect();
83        tracing::debug!("Parallel execution completed");
84        results
85    }
86}