tcvectordb_rust/
collection.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use crate::database::Database;
5use crate::document::{Document, SearchParams};
6use crate::error::{Result, VectorDBError};
7use crate::filter::Filter;
8
9#[derive(Debug, Clone)]
10pub struct Collection {
11    database: Database,
12    name: String,
13}
14
15#[derive(Debug, Serialize)]
16struct UpsertRequest {
17    database: String,
18    collection: String,
19    documents: Vec<Document>,
20    #[serde(rename = "buildIndex", skip_serializing_if = "Option::is_none")]
21    build_index: Option<bool>,
22}
23
24#[derive(Debug, Serialize)]
25struct QueryRequest {
26    database: String,
27    collection: String,
28    #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
29    document_ids: Option<Vec<String>>,
30    #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
31    retrieve_vector: Option<bool>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    limit: Option<u32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    offset: Option<u32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    filter: Option<String>,
38    #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
39    output_fields: Option<Vec<String>>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    sort: Option<Value>,
42}
43
44#[derive(Debug, Serialize)]
45struct SearchRequest {
46    database: String,
47    collection: String,
48    vectors: Vec<Vec<f64>>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    filter: Option<String>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    params: Option<SearchParams>,
53    #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
54    retrieve_vector: Option<bool>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    limit: Option<u32>,
57    #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
58    output_fields: Option<Vec<String>>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    radius: Option<f64>,
61}
62
63#[derive(Debug, Serialize)]
64struct SearchByIdRequest {
65    database: String,
66    collection: String,
67    #[serde(rename = "documentIds")]
68    document_ids: Vec<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    filter: Option<String>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    params: Option<SearchParams>,
73    #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
74    retrieve_vector: Option<bool>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    limit: Option<u32>,
77    #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
78    output_fields: Option<Vec<String>>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    radius: Option<f64>,
81}
82
83#[derive(Debug, Serialize)]
84struct UpdateRequest {
85    database: String,
86    collection: String,
87    data: Document,
88    #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
89    document_ids: Option<Vec<String>>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    filter: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
95struct DeleteRequest {
96    database: String,
97    collection: String,
98    #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
99    document_ids: Option<Vec<String>>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    filter: Option<String>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    limit: Option<u32>,
104}
105
106#[derive(Debug, Serialize)]
107struct CountRequest {
108    database: String,
109    collection: String,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    filter: Option<String>,
112}
113
114#[derive(Debug, Deserialize)]
115struct ApiResponse<T> {
116    code: i32,
117    msg: String,
118    #[serde(default)]
119    data: Option<T>,
120}
121
122impl Collection {
123    pub fn new(database: Database, name: String) -> Self {
124        Self { database, name }
125    }
126
127    pub fn name(&self) -> &str {
128        &self.name
129    }
130
131    pub fn database(&self) -> &Database {
132        &self.database
133    }
134
135    /// Upsert documents into the collection
136    pub async fn upsert(
137        &self,
138        documents: Vec<Document>,
139        _timeout: Option<u64>,
140        build_index: bool,
141    ) -> Result<Value> {
142        let request = UpsertRequest {
143            database: self.database.name().to_string(),
144            collection: self.name.clone(),
145            documents,
146            build_index: Some(build_index),
147        };
148
149        let response = self
150            .database
151            .client()
152            .post("/document/upsert")
153            .await?
154            .json(&request)
155            .send()
156            .await?;
157
158        let api_response: ApiResponse<Value> =
159            self.database.client().handle_response(response).await?;
160
161        if api_response.code != 0 {
162            return Err(VectorDBError::server_error(
163                api_response.code,
164                api_response.msg,
165            ));
166        }
167
168        Ok(api_response.data.unwrap_or(Value::Null))
169    }
170
171    /// Query documents by conditions
172    pub async fn query(
173        &self,
174        document_ids: Option<Vec<String>>,
175        retrieve_vector: bool,
176        limit: Option<u32>,
177        offset: Option<u32>,
178        filter: Option<Filter>,
179        output_fields: Option<Vec<String>>,
180        sort: Option<Value>,
181    ) -> Result<Vec<Document>> {
182        let request = QueryRequest {
183            database: self.database.name().to_string(),
184            collection: self.name.clone(),
185            document_ids,
186            retrieve_vector: Some(retrieve_vector),
187            limit,
188            offset,
189            filter: filter.map(|f| f.condition().to_string()),
190            output_fields,
191            sort,
192        };
193
194        let response = self
195            .database
196            .client()
197            .post("/document/query")
198            .await?
199            .json(&request)
200            .send()
201            .await?;
202
203        let api_response: ApiResponse<Vec<Value>> =
204            self.database.client().handle_response(response).await?;
205
206        if api_response.code != 0 {
207            return Err(VectorDBError::server_error(
208                api_response.code,
209                api_response.msg,
210            ));
211        }
212
213        let documents_data = api_response.data.unwrap_or_default();
214        let documents: Result<Vec<Document>> = documents_data
215            .into_iter()
216            .map(|v| serde_json::from_value(v).map_err(Into::into))
217            .collect();
218
219        documents
220    }
221
222    /// Search similar vectors
223    pub async fn search(
224        &self,
225        vectors: Vec<Vec<f64>>,
226        filter: Option<Filter>,
227        params: Option<SearchParams>,
228        retrieve_vector: bool,
229        limit: u32,
230        output_fields: Option<Vec<String>>,
231        _timeout: Option<u64>,
232        radius: Option<f64>,
233    ) -> Result<Vec<Vec<Document>>> {
234        let request = SearchRequest {
235            database: self.database.name().to_string(),
236            collection: self.name.clone(),
237            vectors,
238            filter: filter.map(|f| f.condition().to_string()),
239            params,
240            retrieve_vector: Some(retrieve_vector),
241            limit: Some(limit),
242            output_fields,
243            radius,
244        };
245
246        let response = self
247            .database
248            .client()
249            .post("/document/search")
250            .await?
251            .json(&request)
252            .send()
253            .await?;
254
255        let api_response: ApiResponse<Vec<Vec<Value>>> =
256            self.database.client().handle_response(response).await?;
257
258        if api_response.code != 0 {
259            return Err(VectorDBError::server_error(
260                api_response.code,
261                api_response.msg,
262            ));
263        }
264
265        let results_data = api_response.data.unwrap_or_default();
266        let results: Result<Vec<Vec<Document>>> = results_data
267            .into_iter()
268            .map(|batch| {
269                batch
270                    .into_iter()
271                    .map(|v| serde_json::from_value(v).map_err(Into::into))
272                    .collect()
273            })
274            .collect();
275
276        results
277    }
278
279    /// Search by document IDs
280    pub async fn search_by_id(
281        &self,
282        document_ids: Vec<String>,
283        filter: Option<Filter>,
284        params: Option<SearchParams>,
285        retrieve_vector: bool,
286        limit: u32,
287        output_fields: Option<Vec<String>>,
288        _timeout: Option<u64>,
289        radius: Option<f64>,
290    ) -> Result<Vec<Vec<Document>>> {
291        let request = SearchByIdRequest {
292            database: self.database.name().to_string(),
293            collection: self.name.clone(),
294            document_ids,
295            filter: filter.map(|f| f.condition().to_string()),
296            params,
297            retrieve_vector: Some(retrieve_vector),
298            limit: Some(limit),
299            output_fields,
300            radius,
301        };
302
303        let response = self
304            .database
305            .client()
306            .post("/document/searchById")
307            .await?
308            .json(&request)
309            .send()
310            .await?;
311
312        let api_response: ApiResponse<Vec<Vec<Value>>> =
313            self.database.client().handle_response(response).await?;
314
315        if api_response.code != 0 {
316            return Err(VectorDBError::server_error(
317                api_response.code,
318                api_response.msg,
319            ));
320        }
321
322        let results_data = api_response.data.unwrap_or_default();
323        let results: Result<Vec<Vec<Document>>> = results_data
324            .into_iter()
325            .map(|batch| {
326                batch
327                    .into_iter()
328                    .map(|v| serde_json::from_value(v).map_err(Into::into))
329                    .collect()
330            })
331            .collect();
332
333        results
334    }
335
336    /// Update documents
337    pub async fn update(
338        &self,
339        data: Document,
340        document_ids: Option<Vec<String>>,
341        filter: Option<Filter>,
342    ) -> Result<Value> {
343        let request = UpdateRequest {
344            database: self.database.name().to_string(),
345            collection: self.name.clone(),
346            data,
347            document_ids,
348            filter: filter.map(|f| f.condition().to_string()),
349        };
350
351        let response = self
352            .database
353            .client()
354            .post("/document/update")
355            .await?
356            .json(&request)
357            .send()
358            .await?;
359
360        let api_response: ApiResponse<Value> =
361            self.database.client().handle_response(response).await?;
362
363        if api_response.code != 0 {
364            return Err(VectorDBError::server_error(
365                api_response.code,
366                api_response.msg,
367            ));
368        }
369
370        Ok(api_response.data.unwrap_or(Value::Null))
371    }
372
373    /// Delete documents
374    pub async fn delete(
375        &self,
376        document_ids: Option<Vec<String>>,
377        filter: Option<Filter>,
378        limit: Option<u32>,
379    ) -> Result<Value> {
380        let request = DeleteRequest {
381            database: self.database.name().to_string(),
382            collection: self.name.clone(),
383            document_ids,
384            filter: filter.map(|f| f.condition().to_string()),
385            limit,
386        };
387
388        let response = self
389            .database
390            .client()
391            .post("/document/delete")
392            .await?
393            .json(&request)
394            .send()
395            .await?;
396
397        let api_response: ApiResponse<Value> =
398            self.database.client().handle_response(response).await?;
399
400        if api_response.code != 0 {
401            return Err(VectorDBError::server_error(
402                api_response.code,
403                api_response.msg,
404            ));
405        }
406
407        Ok(api_response.data.unwrap_or(Value::Null))
408    }
409
410    /// Count documents
411    pub async fn count(&self, filter: Option<Filter>) -> Result<u64> {
412        let request = CountRequest {
413            database: self.database.name().to_string(),
414            collection: self.name.clone(),
415            filter: filter.map(|f| f.condition().to_string()),
416        };
417
418        let response = self
419            .database
420            .client()
421            .post("/document/count")
422            .await?
423            .json(&request)
424            .send()
425            .await?;
426
427        let api_response: ApiResponse<Value> =
428            self.database.client().handle_response(response).await?;
429
430        if api_response.code != 0 {
431            return Err(VectorDBError::server_error(
432                api_response.code,
433                api_response.msg,
434            ));
435        }
436
437        let count_data = api_response.data.unwrap_or(Value::Number(0.into()));
438        let count: u64 = serde_json::from_value(count_data)?;
439        Ok(count)
440    }
441
442    /// Rebuild index
443    pub async fn rebuild_index(&self) -> Result<Value> {
444        let path = format!(
445            "/index/rebuild?database={}&collection={}",
446            self.database.name(),
447            self.name
448        );
449        let response = self.database.client().post(&path).await?.send().await?;
450
451        let api_response: ApiResponse<Value> =
452            self.database.client().handle_response(response).await?;
453
454        if api_response.code != 0 {
455            return Err(VectorDBError::server_error(
456                api_response.code,
457                api_response.msg,
458            ));
459        }
460
461        Ok(api_response.data.unwrap_or(Value::Null))
462    }
463}
464