1use serde::{Deserialize, Serialize};
2use serde_json::{Map, Value};
3use std::collections::HashMap;
4use crate::index::SparseVector;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct HNSWSearchParams {
8 pub ef: u32,
9}
10
11impl HNSWSearchParams {
12 pub fn new(ef: u32) -> Self {
13 Self { ef }
14 }
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SearchParams {
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub ef: Option<u32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub nprobe: Option<u32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub radius: Option<f64>,
25}
26
27impl SearchParams {
28 pub fn new() -> Self {
29 Self {
30 ef: None,
31 nprobe: None,
32 radius: None,
33 }
34 }
35
36 pub fn with_ef(mut self, ef: u32) -> Self {
37 self.ef = Some(ef);
38 self
39 }
40
41 pub fn with_nprobe(mut self, nprobe: u32) -> Self {
42 self.nprobe = Some(nprobe);
43 self
44 }
45
46 pub fn with_radius(mut self, radius: f64) -> Self {
47 self.radius = Some(radius);
48 self
49 }
50}
51
52impl Default for SearchParams {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AnnSearch {
60 #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
61 pub field_name: Option<String>,
62 #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
63 pub document_ids: Option<Vec<String>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub data: Option<Vec<Vec<f64>>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub params: Option<SearchParams>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub limit: Option<u32>,
70}
71
72impl AnnSearch {
73 pub fn new() -> Self {
74 Self {
75 field_name: Some("vector".to_string()),
76 document_ids: None,
77 data: None,
78 params: None,
79 limit: None,
80 }
81 }
82
83 pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
84 self.field_name = Some(field_name.into());
85 self
86 }
87
88 pub fn with_document_ids(mut self, document_ids: Vec<String>) -> Self {
89 self.document_ids = Some(document_ids);
90 self
91 }
92
93 pub fn with_data(mut self, data: Vec<Vec<f64>>) -> Self {
94 self.data = Some(data);
95 self
96 }
97
98 pub fn with_params(mut self, params: SearchParams) -> Self {
99 self.params = Some(params);
100 self
101 }
102
103 pub fn with_limit(mut self, limit: u32) -> Self {
104 self.limit = Some(limit);
105 self
106 }
107}
108
109impl Default for AnnSearch {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct KeywordSearch {
117 #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
118 pub field_name: Option<String>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub data: Option<Vec<SparseVector>>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub limit: Option<u32>,
123 #[serde(rename = "terminateAfter", skip_serializing_if = "Option::is_none")]
124 pub terminate_after: Option<u32>,
125 #[serde(rename = "cutoffFrequency", skip_serializing_if = "Option::is_none")]
126 pub cutoff_frequency: Option<f64>,
127}
128
129impl KeywordSearch {
130 pub fn new() -> Self {
131 Self {
132 field_name: Some("sparse_vector".to_string()),
133 data: None,
134 limit: None,
135 terminate_after: None,
136 cutoff_frequency: None,
137 }
138 }
139
140 pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
141 self.field_name = Some(field_name.into());
142 self
143 }
144
145 pub fn with_data(mut self, data: Vec<SparseVector>) -> Self {
146 self.data = Some(data);
147 self
148 }
149
150 pub fn with_limit(mut self, limit: u32) -> Self {
151 self.limit = Some(limit);
152 self
153 }
154
155 pub fn with_terminate_after(mut self, terminate_after: u32) -> Self {
156 self.terminate_after = Some(terminate_after);
157 self
158 }
159
160 pub fn with_cutoff_frequency(mut self, cutoff_frequency: f64) -> Self {
161 self.cutoff_frequency = Some(cutoff_frequency);
162 self
163 }
164}
165
166impl Default for KeywordSearch {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173#[serde(tag = "method")]
174pub enum Rerank {
175 #[serde(rename = "weighted")]
176 Weighted {
177 #[serde(rename = "fieldList", skip_serializing_if = "Option::is_none")]
178 field_list: Option<Vec<String>>,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 weight: Option<Vec<f64>>,
181 },
182 #[serde(rename = "rrf")]
183 RRF {
184 #[serde(skip_serializing_if = "Option::is_none")]
185 k: Option<u32>,
186 },
187}
188
189impl Rerank {
190 pub fn weighted(field_list: Vec<String>, weight: Vec<f64>) -> Self {
191 Self::Weighted {
192 field_list: Some(field_list),
193 weight: Some(Self::normalize_weights(weight)),
194 }
195 }
196
197 pub fn rrf(k: u32) -> Self {
198 Self::RRF { k: Some(k) }
199 }
200
201 fn normalize_weights(weights: Vec<f64>) -> Vec<f64> {
202 let total: f64 = weights.iter().sum();
203 if total == 0.0 {
204 return weights;
205 }
206
207 let all_zero = weights.iter().all(|&w| w == 0.0);
208 if all_zero {
209 return weights;
210 }
211
212 let has_negative = weights.iter().any(|&w| w < 0.0);
213 if has_negative {
214 return weights;
215 }
216
217 weights.iter().map(|&w| w / total).collect()
218 }
219}
220
221#[derive(Debug, Clone, Default)]
222pub struct Document {
223 data: Map<String, Value>,
224 score: Option<f64>,
225}
226
227impl Document {
228 pub fn new() -> Self {
229 Self::default()
230 }
231
232 pub fn with_id(mut self, id: impl Into<String>) -> Self {
233 self.data.insert("id".to_string(), Value::String(id.into()));
234 self
235 }
236
237 pub fn with_vector(mut self, vector: Vec<f64>) -> Self {
238 self.data.insert("vector".to_string(), serde_json::to_value(vector).unwrap());
239 self
240 }
241
242 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
243 self.data.insert(key.into(), value.into());
244 self
245 }
246
247 pub fn with_score(mut self, score: f64) -> Self {
248 self.score = Some(score);
249 self
250 }
251
252 pub fn get(&self, key: &str) -> Option<&Value> {
253 self.data.get(key)
254 }
255
256 pub fn get_id(&self) -> Option<&str> {
257 self.data.get("id")?.as_str()
258 }
259
260 pub fn get_vector(&self) -> Option<Vec<f64>> {
261 let vector_value = self.data.get("vector")?;
262 serde_json::from_value(vector_value.clone()).ok()
263 }
264
265 pub fn get_score(&self) -> Option<f64> {
266 self.score
267 }
268
269 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<Value>) {
270 self.data.insert(key.into(), value.into());
271 }
272
273 pub fn remove(&mut self, key: &str) -> Option<Value> {
274 self.data.remove(key)
275 }
276
277 pub fn keys(&self) -> impl Iterator<Item = &String> {
278 self.data.keys()
279 }
280
281 pub fn values(&self) -> impl Iterator<Item = &Value> {
282 self.data.values()
283 }
284
285 pub fn iter(&self) -> impl Iterator<Item = (&String, &Value)> {
286 self.data.iter()
287 }
288
289 pub fn is_empty(&self) -> bool {
290 self.data.is_empty()
291 }
292
293 pub fn len(&self) -> usize {
294 self.data.len()
295 }
296}
297
298impl Serialize for Document {
299 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
300 where
301 S: serde::Serializer,
302 {
303 self.data.serialize(serializer)
304 }
305}
306
307impl<'de> Deserialize<'de> for Document {
308 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309 where
310 D: serde::Deserializer<'de>,
311 {
312 let mut data: Map<String, Value> = Map::deserialize(deserializer)?;
313 let score = data.remove("score").and_then(|v| v.as_f64());
314
315 Ok(Self { data, score })
316 }
317}
318
319impl From<HashMap<String, Value>> for Document {
320 fn from(map: HashMap<String, Value>) -> Self {
321 let mut data = Map::new();
322 let mut score = None;
323
324 for (k, v) in map {
325 if k == "score" {
326 score = v.as_f64();
327 } else {
328 data.insert(k, v);
329 }
330 }
331
332 Self { data, score }
333 }
334}
335
336impl From<Map<String, Value>> for Document {
337 fn from(mut data: Map<String, Value>) -> Self {
338 let score = data.remove("score").and_then(|v| v.as_f64());
339 Self { data, score }
340 }
341}