1use std::fmt;
2
3use super::{Operand, PredicateVisitor, ScalarValue, VisitOutcome};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
7pub enum ComparisonOp {
8 Equal,
10 NotEqual,
12 LessThan,
14 LessThanOrEqual,
16 GreaterThan,
18 GreaterThanOrEqual,
20}
21
22impl ComparisonOp {
23 #[must_use]
25 pub fn flipped(self) -> Self {
26 match self {
27 ComparisonOp::Equal => ComparisonOp::Equal,
28 ComparisonOp::NotEqual => ComparisonOp::NotEqual,
29 ComparisonOp::LessThan => ComparisonOp::GreaterThan,
30 ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThanOrEqual,
31 ComparisonOp::GreaterThan => ComparisonOp::LessThan,
32 ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThanOrEqual,
33 }
34 }
35
36 #[must_use]
38 fn negated(self) -> Self {
39 match self {
40 ComparisonOp::Equal => ComparisonOp::NotEqual,
41 ComparisonOp::NotEqual => ComparisonOp::Equal,
42 ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual,
43 ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan,
44 ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual,
45 ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan,
46 }
47 }
48}
49
50impl fmt::Display for ComparisonOp {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.write_str(match self {
53 ComparisonOp::Equal => "=",
54 ComparisonOp::NotEqual => "!=",
55 ComparisonOp::LessThan => "<",
56 ComparisonOp::LessThanOrEqual => "<=",
57 ComparisonOp::GreaterThan => ">",
58 ComparisonOp::GreaterThanOrEqual => ">=",
59 })
60 }
61}
62
63#[derive(Clone, Debug, PartialEq)]
65pub enum PredicateNode {
66 True,
68 Compare {
70 left: Operand,
72 op: ComparisonOp,
74 right: Operand,
76 },
77 InList {
79 expr: Operand,
81 list: Vec<ScalarValue>,
83 negated: bool,
85 },
86 IsNull {
88 expr: Operand,
90 negated: bool,
92 },
93 Not(Box<Predicate>),
95 And(Vec<Predicate>),
97 Or(Vec<Predicate>),
99}
100
101impl PredicateNode {
102 #[must_use]
104 pub(crate) fn is_leaf(&self) -> bool {
105 matches!(
106 self,
107 PredicateNode::True
108 | PredicateNode::Compare { .. }
109 | PredicateNode::InList { .. }
110 | PredicateNode::IsNull { .. }
111 )
112 }
113}
114
115#[derive(Clone, Debug, PartialEq)]
117pub struct Predicate {
118 kind: PredicateNode,
119}
120
121impl Predicate {
122 #[must_use]
124 pub fn kind(&self) -> &PredicateNode {
125 &self.kind
126 }
127
128 #[must_use]
134 pub fn and<I>(clauses: I) -> Self
135 where
136 I: IntoIterator<Item = Predicate>,
137 {
138 let mut acc = Vec::new();
139 for clause in clauses {
140 match clause.into_kind() {
141 PredicateNode::And(mut nested) => acc.append(&mut nested),
142 other => acc.push(Predicate::from_kind(other)),
143 }
144 }
145
146 assert!(
147 !acc.is_empty(),
148 "Predicate::and requires at least one clause"
149 );
150
151 if acc.len() == 1 {
152 acc.pop().expect("length checked")
153 } else {
154 Self::from_kind(PredicateNode::And(acc))
155 }
156 }
157
158 #[must_use]
164 pub fn or<I>(clauses: I) -> Self
165 where
166 I: IntoIterator<Item = Predicate>,
167 {
168 let mut acc = Vec::new();
169 for clause in clauses {
170 match clause.into_kind() {
171 PredicateNode::Or(mut nested) => acc.append(&mut nested),
172 other => acc.push(Predicate::from_kind(other)),
173 }
174 }
175
176 assert!(
177 !acc.is_empty(),
178 "Predicate::or requires at least one clause"
179 );
180
181 if acc.len() == 1 {
182 acc.pop().expect("length checked")
183 } else {
184 Self::from_kind(PredicateNode::Or(acc))
185 }
186 }
187
188 #[must_use]
190 pub fn simplify(self) -> Self {
191 match self.kind {
192 PredicateNode::True
193 | PredicateNode::Compare { .. }
194 | PredicateNode::InList { .. }
195 | PredicateNode::IsNull { .. } => self,
196 PredicateNode::Not(inner) => {
197 let simplified_child = inner.simplify();
198 match simplified_child.into_kind() {
199 PredicateNode::Not(grandchild) => *grandchild,
200 other => Self::from_kind(PredicateNode::Not(Box::new(Self::from_kind(other)))),
201 }
202 }
203 PredicateNode::And(clauses) => {
204 Predicate::and(clauses.into_iter().map(Predicate::simplify))
205 }
206 PredicateNode::Or(clauses) => {
207 Predicate::or(clauses.into_iter().map(Predicate::simplify))
208 }
209 }
210 }
211
212 #[must_use]
214 pub fn negate(self) -> Self {
215 let negated = match self.kind {
216 PredicateNode::True
217 | PredicateNode::Compare { .. }
218 | PredicateNode::InList { .. }
219 | PredicateNode::IsNull { .. } => Predicate::negate_leaf(self.into_kind()),
220 PredicateNode::Not(inner) => *inner,
221 PredicateNode::And(children) => {
222 let negated_children: Vec<_> =
223 children.into_iter().map(Predicate::negate).collect();
224 Predicate::or(negated_children)
225 }
226 PredicateNode::Or(children) => {
227 let negated_children: Vec<_> =
228 children.into_iter().map(Predicate::negate).collect();
229 Predicate::and(negated_children)
230 }
231 };
232 negated.simplify()
233 }
234
235 #[must_use]
237 pub fn conjunction(predicates: Vec<Predicate>) -> Option<Predicate> {
238 match predicates.len() {
239 0 => None,
240 1 => predicates.into_iter().next(),
241 _ => Some(Predicate::and(predicates).simplify()),
242 }
243 }
244
245 #[must_use]
247 pub fn disjunction(predicates: Vec<Predicate>) -> Option<Predicate> {
248 match predicates.len() {
249 0 => None,
250 1 => predicates.into_iter().next(),
251 _ => Some(Predicate::or(predicates).simplify()),
252 }
253 }
254
255 #[must_use]
257 pub fn from_node(node: PredicateNode) -> Self {
258 Self::from_kind(node)
259 }
260
261 pub fn accept<V>(&self, visitor: &mut V) -> Result<VisitOutcome<V::Value>, V::Error>
263 where
264 V: PredicateVisitor + ?Sized,
265 {
266 visitor.visit_predicate(self)
267 }
268
269 pub(crate) fn from_kind(kind: PredicateNode) -> Self {
270 Self { kind }
271 }
272
273 fn into_kind(self) -> PredicateNode {
274 self.kind
275 }
276
277 fn negate_leaf(leaf: PredicateNode) -> Predicate {
278 let negated = match leaf {
279 PredicateNode::True => {
280 return Predicate::from_kind(PredicateNode::Not(Box::new(Predicate::from_kind(
282 PredicateNode::True,
283 ))));
284 }
285 PredicateNode::Compare { left, op, right } => PredicateNode::Compare {
286 left,
287 op: op.negated(),
288 right,
289 },
290 PredicateNode::InList {
291 expr,
292 list,
293 negated,
294 } => PredicateNode::InList {
295 expr,
296 list,
297 negated: !negated,
298 },
299 PredicateNode::IsNull { expr, negated } => PredicateNode::IsNull {
300 expr,
301 negated: !negated,
302 },
303 PredicateNode::Not(_) | PredicateNode::And(_) | PredicateNode::Or(_) => {
304 unreachable!("negate_leaf only handles leaf variants")
305 }
306 };
307 Predicate::from_kind(negated)
308 }
309}