1use itertools::Itertools;
5use vortex_error::{VortexExpect, VortexResult};
6
7use super::nnf::nnf;
8use crate::traversal::{Node as _, NodeVisitor, TraversalOrder};
9use crate::{BinaryVTable, ExprRef, Operator, lit, or};
10
11pub fn cnf(expr: ExprRef) -> Vec<ExprRef> {
53 if expr == lit(true) {
54 return vec![];
56 }
57 let nnf = nnf(expr);
58
59 let mut visitor = CNFVisitor::default();
60 nnf.accept(&mut visitor).vortex_expect("cannot fail");
61 visitor
62 .finish()
63 .into_iter()
64 .filter_map(|disjunction| disjunction.into_iter().reduce(or))
65 .collect_vec()
66}
67
68#[derive(Default)]
69struct CNFVisitor {
70 conjuncts_of_disjuncts: Vec<Vec<ExprRef>>,
71}
72
73impl CNFVisitor {
74 fn finish(self) -> Vec<Vec<ExprRef>> {
75 self.conjuncts_of_disjuncts
76 }
77}
78
79impl NodeVisitor<'_> for CNFVisitor {
80 type NodeTy = ExprRef;
81
82 fn visit_down(&mut self, node: &ExprRef) -> VortexResult<TraversalOrder> {
83 if let Some(binary_expr) = node.as_opt::<BinaryVTable>() {
84 match binary_expr.op() {
85 Operator::And => return Ok(TraversalOrder::Continue),
86 Operator::Or => {
87 let mut visitor = CNFVisitor::default();
88 binary_expr.lhs().accept(&mut visitor)?;
89 let lhs_conjuncts = visitor.finish();
90
91 let mut visitor = CNFVisitor::default();
92 binary_expr.rhs().accept(&mut visitor)?;
93 let rhs_conjuncts = visitor.finish();
94
95 self.conjuncts_of_disjuncts
96 .extend(lhs_conjuncts.iter().flat_map(|lhs_disjunct| {
97 rhs_conjuncts.iter().map(|rhs_disjunct| {
98 let mut lhs_copy = lhs_disjunct.clone();
99 lhs_copy.extend(rhs_disjunct.iter().cloned());
100 lhs_copy
101 })
102 }));
103
104 return Ok(TraversalOrder::Skip);
105 }
106 _ => {}
107 }
108 }
109 self.conjuncts_of_disjuncts.push(vec![node.clone()]);
111 Ok(TraversalOrder::Skip)
112 }
113}
114
115#[cfg(test)]
116mod tests {
117
118 use vortex_expr::forms::cnf::cnf;
119 use vortex_expr::{and, col, eq, gt_eq, lit, lt, not_eq, or};
120
121 #[test]
122 fn test_cnf_simple() {
123 assert_eq!(
124 cnf(or(or(and(col("a"), col("b")), col("c")), col("d"))),
125 vec![
126 or(or(col("a"), col("c")), col("d")),
127 or(or(col("b"), col("c")), col("d"))
128 ]
129 );
130 }
131
132 #[test]
133 fn test_with_lit() {
134 assert_eq!(
135 cnf(or(
136 and(
137 gt_eq(col("earnings"), lit(50_000)),
138 not_eq(col("role"), lit("Manager"))
139 ),
140 col("special_flag")
141 ),),
142 vec![
143 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
144 or(not_eq(col("role"), lit("Manager")), col("special_flag"))
145 ]
146 );
147 }
148
149 #[test]
150 fn test_cnf() {
151 assert_eq!(
152 cnf(or(
153 or(
154 and(
155 gt_eq(col("earnings"), lit(50_000)),
156 not_eq(col("role"), lit("Manager"))
157 ),
158 col("special_flag")
159 ),
160 and(lt(col("tenure"), lit(5)), eq(col("role"), lit("Engineer"))),
161 )),
162 vec![
163 or(
164 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
165 lt(col("tenure"), lit(5))
166 ),
167 or(
168 or(gt_eq(col("earnings"), lit(50_000)), col("special_flag")),
169 eq(col("role"), lit("Engineer"))
170 ),
171 or(
172 or(not_eq(col("role"), lit("Manager")), col("special_flag")),
173 lt(col("tenure"), lit(5))
174 ),
175 or(
176 or(not_eq(col("role"), lit("Manager")), col("special_flag")),
177 eq(col("role"), lit("Engineer"))
178 )
179 ]
180 );
181 }
182}