1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::sync::Arc;
4
5use vortex_array::compute::{BetweenOptions, between};
6use vortex_array::{Array, ArrayRef};
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::VortexResult;
10
11use crate::{BinaryExpr, ExprRef, VortexExpr};
12
13#[derive(Debug, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct Between {
16 arr: ExprRef,
17 lower: ExprRef,
18 upper: ExprRef,
19 options: BetweenOptions,
20}
21
22impl Between {
23 pub fn between(
24 arr: ExprRef,
25 lower: ExprRef,
26 upper: ExprRef,
27 options: BetweenOptions,
28 ) -> ExprRef {
29 Arc::new(Self {
30 arr,
31 lower,
32 upper,
33 options,
34 })
35 }
36
37 pub fn to_binary_expr(&self) -> ExprRef {
38 let lhs = BinaryExpr::new_expr(
39 self.lower.clone(),
40 self.options.lower_strict.to_operator().into(),
41 self.arr.clone(),
42 );
43 let rhs = BinaryExpr::new_expr(
44 self.arr.clone(),
45 self.options.upper_strict.to_operator().into(),
46 self.upper.clone(),
47 );
48 BinaryExpr::new_expr(lhs, crate::Operator::And, rhs)
49 }
50}
51
52impl Display for Between {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(
55 f,
56 "({} {} {} {} {})",
57 self.lower,
58 self.options.lower_strict.to_operator(),
59 self.arr,
60 self.options.upper_strict.to_operator(),
61 self.upper
62 )
63 }
64}
65
66impl PartialEq for Between {
67 fn eq(&self, other: &Between) -> bool {
68 self.arr.eq(&other.arr)
69 && other.lower.eq(&self.lower)
70 && other.upper.eq(&self.upper)
71 && self.options == other.options
72 }
73}
74
75#[cfg(feature = "proto")]
76pub(crate) mod proto {
77 use vortex_array::compute::{BetweenOptions, StrictComparison};
78 use vortex_error::{VortexResult, vortex_bail};
79 use vortex_proto::expr::kind;
80 use vortex_proto::expr::kind::Kind;
81
82 use crate::between::Between;
83 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
84
85 pub(crate) struct BetweenSerde;
86
87 impl Id for BetweenSerde {
88 fn id(&self) -> &'static str {
89 "between"
90 }
91 }
92
93 impl ExprSerializable for Between {
94 fn id(&self) -> &'static str {
95 BetweenSerde.id()
96 }
97
98 fn serialize_kind(&self) -> VortexResult<Kind> {
99 Ok(Kind::Between(kind::Between {
100 lower_strict: self.options.lower_strict == StrictComparison::Strict,
101 upper_strict: self.options.upper_strict == StrictComparison::Strict,
102 }))
103 }
104 }
105
106 impl ExprDeserialize for BetweenSerde {
107 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
108 let Kind::Between(between) = kind else {
109 vortex_bail!("wrong kind {:?}, want between", kind)
110 };
111
112 Ok(Between::between(
113 children[0].clone(),
114 children[1].clone(),
115 children[2].clone(),
116 BetweenOptions {
117 lower_strict: if between.lower_strict {
118 StrictComparison::Strict
119 } else {
120 StrictComparison::NonStrict
121 },
122 upper_strict: if between.upper_strict {
123 StrictComparison::Strict
124 } else {
125 StrictComparison::NonStrict
126 },
127 },
128 ))
129 }
130 }
131}
132
133impl VortexExpr for Between {
134 fn as_any(&self) -> &dyn Any {
135 self
136 }
137
138 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
139 let arr_val = self.arr.evaluate(batch)?;
140 let lower_arr_val = self.lower.evaluate(batch)?;
141 let upper_arr_val = self.upper.evaluate(batch)?;
142
143 between(&arr_val, &lower_arr_val, &upper_arr_val, &self.options)
144 }
145
146 fn children(&self) -> Vec<&ExprRef> {
147 vec![&self.arr, &self.lower, &self.upper]
148 }
149
150 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
151 Arc::new(Self {
152 arr: children[0].clone(),
153 lower: children[1].clone(),
154 upper: children[2].clone(),
155 options: self.options.clone(),
156 })
157 }
158
159 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
160 let arr_dt = self.arr.return_dtype(scope_dtype)?;
161 let lower_dt = self.lower.return_dtype(scope_dtype)?;
162 let upper_dt = self.upper.return_dtype(scope_dtype)?;
163
164 assert!(arr_dt.eq_ignore_nullability(&lower_dt));
165 assert!(arr_dt.eq_ignore_nullability(&upper_dt));
166
167 Ok(Bool(
168 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
169 ))
170 }
171}