Skip to main content

zeph_core/pipeline/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub mod builder;
5pub mod builtin;
6pub mod parallel;
7pub mod step;
8
9pub use builder::Pipeline;
10pub use parallel::ParallelStep;
11pub use step::Step;
12
13#[derive(Debug, thiserror::Error)]
14pub enum PipelineError {
15    #[error(transparent)]
16    Llm(#[from] zeph_llm::LlmError),
17
18    #[error(transparent)]
19    Memory(#[from] zeph_memory::MemoryError),
20
21    #[error("extraction failed: {0}")]
22    Extract(String),
23
24    #[error("{0}")]
25    Custom(String),
26}
27
28#[cfg(test)]
29mod tests {
30    #![allow(clippy::ignored_unit_patterns, clippy::manual_string_new)]
31
32    use std::sync::Arc;
33
34    use super::builtin::{ExtractStep, LlmStep, MapStep, RetrievalStep};
35    use super::parallel::parallel;
36    use super::*;
37    use zeph_llm::mock::MockProvider;
38    use zeph_memory::in_memory_store::InMemoryVectorStore;
39    use zeph_memory::vector_store::{VectorPoint, VectorStore};
40
41    struct AddSuffix {
42        suffix: String,
43    }
44
45    impl Step for AddSuffix {
46        type Input = String;
47        type Output = String;
48
49        async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
50            Ok(format!("{input}{}", self.suffix))
51        }
52    }
53
54    struct ParseLen;
55
56    impl Step for ParseLen {
57        type Input = String;
58        type Output = usize;
59
60        async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
61            Ok(input.len())
62        }
63    }
64
65    #[tokio::test]
66    async fn single_step_pipeline() {
67        let result = Pipeline::start(AddSuffix { suffix: "!".into() })
68            .run("hello".into())
69            .await
70            .unwrap();
71        assert_eq!(result, "hello!");
72    }
73
74    #[tokio::test]
75    async fn chained_pipeline() {
76        let result = Pipeline::start(AddSuffix {
77            suffix: " world".into(),
78        })
79        .step(AddSuffix { suffix: "!".into() })
80        .run("hello".into())
81        .await
82        .unwrap();
83        assert_eq!(result, "hello world!");
84    }
85
86    #[tokio::test]
87    async fn heterogeneous_chain() {
88        let result = Pipeline::start(AddSuffix {
89            suffix: "abc".into(),
90        })
91        .step(ParseLen)
92        .run(String::new())
93        .await
94        .unwrap();
95        assert_eq!(result, 3);
96    }
97
98    #[tokio::test]
99    async fn map_step() {
100        let result = Pipeline::start(MapStep::new(|s: String| s.to_uppercase()))
101            .run("hello".into())
102            .await
103            .unwrap();
104        assert_eq!(result, "HELLO");
105    }
106
107    #[tokio::test]
108    async fn parallel_step() {
109        let step = parallel(
110            AddSuffix {
111                suffix: "_a".into(),
112            },
113            AddSuffix {
114                suffix: "_b".into(),
115            },
116        );
117        let result = Pipeline::start(step).run("x".into()).await.unwrap();
118        assert_eq!(result, ("x_a".into(), "x_b".into()));
119    }
120
121    #[tokio::test]
122    async fn error_propagation() {
123        struct FailStep;
124
125        impl Step for FailStep {
126            type Input = String;
127            type Output = String;
128
129            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
130                Err(PipelineError::Custom("boom".into()))
131            }
132        }
133
134        let result = Pipeline::start(AddSuffix {
135            suffix: "ok".into(),
136        })
137        .step(FailStep)
138        .run("hi".into())
139        .await;
140        assert!(result.is_err());
141        assert!(result.unwrap_err().to_string().contains("boom"));
142    }
143
144    #[tokio::test]
145    async fn extract_step() {
146        use super::builtin::ExtractStep;
147
148        let result = Pipeline::start(MapStep::new(|(): ()| r#"{"a":1,"b":"two"}"#.to_owned()))
149            .step(ExtractStep::<serde_json::Value>::new())
150            .run(())
151            .await
152            .unwrap();
153        assert_eq!(result["a"], 1);
154        assert_eq!(result["b"], "two");
155    }
156
157    // --- LlmStep tests ---
158
159    #[tokio::test]
160    async fn llm_step_returns_response() {
161        let provider = Arc::new(MockProvider::with_responses(vec!["answer".into()]));
162        let result = Pipeline::start(LlmStep::new(provider))
163            .run("question".into())
164            .await
165            .unwrap();
166        assert_eq!(result, "answer");
167    }
168
169    #[tokio::test]
170    async fn llm_step_with_system_prompt() {
171        let provider = Arc::new(MockProvider::with_responses(vec!["ok".into()]));
172        let result = Pipeline::start(LlmStep::new(provider).with_system_prompt("sys"))
173            .run("input".into())
174            .await
175            .unwrap();
176        assert_eq!(result, "ok");
177    }
178
179    #[tokio::test]
180    async fn llm_step_propagates_error() {
181        let provider = Arc::new(MockProvider::failing());
182        let result = Pipeline::start(LlmStep::new(provider))
183            .run("input".into())
184            .await;
185        assert!(result.is_err());
186        assert!(
187            matches!(result.unwrap_err(), PipelineError::Llm(_)),
188            "expected PipelineError::Llm"
189        );
190    }
191
192    // --- RetrievalStep tests ---
193
194    #[tokio::test]
195    async fn retrieval_step_returns_results() {
196        let store = Arc::new(InMemoryVectorStore::new());
197        store.ensure_collection("col", 3).await.unwrap();
198        store
199            .upsert(
200                "col",
201                vec![VectorPoint {
202                    id: "p1".into(),
203                    vector: vec![1.0, 0.0, 0.0],
204                    payload: std::collections::HashMap::new(),
205                }],
206            )
207            .await
208            .unwrap();
209
210        let mut provider = MockProvider::default();
211        provider.supports_embeddings = true;
212        provider.embedding = vec![1.0, 0.0, 0.0];
213        let provider = Arc::new(provider);
214
215        let step = RetrievalStep::new(store, provider, "col", 5);
216        let results = Pipeline::start(step).run("query".into()).await.unwrap();
217        assert_eq!(results.len(), 1);
218        assert_eq!(results[0].id, "p1");
219    }
220
221    #[tokio::test]
222    async fn retrieval_step_embed_error_propagates() {
223        let store = Arc::new(InMemoryVectorStore::new());
224        store.ensure_collection("col", 3).await.unwrap();
225
226        let provider = Arc::new(MockProvider::default());
227
228        let step = RetrievalStep::new(store, provider, "col", 5);
229        let result = Pipeline::start(step).run("query".into()).await;
230        assert!(matches!(result.unwrap_err(), PipelineError::Llm(_)));
231    }
232
233    // --- ExtractStep failure tests ---
234
235    #[tokio::test]
236    async fn extract_step_invalid_json() {
237        let result = Pipeline::start(MapStep::new(|(): ()| "not json".to_owned()))
238            .step(ExtractStep::<serde_json::Value>::new())
239            .run(())
240            .await;
241        assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
242    }
243
244    #[tokio::test]
245    async fn extract_step_type_mismatch() {
246        #[derive(Debug, serde::Deserialize)]
247        struct Strict {
248            #[expect(dead_code)]
249            required_field: Vec<u32>,
250        }
251
252        let result = Pipeline::start(MapStep::new(|(): ()| r#"{"a":1}"#.to_owned()))
253            .step(ExtractStep::<Strict>::new())
254            .run(())
255            .await;
256        assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
257    }
258
259    // --- ParallelStep error tests ---
260
261    #[tokio::test]
262    async fn parallel_step_first_fails() {
263        struct FailStep;
264        impl Step for FailStep {
265            type Input = String;
266            type Output = String;
267            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
268                Err(PipelineError::Custom("fail_a".into()))
269            }
270        }
271
272        let step = parallel(
273            FailStep,
274            AddSuffix {
275                suffix: "_ok".into(),
276            },
277        );
278        let result = Pipeline::start(step).run("x".into()).await;
279        assert!(result.is_err());
280    }
281
282    #[tokio::test]
283    async fn parallel_step_both_fail() {
284        struct FailA;
285        impl Step for FailA {
286            type Input = String;
287            type Output = String;
288            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
289                Err(PipelineError::Custom("fail_a".into()))
290            }
291        }
292        struct FailB;
293        impl Step for FailB {
294            type Input = String;
295            type Output = String;
296            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
297                Err(PipelineError::Custom("fail_b".into()))
298            }
299        }
300
301        let step = parallel(FailA, FailB);
302        let result = Pipeline::start(step).run("x".into()).await;
303        assert!(result.is_err());
304    }
305}