1use std::future::Future;
12use std::pin::Pin;
13
14use crate::{Error, Result};
15
16#[derive(Debug, Clone)]
18pub enum EmbeddingMode {
19 Local { model_path: String },
21 Ollama { base_url: String, model: String },
23 ZeroClaw { base_url: String, api_key: String },
25 LlmProvider {
27 base_url: String,
28 api_key: String,
29 model: String,
30 },
31}
32
33pub trait EmbeddingProvider: Send + Sync {
35 fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>>;
37
38 fn dimensions(&self) -> usize;
40}
41
42pub struct OllamaProvider {
48 base_url: String,
49 model: String,
50 dims: usize,
51}
52
53impl OllamaProvider {
54 pub fn new(base_url: &str, model: &str, dims: usize) -> Self {
55 Self {
56 base_url: base_url.trim_end_matches('/').to_string(),
57 model: model.to_string(),
58 dims,
59 }
60 }
61}
62
63impl EmbeddingProvider for OllamaProvider {
64 fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
65 let url = format!("{}/api/embeddings", self.base_url);
66 let body = serde_json::json!({
67 "model": self.model,
68 "prompt": text,
69 });
70
71 Box::pin(async move {
72 let client = reqwest::Client::new();
73 let resp = client
74 .post(&url)
75 .json(&body)
76 .send()
77 .await
78 .map_err(|e| Error::Http(format!("Ollama request failed: {}", e)))?;
79
80 if !resp.status().is_success() {
81 let status = resp.status();
82 let text = resp.text().await.unwrap_or_default();
83 return Err(Error::Http(format!("Ollama returned {}: {}", status, text)));
84 }
85
86 let data: serde_json::Value = resp
87 .json()
88 .await
89 .map_err(|e| Error::Http(format!("Ollama JSON parse failed: {}", e)))?;
90
91 let embedding = data["embedding"]
92 .as_array()
93 .ok_or_else(|| {
94 Error::Embedding("no 'embedding' array in Ollama response".into())
95 })?
96 .iter()
97 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
98 .collect();
99
100 Ok(embedding)
101 })
102 }
103
104 fn dimensions(&self) -> usize {
105 self.dims
106 }
107}
108
109pub struct OpenAiCompatibleProvider {
117 base_url: String,
118 api_key: String,
119 model: String,
120 dims: usize,
121}
122
123impl OpenAiCompatibleProvider {
124 pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
125 Self {
126 base_url: base_url.trim_end_matches('/').to_string(),
127 api_key: api_key.to_string(),
128 model: model.to_string(),
129 dims,
130 }
131 }
132}
133
134impl EmbeddingProvider for OpenAiCompatibleProvider {
135 fn embed(&self, text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
136 let url = format!("{}/v1/embeddings", self.base_url);
137 let body = serde_json::json!({
138 "model": self.model,
139 "input": text,
140 });
141 let api_key = self.api_key.clone();
142
143 Box::pin(async move {
144 let client = reqwest::Client::new();
145 let resp = client
146 .post(&url)
147 .header("Authorization", format!("Bearer {}", api_key))
148 .json(&body)
149 .send()
150 .await
151 .map_err(|e| Error::Http(format!("embedding request failed: {}", e)))?;
152
153 if !resp.status().is_success() {
154 let status = resp.status();
155 let text = resp.text().await.unwrap_or_default();
156 return Err(Error::Http(format!(
157 "embedding provider returned {}: {}",
158 status, text
159 )));
160 }
161
162 let data: serde_json::Value = resp
163 .json()
164 .await
165 .map_err(|e| Error::Http(format!("JSON parse failed: {}", e)))?;
166
167 let embedding = data["data"][0]["embedding"]
168 .as_array()
169 .ok_or_else(|| {
170 Error::Embedding("no 'data[0].embedding' in response".into())
171 })?
172 .iter()
173 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
174 .collect();
175
176 Ok(embedding)
177 })
178 }
179
180 fn dimensions(&self) -> usize {
181 self.dims
182 }
183}
184
185#[cfg(feature = "local-embeddings")]
190pub struct LocalOnnxProvider {
191 _model_path: String,
192 dims: usize,
193}
194
195#[cfg(feature = "local-embeddings")]
196impl LocalOnnxProvider {
197 pub fn new(model_path: &str, dims: usize) -> Result<Self> {
198 Ok(Self {
199 _model_path: model_path.to_string(),
200 dims,
201 })
202 }
203}
204
205#[cfg(feature = "local-embeddings")]
206impl EmbeddingProvider for LocalOnnxProvider {
207 fn embed(&self, _text: &str) -> Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + '_>> {
208 Box::pin(async {
209 Err(Error::Embedding(
210 "local ONNX embedding not yet fully implemented".into(),
211 ))
212 })
213 }
214
215 fn dimensions(&self) -> usize {
216 self.dims
217 }
218}
219
220pub fn create_provider(mode: EmbeddingMode, dims: usize) -> Result<Box<dyn EmbeddingProvider>> {
226 match mode {
227 EmbeddingMode::Ollama { base_url, model } => {
228 Ok(Box::new(OllamaProvider::new(&base_url, &model, dims)))
229 }
230 EmbeddingMode::ZeroClaw { base_url, api_key } => Ok(Box::new(
231 OpenAiCompatibleProvider::new(&base_url, &api_key, "harrier-oss-v1-270m", dims),
232 )),
233 EmbeddingMode::LlmProvider {
234 base_url,
235 api_key,
236 model,
237 } => Ok(Box::new(OpenAiCompatibleProvider::new(
238 &base_url, &api_key, &model, dims,
239 ))),
240 #[cfg(feature = "local-embeddings")]
241 EmbeddingMode::Local { model_path } => {
242 Ok(Box::new(LocalOnnxProvider::new(&model_path, dims)?))
243 }
244 #[cfg(not(feature = "local-embeddings"))]
245 EmbeddingMode::Local { .. } => Err(Error::Embedding(
246 "local embeddings require the 'local-embeddings' feature".into(),
247 )),
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_create_provider_ollama() {
257 let provider = create_provider(
258 EmbeddingMode::Ollama {
259 base_url: "http://localhost:11434".into(),
260 model: "harrier-oss-v1-270m".into(),
261 },
262 640,
263 );
264 assert!(provider.is_ok());
265 assert_eq!(provider.unwrap().dimensions(), 640);
266 }
267
268 #[test]
269 fn test_create_provider_zeroclaw() {
270 let provider = create_provider(
271 EmbeddingMode::ZeroClaw {
272 base_url: "https://api.example.com".into(),
273 api_key: "test-key".into(),
274 },
275 640,
276 );
277 assert!(provider.is_ok());
278 assert_eq!(provider.unwrap().dimensions(), 640);
279 }
280
281 #[test]
282 fn test_create_provider_llm() {
283 let provider = create_provider(
284 EmbeddingMode::LlmProvider {
285 base_url: "https://api.openai.com".into(),
286 api_key: "test-key".into(),
287 model: "text-embedding-3-small".into(),
288 },
289 1536,
290 );
291 assert!(provider.is_ok());
292 assert_eq!(provider.unwrap().dimensions(), 1536);
293 }
294
295 #[test]
296 fn test_create_provider_local_without_feature() {
297 let provider = create_provider(
298 EmbeddingMode::Local {
299 model_path: "/tmp/model".into(),
300 },
301 640,
302 );
303 #[cfg(not(feature = "local-embeddings"))]
304 assert!(provider.is_err());
305 #[cfg(feature = "local-embeddings")]
306 assert!(provider.is_ok());
307 }
308}