tower_llm/concurrency/
mod.rs

1//! Parallel tool execution and concurrency controls
2//!
3//! What this module provides (spec)
4//! - A layer that fan-outs tool_calls concurrently and fan-ins results deterministically
5//! - Configurable concurrency limits and join/failure policies
6//!
7//! Exports
8//! - Models
9//!   - `ConcurrencyLimit(usize)`
10//!   - `ToolJoinPolicy::{JoinAll, FailFast, TimeoutPerTool(Duration)}`
11//! - Layers
12//!   - `ParallelToolsLayer<S, R>` where `S: Service<RawChatRequest, Response=StepOutcome>` and `R: Service<ToolInvocation,...>`
13//! - Utils
14//!   - Ordering helper to map completed outputs back to requested order
15//!
16//! Implementation strategy
17//! - Wrap the tool router with `tower::buffer::Buffer` to acquire readiness per invocation
18//! - On `StepOutcome::Next` with `invoked_tools`, spawn invocations concurrently:
19//!   - Use `FuturesUnordered` or `join_all` with a semaphore set by `ConcurrencyLimit`
20//!   - Apply `ToolJoinPolicy` (wait all, fail fast, per-invocation timeout)
21//! - Serialize outputs as `tool` messages in the same order as original tool_calls
22//! - Return a rewritten `StepOutcome::Next` with appended tool messages
23//!
24//! Composition
25//! - `ServiceBuilder::new().layer(ParallelToolsLayer::new(limit, policy)).service(step)`
26//! - Combine with resilience layers for per-tool retry/timeout if desired
27//!
28//! Testing strategy
29//! - Fake tools with injected latency and error behavior
30//! - Assert that with `JoinAll` all succeed and order is preserved
31//! - Assert that with `FailFast` layer aborts on first error and surfaces it
32//! - Assert that limit `N` bounds concurrent calls (use atomic counters)
33
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use tokio::sync::Semaphore;
38use tower::{BoxError, Service, ServiceExt};
39
40use crate::core::{ToolInvocation, ToolOutput};
41
42/// Maximum number of concurrent tool invocations.
43#[derive(Debug, Clone, Copy)]
44pub struct ConcurrencyLimit(pub usize);
45
46/// Policy describing how to join multiple tool invocations.
47#[derive(Debug, Clone, Copy)]
48pub enum ToolJoinPolicy {
49    /// Wait for all to complete; if any error occurs, return the first error.
50    JoinAll,
51    /// Return error as soon as one occurs (tasks still complete in background).
52    FailFast,
53}
54
55/// Wraps a tool router service `R` to execute batches of tool invocations concurrently.
56///
57/// Note: The inner service R should be wrapped in tower::buffer::Buffer if it doesn't
58/// support concurrent access (e.g., if it doesn't implement Clone).
59#[derive(Clone)]
60pub struct ParallelToolRouter<R> {
61    inner: R,
62    limit: ConcurrencyLimit,
63    policy: ToolJoinPolicy,
64}
65
66impl<R> ParallelToolRouter<R> {
67    pub fn new(inner: R, limit: ConcurrencyLimit, policy: ToolJoinPolicy) -> Self {
68        Self {
69            inner,
70            limit,
71            policy,
72        }
73    }
74}
75
76impl<R> Service<Vec<ToolInvocation>> for ParallelToolRouter<R>
77where
78    R: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
79    R::Future: Send + 'static,
80{
81    type Response = Vec<ToolOutput>;
82    type Error = BoxError;
83    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
84
85    fn poll_ready(
86        &mut self,
87        _cx: &mut std::task::Context<'_>,
88    ) -> std::task::Poll<Result<(), Self::Error>> {
89        std::task::Poll::Ready(Ok(()))
90    }
91
92    fn call(&mut self, reqs: Vec<ToolInvocation>) -> Self::Future {
93        let limit = self.limit.0.max(1);
94        let policy = self.policy;
95        let router = self.inner.clone();
96        Box::pin(async move {
97            let sem = Arc::new(Semaphore::new(limit));
98            let mut handles = Vec::with_capacity(reqs.len());
99            for (idx, inv) in reqs.into_iter().enumerate() {
100                let permit = sem.clone().acquire_owned().await.expect("semaphore");
101                let mut svc = router.clone();
102                handles.push(tokio::spawn(async move {
103                    let _p = permit;
104                    let out = svc.ready().await?.call(inv).await;
105                    out.map(|o| (idx, o))
106                }));
107            }
108
109            // Collect results; preserve original order
110            let mut slots: Vec<Option<ToolOutput>> = vec![None; handles.len()];
111            let mut first_err: Option<BoxError> = None;
112            for h in handles {
113                match h.await.expect("join") {
114                    Ok((idx, out)) => {
115                        slots[idx] = Some(out);
116                    }
117                    Err(e) => {
118                        if first_err.is_none() {
119                            first_err = Some(e);
120                            if matches!(policy, ToolJoinPolicy::FailFast) {
121                                break;
122                            }
123                        }
124                    }
125                }
126            }
127            if let Some(e) = first_err {
128                return Err(e);
129            }
130            let mut outputs = Vec::with_capacity(slots.len());
131            for s in slots.into_iter() {
132                outputs.push(s.expect("missing output"));
133            }
134            Ok(outputs)
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use serde_json::Value;
143    use std::sync::atomic::{AtomicUsize, Ordering};
144    use tokio::time::{sleep, Duration};
145    use tower::service_fn;
146
147    #[tokio::test]
148    async fn preserves_order_with_concurrency() {
149        // Fake router that delays based on invocation name
150        let router = service_fn(|inv: ToolInvocation| async move {
151            if inv.name == "slow" {
152                sleep(Duration::from_millis(50)).await;
153            } else {
154                sleep(Duration::from_millis(5)).await;
155            }
156            Ok::<_, BoxError>(ToolOutput {
157                id: inv.id,
158                result: Value::String(inv.name),
159            })
160        });
161        let mut svc = ParallelToolRouter::new(router, ConcurrencyLimit(2), ToolJoinPolicy::JoinAll);
162        let reqs = vec![
163            ToolInvocation {
164                id: "a".into(),
165                name: "slow".into(),
166                arguments: Value::Null,
167            },
168            ToolInvocation {
169                id: "b".into(),
170                name: "fast".into(),
171                arguments: Value::Null,
172            },
173        ];
174        let outputs = svc.ready().await.unwrap().call(reqs).await.unwrap();
175        assert_eq!(outputs.len(), 2);
176        // Order must match inputs even though execution time differs
177        assert_eq!(outputs[0].result, Value::String("slow".into()));
178        assert_eq!(outputs[1].result, Value::String("fast".into()));
179    }
180
181    #[tokio::test]
182    async fn fail_fast_returns_error() {
183        let router = service_fn(|inv: ToolInvocation| async move {
184            if inv.name == "bad" {
185                Err::<ToolOutput, BoxError>("boom".into())
186            } else {
187                Ok::<_, BoxError>(ToolOutput {
188                    id: inv.id,
189                    result: Value::Null,
190                })
191            }
192        });
193        let mut svc =
194            ParallelToolRouter::new(router, ConcurrencyLimit(4), ToolJoinPolicy::FailFast);
195        let reqs = vec![
196            ToolInvocation {
197                id: "1".into(),
198                name: "ok".into(),
199                arguments: Value::Null,
200            },
201            ToolInvocation {
202                id: "2".into(),
203                name: "bad".into(),
204                arguments: Value::Null,
205            },
206            ToolInvocation {
207                id: "3".into(),
208                name: "ok".into(),
209                arguments: Value::Null,
210            },
211        ];
212        let err = svc.ready().await.unwrap().call(reqs).await.unwrap_err();
213        assert!(format!("{}", err).contains("boom"));
214    }
215
216    #[tokio::test]
217    async fn enforces_concurrency_limit() {
218        static CURRENT: AtomicUsize = AtomicUsize::new(0);
219        static MAX_OBSERVED: AtomicUsize = AtomicUsize::new(0);
220        let router = service_fn(|inv: ToolInvocation| async move {
221            let now = CURRENT.fetch_add(1, Ordering::SeqCst) + 1;
222            let max = MAX_OBSERVED.load(Ordering::SeqCst);
223            if now > max {
224                let _ = MAX_OBSERVED.compare_exchange(max, now, Ordering::SeqCst, Ordering::SeqCst);
225            }
226            sleep(Duration::from_millis(10)).await;
227            CURRENT.fetch_sub(1, Ordering::SeqCst);
228            Ok::<_, BoxError>(ToolOutput {
229                id: inv.id,
230                result: Value::Null,
231            })
232        });
233
234        let mut svc = ParallelToolRouter::new(router, ConcurrencyLimit(2), ToolJoinPolicy::JoinAll);
235        let reqs: Vec<ToolInvocation> = (0..8)
236            .map(|i| ToolInvocation {
237                id: format!("{}", i),
238                name: "n".into(),
239                arguments: Value::Null,
240            })
241            .collect();
242        let _ = svc.ready().await.unwrap().call(reqs).await.unwrap();
243        assert!(MAX_OBSERVED.load(Ordering::SeqCst) <= 2);
244    }
245}