Skip to main content

ruvector_filter/
expression.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4/// Filter expression for querying vectors by payload
5#[derive(Debug, Clone, Serialize, Deserialize)]
6#[serde(tag = "type", rename_all = "snake_case")]
7pub enum FilterExpression {
8    // Comparison operators
9    Eq {
10        field: String,
11        value: Value,
12    },
13    Ne {
14        field: String,
15        value: Value,
16    },
17    Gt {
18        field: String,
19        value: Value,
20    },
21    Gte {
22        field: String,
23        value: Value,
24    },
25    Lt {
26        field: String,
27        value: Value,
28    },
29    Lte {
30        field: String,
31        value: Value,
32    },
33
34    // Range
35    Range {
36        field: String,
37        gte: Option<Value>,
38        lte: Option<Value>,
39    },
40
41    // Array operations
42    In {
43        field: String,
44        values: Vec<Value>,
45    },
46
47    // Text matching
48    Match {
49        field: String,
50        text: String,
51    },
52
53    // Geo operations (basic)
54    GeoRadius {
55        field: String,
56        lat: f64,
57        lon: f64,
58        radius_m: f64,
59    },
60    GeoBoundingBox {
61        field: String,
62        top_left: (f64, f64),
63        bottom_right: (f64, f64),
64    },
65
66    // Logical operators
67    And(Vec<FilterExpression>),
68    Or(Vec<FilterExpression>),
69    Not(Box<FilterExpression>),
70
71    // Existence check
72    Exists {
73        field: String,
74    },
75    IsNull {
76        field: String,
77    },
78}
79
80impl FilterExpression {
81    /// Create an equality filter
82    pub fn eq(field: impl Into<String>, value: Value) -> Self {
83        Self::Eq {
84            field: field.into(),
85            value,
86        }
87    }
88
89    /// Create a not-equal filter
90    pub fn ne(field: impl Into<String>, value: Value) -> Self {
91        Self::Ne {
92            field: field.into(),
93            value,
94        }
95    }
96
97    /// Create a greater-than filter
98    pub fn gt(field: impl Into<String>, value: Value) -> Self {
99        Self::Gt {
100            field: field.into(),
101            value,
102        }
103    }
104
105    /// Create a greater-than-or-equal filter
106    pub fn gte(field: impl Into<String>, value: Value) -> Self {
107        Self::Gte {
108            field: field.into(),
109            value,
110        }
111    }
112
113    /// Create a less-than filter
114    pub fn lt(field: impl Into<String>, value: Value) -> Self {
115        Self::Lt {
116            field: field.into(),
117            value,
118        }
119    }
120
121    /// Create a less-than-or-equal filter
122    pub fn lte(field: impl Into<String>, value: Value) -> Self {
123        Self::Lte {
124            field: field.into(),
125            value,
126        }
127    }
128
129    /// Create a range filter
130    pub fn range(field: impl Into<String>, gte: Option<Value>, lte: Option<Value>) -> Self {
131        Self::Range {
132            field: field.into(),
133            gte,
134            lte,
135        }
136    }
137
138    /// Create an IN filter
139    pub fn in_values(field: impl Into<String>, values: Vec<Value>) -> Self {
140        Self::In {
141            field: field.into(),
142            values,
143        }
144    }
145
146    /// Create a text match filter
147    pub fn match_text(field: impl Into<String>, text: impl Into<String>) -> Self {
148        Self::Match {
149            field: field.into(),
150            text: text.into(),
151        }
152    }
153
154    /// Create a geo radius filter
155    pub fn geo_radius(field: impl Into<String>, lat: f64, lon: f64, radius_m: f64) -> Self {
156        Self::GeoRadius {
157            field: field.into(),
158            lat,
159            lon,
160            radius_m,
161        }
162    }
163
164    /// Create a geo bounding box filter
165    pub fn geo_bounding_box(
166        field: impl Into<String>,
167        top_left: (f64, f64),
168        bottom_right: (f64, f64),
169    ) -> Self {
170        Self::GeoBoundingBox {
171            field: field.into(),
172            top_left,
173            bottom_right,
174        }
175    }
176
177    /// Create an AND filter
178    pub fn and(filters: Vec<FilterExpression>) -> Self {
179        Self::And(filters)
180    }
181
182    /// Create an OR filter
183    pub fn or(filters: Vec<FilterExpression>) -> Self {
184        Self::Or(filters)
185    }
186
187    /// Create a NOT filter
188    // Public API constructor mirrors `and`/`or`; not the `std::ops::Not` trait.
189    #[allow(clippy::should_implement_trait)]
190    pub fn not(filter: FilterExpression) -> Self {
191        Self::Not(Box::new(filter))
192    }
193
194    /// Create an EXISTS filter
195    pub fn exists(field: impl Into<String>) -> Self {
196        Self::Exists {
197            field: field.into(),
198        }
199    }
200
201    /// Create an IS NULL filter
202    pub fn is_null(field: impl Into<String>) -> Self {
203        Self::IsNull {
204            field: field.into(),
205        }
206    }
207
208    /// Get all field names referenced in this expression
209    pub fn get_fields(&self) -> Vec<String> {
210        let mut fields = Vec::new();
211        self.collect_fields(&mut fields);
212        fields.sort();
213        fields.dedup();
214        fields
215    }
216
217    fn collect_fields(&self, fields: &mut Vec<String>) {
218        match self {
219            Self::Eq { field, .. }
220            | Self::Ne { field, .. }
221            | Self::Gt { field, .. }
222            | Self::Gte { field, .. }
223            | Self::Lt { field, .. }
224            | Self::Lte { field, .. }
225            | Self::Range { field, .. }
226            | Self::In { field, .. }
227            | Self::Match { field, .. }
228            | Self::GeoRadius { field, .. }
229            | Self::GeoBoundingBox { field, .. }
230            | Self::Exists { field }
231            | Self::IsNull { field } => {
232                fields.push(field.clone());
233            }
234            Self::And(exprs) | Self::Or(exprs) => {
235                for expr in exprs {
236                    expr.collect_fields(fields);
237                }
238            }
239            Self::Not(expr) => {
240                expr.collect_fields(fields);
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use serde_json::json;
250
251    #[test]
252    fn test_filter_builders() {
253        let filter = FilterExpression::eq("status", json!("active"));
254        assert!(matches!(filter, FilterExpression::Eq { .. }));
255
256        let filter = FilterExpression::and(vec![
257            FilterExpression::eq("status", json!("active")),
258            FilterExpression::gte("age", json!(18)),
259        ]);
260        assert!(matches!(filter, FilterExpression::And(_)));
261    }
262
263    #[test]
264    fn test_get_fields() {
265        let filter = FilterExpression::and(vec![
266            FilterExpression::eq("status", json!("active")),
267            FilterExpression::or(vec![
268                FilterExpression::gte("age", json!(18)),
269                FilterExpression::lt("score", json!(100)),
270            ]),
271        ]);
272
273        let fields = filter.get_fields();
274        assert_eq!(fields, vec!["age", "score", "status"]);
275    }
276
277    #[test]
278    fn test_serialization() {
279        let filter = FilterExpression::eq("status", json!("active"));
280        let json = serde_json::to_string(&filter).unwrap();
281        let deserialized: FilterExpression = serde_json::from_str(&json).unwrap();
282        assert!(matches!(deserialized, FilterExpression::Eq { .. }));
283    }
284}