Skip to main content

zeph_core/pipeline/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub mod builder;
5pub mod builtin;
6pub mod parallel;
7pub mod step;
8
9pub use builder::Pipeline;
10pub use parallel::ParallelStep;
11pub use step::Step;
12
13#[derive(Debug, thiserror::Error)]
14pub enum PipelineError {
15    #[error(transparent)]
16    Llm(#[from] zeph_llm::LlmError),
17
18    #[error(transparent)]
19    Memory(#[from] zeph_memory::MemoryError),
20
21    #[error("extraction failed: {0}")]
22    Extract(String),
23
24    #[error("{0}")]
25    Custom(String),
26}
27
28#[cfg(test)]
29mod tests {
30    use std::sync::Arc;
31
32    use super::builtin::{ExtractStep, LlmStep, MapStep, RetrievalStep};
33    use super::parallel::parallel;
34    use super::*;
35    use zeph_llm::mock::MockProvider;
36    use zeph_memory::in_memory_store::InMemoryVectorStore;
37    use zeph_memory::vector_store::{VectorPoint, VectorStore};
38
39    struct AddSuffix {
40        suffix: String,
41    }
42
43    impl Step for AddSuffix {
44        type Input = String;
45        type Output = String;
46
47        async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
48            Ok(format!("{input}{}", self.suffix))
49        }
50    }
51
52    struct ParseLen;
53
54    impl Step for ParseLen {
55        type Input = String;
56        type Output = usize;
57
58        async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
59            Ok(input.len())
60        }
61    }
62
63    #[tokio::test]
64    async fn single_step_pipeline() {
65        let result = Pipeline::start(AddSuffix { suffix: "!".into() })
66            .run("hello".into())
67            .await
68            .unwrap();
69        assert_eq!(result, "hello!");
70    }
71
72    #[tokio::test]
73    async fn chained_pipeline() {
74        let result = Pipeline::start(AddSuffix {
75            suffix: " world".into(),
76        })
77        .step(AddSuffix { suffix: "!".into() })
78        .run("hello".into())
79        .await
80        .unwrap();
81        assert_eq!(result, "hello world!");
82    }
83
84    #[tokio::test]
85    async fn heterogeneous_chain() {
86        let result = Pipeline::start(AddSuffix {
87            suffix: "abc".into(),
88        })
89        .step(ParseLen)
90        .run("".into())
91        .await
92        .unwrap();
93        assert_eq!(result, 3);
94    }
95
96    #[tokio::test]
97    async fn map_step() {
98        let result = Pipeline::start(MapStep::new(|s: String| s.to_uppercase()))
99            .run("hello".into())
100            .await
101            .unwrap();
102        assert_eq!(result, "HELLO");
103    }
104
105    #[tokio::test]
106    async fn parallel_step() {
107        let step = parallel(
108            AddSuffix {
109                suffix: "_a".into(),
110            },
111            AddSuffix {
112                suffix: "_b".into(),
113            },
114        );
115        let result = Pipeline::start(step).run("x".into()).await.unwrap();
116        assert_eq!(result, ("x_a".into(), "x_b".into()));
117    }
118
119    #[tokio::test]
120    async fn error_propagation() {
121        struct FailStep;
122
123        impl Step for FailStep {
124            type Input = String;
125            type Output = String;
126
127            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
128                Err(PipelineError::Custom("boom".into()))
129            }
130        }
131
132        let result = Pipeline::start(AddSuffix {
133            suffix: "ok".into(),
134        })
135        .step(FailStep)
136        .run("hi".into())
137        .await;
138        assert!(result.is_err());
139        assert!(result.unwrap_err().to_string().contains("boom"));
140    }
141
142    #[tokio::test]
143    async fn extract_step() {
144        use super::builtin::ExtractStep;
145
146        let result = Pipeline::start(MapStep::new(|_: ()| r#"{"a":1,"b":"two"}"#.to_owned()))
147            .step(ExtractStep::<serde_json::Value>::new())
148            .run(())
149            .await
150            .unwrap();
151        assert_eq!(result["a"], 1);
152        assert_eq!(result["b"], "two");
153    }
154
155    // --- LlmStep tests ---
156
157    #[tokio::test]
158    async fn llm_step_returns_response() {
159        let provider = Arc::new(MockProvider::with_responses(vec!["answer".into()]));
160        let result = Pipeline::start(LlmStep::new(provider))
161            .run("question".into())
162            .await
163            .unwrap();
164        assert_eq!(result, "answer");
165    }
166
167    #[tokio::test]
168    async fn llm_step_with_system_prompt() {
169        let provider = Arc::new(MockProvider::with_responses(vec!["ok".into()]));
170        let result = Pipeline::start(LlmStep::new(provider).with_system_prompt("sys"))
171            .run("input".into())
172            .await
173            .unwrap();
174        assert_eq!(result, "ok");
175    }
176
177    #[tokio::test]
178    async fn llm_step_propagates_error() {
179        let provider = Arc::new(MockProvider::failing());
180        let result = Pipeline::start(LlmStep::new(provider))
181            .run("input".into())
182            .await;
183        assert!(result.is_err());
184        assert!(
185            matches!(result.unwrap_err(), PipelineError::Llm(_)),
186            "expected PipelineError::Llm"
187        );
188    }
189
190    // --- RetrievalStep tests ---
191
192    #[tokio::test]
193    async fn retrieval_step_returns_results() {
194        let store = Arc::new(InMemoryVectorStore::new());
195        store.ensure_collection("col", 3).await.unwrap();
196        store
197            .upsert(
198                "col",
199                vec![VectorPoint {
200                    id: "p1".into(),
201                    vector: vec![1.0, 0.0, 0.0],
202                    payload: std::collections::HashMap::new(),
203                }],
204            )
205            .await
206            .unwrap();
207
208        let mut provider = MockProvider::default();
209        provider.supports_embeddings = true;
210        provider.embedding = vec![1.0, 0.0, 0.0];
211        let provider = Arc::new(provider);
212
213        let step = RetrievalStep::new(store, provider, "col", 5);
214        let results = Pipeline::start(step).run("query".into()).await.unwrap();
215        assert_eq!(results.len(), 1);
216        assert_eq!(results[0].id, "p1");
217    }
218
219    #[tokio::test]
220    async fn retrieval_step_embed_error_propagates() {
221        let store = Arc::new(InMemoryVectorStore::new());
222        store.ensure_collection("col", 3).await.unwrap();
223
224        let provider = Arc::new(MockProvider::default());
225
226        let step = RetrievalStep::new(store, provider, "col", 5);
227        let result = Pipeline::start(step).run("query".into()).await;
228        assert!(matches!(result.unwrap_err(), PipelineError::Llm(_)));
229    }
230
231    // --- ExtractStep failure tests ---
232
233    #[tokio::test]
234    async fn extract_step_invalid_json() {
235        let result = Pipeline::start(MapStep::new(|_: ()| "not json".to_owned()))
236            .step(ExtractStep::<serde_json::Value>::new())
237            .run(())
238            .await;
239        assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
240    }
241
242    #[tokio::test]
243    async fn extract_step_type_mismatch() {
244        #[derive(Debug, serde::Deserialize)]
245        struct Strict {
246            #[expect(dead_code)]
247            required_field: Vec<u32>,
248        }
249
250        let result = Pipeline::start(MapStep::new(|_: ()| r#"{"a":1}"#.to_owned()))
251            .step(ExtractStep::<Strict>::new())
252            .run(())
253            .await;
254        assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
255    }
256
257    // --- ParallelStep error tests ---
258
259    #[tokio::test]
260    async fn parallel_step_first_fails() {
261        struct FailStep;
262        impl Step for FailStep {
263            type Input = String;
264            type Output = String;
265            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
266                Err(PipelineError::Custom("fail_a".into()))
267            }
268        }
269
270        let step = parallel(
271            FailStep,
272            AddSuffix {
273                suffix: "_ok".into(),
274            },
275        );
276        let result = Pipeline::start(step).run("x".into()).await;
277        assert!(result.is_err());
278    }
279
280    #[tokio::test]
281    async fn parallel_step_both_fail() {
282        struct FailA;
283        impl Step for FailA {
284            type Input = String;
285            type Output = String;
286            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
287                Err(PipelineError::Custom("fail_a".into()))
288            }
289        }
290        struct FailB;
291        impl Step for FailB {
292            type Input = String;
293            type Output = String;
294            async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
295                Err(PipelineError::Custom("fail_b".into()))
296            }
297        }
298
299        let step = parallel(FailA, FailB);
300        let result = Pipeline::start(step).run("x".into()).await;
301        assert!(result.is_err());
302    }
303}