sea_core/policy/
expression.rs

1use rust_decimal::Decimal;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
6pub struct WindowSpec {
7    pub duration: u64,
8    pub unit: String,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum Expression {
13    Literal(serde_json::Value),
14
15    QuantityLiteral {
16        value: Decimal,
17        unit: String,
18    },
19
20    TimeLiteral(String), // ISO 8601 timestamp
21
22    IntervalLiteral {
23        start: String, // Time string (e.g., "09:00")
24        end: String,   // Time string (e.g., "17:00")
25    },
26
27    Variable(String),
28
29    GroupBy {
30        variable: String,
31        collection: Box<Expression>,
32        filter: Option<Box<Expression>>,
33        key: Box<Expression>,
34        condition: Box<Expression>,
35    },
36
37    Binary {
38        op: BinaryOp,
39        left: Box<Expression>,
40        right: Box<Expression>,
41    },
42
43    Unary {
44        op: UnaryOp,
45        operand: Box<Expression>,
46    },
47
48    Cast {
49        operand: Box<Expression>,
50        target_type: String,
51    },
52
53    Quantifier {
54        quantifier: Quantifier,
55        variable: String,
56        collection: Box<Expression>,
57        condition: Box<Expression>,
58    },
59
60    MemberAccess {
61        object: String,
62        member: String,
63    },
64
65    Aggregation {
66        function: AggregateFunction,
67        collection: Box<Expression>,
68        field: Option<String>,
69        filter: Option<Box<Expression>>,
70    },
71
72    AggregationComprehension {
73        function: AggregateFunction,
74        variable: String,
75        collection: Box<Expression>,
76        window: Option<WindowSpec>,
77        predicate: Box<Expression>,
78        projection: Box<Expression>,
79        target_unit: Option<String>,
80    },
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum BinaryOp {
85    And,
86    Or,
87
88    Equal,
89    NotEqual,
90    GreaterThan,
91    LessThan,
92    GreaterThanOrEqual,
93    LessThanOrEqual,
94
95    Plus,
96    Minus,
97    Multiply,
98    Divide,
99
100    Contains,
101    StartsWith,
102    EndsWith,
103    Matches,
104
105    HasRole,
106
107    // Temporal operators
108    Before,
109    After,
110    During,
111}
112
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum UnaryOp {
115    Not,
116    Negate,
117}
118
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub enum Quantifier {
121    ForAll,
122    Exists,
123    ExistsUnique,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum AggregateFunction {
128    Count,
129    Sum,
130    Min,
131    Max,
132    Avg,
133}
134
135impl Expression {
136    pub fn literal(value: impl Into<serde_json::Value>) -> Self {
137        Expression::Literal(value.into())
138    }
139
140    pub fn variable(name: &str) -> Self {
141        Expression::Variable(name.to_string())
142    }
143
144    pub fn binary(op: BinaryOp, left: Expression, right: Expression) -> Self {
145        Expression::Binary {
146            op,
147            left: Box::new(left),
148            right: Box::new(right),
149        }
150    }
151
152    pub fn unary(op: UnaryOp, operand: Expression) -> Self {
153        Expression::Unary {
154            op,
155            operand: Box::new(operand),
156        }
157    }
158
159    pub fn quantifier(
160        q: Quantifier,
161        var: &str,
162        collection: Expression,
163        condition: Expression,
164    ) -> Self {
165        Expression::Quantifier {
166            quantifier: q,
167            variable: var.to_string(),
168            collection: Box::new(collection),
169            condition: Box::new(condition),
170        }
171    }
172
173    pub fn cast(operand: Expression, target_type: impl Into<String>) -> Self {
174        Expression::Cast {
175            operand: Box::new(operand),
176            target_type: target_type.into(),
177        }
178    }
179
180    pub fn comparison(
181        var: &str,
182        op: &str,
183        value: impl Into<serde_json::Value>,
184    ) -> Result<Self, String> {
185        let op = match op {
186            ">" => BinaryOp::GreaterThan,
187            "<" => BinaryOp::LessThan,
188            ">=" => BinaryOp::GreaterThanOrEqual,
189            "<=" => BinaryOp::LessThanOrEqual,
190            "==" => BinaryOp::Equal,
191            "!=" => BinaryOp::NotEqual,
192            _ => return Err(format!("Unknown operator: {}", op)),
193        };
194
195        Ok(Expression::binary(
196            op,
197            Expression::variable(var),
198            Expression::literal(value),
199        ))
200    }
201
202    pub fn aggregation(
203        function: AggregateFunction,
204        collection: Expression,
205        field: Option<impl Into<String>>,
206        filter: Option<Expression>,
207    ) -> Self {
208        Expression::Aggregation {
209            function,
210            collection: Box::new(collection),
211            field: field.map(|f| f.into()),
212            filter: filter.map(Box::new),
213        }
214    }
215
216    pub fn member_access(object: &str, member: &str) -> Self {
217        Expression::MemberAccess {
218            object: object.to_string(),
219            member: member.to_string(),
220        }
221    }
222}
223
224impl fmt::Display for Expression {
225    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226        match self {
227            Expression::Literal(v) => write!(f, "{}", v),
228            Expression::QuantityLiteral { value, unit } => {
229                write!(f, "{} {}", value, unit)
230            }
231            Expression::TimeLiteral(timestamp) => write!(f, "\"{}\"", timestamp),
232            Expression::IntervalLiteral { start, end } => {
233                write!(f, "interval(\"{}\", \"{}\")", start, end)
234            }
235            Expression::Variable(n) => write!(f, "{}", n),
236            Expression::GroupBy {
237                variable,
238                collection,
239                filter,
240                key,
241                condition,
242            } => {
243                write!(f, "group_by({} in {}", variable, collection)?;
244                if let Some(flt) = filter {
245                    write!(f, " WHERE {}", flt)?;
246                }
247                write!(f, ": {}) {{ {} }}", key, condition)
248            }
249            Expression::Binary { op, left, right } => {
250                write!(f, "({} {} {})", left, op, right)
251            }
252            Expression::Unary { op, operand } => {
253                write!(f, "{} {}", op, operand)
254            }
255            Expression::Cast {
256                operand,
257                target_type,
258            } => {
259                write!(f, "{} as \"{}\"", operand, target_type)
260            }
261            Expression::Quantifier {
262                quantifier,
263                variable,
264                collection,
265                condition,
266            } => {
267                write!(
268                    f,
269                    "{}({} in {}: {})",
270                    quantifier, variable, collection, condition
271                )
272            }
273            Expression::MemberAccess { object, member } => {
274                write!(f, "{}.{}", object, member)
275            }
276            Expression::Aggregation {
277                function,
278                collection,
279                field,
280                filter,
281            } => {
282                write!(f, "{}({}", function, collection)?;
283                if let Some(fld) = field {
284                    write!(f, ".{}", fld)?;
285                }
286                if let Some(flt) = filter {
287                    write!(f, " WHERE {}", flt)?;
288                }
289                write!(f, ")")
290            }
291            Expression::AggregationComprehension {
292                function,
293                variable,
294                collection,
295                window,
296                predicate,
297                projection,
298                target_unit,
299            } => {
300                write!(f, "{}({} in {}", function, variable, collection)?;
301                if let Some(w) = window {
302                    write!(f, " OVER LAST {} \"{}\"", w.duration, w.unit)?;
303                }
304                match predicate.as_ref() {
305                    Expression::Literal(serde_json::Value::Bool(true)) => {
306                        write!(f, ": {}", projection)?;
307                    }
308                    _ => {
309                        write!(f, " WHERE {}: {}", predicate, projection)?;
310                    }
311                }
312                if let Some(unit) = target_unit {
313                    write!(f, " AS \"{}\"", unit)?;
314                }
315                write!(f, ")")
316            }
317        }
318    }
319}
320
321impl fmt::Display for BinaryOp {
322    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323        match self {
324            BinaryOp::And => write!(f, "AND"),
325            BinaryOp::Or => write!(f, "OR"),
326            BinaryOp::Equal => write!(f, "=="),
327            BinaryOp::NotEqual => write!(f, "!="),
328            BinaryOp::GreaterThan => write!(f, ">"),
329            BinaryOp::LessThan => write!(f, "<"),
330            BinaryOp::GreaterThanOrEqual => write!(f, ">="),
331            BinaryOp::LessThanOrEqual => write!(f, "<="),
332            BinaryOp::Plus => write!(f, "+"),
333            BinaryOp::Minus => write!(f, "-"),
334            BinaryOp::Multiply => write!(f, "*"),
335            BinaryOp::Divide => write!(f, "/"),
336            BinaryOp::Contains => write!(f, "CONTAINS"),
337            BinaryOp::StartsWith => write!(f, "STARTS_WITH"),
338            BinaryOp::EndsWith => write!(f, "ENDS_WITH"),
339            BinaryOp::Matches => write!(f, "MATCHES"),
340            BinaryOp::HasRole => write!(f, "HAS_ROLE"),
341            BinaryOp::Before => write!(f, "BEFORE"),
342            BinaryOp::After => write!(f, "AFTER"),
343            BinaryOp::During => write!(f, "DURING"),
344        }
345    }
346}
347
348impl fmt::Display for UnaryOp {
349    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350        match self {
351            UnaryOp::Not => write!(f, "NOT"),
352            UnaryOp::Negate => write!(f, "-"),
353        }
354    }
355}
356
357impl fmt::Display for Quantifier {
358    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
359        match self {
360            Quantifier::ForAll => write!(f, "ForAll"),
361            Quantifier::Exists => write!(f, "Exists"),
362            Quantifier::ExistsUnique => write!(f, "ExistsUnique"),
363        }
364    }
365}
366
367impl fmt::Display for AggregateFunction {
368    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
369        match self {
370            AggregateFunction::Count => write!(f, "COUNT"),
371            AggregateFunction::Sum => write!(f, "SUM"),
372            AggregateFunction::Min => write!(f, "MIN"),
373            AggregateFunction::Max => write!(f, "MAX"),
374            AggregateFunction::Avg => write!(f, "AVG"),
375        }
376    }
377}