Skip to main content

synaptic_elasticsearch/
vector_store.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
6
7// ---------------------------------------------------------------------------
8// ElasticsearchConfig
9// ---------------------------------------------------------------------------
10
11/// Configuration for connecting to an Elasticsearch cluster.
12#[derive(Debug, Clone)]
13pub struct ElasticsearchConfig {
14    /// Elasticsearch URL (default: `http://localhost:9200`).
15    pub url: String,
16    /// Name of the index to store documents in.
17    pub index_name: String,
18    /// Field name for storing embedding vectors (default: `embedding`).
19    pub vector_field: String,
20    /// Field name for storing document content (default: `content`).
21    pub content_field: String,
22    /// Vector dimensionality (required for index creation).
23    pub dims: usize,
24    /// Optional username for basic authentication.
25    pub username: Option<String>,
26    /// Optional password for basic authentication.
27    pub password: Option<String>,
28}
29
30impl ElasticsearchConfig {
31    /// Create a new config with the required index name and vector dimensions.
32    ///
33    /// Uses default values for URL (`http://localhost:9200`), vector field
34    /// (`embedding`), and content field (`content`).
35    pub fn new(index_name: impl Into<String>, dims: usize) -> Self {
36        Self {
37            url: "http://localhost:9200".to_string(),
38            index_name: index_name.into(),
39            vector_field: "embedding".to_string(),
40            content_field: "content".to_string(),
41            dims,
42            username: None,
43            password: None,
44        }
45    }
46
47    /// Set the Elasticsearch URL.
48    pub fn with_url(mut self, url: impl Into<String>) -> Self {
49        self.url = url.into();
50        self
51    }
52
53    /// Set the vector field name.
54    pub fn with_vector_field(mut self, field: impl Into<String>) -> Self {
55        self.vector_field = field.into();
56        self
57    }
58
59    /// Set the content field name.
60    pub fn with_content_field(mut self, field: impl Into<String>) -> Self {
61        self.content_field = field.into();
62        self
63    }
64
65    /// Set basic authentication credentials.
66    pub fn with_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
67        self.username = Some(username.into());
68        self.password = Some(password.into());
69        self
70    }
71}
72
73// ---------------------------------------------------------------------------
74// ElasticsearchVectorStore
75// ---------------------------------------------------------------------------
76
77/// A [`VectorStore`] implementation backed by [Elasticsearch](https://www.elastic.co/).
78///
79/// Uses the Elasticsearch REST API with `dense_vector` field type and kNN
80/// search for similarity queries. Documents are stored with:
81/// - `_id`: the document ID
82/// - `content`: the document text
83/// - `embedding`: the vector (dense_vector type)
84/// - `metadata`: an object field with arbitrary metadata
85///
86/// Call [`ensure_index`](ElasticsearchVectorStore::ensure_index) to create
87/// the index with proper mappings before inserting documents.
88pub struct ElasticsearchVectorStore {
89    config: ElasticsearchConfig,
90    client: reqwest::Client,
91}
92
93impl ElasticsearchVectorStore {
94    /// Create a new store with the given configuration.
95    pub fn new(config: ElasticsearchConfig) -> Self {
96        Self {
97            config,
98            client: reqwest::Client::new(),
99        }
100    }
101
102    /// Return a reference to the configuration.
103    pub fn config(&self) -> &ElasticsearchConfig {
104        &self.config
105    }
106
107    /// Build a full URL for the given path.
108    fn url(&self, path: &str) -> String {
109        let base = self.config.url.trim_end_matches('/');
110        format!("{base}{path}")
111    }
112
113    /// Apply basic auth to a request builder if credentials are configured.
114    fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
115        if let (Some(ref user), Some(ref pass)) = (&self.config.username, &self.config.password) {
116            builder.basic_auth(user, Some(pass))
117        } else {
118            builder
119        }
120    }
121
122    /// Ensure the index exists with the correct mappings.
123    ///
124    /// Creates the index if it does not exist. If the index already exists,
125    /// this is a no-op. Idempotent and safe to call on every startup.
126    pub async fn ensure_index(&self) -> Result<(), SynapticError> {
127        let index_url = self.url(&format!("/{}", self.config.index_name));
128
129        // Check if index exists.
130        let head_req = self.apply_auth(self.client.head(&index_url));
131        let head_resp = head_req.send().await.map_err(|e| {
132            SynapticError::VectorStore(format!("Elasticsearch HEAD request failed: {e}"))
133        })?;
134
135        if head_resp.status().is_success() {
136            // Index already exists.
137            return Ok(());
138        }
139
140        // Create the index with mappings.
141        let mappings = serde_json::json!({
142            "mappings": {
143                "properties": {
144                    &self.config.content_field: {
145                        "type": "text"
146                    },
147                    &self.config.vector_field: {
148                        "type": "dense_vector",
149                        "dims": self.config.dims,
150                        "index": true,
151                        "similarity": "cosine"
152                    },
153                    "metadata": {
154                        "type": "object",
155                        "enabled": false
156                    }
157                }
158            }
159        });
160
161        let put_req = self
162            .apply_auth(self.client.put(&index_url))
163            .header("Content-Type", "application/json")
164            .json(&mappings);
165
166        let put_resp = put_req.send().await.map_err(|e| {
167            SynapticError::VectorStore(format!("Elasticsearch PUT index failed: {e}"))
168        })?;
169
170        let status = put_resp.status();
171        if !status.is_success() {
172            let text = put_resp.text().await.unwrap_or_default();
173            return Err(SynapticError::VectorStore(format!(
174                "Elasticsearch create index error (HTTP {status}): {text}"
175            )));
176        }
177
178        Ok(())
179    }
180}
181
182// ---------------------------------------------------------------------------
183// VectorStore implementation
184// ---------------------------------------------------------------------------
185
186#[async_trait]
187impl VectorStore for ElasticsearchVectorStore {
188    async fn add_documents(
189        &self,
190        docs: Vec<Document>,
191        embeddings: &dyn Embeddings,
192    ) -> Result<Vec<String>, SynapticError> {
193        if docs.is_empty() {
194            return Ok(Vec::new());
195        }
196
197        // Compute embeddings for all documents.
198        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
199        let vectors = embeddings.embed_documents(&texts).await?;
200
201        let mut ids = Vec::with_capacity(docs.len());
202        let mut bulk_body = String::new();
203
204        for (doc, vector) in docs.into_iter().zip(vectors) {
205            let id = if doc.id.is_empty() {
206                generate_id()
207            } else {
208                doc.id.clone()
209            };
210
211            // Build the action line.
212            let action = serde_json::json!({
213                "index": {
214                    "_index": self.config.index_name,
215                    "_id": id,
216                }
217            });
218            bulk_body.push_str(&action.to_string());
219            bulk_body.push('\n');
220
221            // Build the document line.
222            let doc_body = serde_json::json!({
223                &self.config.content_field: doc.content,
224                &self.config.vector_field: vector,
225                "metadata": doc.metadata,
226            });
227            bulk_body.push_str(&doc_body.to_string());
228            bulk_body.push('\n');
229
230            ids.push(id);
231        }
232
233        let bulk_url = self.url("/_bulk");
234        let req = self
235            .apply_auth(self.client.post(&bulk_url))
236            .header("Content-Type", "application/x-ndjson")
237            .body(bulk_body);
238
239        let resp = req.send().await.map_err(|e| {
240            SynapticError::VectorStore(format!("Elasticsearch bulk request failed: {e}"))
241        })?;
242
243        let status = resp.status();
244        let text = resp.text().await.map_err(|e| {
245            SynapticError::VectorStore(format!("failed to read Elasticsearch response: {e}"))
246        })?;
247
248        if !status.is_success() {
249            return Err(SynapticError::VectorStore(format!(
250                "Elasticsearch bulk error (HTTP {status}): {text}"
251            )));
252        }
253
254        // Check for item-level errors in the bulk response.
255        let parsed: Value = serde_json::from_str(&text).map_err(|e| {
256            SynapticError::VectorStore(format!("failed to parse Elasticsearch bulk response: {e}"))
257        })?;
258
259        if parsed
260            .get("errors")
261            .and_then(|v| v.as_bool())
262            .unwrap_or(false)
263        {
264            return Err(SynapticError::VectorStore(format!(
265                "Elasticsearch bulk operation had errors: {text}"
266            )));
267        }
268
269        Ok(ids)
270    }
271
272    async fn similarity_search(
273        &self,
274        query: &str,
275        k: usize,
276        embeddings: &dyn Embeddings,
277    ) -> Result<Vec<Document>, SynapticError> {
278        let results = self
279            .similarity_search_with_score(query, k, embeddings)
280            .await?;
281        Ok(results.into_iter().map(|(doc, _)| doc).collect())
282    }
283
284    async fn similarity_search_with_score(
285        &self,
286        query: &str,
287        k: usize,
288        embeddings: &dyn Embeddings,
289    ) -> Result<Vec<(Document, f32)>, SynapticError> {
290        let query_vec = embeddings.embed_query(query).await?;
291        self.similarity_search_by_vector_with_score(&query_vec, k)
292            .await
293    }
294
295    async fn similarity_search_by_vector(
296        &self,
297        embedding: &[f32],
298        k: usize,
299    ) -> Result<Vec<Document>, SynapticError> {
300        let results = self
301            .similarity_search_by_vector_with_score(embedding, k)
302            .await?;
303        Ok(results.into_iter().map(|(doc, _)| doc).collect())
304    }
305
306    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
307        if ids.is_empty() {
308            return Ok(());
309        }
310
311        let mut bulk_body = String::new();
312        for id in ids {
313            let action = serde_json::json!({
314                "delete": {
315                    "_index": self.config.index_name,
316                    "_id": id,
317                }
318            });
319            bulk_body.push_str(&action.to_string());
320            bulk_body.push('\n');
321        }
322
323        let bulk_url = self.url("/_bulk");
324        let req = self
325            .apply_auth(self.client.post(&bulk_url))
326            .header("Content-Type", "application/x-ndjson")
327            .body(bulk_body);
328
329        let resp = req.send().await.map_err(|e| {
330            SynapticError::VectorStore(format!("Elasticsearch delete request failed: {e}"))
331        })?;
332
333        let status = resp.status();
334        if !status.is_success() {
335            let text = resp.text().await.unwrap_or_default();
336            return Err(SynapticError::VectorStore(format!(
337                "Elasticsearch delete error (HTTP {status}): {text}"
338            )));
339        }
340
341        Ok(())
342    }
343}
344
345impl ElasticsearchVectorStore {
346    /// Search by vector and return documents with their similarity scores.
347    async fn similarity_search_by_vector_with_score(
348        &self,
349        embedding: &[f32],
350        k: usize,
351    ) -> Result<Vec<(Document, f32)>, SynapticError> {
352        let num_candidates = std::cmp::max(k * 10, 100);
353
354        let search_body = serde_json::json!({
355            "size": k,
356            "knn": {
357                "field": self.config.vector_field,
358                "query_vector": embedding,
359                "k": k,
360                "num_candidates": num_candidates,
361            },
362            "_source": [&self.config.content_field, "metadata"],
363        });
364
365        let search_url = self.url(&format!("/{}/_search", self.config.index_name));
366        let req = self
367            .apply_auth(self.client.post(&search_url))
368            .header("Content-Type", "application/json")
369            .json(&search_body);
370
371        let resp = req
372            .send()
373            .await
374            .map_err(|e| SynapticError::VectorStore(format!("Elasticsearch search failed: {e}")))?;
375
376        let status = resp.status();
377        let text = resp.text().await.map_err(|e| {
378            SynapticError::VectorStore(format!("failed to read Elasticsearch response: {e}"))
379        })?;
380
381        if !status.is_success() {
382            return Err(SynapticError::VectorStore(format!(
383                "Elasticsearch search error (HTTP {status}): {text}"
384            )));
385        }
386
387        let parsed: Value = serde_json::from_str(&text).map_err(|e| {
388            SynapticError::VectorStore(format!("failed to parse Elasticsearch response: {e}"))
389        })?;
390
391        let hits = parsed["hits"]["hits"]
392            .as_array()
393            .cloned()
394            .unwrap_or_default();
395
396        let mut results = Vec::with_capacity(hits.len());
397
398        for hit in hits {
399            let id = hit
400                .get("_id")
401                .and_then(|v| v.as_str())
402                .unwrap_or("")
403                .to_string();
404
405            let score = hit.get("_score").and_then(|v| v.as_f64()).unwrap_or(0.0) as f32;
406
407            let source = hit
408                .get("_source")
409                .cloned()
410                .unwrap_or(Value::Object(serde_json::Map::new()));
411
412            let content = source
413                .get(&self.config.content_field)
414                .and_then(|v| v.as_str())
415                .unwrap_or("")
416                .to_string();
417
418            let metadata: HashMap<String, Value> = source
419                .get("metadata")
420                .and_then(|v| v.as_object())
421                .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
422                .unwrap_or_default();
423
424            let doc = Document::with_metadata(id, content, metadata);
425            results.push((doc, score));
426        }
427
428        Ok(results)
429    }
430}
431
432/// Generate a simple unique ID.
433fn generate_id() -> String {
434    use std::sync::atomic::{AtomicU64, Ordering};
435    use std::time::{SystemTime, UNIX_EPOCH};
436
437    static COUNTER: AtomicU64 = AtomicU64::new(0);
438
439    let timestamp = SystemTime::now()
440        .duration_since(UNIX_EPOCH)
441        .unwrap_or_default()
442        .as_nanos();
443    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
444
445    format!("{:x}-{:x}", timestamp, count)
446}
447
448// ---------------------------------------------------------------------------
449// Tests
450// ---------------------------------------------------------------------------
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn config_new_sets_defaults() {
458        let config = ElasticsearchConfig::new("my_index", 1536);
459        assert_eq!(config.index_name, "my_index");
460        assert_eq!(config.dims, 1536);
461        assert_eq!(config.url, "http://localhost:9200");
462        assert_eq!(config.vector_field, "embedding");
463        assert_eq!(config.content_field, "content");
464        assert!(config.username.is_none());
465        assert!(config.password.is_none());
466    }
467
468    #[test]
469    fn config_with_url() {
470        let config = ElasticsearchConfig::new("idx", 768).with_url("https://es.example.com:9200");
471        assert_eq!(config.url, "https://es.example.com:9200");
472    }
473
474    #[test]
475    fn config_with_vector_field() {
476        let config = ElasticsearchConfig::new("idx", 768).with_vector_field("vec");
477        assert_eq!(config.vector_field, "vec");
478    }
479
480    #[test]
481    fn config_with_content_field() {
482        let config = ElasticsearchConfig::new("idx", 768).with_content_field("text");
483        assert_eq!(config.content_field, "text");
484    }
485
486    #[test]
487    fn config_with_auth() {
488        let config = ElasticsearchConfig::new("idx", 768).with_auth("elastic", "secret123");
489        assert_eq!(config.username.as_deref(), Some("elastic"));
490        assert_eq!(config.password.as_deref(), Some("secret123"));
491    }
492
493    #[test]
494    fn config_builder_chain() {
495        let config = ElasticsearchConfig::new("documents", 1536)
496            .with_url("https://es-cluster:9200")
497            .with_vector_field("doc_embedding")
498            .with_content_field("doc_text")
499            .with_auth("admin", "password");
500
501        assert_eq!(config.index_name, "documents");
502        assert_eq!(config.dims, 1536);
503        assert_eq!(config.url, "https://es-cluster:9200");
504        assert_eq!(config.vector_field, "doc_embedding");
505        assert_eq!(config.content_field, "doc_text");
506        assert_eq!(config.username.as_deref(), Some("admin"));
507        assert_eq!(config.password.as_deref(), Some("password"));
508    }
509
510    #[test]
511    fn store_new_creates_instance() {
512        let config = ElasticsearchConfig::new("test_idx", 768);
513        let store = ElasticsearchVectorStore::new(config);
514        assert_eq!(store.config().index_name, "test_idx");
515        assert_eq!(store.config().dims, 768);
516    }
517
518    #[test]
519    fn url_construction() {
520        let config = ElasticsearchConfig::new("idx", 768);
521        let store = ElasticsearchVectorStore::new(config);
522        assert_eq!(store.url("/_bulk"), "http://localhost:9200/_bulk");
523        assert_eq!(
524            store.url("/idx/_search"),
525            "http://localhost:9200/idx/_search"
526        );
527    }
528
529    #[test]
530    fn url_construction_trailing_slash() {
531        let config = ElasticsearchConfig::new("idx", 768).with_url("http://localhost:9200/");
532        let store = ElasticsearchVectorStore::new(config);
533        assert_eq!(store.url("/_bulk"), "http://localhost:9200/_bulk");
534    }
535
536    #[test]
537    fn generate_id_is_unique() {
538        let id1 = generate_id();
539        let id2 = generate_id();
540        assert_ne!(id1, id2);
541    }
542
543    #[test]
544    fn generate_id_is_non_empty() {
545        let id = generate_id();
546        assert!(!id.is_empty());
547    }
548}