1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
6
7#[derive(Debug, Clone)]
13pub struct ElasticsearchConfig {
14 pub url: String,
16 pub index_name: String,
18 pub vector_field: String,
20 pub content_field: String,
22 pub dims: usize,
24 pub username: Option<String>,
26 pub password: Option<String>,
28}
29
30impl ElasticsearchConfig {
31 pub fn new(index_name: impl Into<String>, dims: usize) -> Self {
36 Self {
37 url: "http://localhost:9200".to_string(),
38 index_name: index_name.into(),
39 vector_field: "embedding".to_string(),
40 content_field: "content".to_string(),
41 dims,
42 username: None,
43 password: None,
44 }
45 }
46
47 pub fn with_url(mut self, url: impl Into<String>) -> Self {
49 self.url = url.into();
50 self
51 }
52
53 pub fn with_vector_field(mut self, field: impl Into<String>) -> Self {
55 self.vector_field = field.into();
56 self
57 }
58
59 pub fn with_content_field(mut self, field: impl Into<String>) -> Self {
61 self.content_field = field.into();
62 self
63 }
64
65 pub fn with_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
67 self.username = Some(username.into());
68 self.password = Some(password.into());
69 self
70 }
71}
72
73pub struct ElasticsearchVectorStore {
89 config: ElasticsearchConfig,
90 client: reqwest::Client,
91}
92
93impl ElasticsearchVectorStore {
94 pub fn new(config: ElasticsearchConfig) -> Self {
96 Self {
97 config,
98 client: reqwest::Client::new(),
99 }
100 }
101
102 pub fn config(&self) -> &ElasticsearchConfig {
104 &self.config
105 }
106
107 fn url(&self, path: &str) -> String {
109 let base = self.config.url.trim_end_matches('/');
110 format!("{base}{path}")
111 }
112
113 fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
115 if let (Some(ref user), Some(ref pass)) = (&self.config.username, &self.config.password) {
116 builder.basic_auth(user, Some(pass))
117 } else {
118 builder
119 }
120 }
121
122 pub async fn ensure_index(&self) -> Result<(), SynapticError> {
127 let index_url = self.url(&format!("/{}", self.config.index_name));
128
129 let head_req = self.apply_auth(self.client.head(&index_url));
131 let head_resp = head_req.send().await.map_err(|e| {
132 SynapticError::VectorStore(format!("Elasticsearch HEAD request failed: {e}"))
133 })?;
134
135 if head_resp.status().is_success() {
136 return Ok(());
138 }
139
140 let mappings = serde_json::json!({
142 "mappings": {
143 "properties": {
144 &self.config.content_field: {
145 "type": "text"
146 },
147 &self.config.vector_field: {
148 "type": "dense_vector",
149 "dims": self.config.dims,
150 "index": true,
151 "similarity": "cosine"
152 },
153 "metadata": {
154 "type": "object",
155 "enabled": false
156 }
157 }
158 }
159 });
160
161 let put_req = self
162 .apply_auth(self.client.put(&index_url))
163 .header("Content-Type", "application/json")
164 .json(&mappings);
165
166 let put_resp = put_req.send().await.map_err(|e| {
167 SynapticError::VectorStore(format!("Elasticsearch PUT index failed: {e}"))
168 })?;
169
170 let status = put_resp.status();
171 if !status.is_success() {
172 let text = put_resp.text().await.unwrap_or_default();
173 return Err(SynapticError::VectorStore(format!(
174 "Elasticsearch create index error (HTTP {status}): {text}"
175 )));
176 }
177
178 Ok(())
179 }
180}
181
182#[async_trait]
187impl VectorStore for ElasticsearchVectorStore {
188 async fn add_documents(
189 &self,
190 docs: Vec<Document>,
191 embeddings: &dyn Embeddings,
192 ) -> Result<Vec<String>, SynapticError> {
193 if docs.is_empty() {
194 return Ok(Vec::new());
195 }
196
197 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
199 let vectors = embeddings.embed_documents(&texts).await?;
200
201 let mut ids = Vec::with_capacity(docs.len());
202 let mut bulk_body = String::new();
203
204 for (doc, vector) in docs.into_iter().zip(vectors) {
205 let id = if doc.id.is_empty() {
206 generate_id()
207 } else {
208 doc.id.clone()
209 };
210
211 let action = serde_json::json!({
213 "index": {
214 "_index": self.config.index_name,
215 "_id": id,
216 }
217 });
218 bulk_body.push_str(&action.to_string());
219 bulk_body.push('\n');
220
221 let doc_body = serde_json::json!({
223 &self.config.content_field: doc.content,
224 &self.config.vector_field: vector,
225 "metadata": doc.metadata,
226 });
227 bulk_body.push_str(&doc_body.to_string());
228 bulk_body.push('\n');
229
230 ids.push(id);
231 }
232
233 let bulk_url = self.url("/_bulk");
234 let req = self
235 .apply_auth(self.client.post(&bulk_url))
236 .header("Content-Type", "application/x-ndjson")
237 .body(bulk_body);
238
239 let resp = req.send().await.map_err(|e| {
240 SynapticError::VectorStore(format!("Elasticsearch bulk request failed: {e}"))
241 })?;
242
243 let status = resp.status();
244 let text = resp.text().await.map_err(|e| {
245 SynapticError::VectorStore(format!("failed to read Elasticsearch response: {e}"))
246 })?;
247
248 if !status.is_success() {
249 return Err(SynapticError::VectorStore(format!(
250 "Elasticsearch bulk error (HTTP {status}): {text}"
251 )));
252 }
253
254 let parsed: Value = serde_json::from_str(&text).map_err(|e| {
256 SynapticError::VectorStore(format!("failed to parse Elasticsearch bulk response: {e}"))
257 })?;
258
259 if parsed
260 .get("errors")
261 .and_then(|v| v.as_bool())
262 .unwrap_or(false)
263 {
264 return Err(SynapticError::VectorStore(format!(
265 "Elasticsearch bulk operation had errors: {text}"
266 )));
267 }
268
269 Ok(ids)
270 }
271
272 async fn similarity_search(
273 &self,
274 query: &str,
275 k: usize,
276 embeddings: &dyn Embeddings,
277 ) -> Result<Vec<Document>, SynapticError> {
278 let results = self
279 .similarity_search_with_score(query, k, embeddings)
280 .await?;
281 Ok(results.into_iter().map(|(doc, _)| doc).collect())
282 }
283
284 async fn similarity_search_with_score(
285 &self,
286 query: &str,
287 k: usize,
288 embeddings: &dyn Embeddings,
289 ) -> Result<Vec<(Document, f32)>, SynapticError> {
290 let query_vec = embeddings.embed_query(query).await?;
291 self.similarity_search_by_vector_with_score(&query_vec, k)
292 .await
293 }
294
295 async fn similarity_search_by_vector(
296 &self,
297 embedding: &[f32],
298 k: usize,
299 ) -> Result<Vec<Document>, SynapticError> {
300 let results = self
301 .similarity_search_by_vector_with_score(embedding, k)
302 .await?;
303 Ok(results.into_iter().map(|(doc, _)| doc).collect())
304 }
305
306 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
307 if ids.is_empty() {
308 return Ok(());
309 }
310
311 let mut bulk_body = String::new();
312 for id in ids {
313 let action = serde_json::json!({
314 "delete": {
315 "_index": self.config.index_name,
316 "_id": id,
317 }
318 });
319 bulk_body.push_str(&action.to_string());
320 bulk_body.push('\n');
321 }
322
323 let bulk_url = self.url("/_bulk");
324 let req = self
325 .apply_auth(self.client.post(&bulk_url))
326 .header("Content-Type", "application/x-ndjson")
327 .body(bulk_body);
328
329 let resp = req.send().await.map_err(|e| {
330 SynapticError::VectorStore(format!("Elasticsearch delete request failed: {e}"))
331 })?;
332
333 let status = resp.status();
334 if !status.is_success() {
335 let text = resp.text().await.unwrap_or_default();
336 return Err(SynapticError::VectorStore(format!(
337 "Elasticsearch delete error (HTTP {status}): {text}"
338 )));
339 }
340
341 Ok(())
342 }
343}
344
345impl ElasticsearchVectorStore {
346 async fn similarity_search_by_vector_with_score(
348 &self,
349 embedding: &[f32],
350 k: usize,
351 ) -> Result<Vec<(Document, f32)>, SynapticError> {
352 let num_candidates = std::cmp::max(k * 10, 100);
353
354 let search_body = serde_json::json!({
355 "size": k,
356 "knn": {
357 "field": self.config.vector_field,
358 "query_vector": embedding,
359 "k": k,
360 "num_candidates": num_candidates,
361 },
362 "_source": [&self.config.content_field, "metadata"],
363 });
364
365 let search_url = self.url(&format!("/{}/_search", self.config.index_name));
366 let req = self
367 .apply_auth(self.client.post(&search_url))
368 .header("Content-Type", "application/json")
369 .json(&search_body);
370
371 let resp = req
372 .send()
373 .await
374 .map_err(|e| SynapticError::VectorStore(format!("Elasticsearch search failed: {e}")))?;
375
376 let status = resp.status();
377 let text = resp.text().await.map_err(|e| {
378 SynapticError::VectorStore(format!("failed to read Elasticsearch response: {e}"))
379 })?;
380
381 if !status.is_success() {
382 return Err(SynapticError::VectorStore(format!(
383 "Elasticsearch search error (HTTP {status}): {text}"
384 )));
385 }
386
387 let parsed: Value = serde_json::from_str(&text).map_err(|e| {
388 SynapticError::VectorStore(format!("failed to parse Elasticsearch response: {e}"))
389 })?;
390
391 let hits = parsed["hits"]["hits"]
392 .as_array()
393 .cloned()
394 .unwrap_or_default();
395
396 let mut results = Vec::with_capacity(hits.len());
397
398 for hit in hits {
399 let id = hit
400 .get("_id")
401 .and_then(|v| v.as_str())
402 .unwrap_or("")
403 .to_string();
404
405 let score = hit.get("_score").and_then(|v| v.as_f64()).unwrap_or(0.0) as f32;
406
407 let source = hit
408 .get("_source")
409 .cloned()
410 .unwrap_or(Value::Object(serde_json::Map::new()));
411
412 let content = source
413 .get(&self.config.content_field)
414 .and_then(|v| v.as_str())
415 .unwrap_or("")
416 .to_string();
417
418 let metadata: HashMap<String, Value> = source
419 .get("metadata")
420 .and_then(|v| v.as_object())
421 .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
422 .unwrap_or_default();
423
424 let doc = Document::with_metadata(id, content, metadata);
425 results.push((doc, score));
426 }
427
428 Ok(results)
429 }
430}
431
432fn generate_id() -> String {
434 use std::sync::atomic::{AtomicU64, Ordering};
435 use std::time::{SystemTime, UNIX_EPOCH};
436
437 static COUNTER: AtomicU64 = AtomicU64::new(0);
438
439 let timestamp = SystemTime::now()
440 .duration_since(UNIX_EPOCH)
441 .unwrap_or_default()
442 .as_nanos();
443 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
444
445 format!("{:x}-{:x}", timestamp, count)
446}
447
448#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn config_new_sets_defaults() {
458 let config = ElasticsearchConfig::new("my_index", 1536);
459 assert_eq!(config.index_name, "my_index");
460 assert_eq!(config.dims, 1536);
461 assert_eq!(config.url, "http://localhost:9200");
462 assert_eq!(config.vector_field, "embedding");
463 assert_eq!(config.content_field, "content");
464 assert!(config.username.is_none());
465 assert!(config.password.is_none());
466 }
467
468 #[test]
469 fn config_with_url() {
470 let config = ElasticsearchConfig::new("idx", 768).with_url("https://es.example.com:9200");
471 assert_eq!(config.url, "https://es.example.com:9200");
472 }
473
474 #[test]
475 fn config_with_vector_field() {
476 let config = ElasticsearchConfig::new("idx", 768).with_vector_field("vec");
477 assert_eq!(config.vector_field, "vec");
478 }
479
480 #[test]
481 fn config_with_content_field() {
482 let config = ElasticsearchConfig::new("idx", 768).with_content_field("text");
483 assert_eq!(config.content_field, "text");
484 }
485
486 #[test]
487 fn config_with_auth() {
488 let config = ElasticsearchConfig::new("idx", 768).with_auth("elastic", "secret123");
489 assert_eq!(config.username.as_deref(), Some("elastic"));
490 assert_eq!(config.password.as_deref(), Some("secret123"));
491 }
492
493 #[test]
494 fn config_builder_chain() {
495 let config = ElasticsearchConfig::new("documents", 1536)
496 .with_url("https://es-cluster:9200")
497 .with_vector_field("doc_embedding")
498 .with_content_field("doc_text")
499 .with_auth("admin", "password");
500
501 assert_eq!(config.index_name, "documents");
502 assert_eq!(config.dims, 1536);
503 assert_eq!(config.url, "https://es-cluster:9200");
504 assert_eq!(config.vector_field, "doc_embedding");
505 assert_eq!(config.content_field, "doc_text");
506 assert_eq!(config.username.as_deref(), Some("admin"));
507 assert_eq!(config.password.as_deref(), Some("password"));
508 }
509
510 #[test]
511 fn store_new_creates_instance() {
512 let config = ElasticsearchConfig::new("test_idx", 768);
513 let store = ElasticsearchVectorStore::new(config);
514 assert_eq!(store.config().index_name, "test_idx");
515 assert_eq!(store.config().dims, 768);
516 }
517
518 #[test]
519 fn url_construction() {
520 let config = ElasticsearchConfig::new("idx", 768);
521 let store = ElasticsearchVectorStore::new(config);
522 assert_eq!(store.url("/_bulk"), "http://localhost:9200/_bulk");
523 assert_eq!(
524 store.url("/idx/_search"),
525 "http://localhost:9200/idx/_search"
526 );
527 }
528
529 #[test]
530 fn url_construction_trailing_slash() {
531 let config = ElasticsearchConfig::new("idx", 768).with_url("http://localhost:9200/");
532 let store = ElasticsearchVectorStore::new(config);
533 assert_eq!(store.url("/_bulk"), "http://localhost:9200/_bulk");
534 }
535
536 #[test]
537 fn generate_id_is_unique() {
538 let id1 = generate_id();
539 let id2 = generate_id();
540 assert_ne!(id1, id2);
541 }
542
543 #[test]
544 fn generate_id_is_non_empty() {
545 let id = generate_id();
546 assert!(!id.is_empty());
547 }
548}