1use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
5
6use crate::traversal::{Node as _, NodeRewriter, NodeVisitor, Transformed, TraversalOrder};
7use crate::{
8 BinaryExpr, BinaryVTable, ExprRef, GetItemVTable, IntoExpr, LiteralVTable, NotVTable, Operator,
9 not,
10};
11
12pub fn nnf(expr: ExprRef) -> ExprRef {
65 let mut rewriter = NNFRewriter::default();
66 expr.rewrite(&mut rewriter)
67 .vortex_expect("cannot fail")
68 .value
69}
70
71pub fn is_nnf(expr: &ExprRef) -> bool {
75 let mut visitor = NNFValidationVisitor::default();
76 expr.accept(&mut visitor).vortex_expect("never fails");
77 visitor.is_nnf
78}
79
80#[derive(Default)]
81struct NNFRewriter {}
82
83impl NodeRewriter for NNFRewriter {
84 type NodeTy = ExprRef;
85
86 fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
87 match node.as_opt::<NotVTable>() {
88 None => Ok(Transformed::no(node)),
89 Some(not_expr) => {
90 let child = not_expr.child();
91 if let Some(binary_expr) = child.as_opt::<BinaryVTable>() {
92 let new_op = match binary_expr.op() {
93 Operator::Eq => Operator::NotEq,
94 Operator::NotEq => Operator::Eq,
95 Operator::Gt => Operator::Lte,
96 Operator::Gte => Operator::Lt,
97 Operator::Lt => Operator::Gte,
98 Operator::Lte => Operator::Gt,
99 Operator::And => Operator::Or,
100 Operator::Or => Operator::And,
101 Operator::Add => {
102 vortex_bail!("nnf: type mismatch: cannot negate addition")
103 }
104 };
105 let (lhs, rhs) = match binary_expr.op() {
106 Operator::Or | Operator::And => (
107 not(binary_expr.lhs().clone()),
108 not(binary_expr.rhs().clone()),
109 ),
110 _ => (binary_expr.lhs().clone(), binary_expr.rhs().clone()),
111 };
112
113 Ok(Transformed::yes(
114 BinaryExpr::new(lhs, new_op, rhs).into_expr(),
115 ))
116 } else if let Some(inner_not_expr) = child.as_opt::<NotVTable>() {
117 Ok(Transformed::yes(inner_not_expr.child().clone()))
118 } else {
119 Ok(Transformed::no(node))
120 }
121 }
122 }
123 }
124}
125
126struct NNFValidationVisitor {
127 is_nnf: bool,
128}
129
130impl Default for NNFValidationVisitor {
131 fn default() -> Self {
132 Self { is_nnf: true }
133 }
134}
135
136impl<'a> NodeVisitor<'a> for NNFValidationVisitor {
137 type NodeTy = ExprRef;
138
139 fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
140 fn expr_is_var(expr: &ExprRef) -> bool {
141 expr.is::<LiteralVTable>() || expr.is::<GetItemVTable>()
142 }
143
144 if let Some(not_expr) = node.as_opt::<NotVTable>() {
145 let is_var = expr_is_var(not_expr.child());
146 self.is_nnf &= is_var;
147 if !is_var {
148 return Ok(TraversalOrder::Stop);
149 }
150 }
151
152 Ok(TraversalOrder::Continue)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use rstest::rstest;
159
160 use super::*;
161 use crate::{and, col, gt_eq, lit, lt, or};
162
163 #[rstest]
164 #[case(
165 and(not(and(lit(true), lit(true))), and(lit(true), lit(true))),
166 and(or(not(lit(true)), not(lit(true))), and(lit(true), lit(true)))
167 )]
168 #[case(not(and(col("a"), col("b"))), or(not(col("a")), not(col("b"))))]
169 #[case(
170 not(and(gt_eq(col("a"), lit(3)), col("b"))),
171 or(lt(col("a"), lit(3)), not(col("b")))
172 )]
173 #[case::double_negation(not(not(col("a"))), col("a"))]
174 #[case::triple_negation(not(not(not(col("a")))), not(col("a")))]
175 #[case(
176 not(and(not(gt_eq(col("a"), lit(3))), col("b"))),
177 or(gt_eq(col("a"), lit(3)), not(col("b")))
178 )]
179 #[case(
180 not(and(and(not(gt_eq(col("a"), lit(3))), col("c")), col("b"))),
181 or(or(gt_eq(col("a"), lit(3)), not(col("c"))), not(col("b")))
182 )]
183 fn basic_nnf_test(#[case] input: ExprRef, #[case] expected: ExprRef) {
184 let output = nnf(input.clone());
185
186 assert_eq!(
187 &output, &expected,
188 "\nOriginal expr: {input}\nRewritten expr: {output}\nexpected expr: {expected}"
189 );
190 assert!(is_nnf(&output));
191 }
192
193 #[rstest]
194 #[case(not(not(col("a"))), false)]
195 #[case(not(not(not(col("a")))), false)]
196 #[case(not(col("a")), true)]
197 #[case(col("a"), true)]
198 fn test_nnf_validation(#[case] expr: ExprRef, #[case] valid: bool) {
199 assert_eq!(is_nnf(&expr), valid);
200 }
201}