1use crate::error::{FilterError, Result};
2use ordered_float::OrderedFloat;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::{BTreeMap, HashMap, HashSet};
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
9#[serde(rename_all = "lowercase")]
10pub enum IndexType {
11 Integer,
12 Float,
13 Keyword,
14 Bool,
15 Geo,
16 Text,
17}
18
19#[derive(Debug, Clone)]
21pub enum PayloadIndex {
22 Integer(BTreeMap<i64, HashSet<String>>),
23 Float(BTreeMap<OrderedFloat<f64>, HashSet<String>>),
24 Keyword(HashMap<String, HashSet<String>>),
25 Bool(HashMap<bool, HashSet<String>>),
26 Geo(Vec<(String, f64, f64)>), Text(HashMap<String, HashSet<String>>), }
29
30impl PayloadIndex {
31 pub fn new(index_type: IndexType) -> Self {
33 match index_type {
34 IndexType::Integer => Self::Integer(BTreeMap::new()),
35 IndexType::Float => Self::Float(BTreeMap::new()),
36 IndexType::Keyword => Self::Keyword(HashMap::new()),
37 IndexType::Bool => Self::Bool(HashMap::new()),
38 IndexType::Geo => Self::Geo(Vec::new()),
39 IndexType::Text => Self::Text(HashMap::new()),
40 }
41 }
42
43 pub fn index_type(&self) -> IndexType {
45 match self {
46 Self::Integer(_) => IndexType::Integer,
47 Self::Float(_) => IndexType::Float,
48 Self::Keyword(_) => IndexType::Keyword,
49 Self::Bool(_) => IndexType::Bool,
50 Self::Geo(_) => IndexType::Geo,
51 Self::Text(_) => IndexType::Text,
52 }
53 }
54
55 pub fn add(&mut self, vector_id: &str, value: &Value) -> Result<()> {
57 match self {
58 Self::Integer(index) => {
59 if let Some(num) = value.as_i64() {
60 index
61 .entry(num)
62 .or_insert_with(HashSet::new)
63 .insert(vector_id.to_string());
64 }
65 }
66 Self::Float(index) => {
67 if let Some(num) = value.as_f64() {
68 index
69 .entry(OrderedFloat(num))
70 .or_insert_with(HashSet::new)
71 .insert(vector_id.to_string());
72 }
73 }
74 Self::Keyword(index) => {
75 if let Some(s) = value.as_str() {
76 index
77 .entry(s.to_string())
78 .or_insert_with(HashSet::new)
79 .insert(vector_id.to_string());
80 }
81 }
82 Self::Bool(index) => {
83 if let Some(b) = value.as_bool() {
84 index
85 .entry(b)
86 .or_insert_with(HashSet::new)
87 .insert(vector_id.to_string());
88 }
89 }
90 Self::Geo(index) => {
91 if let Some(obj) = value.as_object() {
92 if let (Some(lat), Some(lon)) = (
93 obj.get("lat").and_then(|v| v.as_f64()),
94 obj.get("lon").and_then(|v| v.as_f64()),
95 ) {
96 index.push((vector_id.to_string(), lat, lon));
97 }
98 }
99 }
100 Self::Text(index) => {
101 if let Some(text) = value.as_str() {
102 for word in text.split_whitespace() {
104 let word = word.to_lowercase();
105 index
106 .entry(word)
107 .or_insert_with(HashSet::new)
108 .insert(vector_id.to_string());
109 }
110 }
111 }
112 }
113 Ok(())
114 }
115
116 pub fn remove(&mut self, vector_id: &str, value: &Value) -> Result<()> {
118 match self {
119 Self::Integer(index) => {
120 if let Some(num) = value.as_i64() {
121 if let Some(set) = index.get_mut(&num) {
122 set.remove(vector_id);
123 if set.is_empty() {
124 index.remove(&num);
125 }
126 }
127 }
128 }
129 Self::Float(index) => {
130 if let Some(num) = value.as_f64() {
131 if let Some(set) = index.get_mut(&OrderedFloat(num)) {
132 set.remove(vector_id);
133 if set.is_empty() {
134 index.remove(&OrderedFloat(num));
135 }
136 }
137 }
138 }
139 Self::Keyword(index) => {
140 if let Some(s) = value.as_str() {
141 if let Some(set) = index.get_mut(s) {
142 set.remove(vector_id);
143 if set.is_empty() {
144 index.remove(s);
145 }
146 }
147 }
148 }
149 Self::Bool(index) => {
150 if let Some(b) = value.as_bool() {
151 if let Some(set) = index.get_mut(&b) {
152 set.remove(vector_id);
153 if set.is_empty() {
154 index.remove(&b);
155 }
156 }
157 }
158 }
159 Self::Geo(index) => {
160 index.retain(|(id, _, _)| id != vector_id);
161 }
162 Self::Text(index) => {
163 if let Some(text) = value.as_str() {
164 for word in text.split_whitespace() {
165 let word = word.to_lowercase();
166 if let Some(set) = index.get_mut(&word) {
167 set.remove(vector_id);
168 if set.is_empty() {
169 index.remove(&word);
170 }
171 }
172 }
173 }
174 }
175 }
176 Ok(())
177 }
178
179 pub fn clear(&mut self, vector_id: &str) {
181 match self {
182 Self::Integer(index) => {
183 for set in index.values_mut() {
184 set.remove(vector_id);
185 }
186 index.retain(|_, set| !set.is_empty());
187 }
188 Self::Float(index) => {
189 for set in index.values_mut() {
190 set.remove(vector_id);
191 }
192 index.retain(|_, set| !set.is_empty());
193 }
194 Self::Keyword(index) => {
195 for set in index.values_mut() {
196 set.remove(vector_id);
197 }
198 index.retain(|_, set| !set.is_empty());
199 }
200 Self::Bool(index) => {
201 for set in index.values_mut() {
202 set.remove(vector_id);
203 }
204 index.retain(|_, set| !set.is_empty());
205 }
206 Self::Geo(index) => {
207 index.retain(|(id, _, _)| id != vector_id);
208 }
209 Self::Text(index) => {
210 for set in index.values_mut() {
211 set.remove(vector_id);
212 }
213 index.retain(|_, set| !set.is_empty());
214 }
215 }
216 }
217}
218
219#[derive(Debug, Default)]
221pub struct PayloadIndexManager {
222 indices: HashMap<String, PayloadIndex>,
223}
224
225impl PayloadIndexManager {
226 pub fn new() -> Self {
228 Self {
229 indices: HashMap::new(),
230 }
231 }
232
233 pub fn create_index(&mut self, field: &str, index_type: IndexType) -> Result<()> {
235 if self.indices.contains_key(field) {
236 return Err(FilterError::InvalidExpression(format!(
237 "Index already exists for field: {}",
238 field
239 )));
240 }
241 self.indices
242 .insert(field.to_string(), PayloadIndex::new(index_type));
243 Ok(())
244 }
245
246 pub fn drop_index(&mut self, field: &str) -> Result<()> {
248 if self.indices.remove(field).is_none() {
249 return Err(FilterError::IndexNotFound(field.to_string()));
250 }
251 Ok(())
252 }
253
254 pub fn has_index(&self, field: &str) -> bool {
256 self.indices.contains_key(field)
257 }
258
259 pub fn get_index(&self, field: &str) -> Option<&PayloadIndex> {
261 self.indices.get(field)
262 }
263
264 pub fn get_index_mut(&mut self, field: &str) -> Option<&mut PayloadIndex> {
266 self.indices.get_mut(field)
267 }
268
269 pub fn index_payload(&mut self, vector_id: &str, payload: &Value) -> Result<()> {
271 if let Some(obj) = payload.as_object() {
272 for (field, value) in obj {
273 if let Some(index) = self.indices.get_mut(field) {
274 index.add(vector_id, value)?;
275 }
276 }
277 }
278 Ok(())
279 }
280
281 pub fn remove_payload(&mut self, vector_id: &str, payload: &Value) -> Result<()> {
283 if let Some(obj) = payload.as_object() {
284 for (field, value) in obj {
285 if let Some(index) = self.indices.get_mut(field) {
286 index.remove(vector_id, value)?;
287 }
288 }
289 }
290 Ok(())
291 }
292
293 pub fn clear_vector(&mut self, vector_id: &str) {
295 for index in self.indices.values_mut() {
296 index.clear(vector_id);
297 }
298 }
299
300 pub fn indexed_fields(&self) -> Vec<String> {
302 self.indices.keys().cloned().collect()
303 }
304
305 pub fn index_count(&self) -> usize {
307 self.indices.len()
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use serde_json::json;
315
316 #[test]
317 fn test_integer_index() {
318 let mut index = PayloadIndex::new(IndexType::Integer);
319 index.add("v1", &json!(42)).unwrap();
320 index.add("v2", &json!(42)).unwrap();
321 index.add("v3", &json!(100)).unwrap();
322
323 if let PayloadIndex::Integer(map) = index {
324 assert_eq!(map.get(&42).unwrap().len(), 2);
325 assert_eq!(map.get(&100).unwrap().len(), 1);
326 } else {
327 panic!("Wrong index type");
328 }
329 }
330
331 #[test]
332 fn test_keyword_index() {
333 let mut index = PayloadIndex::new(IndexType::Keyword);
334 index.add("v1", &json!("active")).unwrap();
335 index.add("v2", &json!("active")).unwrap();
336 index.add("v3", &json!("inactive")).unwrap();
337
338 if let PayloadIndex::Keyword(map) = index {
339 assert_eq!(map.get("active").unwrap().len(), 2);
340 assert_eq!(map.get("inactive").unwrap().len(), 1);
341 } else {
342 panic!("Wrong index type");
343 }
344 }
345
346 #[test]
347 fn test_index_manager() {
348 let mut manager = PayloadIndexManager::new();
349 manager.create_index("age", IndexType::Integer).unwrap();
350 manager.create_index("status", IndexType::Keyword).unwrap();
351
352 let payload = json!({
353 "age": 25,
354 "status": "active",
355 "name": "Alice"
356 });
357
358 manager.index_payload("v1", &payload).unwrap();
359 assert!(manager.has_index("age"));
360 assert!(manager.has_index("status"));
361 assert!(!manager.has_index("name"));
362 }
363
364 #[test]
365 fn test_geo_index() {
366 let mut index = PayloadIndex::new(IndexType::Geo);
367 index
368 .add("v1", &json!({"lat": 40.7128, "lon": -74.0060}))
369 .unwrap();
370 index
371 .add("v2", &json!({"lat": 34.0522, "lon": -118.2437}))
372 .unwrap();
373
374 if let PayloadIndex::Geo(points) = index {
375 assert_eq!(points.len(), 2);
376 } else {
377 panic!("Wrong index type");
378 }
379 }
380}