Skip to main content

zeph_core/pipeline/
builtin.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use zeph_llm::provider::{LlmProvider, Message, Role};
8use zeph_memory::vector_store::{ScoredVectorPoint, VectorStore};
9
10use super::PipelineError;
11use super::step::Step;
12
13pub struct LlmStep<P> {
14    provider: Arc<P>,
15    system_prompt: Option<String>,
16}
17
18impl<P> LlmStep<P> {
19    #[must_use]
20    pub fn new(provider: Arc<P>) -> Self {
21        Self {
22            provider,
23            system_prompt: None,
24        }
25    }
26
27    #[must_use]
28    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
29        self.system_prompt = Some(prompt.into());
30        self
31    }
32}
33
34impl<P: LlmProvider> Step for LlmStep<P> {
35    type Input = String;
36    type Output = String;
37
38    async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
39        let mut messages = Vec::new();
40        if let Some(sys) = &self.system_prompt {
41            messages.push(Message::from_legacy(Role::System, sys.clone()));
42        }
43        messages.push(Message::from_legacy(Role::User, input));
44        self.provider
45            .chat(&messages)
46            .await
47            .map_err(PipelineError::Llm)
48    }
49}
50
51pub struct RetrievalStep<P, V> {
52    store: Arc<V>,
53    provider: Arc<P>,
54    collection: String,
55    limit: u64,
56}
57
58impl<P, V> RetrievalStep<P, V> {
59    #[must_use]
60    pub fn new(store: Arc<V>, provider: Arc<P>, collection: impl Into<String>, limit: u64) -> Self {
61        Self {
62            store,
63            provider,
64            collection: collection.into(),
65            limit,
66        }
67    }
68}
69
70impl<P: LlmProvider, V: VectorStore> Step for RetrievalStep<P, V> {
71    type Input = String;
72    type Output = Vec<ScoredVectorPoint>;
73
74    async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
75        let embedding = self
76            .provider
77            .embed(&input)
78            .await
79            .map_err(PipelineError::Llm)?;
80        self.store
81            .search(&self.collection, embedding, self.limit, None)
82            .await
83            .map_err(|e| PipelineError::Memory(e.into()))
84    }
85}
86
87pub struct ExtractStep<T> {
88    _marker: PhantomData<T>,
89}
90
91impl<T> ExtractStep<T> {
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            _marker: PhantomData,
96        }
97    }
98}
99
100impl<T> Default for ExtractStep<T> {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl<T: serde::de::DeserializeOwned + Send + Sync> Step for ExtractStep<T> {
107    type Input = String;
108    type Output = T;
109
110    async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
111        serde_json::from_str(&input).map_err(|e| PipelineError::Extract(e.to_string()))
112    }
113}
114
115pub struct MapStep<F, In, Out> {
116    f: F,
117    _marker: PhantomData<fn(In) -> Out>,
118}
119
120impl<F, In, Out> MapStep<F, In, Out> {
121    #[must_use]
122    pub fn new(f: F) -> Self {
123        Self {
124            f,
125            _marker: PhantomData,
126        }
127    }
128}
129
130impl<F, In, Out> Step for MapStep<F, In, Out>
131where
132    F: Fn(In) -> Out + Send + Sync,
133    In: Send,
134    Out: Send,
135{
136    type Input = In;
137    type Output = Out;
138
139    async fn run(&self, input: Self::Input) -> Result<Self::Output, PipelineError> {
140        Ok((self.f)(input))
141    }
142}