Skip to main content

traitclaw_team/
execution.rs

1//! Team execution engine — `Team.bind()` and `Team.run()`.
2//!
3//! Connects [`AgentRole`](crate::AgentRole)s to callable agent functions and
4//! executes the orchestration pipeline driven by the team's router.
5//!
6//! # Design
7//!
8//! `BoundAgent` is a `Box<dyn Fn(&str) -> BoxFuture<Result<String>>>` so
9//! that any callable async closure or agent can be bound without requiring
10//! `traitclaw-core` to be imported by consumers — they just pass a closure.
11//!
12//! # Example
13//!
14//! ```rust
15//! use traitclaw_team::execution::TeamRunner;
16//!
17//! # async fn example() -> traitclaw_core::Result<()> {
18//! let mut runner = TeamRunner::new(2); // max 2 iterations
19//!
20//! // Bind agents as async closures
21//! runner.bind("researcher", |input: String| async move {
22//!     Ok(format!("Research result for: {input}"))
23//! });
24//! runner.bind("writer", |input: String| async move {
25//!     Ok(format!("Written summary of: {input}"))
26//! });
27//!
28//! // Set sequential order
29//! runner.set_sequence(&["researcher", "writer"]);
30//!
31//! let output = runner.run("Write a report on AI").await?;
32//! assert!(output.contains("Written summary"));
33//! # Ok(())
34//! # }
35//! ```
36
37use std::collections::HashMap;
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41
42use traitclaw_core::{Error, Result};
43
44/// Future alias for boxed async agent calls.
45pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
46
47/// Type-erased callable agent.
48pub type BoundAgent = Arc<dyn Fn(String) -> BoxFuture<'static, Result<String>> + Send + Sync>;
49
50/// Sequential team execution engine.
51///
52/// Binds agent roles to async callables and runs them in sequence,
53/// passing each agent's output as the next agent's input.
54pub struct TeamRunner {
55    agents: HashMap<String, BoundAgent>,
56    sequence: Vec<String>,
57    max_iterations: usize,
58}
59
60impl TeamRunner {
61    /// Create a new `TeamRunner` with the given max iteration limit.
62    ///
63    /// # Panics
64    ///
65    /// Panics if `max_iterations == 0`.
66    #[must_use]
67    pub fn new(max_iterations: usize) -> Self {
68        assert!(max_iterations > 0, "max_iterations must be > 0");
69        Self {
70            agents: HashMap::new(),
71            sequence: Vec::new(),
72            max_iterations,
73        }
74    }
75
76    /// Bind an agent closure to a role name.
77    ///
78    /// The closure receives the current input and returns the agent's response.
79    pub fn bind<F, Fut>(&mut self, role: impl Into<String>, agent: F)
80    where
81        F: Fn(String) -> Fut + Send + Sync + 'static,
82        Fut: Future<Output = Result<String>> + Send + 'static,
83    {
84        let name = role.into();
85        self.agents.insert(
86            name,
87            Arc::new(move |input: String| Box::pin(agent(input)) as BoxFuture<_>),
88        );
89    }
90
91    /// Set the execution sequence of role names.
92    pub fn set_sequence(&mut self, roles: &[&str]) {
93        self.sequence = roles.iter().map(|&r| r.to_string()).collect();
94    }
95
96    /// Check if a role has been bound.
97    #[must_use]
98    pub fn is_bound(&self, role: &str) -> bool {
99        self.agents.contains_key(role)
100    }
101
102    /// Execute the team pipeline sequentially.
103    ///
104    /// Each agent in the sequence receives the previous agent's output.
105    /// Returns the last agent's output as the final result.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if:
110    /// - A bound agent returns an error
111    /// - The sequence is empty
112    /// - A role in the sequence is not bound
113    /// - `max_iterations` is exceeded
114    pub async fn run(&self, input: &str) -> Result<String> {
115        if self.sequence.is_empty() {
116            return Err(Error::Runtime("TeamRunner has no sequence defined".into()));
117        }
118
119        let mut current_input = input.to_string();
120        let mut iterations = 0;
121
122        for role in &self.sequence {
123            if iterations >= self.max_iterations {
124                return Err(Error::Runtime(format!(
125                    "TeamRunner exceeded max_iterations ({}) at role '{}'",
126                    self.max_iterations, role
127                )));
128            }
129
130            let agent = self.agents.get(role).ok_or_else(|| {
131                Error::Runtime(format!("Role '{}' not bound in TeamRunner", role))
132            })?;
133
134            current_input = agent(current_input).await?;
135            iterations += 1;
136        }
137
138        Ok(current_input)
139    }
140}
141
142// ─────────────────────────────────────────────────────────────────────────────
143// VerificationChain execution
144// ─────────────────────────────────────────────────────────────────────────────
145
146/// Execute a generate-verify-retry loop.
147///
148/// - Generator produces output from the current prompt
149/// - Verifier returns `Ok(output)` to accept, or `Err(feedback)` to reject
150/// - On rejection, the prompt is augmented with feedback and generation retried
151///
152/// # Example
153///
154/// ```rust
155/// use traitclaw_team::execution::run_verification_chain;
156///
157/// # async fn example() -> traitclaw_core::Result<()> {
158/// let result = run_verification_chain(
159///     "Write a haiku",
160///     3,
161///     |input: String| async move {
162///         // Pretend the first attempt fails
163///         if !input.contains("retry") {
164///             Ok(format!("Draft from: {input}"))
165///         } else {
166///             Ok(format!("Improved draft from: {input}"))
167///         }
168///     },
169///     |output: String| async move {
170///         // Accept on 2nd attempt (when output mentions "retry")
171///         if output.contains("retry") {
172///             Err(format!("Needs improvement: {output}"))
173///         } else {
174///             Ok(output) // Accept on first try here
175///         }
176///     },
177/// ).await;
178/// assert!(result.is_ok());
179/// # Ok(())
180/// # }
181/// ```
182pub async fn run_verification_chain<G, GFut, V, VFut>(
183    initial_input: &str,
184    max_retries: usize,
185    generator: G,
186    verifier: V,
187) -> Result<String>
188where
189    G: Fn(String) -> GFut,
190    GFut: Future<Output = Result<String>>,
191    V: Fn(String) -> VFut,
192    VFut: Future<Output = std::result::Result<String, String>>,
193{
194    let mut prompt = initial_input.to_string();
195    let mut last_output = String::new();
196
197    for attempt in 0..=max_retries {
198        let output = generator(prompt.clone()).await?;
199        last_output = output.clone();
200
201        match verifier(output.clone()).await {
202            Ok(accepted) => return Ok(accepted),
203            Err(feedback) => {
204                if attempt == max_retries {
205                    // All retries exhausted
206                    return Err(Error::Runtime(format!(
207                        "VerificationChain exhausted {max_retries} retries. Last output: {last_output}. Last feedback: {feedback}"
208                    )));
209                }
210                // Augment prompt with feedback for retry
211                prompt = format!("{initial_input}\n\nPrevious attempt:\n{output}\n\nFeedback: {feedback}\n\nPlease improve.");
212            }
213        }
214    }
215
216    // Unreachable but satisfies compiler
217    Ok(last_output)
218}
219
220// ─────────────────────────────────────────────────────────────────────────────
221// Tests
222// ─────────────────────────────────────────────────────────────────────────────
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    // ── TeamRunner ────────────────────────────────────────────────────────────
229
230    #[tokio::test]
231    async fn test_sequential_team_two_agents() {
232        // AC #7: researcher → writer sequential pipeline
233        let mut runner = TeamRunner::new(10);
234
235        runner.bind("researcher", |input: String| async move {
236            Ok(format!("Research: {input}"))
237        });
238        runner.bind("writer", |input: String| async move {
239            Ok(format!("Written: {input}"))
240        });
241        runner.set_sequence(&["researcher", "writer"]);
242
243        let result = runner.run("AI history").await.unwrap();
244        assert!(
245            result.starts_with("Written: Research: AI history"),
246            "got: {result}"
247        );
248    }
249
250    #[tokio::test]
251    async fn test_max_iterations_exceeded() {
252        // AC #8: max_iterations exceeded → error
253        let mut runner = TeamRunner::new(1); // only 1 allowed
254        runner.bind("a", |i: String| async move { Ok(i) });
255        runner.bind("b", |i: String| async move { Ok(i) });
256        runner.set_sequence(&["a", "b"]); // 2 agents, limit=1
257
258        let result = runner.run("test").await;
259        assert!(
260            result.is_err(),
261            "expected error for max_iterations exceeded"
262        );
263        let msg = result.unwrap_err().to_string();
264        assert!(
265            msg.contains("max_iterations"),
266            "error should mention max_iterations: {msg}"
267        );
268    }
269
270    #[tokio::test]
271    async fn test_unbound_role_returns_error() {
272        let mut runner = TeamRunner::new(10);
273        runner.set_sequence(&["missing_role"]);
274        let result = runner.run("test").await;
275        assert!(result.is_err());
276    }
277
278    #[tokio::test]
279    async fn test_empty_sequence_returns_error() {
280        let runner = TeamRunner::new(10);
281        let result = runner.run("test").await;
282        assert!(result.is_err());
283        assert!(result.unwrap_err().to_string().contains("no sequence"));
284    }
285
286    // ── VerificationChain ─────────────────────────────────────────────────────
287
288    #[tokio::test]
289    async fn test_verification_chain_accepts_on_second_try() {
290        // AC #7: generator accepts on 2nd try — prompt contains retry feedback
291        let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
292        let attempt_clone = attempt_count.clone();
293
294        let verify_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
295        let verify_clone = verify_count.clone();
296
297        let result = run_verification_chain(
298            "Write something",
299            3,
300            move |_input: String| {
301                let c = attempt_clone.clone();
302                async move {
303                    let n = c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
304                    Ok(format!("generated-attempt-{n}"))
305                }
306            },
307            move |output: String| {
308                let v = verify_clone.clone();
309                async move {
310                    let n = v.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
311                    if n == 0 {
312                        Err("Too brief".to_string())
313                    } else {
314                        Ok(output) // accept on 2nd verify call
315                    }
316                }
317            },
318        )
319        .await;
320
321        assert!(
322            result.is_ok(),
323            "expected success on retry, got: {:?}",
324            result
325        );
326        let output = result.unwrap();
327        // Second attempt's output should be returned
328        assert!(
329            output.contains("attempt-1"),
330            "expected 2nd attempt output, got: {output}"
331        );
332    }
333
334    #[tokio::test]
335    async fn test_verification_chain_all_retries_exhausted() {
336        // AC #8: all retries exhausted → error with last output
337        let result = run_verification_chain(
338            "Write something",
339            2, // max 2 retries
340            |input: String| async move { Ok(format!("draft: {input}")) },
341            |output: String| async move { Err(format!("not good enough: {output}")) },
342        )
343        .await;
344
345        assert!(result.is_err(), "expected error when all retries exhausted");
346        let msg = result.unwrap_err().to_string();
347        assert!(
348            msg.contains("exhausted"),
349            "error should mention exhausted: {msg}"
350        );
351    }
352
353    #[tokio::test]
354    async fn test_verification_chain_accepts_immediately() {
355        let result = run_verification_chain(
356            "input",
357            3,
358            |_| async { Ok("perfect output".to_string()) },
359            |output: String| async move { Ok(output) },
360        )
361        .await;
362
363        assert_eq!(result.unwrap(), "perfect output");
364    }
365
366    #[tokio::test]
367    async fn test_verification_chain_feedback_included_in_retry() {
368        // AC #7: on rejection, generator retries with feedback appended
369        let got_feedback = std::sync::Arc::new(std::sync::Mutex::new(false));
370        let got_feedback_clone = got_feedback.clone();
371
372        let _ = run_verification_chain(
373            "Write",
374            1,
375            move |input: String| {
376                let f = got_feedback_clone.clone();
377                async move {
378                    if input.contains("Feedback:") {
379                        *f.lock().unwrap() = true;
380                    }
381                    Ok(format!("output: {input}"))
382                }
383            },
384            |_| async move { Err("needs work".to_string()) },
385        )
386        .await;
387
388        assert!(
389            *got_feedback.lock().unwrap(),
390            "retry prompt should contain 'Feedback:'"
391        );
392    }
393}