zeph_core/pipeline/
mod.rs1pub mod builder;
5pub mod builtin;
6pub mod parallel;
7pub mod step;
8
9pub use builder::Pipeline;
10pub use parallel::ParallelStep;
11pub use step::Step;
12
13#[derive(Debug, thiserror::Error)]
14pub enum PipelineError {
15 #[error(transparent)]
16 Llm(#[from] zeph_llm::LlmError),
17
18 #[error(transparent)]
19 Memory(#[from] zeph_memory::MemoryError),
20
21 #[error("extraction failed: {0}")]
22 Extract(String),
23
24 #[error("{0}")]
25 Custom(String),
26}
27
28#[cfg(test)]
29mod tests {
30 #![allow(clippy::ignored_unit_patterns, clippy::manual_string_new)]
31
32 use std::sync::Arc;
33
34 use super::builtin::{ExtractStep, LlmStep, MapStep, RetrievalStep};
35 use super::parallel::parallel;
36 use super::*;
37 use zeph_llm::mock::MockProvider;
38 use zeph_memory::in_memory_store::InMemoryVectorStore;
39 use zeph_memory::vector_store::{VectorPoint, VectorStore};
40
41 struct AddSuffix {
42 suffix: String,
43 }
44
45 impl Step for AddSuffix {
46 type Input = String;
47 type Output = String;
48
49 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
50 Ok(format!("{input}{}", self.suffix))
51 }
52 }
53
54 struct ParseLen;
55
56 impl Step for ParseLen {
57 type Input = String;
58 type Output = usize;
59
60 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
61 Ok(input.len())
62 }
63 }
64
65 #[tokio::test]
66 async fn single_step_pipeline() {
67 let result = Pipeline::start(AddSuffix { suffix: "!".into() })
68 .run("hello".into())
69 .await
70 .unwrap();
71 assert_eq!(result, "hello!");
72 }
73
74 #[tokio::test]
75 async fn chained_pipeline() {
76 let result = Pipeline::start(AddSuffix {
77 suffix: " world".into(),
78 })
79 .step(AddSuffix { suffix: "!".into() })
80 .run("hello".into())
81 .await
82 .unwrap();
83 assert_eq!(result, "hello world!");
84 }
85
86 #[tokio::test]
87 async fn heterogeneous_chain() {
88 let result = Pipeline::start(AddSuffix {
89 suffix: "abc".into(),
90 })
91 .step(ParseLen)
92 .run(String::new())
93 .await
94 .unwrap();
95 assert_eq!(result, 3);
96 }
97
98 #[tokio::test]
99 async fn map_step() {
100 let result = Pipeline::start(MapStep::new(|s: String| s.to_uppercase()))
101 .run("hello".into())
102 .await
103 .unwrap();
104 assert_eq!(result, "HELLO");
105 }
106
107 #[tokio::test]
108 async fn parallel_step() {
109 let step = parallel(
110 AddSuffix {
111 suffix: "_a".into(),
112 },
113 AddSuffix {
114 suffix: "_b".into(),
115 },
116 );
117 let result = Pipeline::start(step).run("x".into()).await.unwrap();
118 assert_eq!(result, ("x_a".into(), "x_b".into()));
119 }
120
121 #[tokio::test]
122 async fn error_propagation() {
123 struct FailStep;
124
125 impl Step for FailStep {
126 type Input = String;
127 type Output = String;
128
129 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
130 Err(PipelineError::Custom("boom".into()))
131 }
132 }
133
134 let result = Pipeline::start(AddSuffix {
135 suffix: "ok".into(),
136 })
137 .step(FailStep)
138 .run("hi".into())
139 .await;
140 assert!(result.is_err());
141 assert!(result.unwrap_err().to_string().contains("boom"));
142 }
143
144 #[tokio::test]
145 async fn extract_step() {
146 use super::builtin::ExtractStep;
147
148 let result = Pipeline::start(MapStep::new(|(): ()| r#"{"a":1,"b":"two"}"#.to_owned()))
149 .step(ExtractStep::<serde_json::Value>::new())
150 .run(())
151 .await
152 .unwrap();
153 assert_eq!(result["a"], 1);
154 assert_eq!(result["b"], "two");
155 }
156
157 #[tokio::test]
160 async fn llm_step_returns_response() {
161 let provider = Arc::new(MockProvider::with_responses(vec!["answer".into()]));
162 let result = Pipeline::start(LlmStep::new(provider))
163 .run("question".into())
164 .await
165 .unwrap();
166 assert_eq!(result, "answer");
167 }
168
169 #[tokio::test]
170 async fn llm_step_with_system_prompt() {
171 let provider = Arc::new(MockProvider::with_responses(vec!["ok".into()]));
172 let result = Pipeline::start(LlmStep::new(provider).with_system_prompt("sys"))
173 .run("input".into())
174 .await
175 .unwrap();
176 assert_eq!(result, "ok");
177 }
178
179 #[tokio::test]
180 async fn llm_step_propagates_error() {
181 let provider = Arc::new(MockProvider::failing());
182 let result = Pipeline::start(LlmStep::new(provider))
183 .run("input".into())
184 .await;
185 assert!(result.is_err());
186 assert!(
187 matches!(result.unwrap_err(), PipelineError::Llm(_)),
188 "expected PipelineError::Llm"
189 );
190 }
191
192 #[tokio::test]
195 async fn retrieval_step_returns_results() {
196 let store = Arc::new(InMemoryVectorStore::new());
197 store.ensure_collection("col", 3).await.unwrap();
198 store
199 .upsert(
200 "col",
201 vec![VectorPoint {
202 id: "p1".into(),
203 vector: vec![1.0, 0.0, 0.0],
204 payload: std::collections::HashMap::new(),
205 }],
206 )
207 .await
208 .unwrap();
209
210 let mut provider = MockProvider::default();
211 provider.supports_embeddings = true;
212 provider.embedding = vec![1.0, 0.0, 0.0];
213 let provider = Arc::new(provider);
214
215 let step = RetrievalStep::new(store, provider, "col", 5);
216 let results = Pipeline::start(step).run("query".into()).await.unwrap();
217 assert_eq!(results.len(), 1);
218 assert_eq!(results[0].id, "p1");
219 }
220
221 #[tokio::test]
222 async fn retrieval_step_embed_error_propagates() {
223 let store = Arc::new(InMemoryVectorStore::new());
224 store.ensure_collection("col", 3).await.unwrap();
225
226 let provider = Arc::new(MockProvider::default());
227
228 let step = RetrievalStep::new(store, provider, "col", 5);
229 let result = Pipeline::start(step).run("query".into()).await;
230 assert!(matches!(result.unwrap_err(), PipelineError::Llm(_)));
231 }
232
233 #[tokio::test]
236 async fn extract_step_invalid_json() {
237 let result = Pipeline::start(MapStep::new(|(): ()| "not json".to_owned()))
238 .step(ExtractStep::<serde_json::Value>::new())
239 .run(())
240 .await;
241 assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
242 }
243
244 #[tokio::test]
245 async fn extract_step_type_mismatch() {
246 #[derive(Debug, serde::Deserialize)]
247 struct Strict {
248 #[expect(dead_code)]
249 required_field: Vec<u32>,
250 }
251
252 let result = Pipeline::start(MapStep::new(|(): ()| r#"{"a":1}"#.to_owned()))
253 .step(ExtractStep::<Strict>::new())
254 .run(())
255 .await;
256 assert!(matches!(result.unwrap_err(), PipelineError::Extract(_)));
257 }
258
259 #[tokio::test]
262 async fn parallel_step_first_fails() {
263 struct FailStep;
264 impl Step for FailStep {
265 type Input = String;
266 type Output = String;
267 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
268 Err(PipelineError::Custom("fail_a".into()))
269 }
270 }
271
272 let step = parallel(
273 FailStep,
274 AddSuffix {
275 suffix: "_ok".into(),
276 },
277 );
278 let result = Pipeline::start(step).run("x".into()).await;
279 assert!(result.is_err());
280 }
281
282 #[tokio::test]
283 async fn parallel_step_both_fail() {
284 struct FailA;
285 impl Step for FailA {
286 type Input = String;
287 type Output = String;
288 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
289 Err(PipelineError::Custom("fail_a".into()))
290 }
291 }
292 struct FailB;
293 impl Step for FailB {
294 type Input = String;
295 type Output = String;
296 async fn run(&self, _input: Self::Input) -> Result<Self::Output, PipelineError> {
297 Err(PipelineError::Custom("fail_b".into()))
298 }
299 }
300
301 let step = parallel(FailA, FailB);
302 let result = Pipeline::start(step).run("x".into()).await;
303 assert!(result.is_err());
304 }
305}