traitclaw_team/
execution.rs1use std::collections::HashMap;
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41
42use traitclaw_core::{Error, Result};
43
44pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
46
47pub type BoundAgent = Arc<dyn Fn(String) -> BoxFuture<'static, Result<String>> + Send + Sync>;
49
50pub struct TeamRunner {
55 agents: HashMap<String, BoundAgent>,
56 sequence: Vec<String>,
57 max_iterations: usize,
58}
59
60impl TeamRunner {
61 #[must_use]
67 pub fn new(max_iterations: usize) -> Self {
68 assert!(max_iterations > 0, "max_iterations must be > 0");
69 Self {
70 agents: HashMap::new(),
71 sequence: Vec::new(),
72 max_iterations,
73 }
74 }
75
76 pub fn bind<F, Fut>(&mut self, role: impl Into<String>, agent: F)
80 where
81 F: Fn(String) -> Fut + Send + Sync + 'static,
82 Fut: Future<Output = Result<String>> + Send + 'static,
83 {
84 let name = role.into();
85 self.agents.insert(
86 name,
87 Arc::new(move |input: String| Box::pin(agent(input)) as BoxFuture<_>),
88 );
89 }
90
91 pub fn set_sequence(&mut self, roles: &[&str]) {
93 self.sequence = roles.iter().map(|&r| r.to_string()).collect();
94 }
95
96 #[must_use]
98 pub fn is_bound(&self, role: &str) -> bool {
99 self.agents.contains_key(role)
100 }
101
102 pub async fn run(&self, input: &str) -> Result<String> {
115 if self.sequence.is_empty() {
116 return Err(Error::Runtime("TeamRunner has no sequence defined".into()));
117 }
118
119 let mut current_input = input.to_string();
120 let mut iterations = 0;
121
122 for role in &self.sequence {
123 if iterations >= self.max_iterations {
124 return Err(Error::Runtime(format!(
125 "TeamRunner exceeded max_iterations ({}) at role '{}'",
126 self.max_iterations, role
127 )));
128 }
129
130 let agent = self.agents.get(role).ok_or_else(|| {
131 Error::Runtime(format!("Role '{}' not bound in TeamRunner", role))
132 })?;
133
134 current_input = agent(current_input).await?;
135 iterations += 1;
136 }
137
138 Ok(current_input)
139 }
140}
141
142pub async fn run_verification_chain<G, GFut, V, VFut>(
183 initial_input: &str,
184 max_retries: usize,
185 generator: G,
186 verifier: V,
187) -> Result<String>
188where
189 G: Fn(String) -> GFut,
190 GFut: Future<Output = Result<String>>,
191 V: Fn(String) -> VFut,
192 VFut: Future<Output = std::result::Result<String, String>>,
193{
194 let mut prompt = initial_input.to_string();
195 let mut last_output = String::new();
196
197 for attempt in 0..=max_retries {
198 let output = generator(prompt.clone()).await?;
199 last_output = output.clone();
200
201 match verifier(output.clone()).await {
202 Ok(accepted) => return Ok(accepted),
203 Err(feedback) => {
204 if attempt == max_retries {
205 return Err(Error::Runtime(format!(
207 "VerificationChain exhausted {max_retries} retries. Last output: {last_output}. Last feedback: {feedback}"
208 )));
209 }
210 prompt = format!("{initial_input}\n\nPrevious attempt:\n{output}\n\nFeedback: {feedback}\n\nPlease improve.");
212 }
213 }
214 }
215
216 Ok(last_output)
218}
219
220#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[tokio::test]
231 async fn test_sequential_team_two_agents() {
232 let mut runner = TeamRunner::new(10);
234
235 runner.bind("researcher", |input: String| async move {
236 Ok(format!("Research: {input}"))
237 });
238 runner.bind("writer", |input: String| async move {
239 Ok(format!("Written: {input}"))
240 });
241 runner.set_sequence(&["researcher", "writer"]);
242
243 let result = runner.run("AI history").await.unwrap();
244 assert!(
245 result.starts_with("Written: Research: AI history"),
246 "got: {result}"
247 );
248 }
249
250 #[tokio::test]
251 async fn test_max_iterations_exceeded() {
252 let mut runner = TeamRunner::new(1); runner.bind("a", |i: String| async move { Ok(i) });
255 runner.bind("b", |i: String| async move { Ok(i) });
256 runner.set_sequence(&["a", "b"]); let result = runner.run("test").await;
259 assert!(
260 result.is_err(),
261 "expected error for max_iterations exceeded"
262 );
263 let msg = result.unwrap_err().to_string();
264 assert!(
265 msg.contains("max_iterations"),
266 "error should mention max_iterations: {msg}"
267 );
268 }
269
270 #[tokio::test]
271 async fn test_unbound_role_returns_error() {
272 let mut runner = TeamRunner::new(10);
273 runner.set_sequence(&["missing_role"]);
274 let result = runner.run("test").await;
275 assert!(result.is_err());
276 }
277
278 #[tokio::test]
279 async fn test_empty_sequence_returns_error() {
280 let runner = TeamRunner::new(10);
281 let result = runner.run("test").await;
282 assert!(result.is_err());
283 assert!(result.unwrap_err().to_string().contains("no sequence"));
284 }
285
286 #[tokio::test]
289 async fn test_verification_chain_accepts_on_second_try() {
290 let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
292 let attempt_clone = attempt_count.clone();
293
294 let verify_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
295 let verify_clone = verify_count.clone();
296
297 let result = run_verification_chain(
298 "Write something",
299 3,
300 move |_input: String| {
301 let c = attempt_clone.clone();
302 async move {
303 let n = c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
304 Ok(format!("generated-attempt-{n}"))
305 }
306 },
307 move |output: String| {
308 let v = verify_clone.clone();
309 async move {
310 let n = v.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
311 if n == 0 {
312 Err("Too brief".to_string())
313 } else {
314 Ok(output) }
316 }
317 },
318 )
319 .await;
320
321 assert!(
322 result.is_ok(),
323 "expected success on retry, got: {:?}",
324 result
325 );
326 let output = result.unwrap();
327 assert!(
329 output.contains("attempt-1"),
330 "expected 2nd attempt output, got: {output}"
331 );
332 }
333
334 #[tokio::test]
335 async fn test_verification_chain_all_retries_exhausted() {
336 let result = run_verification_chain(
338 "Write something",
339 2, |input: String| async move { Ok(format!("draft: {input}")) },
341 |output: String| async move { Err(format!("not good enough: {output}")) },
342 )
343 .await;
344
345 assert!(result.is_err(), "expected error when all retries exhausted");
346 let msg = result.unwrap_err().to_string();
347 assert!(
348 msg.contains("exhausted"),
349 "error should mention exhausted: {msg}"
350 );
351 }
352
353 #[tokio::test]
354 async fn test_verification_chain_accepts_immediately() {
355 let result = run_verification_chain(
356 "input",
357 3,
358 |_| async { Ok("perfect output".to_string()) },
359 |output: String| async move { Ok(output) },
360 )
361 .await;
362
363 assert_eq!(result.unwrap(), "perfect output");
364 }
365
366 #[tokio::test]
367 async fn test_verification_chain_feedback_included_in_retry() {
368 let got_feedback = std::sync::Arc::new(std::sync::Mutex::new(false));
370 let got_feedback_clone = got_feedback.clone();
371
372 let _ = run_verification_chain(
373 "Write",
374 1,
375 move |input: String| {
376 let f = got_feedback_clone.clone();
377 async move {
378 if input.contains("Feedback:") {
379 *f.lock().unwrap() = true;
380 }
381 Ok(format!("output: {input}"))
382 }
383 },
384 |_| async move { Err("needs work".to_string()) },
385 )
386 .await;
387
388 assert!(
389 *got_feedback.lock().unwrap(),
390 "retry prompt should contain 'Feedback:'"
391 );
392 }
393}