1use std::marker::PhantomData;
4
5use spire_proto::spiredb::cluster::{
6 ColumnDef, ColumnType, CreateTableRequest, schema_service_client::SchemaServiceClient,
7};
8use spiresql::vector::types::{Algorithm, IndexParams};
9
10use crate::client::Spire;
11use crate::document::Doc;
12use crate::error::{Error, Result};
13use crate::search::{Filter, Search};
14use crate::watch::WatchStream;
15
16fn doc_cache_key(collection: &str, id: &str) -> u64 {
17 ahash::RandomState::with_seeds(0, 0, 0, 0).hash_one((collection, id))
18}
19
20pub struct Collection<T: Doc> {
25 pub(crate) spire: Spire,
26 pub(crate) name: String,
27 pub(crate) _phantom: PhantomData<T>,
28}
29
30impl<T: Doc> Clone for Collection<T> {
32 fn clone(&self) -> Self {
33 Self {
34 spire: self.spire.clone(),
35 name: self.name.clone(),
36 _phantom: PhantomData,
37 }
38 }
39}
40
41impl<T: Doc> Collection<T> {
42 pub(crate) fn new(spire: Spire, name: String) -> Self {
43 Self {
44 spire,
45 name,
46 _phantom: PhantomData,
47 }
48 }
49
50 pub fn table_name(&self) -> String {
52 format!("_ai_{}", self.name)
53 }
54
55 pub fn index_name(&self) -> String {
57 format!("_ai_{}_vec", self.name)
58 }
59
60 pub async fn ensure(&self) -> Result<()> {
64 let table = self.table_name();
65 let index = self.index_name();
66 let dims = self.spire.inner.embedder.dimensions() as u32;
67
68 let mut schema_client = SchemaServiceClient::new(self.spire.inner.pd_channel.clone());
70
71 let columns = vec![
72 ColumnDef {
73 name: "id".to_string(),
74 r#type: ColumnType::TypeString.into(),
75 nullable: false,
76 ..Default::default()
77 },
78 ColumnDef {
79 name: "doc".to_string(),
80 r#type: ColumnType::TypeBytes.into(),
81 nullable: false,
82 ..Default::default()
83 },
84 ColumnDef {
85 name: "embed_text".to_string(),
86 r#type: ColumnType::TypeString.into(),
87 nullable: true,
88 ..Default::default()
89 },
90 ColumnDef {
91 name: "created_at".to_string(),
92 r#type: ColumnType::TypeTimestamp.into(),
93 nullable: true,
94 ..Default::default()
95 },
96 ];
97
98 let request = CreateTableRequest {
99 name: table.clone(),
100 columns,
101 primary_key: vec!["id".to_string()],
102 };
103
104 match schema_client.create_table(request).await {
105 Ok(_) => {}
106 Err(status) if status.code() == tonic::Code::AlreadyExists => {
107 }
109 Err(e) => return Err(Error::Grpc(e)),
110 }
111
112 if dims > 0 {
114 let params = IndexParams::new(&index, &table, "embedding")
115 .algorithm(Algorithm::Manode)
116 .dimensions(dims);
117
118 match self.spire.inner.vector.create_index(params).await {
119 Ok(_) => {}
120 Err(spiresql::vector::error::VectorError::IndexAlreadyExists(_)) => {}
121 Err(e) => return Err(Error::Vector(e)),
122 }
123 }
124
125 Ok(())
126 }
127
128 pub async fn insert(&self, doc: &T) -> Result<String> {
130 let id = doc.id().to_string();
131 let doc_json = serde_json::to_vec(doc)?;
132 let embed_text = doc.embed_text();
133
134 let cache_key = doc_cache_key(&self.name, &id);
136 self.spire
137 .inner
138 .doc_cache
139 .insert(cache_key, doc_json.clone());
140
141 let embedding = if !embed_text.is_empty() {
143 Some(self.spire.inner.embedder.embed(&embed_text).await?)
144 } else {
145 None
146 };
147
148 if let Some(ref vec) = embedding {
150 self.vector_insert(id.as_bytes(), vec, &doc_json).await?;
151 }
152
153 Ok(id)
154 }
155
156 pub async fn insert_many(&self, docs: &[T]) -> Result<Vec<String>> {
158 if docs.is_empty() {
159 return Ok(Vec::new());
160 }
161
162 let ids: Vec<String> = docs.iter().map(|d| d.id().to_string()).collect();
163 let texts: Vec<String> = docs.iter().map(|d| d.embed_text()).collect();
164
165 let non_empty: Vec<String> = texts.iter().filter(|t| !t.is_empty()).cloned().collect();
167
168 let embeddings = if !non_empty.is_empty() {
169 self.spire.inner.embedder.embed_batch(&non_empty).await?
170 } else {
171 Vec::new()
172 };
173
174 let mut embed_iter = embeddings.into_iter();
176
177 for (i, doc) in docs.iter().enumerate() {
178 let doc_json = serde_json::to_vec(doc)?;
179
180 let cache_key = doc_cache_key(&self.name, &ids[i]);
182 self.spire
183 .inner
184 .doc_cache
185 .insert(cache_key, doc_json.clone());
186
187 if !texts[i].is_empty()
188 && let Some(vec) = embed_iter.next()
189 {
190 self.vector_insert(ids[i].as_bytes(), &vec, &doc_json)
191 .await?;
192 }
193 }
194
195 Ok(ids)
196 }
197
198 async fn vector_insert(&self, doc_id: &[u8], vec: &[f32], payload: &[u8]) -> Result<u64> {
200 let index_name = self.index_name();
201 match self
202 .spire
203 .inner
204 .vector
205 .insert(&index_name, doc_id, vec, Some(payload))
206 .await
207 {
208 Ok(id) => Ok(id),
209 Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
210 self.ensure().await?;
212 Ok(self
213 .spire
214 .inner
215 .vector
216 .insert(&index_name, doc_id, vec, Some(payload))
217 .await?)
218 }
219 Err(e) => Err(Error::Vector(e)),
220 }
221 }
222
223 pub async fn upsert(&self, doc: &T) -> Result<String> {
225 let id = doc.id().to_string();
226
227 let _ = self
229 .spire
230 .inner
231 .vector
232 .delete(&self.index_name(), id.as_bytes())
233 .await;
234
235 self.insert(doc).await
237 }
238
239 pub async fn delete(&self, id: &str) -> Result<bool> {
241 let cache_key = doc_cache_key(&self.name, id);
243 self.spire.inner.doc_cache.remove(&cache_key);
244
245 match self
246 .spire
247 .inner
248 .vector
249 .delete(&self.index_name(), id.as_bytes())
250 .await
251 {
252 Ok(_) => Ok(true),
253 Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => Ok(false),
254 Err(e) => Err(Error::Vector(e)),
255 }
256 }
257
258 pub async fn get(&self, id: &str) -> Result<Option<T>> {
263 let cache_key = doc_cache_key(&self.name, id);
265 if let Some(bytes) = self.spire.inner.doc_cache.get(&cache_key)
266 && let Ok(doc) = serde_json::from_slice::<T>(&bytes)
267 {
268 return Ok(Some(doc));
269 }
270
271 match self
273 .spire
274 .inner
275 .vector
276 .get_payload(&self.index_name(), id.as_bytes())
277 .await
278 {
279 Ok(Some(payload)) => {
280 self.spire
282 .inner
283 .doc_cache
284 .insert(cache_key, payload.clone());
285 match serde_json::from_slice::<T>(&payload) {
286 Ok(doc) => Ok(Some(doc)),
287 Err(_) => Ok(None),
288 }
289 }
290 Ok(None) => Ok(None),
291 Err(_) => Ok(None),
292 }
293 }
294
295 pub async fn get_many(&self, ids: &[&str]) -> Result<Vec<T>> {
297 let mut docs = Vec::new();
298 for id in ids {
299 if let Some(doc) = self.get(id).await? {
300 docs.push(doc);
301 }
302 }
303 Ok(docs)
304 }
305
306 pub async fn all(&self) -> Result<Vec<T>> {
311 let dims = self.spire.inner.embedder.dimensions();
312 if dims == 0 {
313 return Ok(Vec::new());
314 }
315
316 let val = 1.0 / (dims as f32).sqrt();
319 let query_vec = vec![val; dims];
320
321 let index_name = self.index_name();
322 let opts = spiresql::vector::types::SearchOptions::default()
323 .k(10_000)
324 .with_payload();
325
326 let results = match self
327 .spire
328 .inner
329 .vector
330 .search(&index_name, &query_vec, opts.clone())
331 .await
332 {
333 Ok(r) => r,
334 Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
335 self.ensure().await?;
336 self.spire
337 .inner
338 .vector
339 .search(&index_name, &query_vec, opts)
340 .await?
341 }
342 Err(e) => return Err(Error::Vector(e)),
343 };
344
345 let mut docs = Vec::with_capacity(results.len());
346 for result in results {
347 if let Some(payload) = &result.payload
348 && let Ok(doc) = serde_json::from_slice::<T>(payload)
349 {
350 docs.push(doc);
351 }
352 }
353
354 Ok(docs)
355 }
356
357 pub fn search(&self, query: &str) -> Search<T> {
359 Search::query(self.clone(), query.to_string())
360 }
361
362 pub fn similar(&self, id: &str) -> Search<T> {
364 Search::similar_id(self.clone(), id.to_string())
365 }
366
367 pub fn similar_vec(&self, vec: &[f32]) -> Search<T> {
369 Search::similar_vec(self.clone(), vec.to_vec())
370 }
371
372 pub fn filter(&self, sql_where: &str) -> Filter<T> {
374 Filter::new(self.clone(), sql_where.to_string())
375 }
376
377 pub async fn watch(&self) -> Result<WatchStream<T>> {
379 WatchStream::new(&self.spire.inner.stream_addr, &self.table_name()).await
380 }
381}