1use itertools::Itertools;
2use vortex_error::{VortexExpect, VortexResult};
3
4use super::nnf::nnf;
5use crate::traversal::{Node as _, NodeVisitor, TraversalOrder};
6use crate::{BinaryExpr, ExprRef, Operator, lit, or};
7
8pub fn cnf(expr: ExprRef) -> Vec<ExprRef> {
50 if expr == lit(true) {
51 return vec![];
53 }
54 let nnf = nnf(expr);
55
56 let mut visitor = CNFVisitor::default();
57 nnf.accept(&mut visitor).vortex_expect("cannot fail");
58 visitor
59 .finish()
60 .into_iter()
61 .filter_map(|disjunction| disjunction.into_iter().reduce(or))
62 .collect_vec()
63}
64
65#[derive(Default)]
66struct CNFVisitor {
67 conjuncts_of_disjuncts: Vec<Vec<ExprRef>>,
68}
69
70impl CNFVisitor {
71 fn finish(self) -> Vec<Vec<ExprRef>> {
72 self.conjuncts_of_disjuncts
73 }
74}
75
76impl NodeVisitor<'_> for CNFVisitor {
77 type NodeTy = ExprRef;
78
79 fn visit_down(&mut self, node: &ExprRef) -> VortexResult<TraversalOrder> {
80 if let Some(binary_expr) = node.as_any().downcast_ref::<BinaryExpr>() {
81 match binary_expr.op() {
82 Operator::And => return Ok(TraversalOrder::Continue),
83 Operator::Or => {
84 let mut visitor = CNFVisitor::default();
85 binary_expr.lhs().accept(&mut visitor)?;
86 let lhs_conjuncts = visitor.finish();
87
88 let mut visitor = CNFVisitor::default();
89 binary_expr.rhs().accept(&mut visitor)?;
90 let rhs_conjuncts = visitor.finish();
91
92 self.conjuncts_of_disjuncts
93 .extend(lhs_conjuncts.iter().flat_map(|lhs_disjunct| {
94 rhs_conjuncts.iter().map(|rhs_disjunct| {
95 let mut lhs_copy = lhs_disjunct.clone();
96 lhs_copy.extend(rhs_disjunct.iter().cloned());
97 lhs_copy
98 })
99 }));
100
101 return Ok(TraversalOrder::Skip);
102 }
103 _ => {}
104 }
105 }
106 self.conjuncts_of_disjuncts.push(vec![node.clone()]);
108 Ok(TraversalOrder::Skip)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114
115 use vortex_expr::forms::cnf::cnf;
116 use vortex_expr::{and, col, eq, gt_eq, lit, lt, not_eq, or};
117
118 #[test]
119 fn test_cnf_simple() {
120 assert_eq!(
121 cnf(or(or(and(col("a"), col("b")), col("c")), col("d"))),
122 vec![
123 or(or(col("a"), col("c")), col("d")),
124 or(or(col("b"), col("c")), col("d"))
125 ]
126 );
127 }
128
129 #[test]
130 fn test_with_lit() {
131 assert_eq!(
132 cnf(or(
133 and(
134 gt_eq(col("earnings"), lit(50_000)),
135 not_eq(col("role"), lit("Manager"))
136 ),
137 col("special_flag")
138 ),),
139 vec![
140 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
141 or(not_eq(col("role"), lit("Manager")), col("special_flag"))
142 ]
143 );
144 }
145
146 #[test]
147 fn test_cnf() {
148 assert_eq!(
149 cnf(or(
150 or(
151 and(
152 gt_eq(col("earnings"), lit(50_000)),
153 not_eq(col("role"), lit("Manager"))
154 ),
155 col("special_flag")
156 ),
157 and(lt(col("tenure"), lit(5)), eq(col("role"), lit("Engineer"))),
158 )),
159 vec![
160 or(
161 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
162 lt(col("tenure"), lit(5))
163 ),
164 or(
165 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
166 eq(col("role"), lit("Engineer"))
167 ),
168 or(
169 or(not_eq(col("role"), lit("Manager")), col("special_flag")),
170 lt(col("tenure"), lit(5))
171 ),
172 or(
173 or(not_eq(col("role"), lit("Manager")), col("special_flag")),
174 eq(col("role"), lit("Engineer"))
175 )
176 ]
177 );
178 }
179}