vortex_expr/forms/
nnf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
5
6use crate::traversal::{Node as _, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
7use crate::{
8    BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralVTable, NotVTable, Operator,
9    not,
10};
11
12/// Return an equivalent expression in Negative Normal Form ([NNF](https://en.wikipedia.org/wiki/Negation_normal_form)).
13///
14/// In NNF, [crate::NotExpr] expressions may only contain terminal nodes such as [Literal](crate::LiteralExpr) or
15/// [GetItem](crate::GetItemExpr). They *may not* contain [crate::BinaryExpr], [crate::NotExpr], etc.
16///
17/// # Examples
18///
19/// Double negation is removed entirely:
20///
21/// ```
22/// use vortex_expr::{not, col};
23/// use vortex_expr::forms::nnf::nnf;
24///
25/// let double_negation = not(not(col("a")));
26/// let nnfed = nnf(double_negation);
27/// assert_eq!(&nnfed, &col("a"));
28/// ```
29///
30/// Triple negation becomes single negation:
31///
32/// ```
33/// use vortex_expr::{not, col};
34/// use vortex_expr::forms::nnf::nnf;
35///
36/// let triple_negation = not(not(not(col("a"))));
37/// let nnfed = nnf(triple_negation);
38/// assert_eq!(&nnfed, &not(col("a")));
39/// ```
40///
41/// Negation at a high-level is pushed to the leaves, likely increasing the total number nodes:
42///
43/// ```
44/// use vortex_expr::{not, col, and, or};
45/// use vortex_expr::forms::nnf::nnf;
46///
47/// assert_eq!(
48///     &nnf(not(and(col("a"), col("b")))),
49///     &or(not(col("a")), not(col("b")))
50/// );
51/// ```
52///
53/// In Vortex, NNF is extended beyond simple Boolean operators to any Boolean-valued operator:
54///
55/// ```
56/// use vortex_expr::{not, col, and, or, lt, lit, gt_eq};
57/// use vortex_expr::forms::nnf::nnf;
58///
59/// assert_eq!(
60///     &nnf(not(and(gt_eq(col("a"), lit(3)), col("b")))),
61///     &or(lt(col("a"), lit(3)), not(col("b")))
62/// );
63/// ```
64pub fn nnf(expr: ExprRef) -> ExprRef {
65    let mut rewriter = NNFRewriter::default();
66    expr.rewrite(&mut rewriter)
67        .vortex_expect("cannot fail")
68        .value
69}
70
71/// Verifies whether the expression is in Negative Normal Form ([NNF](https://en.wikipedia.org/wiki/Negation_normal_form)).
72///
73/// Note that NNF isn't canonical, different expressions might be logically equivalent but different.
74pub fn is_nnf(expr: &ExprRef) -> bool {
75    let mut visitor = NNFValidationVisitor::default();
76    expr.accept(&mut visitor).vortex_expect("never fails");
77    visitor.is_nnf
78}
79
80#[derive(Default)]
81struct NNFRewriter {}
82
83impl NodeRewriter for NNFRewriter {
84    type NodeTy = ExprRef;
85
86    fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
87        match node.as_opt::<NotVTable>() {
88            None => Ok(Transformed::no(node)),
89            Some(not_expr) => {
90                let child = not_expr.child();
91                if let Some(binary_expr) = child.as_opt::<BinaryVTable>() {
92                    let new_op = match binary_expr.op() {
93                        Operator::Eq => Operator::NotEq,
94                        Operator::NotEq => Operator::Eq,
95                        Operator::Gt => Operator::Lte,
96                        Operator::Gte => Operator::Lt,
97                        Operator::Lt => Operator::Gte,
98                        Operator::Lte => Operator::Gt,
99                        Operator::And => Operator::Or,
100                        Operator::Or => Operator::And,
101                        Operator::Add => {
102                            vortex_bail!("nnf: type mismatch: cannot negate addition")
103                        }
104                    };
105                    let (lhs, rhs) = match binary_expr.op() {
106                        Operator::Or | Operator::And => (
107                            not(binary_expr.lhs().clone()),
108                            not(binary_expr.rhs().clone()),
109                        ),
110                        _ => (binary_expr.lhs().clone(), binary_expr.rhs().clone()),
111                    };
112
113                    Ok(Transformed::yes(
114                        BinaryExpr::new(lhs, new_op, rhs).into_expr(),
115                    ))
116                } else if let Some(inner_not_expr) = child.as_opt::<NotVTable>() {
117                    Ok(Transformed::yes(inner_not_expr.child().clone()))
118                } else {
119                    Ok(Transformed::no(node))
120                }
121            }
122        }
123    }
124}
125
126struct NNFValidationVisitor {
127    is_nnf: bool,
128}
129
130impl Default for NNFValidationVisitor {
131    fn default() -> Self {
132        Self { is_nnf: true }
133    }
134}
135
136impl<'a> NodeVisitor<'a> for NNFValidationVisitor {
137    type NodeTy = ExprRef;
138
139    fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
140        fn expr_is_var(expr: &ExprRef) -> bool {
141            expr.is::<LiteralVTable>() || expr.is::<GetItemVTable>()
142        }
143
144        if let Some(not_expr) = node.as_opt::<NotVTable>() {
145            let is_var = expr_is_var(not_expr.child());
146            self.is_nnf &= is_var;
147            if !is_var {
148                return Ok(TraversalOrder::Stop);
149            }
150        }
151
152        Ok(TraversalOrder::Continue)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use rstest::rstest;
159
160    use super::*;
161    use crate::{and, col, gt_eq, lit, lt, or};
162
163    #[rstest]
164    #[case(
165        and(not(and(lit(true), lit(true))), and(lit(true), lit(true))),
166        and(or(not(lit(true)), not(lit(true))), and(lit(true), lit(true)))
167    )]
168    #[case(not(and(col("a"), col("b"))), or(not(col("a")), not(col("b"))))]
169    #[case(
170        not(and(gt_eq(col("a"), lit(3)), col("b"))),
171        or(lt(col("a"), lit(3)), not(col("b")))
172    )]
173    #[case::double_negation(not(not(col("a"))), col("a"))]
174    #[case::triple_negation(not(not(not(col("a")))), not(col("a")))]
175    #[case(
176        not(and(not(gt_eq(col("a"), lit(3))), col("b"))),
177        or(gt_eq(col("a"), lit(3)), not(col("b")))
178    )]
179    #[case(
180        not(and(and(not(gt_eq(col("a"), lit(3))), col("c")), col("b"))),
181        or(or(gt_eq(col("a"), lit(3)), not(col("c"))), not(col("b")))
182    )]
183    fn basic_nnf_test(#[case] input: ExprRef, #[case] expected: ExprRef) {
184        let output = nnf(input.clone());
185
186        assert_eq!(
187            &output, &expected,
188            "\nOriginal expr: {input}\nRewritten expr: {output}\nexpected expr: {expected}"
189        );
190        assert!(is_nnf(&output));
191    }
192
193    #[rstest]
194    #[case(not(not(col("a"))), false)]
195    #[case(not(not(not(col("a")))), false)]
196    #[case(not(col("a")), true)]
197    #[case(col("a"), true)]
198    fn test_nnf_validation(#[case] expr: ExprRef, #[case] valid: bool) {
199        assert_eq!(is_nnf(&expr), valid);
200    }
201}