Skip to main content

tiny_loop/tool/executor/
parallel.rs

1use crate::{
2    tool::{Tool, executor::ToolExecutor},
3    types::{Message, ToolCall},
4};
5use async_trait::async_trait;
6use futures::future::join_all;
7use std::collections::HashMap;
8
9/// Executes tools in parallel by grouping calls by tool name and using [`Tool::call_batch`]
10pub struct ParallelExecutor {
11    tools: HashMap<String, Box<dyn Tool + Sync>>,
12}
13
14impl ParallelExecutor {
15    pub fn new() -> Self {
16        Self {
17            tools: HashMap::new(),
18        }
19    }
20}
21
22#[async_trait]
23impl ToolExecutor for ParallelExecutor {
24    fn add(&mut self, name: String, tool: Box<dyn Tool + Sync>) -> Option<Box<dyn Tool + Sync>> {
25        tracing::trace!("Registering tool: {}", name);
26        self.tools.insert(name, tool)
27    }
28
29    async fn execute(&self, calls: Vec<ToolCall>) -> Vec<Message> {
30        tracing::debug!("Executing {} tool calls in parallel", calls.len());
31        let mut grouped: HashMap<String, Vec<ToolCall>> = HashMap::new();
32        for call in calls {
33            grouped
34                .entry(call.function.name.clone())
35                .or_default()
36                .push(call);
37        }
38
39        tracing::trace!("Grouped into {} unique tools", grouped.len());
40
41        let futures = grouped.into_iter().map(|(name, calls)| async move {
42            tracing::debug!("Executing {} calls for tool '{}'", calls.len(), name);
43            if let Some(tool) = self.tools.get(&name) {
44                tool.call_batch(calls).await
45            } else {
46                tracing::debug!("Tool '{}' not found", name);
47                calls
48                    .into_iter()
49                    .map(|call| Message::Tool {
50                        tool_call_id: call.id,
51                        content: format!("Tool '{}' not found", name),
52                    })
53                    .collect()
54            }
55        });
56
57        let results = join_all(futures).await.into_iter().flatten().collect();
58        tracing::debug!("Parallel execution completed");
59        results
60    }
61}