Skip to main content

traitclaw_team/
lib.rs

1//! Multi-agent orchestration for the `TraitClaw` AI agent framework.
2//!
3//! Provides `Team` structs for composing agents, `Router` protocol for
4//! message routing, delegation between agents, and `VerificationChain`
5//! for generate-then-verify patterns.
6//!
7//! # Quick Start
8//!
9//! ```rust
10//! use traitclaw_team::{AgentRole, Team, VerificationChain, VerifyResult};
11//! use traitclaw_team::router::SequentialRouter;
12//!
13//! let team = Team::new("research_team")
14//!     .add_role(AgentRole::new("researcher", "Research topics in depth"))
15//!     .add_role(AgentRole::new("writer", "Write clear summaries"));
16//!
17//! assert_eq!(team.name(), "research_team");
18//! assert_eq!(team.roles().len(), 2);
19//! ```
20
21#![deny(missing_docs)]
22#![allow(clippy::redundant_closure)]
23
24pub mod conditional_router;
25pub mod execution;
26pub mod group_chat;
27pub mod router;
28pub mod team_context;
29
30#[cfg(test)]
31pub(crate) mod tests_common;
32
33use serde::{Deserialize, Serialize};
34use std::sync::Arc;
35use traitclaw_core::traits::provider::Provider;
36
37pub use conditional_router::ConditionalRouter;
38pub use execution::{run_verification_chain, TeamRunner};
39pub use team_context::TeamContext;
40
41/// Create an [`AgentPool`](traitclaw_core::pool::AgentPool) from a [`Team`] and a provider.
42///
43/// Each role's `system_prompt` is used as the agent's system prompt.
44/// Roles without a `system_prompt` cause an error listing all missing roles.
45///
46/// # Example
47///
48/// ```rust
49/// use traitclaw_team::{AgentRole, Team, pool_from_team};
50/// use traitclaw_core::traits::provider::Provider;
51///
52/// # fn example(provider: impl Provider) -> traitclaw_core::Result<()> {
53/// let team = Team::new("content_team")
54///     .add_role(AgentRole::new("researcher", "Research").with_system_prompt("You research topics."))
55///     .add_role(AgentRole::new("writer", "Write").with_system_prompt("You write articles."));
56///
57/// let pool = pool_from_team(&team, provider)?;
58/// assert_eq!(pool.len(), 2);
59/// # Ok(())
60/// # }
61/// ```
62///
63/// # Errors
64///
65/// Returns an error if any role in the team is missing a `system_prompt`.
66pub fn pool_from_team(
67    team: &Team,
68    provider: impl Provider,
69) -> traitclaw_core::Result<traitclaw_core::pool::AgentPool> {
70    // Check for missing system_prompts first
71    let missing: Vec<&str> = team
72        .roles()
73        .iter()
74        .filter(|r| r.system_prompt.is_none())
75        .map(|r| r.name.as_str())
76        .collect();
77
78    if !missing.is_empty() {
79        return Err(traitclaw_core::Error::Config(format!(
80            "Cannot create AgentPool from team '{}': roles missing system_prompt: [{}]",
81            team.name(),
82            missing.join(", ")
83        )));
84    }
85
86    let factory = traitclaw_core::factory::AgentFactory::new(provider);
87    let agents: Vec<traitclaw_core::Agent> = team
88        .roles()
89        .iter()
90        .map(|role| factory.spawn(role.system_prompt.as_ref().expect("checked above")))
91        .collect();
92
93    Ok(traitclaw_core::pool::AgentPool::new(agents))
94}
95
96/// Create an [`AgentPool`](traitclaw_core::pool::AgentPool) from a [`Team`]
97/// using a pre-wrapped `Arc<dyn Provider>`.
98///
99/// Same as [`pool_from_team`] but accepts a shared provider reference.
100pub fn pool_from_team_arc(
101    team: &Team,
102    provider: Arc<dyn Provider>,
103) -> traitclaw_core::Result<traitclaw_core::pool::AgentPool> {
104    let missing: Vec<&str> = team
105        .roles()
106        .iter()
107        .filter(|r| r.system_prompt.is_none())
108        .map(|r| r.name.as_str())
109        .collect();
110
111    if !missing.is_empty() {
112        return Err(traitclaw_core::Error::Config(format!(
113            "Cannot create AgentPool from team '{}': roles missing system_prompt: [{}]",
114            team.name(),
115            missing.join(", ")
116        )));
117    }
118
119    let factory = traitclaw_core::factory::AgentFactory::from_arc(provider);
120    let agents: Vec<traitclaw_core::Agent> = team
121        .roles()
122        .iter()
123        .map(|role| factory.spawn(role.system_prompt.as_ref().expect("checked above")))
124        .collect();
125
126    Ok(traitclaw_core::pool::AgentPool::new(agents))
127}
128
129/// A team of agents working together.
130pub struct Team {
131    name: String,
132    roles: Vec<AgentRole>,
133    router: Box<dyn router::Router>,
134}
135
136impl Team {
137    /// Create a new team with the given name.
138    #[must_use]
139    pub fn new(name: impl Into<String>) -> Self {
140        Self {
141            name: name.into(),
142            roles: Vec::new(),
143            router: Box::new(router::SequentialRouter::new()),
144        }
145    }
146
147    /// Add a role to the team.
148    #[must_use]
149    pub fn add_role(mut self, role: AgentRole) -> Self {
150        self.roles.push(role);
151        self
152    }
153
154    /// Get the team name.
155    #[must_use]
156    pub fn name(&self) -> &str {
157        &self.name
158    }
159
160    /// Get the team's roles.
161    #[must_use]
162    pub fn roles(&self) -> &[AgentRole] {
163        &self.roles
164    }
165
166    /// Find a role by name.
167    #[must_use]
168    pub fn find_role(&self, name: &str) -> Option<&AgentRole> {
169        self.roles.iter().find(|r| r.name == name)
170    }
171
172    /// Set a custom router for this team.
173    ///
174    /// Default: [`SequentialRouter`](router::SequentialRouter).
175    #[must_use]
176    pub fn with_router(mut self, router: impl router::Router) -> Self {
177        self.router = Box::new(router);
178        self
179    }
180
181    /// Get a reference to the team's router.
182    #[must_use]
183    pub fn router(&self) -> &dyn router::Router {
184        &*self.router
185    }
186}
187
188/// A role within a team — describes what an agent specializes in.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct AgentRole {
191    /// Role name (used for routing).
192    pub name: String,
193    /// Description of the role's responsibilities.
194    pub description: String,
195    /// Optional system prompt override for this role.
196    pub system_prompt: Option<String>,
197}
198
199impl AgentRole {
200    /// Create a new agent role.
201    #[must_use]
202    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
203        Self {
204            name: name.into(),
205            description: description.into(),
206            system_prompt: None,
207        }
208    }
209
210    /// Set a custom system prompt for this role.
211    #[must_use]
212    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
213        self.system_prompt = Some(prompt.into());
214        self
215    }
216}
217
218/// A verification chain: generate with one agent, verify with another.
219///
220/// If verification fails, the generation is retried with feedback.
221pub struct VerificationChain {
222    /// Maximum number of generate-verify cycles.
223    pub max_retries: usize,
224}
225
226impl VerificationChain {
227    /// Create a new verification chain with default 3 retries.
228    #[must_use]
229    pub fn new() -> Self {
230        Self { max_retries: 3 }
231    }
232
233    /// Set the maximum number of retries.
234    #[must_use]
235    pub fn with_max_retries(mut self, n: usize) -> Self {
236        self.max_retries = n;
237        self
238    }
239}
240
241impl Default for VerificationChain {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247/// Result of a verification step.
248#[derive(Debug, Clone)]
249pub enum VerifyResult {
250    /// Verification passed — output is acceptable.
251    Accepted(String),
252    /// Verification failed — include feedback for retry.
253    Rejected(String),
254}
255
256impl VerifyResult {
257    /// Check if the result was accepted.
258    #[must_use]
259    pub fn is_accepted(&self) -> bool {
260        matches!(self, Self::Accepted(_))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_team_builder() {
270        let team = Team::new("test_team")
271            .add_role(AgentRole::new("role1", "First role"))
272            .add_role(AgentRole::new("role2", "Second role"));
273
274        assert_eq!(team.name(), "test_team");
275        assert_eq!(team.roles().len(), 2);
276    }
277
278    #[test]
279    fn test_find_role() {
280        let team = Team::new("team").add_role(AgentRole::new("researcher", "Research"));
281
282        assert!(team.find_role("researcher").is_some());
283        assert!(team.find_role("unknown").is_none());
284    }
285
286    #[test]
287    fn test_agent_role_with_prompt() {
288        let role =
289            AgentRole::new("writer", "Write docs").with_system_prompt("You are a technical writer");
290        assert_eq!(
291            role.system_prompt,
292            Some("You are a technical writer".into())
293        );
294    }
295
296    #[test]
297    fn test_verification_chain_default() {
298        let chain = VerificationChain::new();
299        assert_eq!(chain.max_retries, 3);
300    }
301
302    #[test]
303    fn test_verification_chain_custom_retries() {
304        let chain = VerificationChain::new().with_max_retries(5);
305        assert_eq!(chain.max_retries, 5);
306    }
307
308    #[test]
309    fn test_verify_result() {
310        assert!(VerifyResult::Accepted("ok".into()).is_accepted());
311        assert!(!VerifyResult::Rejected("bad".into()).is_accepted());
312    }
313}