Skip to main content

spire_ai/
collection.rs

1//! Typed document collections with automatic embedding and vector search.
2
3use std::marker::PhantomData;
4
5use spire_proto::spiredb::cluster::{
6    ColumnDef, ColumnType, CreateTableRequest, schema_service_client::SchemaServiceClient,
7};
8use spiresql::vector::types::{Algorithm, IndexParams};
9
10use crate::client::Spire;
11use crate::document::Doc;
12use crate::error::{Error, Result};
13use crate::search::{Filter, Search};
14use crate::watch::WatchStream;
15
16fn doc_cache_key(collection: &str, id: &str) -> u64 {
17    ahash::RandomState::with_seeds(0, 0, 0, 0).hash_one((collection, id))
18}
19
20/// A typed document collection stored in SpireDB.
21///
22/// Documents are stored as JSON in a SpireDB table, with vector embeddings
23/// in a separate vector index for semantic search.
24pub struct Collection<T: Doc> {
25    pub(crate) spire: Spire,
26    pub(crate) name: String,
27    pub(crate) _phantom: PhantomData<T>,
28}
29
30// Manual Clone impl: T doesn't need to be Clone since we only hold PhantomData<T>
31impl<T: Doc> Clone for Collection<T> {
32    fn clone(&self) -> Self {
33        Self {
34            spire: self.spire.clone(),
35            name: self.name.clone(),
36            _phantom: PhantomData,
37        }
38    }
39}
40
41impl<T: Doc> Collection<T> {
42    pub(crate) fn new(spire: Spire, name: String) -> Self {
43        Self {
44            spire,
45            name,
46            _phantom: PhantomData,
47        }
48    }
49
50    /// The internal table name used in SpireDB.
51    pub fn table_name(&self) -> String {
52        format!("_ai_{}", self.name)
53    }
54
55    /// The internal vector index name used in SpireDB.
56    pub fn index_name(&self) -> String {
57        format!("_ai_{}_vec", self.name)
58    }
59
60    /// Ensure the backing table and vector index exist.
61    ///
62    /// Creates them if they don't exist. Safe to call multiple times.
63    pub async fn ensure(&self) -> Result<()> {
64        let table = self.table_name();
65        let index = self.index_name();
66        let dims = self.spire.inner.embedder.dimensions() as u32;
67
68        // Create table via SchemaService
69        let mut schema_client = SchemaServiceClient::new(self.spire.inner.pd_channel.clone());
70
71        let columns = vec![
72            ColumnDef {
73                name: "id".to_string(),
74                r#type: ColumnType::TypeString.into(),
75                nullable: false,
76                ..Default::default()
77            },
78            ColumnDef {
79                name: "doc".to_string(),
80                r#type: ColumnType::TypeBytes.into(),
81                nullable: false,
82                ..Default::default()
83            },
84            ColumnDef {
85                name: "embed_text".to_string(),
86                r#type: ColumnType::TypeString.into(),
87                nullable: true,
88                ..Default::default()
89            },
90            ColumnDef {
91                name: "created_at".to_string(),
92                r#type: ColumnType::TypeTimestamp.into(),
93                nullable: true,
94                ..Default::default()
95            },
96        ];
97
98        let request = CreateTableRequest {
99            name: table.clone(),
100            columns,
101            primary_key: vec!["id".to_string()],
102        };
103
104        match schema_client.create_table(request).await {
105            Ok(_) => {}
106            Err(status) if status.code() == tonic::Code::AlreadyExists => {
107                // Table already exists, that's fine
108            }
109            Err(e) => return Err(Error::Grpc(e)),
110        }
111
112        // Create vector index if embedder is configured (dims > 0)
113        if dims > 0 {
114            let params = IndexParams::new(&index, &table, "embedding")
115                .algorithm(Algorithm::Manode)
116                .dimensions(dims);
117
118            match self.spire.inner.vector.create_index(params).await {
119                Ok(_) => {}
120                Err(spiresql::vector::error::VectorError::IndexAlreadyExists(_)) => {}
121                Err(e) => return Err(Error::Vector(e)),
122            }
123        }
124
125        Ok(())
126    }
127
128    /// Insert a document. Automatically generates embedding if `embed_text()` is non-empty.
129    pub async fn insert(&self, doc: &T) -> Result<String> {
130        let id = doc.id().to_string();
131        let doc_json = serde_json::to_vec(doc)?;
132        let embed_text = doc.embed_text();
133
134        // Cache the doc for later get() lookups
135        let cache_key = doc_cache_key(&self.name, &id);
136        self.spire
137            .inner
138            .doc_cache
139            .insert(cache_key, doc_json.clone());
140
141        // Generate embedding if text is non-empty
142        let embedding = if !embed_text.is_empty() {
143            Some(self.spire.inner.embedder.embed(&embed_text).await?)
144        } else {
145            None
146        };
147
148        // Insert vector with doc JSON as payload
149        if let Some(ref vec) = embedding {
150            self.vector_insert(id.as_bytes(), vec, &doc_json).await?;
151        }
152
153        Ok(id)
154    }
155
156    /// Insert multiple documents in a batch.
157    pub async fn insert_many(&self, docs: &[T]) -> Result<Vec<String>> {
158        if docs.is_empty() {
159            return Ok(Vec::new());
160        }
161
162        let ids: Vec<String> = docs.iter().map(|d| d.id().to_string()).collect();
163        let texts: Vec<String> = docs.iter().map(|d| d.embed_text()).collect();
164
165        // Batch embed non-empty texts
166        let non_empty: Vec<String> = texts.iter().filter(|t| !t.is_empty()).cloned().collect();
167
168        let embeddings = if !non_empty.is_empty() {
169            self.spire.inner.embedder.embed_batch(&non_empty).await?
170        } else {
171            Vec::new()
172        };
173
174        // Map embeddings back to docs
175        let mut embed_iter = embeddings.into_iter();
176
177        for (i, doc) in docs.iter().enumerate() {
178            let doc_json = serde_json::to_vec(doc)?;
179
180            // Cache the doc
181            let cache_key = doc_cache_key(&self.name, &ids[i]);
182            self.spire
183                .inner
184                .doc_cache
185                .insert(cache_key, doc_json.clone());
186
187            if !texts[i].is_empty()
188                && let Some(vec) = embed_iter.next()
189            {
190                self.vector_insert(ids[i].as_bytes(), &vec, &doc_json)
191                    .await?;
192            }
193        }
194
195        Ok(ids)
196    }
197
198    /// Insert into the vector index, re-creating it on `IndexNotFound`
199    async fn vector_insert(&self, doc_id: &[u8], vec: &[f32], payload: &[u8]) -> Result<u64> {
200        let index_name = self.index_name();
201        match self
202            .spire
203            .inner
204            .vector
205            .insert(&index_name, doc_id, vec, Some(payload))
206            .await
207        {
208            Ok(id) => Ok(id),
209            Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
210                // Index was lost, recreate and retry once.
211                self.ensure().await?;
212                Ok(self
213                    .spire
214                    .inner
215                    .vector
216                    .insert(&index_name, doc_id, vec, Some(payload))
217                    .await?)
218            }
219            Err(e) => Err(Error::Vector(e)),
220        }
221    }
222
223    /// Upsert a document (insert or replace).
224    pub async fn upsert(&self, doc: &T) -> Result<String> {
225        let id = doc.id().to_string();
226
227        // Delete existing vector if present (ignore not-found)
228        let _ = self
229            .spire
230            .inner
231            .vector
232            .delete(&self.index_name(), id.as_bytes())
233            .await;
234
235        // Insert the new version
236        self.insert(doc).await
237    }
238
239    /// Delete a document by ID.
240    pub async fn delete(&self, id: &str) -> Result<bool> {
241        // Remove from cache
242        let cache_key = doc_cache_key(&self.name, id);
243        self.spire.inner.doc_cache.remove(&cache_key);
244
245        match self
246            .spire
247            .inner
248            .vector
249            .delete(&self.index_name(), id.as_bytes())
250            .await
251        {
252            Ok(_) => Ok(true),
253            Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => Ok(false),
254            Err(e) => Err(Error::Vector(e)),
255        }
256    }
257
258    /// Get a document by ID.
259    ///
260    /// Checks the in-memory cache first, then falls back to a GetPayload RPC
261    /// to retrieve the payload from SpireDB.
262    pub async fn get(&self, id: &str) -> Result<Option<T>> {
263        // Fast path: check in-memory cache
264        let cache_key = doc_cache_key(&self.name, id);
265        if let Some(bytes) = self.spire.inner.doc_cache.get(&cache_key)
266            && let Ok(doc) = serde_json::from_slice::<T>(&bytes)
267        {
268            return Ok(Some(doc));
269        }
270
271        // Slow path: fetch from SpireDB via GetPayload RPC
272        match self
273            .spire
274            .inner
275            .vector
276            .get_payload(&self.index_name(), id.as_bytes())
277            .await
278        {
279            Ok(Some(payload)) => {
280                // Cache for next time
281                self.spire
282                    .inner
283                    .doc_cache
284                    .insert(cache_key, payload.clone());
285                match serde_json::from_slice::<T>(&payload) {
286                    Ok(doc) => Ok(Some(doc)),
287                    Err(_) => Ok(None),
288                }
289            }
290            Ok(None) => Ok(None),
291            Err(_) => Ok(None),
292        }
293    }
294
295    /// Get multiple documents by IDs.
296    pub async fn get_many(&self, ids: &[&str]) -> Result<Vec<T>> {
297        let mut docs = Vec::new();
298        for id in ids {
299            if let Some(doc) = self.get(id).await? {
300                docs.push(doc);
301            }
302        }
303        Ok(docs)
304    }
305
306    /// List all documents in the collection.
307    ///
308    /// Performs a broad vector search to retrieve all stored documents.
309    /// Use [`filter`](Self::filter) for SQL-based filtering once implemented.
310    pub async fn all(&self) -> Result<Vec<T>> {
311        let dims = self.spire.inner.embedder.dimensions();
312        if dims == 0 {
313            return Ok(Vec::new());
314        }
315
316        // Use a uniform normalized vector — equal components in all dimensions
317        // gives an unbiased search that returns all docs by proximity.
318        let val = 1.0 / (dims as f32).sqrt();
319        let query_vec = vec![val; dims];
320
321        let index_name = self.index_name();
322        let opts = spiresql::vector::types::SearchOptions::default()
323            .k(10_000)
324            .with_payload();
325
326        let results = match self
327            .spire
328            .inner
329            .vector
330            .search(&index_name, &query_vec, opts.clone())
331            .await
332        {
333            Ok(r) => r,
334            Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
335                self.ensure().await?;
336                self.spire
337                    .inner
338                    .vector
339                    .search(&index_name, &query_vec, opts)
340                    .await?
341            }
342            Err(e) => return Err(Error::Vector(e)),
343        };
344
345        let mut docs = Vec::with_capacity(results.len());
346        for result in results {
347            if let Some(payload) = &result.payload
348                && let Ok(doc) = serde_json::from_slice::<T>(payload)
349            {
350                docs.push(doc);
351            }
352        }
353
354        Ok(docs)
355    }
356
357    /// Start a semantic search.
358    pub fn search(&self, query: &str) -> Search<T> {
359        Search::query(self.clone(), query.to_string())
360    }
361
362    /// Find documents similar to an existing document.
363    pub fn similar(&self, id: &str) -> Search<T> {
364        Search::similar_id(self.clone(), id.to_string())
365    }
366
367    /// Find documents similar to a raw vector.
368    pub fn similar_vec(&self, vec: &[f32]) -> Search<T> {
369        Search::similar_vec(self.clone(), vec.to_vec())
370    }
371
372    /// Filter documents using SQL WHERE clause (no vector search).
373    pub fn filter(&self, sql_where: &str) -> Filter<T> {
374        Filter::new(self.clone(), sql_where.to_string())
375    }
376
377    /// Watch for changes to this collection via CDC.
378    pub async fn watch(&self) -> Result<WatchStream<T>> {
379        WatchStream::new(&self.spire.inner.stream_addr, &self.table_name()).await
380    }
381}