Skip to main content

totalreclaw_memory/
embedding.rs

1//! Embedding pipeline with 4 modes.
2//!
3//! - `Local` — ONNX runtime (feature-gated with `local-embeddings`)
4//! - `Ollama` — HTTP POST to local Ollama server
5//! - `ZeroClaw` — Remote ZeroClaw embedding provider
6//! - `LlmProvider` — OpenAI-compatible `/v1/embeddings` endpoint
7//!
8//! The local ONNX mode is gated behind `#[cfg(feature = "local-embeddings")]`
9//! to avoid pulling in the `ort` and `tokenizers` dependencies by default.
10
11use std::future::Future;
12use std::pin::Pin;
13
14use crate::{Error, Result};
15
16/// Embedding mode enum.
17#[derive(Debug, Clone)]
18pub enum EmbeddingMode {
19    /// Local ONNX model (requires `local-embeddings` feature).
20    Local { model_path: String },
21    /// Ollama server.
22    Ollama { base_url: String, model: String },
23    /// ZeroClaw remote embedding provider.
24    ZeroClaw { base_url: String, api_key: String },
25    /// OpenAI-compatible /v1/embeddings endpoint.
26    LlmProvider {
27        base_url: String,
28        api_key: String,
29        model: String,
30    },
31}
32
33/// Trait for embedding providers.
34pub trait EmbeddingProvider: Send + Sync {
35    /// Embed a single text string into a vector.
36    fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>>;
37
38    /// Return the embedding dimensionality.
39    fn dimensions(&self) -> usize;
40}
41
42// ---------------------------------------------------------------------------
43// Ollama provider
44// ---------------------------------------------------------------------------
45
46/// Ollama embedding provider.
47pub struct OllamaProvider {
48    base_url: String,
49    model: String,
50    dims: usize,
51}
52
53impl OllamaProvider {
54    pub fn new(base_url: &str, model: &str, dims: usize) -> Self {
55        Self {
56            base_url: base_url.trim_end_matches('/').to_string(),
57            model: model.to_string(),
58            dims,
59        }
60    }
61}
62
63impl EmbeddingProvider for OllamaProvider {
64    fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
65        let url = format!("{}/api/embeddings", self.base_url);
66        let body = serde_json::json!({
67            "model": self.model,
68            "prompt": text,
69        });
70
71        Box::pin(async move {
72            let client = reqwest::Client::new();
73            let resp = client
74                .post(&url)
75                .json(&body)
76                .send()
77                .await
78                .map_err(|e| Error::Http(format!("Ollama request failed: {}", e)))?;
79
80            if !resp.status().is_success() {
81                let status = resp.status();
82                let text = resp.text().await.unwrap_or_default();
83                return Err(Error::Http(format!("Ollama returned {}: {}", status, text)));
84            }
85
86            let data: serde_json::Value = resp
87                .json()
88                .await
89                .map_err(|e| Error::Http(format!("Ollama JSON parse failed: {}", e)))?;
90
91            let embedding = data["embedding"]
92                .as_array()
93                .ok_or_else(|| {
94                    Error::Embedding("no 'embedding' array in Ollama response".into())
95                })?
96                .iter()
97                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
98                .collect();
99
100            Ok(embedding)
101        })
102    }
103
104    fn dimensions(&self) -> usize {
105        self.dims
106    }
107}
108
109// ---------------------------------------------------------------------------
110// OpenAI-compatible provider (ZeroClaw + LlmProvider)
111// ---------------------------------------------------------------------------
112
113/// OpenAI-compatible embedding provider.
114///
115/// Works with any server that implements `/v1/embeddings` (OpenAI, ZeroClaw, etc.).
116pub struct OpenAiCompatibleProvider {
117    base_url: String,
118    api_key: String,
119    model: String,
120    dims: usize,
121}
122
123impl OpenAiCompatibleProvider {
124    pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
125        Self {
126            base_url: base_url.trim_end_matches('/').to_string(),
127            api_key: api_key.to_string(),
128            model: model.to_string(),
129            dims,
130        }
131    }
132}
133
134impl EmbeddingProvider for OpenAiCompatibleProvider {
135    fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
136        let url = format!("{}/v1/embeddings", self.base_url);
137        let body = serde_json::json!({
138            "model": self.model,
139            "input": text,
140        });
141        let api_key = self.api_key.clone();
142
143        Box::pin(async move {
144            let client = reqwest::Client::new();
145            let resp = client
146                .post(&url)
147                .header("Authorization", format!("Bearer {}", api_key))
148                .json(&body)
149                .send()
150                .await
151                .map_err(|e| Error::Http(format!("embedding request failed: {}", e)))?;
152
153            if !resp.status().is_success() {
154                let status = resp.status();
155                let text = resp.text().await.unwrap_or_default();
156                return Err(Error::Http(format!(
157                    "embedding provider returned {}: {}",
158                    status, text
159                )));
160            }
161
162            let data: serde_json::Value = resp
163                .json()
164                .await
165                .map_err(|e| Error::Http(format!("JSON parse failed: {}", e)))?;
166
167            let embedding = data["data"][0]["embedding"]
168                .as_array()
169                .ok_or_else(|| {
170                    Error::Embedding("no 'data[0].embedding' in response".into())
171                })?
172                .iter()
173                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
174                .collect();
175
176            Ok(embedding)
177        })
178    }
179
180    fn dimensions(&self) -> usize {
181        self.dims
182    }
183}
184
185// ---------------------------------------------------------------------------
186// Local ONNX provider (feature-gated)
187// ---------------------------------------------------------------------------
188
189#[cfg(feature = "local-embeddings")]
190pub struct LocalOnnxProvider {
191    _model_path: String,
192    dims: usize,
193}
194
195#[cfg(feature = "local-embeddings")]
196impl LocalOnnxProvider {
197    pub fn new(model_path: &str, dims: usize) -> Result<Self> {
198        Ok(Self {
199            _model_path: model_path.to_string(),
200            dims,
201        })
202    }
203}
204
205#[cfg(feature = "local-embeddings")]
206impl EmbeddingProvider for LocalOnnxProvider {
207    fn embed(&self, _text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
208        Box::pin(async {
209            Err(Error::Embedding(
210                "local ONNX embedding not yet fully implemented".into(),
211            ))
212        })
213    }
214
215    fn dimensions(&self) -> usize {
216        self.dims
217    }
218}
219
220// ---------------------------------------------------------------------------
221// Factory
222// ---------------------------------------------------------------------------
223
224/// Create an embedding provider from a mode configuration.
225pub fn create_provider(mode: EmbeddingMode, dims: usize) -> Result<Box<dyn EmbeddingProvider>> {
226    match mode {
227        EmbeddingMode::Ollama { base_url, model } => {
228            Ok(Box::new(OllamaProvider::new(&base_url, &model, dims)))
229        }
230        EmbeddingMode::ZeroClaw { base_url, api_key } => Ok(Box::new(
231            OpenAiCompatibleProvider::new(&base_url, &api_key, "harrier-oss-v1-270m", dims),
232        )),
233        EmbeddingMode::LlmProvider {
234            base_url,
235            api_key,
236            model,
237        } => Ok(Box::new(OpenAiCompatibleProvider::new(
238            &base_url, &api_key, &model, dims,
239        ))),
240        #[cfg(feature = "local-embeddings")]
241        EmbeddingMode::Local { model_path } => {
242            Ok(Box::new(LocalOnnxProvider::new(&model_path, dims)?))
243        }
244        #[cfg(not(feature = "local-embeddings"))]
245        EmbeddingMode::Local { .. } => Err(Error::Embedding(
246            "local embeddings require the 'local-embeddings' feature".into(),
247        )),
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_create_provider_ollama() {
257        let provider = create_provider(
258            EmbeddingMode::Ollama {
259                base_url: "http://localhost:11434".into(),
260                model: "harrier-oss-v1-270m".into(),
261            },
262            640,
263        );
264        assert!(provider.is_ok());
265        assert_eq!(provider.unwrap().dimensions(), 640);
266    }
267
268    #[test]
269    fn test_create_provider_zeroclaw() {
270        let provider = create_provider(
271            EmbeddingMode::ZeroClaw {
272                base_url: "https://api.example.com".into(),
273                api_key: "test-key".into(),
274            },
275            640,
276        );
277        assert!(provider.is_ok());
278        assert_eq!(provider.unwrap().dimensions(), 640);
279    }
280
281    #[test]
282    fn test_create_provider_llm() {
283        let provider = create_provider(
284            EmbeddingMode::LlmProvider {
285                base_url: "https://api.openai.com".into(),
286                api_key: "test-key".into(),
287                model: "text-embedding-3-small".into(),
288            },
289            1536,
290        );
291        assert!(provider.is_ok());
292        assert_eq!(provider.unwrap().dimensions(), 1536);
293    }
294
295    #[test]
296    fn test_create_provider_local_without_feature() {
297        let provider = create_provider(
298            EmbeddingMode::Local {
299                model_path: "/tmp/model".into(),
300            },
301            640,
302        );
303        #[cfg(not(feature = "local-embeddings"))]
304        assert!(provider.is_err());
305        #[cfg(feature = "local-embeddings")]
306        assert!(provider.is_ok());
307    }
308}