Skip to main content

swink_agent_patterns/pipeline/
types.rs

1//! Core pipeline types: PipelineId, Pipeline, MergeStrategy, ExitCondition.
2
3use std::fmt;
4
5use regex::Regex;
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9// ─── PipelineId ─────────────────────────────────────────────────────────────
10
11/// Unique identifier for a pipeline definition.
12#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct PipelineId(String);
14
15impl PipelineId {
16    /// Create a pipeline ID from a string.
17    pub fn new(id: impl Into<String>) -> Self {
18        Self(id.into())
19    }
20
21    /// Generate a unique pipeline ID using UUID v4.
22    pub fn generate() -> Self {
23        Self(Uuid::new_v4().to_string())
24    }
25}
26
27impl fmt::Display for PipelineId {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.write_str(&self.0)
30    }
31}
32
33// ─── MergeStrategy ──────────────────────────────────────────────────────────
34
35/// Controls how parallel branch outputs are combined.
36#[derive(Clone, Debug, Serialize, Deserialize)]
37pub enum MergeStrategy {
38    /// Join all outputs in declaration order with a separator.
39    Concat { separator: String },
40    /// Return the first branch to complete.
41    First,
42    /// Return the first N branches to complete.
43    Fastest { n: usize },
44    /// Pass all outputs to a named aggregator agent.
45    Custom { aggregator: String },
46}
47
48// ─── ExitCondition ──────────────────────────────────────────────────────────
49
50/// Controls when a loop pipeline terminates.
51#[derive(Clone, Debug)]
52pub enum ExitCondition {
53    /// Exit when the body agent invokes the named tool.
54    ToolCalled { tool_name: String },
55    /// Exit when the output matches the regex pattern.
56    OutputContains {
57        pattern: String,
58        #[allow(dead_code)]
59        compiled: Regex,
60    },
61    /// Always run to the max_iterations cap.
62    MaxIterations,
63}
64
65impl ExitCondition {
66    /// Create an `OutputContains` condition, eagerly validating the regex.
67    ///
68    /// Returns `Err` if the pattern is not a valid regex.
69    pub fn output_contains(pattern: impl Into<String>) -> Result<Self, String> {
70        let pattern = pattern.into();
71        let compiled =
72            Regex::new(&pattern).map_err(|e| format!("invalid regex '{pattern}': {e}"))?;
73        Ok(Self::OutputContains { pattern, compiled })
74    }
75}
76
77impl Serialize for ExitCondition {
78    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
79        #[derive(Serialize)]
80        #[serde(tag = "type")]
81        enum Helper<'a> {
82            ToolCalled { tool_name: &'a str },
83            OutputContains { pattern: &'a str },
84            MaxIterations,
85        }
86
87        match self {
88            Self::ToolCalled { tool_name } => {
89                Helper::ToolCalled { tool_name }.serialize(serializer)
90            }
91            Self::OutputContains { pattern, .. } => {
92                Helper::OutputContains { pattern }.serialize(serializer)
93            }
94            Self::MaxIterations => Helper::MaxIterations.serialize(serializer),
95        }
96    }
97}
98
99impl<'de> Deserialize<'de> for ExitCondition {
100    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
101        #[derive(Deserialize)]
102        #[serde(tag = "type")]
103        enum Helper {
104            ToolCalled { tool_name: String },
105            OutputContains { pattern: String },
106            MaxIterations,
107        }
108
109        let h = Helper::deserialize(deserializer)?;
110        match h {
111            Helper::ToolCalled { tool_name } => Ok(Self::ToolCalled { tool_name }),
112            Helper::OutputContains { pattern } => {
113                let compiled = Regex::new(&pattern).map_err(serde::de::Error::custom)?;
114                Ok(Self::OutputContains { pattern, compiled })
115            }
116            Helper::MaxIterations => Ok(Self::MaxIterations),
117        }
118    }
119}
120
121// ─── Pipeline ───────────────────────────────────────────────────────────────
122
123/// A pipeline definition describing how to compose multiple agents.
124#[derive(Clone, Debug, Serialize, Deserialize)]
125#[serde(tag = "type")]
126pub enum Pipeline {
127    /// Execute agents in declared order, passing output forward.
128    Sequential {
129        id: PipelineId,
130        name: String,
131        steps: Vec<String>,
132        pass_context: bool,
133    },
134    /// Execute agents concurrently and merge results.
135    Parallel {
136        id: PipelineId,
137        name: String,
138        branches: Vec<String>,
139        merge_strategy: MergeStrategy,
140    },
141    /// Execute an agent repeatedly until an exit condition is met.
142    Loop {
143        id: PipelineId,
144        name: String,
145        body: String,
146        exit_condition: ExitCondition,
147        max_iterations: usize,
148    },
149}
150
151impl Pipeline {
152    /// Create a sequential pipeline without context passing.
153    pub fn sequential(name: impl Into<String>, steps: Vec<String>) -> Self {
154        Self::Sequential {
155            id: PipelineId::generate(),
156            name: name.into(),
157            steps,
158            pass_context: false,
159        }
160    }
161
162    /// Create a sequential pipeline with context passing enabled.
163    pub fn sequential_with_context(name: impl Into<String>, steps: Vec<String>) -> Self {
164        Self::Sequential {
165            id: PipelineId::generate(),
166            name: name.into(),
167            steps,
168            pass_context: true,
169        }
170    }
171
172    /// Create a parallel pipeline.
173    pub fn parallel(
174        name: impl Into<String>,
175        branches: Vec<String>,
176        merge_strategy: MergeStrategy,
177    ) -> Self {
178        Self::Parallel {
179            id: PipelineId::generate(),
180            name: name.into(),
181            branches,
182            merge_strategy,
183        }
184    }
185
186    /// Create a loop pipeline.
187    pub fn loop_(
188        name: impl Into<String>,
189        body: impl Into<String>,
190        exit_condition: ExitCondition,
191    ) -> Self {
192        Self::Loop {
193            id: PipelineId::generate(),
194            name: name.into(),
195            body: body.into(),
196            exit_condition,
197            max_iterations: 10,
198        }
199    }
200
201    /// Create a loop pipeline with a custom max iterations cap.
202    pub fn loop_with_max(
203        name: impl Into<String>,
204        body: impl Into<String>,
205        exit_condition: ExitCondition,
206        max_iterations: usize,
207    ) -> Self {
208        Self::Loop {
209            id: PipelineId::generate(),
210            name: name.into(),
211            body: body.into(),
212            exit_condition,
213            max_iterations,
214        }
215    }
216
217    /// Override the auto-generated ID.
218    #[must_use]
219    pub fn with_id(mut self, id: PipelineId) -> Self {
220        match &mut self {
221            Self::Sequential { id: i, .. }
222            | Self::Parallel { id: i, .. }
223            | Self::Loop { id: i, .. } => *i = id,
224        }
225        self
226    }
227
228    /// Returns the pipeline's unique identifier.
229    pub fn id(&self) -> &PipelineId {
230        match self {
231            Self::Sequential { id, .. } | Self::Parallel { id, .. } | Self::Loop { id, .. } => id,
232        }
233    }
234
235    /// Returns the pipeline's human-readable name.
236    pub fn name(&self) -> &str {
237        match self {
238            Self::Sequential { name, .. }
239            | Self::Parallel { name, .. }
240            | Self::Loop { name, .. } => name,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use std::collections::HashSet;
249
250    // T014: PipelineId tests
251
252    #[test]
253    fn pipeline_id_new_and_display() {
254        let id = PipelineId::new("test-pipeline");
255        assert_eq!(id.to_string(), "test-pipeline");
256    }
257
258    #[test]
259    fn pipeline_id_generate_is_unique() {
260        let a = PipelineId::generate();
261        let b = PipelineId::generate();
262        assert_ne!(a, b);
263    }
264
265    #[test]
266    fn pipeline_id_equality_and_hashing() {
267        let a = PipelineId::new("same");
268        let b = PipelineId::new("same");
269        assert_eq!(a, b);
270
271        let mut set = HashSet::new();
272        set.insert(a);
273        assert!(set.contains(&b));
274    }
275
276    #[test]
277    fn pipeline_id_serde_roundtrip() {
278        let id = PipelineId::new("round-trip");
279        let json = serde_json::to_string(&id).unwrap();
280        let parsed: PipelineId = serde_json::from_str(&json).unwrap();
281        assert_eq!(id, parsed);
282    }
283
284    // T015: ExitCondition tests
285
286    #[test]
287    fn exit_condition_output_contains_valid_regex() {
288        let cond = ExitCondition::output_contains(r"\bDONE\b").unwrap();
289        match cond {
290            ExitCondition::OutputContains { pattern, compiled } => {
291                assert_eq!(pattern, r"\bDONE\b");
292                assert!(compiled.is_match("task DONE here"));
293            }
294            _ => panic!("expected OutputContains"),
295        }
296    }
297
298    #[test]
299    fn exit_condition_output_contains_invalid_regex() {
300        let result = ExitCondition::output_contains("[invalid");
301        assert!(result.is_err());
302    }
303
304    #[test]
305    fn exit_condition_serde_roundtrip_recompiles() {
306        let cond = ExitCondition::output_contains(r"done|finished").unwrap();
307        let json = serde_json::to_string(&cond).unwrap();
308        let parsed: ExitCondition = serde_json::from_str(&json).unwrap();
309        match parsed {
310            ExitCondition::OutputContains { pattern, compiled } => {
311                assert_eq!(pattern, "done|finished");
312                assert!(compiled.is_match("all done"));
313            }
314            _ => panic!("expected OutputContains"),
315        }
316    }
317
318    // T016: Pipeline constructor tests
319
320    #[test]
321    fn sequential_constructor() {
322        let p = Pipeline::sequential("test", vec!["a".into(), "b".into()]);
323        assert_eq!(p.name(), "test");
324        match &p {
325            Pipeline::Sequential {
326                pass_context,
327                steps,
328                ..
329            } => {
330                assert!(!pass_context);
331                assert_eq!(steps.len(), 2);
332            }
333            _ => panic!("expected Sequential"),
334        }
335    }
336
337    #[test]
338    fn parallel_constructor() {
339        let p = Pipeline::parallel("par", vec!["x".into(), "y".into()], MergeStrategy::First);
340        assert_eq!(p.name(), "par");
341        assert!(matches!(p, Pipeline::Parallel { .. }));
342    }
343
344    #[test]
345    fn loop_constructor() {
346        let p = Pipeline::loop_("lp", "body-agent", ExitCondition::MaxIterations);
347        assert_eq!(p.name(), "lp");
348        match &p {
349            Pipeline::Loop { max_iterations, .. } => assert_eq!(*max_iterations, 10),
350            _ => panic!("expected Loop"),
351        }
352    }
353
354    #[test]
355    fn with_id_overrides_generated_id() {
356        let custom = PipelineId::new("custom-id");
357        let p = Pipeline::sequential("s", vec![]).with_id(custom.clone());
358        assert_eq!(*p.id(), custom);
359    }
360
361    #[test]
362    fn auto_generated_ids_are_unique() {
363        let a = Pipeline::sequential("a", vec![]);
364        let b = Pipeline::sequential("b", vec![]);
365        assert_ne!(a.id(), b.id());
366    }
367}