sea_core/policy/
expression.rs

1use rust_decimal::Decimal;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
6pub struct WindowSpec {
7    pub duration: u64,
8    pub unit: String,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum Expression {
13    Literal(serde_json::Value),
14
15    QuantityLiteral {
16        value: Decimal,
17        unit: String,
18    },
19
20    TimeLiteral(String), // ISO 8601 timestamp
21
22    IntervalLiteral {
23        start: String, // Time string (e.g., "09:00")
24        end: String,   // Time string (e.g., "17:00")
25    },
26
27    Variable(String),
28
29    GroupBy {
30        variable: String,
31        collection: Box<Expression>,
32        filter: Option<Box<Expression>>,
33        key: Box<Expression>,
34        condition: Box<Expression>,
35    },
36
37    Binary {
38        op: BinaryOp,
39        left: Box<Expression>,
40        right: Box<Expression>,
41    },
42
43    Unary {
44        op: UnaryOp,
45        operand: Box<Expression>,
46    },
47
48    Cast {
49        operand: Box<Expression>,
50        target_type: String,
51    },
52
53    Quantifier {
54        quantifier: Quantifier,
55        variable: String,
56        collection: Box<Expression>,
57        condition: Box<Expression>,
58    },
59
60    MemberAccess {
61        object: String,
62        member: String,
63    },
64
65    Aggregation {
66        function: AggregateFunction,
67        collection: Box<Expression>,
68        field: Option<String>,
69        filter: Option<Box<Expression>>,
70    },
71
72    AggregationComprehension {
73        function: AggregateFunction,
74        variable: String,
75        collection: Box<Expression>,
76        window: Option<WindowSpec>,
77        predicate: Box<Expression>,
78        projection: Box<Expression>,
79        target_unit: Option<String>,
80    },
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum BinaryOp {
85    And,
86    Or,
87
88    Equal,
89    NotEqual,
90    GreaterThan,
91    LessThan,
92    GreaterThanOrEqual,
93    LessThanOrEqual,
94
95    Plus,
96    Minus,
97    Multiply,
98    Divide,
99
100    Contains,
101    StartsWith,
102    EndsWith,
103    Matches,
104
105    HasRole,
106
107    // Temporal operators
108    Before,
109    After,
110    During,
111}
112
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum UnaryOp {
115    Not,
116    Negate,
117}
118
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub enum Quantifier {
121    ForAll,
122    Exists,
123    ExistsUnique,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum AggregateFunction {
128    Count,
129    Sum,
130    Min,
131    Max,
132    Avg,
133}
134
135impl Expression {
136    /// Returns a canonical normalized form of this expression.
137    ///
138    /// The normalized form applies various simplification rules:
139    /// - Identity elimination (`true AND x` → `x`)
140    /// - Domination (`false AND x` → `false`)
141    /// - Idempotence (`a AND a` → `a`)
142    /// - Absorption (`a OR (a AND b)` → `a`)
143    /// - Double negation (`NOT NOT x` → `x`)
144    /// - Commutative sorting (operands sorted lexicographically)
145    #[must_use]
146    pub fn normalize(&self) -> super::NormalizedExpression {
147        super::NormalizedExpression::new(self)
148    }
149
150    /// Returns true if two expressions are semantically equivalent.
151    ///
152    /// Two expressions are equivalent if their normalized forms are identical.
153    #[must_use]
154    pub fn is_equivalent(&self, other: &Expression) -> bool {
155        self.normalize() == other.normalize()
156    }
157
158    pub fn literal(value: impl Into<serde_json::Value>) -> Self {
159        Expression::Literal(value.into())
160    }
161
162    pub fn variable(name: &str) -> Self {
163        Expression::Variable(name.to_string())
164    }
165
166    pub fn binary(op: BinaryOp, left: Expression, right: Expression) -> Self {
167        Expression::Binary {
168            op,
169            left: Box::new(left),
170            right: Box::new(right),
171        }
172    }
173
174    pub fn unary(op: UnaryOp, operand: Expression) -> Self {
175        Expression::Unary {
176            op,
177            operand: Box::new(operand),
178        }
179    }
180
181    pub fn quantifier(
182        q: Quantifier,
183        var: &str,
184        collection: Expression,
185        condition: Expression,
186    ) -> Self {
187        Expression::Quantifier {
188            quantifier: q,
189            variable: var.to_string(),
190            collection: Box::new(collection),
191            condition: Box::new(condition),
192        }
193    }
194
195    pub fn cast(operand: Expression, target_type: impl Into<String>) -> Self {
196        Expression::Cast {
197            operand: Box::new(operand),
198            target_type: target_type.into(),
199        }
200    }
201
202    pub fn comparison(
203        var: &str,
204        op: &str,
205        value: impl Into<serde_json::Value>,
206    ) -> Result<Self, String> {
207        let op = match op {
208            ">" => BinaryOp::GreaterThan,
209            "<" => BinaryOp::LessThan,
210            ">=" => BinaryOp::GreaterThanOrEqual,
211            "<=" => BinaryOp::LessThanOrEqual,
212            "==" => BinaryOp::Equal,
213            "!=" => BinaryOp::NotEqual,
214            _ => return Err(format!("Unknown operator: {}", op)),
215        };
216
217        Ok(Expression::binary(
218            op,
219            Expression::variable(var),
220            Expression::literal(value),
221        ))
222    }
223
224    pub fn aggregation(
225        function: AggregateFunction,
226        collection: Expression,
227        field: Option<impl Into<String>>,
228        filter: Option<Expression>,
229    ) -> Self {
230        Expression::Aggregation {
231            function,
232            collection: Box::new(collection),
233            field: field.map(|f| f.into()),
234            filter: filter.map(Box::new),
235        }
236    }
237
238    pub fn member_access(object: &str, member: &str) -> Self {
239        Expression::MemberAccess {
240            object: object.to_string(),
241            member: member.to_string(),
242        }
243    }
244}
245
246impl fmt::Display for Expression {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        match self {
249            Expression::Literal(v) => write!(f, "{}", v),
250            Expression::QuantityLiteral { value, unit } => {
251                write!(f, "{} {}", value, unit)
252            }
253            Expression::TimeLiteral(timestamp) => write!(f, "\"{}\"", timestamp),
254            Expression::IntervalLiteral { start, end } => {
255                write!(f, "interval(\"{}\", \"{}\")", start, end)
256            }
257            Expression::Variable(n) => write!(f, "{}", n),
258            Expression::GroupBy {
259                variable,
260                collection,
261                filter,
262                key,
263                condition,
264            } => {
265                write!(f, "group_by({} in {}", variable, collection)?;
266                if let Some(flt) = filter {
267                    write!(f, " WHERE {}", flt)?;
268                }
269                write!(f, ": {}) {{ {} }}", key, condition)
270            }
271            Expression::Binary { op, left, right } => {
272                write!(f, "({} {} {})", left, op, right)
273            }
274            Expression::Unary { op, operand } => {
275                write!(f, "{} {}", op, operand)
276            }
277            Expression::Cast {
278                operand,
279                target_type,
280            } => {
281                write!(f, "{} as \"{}\"", operand, target_type)
282            }
283            Expression::Quantifier {
284                quantifier,
285                variable,
286                collection,
287                condition,
288            } => {
289                write!(
290                    f,
291                    "{}({} in {}: {})",
292                    quantifier, variable, collection, condition
293                )
294            }
295            Expression::MemberAccess { object, member } => {
296                write!(f, "{}.{}", object, member)
297            }
298            Expression::Aggregation {
299                function,
300                collection,
301                field,
302                filter,
303            } => {
304                write!(f, "{}({}", function, collection)?;
305                if let Some(fld) = field {
306                    write!(f, ".{}", fld)?;
307                }
308                if let Some(flt) = filter {
309                    write!(f, " WHERE {}", flt)?;
310                }
311                write!(f, ")")
312            }
313            Expression::AggregationComprehension {
314                function,
315                variable,
316                collection,
317                window,
318                predicate,
319                projection,
320                target_unit,
321            } => {
322                write!(f, "{}({} in {}", function, variable, collection)?;
323                if let Some(w) = window {
324                    write!(f, " OVER LAST {} \"{}\"", w.duration, w.unit)?;
325                }
326                match predicate.as_ref() {
327                    Expression::Literal(serde_json::Value::Bool(true)) => {
328                        write!(f, ": {}", projection)?;
329                    }
330                    _ => {
331                        write!(f, " WHERE {}: {}", predicate, projection)?;
332                    }
333                }
334                if let Some(unit) = target_unit {
335                    write!(f, " AS \"{}\"", unit)?;
336                }
337                write!(f, ")")
338            }
339        }
340    }
341}
342
343impl fmt::Display for BinaryOp {
344    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345        match self {
346            BinaryOp::And => write!(f, "AND"),
347            BinaryOp::Or => write!(f, "OR"),
348            BinaryOp::Equal => write!(f, "=="),
349            BinaryOp::NotEqual => write!(f, "!="),
350            BinaryOp::GreaterThan => write!(f, ">"),
351            BinaryOp::LessThan => write!(f, "<"),
352            BinaryOp::GreaterThanOrEqual => write!(f, ">="),
353            BinaryOp::LessThanOrEqual => write!(f, "<="),
354            BinaryOp::Plus => write!(f, "+"),
355            BinaryOp::Minus => write!(f, "-"),
356            BinaryOp::Multiply => write!(f, "*"),
357            BinaryOp::Divide => write!(f, "/"),
358            BinaryOp::Contains => write!(f, "CONTAINS"),
359            BinaryOp::StartsWith => write!(f, "STARTS_WITH"),
360            BinaryOp::EndsWith => write!(f, "ENDS_WITH"),
361            BinaryOp::Matches => write!(f, "MATCHES"),
362            BinaryOp::HasRole => write!(f, "HAS_ROLE"),
363            BinaryOp::Before => write!(f, "BEFORE"),
364            BinaryOp::After => write!(f, "AFTER"),
365            BinaryOp::During => write!(f, "DURING"),
366        }
367    }
368}
369
370impl fmt::Display for UnaryOp {
371    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
372        match self {
373            UnaryOp::Not => write!(f, "NOT"),
374            UnaryOp::Negate => write!(f, "-"),
375        }
376    }
377}
378
379impl fmt::Display for Quantifier {
380    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
381        match self {
382            Quantifier::ForAll => write!(f, "ForAll"),
383            Quantifier::Exists => write!(f, "Exists"),
384            Quantifier::ExistsUnique => write!(f, "ExistsUnique"),
385        }
386    }
387}
388
389impl fmt::Display for AggregateFunction {
390    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
391        match self {
392            AggregateFunction::Count => write!(f, "COUNT"),
393            AggregateFunction::Sum => write!(f, "SUM"),
394            AggregateFunction::Min => write!(f, "MIN"),
395            AggregateFunction::Max => write!(f, "MAX"),
396            AggregateFunction::Avg => write!(f, "AVG"),
397        }
398    }
399}