tcvectordb_rust/
document.rs

1use serde::{Deserialize, Serialize};
2use serde_json::{Map, Value};
3use std::collections::HashMap;
4use crate::index::SparseVector;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct HNSWSearchParams {
8    pub ef: u32,
9}
10
11impl HNSWSearchParams {
12    pub fn new(ef: u32) -> Self {
13        Self { ef }
14    }
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SearchParams {
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub ef: Option<u32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub nprobe: Option<u32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub radius: Option<f64>,
25}
26
27impl SearchParams {
28    pub fn new() -> Self {
29        Self {
30            ef: None,
31            nprobe: None,
32            radius: None,
33        }
34    }
35
36    pub fn with_ef(mut self, ef: u32) -> Self {
37        self.ef = Some(ef);
38        self
39    }
40
41    pub fn with_nprobe(mut self, nprobe: u32) -> Self {
42        self.nprobe = Some(nprobe);
43        self
44    }
45
46    pub fn with_radius(mut self, radius: f64) -> Self {
47        self.radius = Some(radius);
48        self
49    }
50}
51
52impl Default for SearchParams {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AnnSearch {
60    #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
61    pub field_name: Option<String>,
62    #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
63    pub document_ids: Option<Vec<String>>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub data: Option<Vec<Vec<f64>>>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub params: Option<SearchParams>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub limit: Option<u32>,
70}
71
72impl AnnSearch {
73    pub fn new() -> Self {
74        Self {
75            field_name: Some("vector".to_string()),
76            document_ids: None,
77            data: None,
78            params: None,
79            limit: None,
80        }
81    }
82
83    pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
84        self.field_name = Some(field_name.into());
85        self
86    }
87
88    pub fn with_document_ids(mut self, document_ids: Vec<String>) -> Self {
89        self.document_ids = Some(document_ids);
90        self
91    }
92
93    pub fn with_data(mut self, data: Vec<Vec<f64>>) -> Self {
94        self.data = Some(data);
95        self
96    }
97
98    pub fn with_params(mut self, params: SearchParams) -> Self {
99        self.params = Some(params);
100        self
101    }
102
103    pub fn with_limit(mut self, limit: u32) -> Self {
104        self.limit = Some(limit);
105        self
106    }
107}
108
109impl Default for AnnSearch {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct KeywordSearch {
117    #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
118    pub field_name: Option<String>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub data: Option<Vec<SparseVector>>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub limit: Option<u32>,
123    #[serde(rename = "terminateAfter", skip_serializing_if = "Option::is_none")]
124    pub terminate_after: Option<u32>,
125    #[serde(rename = "cutoffFrequency", skip_serializing_if = "Option::is_none")]
126    pub cutoff_frequency: Option<f64>,
127}
128
129impl KeywordSearch {
130    pub fn new() -> Self {
131        Self {
132            field_name: Some("sparse_vector".to_string()),
133            data: None,
134            limit: None,
135            terminate_after: None,
136            cutoff_frequency: None,
137        }
138    }
139
140    pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
141        self.field_name = Some(field_name.into());
142        self
143    }
144
145    pub fn with_data(mut self, data: Vec<SparseVector>) -> Self {
146        self.data = Some(data);
147        self
148    }
149
150    pub fn with_limit(mut self, limit: u32) -> Self {
151        self.limit = Some(limit);
152        self
153    }
154
155    pub fn with_terminate_after(mut self, terminate_after: u32) -> Self {
156        self.terminate_after = Some(terminate_after);
157        self
158    }
159
160    pub fn with_cutoff_frequency(mut self, cutoff_frequency: f64) -> Self {
161        self.cutoff_frequency = Some(cutoff_frequency);
162        self
163    }
164}
165
166impl Default for KeywordSearch {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173#[serde(tag = "method")]
174pub enum Rerank {
175    #[serde(rename = "weighted")]
176    Weighted {
177        #[serde(rename = "fieldList", skip_serializing_if = "Option::is_none")]
178        field_list: Option<Vec<String>>,
179        #[serde(skip_serializing_if = "Option::is_none")]
180        weight: Option<Vec<f64>>,
181    },
182    #[serde(rename = "rrf")]
183    RRF {
184        #[serde(skip_serializing_if = "Option::is_none")]
185        k: Option<u32>,
186    },
187}
188
189impl Rerank {
190    pub fn weighted(field_list: Vec<String>, weight: Vec<f64>) -> Self {
191        Self::Weighted {
192            field_list: Some(field_list),
193            weight: Some(Self::normalize_weights(weight)),
194        }
195    }
196
197    pub fn rrf(k: u32) -> Self {
198        Self::RRF { k: Some(k) }
199    }
200
201    fn normalize_weights(weights: Vec<f64>) -> Vec<f64> {
202        let total: f64 = weights.iter().sum();
203        if total == 0.0 {
204            return weights;
205        }
206
207        let all_zero = weights.iter().all(|&w| w == 0.0);
208        if all_zero {
209            return weights;
210        }
211
212        let has_negative = weights.iter().any(|&w| w < 0.0);
213        if has_negative {
214            return weights;
215        }
216
217        weights.iter().map(|&w| w / total).collect()
218    }
219}
220
221#[derive(Debug, Clone, Default)]
222pub struct Document {
223    data: Map<String, Value>,
224    score: Option<f64>,
225}
226
227impl Document {
228    pub fn new() -> Self {
229        Self::default()
230    }
231
232    pub fn with_id(mut self, id: impl Into<String>) -> Self {
233        self.data.insert("id".to_string(), Value::String(id.into()));
234        self
235    }
236
237    pub fn with_vector(mut self, vector: Vec<f64>) -> Self {
238        self.data.insert("vector".to_string(), serde_json::to_value(vector).unwrap());
239        self
240    }
241
242    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
243        self.data.insert(key.into(), value.into());
244        self
245    }
246
247    pub fn with_score(mut self, score: f64) -> Self {
248        self.score = Some(score);
249        self
250    }
251
252    pub fn get(&self, key: &str) -> Option<&Value> {
253        self.data.get(key)
254    }
255
256    pub fn get_id(&self) -> Option<&str> {
257        self.data.get("id")?.as_str()
258    }
259
260    pub fn get_vector(&self) -> Option<Vec<f64>> {
261        let vector_value = self.data.get("vector")?;
262        serde_json::from_value(vector_value.clone()).ok()
263    }
264
265    pub fn get_score(&self) -> Option<f64> {
266        self.score
267    }
268
269    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<Value>) {
270        self.data.insert(key.into(), value.into());
271    }
272
273    pub fn remove(&mut self, key: &str) -> Option<Value> {
274        self.data.remove(key)
275    }
276
277    pub fn keys(&self) -> impl Iterator<Item = &String> {
278        self.data.keys()
279    }
280
281    pub fn values(&self) -> impl Iterator<Item = &Value> {
282        self.data.values()
283    }
284
285    pub fn iter(&self) -> impl Iterator<Item = (&String, &Value)> {
286        self.data.iter()
287    }
288
289    pub fn is_empty(&self) -> bool {
290        self.data.is_empty()
291    }
292
293    pub fn len(&self) -> usize {
294        self.data.len()
295    }
296}
297
298impl Serialize for Document {
299    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
300    where
301        S: serde::Serializer,
302    {
303        self.data.serialize(serializer)
304    }
305}
306
307impl<'de> Deserialize<'de> for Document {
308    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309    where
310        D: serde::Deserializer<'de>,
311    {
312        let mut data: Map<String, Value> = Map::deserialize(deserializer)?;
313        let score = data.remove("score").and_then(|v| v.as_f64());
314        
315        Ok(Self { data, score })
316    }
317}
318
319impl From<HashMap<String, Value>> for Document {
320    fn from(map: HashMap<String, Value>) -> Self {
321        let mut data = Map::new();
322        let mut score = None;
323        
324        for (k, v) in map {
325            if k == "score" {
326                score = v.as_f64();
327            } else {
328                data.insert(k, v);
329            }
330        }
331        
332        Self { data, score }
333    }
334}
335
336impl From<Map<String, Value>> for Document {
337    fn from(mut data: Map<String, Value>) -> Self {
338        let score = data.remove("score").and_then(|v| v.as_f64());
339        Self { data, score }
340    }
341}