tower_llm/concurrency/
mod.rs1use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use tokio::sync::Semaphore;
38use tower::{BoxError, Service, ServiceExt};
39
40use crate::core::{ToolInvocation, ToolOutput};
41
42#[derive(Debug, Clone, Copy)]
44pub struct ConcurrencyLimit(pub usize);
45
46#[derive(Debug, Clone, Copy)]
48pub enum ToolJoinPolicy {
49 JoinAll,
51 FailFast,
53}
54
55#[derive(Clone)]
60pub struct ParallelToolRouter<R> {
61 inner: R,
62 limit: ConcurrencyLimit,
63 policy: ToolJoinPolicy,
64}
65
66impl<R> ParallelToolRouter<R> {
67 pub fn new(inner: R, limit: ConcurrencyLimit, policy: ToolJoinPolicy) -> Self {
68 Self {
69 inner,
70 limit,
71 policy,
72 }
73 }
74}
75
76impl<R> Service<Vec<ToolInvocation>> for ParallelToolRouter<R>
77where
78 R: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
79 R::Future: Send + 'static,
80{
81 type Response = Vec<ToolOutput>;
82 type Error = BoxError;
83 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
84
85 fn poll_ready(
86 &mut self,
87 _cx: &mut std::task::Context<'_>,
88 ) -> std::task::Poll<Result<(), Self::Error>> {
89 std::task::Poll::Ready(Ok(()))
90 }
91
92 fn call(&mut self, reqs: Vec<ToolInvocation>) -> Self::Future {
93 let limit = self.limit.0.max(1);
94 let policy = self.policy;
95 let router = self.inner.clone();
96 Box::pin(async move {
97 let sem = Arc::new(Semaphore::new(limit));
98 let mut handles = Vec::with_capacity(reqs.len());
99 for (idx, inv) in reqs.into_iter().enumerate() {
100 let permit = sem.clone().acquire_owned().await.expect("semaphore");
101 let mut svc = router.clone();
102 handles.push(tokio::spawn(async move {
103 let _p = permit;
104 let out = svc.ready().await?.call(inv).await;
105 out.map(|o| (idx, o))
106 }));
107 }
108
109 let mut slots: Vec<Option<ToolOutput>> = vec![None; handles.len()];
111 let mut first_err: Option<BoxError> = None;
112 for h in handles {
113 match h.await.expect("join") {
114 Ok((idx, out)) => {
115 slots[idx] = Some(out);
116 }
117 Err(e) => {
118 if first_err.is_none() {
119 first_err = Some(e);
120 if matches!(policy, ToolJoinPolicy::FailFast) {
121 break;
122 }
123 }
124 }
125 }
126 }
127 if let Some(e) = first_err {
128 return Err(e);
129 }
130 let mut outputs = Vec::with_capacity(slots.len());
131 for s in slots.into_iter() {
132 outputs.push(s.expect("missing output"));
133 }
134 Ok(outputs)
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use serde_json::Value;
143 use std::sync::atomic::{AtomicUsize, Ordering};
144 use tokio::time::{sleep, Duration};
145 use tower::service_fn;
146
147 #[tokio::test]
148 async fn preserves_order_with_concurrency() {
149 let router = service_fn(|inv: ToolInvocation| async move {
151 if inv.name == "slow" {
152 sleep(Duration::from_millis(50)).await;
153 } else {
154 sleep(Duration::from_millis(5)).await;
155 }
156 Ok::<_, BoxError>(ToolOutput {
157 id: inv.id,
158 result: Value::String(inv.name),
159 })
160 });
161 let mut svc = ParallelToolRouter::new(router, ConcurrencyLimit(2), ToolJoinPolicy::JoinAll);
162 let reqs = vec![
163 ToolInvocation {
164 id: "a".into(),
165 name: "slow".into(),
166 arguments: Value::Null,
167 },
168 ToolInvocation {
169 id: "b".into(),
170 name: "fast".into(),
171 arguments: Value::Null,
172 },
173 ];
174 let outputs = svc.ready().await.unwrap().call(reqs).await.unwrap();
175 assert_eq!(outputs.len(), 2);
176 assert_eq!(outputs[0].result, Value::String("slow".into()));
178 assert_eq!(outputs[1].result, Value::String("fast".into()));
179 }
180
181 #[tokio::test]
182 async fn fail_fast_returns_error() {
183 let router = service_fn(|inv: ToolInvocation| async move {
184 if inv.name == "bad" {
185 Err::<ToolOutput, BoxError>("boom".into())
186 } else {
187 Ok::<_, BoxError>(ToolOutput {
188 id: inv.id,
189 result: Value::Null,
190 })
191 }
192 });
193 let mut svc =
194 ParallelToolRouter::new(router, ConcurrencyLimit(4), ToolJoinPolicy::FailFast);
195 let reqs = vec![
196 ToolInvocation {
197 id: "1".into(),
198 name: "ok".into(),
199 arguments: Value::Null,
200 },
201 ToolInvocation {
202 id: "2".into(),
203 name: "bad".into(),
204 arguments: Value::Null,
205 },
206 ToolInvocation {
207 id: "3".into(),
208 name: "ok".into(),
209 arguments: Value::Null,
210 },
211 ];
212 let err = svc.ready().await.unwrap().call(reqs).await.unwrap_err();
213 assert!(format!("{}", err).contains("boom"));
214 }
215
216 #[tokio::test]
217 async fn enforces_concurrency_limit() {
218 static CURRENT: AtomicUsize = AtomicUsize::new(0);
219 static MAX_OBSERVED: AtomicUsize = AtomicUsize::new(0);
220 let router = service_fn(|inv: ToolInvocation| async move {
221 let now = CURRENT.fetch_add(1, Ordering::SeqCst) + 1;
222 let max = MAX_OBSERVED.load(Ordering::SeqCst);
223 if now > max {
224 let _ = MAX_OBSERVED.compare_exchange(max, now, Ordering::SeqCst, Ordering::SeqCst);
225 }
226 sleep(Duration::from_millis(10)).await;
227 CURRENT.fetch_sub(1, Ordering::SeqCst);
228 Ok::<_, BoxError>(ToolOutput {
229 id: inv.id,
230 result: Value::Null,
231 })
232 });
233
234 let mut svc = ParallelToolRouter::new(router, ConcurrencyLimit(2), ToolJoinPolicy::JoinAll);
235 let reqs: Vec<ToolInvocation> = (0..8)
236 .map(|i| ToolInvocation {
237 id: format!("{}", i),
238 name: "n".into(),
239 arguments: Value::Null,
240 })
241 .collect();
242 let _ = svc.ready().await.unwrap().call(reqs).await.unwrap();
243 assert!(MAX_OBSERVED.load(Ordering::SeqCst) <= 2);
244 }
245}