Skip to main content

shape_ast/transform/
desugar.rs

1//! AST desugaring pass
2//!
3//! Transforms high-level syntax into equivalent lower-level constructs.
4//! Currently handles LINQ-style from queries → method chains.
5
6use crate::ast::{
7    DestructurePattern, Expr, FromQueryExpr, FunctionParameter, Item, Literal, ObjectEntry,
8    Program, QueryClause, Span, Statement,
9};
10
11/// Desugar all high-level syntax in a program.
12/// This should be called before compilation.
13pub fn desugar_program(program: &mut Program) {
14    for item in &mut program.items {
15        desugar_item(item);
16    }
17}
18
19fn desugar_item(item: &mut Item) {
20    match item {
21        Item::Function(func, _) => {
22            for stmt in &mut func.body {
23                desugar_statement(stmt);
24            }
25        }
26        Item::VariableDecl(decl, _) => {
27            if let Some(value) = &mut decl.value {
28                desugar_expr(value);
29            }
30        }
31        Item::Assignment(assign, _) => {
32            desugar_expr(&mut assign.value);
33        }
34        Item::Expression(expr, _) => {
35            desugar_expr(expr);
36        }
37        Item::Statement(stmt, _) => {
38            desugar_statement(stmt);
39        }
40        Item::Export(export, _) => match &mut export.item {
41            crate::ast::ExportItem::Function(func) => {
42                for stmt in &mut func.body {
43                    desugar_statement(stmt);
44                }
45            }
46            crate::ast::ExportItem::TypeAlias(_)
47            | crate::ast::ExportItem::Named(_)
48            | crate::ast::ExportItem::Enum(_)
49            | crate::ast::ExportItem::Struct(_)
50            | crate::ast::ExportItem::Interface(_)
51            | crate::ast::ExportItem::Trait(_)
52            | crate::ast::ExportItem::BuiltinFunction(_)
53            | crate::ast::ExportItem::BuiltinType(_)
54            | crate::ast::ExportItem::Annotation(_)
55            | crate::ast::ExportItem::ForeignFunction(_) => {}
56        },
57        Item::Module(module, _) => {
58            for inner in &mut module.items {
59                desugar_item(inner);
60            }
61        }
62        Item::Extend(extend, _) => {
63            for method in &mut extend.methods {
64                for stmt in &mut method.body {
65                    desugar_statement(stmt);
66                }
67            }
68        }
69        Item::Stream(stream, _) => {
70            for decl in &mut stream.state {
71                if let Some(value) = &mut decl.value {
72                    desugar_expr(value);
73                }
74            }
75            if let Some(stmts) = &mut stream.on_connect {
76                for stmt in stmts {
77                    desugar_statement(stmt);
78                }
79            }
80            if let Some(stmts) = &mut stream.on_disconnect {
81                for stmt in stmts {
82                    desugar_statement(stmt);
83                }
84            }
85            if let Some(on_event) = &mut stream.on_event {
86                for stmt in &mut on_event.body {
87                    desugar_statement(stmt);
88                }
89            }
90            if let Some(on_window) = &mut stream.on_window {
91                for stmt in &mut on_window.body {
92                    desugar_statement(stmt);
93                }
94            }
95            if let Some(on_error) = &mut stream.on_error {
96                for stmt in &mut on_error.body {
97                    desugar_statement(stmt);
98                }
99            }
100        }
101        // Other items don't need desugaring
102        _ => {}
103    }
104}
105
106fn desugar_statement(stmt: &mut Statement) {
107    match stmt {
108        Statement::Return(Some(expr), _) => desugar_expr(expr),
109        Statement::VariableDecl(decl, _) => {
110            if let Some(value) = &mut decl.value {
111                desugar_expr(value);
112            }
113        }
114        Statement::Assignment(assign, _) => desugar_expr(&mut assign.value),
115        Statement::Expression(expr, _) => desugar_expr(expr),
116        Statement::For(for_loop, _) => {
117            match &mut for_loop.init {
118                crate::ast::ForInit::ForIn { iter, .. } => desugar_expr(iter),
119                crate::ast::ForInit::ForC {
120                    init,
121                    condition,
122                    update,
123                } => {
124                    desugar_statement(init);
125                    desugar_expr(condition);
126                    desugar_expr(update);
127                }
128            }
129            for s in &mut for_loop.body {
130                desugar_statement(s);
131            }
132        }
133        Statement::While(while_loop, _) => {
134            desugar_expr(&mut while_loop.condition);
135            for s in &mut while_loop.body {
136                desugar_statement(s);
137            }
138        }
139        Statement::If(if_stmt, _) => {
140            desugar_expr(&mut if_stmt.condition);
141            for s in &mut if_stmt.then_body {
142                desugar_statement(s);
143            }
144            if let Some(else_body) = &mut if_stmt.else_body {
145                for s in else_body {
146                    desugar_statement(s);
147                }
148            }
149        }
150        _ => {}
151    }
152}
153
154fn desugar_expr(expr: &mut Expr) {
155    // First recursively desugar nested expressions
156    match expr {
157        Expr::FromQuery(from_query, span) => {
158            // Desugar nested expressions in the query first
159            desugar_expr(&mut from_query.source);
160            for clause in &mut from_query.clauses {
161                match clause {
162                    QueryClause::Where(pred) => desugar_expr(pred),
163                    QueryClause::OrderBy(specs) => {
164                        for spec in specs {
165                            desugar_expr(&mut spec.key);
166                        }
167                    }
168                    QueryClause::GroupBy { element, key, .. } => {
169                        desugar_expr(element);
170                        desugar_expr(key);
171                    }
172                    QueryClause::Join {
173                        source,
174                        left_key,
175                        right_key,
176                        ..
177                    } => {
178                        desugar_expr(source);
179                        desugar_expr(left_key);
180                        desugar_expr(right_key);
181                    }
182                    QueryClause::Let { value, .. } => desugar_expr(value),
183                }
184            }
185            desugar_expr(&mut from_query.select);
186
187            // Now desugar the from query to method chains
188            let desugared = desugar_from_query(from_query, *span);
189            *expr = desugared;
190        }
191        // Recursively handle all other expression types
192        Expr::PropertyAccess { object, .. } => desugar_expr(object),
193        Expr::IndexAccess {
194            object,
195            index,
196            end_index,
197            ..
198        } => {
199            desugar_expr(object);
200            desugar_expr(index);
201            if let Some(end) = end_index {
202                desugar_expr(end);
203            }
204        }
205        Expr::BinaryOp { left, right, .. } => {
206            desugar_expr(left);
207            desugar_expr(right);
208        }
209        Expr::FuzzyComparison { left, right, .. } => {
210            desugar_expr(left);
211            desugar_expr(right);
212        }
213        Expr::UnaryOp { operand, .. } => desugar_expr(operand),
214        Expr::FunctionCall {
215            args, named_args, ..
216        } => {
217            for arg in args {
218                desugar_expr(arg);
219            }
220            for (_, val) in named_args {
221                desugar_expr(val);
222            }
223        }
224        Expr::QualifiedFunctionCall {
225            args, named_args, ..
226        } => {
227            for arg in args {
228                desugar_expr(arg);
229            }
230            for (_, val) in named_args {
231                desugar_expr(val);
232            }
233        }
234        Expr::MethodCall {
235            receiver,
236            args,
237            named_args,
238            ..
239        } => {
240            desugar_expr(receiver);
241            for arg in args {
242                desugar_expr(arg);
243            }
244            for (_, val) in named_args {
245                desugar_expr(val);
246            }
247        }
248        Expr::Conditional {
249            condition,
250            then_expr,
251            else_expr,
252            ..
253        } => {
254            desugar_expr(condition);
255            desugar_expr(then_expr);
256            if let Some(e) = else_expr {
257                desugar_expr(e);
258            }
259        }
260        Expr::Object(entries, _) => {
261            for entry in entries {
262                match entry {
263                    ObjectEntry::Field { value, .. } => desugar_expr(value),
264                    ObjectEntry::Spread(e) => desugar_expr(e),
265                }
266            }
267        }
268        Expr::Array(elements, _) => {
269            for elem in elements {
270                desugar_expr(elem);
271            }
272        }
273        Expr::ListComprehension(comp, _) => {
274            desugar_expr(&mut comp.element);
275            for clause in &mut comp.clauses {
276                desugar_expr(&mut clause.iterable);
277                if let Some(filter) = &mut clause.filter {
278                    desugar_expr(filter);
279                }
280            }
281        }
282        Expr::Block(block, _) => {
283            for item in &mut block.items {
284                match item {
285                    crate::ast::BlockItem::VariableDecl(decl) => {
286                        if let Some(value) = &mut decl.value {
287                            desugar_expr(value);
288                        }
289                    }
290                    crate::ast::BlockItem::Assignment(assign) => {
291                        desugar_expr(&mut assign.value);
292                    }
293                    crate::ast::BlockItem::Statement(stmt) => {
294                        desugar_statement(stmt);
295                    }
296                    crate::ast::BlockItem::Expression(e) => {
297                        desugar_expr(e);
298                    }
299                }
300            }
301        }
302        Expr::TypeAssertion { expr: inner, .. } => desugar_expr(inner),
303        Expr::InstanceOf { expr: inner, .. } => desugar_expr(inner),
304        Expr::FunctionExpr { body, .. } => {
305            for stmt in body {
306                desugar_statement(stmt);
307            }
308        }
309        Expr::Spread(inner, _) => desugar_expr(inner),
310        Expr::If(if_expr, _) => {
311            desugar_expr(&mut if_expr.condition);
312            desugar_expr(&mut if_expr.then_branch);
313            if let Some(e) = &mut if_expr.else_branch {
314                desugar_expr(e);
315            }
316        }
317        Expr::While(while_expr, _) => {
318            desugar_expr(&mut while_expr.condition);
319            desugar_expr(&mut while_expr.body);
320        }
321        Expr::For(for_expr, _) => {
322            desugar_expr(&mut for_expr.iterable);
323            desugar_expr(&mut for_expr.body);
324        }
325        Expr::Loop(loop_expr, _) => {
326            desugar_expr(&mut loop_expr.body);
327        }
328        Expr::Let(let_expr, _) => {
329            if let Some(val) = &mut let_expr.value {
330                desugar_expr(val);
331            }
332            desugar_expr(&mut let_expr.body);
333        }
334        Expr::Assign(assign, _) => {
335            desugar_expr(&mut assign.target);
336            desugar_expr(&mut assign.value);
337        }
338        Expr::Break(Some(e), _) => desugar_expr(e),
339        Expr::Return(Some(e), _) => desugar_expr(e),
340        Expr::Match(match_expr, _) => {
341            desugar_expr(&mut match_expr.scrutinee);
342            for arm in &mut match_expr.arms {
343                if let Some(guard) = &mut arm.guard {
344                    desugar_expr(guard);
345                }
346                desugar_expr(&mut arm.body);
347            }
348        }
349        Expr::Range { start, end, .. } => {
350            if let Some(s) = start {
351                desugar_expr(s);
352            }
353            if let Some(e) = end {
354                desugar_expr(e);
355            }
356        }
357        Expr::TimeframeContext { expr: inner, .. } => desugar_expr(inner),
358        Expr::TryOperator(inner, _) => desugar_expr(inner),
359        Expr::UsingImpl { expr: inner, .. } => desugar_expr(inner),
360        Expr::Await(inner, _) => desugar_expr(inner),
361        Expr::EnumConstructor { payload, .. } => match payload {
362            crate::ast::EnumConstructorPayload::Tuple(args) => {
363                for arg in args {
364                    desugar_expr(arg);
365                }
366            }
367            crate::ast::EnumConstructorPayload::Struct(fields) => {
368                for (_, val) in fields {
369                    desugar_expr(val);
370                }
371            }
372            crate::ast::EnumConstructorPayload::Unit => {}
373        },
374        Expr::SimulationCall { params, .. } => {
375            for (_, val) in params {
376                desugar_expr(val);
377            }
378        }
379        Expr::WindowExpr(window_expr, _) => {
380            // Desugar window function arguments
381            match &mut window_expr.function {
382                crate::ast::WindowFunction::Lead { expr, default, .. }
383                | crate::ast::WindowFunction::Lag { expr, default, .. } => {
384                    desugar_expr(expr);
385                    if let Some(d) = default {
386                        desugar_expr(d);
387                    }
388                }
389                crate::ast::WindowFunction::FirstValue(e)
390                | crate::ast::WindowFunction::LastValue(e)
391                | crate::ast::WindowFunction::Sum(e)
392                | crate::ast::WindowFunction::Avg(e)
393                | crate::ast::WindowFunction::Min(e)
394                | crate::ast::WindowFunction::Max(e) => {
395                    desugar_expr(e);
396                }
397                crate::ast::WindowFunction::NthValue(e, _) => {
398                    desugar_expr(e);
399                }
400                crate::ast::WindowFunction::Count(Some(e)) => {
401                    desugar_expr(e);
402                }
403                _ => {}
404            }
405            // Desugar partition_by and order_by expressions
406            for e in &mut window_expr.over.partition_by {
407                desugar_expr(e);
408            }
409            if let Some(order_by) = &mut window_expr.over.order_by {
410                for (e, _) in &mut order_by.columns {
411                    desugar_expr(e);
412                }
413            }
414        }
415        Expr::DataRef(data_ref, _) => match &mut data_ref.index {
416            crate::ast::DataIndex::Expression(e) => desugar_expr(e),
417            crate::ast::DataIndex::ExpressionRange(start, end) => {
418                desugar_expr(start);
419                desugar_expr(end);
420            }
421            _ => {}
422        },
423        Expr::DataRelativeAccess {
424            reference, index, ..
425        } => {
426            desugar_expr(reference);
427            match index {
428                crate::ast::DataIndex::Expression(e) => desugar_expr(e),
429                crate::ast::DataIndex::ExpressionRange(start, end) => {
430                    desugar_expr(start);
431                    desugar_expr(end);
432                }
433                _ => {}
434            }
435        }
436        Expr::StructLiteral { fields, .. } => {
437            for (_, value) in fields {
438                desugar_expr(value);
439            }
440        }
441        // Leaf nodes - no recursion needed
442        Expr::Literal(_, _)
443        | Expr::Identifier(_, _)
444        | Expr::DataDateTimeRef(_, _)
445        | Expr::TimeRef(_, _)
446        | Expr::DateTime(_, _)
447        | Expr::PatternRef(_, _)
448        | Expr::Duration(_, _)
449        | Expr::Unit(_)
450        | Expr::Continue(_)
451        | Expr::Break(None, _)
452        | Expr::Return(None, _)
453        | Expr::Join(_, _) => {}
454        Expr::Annotated { target, .. } => {
455            desugar_expr(target);
456        }
457        Expr::AsyncLet(async_let, _) => {
458            desugar_expr(&mut async_let.expr);
459        }
460        Expr::AsyncScope(inner, _) => {
461            desugar_expr(inner);
462        }
463        Expr::Comptime(stmts, _) => {
464            for stmt in stmts {
465                desugar_statement(stmt);
466            }
467        }
468        Expr::ComptimeFor(cf, _) => {
469            desugar_expr(&mut cf.iterable);
470            for stmt in &mut cf.body {
471                desugar_statement(stmt);
472            }
473        }
474        Expr::Reference { expr: inner, .. } => desugar_expr(inner),
475        Expr::TableRows(rows, _) => {
476            for row in rows {
477                for elem in row {
478                    desugar_expr(elem);
479                }
480            }
481        }
482    }
483}
484
485/// Desugar a from query expression into method chain calls.
486///
487/// Example transformation:
488/// ```text
489/// from t in trades where t.amount > 1000 order by t.date desc select t.price
490/// ```
491/// becomes:
492/// ```text
493/// trades.filter(|t| t.amount > 1000).orderBy(|t| t.date, "desc").map(|t| t.price)
494/// ```
495fn desugar_from_query(from_query: &FromQueryExpr, span: Span) -> Expr {
496    let current_var = &from_query.variable;
497    let mut result = (*from_query.source).clone();
498
499    // Track the current iteration variable (changes after group by into)
500    let mut iter_var = current_var.clone();
501
502    for clause in &from_query.clauses {
503        match clause {
504            QueryClause::Where(pred) => {
505                // source.filter(|var| predicate)
506                // Uses "filter" (Queryable interface) instead of "where" for
507                // consistency with trait-based dispatch (DbTable, DataTable, Array).
508                result = method_call(
509                    result,
510                    "filter",
511                    vec![make_lambda(&iter_var, pred, span)],
512                    span,
513                );
514            }
515            QueryClause::OrderBy(specs) => {
516                // source.orderBy(|var| key, "dir").thenBy(|var| key2, "dir2")...
517                for (i, spec) in specs.iter().enumerate() {
518                    let method = if i == 0 { "orderBy" } else { "thenBy" };
519                    let dir = if spec.descending { "desc" } else { "asc" };
520                    result = method_call(
521                        result,
522                        method,
523                        vec![
524                            make_lambda(&iter_var, &spec.key, span),
525                            string_lit(dir, span),
526                        ],
527                        span,
528                    );
529                }
530            }
531            QueryClause::GroupBy { key, into_var, .. } => {
532                // source.groupBy(|var| key)
533                result = method_call(
534                    result,
535                    "groupBy",
536                    vec![make_lambda(&iter_var, key, span)],
537                    span,
538                );
539                // After grouping, iteration variable changes to the group
540                if let Some(var) = into_var {
541                    iter_var = var.clone();
542                }
543            }
544            QueryClause::Join {
545                variable,
546                source: join_source,
547                left_key,
548                right_key,
549                into_var,
550            } => {
551                // source.innerJoin(other, |left| leftKey, |right| rightKey, |left, right| result)
552                // or with into: source.leftJoin(...)
553                let method = if into_var.is_some() {
554                    "leftJoin"
555                } else {
556                    "innerJoin"
557                };
558
559                // Build result selector that creates object with both variables
560                let result_selector = make_binary_lambda(
561                    &iter_var,
562                    variable,
563                    &make_object(
564                        vec![
565                            (iter_var.clone(), Expr::Identifier(iter_var.clone(), span)),
566                            (variable.clone(), Expr::Identifier(variable.clone(), span)),
567                        ],
568                        span,
569                    ),
570                    span,
571                );
572
573                result = method_call(
574                    result,
575                    method,
576                    vec![
577                        (**join_source).clone(),
578                        make_lambda(&iter_var, left_key, span),
579                        make_lambda(variable, right_key, span),
580                        result_selector,
581                    ],
582                    span,
583                );
584            }
585            QueryClause::Let {
586                variable: let_var,
587                value,
588            } => {
589                // Transform to intermediate select that adds the binding
590                // source.select(|var| { __orig: var, let_var: value })
591                let intermediate = make_object(
592                    vec![
593                        (
594                            "__orig".to_string(),
595                            Expr::Identifier(iter_var.clone(), span),
596                        ),
597                        (let_var.clone(), (**value).clone()),
598                    ],
599                    span,
600                );
601                result = method_call(
602                    result,
603                    "select",
604                    vec![make_lambda(&iter_var, &intermediate, span)],
605                    span,
606                );
607                // Update iter_var to access __orig for the original variable
608                // This is a simplification; in a full implementation we'd rewrite references
609                iter_var = "__x".to_string();
610            }
611        }
612    }
613
614    // Final select → desugars to .map() for typed table algebra
615    result = method_call(
616        result,
617        "map",
618        vec![make_lambda(&iter_var, &from_query.select, span)],
619        span,
620    );
621
622    result
623}
624
625/// Create a method call expression
626fn method_call(receiver: Expr, method: &str, args: Vec<Expr>, span: Span) -> Expr {
627    Expr::MethodCall {
628        receiver: Box::new(receiver),
629        method: method.to_string(),
630        args,
631        named_args: vec![],
632        optional: false,
633        span,
634    }
635}
636
637/// Create a lambda expression: |param| body
638fn make_lambda(param: &str, body: &Expr, span: Span) -> Expr {
639    Expr::FunctionExpr {
640        params: vec![FunctionParameter {
641            pattern: DestructurePattern::Identifier(param.to_string(), span),
642            is_const: false,
643            is_reference: false,
644            is_mut_reference: false,
645            is_out: false,
646            type_annotation: None,
647            default_value: None,
648        }],
649        return_type: None,
650        body: vec![Statement::Return(Some(body.clone()), span)],
651        span,
652    }
653}
654
655/// Create a binary lambda expression: |param1, param2| body
656fn make_binary_lambda(param1: &str, param2: &str, body: &Expr, span: Span) -> Expr {
657    Expr::FunctionExpr {
658        params: vec![
659            FunctionParameter {
660                pattern: DestructurePattern::Identifier(param1.to_string(), span),
661                is_const: false,
662                is_reference: false,
663                is_mut_reference: false,
664                is_out: false,
665                type_annotation: None,
666                default_value: None,
667            },
668            FunctionParameter {
669                pattern: DestructurePattern::Identifier(param2.to_string(), span),
670                is_const: false,
671                is_reference: false,
672                is_mut_reference: false,
673                is_out: false,
674                type_annotation: None,
675                default_value: None,
676            },
677        ],
678        return_type: None,
679        body: vec![Statement::Return(Some(body.clone()), span)],
680        span,
681    }
682}
683
684/// Create a string literal expression
685fn string_lit(s: &str, span: Span) -> Expr {
686    Expr::Literal(Literal::String(s.to_string()), span)
687}
688
689/// Create an object literal expression
690fn make_object(fields: Vec<(String, Expr)>, span: Span) -> Expr {
691    let entries = fields
692        .into_iter()
693        .map(|(key, value)| ObjectEntry::Field {
694            key,
695            value,
696            type_annotation: None,
697        })
698        .collect();
699    Expr::Object(entries, span)
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705    use crate::parser::parse_program;
706
707    /// Helper to extract expression from a program item
708    fn get_expr(item: &Item) -> Option<&Expr> {
709        match item {
710            Item::Expression(expr, _) => Some(expr),
711            Item::Statement(Statement::Expression(expr, _), _) => Some(expr),
712            _ => None,
713        }
714    }
715
716    #[test]
717    fn test_basic_from_query_desugaring() {
718        let code = "from x in [1, 2, 3] where x > 1 select x * 2";
719        let mut program = parse_program(code).unwrap();
720        desugar_program(&mut program);
721
722        // The desugared form should be a method chain
723        if let Some(expr) = get_expr(&program.items[0]) {
724            // Should be: [1,2,3].filter(|x| x > 1).map(|x| x * 2)
725            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
726        } else {
727            panic!("Expected expression item");
728        }
729    }
730
731    #[test]
732    fn test_order_by_desugaring() {
733        let code = "from x in arr order by x.value desc select x";
734        let mut program = parse_program(code).unwrap();
735        desugar_program(&mut program);
736
737        if let Some(expr) = get_expr(&program.items[0]) {
738            // Final call should be select
739            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
740        } else {
741            panic!("Expected expression item");
742        }
743    }
744}