zeph_core/pipeline/
mod.rs1pub mod builder;
5pub mod builtin;
6pub mod parallel;
7pub mod step;
8
9pub use builder::Pipeline;
10pub use parallel::ParallelStep;
11pub use step::Step;
12
13#[derive(Debug, thiserror::Error)]
14pub enum PipelineError {
15 #[error(transparent)]
16 Llm(#[from] zeph_llm::LlmError),
17
18 #[error(transparent)]
19 Memory(#[from] zeph_memory::MemoryError),
20
21 #[error("extraction failed: {0}")]
22 Extract(String),
23
24 #[error("{0}")]
25 Custom(String),
26}
27
28#[cfg(test)]
29mod tests {
30 use std::sync::Arc;
31
32 use super::builtin::{ExtractStep, LlmStep, MapStep, RetrievalStep};
33 use super::parallel::parallel;
34 use super::*;
35 use zeph_llm::mock::MockProvider;
36 use zeph_memory::in_memory_store::InMemoryVectorStore;
37 use zeph_memory::vector_store::{VectorPoint, VectorStore};
38
39 struct AddSuffix {
40 suffix: String,
41 }
42
43 impl Step for AddSuffix {
44 type Input = String;
45 type Output = String;
46
47 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
48 Ok(format!("{input}{}", self.suffix))
49 }
50 }
51
52 struct ParseLen;
53
54 impl Step for ParseLen {
55 type Input = String;
56 type Output = usize;
57
58 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
59 Ok(input.len())
60 }
61 }
62
63 #[tokio::test]
64 async fn single_step_pipeline() {
65 let result = Pipeline::start(AddSuffix { suffix: "!".into() })
66 .run("hello".into())
67 .await
68 .unwrap();
69 assert_eq!(result, "hello!");
70 }
71
72 #[tokio::test]
73 async fn chained_pipeline() {
74 let result = Pipeline::start(AddSuffix {
75 suffix: " world".into(),
76 })
77 .step(AddSuffix { suffix: "!".into() })
78 .run("hello".into())
79 .await
80 .unwrap();
81 assert_eq!(result, "hello world!");
82 }
83
84 #[tokio::test]
85 async fn heterogeneous_chain() {
86 let result = Pipeline::start(AddSuffix {
87 suffix: "abc".into(),
88 })
89 .step(ParseLen)
90 .run("".into())
91 .await
92 .unwrap();
93 assert_eq!(result, 3);
94 }
95
96 #[tokio::test]
97 async fn map_step() {
98 let result = Pipeline::start(MapStep::new(|s: String| s.to_uppercase()))
99 .run("hello".into())
100 .await
101 .unwrap();
102 assert_eq!(result, "HELLO");
103 }
104
105 #[tokio::test]
106 async fn parallel_step() {
107 let step = parallel(
108 AddSuffix {
109 suffix: "_a".into(),
110 },
111 AddSuffix {
112 suffix: "_b".into(),
113 },
114 );
115 let result = Pipeline::start(step).run("x".into()).await.unwrap();
116 assert_eq!(result, ("x_a".into(), "x_b".into()));
117 }
118
119 #[tokio::test]
120 async fn error_propagation() {
121 struct FailStep;
122
123 impl Step for FailStep {
124 type Input = String;
125 type Output = String;
126
127 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
128 Err(PipelineError::Custom("boom".into()))
129 }
130 }
131
132 let result = Pipeline::start(AddSuffix {
133 suffix: "ok".into(),
134 })
135 .step(FailStep)
136 .run("hi".into())
137 .await;
138 assert!(result.is_err());
139 assert!(result.unwrap_err().to_string().contains("boom"));
140 }
141
142 #[tokio::test]
143 async fn extract_step() {
144 use super::builtin::ExtractStep;
145
146 let result = Pipeline::start(MapStep::new(|_: ()| r#"{"a":1,"b":"two"}"#.to_owned()))
147 .step(ExtractStep::<serde_json::Value>::new())
148 .run(())
149 .await
150 .unwrap();
151 assert_eq!(result["a"], 1);
152 assert_eq!(result["b"], "two");
153 }
154
155 #[tokio::test]
158 async fn llm_step_returns_response() {
159 let provider = Arc::new(MockProvider::with_responses(vec!["answer".into()]));
160 let result = Pipeline::start(LlmStep::new(provider))
161 .run("question".into())
162 .await
163 .unwrap();
164 assert_eq!(result, "answer");
165 }
166
167 #[tokio::test]
168 async fn llm_step_with_system_prompt() {
169 let provider = Arc::new(MockProvider::with_responses(vec!["ok".into()]));
170 let result = Pipeline::start(LlmStep::new(provider).with_system_prompt("sys"))
171 .run("input".into())
172 .await
173 .unwrap();
174 assert_eq!(result, "ok");
175 }
176
177 #[tokio::test]
178 async fn llm_step_propagates_error() {
179 let provider = Arc::new(MockProvider::failing());
180 let result = Pipeline::start(LlmStep::new(provider))
181 .run("input".into())
182 .await;
183 assert!(result.is_err());
184 assert!(
185 matches!(result.unwrap_err(), PipelineError::Llm(_)),
186 "expected PipelineError::Llm"
187 );
188 }
189
190 #[tokio::test]
193 async fn retrieval_step_returns_results() {
194 let store = Arc::new(InMemoryVectorStore::new());
195 store.ensure_collection("col", 3).await.unwrap();
196 store
197 .upsert(
198 "col",
199 vec![VectorPoint {
200 id: "p1".into(),
201 vector: vec![1.0, 0.0, 0.0],
202 payload: std::collections::HashMap::new(),
203 }],
204 )
205 .await
206 .unwrap();
207
208 let mut provider = MockProvider::default();
209 provider.supports_embeddings = true;
210 provider.embedding = vec![1.0, 0.0, 0.0];
211 let provider = Arc::new(provider);
212
213 let step = RetrievalStep::new(store, provider, "col", 5);
214 let results = Pipeline::start(step).run("query".into()).await.unwrap();
215 assert_eq!(results.len(), 1);
216 assert_eq!(results[0].id, "p1");
217 }
218
219 #[tokio::test]
220 async fn retrieval_step_embed_error_propagates() {
221 let store = Arc::new(InMemoryVectorStore::new());
222 store.ensure_collection("col", 3).await.unwrap();
223
224 let provider = Arc::new(MockProvider::default());
225
226 let step = RetrievalStep::new(store, provider, "col", 5);
227 let result = Pipeline::start(step).run("query".into()).await;
228 assert!(matches!(result.unwrap_err(), PipelineError::Llm(_)));
229 }
230
231 #[tokio::test]
234 async fn extract_step_invalid_json() {
235 let result = Pipeline::start(MapStep::new(|_: ()| "not json".to_owned()))
236 .step(ExtractStep::<serde_json::Value>::new())
237 .run(())
238 .await;
239 assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
240 }
241
242 #[tokio::test]
243 async fn extract_step_type_mismatch() {
244 #[derive(Debug, serde::Deserialize)]
245 struct Strict {
246 #[expect(dead_code)]
247 required_field: Vec<u32>,
248 }
249
250 let result = Pipeline::start(MapStep::new(|_: ()| r#"{"a":1}"#.to_owned()))
251 .step(ExtractStep::<Strict>::new())
252 .run(())
253 .await;
254 assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
255 }
256
257 #[tokio::test]
260 async fn parallel_step_first_fails() {
261 struct FailStep;
262 impl Step for FailStep {
263 type Input = String;
264 type Output = String;
265 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
266 Err(PipelineError::Custom("fail_a".into()))
267 }
268 }
269
270 let step = parallel(
271 FailStep,
272 AddSuffix {
273 suffix: "_ok".into(),
274 },
275 );
276 let result = Pipeline::start(step).run("x".into()).await;
277 assert!(result.is_err());
278 }
279
280 #[tokio::test]
281 async fn parallel_step_both_fail() {
282 struct FailA;
283 impl Step for FailA {
284 type Input = String;
285 type Output = String;
286 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
287 Err(PipelineError::Custom("fail_a".into()))
288 }
289 }
290 struct FailB;
291 impl Step for FailB {
292 type Input = String;
293 type Output = String;
294 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
295 Err(PipelineError::Custom("fail_b".into()))
296 }
297 }
298
299 let step = parallel(FailA, FailB);
300 let result = Pipeline::start(step).run("x".into()).await;
301 assert!(result.is_err());
302 }
303}