1use rust_decimal::Decimal;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
6pub struct WindowSpec {
7 pub duration: u64,
8 pub unit: String,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum Expression {
13 Literal(serde_json::Value),
14
15 QuantityLiteral {
16 value: Decimal,
17 unit: String,
18 },
19
20 TimeLiteral(String), IntervalLiteral {
23 start: String, end: String, },
26
27 Variable(String),
28
29 GroupBy {
30 variable: String,
31 collection: Box<Expression>,
32 filter: Option<Box<Expression>>,
33 key: Box<Expression>,
34 condition: Box<Expression>,
35 },
36
37 Binary {
38 op: BinaryOp,
39 left: Box<Expression>,
40 right: Box<Expression>,
41 },
42
43 Unary {
44 op: UnaryOp,
45 operand: Box<Expression>,
46 },
47
48 Cast {
49 operand: Box<Expression>,
50 target_type: String,
51 },
52
53 Quantifier {
54 quantifier: Quantifier,
55 variable: String,
56 collection: Box<Expression>,
57 condition: Box<Expression>,
58 },
59
60 MemberAccess {
61 object: String,
62 member: String,
63 },
64
65 Aggregation {
66 function: AggregateFunction,
67 collection: Box<Expression>,
68 field: Option<String>,
69 filter: Option<Box<Expression>>,
70 },
71
72 AggregationComprehension {
73 function: AggregateFunction,
74 variable: String,
75 collection: Box<Expression>,
76 window: Option<WindowSpec>,
77 predicate: Box<Expression>,
78 projection: Box<Expression>,
79 target_unit: Option<String>,
80 },
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum BinaryOp {
85 And,
86 Or,
87
88 Equal,
89 NotEqual,
90 GreaterThan,
91 LessThan,
92 GreaterThanOrEqual,
93 LessThanOrEqual,
94
95 Plus,
96 Minus,
97 Multiply,
98 Divide,
99
100 Contains,
101 StartsWith,
102 EndsWith,
103 Matches,
104
105 HasRole,
106
107 Before,
109 After,
110 During,
111}
112
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum UnaryOp {
115 Not,
116 Negate,
117}
118
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub enum Quantifier {
121 ForAll,
122 Exists,
123 ExistsUnique,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum AggregateFunction {
128 Count,
129 Sum,
130 Min,
131 Max,
132 Avg,
133}
134
135impl Expression {
136 #[must_use]
146 pub fn normalize(&self) -> super::NormalizedExpression {
147 super::NormalizedExpression::new(self)
148 }
149
150 #[must_use]
154 pub fn is_equivalent(&self, other: &Expression) -> bool {
155 self.normalize() == other.normalize()
156 }
157
158 pub fn literal(value: impl Into<serde_json::Value>) -> Self {
159 Expression::Literal(value.into())
160 }
161
162 pub fn variable(name: &str) -> Self {
163 Expression::Variable(name.to_string())
164 }
165
166 pub fn binary(op: BinaryOp, left: Expression, right: Expression) -> Self {
167 Expression::Binary {
168 op,
169 left: Box::new(left),
170 right: Box::new(right),
171 }
172 }
173
174 pub fn unary(op: UnaryOp, operand: Expression) -> Self {
175 Expression::Unary {
176 op,
177 operand: Box::new(operand),
178 }
179 }
180
181 pub fn quantifier(
182 q: Quantifier,
183 var: &str,
184 collection: Expression,
185 condition: Expression,
186 ) -> Self {
187 Expression::Quantifier {
188 quantifier: q,
189 variable: var.to_string(),
190 collection: Box::new(collection),
191 condition: Box::new(condition),
192 }
193 }
194
195 pub fn cast(operand: Expression, target_type: impl Into<String>) -> Self {
196 Expression::Cast {
197 operand: Box::new(operand),
198 target_type: target_type.into(),
199 }
200 }
201
202 pub fn comparison(
203 var: &str,
204 op: &str,
205 value: impl Into<serde_json::Value>,
206 ) -> Result<Self, String> {
207 let op = match op {
208 ">" => BinaryOp::GreaterThan,
209 "<" => BinaryOp::LessThan,
210 ">=" => BinaryOp::GreaterThanOrEqual,
211 "<=" => BinaryOp::LessThanOrEqual,
212 "==" => BinaryOp::Equal,
213 "!=" => BinaryOp::NotEqual,
214 _ => return Err(format!("Unknown operator: {}", op)),
215 };
216
217 Ok(Expression::binary(
218 op,
219 Expression::variable(var),
220 Expression::literal(value),
221 ))
222 }
223
224 pub fn aggregation(
225 function: AggregateFunction,
226 collection: Expression,
227 field: Option<impl Into<String>>,
228 filter: Option<Expression>,
229 ) -> Self {
230 Expression::Aggregation {
231 function,
232 collection: Box::new(collection),
233 field: field.map(|f| f.into()),
234 filter: filter.map(Box::new),
235 }
236 }
237
238 pub fn member_access(object: &str, member: &str) -> Self {
239 Expression::MemberAccess {
240 object: object.to_string(),
241 member: member.to_string(),
242 }
243 }
244}
245
246impl fmt::Display for Expression {
247 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248 match self {
249 Expression::Literal(v) => write!(f, "{}", v),
250 Expression::QuantityLiteral { value, unit } => {
251 write!(f, "{} {}", value, unit)
252 }
253 Expression::TimeLiteral(timestamp) => write!(f, "\"{}\"", timestamp),
254 Expression::IntervalLiteral { start, end } => {
255 write!(f, "interval(\"{}\", \"{}\")", start, end)
256 }
257 Expression::Variable(n) => write!(f, "{}", n),
258 Expression::GroupBy {
259 variable,
260 collection,
261 filter,
262 key,
263 condition,
264 } => {
265 write!(f, "group_by({} in {}", variable, collection)?;
266 if let Some(flt) = filter {
267 write!(f, " WHERE {}", flt)?;
268 }
269 write!(f, ": {}) {{ {} }}", key, condition)
270 }
271 Expression::Binary { op, left, right } => {
272 write!(f, "({} {} {})", left, op, right)
273 }
274 Expression::Unary { op, operand } => {
275 write!(f, "{} {}", op, operand)
276 }
277 Expression::Cast {
278 operand,
279 target_type,
280 } => {
281 write!(f, "{} as \"{}\"", operand, target_type)
282 }
283 Expression::Quantifier {
284 quantifier,
285 variable,
286 collection,
287 condition,
288 } => {
289 write!(
290 f,
291 "{}({} in {}: {})",
292 quantifier, variable, collection, condition
293 )
294 }
295 Expression::MemberAccess { object, member } => {
296 write!(f, "{}.{}", object, member)
297 }
298 Expression::Aggregation {
299 function,
300 collection,
301 field,
302 filter,
303 } => {
304 write!(f, "{}({}", function, collection)?;
305 if let Some(fld) = field {
306 write!(f, ".{}", fld)?;
307 }
308 if let Some(flt) = filter {
309 write!(f, " WHERE {}", flt)?;
310 }
311 write!(f, ")")
312 }
313 Expression::AggregationComprehension {
314 function,
315 variable,
316 collection,
317 window,
318 predicate,
319 projection,
320 target_unit,
321 } => {
322 write!(f, "{}({} in {}", function, variable, collection)?;
323 if let Some(w) = window {
324 write!(f, " OVER LAST {} \"{}\"", w.duration, w.unit)?;
325 }
326 match predicate.as_ref() {
327 Expression::Literal(serde_json::Value::Bool(true)) => {
328 write!(f, ": {}", projection)?;
329 }
330 _ => {
331 write!(f, " WHERE {}: {}", predicate, projection)?;
332 }
333 }
334 if let Some(unit) = target_unit {
335 write!(f, " AS \"{}\"", unit)?;
336 }
337 write!(f, ")")
338 }
339 }
340 }
341}
342
343impl fmt::Display for BinaryOp {
344 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345 match self {
346 BinaryOp::And => write!(f, "AND"),
347 BinaryOp::Or => write!(f, "OR"),
348 BinaryOp::Equal => write!(f, "=="),
349 BinaryOp::NotEqual => write!(f, "!="),
350 BinaryOp::GreaterThan => write!(f, ">"),
351 BinaryOp::LessThan => write!(f, "<"),
352 BinaryOp::GreaterThanOrEqual => write!(f, ">="),
353 BinaryOp::LessThanOrEqual => write!(f, "<="),
354 BinaryOp::Plus => write!(f, "+"),
355 BinaryOp::Minus => write!(f, "-"),
356 BinaryOp::Multiply => write!(f, "*"),
357 BinaryOp::Divide => write!(f, "/"),
358 BinaryOp::Contains => write!(f, "CONTAINS"),
359 BinaryOp::StartsWith => write!(f, "STARTS_WITH"),
360 BinaryOp::EndsWith => write!(f, "ENDS_WITH"),
361 BinaryOp::Matches => write!(f, "MATCHES"),
362 BinaryOp::HasRole => write!(f, "HAS_ROLE"),
363 BinaryOp::Before => write!(f, "BEFORE"),
364 BinaryOp::After => write!(f, "AFTER"),
365 BinaryOp::During => write!(f, "DURING"),
366 }
367 }
368}
369
370impl fmt::Display for UnaryOp {
371 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
372 match self {
373 UnaryOp::Not => write!(f, "NOT"),
374 UnaryOp::Negate => write!(f, "-"),
375 }
376 }
377}
378
379impl fmt::Display for Quantifier {
380 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
381 match self {
382 Quantifier::ForAll => write!(f, "ForAll"),
383 Quantifier::Exists => write!(f, "Exists"),
384 Quantifier::ExistsUnique => write!(f, "ExistsUnique"),
385 }
386 }
387}
388
389impl fmt::Display for AggregateFunction {
390 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
391 match self {
392 AggregateFunction::Count => write!(f, "COUNT"),
393 AggregateFunction::Sum => write!(f, "SUM"),
394 AggregateFunction::Min => write!(f, "MIN"),
395 AggregateFunction::Max => write!(f, "MAX"),
396 AggregateFunction::Avg => write!(f, "AVG"),
397 }
398 }
399}