Skip to main content

trs_dataframe/filter/
mod.rs

1use std::str::FromStr;
2
3use data_value::DataValue;
4use pest::{iterators::Pair, Parser};
5use regex::Regex;
6use tracing::trace;
7/// Filter-specific error types.
8pub mod error;
9/// Operator matching and function-based filtering logic.
10pub mod filtering;
11pub use filtering::*;
12type Result<T> = std::result::Result<T, error::Error>;
13
14#[derive(pest_derive::Parser)]
15#[grammar = "filter/grammar/data.pest"]
16struct DataParser;
17
18/// Trait for data sources that can evaluate filter expressions.
19pub trait Filtering {
20    /// Returns row indices matching `expression`.
21    fn prepare_indicies(&self, expression: &Expression) -> Result<Vec<usize>>;
22    /// Returns row indices for function-based expressions (e.g. `.len()`).
23    fn apply_function(&self, expression: &Expression) -> Result<Vec<usize>>;
24}
25
26/// The operators for filtering functions
27#[derive(Debug, Clone, PartialEq, Copy)]
28pub enum FilterOperator {
29    Equal,
30    NotEqual,
31    Less,
32    Greater,
33    LeOrEq,
34    GrOrEq,
35    Regex,
36    In,
37    NotIn,
38}
39
40/// The operators for filtering functions
41#[derive(Debug, Clone, PartialEq, Copy)]
42pub enum FilterJoin {
43    And,
44    Or,
45}
46/// A single filter comparison: `left operator right`.
47#[derive(Debug, Clone, PartialEq)]
48pub struct Expression {
49    pub left: DataInput,
50    pub operator: FilterOperator,
51    pub right: DataInput,
52}
53
54/// The right-hand side of a filter expression, resolved to a concrete form.
55#[derive(Debug)]
56pub enum FilterArgument {
57    Value(DataValue),
58    Regex(regex::Regex),
59    Vec(Vec<DataValue>),
60}
61
62impl FilterArgument {
63    /// Returns the scalar value, or [`DataValue::Null`] for non-scalar variants.
64    pub fn value(&self) -> &DataValue {
65        match self {
66            FilterArgument::Value(value) => value,
67            FilterArgument::Regex(_) => &DataValue::Null, // Regex does not have a value
68            FilterArgument::Vec(_vec) => &DataValue::Null,
69        }
70    }
71
72    /// Returns the inner vector for `In`/`NotIn` arguments, if available.
73    pub fn vec(&self) -> Option<&Vec<DataValue>> {
74        match self {
75            FilterArgument::Value(value) => {
76                if let DataValue::Vec(vec) = value {
77                    Some(vec)
78                } else {
79                    None
80                }
81            }
82            FilterArgument::Regex(_) => None, // Regex does not have a value
83            FilterArgument::Vec(vec) => Some(vec),
84        }
85    }
86
87    /// Returns the compiled regex for `Regex` arguments, if available.
88    pub fn regex(&self) -> Option<&Regex> {
89        match self {
90            FilterArgument::Value(_value) => None,
91            FilterArgument::Regex(regex) => Some(regex),
92            FilterArgument::Vec(_) => None, // Vec does not have a regex
93        }
94    }
95}
96
97impl Expression {
98    /// Resolves the right-hand side into a [`FilterArgument`] suitable for
99    /// the expression's operator.
100    pub fn filter_argument(&self) -> Result<FilterArgument> {
101        match self.operator {
102            FilterOperator::Equal
103            | FilterOperator::NotEqual
104            | FilterOperator::Less
105            | FilterOperator::Greater
106            | FilterOperator::LeOrEq
107            | FilterOperator::GrOrEq => Ok(FilterArgument::Value(self.right.value())),
108            FilterOperator::Regex => {
109                if let DataValue::String(ref regex) = self.right.value() {
110                    Ok(FilterArgument::Regex(regex::Regex::new(regex)?))
111                } else {
112                    Err(error::parser_error(
113                        "Expected a regex string for Regex operator",
114                    ))
115                }
116            }
117            FilterOperator::In | FilterOperator::NotIn => {
118                if let DataValue::Vec(ref vec) = self.right.value() {
119                    Ok(FilterArgument::Vec(vec.clone()))
120                } else {
121                    Err(error::parser_error(
122                        "Expected a vector for In/NotIn operator",
123                    ))
124                }
125            }
126        }
127    }
128}
129
130/// A tree of filter expressions connected by `&&` / `||` operators.
131#[derive(Debug, Clone, PartialEq)]
132pub enum FilterCombinantion {
133    Simple(Expression),
134    /// and with &&
135    And(Expression, Box<FilterCombinantion>),
136    /// or with ||
137    Or(Expression, Box<FilterCombinantion>),
138    Grouped(Vec<FilterCombinantion>),
139}
140
141/// Built-in column functions usable in filter expressions (e.g. `.len()`).
142#[derive(Debug, Clone, Copy, PartialEq)]
143pub enum Function {
144    Len,
145    ToDateTimeUs,
146}
147
148/// One side of a filter expression — a literal value, column key, or function call.
149#[derive(Debug, Clone, PartialEq)]
150pub enum DataInput {
151    Value(DataValue),
152    Key(String),
153    Function(String, Function),
154    Mod(String, DataValue),
155}
156
157impl DataInput {
158    /// Returns the column name if this is a `Key`, `Function`, or `Mod` variant.
159    pub fn as_key(&self) -> Option<&str> {
160        match self {
161            DataInput::Key(key) => Some(key),
162            DataInput::Value(_) => None,
163            DataInput::Function(key, _) => Some(key), // Functions do not have a key
164            DataInput::Mod(key, _) => Some(key),
165        }
166    }
167
168    /// Returns the contained [`DataValue`], or [`DataValue::Null`] for
169    /// function/mod variants.
170    pub fn value(&self) -> DataValue {
171        match self {
172            DataInput::Value(value) => value.clone(),
173            DataInput::Key(key) => DataValue::String(key.into()),
174            DataInput::Function(_, _) => DataValue::Null,
175            DataInput::Mod(..) => DataValue::Null,
176        }
177    }
178
179    /// Returns `true` if this input is a function call (e.g. `.len()`).
180    pub fn is_function(&self) -> bool {
181        matches!(self, DataInput::Function(_, _))
182    }
183
184    /// Returns `true` if this input is a modulo operation (e.g. `col % 2`).
185    pub fn is_mod(&self) -> bool {
186        matches!(self, DataInput::Mod(_, _))
187    }
188}
189
190/// Parsed filter expression tree, ready to be evaluated against a data source.
191#[derive(Debug, Clone, PartialEq)]
192pub struct FilterRules {
193    pub rules: Vec<FilterCombinantion>,
194}
195
196impl TryFrom<&str> for FilterRules {
197    type Error = error::Error;
198
199    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
200        DataParser::parse(Rule::full_expression, value)
201            .map_err(|e| error::parser_error(format!("Failed to parse DataInput: {e}")))?
202            .next()
203            .ok_or(error::parser_error(
204                "Expected a Rule::atom but found nothing",
205            ))
206            .and_then(parse_full_expression)
207    }
208}
209
210fn parse_left(rule: Pair<Rule>) -> Result<DataInput> {
211    trace!("Parsing left expression: {rule:?}");
212    let mut inner = rule.into_inner();
213    trace!("Parsing left inner: {inner:?}");
214    let key = inner
215        .next()
216        .ok_or(error::parser_error("Expected a key in left expression"))?
217        .as_str()
218        .to_string();
219    if let Some(function) = inner.next() {
220        let function_name = function.as_str();
221
222        if function_name.contains("%") {
223            let mut inn = function.into_inner();
224            // let _ = inn
225            //     .next()
226            //     .ok_or(error::parser_error("Expected a key in left expression"))?;
227            let atom = inn
228                .next()
229                .ok_or(error::parser_error("Expected a key in left expression"))?;
230            trace!("Atom {atom:?}");
231            return Ok(DataInput::Mod(key, parse_atom(atom)?.value()));
232        }
233        let function = match function_name {
234            ".len()" => Function::Len,
235            ".to_datetime_us()" => Function::ToDateTimeUs,
236            _ => return Err(error::parser_error("Unknown function: {function_name}")),
237        };
238        return Ok(DataInput::Function(key, function));
239    }
240    Ok(DataInput::Key(key)) // Placeholder for Function
241}
242
243fn parse_expression(pair: Pair<Rule>) -> Result<Expression> {
244    trace!("Parsing expression: {pair:?}");
245    match pair.as_rule() {
246        Rule::expression => {
247            let mut pairs = pair.into_inner();
248            trace!("Parsing expression pairs: {pairs:?}");
249            let left = parse_left(
250                pairs
251                    .next()
252                    .ok_or(error::parser_error("Expected a left expression"))?,
253            )?;
254            trace!("Parsing expression left: {left:?}");
255
256            let operator = pairs
257                .next()
258                .and_then(|s| s.as_str().parse::<FilterOperator>().ok())
259                .ok_or(error::parser_error("Expected a valid filter operator"))?;
260            trace!("Parsing expression operator: {operator:?}");
261            let right = parse_atom(
262                pairs
263                    .next()
264                    .ok_or(error::parser_error("Expected a right expression"))?,
265            )?;
266
267            trace!("Parsing expression right: {right:?}");
268            Ok(Expression {
269                left,
270                operator,
271                right,
272            })
273        }
274        e => Err(error::parser_error(format!(
275            "Unexpected rule in expression {e:?}"
276        ))),
277    }
278}
279fn parse_operator(pair: Pair<Rule>) -> Result<FilterJoin> {
280    match pair.as_str() {
281        "&&" => Ok(FilterJoin::And),
282        "||" => Ok(FilterJoin::Or),
283        _ => Err(error::parser_error(format!(
284            "Unknown operator: {}",
285            pair.as_str()
286        ))),
287    }
288}
289fn parse_filter_combination(pair: Pair<Rule>) -> Result<FilterCombinantion> {
290    if pair.as_rule() == Rule::expression {
291        return Ok(FilterCombinantion::Simple(parse_expression(pair)?));
292    }
293    let mut pairs = pair.into_inner();
294    trace!("Parsing filter combo expression pairs: {pairs:?}");
295    let first = parse_expression(pairs.next().ok_or(error::parser_error(
296        "Expected at least one expression in the pair",
297    ))?)?;
298    if let Some(op) = pairs.next() {
299        trace!("Parsing filter combo expression: {op:?} vs pairs {pairs:?}");
300        let op = parse_operator(op)?;
301        match op {
302            FilterJoin::And => {
303                return Ok(FilterCombinantion::And(
304                    first,
305                    Box::new(parse_filter_combination(pairs.next().ok_or(
306                        error::parser_error("Expected a next expression after '&&'"),
307                    )?)?),
308                ));
309            }
310            FilterJoin::Or => {
311                return Ok(FilterCombinantion::Or(
312                    first,
313                    Box::new(parse_filter_combination(pairs.next().ok_or(
314                        error::parser_error("Expected a next expression after '||'"),
315                    )?)?),
316                ));
317            }
318        }
319    }
320    Ok(FilterCombinantion::Simple(first))
321}
322fn parse_full_expression(pair: Pair<Rule>) -> Result<FilterRules> {
323    let mut rules = Vec::new();
324    trace!("Parsing full expression: {pair:?}");
325    match pair.as_rule() {
326        Rule::full_expression => {
327            let mut pairs = pair.into_inner();
328            trace!("Parsing full expression pairs: {pairs:?}");
329            let left = parse_expression(pairs.next().ok_or(error::parser_error(
330                "Expected at least one expression in the pair",
331            ))?)?;
332
333            if let Some(op) = pairs.next() {
334                trace!("Parsing operator: {op:?}");
335                let op = parse_operator(op)?;
336                let right = pairs.next().ok_or(error::parser_error(
337                    "Expected a next expression after operator",
338                ))?;
339                let ops = |op: FilterJoin,
340                           right: FilterCombinantion,
341                           rules: &mut Vec<FilterCombinantion>|
342                 -> Result<()> {
343                    match op {
344                        FilterJoin::And => {
345                            rules.push(FilterCombinantion::And(left, Box::new(right)));
346                        }
347                        FilterJoin::Or => {
348                            rules.push(FilterCombinantion::Or(left, Box::new(right)));
349                        }
350                    }
351                    Ok(())
352                };
353                match right.as_rule() {
354                    Rule::expression => {
355                        let right_expr = parse_expression(right)?;
356                        ops(op, FilterCombinantion::Simple(right_expr), &mut rules)?;
357                    }
358                    Rule::grouped_expression => {
359                        let grouped_expr = parse_filter_combination(right)?;
360                        ops(op, grouped_expr, &mut rules)?;
361                    }
362                    _ => return Err(error::parser_error("Expected an expression after operator")),
363                }
364            } else {
365                rules.push(FilterCombinantion::Simple(left));
366            }
367        }
368        _ => return Err(error::parser_error("Expected a full expression rule")),
369    }
370
371    Ok(FilterRules { rules })
372}
373
374impl TryFrom<&str> for DataInput {
375    type Error = error::Error;
376
377    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
378        DataParser::parse(Rule::atom, value)
379            .map_err(|e| error::parser_error(format!("Failed to parse DataInput: {e}")))?
380            .next()
381            .ok_or(error::parser_error(
382                "Expected a Rule::atom but found nothing",
383            ))
384            .and_then(parse_atom)
385    }
386}
387
388fn number_to_value<T: FromStr>(number: &str, post_fix: &str) -> Result<T> {
389    num_to_value(number.split(post_fix).next().ok_or_else(|| {
390        error::parser_error("Expected a number with postfix '{post_fix}' but found: {number}")
391    })?)
392}
393
394fn num_to_value<T: FromStr>(number: &str) -> Result<T> {
395    match number.parse::<T>() {
396        Ok(value) => Ok(value),
397        Err(_e) => Err(error::parser_error(format!(
398            "Failed to parse number {number}"
399        ))),
400    }
401}
402
403fn parse_atom(rule: Pair<Rule>) -> Result<DataInput> {
404    match rule.as_rule() {
405        Rule::atom => {
406            let inner = rule.into_inner().next().ok_or(error::parser_error(
407                "Expected a Rule::atom but found nothing",
408            ))?;
409            parse_atom(inner)
410        }
411        Rule::u32 => number_to_value::<u32>(rule.as_str(), "u32")
412            .map(|value| DataInput::Value(DataValue::from(value))),
413        Rule::i32 => number_to_value::<i32>(rule.as_str(), "i32")
414            .map(|value| DataInput::Value(DataValue::from(value))),
415        Rule::u64 => number_to_value::<u64>(rule.as_str(), "u64")
416            .map(|value| DataInput::Value(DataValue::from(value))),
417        Rule::i64 => {
418            let str_rule = rule.as_str();
419            if str_rule.contains("i64") {
420                number_to_value::<i64>(str_rule, "i64")
421                    .map(|value| DataInput::Value(DataValue::from(value)))
422            } else {
423                num_to_value::<i64>(str_rule).map(|val| DataInput::Value(DataValue::from(val)))
424            }
425        }
426        Rule::f32 => number_to_value::<f32>(rule.as_str(), "f32")
427            .map(|value| DataInput::Value(DataValue::from(value))),
428        Rule::f64 => number_to_value::<f64>(rule.as_str(), "f64")
429            .map(|value| DataInput::Value(DataValue::from(value))),
430        Rule::float => number_to_value::<f64>(rule.as_str(), "f64")
431            .map(|value| DataInput::Value(DataValue::from(value))),
432        Rule::string_qt => {
433            let value = rule.as_str().trim_matches('\'');
434            Ok(DataInput::Value(DataValue::String(value.into())))
435        }
436        Rule::boolean => {
437            let value = rule.as_str();
438            match value {
439                "true" => Ok(DataInput::Value(DataValue::Bool(true))),
440                "false" => Ok(DataInput::Value(DataValue::Bool(false))),
441                _ => Err(error::parser_error(
442                    "Expected boolean value but found: {value}",
443                )),
444            }
445        }
446        Rule::null => Ok(DataInput::Value(DataValue::Null)),
447        Rule::key => Ok(DataInput::Key(rule.as_str().to_string())),
448        Rule::array => {
449            let mut values = Vec::new();
450            for pair in rule.into_inner() {
451                match parse_atom(pair)? {
452                    DataInput::Value(value) => values.push(value),
453                    DataInput::Key(key) => {
454                        values.push(DataValue::String(key.into()));
455                    }
456                    DataInput::Function(_, _) => {
457                        return Err(error::parser_error("Function in array is not supported"));
458                    }
459                    DataInput::Mod(_, _) => {
460                        return Err(error::parser_error("Function in array is not supported"));
461                    }
462                }
463            }
464            Ok(DataInput::Value(DataValue::Vec(values)))
465        }
466        Rule::left => parse_left(rule),
467        _ => Err(error::parser_error("{rule} did not match any 'Rule' ")),
468    }
469}
470
471impl std::str::FromStr for FilterOperator {
472    type Err = error::Error;
473
474    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
475        match s {
476            "==" => Ok(FilterOperator::Equal),
477            "!=" => Ok(FilterOperator::NotEqual),
478            "<" => Ok(FilterOperator::Less),
479            ">" => Ok(FilterOperator::Greater),
480            "<=" => Ok(FilterOperator::LeOrEq),
481            ">=" => Ok(FilterOperator::GrOrEq),
482            "~=" => Ok(FilterOperator::Regex),
483            "in" => Ok(FilterOperator::In),
484            "notIn" => Ok(FilterOperator::NotIn),
485            _ => Err(error::parser_error(format!("Unknown filter operator: {s}"))),
486        }
487    }
488}
489
490#[cfg(test)]
491mod test {
492    use super::*;
493    use rstest::*;
494
495    #[rstest]
496    #[case("abc", DataInput::Key("abc".to_string()))]
497    #[case("'abc'", DataInput::Value(DataValue::from("abc")))]
498    #[case("1u32", DataInput::Value(DataValue::from(1u32)))]
499    #[case("1i32", DataInput::Value(DataValue::from(1i32)))]
500    #[case("1u64", DataInput::Value(DataValue::from(1u64)))]
501    #[case("1i64", DataInput::Value(DataValue::from(1i64)))]
502    #[case("1f64", DataInput::Value(DataValue::from(1f64)))]
503    #[case("null", DataInput::Value(DataValue::Null))]
504    #[case("true", DataInput::Value(DataValue::from(true)))]
505    #[case("false", DataInput::Value(DataValue::from(false)))]
506    #[case("1.0", DataInput::Value(DataValue::from(1f64)))]
507    #[case("[1u32, 1f64, 'abc', notakey]", DataInput::Value(DataValue::Vec(vec![
508        DataValue::from(1u32),
509        DataValue::from(1f64),
510        DataValue::from("abc"),
511        DataValue::from("notakey"),
512    ])))]
513    #[case("1.0f32", DataInput::Value(DataValue::from(1f32)))]
514    #[case("1", DataInput::Value(DataValue::from(1i64)))]
515    fn test_parser(#[case] input: &str, #[case] expected: DataInput) {
516        let result = DataInput::try_from(input);
517        assert!(result.is_ok(), "Failed to parse '{input}' {result:?}");
518        assert_eq!(result.unwrap(), expected);
519    }
520
521    #[rstest]
522    #[case("abc > 1u32", FilterRules{ rules: vec![FilterCombinantion::Simple(Expression {
523        left: DataInput::Key("abc".to_string()),
524        operator: FilterOperator::Greater,
525        right: DataInput::Value(DataValue::from(1u32)),
526    })] })]
527    #[case("abc > 1u32 && c == 'a'", FilterRules{ rules: vec![FilterCombinantion::And(Expression {
528        left: DataInput::Key("abc".to_string()),
529        operator: FilterOperator::Greater,
530        right: DataInput::Value(DataValue::from(1u32)),
531    }, Box::new(
532        FilterCombinantion::Simple(Expression {
533            left: DataInput::Key("c".to_string()),
534            operator: FilterOperator::Equal,
535            right: DataInput::Value(DataValue::from("a")),
536        }),
537    ))] })]
538    #[case("abc > 1u32 || c <= 12.0f64", FilterRules{ rules: vec![FilterCombinantion::Or(Expression {
539        left: DataInput::Key("abc".to_string()),
540        operator: FilterOperator::Greater,
541        right: DataInput::Value(DataValue::from(1u32)),
542    }, Box::new(
543        FilterCombinantion::Simple(Expression {
544            left: DataInput::Key("c".to_string()),
545            operator: FilterOperator::LeOrEq,
546            right: DataInput::Value(DataValue::from(12f64)),
547        }),
548    ))] })]
549    #[case("abc in [1i32] && (g >= 1u64 || c ~= '.*')", FilterRules{ rules: vec![FilterCombinantion::And(Expression {
550        left: DataInput::Key("abc".to_string()),
551        operator: FilterOperator::In,
552        right: DataInput::Value(DataValue::Vec(vec![1i32.into()])),
553    }, Box::new(
554        FilterCombinantion::Or(Expression {
555            left: DataInput::Key("g".to_string()),
556            operator: FilterOperator::GrOrEq,
557            right: DataInput::Value(DataValue::from(1u64)),
558        }, Box::new(
559            FilterCombinantion::Simple(Expression {
560                left: DataInput::Key("c".to_string()),
561                operator: FilterOperator::Regex,
562                right: DataInput::Value(DataValue::from(".*")),
563            }),
564        )),
565    ))] })]
566    fn test_parser_filter(#[case] input: &str, #[case] expected: FilterRules) {
567        let result = FilterRules::try_from(input);
568        assert!(result.is_ok(), "Failed to parse '{input}' {result:?}");
569        assert_eq!(result.unwrap(), expected);
570    }
571
572    #[rstest]
573    #[case("abc.len() > 1u32", FilterRules{ rules: vec![FilterCombinantion::Simple(Expression {
574        left: DataInput::Function("abc".to_string(), Function::Len),
575        operator: FilterOperator::Greater,
576        right: DataInput::Value(DataValue::from(1u32)),
577    })] })]
578    #[case("abc.to_datetime_us() > '2025-07-01 00:00:00' && c == 'a'", FilterRules{ rules: vec![FilterCombinantion::And(Expression {
579        left: DataInput::Function("abc".to_string(), Function::ToDateTimeUs),
580        operator: FilterOperator::Greater,
581        right: DataInput::Value(DataValue::from("2025-07-01 00:00:00")),
582    }, Box::new(
583        FilterCombinantion::Simple(Expression {
584            left: DataInput::Key("c".to_string()),
585            operator: FilterOperator::Equal,
586            right: DataInput::Value(DataValue::from("a")),
587        }),
588    ))] })]
589    #[case("abc % 1u32 == 1u32", FilterRules{ rules: vec![FilterCombinantion::Simple(Expression {
590        left: DataInput::Mod("abc".to_string(), DataValue::U32(1)),
591        operator: FilterOperator::Equal,
592        right: DataInput::Value(DataValue::from(1u32)),
593    })] })]
594
595    fn test_functions(#[case] input: &str, #[case] expected: FilterRules) {
596        let result = FilterRules::try_from(input);
597        assert!(result.is_ok(), "Failed to parse '{input}' {result:?}");
598        assert_eq!(result.unwrap(), expected);
599    }
600
601    #[rstest]
602    #[case("")]
603    #[case("a >>= 1i32")]
604    #[case("a.unknown_fn() == 1i32")]
605    #[case("garbage")]
606    fn parser_rejects_invalid_input(#[case] input: &str) {
607        assert!(
608            FilterRules::try_from(input).is_err(),
609            "input {input:?} should not parse"
610        );
611    }
612
613    #[rstest]
614    #[case("a in [1i32, 2i32, 3i32]")]
615    #[case("a notIn [10i32, 20i32]")]
616    fn parser_accepts_in_and_not_in(#[case] input: &str) {
617        // Smoke-test the In/NotIn parser path — we don't pin the exact tree.
618        let parsed = FilterRules::try_from(input);
619        assert!(parsed.is_ok(), "input {input:?} should parse: {parsed:?}");
620    }
621
622    #[rstest]
623    fn parser_accepts_regex() {
624        // `~=` is the regex operator.
625        let parsed = FilterRules::try_from("a ~= 'foo'");
626        assert!(parsed.is_ok(), "parsed = {parsed:?}");
627    }
628
629    #[rstest]
630    fn parser_accepts_grouped_combination() {
631        // The grammar requires a leading non-grouped expression before the
632        // first parenthesized group.
633        let parsed = FilterRules::try_from("a == 1i32 && (b == 2i32 || c == 3i32)");
634        assert!(parsed.is_ok(), "parsed = {parsed:?}");
635    }
636}