1use varpulis_core::ast::*;
8use varpulis_core::span::Spanned;
9
10pub fn fold_program(program: Program) -> Program {
12 Program {
13 statements: program
14 .statements
15 .into_iter()
16 .map(fold_spanned_stmt)
17 .collect(),
18 }
19}
20
21fn fold_spanned_stmt(s: Spanned<Stmt>) -> Spanned<Stmt> {
22 Spanned::new(fold_stmt(s.node), s.span)
23}
24
25fn fold_stmt(stmt: Stmt) -> Stmt {
26 match stmt {
27 Stmt::VarDecl {
28 mutable,
29 name,
30 ty,
31 value,
32 } => Stmt::VarDecl {
33 mutable,
34 name,
35 ty,
36 value: fold_expr(value),
37 },
38 Stmt::ConstDecl { name, ty, value } => Stmt::ConstDecl {
39 name,
40 ty,
41 value: fold_expr(value),
42 },
43 Stmt::FnDecl {
44 name,
45 params,
46 ret,
47 body,
48 } => Stmt::FnDecl {
49 name,
50 params,
51 ret,
52 body: body.into_iter().map(fold_spanned_stmt).collect(),
53 },
54 Stmt::StreamDecl {
55 name,
56 type_annotation,
57 source,
58 ops,
59 op_spans,
60 } => Stmt::StreamDecl {
61 name,
62 type_annotation,
63 source,
64 ops: ops.into_iter().map(fold_stream_op).collect(),
65 op_spans,
66 },
67 Stmt::If {
68 cond,
69 then_branch,
70 elif_branches,
71 else_branch,
72 } => Stmt::If {
73 cond: fold_expr(cond),
74 then_branch: then_branch.into_iter().map(fold_spanned_stmt).collect(),
75 elif_branches: elif_branches
76 .into_iter()
77 .map(|(c, b)| (fold_expr(c), b.into_iter().map(fold_spanned_stmt).collect()))
78 .collect(),
79 else_branch: else_branch.map(|b| b.into_iter().map(fold_spanned_stmt).collect()),
80 },
81 Stmt::For { var, iter, body } => Stmt::For {
82 var,
83 iter: fold_expr(iter),
84 body: body.into_iter().map(fold_spanned_stmt).collect(),
85 },
86 Stmt::While { cond, body } => Stmt::While {
87 cond: fold_expr(cond),
88 body: body.into_iter().map(fold_spanned_stmt).collect(),
89 },
90 Stmt::Return(Some(expr)) => Stmt::Return(Some(fold_expr(expr))),
91 Stmt::Expr(expr) => Stmt::Expr(fold_expr(expr)),
92 Stmt::Assignment { name, value } => Stmt::Assignment {
93 name,
94 value: fold_expr(value),
95 },
96 Stmt::Emit { event_type, fields } => Stmt::Emit {
97 event_type,
98 fields: fields.into_iter().map(fold_named_arg).collect(),
99 },
100 other => other,
102 }
103}
104
105fn fold_named_arg(arg: NamedArg) -> NamedArg {
106 NamedArg {
107 name: arg.name,
108 value: fold_expr(arg.value),
109 }
110}
111
112fn fold_stream_op(op: StreamOp) -> StreamOp {
113 match op {
114 StreamOp::Where(expr) => StreamOp::Where(fold_expr(expr)),
115 StreamOp::Filter(expr) => StreamOp::Filter(fold_expr(expr)),
116 StreamOp::Map(expr) => StreamOp::Map(fold_expr(expr)),
117 StreamOp::Process(expr) => StreamOp::Process(fold_expr(expr)),
118 StreamOp::OnError(expr) => StreamOp::OnError(fold_expr(expr)),
119 StreamOp::Having(expr) => StreamOp::Having(fold_expr(expr)),
120 StreamOp::PartitionBy(expr) => StreamOp::PartitionBy(fold_expr(expr)),
121 StreamOp::Limit(expr) => StreamOp::Limit(fold_expr(expr)),
122 StreamOp::Distinct(opt) => StreamOp::Distinct(opt.map(fold_expr)),
123 StreamOp::On(expr) => StreamOp::On(fold_expr(expr)),
124 StreamOp::Within(expr) => StreamOp::Within(fold_expr(expr)),
125 StreamOp::AllowedLateness(expr) => StreamOp::AllowedLateness(fold_expr(expr)),
126 StreamOp::ToExpr(expr) => StreamOp::ToExpr(fold_expr(expr)),
127 StreamOp::Emit {
128 output_type,
129 fields,
130 target_context,
131 } => StreamOp::Emit {
132 output_type,
133 fields: fields.into_iter().map(fold_named_arg).collect(),
134 target_context,
135 },
136 StreamOp::Print(exprs) => StreamOp::Print(exprs.into_iter().map(fold_expr).collect()),
137 StreamOp::Log(args) => StreamOp::Log(args.into_iter().map(fold_named_arg).collect()),
138 StreamOp::Tap(args) => StreamOp::Tap(args.into_iter().map(fold_named_arg).collect()),
139 other => other,
141 }
142}
143
144fn fold_expr(expr: Expr) -> Expr {
146 match expr {
147 Expr::Binary { op, left, right } => {
149 let left = fold_expr(*left);
150 let right = fold_expr(*right);
151 fold_binary(op, left, right)
152 }
153 Expr::Unary { op, expr } => {
155 let inner = fold_expr(*expr);
156 fold_unary(op, inner)
157 }
158 Expr::Call { func, args } => Expr::Call {
160 func: Box::new(fold_expr(*func)),
161 args: args.into_iter().map(fold_arg).collect(),
162 },
163 Expr::Array(elems) => Expr::Array(elems.into_iter().map(fold_expr).collect()),
165 Expr::Map(entries) => Expr::Map(
167 entries
168 .into_iter()
169 .map(|(k, v)| (k, fold_expr(v)))
170 .collect(),
171 ),
172 Expr::Lambda { params, body } => Expr::Lambda {
174 params,
175 body: Box::new(fold_expr(*body)),
176 },
177 Expr::If {
179 cond,
180 then_branch,
181 else_branch,
182 } => Expr::If {
183 cond: Box::new(fold_expr(*cond)),
184 then_branch: Box::new(fold_expr(*then_branch)),
185 else_branch: Box::new(fold_expr(*else_branch)),
186 },
187 Expr::Coalesce { expr, default } => Expr::Coalesce {
189 expr: Box::new(fold_expr(*expr)),
190 default: Box::new(fold_expr(*default)),
191 },
192 Expr::Range {
194 start,
195 end,
196 inclusive,
197 } => Expr::Range {
198 start: Box::new(fold_expr(*start)),
199 end: Box::new(fold_expr(*end)),
200 inclusive,
201 },
202 Expr::Member { expr, member } => Expr::Member {
204 expr: Box::new(fold_expr(*expr)),
205 member,
206 },
207 Expr::OptionalMember { expr, member } => Expr::OptionalMember {
209 expr: Box::new(fold_expr(*expr)),
210 member,
211 },
212 Expr::Index { expr, index } => Expr::Index {
214 expr: Box::new(fold_expr(*expr)),
215 index: Box::new(fold_expr(*index)),
216 },
217 Expr::Slice { expr, start, end } => Expr::Slice {
219 expr: Box::new(fold_expr(*expr)),
220 start: start.map(|s| Box::new(fold_expr(*s))),
221 end: end.map(|e| Box::new(fold_expr(*e))),
222 },
223 Expr::Block { stmts, result } => Expr::Block {
225 stmts: stmts
226 .into_iter()
227 .map(|(name, ty, val, mutable)| (name, ty, fold_expr(val), mutable))
228 .collect(),
229 result: Box::new(fold_expr(*result)),
230 },
231 other => other,
233 }
234}
235
236fn fold_arg(arg: Arg) -> Arg {
237 match arg {
238 Arg::Positional(expr) => Arg::Positional(fold_expr(expr)),
239 Arg::Named(name, expr) => Arg::Named(name, fold_expr(expr)),
240 }
241}
242
243fn fold_binary(op: BinOp, left: Expr, right: Expr) -> Expr {
245 match (&op, &left, &right) {
247 (BinOp::Add, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_add(*b)),
249 (BinOp::Sub, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_sub(*b)),
250 (BinOp::Mul, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_mul(*b)),
251 (BinOp::Div, Expr::Int(a), Expr::Int(b)) if *b != 0 => return Expr::Int(a / b),
252 (BinOp::Mod, Expr::Int(a), Expr::Int(b)) if *b != 0 => return Expr::Int(a % b),
253 (BinOp::Pow, Expr::Int(a), Expr::Int(b)) if *b >= 0 => {
254 return Expr::Int(a.wrapping_pow(*b as u32));
255 }
256
257 (BinOp::Add, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a + b),
259 (BinOp::Sub, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a - b),
260 (BinOp::Mul, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a * b),
261 (BinOp::Div, Expr::Float(a), Expr::Float(b)) if *b != 0.0 => {
262 return Expr::Float(a / b);
263 }
264
265 _ => {}
266 }
267
268 match (&op, &left, &right) {
270 (BinOp::Mul, _, Expr::Int(0)) | (BinOp::Mul, Expr::Int(0), _) => {
272 return Expr::Int(0);
273 }
274 (BinOp::Mul, _, Expr::Int(1)) => return left,
276 (BinOp::Mul, Expr::Int(1), _) => return right,
278 (BinOp::Add, _, Expr::Int(0)) => return left,
280 (BinOp::Add, Expr::Int(0), _) => return right,
282 (BinOp::Sub, _, Expr::Int(0)) => return left,
284 (BinOp::Div, _, Expr::Int(1)) => return left,
286
287 _ => {}
288 }
289
290 Expr::Binary {
292 op,
293 left: Box::new(left),
294 right: Box::new(right),
295 }
296}
297
298fn fold_unary(op: UnaryOp, inner: Expr) -> Expr {
300 match (&op, &inner) {
301 (UnaryOp::Neg, Expr::Int(a)) => Expr::Int(-a),
302 (UnaryOp::Neg, Expr::Float(a)) => Expr::Float(-a),
303 _ => Expr::Unary {
304 op,
305 expr: Box::new(inner),
306 },
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 fn bin(op: BinOp, left: Expr, right: Expr) -> Expr {
316 Expr::Binary {
317 op,
318 left: Box::new(left),
319 right: Box::new(right),
320 }
321 }
322
323 fn unary(op: UnaryOp, expr: Expr) -> Expr {
324 Expr::Unary {
325 op,
326 expr: Box::new(expr),
327 }
328 }
329
330 #[test]
331 fn fold_int_addition() {
332 let expr = bin(BinOp::Add, Expr::Int(1), Expr::Int(2));
333 assert_eq!(fold_expr(expr), Expr::Int(3));
334 }
335
336 #[test]
337 fn fold_int_subtraction() {
338 let expr = bin(BinOp::Sub, Expr::Int(5), Expr::Int(3));
339 assert_eq!(fold_expr(expr), Expr::Int(2));
340 }
341
342 #[test]
343 fn fold_int_multiplication() {
344 let expr = bin(BinOp::Mul, Expr::Int(2), Expr::Int(3));
345 assert_eq!(fold_expr(expr), Expr::Int(6));
346 }
347
348 #[test]
349 fn fold_int_division() {
350 let expr = bin(BinOp::Div, Expr::Int(10), Expr::Int(3));
351 assert_eq!(fold_expr(expr), Expr::Int(3));
352 }
353
354 #[test]
355 fn fold_int_pow() {
356 let expr = bin(BinOp::Pow, Expr::Int(2), Expr::Int(10));
357 assert_eq!(fold_expr(expr), Expr::Int(1024));
358 }
359
360 #[test]
361 fn fold_float_arithmetic() {
362 let expr = bin(BinOp::Mul, Expr::Float(2.0), Expr::Float(3.0));
363 assert_eq!(fold_expr(expr), Expr::Float(6.0));
364 }
365
366 #[test]
367 fn fold_identity_mul_zero() {
368 let expr = bin(BinOp::Mul, Expr::Ident("x".into()), Expr::Int(0));
370 assert_eq!(fold_expr(expr), Expr::Int(0));
371 }
372
373 #[test]
374 fn fold_identity_mul_one() {
375 let expr = bin(BinOp::Mul, Expr::Ident("x".into()), Expr::Int(1));
377 assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
378 }
379
380 #[test]
381 fn fold_identity_add_zero() {
382 let expr = bin(BinOp::Add, Expr::Ident("x".into()), Expr::Int(0));
384 assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
385 }
386
387 #[test]
388 fn fold_nested_expression() {
389 let inner = bin(BinOp::Add, Expr::Int(2), Expr::Int(3));
391 let expr = bin(BinOp::Mul, inner, Expr::Int(4));
392 assert_eq!(fold_expr(expr), Expr::Int(20));
393 }
394
395 #[test]
396 fn fold_unary_neg_int() {
397 let expr = unary(UnaryOp::Neg, Expr::Int(5));
398 assert_eq!(fold_expr(expr), Expr::Int(-5));
399 }
400
401 #[test]
402 fn fold_unary_neg_float() {
403 let expr = unary(UnaryOp::Neg, Expr::Float(2.75));
404 assert_eq!(fold_expr(expr), Expr::Float(-2.75));
405 }
406
407 #[test]
408 fn preserves_non_constant() {
409 let expr = bin(BinOp::Add, Expr::Ident("x".into()), Expr::Int(1));
411 let folded = fold_expr(expr.clone());
412 assert_eq!(folded, expr);
413 }
414
415 #[test]
416 fn fold_in_call_args() {
417 let call = Expr::Call {
419 func: Box::new(Expr::Ident("f".into())),
420 args: vec![Arg::Positional(bin(BinOp::Mul, Expr::Int(2), Expr::Int(3)))],
421 };
422 let folded = fold_expr(call);
423 assert_eq!(
424 folded,
425 Expr::Call {
426 func: Box::new(Expr::Ident("f".into())),
427 args: vec![Arg::Positional(Expr::Int(6))],
428 }
429 );
430 }
431
432 #[test]
433 fn div_by_zero_not_folded() {
434 let expr = bin(BinOp::Div, Expr::Int(1), Expr::Int(0));
436 let folded = fold_expr(expr);
437 assert!(matches!(folded, Expr::Binary { .. }));
438 }
439
440 #[test]
441 fn fold_identity_sub_zero() {
442 let expr = bin(BinOp::Sub, Expr::Ident("x".into()), Expr::Int(0));
444 assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
445 }
446
447 #[test]
448 fn fold_identity_div_one() {
449 let expr = bin(BinOp::Div, Expr::Ident("x".into()), Expr::Int(1));
451 assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
452 }
453}