Skip to main content

shape_ast/transform/
desugar.rs

1//! AST desugaring pass
2//!
3//! Transforms high-level syntax into equivalent lower-level constructs.
4//! Currently handles LINQ-style from queries → method chains.
5
6use crate::ast::{
7    DestructurePattern, Expr, FromQueryExpr, FunctionParameter, Item, Literal, ObjectEntry,
8    Program, QueryClause, Span, Statement,
9};
10
11/// Desugar all high-level syntax in a program.
12/// This should be called before compilation.
13pub fn desugar_program(program: &mut Program) {
14    for item in &mut program.items {
15        desugar_item(item);
16    }
17}
18
19fn desugar_item(item: &mut Item) {
20    match item {
21        Item::Function(func, _) => {
22            for stmt in &mut func.body {
23                desugar_statement(stmt);
24            }
25        }
26        Item::VariableDecl(decl, _) => {
27            if let Some(value) = &mut decl.value {
28                desugar_expr(value);
29            }
30        }
31        Item::Assignment(assign, _) => {
32            desugar_expr(&mut assign.value);
33        }
34        Item::Expression(expr, _) => {
35            desugar_expr(expr);
36        }
37        Item::Statement(stmt, _) => {
38            desugar_statement(stmt);
39        }
40        Item::Export(export, _) => match &mut export.item {
41            crate::ast::ExportItem::Function(func) => {
42                for stmt in &mut func.body {
43                    desugar_statement(stmt);
44                }
45            }
46            crate::ast::ExportItem::TypeAlias(_)
47            | crate::ast::ExportItem::Named(_)
48            | crate::ast::ExportItem::Enum(_)
49            | crate::ast::ExportItem::Struct(_)
50            | crate::ast::ExportItem::Trait(_)
51            | crate::ast::ExportItem::BuiltinFunction(_)
52            | crate::ast::ExportItem::BuiltinType(_)
53            | crate::ast::ExportItem::Annotation(_)
54            | crate::ast::ExportItem::ForeignFunction(_) => {}
55        },
56        Item::Module(module, _) => {
57            for inner in &mut module.items {
58                desugar_item(inner);
59            }
60        }
61        Item::Extend(extend, _) => {
62            for method in &mut extend.methods {
63                for stmt in &mut method.body {
64                    desugar_statement(stmt);
65                }
66            }
67        }
68        Item::Stream(stream, _) => {
69            for decl in &mut stream.state {
70                if let Some(value) = &mut decl.value {
71                    desugar_expr(value);
72                }
73            }
74            if let Some(stmts) = &mut stream.on_connect {
75                for stmt in stmts {
76                    desugar_statement(stmt);
77                }
78            }
79            if let Some(stmts) = &mut stream.on_disconnect {
80                for stmt in stmts {
81                    desugar_statement(stmt);
82                }
83            }
84            if let Some(on_event) = &mut stream.on_event {
85                for stmt in &mut on_event.body {
86                    desugar_statement(stmt);
87                }
88            }
89            if let Some(on_window) = &mut stream.on_window {
90                for stmt in &mut on_window.body {
91                    desugar_statement(stmt);
92                }
93            }
94            if let Some(on_error) = &mut stream.on_error {
95                for stmt in &mut on_error.body {
96                    desugar_statement(stmt);
97                }
98            }
99        }
100        // Other items don't need desugaring
101        _ => {}
102    }
103}
104
105fn desugar_statement(stmt: &mut Statement) {
106    match stmt {
107        Statement::Return(Some(expr), _) => desugar_expr(expr),
108        Statement::VariableDecl(decl, _) => {
109            if let Some(value) = &mut decl.value {
110                desugar_expr(value);
111            }
112        }
113        Statement::Assignment(assign, _) => desugar_expr(&mut assign.value),
114        Statement::Expression(expr, _) => desugar_expr(expr),
115        Statement::For(for_loop, _) => {
116            match &mut for_loop.init {
117                crate::ast::ForInit::ForIn { iter, .. } => desugar_expr(iter),
118                crate::ast::ForInit::ForC {
119                    init,
120                    condition,
121                    update,
122                } => {
123                    desugar_statement(init);
124                    desugar_expr(condition);
125                    desugar_expr(update);
126                }
127            }
128            for s in &mut for_loop.body {
129                desugar_statement(s);
130            }
131        }
132        Statement::While(while_loop, _) => {
133            desugar_expr(&mut while_loop.condition);
134            for s in &mut while_loop.body {
135                desugar_statement(s);
136            }
137        }
138        Statement::If(if_stmt, _) => {
139            desugar_expr(&mut if_stmt.condition);
140            for s in &mut if_stmt.then_body {
141                desugar_statement(s);
142            }
143            if let Some(else_body) = &mut if_stmt.else_body {
144                for s in else_body {
145                    desugar_statement(s);
146                }
147            }
148        }
149        _ => {}
150    }
151}
152
153fn desugar_expr(expr: &mut Expr) {
154    // First recursively desugar nested expressions
155    match expr {
156        Expr::FromQuery(from_query, span) => {
157            // Desugar nested expressions in the query first
158            desugar_expr(&mut from_query.source);
159            for clause in &mut from_query.clauses {
160                match clause {
161                    QueryClause::Where(pred) => desugar_expr(pred),
162                    QueryClause::OrderBy(specs) => {
163                        for spec in specs {
164                            desugar_expr(&mut spec.key);
165                        }
166                    }
167                    QueryClause::GroupBy { element, key, .. } => {
168                        desugar_expr(element);
169                        desugar_expr(key);
170                    }
171                    QueryClause::Join {
172                        source,
173                        left_key,
174                        right_key,
175                        ..
176                    } => {
177                        desugar_expr(source);
178                        desugar_expr(left_key);
179                        desugar_expr(right_key);
180                    }
181                    QueryClause::Let { value, .. } => desugar_expr(value),
182                }
183            }
184            desugar_expr(&mut from_query.select);
185
186            // Now desugar the from query to method chains
187            let desugared = desugar_from_query(from_query, *span);
188            *expr = desugared;
189        }
190        // Recursively handle all other expression types
191        Expr::PropertyAccess { object, .. } => desugar_expr(object),
192        Expr::IndexAccess {
193            object,
194            index,
195            end_index,
196            ..
197        } => {
198            desugar_expr(object);
199            desugar_expr(index);
200            if let Some(end) = end_index {
201                desugar_expr(end);
202            }
203        }
204        Expr::BinaryOp { left, right, .. } => {
205            desugar_expr(left);
206            desugar_expr(right);
207        }
208        Expr::FuzzyComparison { left, right, .. } => {
209            desugar_expr(left);
210            desugar_expr(right);
211        }
212        Expr::UnaryOp { operand, .. } => desugar_expr(operand),
213        Expr::FunctionCall {
214            args, named_args, ..
215        } => {
216            for arg in args {
217                desugar_expr(arg);
218            }
219            for (_, val) in named_args {
220                desugar_expr(val);
221            }
222        }
223        Expr::QualifiedFunctionCall {
224            args, named_args, ..
225        } => {
226            for arg in args {
227                desugar_expr(arg);
228            }
229            for (_, val) in named_args {
230                desugar_expr(val);
231            }
232        }
233        Expr::MethodCall {
234            receiver,
235            args,
236            named_args,
237            ..
238        } => {
239            desugar_expr(receiver);
240            for arg in args {
241                desugar_expr(arg);
242            }
243            for (_, val) in named_args {
244                desugar_expr(val);
245            }
246        }
247        Expr::Conditional {
248            condition,
249            then_expr,
250            else_expr,
251            ..
252        } => {
253            desugar_expr(condition);
254            desugar_expr(then_expr);
255            if let Some(e) = else_expr {
256                desugar_expr(e);
257            }
258        }
259        Expr::Object(entries, _) => {
260            for entry in entries {
261                match entry {
262                    ObjectEntry::Field { value, .. } => desugar_expr(value),
263                    ObjectEntry::Spread(e) => desugar_expr(e),
264                }
265            }
266        }
267        Expr::Array(elements, _) => {
268            for elem in elements {
269                desugar_expr(elem);
270            }
271        }
272        Expr::ListComprehension(comp, _) => {
273            desugar_expr(&mut comp.element);
274            for clause in &mut comp.clauses {
275                desugar_expr(&mut clause.iterable);
276                if let Some(filter) = &mut clause.filter {
277                    desugar_expr(filter);
278                }
279            }
280        }
281        Expr::Block(block, _) => {
282            for item in &mut block.items {
283                match item {
284                    crate::ast::BlockItem::VariableDecl(decl) => {
285                        if let Some(value) = &mut decl.value {
286                            desugar_expr(value);
287                        }
288                    }
289                    crate::ast::BlockItem::Assignment(assign) => {
290                        desugar_expr(&mut assign.value);
291                    }
292                    crate::ast::BlockItem::Statement(stmt) => {
293                        desugar_statement(stmt);
294                    }
295                    crate::ast::BlockItem::Expression(e) => {
296                        desugar_expr(e);
297                    }
298                }
299            }
300        }
301        Expr::TypeAssertion { expr: inner, .. } => desugar_expr(inner),
302        Expr::InstanceOf { expr: inner, .. } => desugar_expr(inner),
303        Expr::FunctionExpr { body, .. } => {
304            for stmt in body {
305                desugar_statement(stmt);
306            }
307        }
308        Expr::Spread(inner, _) => desugar_expr(inner),
309        Expr::If(if_expr, _) => {
310            desugar_expr(&mut if_expr.condition);
311            desugar_expr(&mut if_expr.then_branch);
312            if let Some(e) = &mut if_expr.else_branch {
313                desugar_expr(e);
314            }
315        }
316        Expr::While(while_expr, _) => {
317            desugar_expr(&mut while_expr.condition);
318            desugar_expr(&mut while_expr.body);
319        }
320        Expr::For(for_expr, _) => {
321            desugar_expr(&mut for_expr.iterable);
322            desugar_expr(&mut for_expr.body);
323        }
324        Expr::Loop(loop_expr, _) => {
325            desugar_expr(&mut loop_expr.body);
326        }
327        Expr::Let(let_expr, _) => {
328            if let Some(val) = &mut let_expr.value {
329                desugar_expr(val);
330            }
331            desugar_expr(&mut let_expr.body);
332        }
333        Expr::Assign(assign, _) => {
334            desugar_expr(&mut assign.target);
335            desugar_expr(&mut assign.value);
336        }
337        Expr::Break(Some(e), _) => desugar_expr(e),
338        Expr::Return(Some(e), _) => desugar_expr(e),
339        Expr::Match(match_expr, _) => {
340            desugar_expr(&mut match_expr.scrutinee);
341            for arm in &mut match_expr.arms {
342                if let Some(guard) = &mut arm.guard {
343                    desugar_expr(guard);
344                }
345                desugar_expr(&mut arm.body);
346            }
347        }
348        Expr::Range { start, end, .. } => {
349            if let Some(s) = start {
350                desugar_expr(s);
351            }
352            if let Some(e) = end {
353                desugar_expr(e);
354            }
355        }
356        Expr::TimeframeContext { expr: inner, .. } => desugar_expr(inner),
357        Expr::TryOperator(inner, _) => desugar_expr(inner),
358        Expr::UsingImpl { expr: inner, .. } => desugar_expr(inner),
359        Expr::Await(inner, _) => desugar_expr(inner),
360        Expr::EnumConstructor { payload, .. } => match payload {
361            crate::ast::EnumConstructorPayload::Tuple(args) => {
362                for arg in args {
363                    desugar_expr(arg);
364                }
365            }
366            crate::ast::EnumConstructorPayload::Struct(fields) => {
367                for (_, val) in fields {
368                    desugar_expr(val);
369                }
370            }
371            crate::ast::EnumConstructorPayload::Unit => {}
372        },
373        Expr::SimulationCall { params, .. } => {
374            for (_, val) in params {
375                desugar_expr(val);
376            }
377        }
378        Expr::WindowExpr(window_expr, _) => {
379            // Desugar window function arguments
380            match &mut window_expr.function {
381                crate::ast::WindowFunction::Lead { expr, default, .. }
382                | crate::ast::WindowFunction::Lag { expr, default, .. } => {
383                    desugar_expr(expr);
384                    if let Some(d) = default {
385                        desugar_expr(d);
386                    }
387                }
388                crate::ast::WindowFunction::FirstValue(e)
389                | crate::ast::WindowFunction::LastValue(e)
390                | crate::ast::WindowFunction::Sum(e)
391                | crate::ast::WindowFunction::Avg(e)
392                | crate::ast::WindowFunction::Min(e)
393                | crate::ast::WindowFunction::Max(e) => {
394                    desugar_expr(e);
395                }
396                crate::ast::WindowFunction::NthValue(e, _) => {
397                    desugar_expr(e);
398                }
399                crate::ast::WindowFunction::Count(Some(e)) => {
400                    desugar_expr(e);
401                }
402                _ => {}
403            }
404            // Desugar partition_by and order_by expressions
405            for e in &mut window_expr.over.partition_by {
406                desugar_expr(e);
407            }
408            if let Some(order_by) = &mut window_expr.over.order_by {
409                for (e, _) in &mut order_by.columns {
410                    desugar_expr(e);
411                }
412            }
413        }
414        Expr::DataRef(data_ref, _) => match &mut data_ref.index {
415            crate::ast::DataIndex::Expression(e) => desugar_expr(e),
416            crate::ast::DataIndex::ExpressionRange(start, end) => {
417                desugar_expr(start);
418                desugar_expr(end);
419            }
420            _ => {}
421        },
422        Expr::DataRelativeAccess {
423            reference, index, ..
424        } => {
425            desugar_expr(reference);
426            match index {
427                crate::ast::DataIndex::Expression(e) => desugar_expr(e),
428                crate::ast::DataIndex::ExpressionRange(start, end) => {
429                    desugar_expr(start);
430                    desugar_expr(end);
431                }
432                _ => {}
433            }
434        }
435        Expr::StructLiteral { fields, .. } => {
436            for (_, value) in fields {
437                desugar_expr(value);
438            }
439        }
440        // Leaf nodes - no recursion needed
441        Expr::Literal(_, _)
442        | Expr::Identifier(_, _)
443        | Expr::DataDateTimeRef(_, _)
444        | Expr::TimeRef(_, _)
445        | Expr::DateTime(_, _)
446        | Expr::PatternRef(_, _)
447        | Expr::Duration(_, _)
448        | Expr::Unit(_)
449        | Expr::Continue(_)
450        | Expr::Break(None, _)
451        | Expr::Return(None, _)
452        | Expr::Join(_, _) => {}
453        Expr::Annotated { target, .. } => {
454            desugar_expr(target);
455        }
456        Expr::AsyncLet(async_let, _) => {
457            desugar_expr(&mut async_let.expr);
458        }
459        Expr::AsyncScope(inner, _) => {
460            desugar_expr(inner);
461        }
462        Expr::Comptime(stmts, _) => {
463            for stmt in stmts {
464                desugar_statement(stmt);
465            }
466        }
467        Expr::ComptimeFor(cf, _) => {
468            desugar_expr(&mut cf.iterable);
469            for stmt in &mut cf.body {
470                desugar_statement(stmt);
471            }
472        }
473        Expr::Reference { expr: inner, .. } => desugar_expr(inner),
474        Expr::TableRows(rows, _) => {
475            for row in rows {
476                for elem in row {
477                    desugar_expr(elem);
478                }
479            }
480        }
481    }
482}
483
484/// Desugar a from query expression into method chain calls.
485///
486/// Example transformation:
487/// ```text
488/// from t in trades where t.amount > 1000 order by t.date desc select t.price
489/// ```
490/// becomes:
491/// ```text
492/// trades.filter(|t| t.amount > 1000).orderBy(|t| t.date, "desc").map(|t| t.price)
493/// ```
494fn desugar_from_query(from_query: &FromQueryExpr, span: Span) -> Expr {
495    let current_var = &from_query.variable;
496    let mut result = (*from_query.source).clone();
497
498    // Track the current iteration variable (changes after group by into)
499    let mut iter_var = current_var.clone();
500
501    for clause in &from_query.clauses {
502        match clause {
503            QueryClause::Where(pred) => {
504                // source.filter(|var| predicate)
505                // Uses "filter" (Queryable interface) instead of "where" for
506                // consistency with trait-based dispatch (DbTable, DataTable, Array).
507                result = method_call(
508                    result,
509                    "filter",
510                    vec![make_lambda(&iter_var, pred, span)],
511                    span,
512                );
513            }
514            QueryClause::OrderBy(specs) => {
515                // source.orderBy(|var| key, "dir").thenBy(|var| key2, "dir2")...
516                for (i, spec) in specs.iter().enumerate() {
517                    let method = if i == 0 { "orderBy" } else { "thenBy" };
518                    let dir = if spec.descending { "desc" } else { "asc" };
519                    result = method_call(
520                        result,
521                        method,
522                        vec![
523                            make_lambda(&iter_var, &spec.key, span),
524                            string_lit(dir, span),
525                        ],
526                        span,
527                    );
528                }
529            }
530            QueryClause::GroupBy { key, into_var, .. } => {
531                // source.groupBy(|var| key)
532                result = method_call(
533                    result,
534                    "groupBy",
535                    vec![make_lambda(&iter_var, key, span)],
536                    span,
537                );
538                // After grouping, iteration variable changes to the group
539                if let Some(var) = into_var {
540                    iter_var = var.clone();
541                }
542            }
543            QueryClause::Join {
544                variable,
545                source: join_source,
546                left_key,
547                right_key,
548                into_var,
549            } => {
550                // source.innerJoin(other, |left| leftKey, |right| rightKey, |left, right| result)
551                // or with into: source.leftJoin(...)
552                let method = if into_var.is_some() {
553                    "leftJoin"
554                } else {
555                    "innerJoin"
556                };
557
558                // Build result selector that creates object with both variables
559                let result_selector = make_binary_lambda(
560                    &iter_var,
561                    variable,
562                    &make_object(
563                        vec![
564                            (iter_var.clone(), Expr::Identifier(iter_var.clone(), span)),
565                            (variable.clone(), Expr::Identifier(variable.clone(), span)),
566                        ],
567                        span,
568                    ),
569                    span,
570                );
571
572                result = method_call(
573                    result,
574                    method,
575                    vec![
576                        (**join_source).clone(),
577                        make_lambda(&iter_var, left_key, span),
578                        make_lambda(variable, right_key, span),
579                        result_selector,
580                    ],
581                    span,
582                );
583            }
584            QueryClause::Let {
585                variable: let_var,
586                value,
587            } => {
588                // Transform to intermediate select that adds the binding
589                // source.select(|var| { __orig: var, let_var: value })
590                let intermediate = make_object(
591                    vec![
592                        (
593                            "__orig".to_string(),
594                            Expr::Identifier(iter_var.clone(), span),
595                        ),
596                        (let_var.clone(), (**value).clone()),
597                    ],
598                    span,
599                );
600                result = method_call(
601                    result,
602                    "select",
603                    vec![make_lambda(&iter_var, &intermediate, span)],
604                    span,
605                );
606                // Update iter_var to access __orig for the original variable
607                // This is a simplification; in a full implementation we'd rewrite references
608                iter_var = "__x".to_string();
609            }
610        }
611    }
612
613    // Final select → desugars to .map() for typed table algebra
614    result = method_call(
615        result,
616        "map",
617        vec![make_lambda(&iter_var, &from_query.select, span)],
618        span,
619    );
620
621    result
622}
623
624/// Create a method call expression
625fn method_call(receiver: Expr, method: &str, args: Vec<Expr>, span: Span) -> Expr {
626    Expr::MethodCall {
627        receiver: Box::new(receiver),
628        method: method.to_string(),
629        args,
630        named_args: vec![],
631        optional: false,
632        span,
633    }
634}
635
636/// Create a lambda expression: |param| body
637fn make_lambda(param: &str, body: &Expr, span: Span) -> Expr {
638    Expr::FunctionExpr {
639        params: vec![FunctionParameter {
640            pattern: DestructurePattern::Identifier(param.to_string(), span),
641            is_const: false,
642            is_reference: false,
643            is_mut_reference: false,
644            is_out: false,
645            type_annotation: None,
646            default_value: None,
647        }],
648        return_type: None,
649        body: vec![Statement::Return(Some(body.clone()), span)],
650        span,
651    }
652}
653
654/// Create a binary lambda expression: |param1, param2| body
655fn make_binary_lambda(param1: &str, param2: &str, body: &Expr, span: Span) -> Expr {
656    Expr::FunctionExpr {
657        params: vec![
658            FunctionParameter {
659                pattern: DestructurePattern::Identifier(param1.to_string(), span),
660                is_const: false,
661                is_reference: false,
662                is_mut_reference: false,
663                is_out: false,
664                type_annotation: None,
665                default_value: None,
666            },
667            FunctionParameter {
668                pattern: DestructurePattern::Identifier(param2.to_string(), span),
669                is_const: false,
670                is_reference: false,
671                is_mut_reference: false,
672                is_out: false,
673                type_annotation: None,
674                default_value: None,
675            },
676        ],
677        return_type: None,
678        body: vec![Statement::Return(Some(body.clone()), span)],
679        span,
680    }
681}
682
683/// Create a string literal expression
684fn string_lit(s: &str, span: Span) -> Expr {
685    Expr::Literal(Literal::String(s.to_string()), span)
686}
687
688/// Create an object literal expression
689fn make_object(fields: Vec<(String, Expr)>, span: Span) -> Expr {
690    let entries = fields
691        .into_iter()
692        .map(|(key, value)| ObjectEntry::Field {
693            key,
694            value,
695            type_annotation: None,
696        })
697        .collect();
698    Expr::Object(entries, span)
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use crate::parser::parse_program;
705
706    /// Helper to extract expression from a program item
707    fn get_expr(item: &Item) -> Option<&Expr> {
708        match item {
709            Item::Expression(expr, _) => Some(expr),
710            Item::Statement(Statement::Expression(expr, _), _) => Some(expr),
711            _ => None,
712        }
713    }
714
715    #[test]
716    fn test_basic_from_query_desugaring() {
717        let code = "from x in [1, 2, 3] where x > 1 select x * 2";
718        let mut program = parse_program(code).unwrap();
719        desugar_program(&mut program);
720
721        // The desugared form should be a method chain
722        if let Some(expr) = get_expr(&program.items[0]) {
723            // Should be: [1,2,3].filter(|x| x > 1).map(|x| x * 2)
724            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
725        } else {
726            panic!("Expected expression item");
727        }
728    }
729
730    #[test]
731    fn test_order_by_desugaring() {
732        let code = "from x in arr order by x.value desc select x";
733        let mut program = parse_program(code).unwrap();
734        desugar_program(&mut program);
735
736        if let Some(expr) = get_expr(&program.items[0]) {
737            // Final call should be select
738            assert!(matches!(expr, Expr::MethodCall { method, .. } if method == "map"));
739        } else {
740            panic!("Expected expression item");
741        }
742    }
743}