Skip to main content

shape_ast/transform/
desugar.rs

1//! AST desugaring pass
2//!
3//! Transforms high-level syntax into equivalent lower-level constructs.
4//! Currently handles LINQ-style from queries → method chains.
5
6use crate::ast::{
7    DestructurePattern, Expr, FromQueryExpr, FunctionParameter, Item, Literal, ObjectEntry,
8    Program, QueryClause, Span, Statement,
9};
10
11/// Desugar all high-level syntax in a program.
12/// This should be called before compilation.
13pub fn desugar_program(program: &mut Program) {
14    for item in &mut program.items {
15        desugar_item(item);
16    }
17}
18
19fn desugar_item(item: &mut Item) {
20    match item {
21        Item::Function(func, _) => {
22            for stmt in &mut func.body {
23                desugar_statement(stmt);
24            }
25        }
26        Item::VariableDecl(decl, _) => {
27            if let Some(value) = &mut decl.value {
28                desugar_expr(value);
29            }
30        }
31        Item::Assignment(assign, _) => {
32            desugar_expr(&mut assign.value);
33        }
34        Item::Expression(expr, _) => {
35            desugar_expr(expr);
36        }
37        Item::Statement(stmt, _) => {
38            desugar_statement(stmt);
39        }
40        Item::Export(export, _) => match &mut export.item {
41            crate::ast::ExportItem::Function(func) => {
42                for stmt in &mut func.body {
43                    desugar_statement(stmt);
44                }
45            }
46            crate::ast::ExportItem::TypeAlias(_)
47            | crate::ast::ExportItem::Named(_)
48            | crate::ast::ExportItem::Enum(_)
49            | crate::ast::ExportItem::Struct(_)
50            | crate::ast::ExportItem::Interface(_)
51            | crate::ast::ExportItem::Trait(_)
52            | crate::ast::ExportItem::ForeignFunction(_) => {}
53        },
54        Item::Module(module, _) => {
55            for inner in &mut module.items {
56                desugar_item(inner);
57            }
58        }
59        Item::Extend(extend, _) => {
60            for method in &mut extend.methods {
61                for stmt in &mut method.body {
62                    desugar_statement(stmt);
63                }
64            }
65        }
66        Item::Stream(stream, _) => {
67            for decl in &mut stream.state {
68                if let Some(value) = &mut decl.value {
69                    desugar_expr(value);
70                }
71            }
72            if let Some(stmts) = &mut stream.on_connect {
73                for stmt in stmts {
74                    desugar_statement(stmt);
75                }
76            }
77            if let Some(stmts) = &mut stream.on_disconnect {
78                for stmt in stmts {
79                    desugar_statement(stmt);
80                }
81            }
82            if let Some(on_event) = &mut stream.on_event {
83                for stmt in &mut on_event.body {
84                    desugar_statement(stmt);
85                }
86            }
87            if let Some(on_window) = &mut stream.on_window {
88                for stmt in &mut on_window.body {
89                    desugar_statement(stmt);
90                }
91            }
92            if let Some(on_error) = &mut stream.on_error {
93                for stmt in &mut on_error.body {
94                    desugar_statement(stmt);
95                }
96            }
97        }
98        // Other items don't need desugaring
99        _ => {}
100    }
101}
102
103fn desugar_statement(stmt: &mut Statement) {
104    match stmt {
105        Statement::Return(Some(expr), _) => desugar_expr(expr),
106        Statement::VariableDecl(decl, _) => {
107            if let Some(value) = &mut decl.value {
108                desugar_expr(value);
109            }
110        }
111        Statement::Assignment(assign, _) => desugar_expr(&mut assign.value),
112        Statement::Expression(expr, _) => desugar_expr(expr),
113        Statement::For(for_loop, _) => {
114            match &mut for_loop.init {
115                crate::ast::ForInit::ForIn { iter, .. } => desugar_expr(iter),
116                crate::ast::ForInit::ForC {
117                    init,
118                    condition,
119                    update,
120                } => {
121                    desugar_statement(init);
122                    desugar_expr(condition);
123                    desugar_expr(update);
124                }
125            }
126            for s in &mut for_loop.body {
127                desugar_statement(s);
128            }
129        }
130        Statement::While(while_loop, _) => {
131            desugar_expr(&mut while_loop.condition);
132            for s in &mut while_loop.body {
133                desugar_statement(s);
134            }
135        }
136        Statement::If(if_stmt, _) => {
137            desugar_expr(&mut if_stmt.condition);
138            for s in &mut if_stmt.then_body {
139                desugar_statement(s);
140            }
141            if let Some(else_body) = &mut if_stmt.else_body {
142                for s in else_body {
143                    desugar_statement(s);
144                }
145            }
146        }
147        _ => {}
148    }
149}
150
151fn desugar_expr(expr: &mut Expr) {
152    // First recursively desugar nested expressions
153    match expr {
154        Expr::FromQuery(from_query, span) => {
155            // Desugar nested expressions in the query first
156            desugar_expr(&mut from_query.source);
157            for clause in &mut from_query.clauses {
158                match clause {
159                    QueryClause::Where(pred) => desugar_expr(pred),
160                    QueryClause::OrderBy(specs) => {
161                        for spec in specs {
162                            desugar_expr(&mut spec.key);
163                        }
164                    }
165                    QueryClause::GroupBy { element, key, .. } => {
166                        desugar_expr(element);
167                        desugar_expr(key);
168                    }
169                    QueryClause::Join {
170                        source,
171                        left_key,
172                        right_key,
173                        ..
174                    } => {
175                        desugar_expr(source);
176                        desugar_expr(left_key);
177                        desugar_expr(right_key);
178                    }
179                    QueryClause::Let { value, .. } => desugar_expr(value),
180                }
181            }
182            desugar_expr(&mut from_query.select);
183
184            // Now desugar the from query to method chains
185            let desugared = desugar_from_query(from_query, *span);
186            *expr = desugared;
187        }
188        // Recursively handle all other expression types
189        Expr::PropertyAccess { object, .. } => desugar_expr(object),
190        Expr::IndexAccess {
191            object,
192            index,
193            end_index,
194            ..
195        } => {
196            desugar_expr(object);
197            desugar_expr(index);
198            if let Some(end) = end_index {
199                desugar_expr(end);
200            }
201        }
202        Expr::BinaryOp { left, right, .. } => {
203            desugar_expr(left);
204            desugar_expr(right);
205        }
206        Expr::FuzzyComparison { left, right, .. } => {
207            desugar_expr(left);
208            desugar_expr(right);
209        }
210        Expr::UnaryOp { operand, .. } => desugar_expr(operand),
211        Expr::FunctionCall {
212            args, named_args, ..
213        } => {
214            for arg in args {
215                desugar_expr(arg);
216            }
217            for (_, val) in named_args {
218                desugar_expr(val);
219            }
220        }
221        Expr::MethodCall {
222            receiver,
223            args,
224            named_args,
225            ..
226        } => {
227            desugar_expr(receiver);
228            for arg in args {
229                desugar_expr(arg);
230            }
231            for (_, val) in named_args {
232                desugar_expr(val);
233            }
234        }
235        Expr::Conditional {
236            condition,
237            then_expr,
238            else_expr,
239            ..
240        } => {
241            desugar_expr(condition);
242            desugar_expr(then_expr);
243            if let Some(e) = else_expr {
244                desugar_expr(e);
245            }
246        }
247        Expr::Object(entries, _) => {
248            for entry in entries {
249                match entry {
250                    ObjectEntry::Field { value, .. } => desugar_expr(value),
251                    ObjectEntry::Spread(e) => desugar_expr(e),
252                }
253            }
254        }
255        Expr::Array(elements, _) => {
256            for elem in elements {
257                desugar_expr(elem);
258            }
259        }
260        Expr::ListComprehension(comp, _) => {
261            desugar_expr(&mut comp.element);
262            for clause in &mut comp.clauses {
263                desugar_expr(&mut clause.iterable);
264                if let Some(filter) = &mut clause.filter {
265                    desugar_expr(filter);
266                }
267            }
268        }
269        Expr::Block(block, _) => {
270            for item in &mut block.items {
271                match item {
272                    crate::ast::BlockItem::VariableDecl(decl) => {
273                        if let Some(value) = &mut decl.value {
274                            desugar_expr(value);
275                        }
276                    }
277                    crate::ast::BlockItem::Assignment(assign) => {
278                        desugar_expr(&mut assign.value);
279                    }
280                    crate::ast::BlockItem::Statement(stmt) => {
281                        desugar_statement(stmt);
282                    }
283                    crate::ast::BlockItem::Expression(e) => {
284                        desugar_expr(e);
285                    }
286                }
287            }
288        }
289        Expr::TypeAssertion { expr: inner, .. } => desugar_expr(inner),
290        Expr::InstanceOf { expr: inner, .. } => desugar_expr(inner),
291        Expr::FunctionExpr { body, .. } => {
292            for stmt in body {
293                desugar_statement(stmt);
294            }
295        }
296        Expr::Spread(inner, _) => desugar_expr(inner),
297        Expr::If(if_expr, _) => {
298            desugar_expr(&mut if_expr.condition);
299            desugar_expr(&mut if_expr.then_branch);
300            if let Some(e) = &mut if_expr.else_branch {
301                desugar_expr(e);
302            }
303        }
304        Expr::While(while_expr, _) => {
305            desugar_expr(&mut while_expr.condition);
306            desugar_expr(&mut while_expr.body);
307        }
308        Expr::For(for_expr, _) => {
309            desugar_expr(&mut for_expr.iterable);
310            desugar_expr(&mut for_expr.body);
311        }
312        Expr::Loop(loop_expr, _) => {
313            desugar_expr(&mut loop_expr.body);
314        }
315        Expr::Let(let_expr, _) => {
316            if let Some(val) = &mut let_expr.value {
317                desugar_expr(val);
318            }
319            desugar_expr(&mut let_expr.body);
320        }
321        Expr::Assign(assign, _) => {
322            desugar_expr(&mut assign.target);
323            desugar_expr(&mut assign.value);
324        }
325        Expr::Break(Some(e), _) => desugar_expr(e),
326        Expr::Return(Some(e), _) => desugar_expr(e),
327        Expr::Match(match_expr, _) => {
328            desugar_expr(&mut match_expr.scrutinee);
329            for arm in &mut match_expr.arms {
330                if let Some(guard) = &mut arm.guard {
331                    desugar_expr(guard);
332                }
333                desugar_expr(&mut arm.body);
334            }
335        }
336        Expr::Range { start, end, .. } => {
337            if let Some(s) = start {
338                desugar_expr(s);
339            }
340            if let Some(e) = end {
341                desugar_expr(e);
342            }
343        }
344        Expr::TimeframeContext { expr: inner, .. } => desugar_expr(inner),
345        Expr::TryOperator(inner, _) => desugar_expr(inner),
346        Expr::UsingImpl { expr: inner, .. } => desugar_expr(inner),
347        Expr::Await(inner, _) => desugar_expr(inner),
348        Expr::EnumConstructor { payload, .. } => match payload {
349            crate::ast::EnumConstructorPayload::Tuple(args) => {
350                for arg in args {
351                    desugar_expr(arg);
352                }
353            }
354            crate::ast::EnumConstructorPayload::Struct(fields) => {
355                for (_, val) in fields {
356                    desugar_expr(val);
357                }
358            }
359            crate::ast::EnumConstructorPayload::Unit => {}
360        },
361        Expr::SimulationCall { params, .. } => {
362            for (_, val) in params {
363                desugar_expr(val);
364            }
365        }
366        Expr::WindowExpr(window_expr, _) => {
367            // Desugar window function arguments
368            match &mut window_expr.function {
369                crate::ast::WindowFunction::Lead { expr, default, .. }
370                | crate::ast::WindowFunction::Lag { expr, default, .. } => {
371                    desugar_expr(expr);
372                    if let Some(d) = default {
373                        desugar_expr(d);
374                    }
375                }
376                crate::ast::WindowFunction::FirstValue(e)
377                | crate::ast::WindowFunction::LastValue(e)
378                | crate::ast::WindowFunction::Sum(e)
379                | crate::ast::WindowFunction::Avg(e)
380                | crate::ast::WindowFunction::Min(e)
381                | crate::ast::WindowFunction::Max(e) => {
382                    desugar_expr(e);
383                }
384                crate::ast::WindowFunction::NthValue(e, _) => {
385                    desugar_expr(e);
386                }
387                crate::ast::WindowFunction::Count(Some(e)) => {
388                    desugar_expr(e);
389                }
390                _ => {}
391            }
392            // Desugar partition_by and order_by expressions
393            for e in &mut window_expr.over.partition_by {
394                desugar_expr(e);
395            }
396            if let Some(order_by) = &mut window_expr.over.order_by {
397                for (e, _) in &mut order_by.columns {
398                    desugar_expr(e);
399                }
400            }
401        }
402        Expr::DataRef(data_ref, _) => match &mut data_ref.index {
403            crate::ast::DataIndex::Expression(e) => desugar_expr(e),
404            crate::ast::DataIndex::ExpressionRange(start, end) => {
405                desugar_expr(start);
406                desugar_expr(end);
407            }
408            _ => {}
409        },
410        Expr::DataRelativeAccess {
411            reference, index, ..
412        } => {
413            desugar_expr(reference);
414            match index {
415                crate::ast::DataIndex::Expression(e) => desugar_expr(e),
416                crate::ast::DataIndex::ExpressionRange(start, end) => {
417                    desugar_expr(start);
418                    desugar_expr(end);
419                }
420                _ => {}
421            }
422        }
423        Expr::StructLiteral { fields, .. } => {
424            for (_, value) in fields {
425                desugar_expr(value);
426            }
427        }
428        // Leaf nodes - no recursion needed
429        Expr::Literal(_, _)
430        | Expr::Identifier(_, _)
431        | Expr::DataDateTimeRef(_, _)
432        | Expr::TimeRef(_, _)
433        | Expr::DateTime(_, _)
434        | Expr::PatternRef(_, _)
435        | Expr::Duration(_, _)
436        | Expr::Unit(_)
437        | Expr::Continue(_)
438        | Expr::Break(None, _)
439        | Expr::Return(None, _)
440        | Expr::Join(_, _) => {}
441        Expr::Annotated { target, .. } => {
442            desugar_expr(target);
443        }
444        Expr::AsyncLet(async_let, _) => {
445            desugar_expr(&mut async_let.expr);
446        }
447        Expr::AsyncScope(inner, _) => {
448            desugar_expr(inner);
449        }
450        Expr::Comptime(stmts, _) => {
451            for stmt in stmts {
452                desugar_statement(stmt);
453            }
454        }
455        Expr::ComptimeFor(cf, _) => {
456            desugar_expr(&mut cf.iterable);
457            for stmt in &mut cf.body {
458                desugar_statement(stmt);
459            }
460        }
461        Expr::Reference { expr: inner, .. } => desugar_expr(inner),
462        Expr::TableRows(rows, _) => {
463            for row in rows {
464                for elem in row {
465                    desugar_expr(elem);
466                }
467            }
468        }
469    }
470}
471
472/// Desugar a from query expression into method chain calls.
473///
474/// Example transformation:
475/// ```text
476/// from t in trades where t.amount > 1000 order by t.date desc select t.price
477/// ```
478/// becomes:
479/// ```text
480/// trades.filter(|t| t.amount > 1000).orderBy(|t| t.date, "desc").map(|t| t.price)
481/// ```
482fn desugar_from_query(from_query: &FromQueryExpr, span: Span) -> Expr {
483    let current_var = &from_query.variable;
484    let mut result = (*from_query.source).clone();
485
486    // Track the current iteration variable (changes after group by into)
487    let mut iter_var = current_var.clone();
488
489    for clause in &from_query.clauses {
490        match clause {
491            QueryClause::Where(pred) => {
492                // source.filter(|var| predicate)
493                // Uses "filter" (Queryable interface) instead of "where" for
494                // consistency with trait-based dispatch (DbTable, DataTable, Array).
495                result = method_call(
496                    result,
497                    "filter",
498                    vec![make_lambda(&iter_var, pred, span)],
499                    span,
500                );
501            }
502            QueryClause::OrderBy(specs) => {
503                // source.orderBy(|var| key, "dir").thenBy(|var| key2, "dir2")...
504                for (i, spec) in specs.iter().enumerate() {
505                    let method = if i == 0 { "orderBy" } else { "thenBy" };
506                    let dir = if spec.descending { "desc" } else { "asc" };
507                    result = method_call(
508                        result,
509                        method,
510                        vec![
511                            make_lambda(&iter_var, &spec.key, span),
512                            string_lit(dir, span),
513                        ],
514                        span,
515                    );
516                }
517            }
518            QueryClause::GroupBy { key, into_var, .. } => {
519                // source.groupBy(|var| key)
520                result = method_call(
521                    result,
522                    "groupBy",
523                    vec![make_lambda(&iter_var, key, span)],
524                    span,
525                );
526                // After grouping, iteration variable changes to the group
527                if let Some(var) = into_var {
528                    iter_var = var.clone();
529                }
530            }
531            QueryClause::Join {
532                variable,
533                source: join_source,
534                left_key,
535                right_key,
536                into_var,
537            } => {
538                // source.innerJoin(other, |left| leftKey, |right| rightKey, |left, right| result)
539                // or with into: source.leftJoin(...)
540                let method = if into_var.is_some() {
541                    "leftJoin"
542                } else {
543                    "innerJoin"
544                };
545
546                // Build result selector that creates object with both variables
547                let result_selector = make_binary_lambda(
548                    &iter_var,
549                    variable,
550                    &make_object(
551                        vec![
552                            (iter_var.clone(), Expr::Identifier(iter_var.clone(), span)),
553                            (variable.clone(), Expr::Identifier(variable.clone(), span)),
554                        ],
555                        span,
556                    ),
557                    span,
558                );
559
560                result = method_call(
561                    result,
562                    method,
563                    vec![
564                        (**join_source).clone(),
565                        make_lambda(&iter_var, left_key, span),
566                        make_lambda(variable, right_key, span),
567                        result_selector,
568                    ],
569                    span,
570                );
571            }
572            QueryClause::Let {
573                variable: let_var,
574                value,
575            } => {
576                // Transform to intermediate select that adds the binding
577                // source.select(|var| { __orig: var, let_var: value })
578                let intermediate = make_object(
579                    vec![
580                        (
581                            "__orig".to_string(),
582                            Expr::Identifier(iter_var.clone(), span),
583                        ),
584                        (let_var.clone(), (**value).clone()),
585                    ],
586                    span,
587                );
588                result = method_call(
589                    result,
590                    "select",
591                    vec![make_lambda(&iter_var, &intermediate, span)],
592                    span,
593                );
594                // Update iter_var to access __orig for the original variable
595                // This is a simplification; in a full implementation we'd rewrite references
596                iter_var = "__x".to_string();
597            }
598        }
599    }
600
601    // Final select → desugars to .map() for typed table algebra
602    result = method_call(
603        result,
604        "map",
605        vec![make_lambda(&iter_var, &from_query.select, span)],
606        span,
607    );
608
609    result
610}
611
612/// Create a method call expression
613fn method_call(receiver: Expr, method: &str, args: Vec<Expr>, span: Span) -> Expr {
614    Expr::MethodCall {
615        receiver: Box::new(receiver),
616        method: method.to_string(),
617        args,
618        named_args: vec![],
619        span,
620    }
621}
622
623/// Create a lambda expression: |param| body
624fn make_lambda(param: &str, body: &Expr, span: Span) -> Expr {
625    Expr::FunctionExpr {
626        params: vec![FunctionParameter {
627            pattern: DestructurePattern::Identifier(param.to_string(), span),
628            is_const: false,
629            is_reference: false,
630            is_mut_reference: false,
631            is_out: false,
632            type_annotation: None,
633            default_value: None,
634        }],
635        return_type: None,
636        body: vec![Statement::Return(Some(body.clone()), span)],
637        span,
638    }
639}
640
641/// Create a binary lambda expression: |param1, param2| body
642fn make_binary_lambda(param1: &str, param2: &str, body: &Expr, span: Span) -> Expr {
643    Expr::FunctionExpr {
644        params: vec![
645            FunctionParameter {
646                pattern: DestructurePattern::Identifier(param1.to_string(), span),
647                is_const: false,
648                is_reference: false,
649                is_mut_reference: false,
650                is_out: false,
651                type_annotation: None,
652                default_value: None,
653            },
654            FunctionParameter {
655                pattern: DestructurePattern::Identifier(param2.to_string(), span),
656                is_const: false,
657                is_reference: false,
658                is_mut_reference: false,
659                is_out: false,
660                type_annotation: None,
661                default_value: None,
662            },
663        ],
664        return_type: None,
665        body: vec![Statement::Return(Some(body.clone()), span)],
666        span,
667    }
668}
669
670/// Create a string literal expression
671fn string_lit(s: &str, span: Span) -> Expr {
672    Expr::Literal(Literal::String(s.to_string()), span)
673}
674
675/// Create an object literal expression
676fn make_object(fields: Vec<(String, Expr)>, span: Span) -> Expr {
677    let entries = fields
678        .into_iter()
679        .map(|(key, value)| ObjectEntry::Field {
680            key,
681            value,
682            type_annotation: None,
683        })
684        .collect();
685    Expr::Object(entries, span)
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::parser::parse_program;
692
693    /// Helper to extract expression from a program item
694    fn get_expr(item: &Item) -> Option<&Expr> {
695        match item {
696            Item::Expression(expr, _) => Some(expr),
697            Item::Statement(Statement::Expression(expr, _), _) => Some(expr),
698            _ => None,
699        }
700    }
701
702    #[test]
703    fn test_basic_from_query_desugaring() {
704        let code = "from x in [1, 2, 3] where x > 1 select x * 2";
705        let mut program = parse_program(code).unwrap();
706        desugar_program(&mut program);
707
708        // The desugared form should be a method chain
709        if let Some(expr) = get_expr(&program.items[0]) {
710            // Should be: [1,2,3].filter(|x| x > 1).map(|x| x * 2)
711            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
712        } else {
713            panic!("Expected expression item");
714        }
715    }
716
717    #[test]
718    fn test_order_by_desugaring() {
719        let code = "from x in arr order by x.value desc select x";
720        let mut program = parse_program(code).unwrap();
721        desugar_program(&mut program);
722
723        if let Some(expr) = get_expr(&program.items[0]) {
724            // Final call should be select
725            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
726        } else {
727            panic!("Expected expression item");
728        }
729    }
730}