Skip to main content

swink_agent_patterns/pipeline/
executor.rs

1//! Pipeline executor and agent factory traits.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use swink_agent::{Agent, AgentMessage, AgentResult, ContentBlock, LlmMessage};
7use tokio_util::sync::CancellationToken;
8
9use super::events::PipelineEvent;
10use super::output::{PipelineError, PipelineOutput, StepResult};
11use super::registry::PipelineRegistry;
12use super::types::{Pipeline, PipelineId};
13
14// ─── AgentFactory ───────────────────────────────────────────────────────────
15
16/// Trait for creating agents by name during pipeline execution.
17pub trait AgentFactory: Send + Sync {
18    /// Create an agent with the given name.
19    fn create(&self, name: &str) -> Result<Agent, PipelineError>;
20}
21
22// ─── SimpleAgentFactory ─────────────────────────────────────────────────────
23
24/// A basic agent factory backed by a name → builder-fn registry.
25pub struct SimpleAgentFactory {
26    builders: HashMap<String, Arc<dyn Fn() -> Agent + Send + Sync>>,
27}
28
29impl SimpleAgentFactory {
30    /// Create an empty factory.
31    pub fn new() -> Self {
32        Self {
33            builders: HashMap::new(),
34        }
35    }
36
37    /// Register a builder function for the given agent name.
38    pub fn register(
39        &mut self,
40        name: impl Into<String>,
41        builder: impl Fn() -> Agent + Send + Sync + 'static,
42    ) {
43        self.builders.insert(name.into(), Arc::new(builder));
44    }
45}
46
47impl Default for SimpleAgentFactory {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl AgentFactory for SimpleAgentFactory {
54    fn create(&self, name: &str) -> Result<Agent, PipelineError> {
55        let builder = self
56            .builders
57            .get(name)
58            .ok_or_else(|| PipelineError::AgentNotFound {
59                name: name.to_owned(),
60            })?;
61        Ok(builder())
62    }
63}
64
65// ─── PipelineExecutor ───────────────────────────────────────────────────────
66
67/// Orchestrates pipeline execution using an agent factory and registry.
68pub struct PipelineExecutor {
69    factory: Arc<dyn AgentFactory>,
70    registry: Arc<PipelineRegistry>,
71    event_handler: Option<Arc<dyn Fn(PipelineEvent) + Send + Sync>>,
72}
73
74impl PipelineExecutor {
75    /// Create a new executor with the given factory and registry.
76    pub fn new(factory: Arc<dyn AgentFactory>, registry: Arc<PipelineRegistry>) -> Self {
77        Self {
78            factory,
79            registry,
80            event_handler: None,
81        }
82    }
83
84    /// Set an event handler that receives pipeline lifecycle events.
85    #[must_use]
86    pub fn with_event_handler(
87        mut self,
88        handler: impl Fn(PipelineEvent) + Send + Sync + 'static,
89    ) -> Self {
90        self.event_handler = Some(Arc::new(handler));
91        self
92    }
93
94    /// Emit a pipeline event to the handler (if set).
95    fn emit(&self, event: PipelineEvent) {
96        if let Some(handler) = &self.event_handler {
97            handler(event);
98        }
99    }
100
101    /// Run a pipeline by ID.
102    pub async fn run(
103        &self,
104        pipeline_id: &PipelineId,
105        input: String,
106        cancellation_token: CancellationToken,
107    ) -> Result<PipelineOutput, PipelineError> {
108        let pipeline = match self.registry.get(pipeline_id) {
109            Some(pipeline) => pipeline,
110            None => {
111                let err = PipelineError::PipelineNotFound {
112                    id: pipeline_id.clone(),
113                };
114                self.emit(PipelineEvent::Failed {
115                    pipeline_id: pipeline_id.clone(),
116                    error_message: err.to_string(),
117                });
118                return Err(err);
119            }
120        };
121
122        let result = match pipeline {
123            Pipeline::Sequential {
124                id,
125                name,
126                steps,
127                pass_context,
128            } => {
129                self.run_sequential(id, name, steps, pass_context, input, cancellation_token)
130                    .await
131            }
132            Pipeline::Parallel {
133                id,
134                name,
135                branches,
136                merge_strategy,
137            } => {
138                super::parallel::run_parallel(
139                    &self.factory,
140                    &self.event_handler,
141                    id,
142                    name,
143                    branches,
144                    merge_strategy,
145                    input,
146                    cancellation_token,
147                )
148                .await
149            }
150            Pipeline::Loop {
151                id,
152                name,
153                body,
154                exit_condition,
155                max_iterations,
156            } => {
157                super::loop_exec::run_loop(
158                    &self.factory,
159                    &self.event_handler,
160                    id,
161                    name,
162                    body,
163                    exit_condition,
164                    max_iterations,
165                    input,
166                    cancellation_token,
167                )
168                .await
169            }
170        };
171
172        if let Err(err) = &result {
173            self.emit(PipelineEvent::Failed {
174                pipeline_id: pipeline_id.clone(),
175                error_message: err.to_string(),
176            });
177        }
178
179        result
180    }
181
182    async fn run_sequential(
183        &self,
184        id: PipelineId,
185        name: String,
186        steps: Vec<String>,
187        pass_context: bool,
188        input: String,
189        cancellation_token: CancellationToken,
190    ) -> Result<PipelineOutput, PipelineError> {
191        let start = std::time::Instant::now();
192        let mut step_results = Vec::new();
193        let mut current_input = input;
194        let mut total_usage = swink_agent::Usage::default();
195        // Accumulated message history for pass_context mode.
196        let mut context_messages: Vec<LlmMessage> = Vec::new();
197
198        self.emit(PipelineEvent::Started {
199            pipeline_id: id.clone(),
200            pipeline_name: name.clone(),
201        });
202
203        for (index, agent_name) in steps.iter().enumerate() {
204            if cancellation_token.is_cancelled() {
205                return Err(PipelineError::Cancelled);
206            }
207
208            self.emit(PipelineEvent::StepStarted {
209                pipeline_id: id.clone(),
210                step_index: index,
211                agent_name: agent_name.clone(),
212            });
213
214            let step_start = std::time::Instant::now();
215            let mut agent = self.factory.create(agent_name)?;
216
217            // Build input messages: either accumulated context or just the current input.
218            let messages = if pass_context && !context_messages.is_empty() {
219                let mut msgs: Vec<AgentMessage> = context_messages
220                    .iter()
221                    .map(|llm| AgentMessage::Llm(llm.clone()))
222                    .collect();
223                msgs.push(user_msg(&current_input));
224                msgs
225            } else {
226                vec![user_msg(&current_input)]
227            };
228
229            let result =
230                agent
231                    .prompt_async(messages)
232                    .await
233                    .map_err(|e| PipelineError::StepFailed {
234                        step_index: index,
235                        agent_name: agent_name.clone(),
236                        source: Box::new(e),
237                    })?;
238
239            let response = extract_text_response(&result);
240            let step_duration = step_start.elapsed();
241
242            total_usage += result.usage.clone();
243
244            self.emit(PipelineEvent::StepCompleted {
245                pipeline_id: id.clone(),
246                step_index: index,
247                agent_name: agent_name.clone(),
248                duration: step_duration,
249                usage: result.usage.clone(),
250            });
251
252            step_results.push(StepResult {
253                agent_name: agent_name.clone(),
254                response: response.clone(),
255                duration: step_duration,
256                usage: result.usage.clone(),
257            });
258
259            // In pass_context mode, accumulate the user message and assistant response.
260            if pass_context {
261                // Push the user message as an LlmMessage
262                context_messages.push(LlmMessage::User(swink_agent::UserMessage {
263                    content: vec![ContentBlock::Text {
264                        text: current_input.clone(),
265                    }],
266                    timestamp: 0,
267                    cache_hint: None,
268                }));
269                // Add the assistant messages from the result.
270                for msg in &result.messages {
271                    if let AgentMessage::Llm(llm @ LlmMessage::Assistant(_)) = msg {
272                        context_messages.push(llm.clone());
273                    }
274                }
275            }
276
277            current_input = response;
278        }
279
280        let total_duration = start.elapsed();
281        let final_response = step_results
282            .last()
283            .map(|s| s.response.clone())
284            .unwrap_or_default();
285
286        self.emit(PipelineEvent::Completed {
287            pipeline_id: id.clone(),
288            total_duration,
289            total_usage: total_usage.clone(),
290        });
291
292        Ok(PipelineOutput {
293            pipeline_id: id,
294            final_response,
295            steps: step_results,
296            total_duration,
297            total_usage,
298        })
299    }
300}
301
302/// Build a user message from text (local helper to avoid testkit dependency).
303fn user_msg(text: &str) -> AgentMessage {
304    AgentMessage::Llm(LlmMessage::User(swink_agent::UserMessage {
305        content: vec![ContentBlock::Text {
306            text: text.to_string(),
307        }],
308        timestamp: 0,
309        cache_hint: None,
310    }))
311}
312
313/// Extract concatenated text content from an agent result's last assistant message.
314fn extract_text_response(result: &AgentResult) -> String {
315    result
316        .messages
317        .iter()
318        .rev()
319        .find_map(|m| match m {
320            AgentMessage::Llm(LlmMessage::Assistant(msg)) => Some(msg),
321            _ => None,
322        })
323        .map(|msg| {
324            msg.content
325                .iter()
326                .filter_map(|b| match b {
327                    ContentBlock::Text { text } => Some(text.as_str()),
328                    _ => None,
329                })
330                .collect::<Vec<_>>()
331                .join("")
332        })
333        .unwrap_or_default()
334}
335
336#[cfg(all(test, feature = "testkit"))]
337mod tests {
338    use super::*;
339    use std::sync::Arc;
340    use swink_agent::AgentOptions;
341    use swink_agent::testing::{MockStreamFn, default_convert, default_model, text_only_events};
342
343    fn make_agent() -> Agent {
344        let options = AgentOptions::new(
345            "test",
346            default_model(),
347            Arc::new(MockStreamFn::new(vec![])),
348            default_convert,
349        );
350        Agent::new(options)
351    }
352
353    fn make_text_agent(text: &str) -> Agent {
354        let events = text_only_events(text);
355        let options = AgentOptions::new(
356            "test",
357            default_model(),
358            Arc::new(MockStreamFn::new(vec![events])),
359            default_convert,
360        );
361        Agent::new(options)
362    }
363
364    // T017: SimpleAgentFactory tests
365
366    #[test]
367    fn factory_create_registered_agent_succeeds() {
368        let mut factory = SimpleAgentFactory::new();
369        factory.register("test-agent", make_agent);
370
371        let result = factory.create("test-agent");
372        assert!(result.is_ok());
373    }
374
375    #[test]
376    fn factory_create_unknown_returns_agent_not_found() {
377        let factory = SimpleAgentFactory::new();
378
379        let result = factory.create("nonexistent");
380        assert!(matches!(
381            result,
382            Err(PipelineError::AgentNotFound { name }) if name == "nonexistent"
383        ));
384    }
385
386    // T020-T024: Sequential pipeline tests
387
388    fn build_executor(factory: SimpleAgentFactory, registry: PipelineRegistry) -> PipelineExecutor {
389        PipelineExecutor::new(Arc::new(factory), Arc::new(registry))
390    }
391
392    #[tokio::test]
393    async fn sequential_two_step_pipeline() {
394        let mut factory = SimpleAgentFactory::new();
395        factory.register("agent-a", || make_text_agent("hello"));
396        factory.register("agent-b", || make_text_agent("world"));
397
398        let registry = PipelineRegistry::new();
399        let pipeline = Pipeline::sequential("two-step", vec!["agent-a".into(), "agent-b".into()]);
400        let id = pipeline.id().clone();
401        registry.register(pipeline);
402
403        let executor = build_executor(factory, registry);
404        let token = CancellationToken::new();
405
406        let output = executor.run(&id, "input".into(), token).await.unwrap();
407        assert_eq!(output.final_response, "world");
408        assert_eq!(output.steps.len(), 2);
409        assert_eq!(output.steps[0].agent_name, "agent-a");
410        assert_eq!(output.steps[0].response, "hello");
411        assert_eq!(output.steps[1].agent_name, "agent-b");
412        assert_eq!(output.steps[1].response, "world");
413    }
414
415    #[tokio::test]
416    async fn sequential_missing_step_agent_halts_with_error() {
417        // agent-b is not registered in the factory, causing AgentNotFound.
418        let mut factory = SimpleAgentFactory::new();
419        factory.register("agent-a", || make_text_agent("step-one"));
420        // agent-b intentionally not registered
421        factory.register("agent-c", || make_text_agent("step-three"));
422
423        let registry = PipelineRegistry::new();
424        let pipeline = Pipeline::sequential(
425            "three-step",
426            vec!["agent-a".into(), "agent-b".into(), "agent-c".into()],
427        );
428        let id = pipeline.id().clone();
429        registry.register(pipeline);
430
431        let executor = build_executor(factory, registry);
432        let token = CancellationToken::new();
433
434        let result = executor.run(&id, "input".into(), token).await;
435        assert!(result.is_err(), "expected error when step agent not found");
436        assert!(
437            matches!(result.unwrap_err(), PipelineError::AgentNotFound { name } if name == "agent-b"),
438            "expected AgentNotFound for agent-b"
439        );
440    }
441
442    #[tokio::test]
443    async fn sequential_missing_agent_returns_agent_not_found() {
444        let factory = SimpleAgentFactory::new(); // no agents registered
445
446        let registry = PipelineRegistry::new();
447        let pipeline = Pipeline::sequential("missing", vec!["ghost".into()]);
448        let id = pipeline.id().clone();
449        registry.register(pipeline);
450
451        let executor = build_executor(factory, registry);
452        let token = CancellationToken::new();
453
454        let result = executor.run(&id, "input".into(), token).await;
455        assert!(matches!(
456            result,
457            Err(PipelineError::AgentNotFound { name }) if name == "ghost"
458        ));
459    }
460
461    #[tokio::test]
462    async fn sequential_zero_steps_returns_empty() {
463        let factory = SimpleAgentFactory::new();
464
465        let registry = PipelineRegistry::new();
466        let pipeline = Pipeline::sequential("empty", vec![]);
467        let id = pipeline.id().clone();
468        registry.register(pipeline);
469
470        let executor = build_executor(factory, registry);
471        let token = CancellationToken::new();
472
473        let output = executor.run(&id, "input".into(), token).await.unwrap();
474        assert!(output.steps.is_empty());
475        assert!(output.final_response.is_empty());
476    }
477
478    #[tokio::test]
479    async fn run_unknown_pipeline_returns_not_found() {
480        let factory = SimpleAgentFactory::new();
481        let registry = PipelineRegistry::new();
482
483        let executor = build_executor(factory, registry);
484        let token = CancellationToken::new();
485        let unknown_id = PipelineId::new("nonexistent");
486
487        let result = executor.run(&unknown_id, "input".into(), token).await;
488        assert!(matches!(
489            result,
490            Err(PipelineError::PipelineNotFound { id }) if id == unknown_id
491        ));
492    }
493}