zeph_core/pipeline/
builtin.rs1use std::marker::PhantomData;
5use std::sync::Arc;
6
7use zeph_llm::provider::{LlmProvider, Message, Role};
8use zeph_memory::vector_store::{ScoredVectorPoint, VectorStore};
9
10use super::PipelineError;
11use super::step::Step;
12
13pub struct LlmStep<P> {
14 provider: Arc<P>,
15 system_prompt: Option<String>,
16}
17
18impl<P> LlmStep<P> {
19 #[must_use]
20 pub fn new(provider: Arc<P>) -> Self {
21 Self {
22 provider,
23 system_prompt: None,
24 }
25 }
26
27 #[must_use]
28 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
29 self.system_prompt = Some(prompt.into());
30 self
31 }
32}
33
34impl<P: LlmProvider> Step for LlmStep<P> {
35 type Input = String;
36 type Output = String;
37
38 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
39 let mut messages = Vec::new();
40 if let Some(sys) = &self.system_prompt {
41 messages.push(Message::from_legacy(Role::System, sys.clone()));
42 }
43 messages.push(Message::from_legacy(Role::User, input));
44 self.provider
45 .chat(&messages)
46 .await
47 .map_err(PipelineError::Llm)
48 }
49}
50
51pub struct RetrievalStep<P, V> {
52 store: Arc<V>,
53 provider: Arc<P>,
54 collection: String,
55 limit: u64,
56}
57
58impl<P, V> RetrievalStep<P, V> {
59 #[must_use]
60 pub fn new(store: Arc<V>, provider: Arc<P>, collection: impl Into<String>, limit: u64) -> Self {
61 Self {
62 store,
63 provider,
64 collection: collection.into(),
65 limit,
66 }
67 }
68}
69
70impl<P: LlmProvider, V: VectorStore> Step for RetrievalStep<P, V> {
71 type Input = String;
72 type Output = Vec<ScoredVectorPoint>;
73
74 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
75 let embedding = self
76 .provider
77 .embed(&input)
78 .await
79 .map_err(PipelineError::Llm)?;
80 self.store
81 .search(&self.collection, embedding, self.limit, None)
82 .await
83 .map_err(|e| PipelineError::Memory(e.into()))
84 }
85}
86
87pub struct ExtractStep<T> {
88 _marker: PhantomData<T>,
89}
90
91impl<T> ExtractStep<T> {
92 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 _marker: PhantomData,
96 }
97 }
98}
99
100impl<T> Default for ExtractStep<T> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl<T: serde::de::DeserializeOwned + Send + Sync> Step for ExtractStep<T> {
107 type Input = String;
108 type Output = T;
109
110 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
111 serde_json::from_str(&input).map_err(|e| PipelineError::Extract(e.to_string()))
112 }
113}
114
115pub struct MapStep<F, In, Out> {
116 f: F,
117 _marker: PhantomData<fn(In) -> Out>,
118}
119
120impl<F, In, Out> MapStep<F, In, Out> {
121 #[must_use]
122 pub fn new(f: F) -> Self {
123 Self {
124 f,
125 _marker: PhantomData,
126 }
127 }
128}
129
130impl<F, In, Out> Step for MapStep<F, In, Out>
131where
132 F: Fn(In) -> Out + Send + Sync,
133 In: Send,
134 Out: Send,
135{
136 type Input = In;
137 type Output = Out;
138
139 async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
140 Ok((self.f)(input))
141 }
142}