1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use crate::database::Database;
5use crate::document::{Document, SearchParams};
6use crate::error::{Result, VectorDBError};
7use crate::filter::Filter;
8
9#[derive(Debug, Clone)]
10pub struct Collection {
11 database: Database,
12 name: String,
13}
14
15#[derive(Debug, Serialize)]
16struct UpsertRequest {
17 database: String,
18 collection: String,
19 documents: Vec<Document>,
20 #[serde(rename = "buildIndex", skip_serializing_if = "Option::is_none")]
21 build_index: Option<bool>,
22}
23
24#[derive(Debug, Serialize)]
25struct QueryRequest {
26 database: String,
27 collection: String,
28 #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
29 document_ids: Option<Vec<String>>,
30 #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
31 retrieve_vector: Option<bool>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 limit: Option<u32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 offset: Option<u32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 filter: Option<String>,
38 #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
39 output_fields: Option<Vec<String>>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 sort: Option<Value>,
42}
43
44#[derive(Debug, Serialize)]
45struct SearchRequest {
46 database: String,
47 collection: String,
48 vectors: Vec<Vec<f64>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 filter: Option<String>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 params: Option<SearchParams>,
53 #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
54 retrieve_vector: Option<bool>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 limit: Option<u32>,
57 #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
58 output_fields: Option<Vec<String>>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 radius: Option<f64>,
61}
62
63#[derive(Debug, Serialize)]
64struct SearchByIdRequest {
65 database: String,
66 collection: String,
67 #[serde(rename = "documentIds")]
68 document_ids: Vec<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 filter: Option<String>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 params: Option<SearchParams>,
73 #[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
74 retrieve_vector: Option<bool>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 limit: Option<u32>,
77 #[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
78 output_fields: Option<Vec<String>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 radius: Option<f64>,
81}
82
83#[derive(Debug, Serialize)]
84struct UpdateRequest {
85 database: String,
86 collection: String,
87 data: Document,
88 #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
89 document_ids: Option<Vec<String>>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 filter: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
95struct DeleteRequest {
96 database: String,
97 collection: String,
98 #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
99 document_ids: Option<Vec<String>>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 filter: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 limit: Option<u32>,
104}
105
106#[derive(Debug, Serialize)]
107struct CountRequest {
108 database: String,
109 collection: String,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 filter: Option<String>,
112}
113
114#[derive(Debug, Deserialize)]
115struct ApiResponse<T> {
116 code: i32,
117 msg: String,
118 #[serde(default)]
119 data: Option<T>,
120}
121
122impl Collection {
123 pub fn new(database: Database, name: String) -> Self {
124 Self { database, name }
125 }
126
127 pub fn name(&self) -> &str {
128 &self.name
129 }
130
131 pub fn database(&self) -> &Database {
132 &self.database
133 }
134
135 pub async fn upsert(
137 &self,
138 documents: Vec<Document>,
139 _timeout: Option<u64>,
140 build_index: bool,
141 ) -> Result<Value> {
142 let request = UpsertRequest {
143 database: self.database.name().to_string(),
144 collection: self.name.clone(),
145 documents,
146 build_index: Some(build_index),
147 };
148
149 let response = self
150 .database
151 .client()
152 .post("/document/upsert")
153 .await?
154 .json(&request)
155 .send()
156 .await?;
157
158 let api_response: ApiResponse<Value> =
159 self.database.client().handle_response(response).await?;
160
161 if api_response.code != 0 {
162 return Err(VectorDBError::server_error(
163 api_response.code,
164 api_response.msg,
165 ));
166 }
167
168 Ok(api_response.data.unwrap_or(Value::Null))
169 }
170
171 pub async fn query(
173 &self,
174 document_ids: Option<Vec<String>>,
175 retrieve_vector: bool,
176 limit: Option<u32>,
177 offset: Option<u32>,
178 filter: Option<Filter>,
179 output_fields: Option<Vec<String>>,
180 sort: Option<Value>,
181 ) -> Result<Vec<Document>> {
182 let request = QueryRequest {
183 database: self.database.name().to_string(),
184 collection: self.name.clone(),
185 document_ids,
186 retrieve_vector: Some(retrieve_vector),
187 limit,
188 offset,
189 filter: filter.map(|f| f.condition().to_string()),
190 output_fields,
191 sort,
192 };
193
194 let response = self
195 .database
196 .client()
197 .post("/document/query")
198 .await?
199 .json(&request)
200 .send()
201 .await?;
202
203 let api_response: ApiResponse<Vec<Value>> =
204 self.database.client().handle_response(response).await?;
205
206 if api_response.code != 0 {
207 return Err(VectorDBError::server_error(
208 api_response.code,
209 api_response.msg,
210 ));
211 }
212
213 let documents_data = api_response.data.unwrap_or_default();
214 let documents: Result<Vec<Document>> = documents_data
215 .into_iter()
216 .map(|v| serde_json::from_value(v).map_err(Into::into))
217 .collect();
218
219 documents
220 }
221
222 pub async fn search(
224 &self,
225 vectors: Vec<Vec<f64>>,
226 filter: Option<Filter>,
227 params: Option<SearchParams>,
228 retrieve_vector: bool,
229 limit: u32,
230 output_fields: Option<Vec<String>>,
231 _timeout: Option<u64>,
232 radius: Option<f64>,
233 ) -> Result<Vec<Vec<Document>>> {
234 let request = SearchRequest {
235 database: self.database.name().to_string(),
236 collection: self.name.clone(),
237 vectors,
238 filter: filter.map(|f| f.condition().to_string()),
239 params,
240 retrieve_vector: Some(retrieve_vector),
241 limit: Some(limit),
242 output_fields,
243 radius,
244 };
245
246 let response = self
247 .database
248 .client()
249 .post("/document/search")
250 .await?
251 .json(&request)
252 .send()
253 .await?;
254
255 let api_response: ApiResponse<Vec<Vec<Value>>> =
256 self.database.client().handle_response(response).await?;
257
258 if api_response.code != 0 {
259 return Err(VectorDBError::server_error(
260 api_response.code,
261 api_response.msg,
262 ));
263 }
264
265 let results_data = api_response.data.unwrap_or_default();
266 let results: Result<Vec<Vec<Document>>> = results_data
267 .into_iter()
268 .map(|batch| {
269 batch
270 .into_iter()
271 .map(|v| serde_json::from_value(v).map_err(Into::into))
272 .collect()
273 })
274 .collect();
275
276 results
277 }
278
279 pub async fn search_by_id(
281 &self,
282 document_ids: Vec<String>,
283 filter: Option<Filter>,
284 params: Option<SearchParams>,
285 retrieve_vector: bool,
286 limit: u32,
287 output_fields: Option<Vec<String>>,
288 _timeout: Option<u64>,
289 radius: Option<f64>,
290 ) -> Result<Vec<Vec<Document>>> {
291 let request = SearchByIdRequest {
292 database: self.database.name().to_string(),
293 collection: self.name.clone(),
294 document_ids,
295 filter: filter.map(|f| f.condition().to_string()),
296 params,
297 retrieve_vector: Some(retrieve_vector),
298 limit: Some(limit),
299 output_fields,
300 radius,
301 };
302
303 let response = self
304 .database
305 .client()
306 .post("/document/searchById")
307 .await?
308 .json(&request)
309 .send()
310 .await?;
311
312 let api_response: ApiResponse<Vec<Vec<Value>>> =
313 self.database.client().handle_response(response).await?;
314
315 if api_response.code != 0 {
316 return Err(VectorDBError::server_error(
317 api_response.code,
318 api_response.msg,
319 ));
320 }
321
322 let results_data = api_response.data.unwrap_or_default();
323 let results: Result<Vec<Vec<Document>>> = results_data
324 .into_iter()
325 .map(|batch| {
326 batch
327 .into_iter()
328 .map(|v| serde_json::from_value(v).map_err(Into::into))
329 .collect()
330 })
331 .collect();
332
333 results
334 }
335
336 pub async fn update(
338 &self,
339 data: Document,
340 document_ids: Option<Vec<String>>,
341 filter: Option<Filter>,
342 ) -> Result<Value> {
343 let request = UpdateRequest {
344 database: self.database.name().to_string(),
345 collection: self.name.clone(),
346 data,
347 document_ids,
348 filter: filter.map(|f| f.condition().to_string()),
349 };
350
351 let response = self
352 .database
353 .client()
354 .post("/document/update")
355 .await?
356 .json(&request)
357 .send()
358 .await?;
359
360 let api_response: ApiResponse<Value> =
361 self.database.client().handle_response(response).await?;
362
363 if api_response.code != 0 {
364 return Err(VectorDBError::server_error(
365 api_response.code,
366 api_response.msg,
367 ));
368 }
369
370 Ok(api_response.data.unwrap_or(Value::Null))
371 }
372
373 pub async fn delete(
375 &self,
376 document_ids: Option<Vec<String>>,
377 filter: Option<Filter>,
378 limit: Option<u32>,
379 ) -> Result<Value> {
380 let request = DeleteRequest {
381 database: self.database.name().to_string(),
382 collection: self.name.clone(),
383 document_ids,
384 filter: filter.map(|f| f.condition().to_string()),
385 limit,
386 };
387
388 let response = self
389 .database
390 .client()
391 .post("/document/delete")
392 .await?
393 .json(&request)
394 .send()
395 .await?;
396
397 let api_response: ApiResponse<Value> =
398 self.database.client().handle_response(response).await?;
399
400 if api_response.code != 0 {
401 return Err(VectorDBError::server_error(
402 api_response.code,
403 api_response.msg,
404 ));
405 }
406
407 Ok(api_response.data.unwrap_or(Value::Null))
408 }
409
410 pub async fn count(&self, filter: Option<Filter>) -> Result<u64> {
412 let request = CountRequest {
413 database: self.database.name().to_string(),
414 collection: self.name.clone(),
415 filter: filter.map(|f| f.condition().to_string()),
416 };
417
418 let response = self
419 .database
420 .client()
421 .post("/document/count")
422 .await?
423 .json(&request)
424 .send()
425 .await?;
426
427 let api_response: ApiResponse<Value> =
428 self.database.client().handle_response(response).await?;
429
430 if api_response.code != 0 {
431 return Err(VectorDBError::server_error(
432 api_response.code,
433 api_response.msg,
434 ));
435 }
436
437 let count_data = api_response.data.unwrap_or(Value::Number(0.into()));
438 let count: u64 = serde_json::from_value(count_data)?;
439 Ok(count)
440 }
441
442 pub async fn rebuild_index(&self) -> Result<Value> {
444 let path = format!(
445 "/index/rebuild?database={}&collection={}",
446 self.database.name(),
447 self.name
448 );
449 let response = self.database.client().post(&path).await?.send().await?;
450
451 let api_response: ApiResponse<Value> =
452 self.database.client().handle_response(response).await?;
453
454 if api_response.code != 0 {
455 return Err(VectorDBError::server_error(
456 api_response.code,
457 api_response.msg,
458 ));
459 }
460
461 Ok(api_response.data.unwrap_or(Value::Null))
462 }
463}
464