Skip to main content

traitclaw_core/
pool.rs

1//! Agent pool for managing and executing groups of agents.
2//!
3//! `AgentPool` holds a collection of agents and provides methods for
4//! sequential pipeline execution (output chaining).
5
6use crate::agent::Agent;
7use crate::agent::AgentOutput;
8use crate::Result;
9
10/// A collection of agents for group execution.
11///
12/// `AgentPool` takes ownership of a `Vec<Agent>` and provides
13/// sequential pipeline execution where each agent's output feeds
14/// into the next agent's input.
15///
16/// # Example
17///
18/// ```rust,no_run
19/// use traitclaw_core::pool::AgentPool;
20/// use traitclaw_core::agent::Agent;
21///
22/// # fn example(agents: Vec<Agent>) {
23/// let pool = AgentPool::new(agents);
24/// assert_eq!(pool.len(), 3);
25/// # }
26/// ```
27pub struct AgentPool {
28    agents: Vec<Agent>,
29}
30
31impl std::fmt::Debug for AgentPool {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("AgentPool")
34            .field("len", &self.agents.len())
35            .finish()
36    }
37}
38
39impl AgentPool {
40    /// Create a new pool from a vector of agents.
41    #[must_use]
42    pub fn new(agents: Vec<Agent>) -> Self {
43        Self { agents }
44    }
45
46    /// Returns the number of agents in the pool.
47    #[must_use]
48    pub fn len(&self) -> usize {
49        self.agents.len()
50    }
51
52    /// Returns `true` if the pool contains no agents.
53    #[must_use]
54    pub fn is_empty(&self) -> bool {
55        self.agents.is_empty()
56    }
57
58    /// Get a reference to an agent by index.
59    ///
60    /// Returns `None` if the index is out of bounds.
61    #[must_use]
62    pub fn get(&self, index: usize) -> Option<&Agent> {
63        self.agents.get(index)
64    }
65
66    /// Run agents sequentially, chaining outputs.
67    ///
68    /// Each agent receives the previous agent's text output as input.
69    /// The first agent receives the provided `input` string.
70    ///
71    /// # Example
72    ///
73    /// ```rust,no_run
74    /// use traitclaw_core::pool::AgentPool;
75    /// use traitclaw_core::agent::Agent;
76    ///
77    /// # async fn example(pool: &AgentPool) -> traitclaw_core::Result<()> {
78    /// let output = pool.run_sequential("Research Rust async patterns").await?;
79    /// println!("Final output: {}", output.text());
80    /// # Ok(())
81    /// # }
82    /// ```
83    ///
84    /// # Errors
85    ///
86    /// Returns an error immediately if any agent in the pipeline fails.
87    /// Earlier agents' outputs are not available on error.
88    pub async fn run_sequential(&self, input: &str) -> Result<AgentOutput> {
89        if self.agents.is_empty() {
90            return Err(crate::Error::Runtime(
91                "AgentPool::run_sequential called on empty pool".into(),
92            ));
93        }
94
95        let mut current_input = input.to_string();
96        let mut last_output: Option<AgentOutput> = None;
97
98        for agent in &self.agents {
99            let output = agent.run(&current_input).await?;
100            current_input = output.text().to_string();
101            last_output = Some(output);
102        }
103
104        // SAFETY: We checked is_empty above, so last_output is always Some
105        Ok(last_output.expect("pool is non-empty"))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::traits::provider::Provider;
113    use crate::types::completion::{CompletionRequest, CompletionResponse, ResponseContent, Usage};
114    use crate::types::model_info::{ModelInfo, ModelTier};
115    use crate::types::stream::CompletionStream;
116    use async_trait::async_trait;
117    use std::sync::atomic::{AtomicUsize, Ordering};
118    use std::sync::Arc;
119
120    struct EchoProvider {
121        info: ModelInfo,
122        prefix: String,
123        call_count: Arc<AtomicUsize>,
124    }
125
126    impl EchoProvider {
127        fn new(prefix: &str) -> Self {
128            Self {
129                info: ModelInfo::new("echo", ModelTier::Small, 4_096, false, false, false),
130                prefix: prefix.to_string(),
131                call_count: Arc::new(AtomicUsize::new(0)),
132            }
133        }
134    }
135
136    #[async_trait]
137    impl Provider for EchoProvider {
138        async fn complete(&self, req: CompletionRequest) -> crate::Result<CompletionResponse> {
139            self.call_count.fetch_add(1, Ordering::SeqCst);
140            // Echo back the last user message with our prefix
141            let last_msg = req
142                .messages
143                .iter()
144                .rev()
145                .find(|m| m.role == crate::types::message::MessageRole::User)
146                .map(|m| m.content.clone())
147                .unwrap_or_default();
148            Ok(CompletionResponse {
149                content: ResponseContent::Text(format!("[{}] {}", self.prefix, last_msg)),
150                usage: Usage {
151                    prompt_tokens: 1,
152                    completion_tokens: 1,
153                    total_tokens: 2,
154                },
155            })
156        }
157        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
158            unimplemented!()
159        }
160        fn model_info(&self) -> &ModelInfo {
161            &self.info
162        }
163    }
164
165    #[test]
166    fn test_pool_new_and_len() {
167        let agents = vec![
168            Agent::with_system(EchoProvider::new("A"), "Agent A"),
169            Agent::with_system(EchoProvider::new("B"), "Agent B"),
170        ];
171        let pool = AgentPool::new(agents);
172        assert_eq!(pool.len(), 2);
173        assert!(!pool.is_empty());
174    }
175
176    #[test]
177    fn test_pool_empty() {
178        let pool = AgentPool::new(vec![]);
179        assert!(pool.is_empty());
180        assert_eq!(pool.len(), 0);
181    }
182
183    #[test]
184    fn test_pool_get() {
185        let agents = vec![
186            Agent::with_system(EchoProvider::new("A"), "Agent A"),
187            Agent::with_system(EchoProvider::new("B"), "Agent B"),
188            Agent::with_system(EchoProvider::new("C"), "Agent C"),
189        ];
190        let pool = AgentPool::new(agents);
191        assert!(pool.get(0).is_some());
192        assert!(pool.get(1).is_some());
193        assert!(pool.get(2).is_some());
194        assert!(pool.get(5).is_none());
195    }
196
197    #[tokio::test]
198    async fn test_pool_run_sequential_single_agent() {
199        let agents = vec![Agent::with_system(EchoProvider::new("Solo"), "Solo agent")];
200        let pool = AgentPool::new(agents);
201        let output = pool.run_sequential("Hello").await.unwrap();
202        assert_eq!(output.text(), "[Solo] Hello");
203    }
204
205    #[tokio::test]
206    async fn test_pool_run_sequential_pipeline() {
207        let agents = vec![
208            Agent::with_system(EchoProvider::new("R"), "Researcher"),
209            Agent::with_system(EchoProvider::new("W"), "Writer"),
210        ];
211        let pool = AgentPool::new(agents);
212        let output = pool.run_sequential("topic").await.unwrap();
213        // First agent: "[R] topic" → Second agent: "[W] [R] topic"
214        assert_eq!(output.text(), "[W] [R] topic");
215    }
216
217    #[tokio::test]
218    async fn test_pool_run_sequential_empty_pool_errors() {
219        let pool = AgentPool::new(vec![]);
220        let result = pool.run_sequential("anything").await;
221        assert!(result.is_err());
222    }
223
224    #[test]
225    fn test_pool_debug() {
226        let pool = AgentPool::new(vec![Agent::with_system(EchoProvider::new("A"), "A")]);
227        let debug = format!("{pool:?}");
228        assert!(debug.contains("AgentPool"));
229        assert!(debug.contains("len: 1"));
230    }
231}