1use std::collections::HashMap;
4use std::sync::Arc;
5
6use swink_agent::{Agent, AgentMessage, AgentResult, ContentBlock, LlmMessage};
7use tokio_util::sync::CancellationToken;
8
9use super::events::PipelineEvent;
10use super::output::{PipelineError, PipelineOutput, StepResult};
11use super::registry::PipelineRegistry;
12use super::types::{Pipeline, PipelineId};
13
14pub trait AgentFactory: Send + Sync {
18 fn create(&self, name: &str) -> Result<Agent, PipelineError>;
20}
21
22pub struct SimpleAgentFactory {
26 builders: HashMap<String, Arc<dyn Fn() -> Agent + Send + Sync>>,
27}
28
29impl SimpleAgentFactory {
30 pub fn new() -> Self {
32 Self {
33 builders: HashMap::new(),
34 }
35 }
36
37 pub fn register(
39 &mut self,
40 name: impl Into<String>,
41 builder: impl Fn() -> Agent + Send + Sync + 'static,
42 ) {
43 self.builders.insert(name.into(), Arc::new(builder));
44 }
45}
46
47impl Default for SimpleAgentFactory {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl AgentFactory for SimpleAgentFactory {
54 fn create(&self, name: &str) -> Result<Agent, PipelineError> {
55 let builder = self
56 .builders
57 .get(name)
58 .ok_or_else(|| PipelineError::AgentNotFound {
59 name: name.to_owned(),
60 })?;
61 Ok(builder())
62 }
63}
64
65pub struct PipelineExecutor {
69 factory: Arc<dyn AgentFactory>,
70 registry: Arc<PipelineRegistry>,
71 event_handler: Option<Arc<dyn Fn(PipelineEvent) + Send + Sync>>,
72}
73
74impl PipelineExecutor {
75 pub fn new(factory: Arc<dyn AgentFactory>, registry: Arc<PipelineRegistry>) -> Self {
77 Self {
78 factory,
79 registry,
80 event_handler: None,
81 }
82 }
83
84 #[must_use]
86 pub fn with_event_handler(
87 mut self,
88 handler: impl Fn(PipelineEvent) + Send + Sync + 'static,
89 ) -> Self {
90 self.event_handler = Some(Arc::new(handler));
91 self
92 }
93
94 fn emit(&self, event: PipelineEvent) {
96 if let Some(handler) = &self.event_handler {
97 handler(event);
98 }
99 }
100
101 pub async fn run(
103 &self,
104 pipeline_id: &PipelineId,
105 input: String,
106 cancellation_token: CancellationToken,
107 ) -> Result<PipelineOutput, PipelineError> {
108 let pipeline = match self.registry.get(pipeline_id) {
109 Some(pipeline) => pipeline,
110 None => {
111 let err = PipelineError::PipelineNotFound {
112 id: pipeline_id.clone(),
113 };
114 self.emit(PipelineEvent::Failed {
115 pipeline_id: pipeline_id.clone(),
116 error_message: err.to_string(),
117 });
118 return Err(err);
119 }
120 };
121
122 let result = match pipeline {
123 Pipeline::Sequential {
124 id,
125 name,
126 steps,
127 pass_context,
128 } => {
129 self.run_sequential(id, name, steps, pass_context, input, cancellation_token)
130 .await
131 }
132 Pipeline::Parallel {
133 id,
134 name,
135 branches,
136 merge_strategy,
137 } => {
138 super::parallel::run_parallel(
139 &self.factory,
140 &self.event_handler,
141 id,
142 name,
143 branches,
144 merge_strategy,
145 input,
146 cancellation_token,
147 )
148 .await
149 }
150 Pipeline::Loop {
151 id,
152 name,
153 body,
154 exit_condition,
155 max_iterations,
156 } => {
157 super::loop_exec::run_loop(
158 &self.factory,
159 &self.event_handler,
160 id,
161 name,
162 body,
163 exit_condition,
164 max_iterations,
165 input,
166 cancellation_token,
167 )
168 .await
169 }
170 };
171
172 if let Err(err) = &result {
173 self.emit(PipelineEvent::Failed {
174 pipeline_id: pipeline_id.clone(),
175 error_message: err.to_string(),
176 });
177 }
178
179 result
180 }
181
182 async fn run_sequential(
183 &self,
184 id: PipelineId,
185 name: String,
186 steps: Vec<String>,
187 pass_context: bool,
188 input: String,
189 cancellation_token: CancellationToken,
190 ) -> Result<PipelineOutput, PipelineError> {
191 let start = std::time::Instant::now();
192 let mut step_results = Vec::new();
193 let mut current_input = input;
194 let mut total_usage = swink_agent::Usage::default();
195 let mut context_messages: Vec<LlmMessage> = Vec::new();
197
198 self.emit(PipelineEvent::Started {
199 pipeline_id: id.clone(),
200 pipeline_name: name.clone(),
201 });
202
203 for (index, agent_name) in steps.iter().enumerate() {
204 if cancellation_token.is_cancelled() {
205 return Err(PipelineError::Cancelled);
206 }
207
208 self.emit(PipelineEvent::StepStarted {
209 pipeline_id: id.clone(),
210 step_index: index,
211 agent_name: agent_name.clone(),
212 });
213
214 let step_start = std::time::Instant::now();
215 let mut agent = self.factory.create(agent_name)?;
216
217 let messages = if pass_context && !context_messages.is_empty() {
219 let mut msgs: Vec<AgentMessage> = context_messages
220 .iter()
221 .map(|llm| AgentMessage::Llm(llm.clone()))
222 .collect();
223 msgs.push(user_msg(¤t_input));
224 msgs
225 } else {
226 vec![user_msg(¤t_input)]
227 };
228
229 let result =
230 agent
231 .prompt_async(messages)
232 .await
233 .map_err(|e| PipelineError::StepFailed {
234 step_index: index,
235 agent_name: agent_name.clone(),
236 source: Box::new(e),
237 })?;
238
239 let response = extract_text_response(&result);
240 let step_duration = step_start.elapsed();
241
242 total_usage += result.usage.clone();
243
244 self.emit(PipelineEvent::StepCompleted {
245 pipeline_id: id.clone(),
246 step_index: index,
247 agent_name: agent_name.clone(),
248 duration: step_duration,
249 usage: result.usage.clone(),
250 });
251
252 step_results.push(StepResult {
253 agent_name: agent_name.clone(),
254 response: response.clone(),
255 duration: step_duration,
256 usage: result.usage.clone(),
257 });
258
259 if pass_context {
261 context_messages.push(LlmMessage::User(swink_agent::UserMessage {
263 content: vec![ContentBlock::Text {
264 text: current_input.clone(),
265 }],
266 timestamp: 0,
267 cache_hint: None,
268 }));
269 for msg in &result.messages {
271 if let AgentMessage::Llm(llm @ LlmMessage::Assistant(_)) = msg {
272 context_messages.push(llm.clone());
273 }
274 }
275 }
276
277 current_input = response;
278 }
279
280 let total_duration = start.elapsed();
281 let final_response = step_results
282 .last()
283 .map(|s| s.response.clone())
284 .unwrap_or_default();
285
286 self.emit(PipelineEvent::Completed {
287 pipeline_id: id.clone(),
288 total_duration,
289 total_usage: total_usage.clone(),
290 });
291
292 Ok(PipelineOutput {
293 pipeline_id: id,
294 final_response,
295 steps: step_results,
296 total_duration,
297 total_usage,
298 })
299 }
300}
301
302fn user_msg(text: &str) -> AgentMessage {
304 AgentMessage::Llm(LlmMessage::User(swink_agent::UserMessage {
305 content: vec![ContentBlock::Text {
306 text: text.to_string(),
307 }],
308 timestamp: 0,
309 cache_hint: None,
310 }))
311}
312
313fn extract_text_response(result: &AgentResult) -> String {
315 result
316 .messages
317 .iter()
318 .rev()
319 .find_map(|m| match m {
320 AgentMessage::Llm(LlmMessage::Assistant(msg)) => Some(msg),
321 _ => None,
322 })
323 .map(|msg| {
324 msg.content
325 .iter()
326 .filter_map(|b| match b {
327 ContentBlock::Text { text } => Some(text.as_str()),
328 _ => None,
329 })
330 .collect::<Vec<_>>()
331 .join("")
332 })
333 .unwrap_or_default()
334}
335
336#[cfg(all(test, feature = "testkit"))]
337mod tests {
338 use super::*;
339 use std::sync::Arc;
340 use swink_agent::AgentOptions;
341 use swink_agent::testing::{MockStreamFn, default_convert, default_model, text_only_events};
342
343 fn make_agent() -> Agent {
344 let options = AgentOptions::new(
345 "test",
346 default_model(),
347 Arc::new(MockStreamFn::new(vec![])),
348 default_convert,
349 );
350 Agent::new(options)
351 }
352
353 fn make_text_agent(text: &str) -> Agent {
354 let events = text_only_events(text);
355 let options = AgentOptions::new(
356 "test",
357 default_model(),
358 Arc::new(MockStreamFn::new(vec![events])),
359 default_convert,
360 );
361 Agent::new(options)
362 }
363
364 #[test]
367 fn factory_create_registered_agent_succeeds() {
368 let mut factory = SimpleAgentFactory::new();
369 factory.register("test-agent", make_agent);
370
371 let result = factory.create("test-agent");
372 assert!(result.is_ok());
373 }
374
375 #[test]
376 fn factory_create_unknown_returns_agent_not_found() {
377 let factory = SimpleAgentFactory::new();
378
379 let result = factory.create("nonexistent");
380 assert!(matches!(
381 result,
382 Err(PipelineError::AgentNotFound { name }) if name == "nonexistent"
383 ));
384 }
385
386 fn build_executor(factory: SimpleAgentFactory, registry: PipelineRegistry) -> PipelineExecutor {
389 PipelineExecutor::new(Arc::new(factory), Arc::new(registry))
390 }
391
392 #[tokio::test]
393 async fn sequential_two_step_pipeline() {
394 let mut factory = SimpleAgentFactory::new();
395 factory.register("agent-a", || make_text_agent("hello"));
396 factory.register("agent-b", || make_text_agent("world"));
397
398 let registry = PipelineRegistry::new();
399 let pipeline = Pipeline::sequential("two-step", vec!["agent-a".into(), "agent-b".into()]);
400 let id = pipeline.id().clone();
401 registry.register(pipeline);
402
403 let executor = build_executor(factory, registry);
404 let token = CancellationToken::new();
405
406 let output = executor.run(&id, "input".into(), token).await.unwrap();
407 assert_eq!(output.final_response, "world");
408 assert_eq!(output.steps.len(), 2);
409 assert_eq!(output.steps[0].agent_name, "agent-a");
410 assert_eq!(output.steps[0].response, "hello");
411 assert_eq!(output.steps[1].agent_name, "agent-b");
412 assert_eq!(output.steps[1].response, "world");
413 }
414
415 #[tokio::test]
416 async fn sequential_missing_step_agent_halts_with_error() {
417 let mut factory = SimpleAgentFactory::new();
419 factory.register("agent-a", || make_text_agent("step-one"));
420 factory.register("agent-c", || make_text_agent("step-three"));
422
423 let registry = PipelineRegistry::new();
424 let pipeline = Pipeline::sequential(
425 "three-step",
426 vec!["agent-a".into(), "agent-b".into(), "agent-c".into()],
427 );
428 let id = pipeline.id().clone();
429 registry.register(pipeline);
430
431 let executor = build_executor(factory, registry);
432 let token = CancellationToken::new();
433
434 let result = executor.run(&id, "input".into(), token).await;
435 assert!(result.is_err(), "expected error when step agent not found");
436 assert!(
437 matches!(result.unwrap_err(), PipelineError::AgentNotFound { name } if name == "agent-b"),
438 "expected AgentNotFound for agent-b"
439 );
440 }
441
442 #[tokio::test]
443 async fn sequential_missing_agent_returns_agent_not_found() {
444 let factory = SimpleAgentFactory::new(); let registry = PipelineRegistry::new();
447 let pipeline = Pipeline::sequential("missing", vec!["ghost".into()]);
448 let id = pipeline.id().clone();
449 registry.register(pipeline);
450
451 let executor = build_executor(factory, registry);
452 let token = CancellationToken::new();
453
454 let result = executor.run(&id, "input".into(), token).await;
455 assert!(matches!(
456 result,
457 Err(PipelineError::AgentNotFound { name }) if name == "ghost"
458 ));
459 }
460
461 #[tokio::test]
462 async fn sequential_zero_steps_returns_empty() {
463 let factory = SimpleAgentFactory::new();
464
465 let registry = PipelineRegistry::new();
466 let pipeline = Pipeline::sequential("empty", vec![]);
467 let id = pipeline.id().clone();
468 registry.register(pipeline);
469
470 let executor = build_executor(factory, registry);
471 let token = CancellationToken::new();
472
473 let output = executor.run(&id, "input".into(), token).await.unwrap();
474 assert!(output.steps.is_empty());
475 assert!(output.final_response.is_empty());
476 }
477
478 #[tokio::test]
479 async fn run_unknown_pipeline_returns_not_found() {
480 let factory = SimpleAgentFactory::new();
481 let registry = PipelineRegistry::new();
482
483 let executor = build_executor(factory, registry);
484 let token = CancellationToken::new();
485 let unknown_id = PipelineId::new("nonexistent");
486
487 let result = executor.run(&unknown_id, "input".into(), token).await;
488 assert!(matches!(
489 result,
490 Err(PipelineError::PipelineNotFound { id }) if id == unknown_id
491 ));
492 }
493}