Skip to main content

shape_runtime/
visitor.rs

1//! AST Visitor trait and walk functions for Shape.
2//!
3//! This module provides a visitor pattern for traversing the AST.
4//! All variants are explicitly handled - no wildcards.
5//!
6//! ## Per-Variant Expression Methods
7//!
8//! The `Visitor` trait provides fine-grained per-variant methods for expressions.
9//! Each method has a default implementation that returns `true` (continue into
10//! children). Override only the variants you care about.
11//!
12//! The visit order for each expression is:
13//! 1. `visit_expr(expr)` — coarse pre-visit hook; return `false` to skip entirely
14//! 2. `visit_<variant>(expr, span)` — per-variant hook; return `false` to skip children
15//! 3. Walk children recursively
16//! 4. `leave_expr(expr)` — post-visit hook
17
18use shape_ast::ast::*;
19
20/// A visitor trait for traversing Shape AST nodes.
21///
22/// All `visit_*` methods return `bool`:
23/// - `true`: continue visiting children
24/// - `false`: skip children
25///
26/// The `leave_*` methods are called after visiting all children.
27///
28/// ## Per-Variant Expression Methods
29///
30/// For finer granularity, override the per-variant expression methods
31/// (e.g., `visit_identifier`, `visit_binary_op`, `visit_method_call`).
32/// These are called from `walk_expr` after the coarse `visit_expr` hook.
33/// Each receives the full `&Expr` node and its `Span`.
34pub trait Visitor: Sized {
35    // ===== Coarse-grained visitors (called on every node) =====
36
37    /// Called before visiting any expression. Return `false` to skip entirely
38    /// (neither per-variant method nor children will be visited).
39    fn visit_expr(&mut self, _expr: &Expr) -> bool {
40        true
41    }
42    /// Called after visiting an expression and all its children.
43    fn leave_expr(&mut self, _expr: &Expr) {}
44
45    // Statement visitors
46    fn visit_stmt(&mut self, _stmt: &Statement) -> bool {
47        true
48    }
49    fn leave_stmt(&mut self, _stmt: &Statement) {}
50
51    // Item visitors
52    fn visit_item(&mut self, _item: &Item) -> bool {
53        true
54    }
55    fn leave_item(&mut self, _item: &Item) {}
56
57    // Function definition visitors
58    fn visit_function(&mut self, _func: &FunctionDef) -> bool {
59        true
60    }
61    fn leave_function(&mut self, _func: &FunctionDef) {}
62
63    // Literal visitors (kept for backward compat — also called from walk_expr)
64    fn visit_literal(&mut self, _lit: &Literal) -> bool {
65        true
66    }
67    fn leave_literal(&mut self, _lit: &Literal) {}
68
69    // Block visitors (kept for backward compat — also called from walk_expr)
70    fn visit_block(&mut self, _block: &BlockExpr) -> bool {
71        true
72    }
73    fn leave_block(&mut self, _block: &BlockExpr) {}
74
75    // ===== Per-variant expression visitors =====
76    //
77    // Each method receives the full &Expr and its Span. Return `true` to
78    // continue walking children, `false` to skip children.
79    //
80    // Default implementations return `true` (walk children).
81
82    fn visit_expr_literal(&mut self, _expr: &Expr, _span: Span) -> bool {
83        true
84    }
85    fn visit_expr_identifier(&mut self, _expr: &Expr, _span: Span) -> bool {
86        true
87    }
88    fn visit_expr_data_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
89        true
90    }
91    fn visit_expr_data_datetime_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
92        true
93    }
94    fn visit_expr_data_relative_access(&mut self, _expr: &Expr, _span: Span) -> bool {
95        true
96    }
97    fn visit_expr_property_access(&mut self, _expr: &Expr, _span: Span) -> bool {
98        true
99    }
100    fn visit_expr_index_access(&mut self, _expr: &Expr, _span: Span) -> bool {
101        true
102    }
103    fn visit_expr_binary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
104        true
105    }
106    fn visit_expr_fuzzy_comparison(&mut self, _expr: &Expr, _span: Span) -> bool {
107        true
108    }
109    fn visit_expr_unary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
110        true
111    }
112    fn visit_expr_function_call(&mut self, _expr: &Expr, _span: Span) -> bool {
113        true
114    }
115    fn visit_expr_enum_constructor(&mut self, _expr: &Expr, _span: Span) -> bool {
116        true
117    }
118    fn visit_expr_time_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
119        true
120    }
121    fn visit_expr_datetime(&mut self, _expr: &Expr, _span: Span) -> bool {
122        true
123    }
124    fn visit_expr_pattern_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
125        true
126    }
127    fn visit_expr_conditional(&mut self, _expr: &Expr, _span: Span) -> bool {
128        true
129    }
130    fn visit_expr_object(&mut self, _expr: &Expr, _span: Span) -> bool {
131        true
132    }
133    fn visit_expr_array(&mut self, _expr: &Expr, _span: Span) -> bool {
134        true
135    }
136    fn visit_expr_list_comprehension(&mut self, _expr: &Expr, _span: Span) -> bool {
137        true
138    }
139    fn visit_expr_block(&mut self, _expr: &Expr, _span: Span) -> bool {
140        true
141    }
142    fn visit_expr_type_assertion(&mut self, _expr: &Expr, _span: Span) -> bool {
143        true
144    }
145    fn visit_expr_instance_of(&mut self, _expr: &Expr, _span: Span) -> bool {
146        true
147    }
148    fn visit_expr_function_expr(&mut self, _expr: &Expr, _span: Span) -> bool {
149        true
150    }
151    fn visit_expr_duration(&mut self, _expr: &Expr, _span: Span) -> bool {
152        true
153    }
154    fn visit_expr_spread(&mut self, _expr: &Expr, _span: Span) -> bool {
155        true
156    }
157    fn visit_expr_if(&mut self, _expr: &Expr, _span: Span) -> bool {
158        true
159    }
160    fn visit_expr_while(&mut self, _expr: &Expr, _span: Span) -> bool {
161        true
162    }
163    fn visit_expr_for(&mut self, _expr: &Expr, _span: Span) -> bool {
164        true
165    }
166    fn visit_expr_loop(&mut self, _expr: &Expr, _span: Span) -> bool {
167        true
168    }
169    fn visit_expr_let(&mut self, _expr: &Expr, _span: Span) -> bool {
170        true
171    }
172    fn visit_expr_assign(&mut self, _expr: &Expr, _span: Span) -> bool {
173        true
174    }
175    fn visit_expr_break(&mut self, _expr: &Expr, _span: Span) -> bool {
176        true
177    }
178    fn visit_expr_continue(&mut self, _expr: &Expr, _span: Span) -> bool {
179        true
180    }
181    fn visit_expr_return(&mut self, _expr: &Expr, _span: Span) -> bool {
182        true
183    }
184    fn visit_expr_method_call(&mut self, _expr: &Expr, _span: Span) -> bool {
185        true
186    }
187    fn visit_expr_match(&mut self, _expr: &Expr, _span: Span) -> bool {
188        true
189    }
190    fn visit_expr_unit(&mut self, _expr: &Expr, _span: Span) -> bool {
191        true
192    }
193    fn visit_expr_range(&mut self, _expr: &Expr, _span: Span) -> bool {
194        true
195    }
196    fn visit_expr_timeframe_context(&mut self, _expr: &Expr, _span: Span) -> bool {
197        true
198    }
199    fn visit_expr_try_operator(&mut self, _expr: &Expr, _span: Span) -> bool {
200        true
201    }
202    fn visit_expr_using_impl(&mut self, _expr: &Expr, _span: Span) -> bool {
203        true
204    }
205    fn visit_expr_simulation_call(&mut self, _expr: &Expr, _span: Span) -> bool {
206        true
207    }
208    fn visit_expr_window_expr(&mut self, _expr: &Expr, _span: Span) -> bool {
209        true
210    }
211    fn visit_expr_from_query(&mut self, _expr: &Expr, _span: Span) -> bool {
212        true
213    }
214    fn visit_expr_struct_literal(&mut self, _expr: &Expr, _span: Span) -> bool {
215        true
216    }
217    fn visit_expr_await(&mut self, _expr: &Expr, _span: Span) -> bool {
218        true
219    }
220    fn visit_expr_join(&mut self, _expr: &Expr, _span: Span) -> bool {
221        true
222    }
223    fn visit_expr_annotated(&mut self, _expr: &Expr, _span: Span) -> bool {
224        true
225    }
226    fn visit_expr_async_let(&mut self, _expr: &Expr, _span: Span) -> bool {
227        true
228    }
229    fn visit_expr_async_scope(&mut self, _expr: &Expr, _span: Span) -> bool {
230        true
231    }
232    fn visit_expr_comptime(&mut self, _expr: &Expr, _span: Span) -> bool {
233        true
234    }
235    fn visit_expr_comptime_for(&mut self, _expr: &Expr, _span: Span) -> bool {
236        true
237    }
238    fn visit_expr_reference(&mut self, _expr: &Expr, _span: Span) -> bool {
239        true
240    }
241}
242
243// ===== Walk Functions =====
244
245/// Walk a program, visiting all items
246pub fn walk_program<V: Visitor>(visitor: &mut V, program: &Program) {
247    for item in &program.items {
248        walk_item(visitor, item);
249    }
250}
251
252/// Walk an item
253pub fn walk_item<V: Visitor>(visitor: &mut V, item: &Item) {
254    if !visitor.visit_item(item) {
255        return;
256    }
257
258    match item {
259        Item::Import(_, _) => {}
260        Item::Module(module_def, _) => {
261            for inner in &module_def.items {
262                walk_item(visitor, inner);
263            }
264        }
265        Item::Export(export, _) => match &export.item {
266            ExportItem::Function(func) => walk_function(visitor, func),
267            ExportItem::BuiltinFunction(_) => {}
268            ExportItem::BuiltinType(_) => {}
269            ExportItem::TypeAlias(_) => {}
270            ExportItem::Named(_) => {}
271            ExportItem::Enum(_) => {}
272            ExportItem::Struct(_) => {}
273            ExportItem::Interface(_) => {}
274            ExportItem::Trait(_) => {}
275            ExportItem::Annotation(annotation_def) => {
276                for handler in &annotation_def.handlers {
277                    walk_expr(visitor, &handler.body);
278                }
279            }
280            ExportItem::ForeignFunction(_) => {} // foreign bodies are opaque
281        },
282        Item::TypeAlias(_, _) => {}
283        Item::Interface(_, _) => {}
284        Item::Trait(_, _) => {}
285        Item::Enum(_, _) => {}
286        Item::Extend(extend, _) => {
287            for method in &extend.methods {
288                for stmt in &method.body {
289                    walk_stmt(visitor, stmt);
290                }
291            }
292        }
293        Item::Impl(impl_block, _) => {
294            for method in &impl_block.methods {
295                for stmt in &method.body {
296                    walk_stmt(visitor, stmt);
297                }
298            }
299        }
300        Item::Function(func, _) => walk_function(visitor, func),
301        Item::Query(query, _) => walk_query(visitor, query),
302        Item::VariableDecl(decl, _) => {
303            if let Some(value) = &decl.value {
304                walk_expr(visitor, value);
305            }
306        }
307        Item::Assignment(assign, _) => {
308            walk_expr(visitor, &assign.value);
309        }
310        Item::Expression(expr, _) => walk_expr(visitor, expr),
311        Item::Stream(stream, _) => {
312            for decl in &stream.state {
313                if let Some(value) = &decl.value {
314                    walk_expr(visitor, value);
315                }
316            }
317            if let Some(stmts) = &stream.on_connect {
318                for stmt in stmts {
319                    walk_stmt(visitor, stmt);
320                }
321            }
322            if let Some(stmts) = &stream.on_disconnect {
323                for stmt in stmts {
324                    walk_stmt(visitor, stmt);
325                }
326            }
327            if let Some(on_event) = &stream.on_event {
328                for stmt in &on_event.body {
329                    walk_stmt(visitor, stmt);
330                }
331            }
332            if let Some(on_window) = &stream.on_window {
333                for stmt in &on_window.body {
334                    walk_stmt(visitor, stmt);
335                }
336            }
337            if let Some(on_error) = &stream.on_error {
338                for stmt in &on_error.body {
339                    walk_stmt(visitor, stmt);
340                }
341            }
342        }
343        Item::Test(test, _) => {
344            if let Some(setup) = &test.setup {
345                for stmt in setup {
346                    walk_stmt(visitor, stmt);
347                }
348            }
349            if let Some(teardown) = &test.teardown {
350                for stmt in teardown {
351                    walk_stmt(visitor, stmt);
352                }
353            }
354            for case in &test.test_cases {
355                for test_stmt in &case.body {
356                    walk_test_statement(visitor, test_stmt);
357                }
358            }
359        }
360        Item::Optimize(opt, _) => {
361            walk_expr(visitor, &opt.range.0);
362            walk_expr(visitor, &opt.range.1);
363            if let OptimizationMetric::Custom(expr) = &opt.metric {
364                walk_expr(visitor, expr);
365            }
366        }
367        Item::Statement(stmt, _) => walk_stmt(visitor, stmt),
368        Item::AnnotationDef(ann_def, _) => {
369            // Walk the lifecycle handlers of the annotation definition
370            for handler in &ann_def.handlers {
371                walk_expr(visitor, &handler.body);
372            }
373        }
374        Item::StructType(_, _) => {
375            // No expressions to walk in struct type definitions
376        }
377        Item::DataSource(ds, _) => {
378            walk_expr(visitor, &ds.provider_expr);
379        }
380        Item::QueryDecl(_, _) => {
381            // Query declarations have no walkable expressions (SQL is a string literal)
382        }
383        Item::Comptime(stmts, _) => {
384            for stmt in stmts {
385                walk_stmt(visitor, stmt);
386            }
387        }
388        Item::BuiltinTypeDecl(_, _) => {
389            // Declaration-only intrinsic
390        }
391        Item::BuiltinFunctionDecl(_, _) => {
392            // Declaration-only intrinsic
393        }
394        Item::ForeignFunction(_, _) => {
395            // Foreign function bodies are opaque to the Shape visitor
396        }
397    }
398
399    visitor.leave_item(item);
400}
401
402/// Walk a function definition
403pub fn walk_function<V: Visitor>(visitor: &mut V, func: &FunctionDef) {
404    if !visitor.visit_function(func) {
405        return;
406    }
407
408    // Visit parameter default values
409    for param in &func.params {
410        if let Some(default) = &param.default_value {
411            walk_expr(visitor, default);
412        }
413    }
414
415    // Visit body statements
416    for stmt in &func.body {
417        walk_stmt(visitor, stmt);
418    }
419
420    visitor.leave_function(func);
421}
422
423/// Walk a query
424pub fn walk_query<V: Visitor>(visitor: &mut V, query: &Query) {
425    match query {
426        Query::Backtest(backtest) => {
427            for (_, expr) in &backtest.parameters {
428                walk_expr(visitor, expr);
429            }
430        }
431        Query::Alert(alert) => {
432            walk_expr(visitor, &alert.condition);
433        }
434        Query::With(with_query) => {
435            // Walk CTEs
436            for cte in &with_query.ctes {
437                walk_query(visitor, &cte.query);
438            }
439            // Walk main query
440            walk_query(visitor, &with_query.query);
441        }
442    }
443}
444
445/// Walk a statement
446pub fn walk_stmt<V: Visitor>(visitor: &mut V, stmt: &Statement) {
447    if !visitor.visit_stmt(stmt) {
448        return;
449    }
450
451    match stmt {
452        Statement::Return(expr, _) => {
453            if let Some(e) = expr {
454                walk_expr(visitor, e);
455            }
456        }
457        Statement::Break(_) => {}
458        Statement::Continue(_) => {}
459        Statement::VariableDecl(decl, _) => {
460            if let Some(value) = &decl.value {
461                walk_expr(visitor, value);
462            }
463        }
464        Statement::Assignment(assign, _) => {
465            walk_expr(visitor, &assign.value);
466        }
467        Statement::Expression(expr, _) => walk_expr(visitor, expr),
468        Statement::For(for_loop, _) => {
469            match &for_loop.init {
470                ForInit::ForIn { iter, .. } => walk_expr(visitor, iter),
471                ForInit::ForC {
472                    init,
473                    condition,
474                    update,
475                } => {
476                    walk_stmt(visitor, init);
477                    walk_expr(visitor, condition);
478                    walk_expr(visitor, update);
479                }
480            }
481            for stmt in &for_loop.body {
482                walk_stmt(visitor, stmt);
483            }
484        }
485        Statement::While(while_loop, _) => {
486            walk_expr(visitor, &while_loop.condition);
487            for stmt in &while_loop.body {
488                walk_stmt(visitor, stmt);
489            }
490        }
491        Statement::If(if_stmt, _) => {
492            walk_expr(visitor, &if_stmt.condition);
493            for stmt in &if_stmt.then_body {
494                walk_stmt(visitor, stmt);
495            }
496            if let Some(else_body) = &if_stmt.else_body {
497                for stmt in else_body {
498                    walk_stmt(visitor, stmt);
499                }
500            }
501        }
502        Statement::Extend(ext, _) => {
503            for method in &ext.methods {
504                for stmt in &method.body {
505                    walk_stmt(visitor, stmt);
506                }
507            }
508        }
509        Statement::RemoveTarget(_) => {}
510        Statement::SetParamType { .. }
511        | Statement::SetReturnType { .. }
512        | Statement::SetReturnExpr { .. } => {}
513        Statement::SetParamValue { expression, .. } => {
514            walk_expr(visitor, expression);
515        }
516        Statement::ReplaceModuleExpr { expression, .. } => {
517            walk_expr(visitor, expression);
518        }
519        Statement::ReplaceBodyExpr { expression, .. } => {
520            walk_expr(visitor, expression);
521        }
522        Statement::ReplaceBody { body, .. } => {
523            for stmt in body {
524                walk_stmt(visitor, stmt);
525            }
526        }
527    }
528
529    visitor.leave_stmt(stmt);
530}
531
532/// Walk an expression - ALL VARIANTS HANDLED EXPLICITLY
533///
534/// Visit order:
535/// 1. `visit_expr(expr)` — return `false` to skip entirely
536/// 2. `visit_expr_<variant>(expr, span)` — return `false` to skip children
537/// 3. Walk children recursively
538/// 4. `leave_expr(expr)`
539pub fn walk_expr<V: Visitor>(visitor: &mut V, expr: &Expr) {
540    if !visitor.visit_expr(expr) {
541        return;
542    }
543
544    match expr {
545        // Leaf nodes (no children)
546        Expr::Literal(lit, span) => {
547            if visitor.visit_expr_literal(expr, *span) {
548                visitor.visit_literal(lit);
549                visitor.leave_literal(lit);
550            }
551        }
552        Expr::Identifier(_, span) => {
553            visitor.visit_expr_identifier(expr, *span);
554        }
555        Expr::DataRef(data_ref, span) => {
556            if visitor.visit_expr_data_ref(expr, *span) {
557                match &data_ref.index {
558                    DataIndex::Expression(e) => walk_expr(visitor, e),
559                    DataIndex::ExpressionRange(start, end) => {
560                        walk_expr(visitor, start);
561                        walk_expr(visitor, end);
562                    }
563                    DataIndex::Single(_) | DataIndex::Range(_, _) => {}
564                }
565            }
566        }
567        Expr::DataDateTimeRef(_, span) => {
568            visitor.visit_expr_data_datetime_ref(expr, *span);
569        }
570        Expr::DataRelativeAccess {
571            reference,
572            index,
573            span,
574        } => {
575            if visitor.visit_expr_data_relative_access(expr, *span) {
576                walk_expr(visitor, reference);
577                match index {
578                    DataIndex::Expression(e) => walk_expr(visitor, e),
579                    DataIndex::ExpressionRange(start, end) => {
580                        walk_expr(visitor, start);
581                        walk_expr(visitor, end);
582                    }
583                    DataIndex::Single(_) | DataIndex::Range(_, _) => {}
584                }
585            }
586        }
587        Expr::PropertyAccess { object, span, .. } => {
588            if visitor.visit_expr_property_access(expr, *span) {
589                walk_expr(visitor, object);
590            }
591        }
592        Expr::IndexAccess {
593            object,
594            index,
595            end_index,
596            span,
597        } => {
598            if visitor.visit_expr_index_access(expr, *span) {
599                walk_expr(visitor, object);
600                walk_expr(visitor, index);
601                if let Some(end) = end_index {
602                    walk_expr(visitor, end);
603                }
604            }
605        }
606        Expr::BinaryOp {
607            left, right, span, ..
608        } => {
609            if visitor.visit_expr_binary_op(expr, *span) {
610                walk_expr(visitor, left);
611                walk_expr(visitor, right);
612            }
613        }
614        Expr::FuzzyComparison {
615            left, right, span, ..
616        } => {
617            if visitor.visit_expr_fuzzy_comparison(expr, *span) {
618                walk_expr(visitor, left);
619                walk_expr(visitor, right);
620            }
621        }
622        Expr::UnaryOp { operand, span, .. } => {
623            if visitor.visit_expr_unary_op(expr, *span) {
624                walk_expr(visitor, operand);
625            }
626        }
627        Expr::FunctionCall {
628            args,
629            named_args,
630            span,
631            ..
632        } => {
633            if visitor.visit_expr_function_call(expr, *span) {
634                for arg in args {
635                    walk_expr(visitor, arg);
636                }
637                for (_, value) in named_args {
638                    walk_expr(visitor, value);
639                }
640            }
641        }
642        Expr::QualifiedFunctionCall {
643            args,
644            named_args,
645            span,
646            ..
647        } => {
648            if visitor.visit_expr_function_call(expr, *span) {
649                for arg in args {
650                    walk_expr(visitor, arg);
651                }
652                for (_, value) in named_args {
653                    walk_expr(visitor, value);
654                }
655            }
656        }
657        Expr::EnumConstructor { payload, span, .. } => {
658            if visitor.visit_expr_enum_constructor(expr, *span) {
659                match payload {
660                    EnumConstructorPayload::Unit => {}
661                    EnumConstructorPayload::Tuple(values) => {
662                        for value in values {
663                            walk_expr(visitor, value);
664                        }
665                    }
666                    EnumConstructorPayload::Struct(fields) => {
667                        for (_, value) in fields {
668                            walk_expr(visitor, value);
669                        }
670                    }
671                }
672            }
673        }
674        Expr::TimeRef(_, span) => {
675            visitor.visit_expr_time_ref(expr, *span);
676        }
677        Expr::DateTime(_, span) => {
678            visitor.visit_expr_datetime(expr, *span);
679        }
680        Expr::PatternRef(_, span) => {
681            visitor.visit_expr_pattern_ref(expr, *span);
682        }
683        Expr::Conditional {
684            condition,
685            then_expr,
686            else_expr,
687            span,
688        } => {
689            if visitor.visit_expr_conditional(expr, *span) {
690                walk_expr(visitor, condition);
691                walk_expr(visitor, then_expr);
692                if let Some(else_e) = else_expr {
693                    walk_expr(visitor, else_e);
694                }
695            }
696        }
697        Expr::Object(entries, span) => {
698            if visitor.visit_expr_object(expr, *span) {
699                for entry in entries {
700                    match entry {
701                        ObjectEntry::Field { value, .. } => walk_expr(visitor, value),
702                        ObjectEntry::Spread(spread_expr) => walk_expr(visitor, spread_expr),
703                    }
704                }
705            }
706        }
707        Expr::Array(elements, span) => {
708            if visitor.visit_expr_array(expr, *span) {
709                for elem in elements {
710                    walk_expr(visitor, elem);
711                }
712            }
713        }
714        Expr::TableRows(rows, _span) => {
715            for row in rows {
716                for elem in row {
717                    walk_expr(visitor, elem);
718                }
719            }
720        }
721        Expr::ListComprehension(comp, span) => {
722            if visitor.visit_expr_list_comprehension(expr, *span) {
723                walk_expr(visitor, &comp.element);
724                for clause in &comp.clauses {
725                    walk_expr(visitor, &clause.iterable);
726                    if let Some(filter) = &clause.filter {
727                        walk_expr(visitor, filter);
728                    }
729                }
730            }
731        }
732        Expr::Block(block, span) => {
733            if visitor.visit_expr_block(expr, *span) {
734                if visitor.visit_block(block) {
735                    for item in &block.items {
736                        match item {
737                            BlockItem::VariableDecl(decl) => {
738                                if let Some(value) = &decl.value {
739                                    walk_expr(visitor, value);
740                                }
741                            }
742                            BlockItem::Assignment(assign) => {
743                                walk_expr(visitor, &assign.value);
744                            }
745                            BlockItem::Statement(stmt) => {
746                                walk_stmt(visitor, stmt);
747                            }
748                            BlockItem::Expression(e) => walk_expr(visitor, e),
749                        }
750                    }
751                    visitor.leave_block(block);
752                }
753            }
754        }
755        Expr::TypeAssertion {
756            expr: inner, span, ..
757        } => {
758            if visitor.visit_expr_type_assertion(expr, *span) {
759                walk_expr(visitor, inner);
760            }
761        }
762        Expr::InstanceOf {
763            expr: inner, span, ..
764        } => {
765            if visitor.visit_expr_instance_of(expr, *span) {
766                walk_expr(visitor, inner);
767            }
768        }
769        Expr::FunctionExpr {
770            params, body, span, ..
771        } => {
772            if visitor.visit_expr_function_expr(expr, *span) {
773                for param in params {
774                    if let Some(default) = &param.default_value {
775                        walk_expr(visitor, default);
776                    }
777                }
778                for stmt in body {
779                    walk_stmt(visitor, stmt);
780                }
781            }
782        }
783        Expr::Duration(_, span) => {
784            visitor.visit_expr_duration(expr, *span);
785        }
786        Expr::Spread(inner, span) => {
787            if visitor.visit_expr_spread(expr, *span) {
788                walk_expr(visitor, inner);
789            }
790        }
791        Expr::If(if_expr, span) => {
792            if visitor.visit_expr_if(expr, *span) {
793                walk_expr(visitor, &if_expr.condition);
794                walk_expr(visitor, &if_expr.then_branch);
795                if let Some(else_branch) = &if_expr.else_branch {
796                    walk_expr(visitor, else_branch);
797                }
798            }
799        }
800        Expr::While(while_expr, span) => {
801            if visitor.visit_expr_while(expr, *span) {
802                walk_expr(visitor, &while_expr.condition);
803                walk_expr(visitor, &while_expr.body);
804            }
805        }
806        Expr::For(for_expr, span) => {
807            if visitor.visit_expr_for(expr, *span) {
808                walk_expr(visitor, &for_expr.iterable);
809                walk_expr(visitor, &for_expr.body);
810            }
811        }
812        Expr::Loop(loop_expr, span) => {
813            if visitor.visit_expr_loop(expr, *span) {
814                walk_expr(visitor, &loop_expr.body);
815            }
816        }
817        Expr::Let(let_expr, span) => {
818            if visitor.visit_expr_let(expr, *span) {
819                if let Some(value) = &let_expr.value {
820                    walk_expr(visitor, value);
821                }
822                walk_expr(visitor, &let_expr.body);
823            }
824        }
825        Expr::Assign(assign, span) => {
826            if visitor.visit_expr_assign(expr, *span) {
827                walk_expr(visitor, &assign.target);
828                walk_expr(visitor, &assign.value);
829            }
830        }
831        Expr::Break(inner, span) => {
832            if visitor.visit_expr_break(expr, *span) {
833                if let Some(e) = inner {
834                    walk_expr(visitor, e);
835                }
836            }
837        }
838        Expr::Continue(span) => {
839            visitor.visit_expr_continue(expr, *span);
840        }
841        Expr::Return(inner, span) => {
842            if visitor.visit_expr_return(expr, *span) {
843                if let Some(e) = inner {
844                    walk_expr(visitor, e);
845                }
846            }
847        }
848        Expr::MethodCall {
849            receiver,
850            args,
851            named_args,
852            span,
853            ..
854        } => {
855            if visitor.visit_expr_method_call(expr, *span) {
856                walk_expr(visitor, receiver);
857                for arg in args {
858                    walk_expr(visitor, arg);
859                }
860                for (_, value) in named_args {
861                    walk_expr(visitor, value);
862                }
863            }
864        }
865        Expr::Match(match_expr, span) => {
866            if visitor.visit_expr_match(expr, *span) {
867                walk_expr(visitor, &match_expr.scrutinee);
868                for arm in &match_expr.arms {
869                    if let Some(guard) = &arm.guard {
870                        walk_expr(visitor, guard);
871                    }
872                    walk_expr(visitor, &arm.body);
873                }
874            }
875        }
876        Expr::Unit(span) => {
877            visitor.visit_expr_unit(expr, *span);
878        }
879        Expr::Range {
880            start, end, span, ..
881        } => {
882            if visitor.visit_expr_range(expr, *span) {
883                if let Some(s) = start {
884                    walk_expr(visitor, s);
885                }
886                if let Some(e) = end {
887                    walk_expr(visitor, e);
888                }
889            }
890        }
891        Expr::TimeframeContext {
892            expr: inner, span, ..
893        } => {
894            if visitor.visit_expr_timeframe_context(expr, *span) {
895                walk_expr(visitor, inner);
896            }
897        }
898        Expr::TryOperator(inner, span) => {
899            if visitor.visit_expr_try_operator(expr, *span) {
900                walk_expr(visitor, inner);
901            }
902        }
903        Expr::UsingImpl {
904            expr: inner, span, ..
905        } => {
906            if visitor.visit_expr_using_impl(expr, *span) {
907                walk_expr(visitor, inner);
908            }
909        }
910        Expr::SimulationCall { params, span, .. } => {
911            if visitor.visit_expr_simulation_call(expr, *span) {
912                for (_, value) in params {
913                    walk_expr(visitor, value);
914                }
915            }
916        }
917        Expr::WindowExpr(window_expr, span) => {
918            if visitor.visit_expr_window_expr(expr, *span) {
919                // Walk function argument expressions
920                match &window_expr.function {
921                    WindowFunction::Lead { expr, default, .. }
922                    | WindowFunction::Lag { expr, default, .. } => {
923                        walk_expr(visitor, expr);
924                        if let Some(d) = default {
925                            walk_expr(visitor, d);
926                        }
927                    }
928                    WindowFunction::FirstValue(e)
929                    | WindowFunction::LastValue(e)
930                    | WindowFunction::Sum(e)
931                    | WindowFunction::Avg(e)
932                    | WindowFunction::Min(e)
933                    | WindowFunction::Max(e) => {
934                        walk_expr(visitor, e);
935                    }
936                    WindowFunction::NthValue(e, _) => {
937                        walk_expr(visitor, e);
938                    }
939                    WindowFunction::Count(opt_e) => {
940                        if let Some(e) = opt_e {
941                            walk_expr(visitor, e);
942                        }
943                    }
944                    WindowFunction::RowNumber
945                    | WindowFunction::Rank
946                    | WindowFunction::DenseRank
947                    | WindowFunction::Ntile(_) => {}
948                }
949                // Walk partition_by expressions
950                for e in &window_expr.over.partition_by {
951                    walk_expr(visitor, e);
952                }
953                // Walk order_by expressions
954                if let Some(order_by) = &window_expr.over.order_by {
955                    for (e, _) in &order_by.columns {
956                        walk_expr(visitor, e);
957                    }
958                }
959            }
960        }
961        Expr::FromQuery(from_query, span) => {
962            if visitor.visit_expr_from_query(expr, *span) {
963                // Walk source expression
964                walk_expr(visitor, &from_query.source);
965                // Walk each clause
966                for clause in &from_query.clauses {
967                    match clause {
968                        QueryClause::Where(pred) => {
969                            walk_expr(visitor, pred);
970                        }
971                        QueryClause::OrderBy(specs) => {
972                            for spec in specs {
973                                walk_expr(visitor, &spec.key);
974                            }
975                        }
976                        QueryClause::GroupBy { element, key, .. } => {
977                            walk_expr(visitor, element);
978                            walk_expr(visitor, key);
979                        }
980                        QueryClause::Join {
981                            source,
982                            left_key,
983                            right_key,
984                            ..
985                        } => {
986                            walk_expr(visitor, source);
987                            walk_expr(visitor, left_key);
988                            walk_expr(visitor, right_key);
989                        }
990                        QueryClause::Let { value, .. } => {
991                            walk_expr(visitor, value);
992                        }
993                    }
994                }
995                // Walk select expression
996                walk_expr(visitor, &from_query.select);
997            }
998        }
999        Expr::StructLiteral { fields, span, .. } => {
1000            if visitor.visit_expr_struct_literal(expr, *span) {
1001                for (_, value_expr) in fields {
1002                    walk_expr(visitor, value_expr);
1003                }
1004            }
1005        }
1006        Expr::Await(inner, span) => {
1007            if visitor.visit_expr_await(expr, *span) {
1008                walk_expr(visitor, inner);
1009            }
1010        }
1011        Expr::Join(join_expr, span) => {
1012            if visitor.visit_expr_join(expr, *span) {
1013                for branch in &join_expr.branches {
1014                    walk_expr(visitor, &branch.expr);
1015                }
1016            }
1017        }
1018        Expr::Annotated { target, span, .. } => {
1019            if visitor.visit_expr_annotated(expr, *span) {
1020                walk_expr(visitor, target);
1021            }
1022        }
1023        Expr::AsyncLet(async_let, span) => {
1024            if visitor.visit_expr_async_let(expr, *span) {
1025                walk_expr(visitor, &async_let.expr);
1026            }
1027        }
1028        Expr::AsyncScope(inner, span) => {
1029            if visitor.visit_expr_async_scope(expr, *span) {
1030                walk_expr(visitor, inner);
1031            }
1032        }
1033        Expr::Comptime(stmts, span) => {
1034            if visitor.visit_expr_comptime(expr, *span) {
1035                for stmt in stmts {
1036                    walk_stmt(visitor, stmt);
1037                }
1038            }
1039        }
1040        Expr::ComptimeFor(cf, span) => {
1041            if visitor.visit_expr_comptime_for(expr, *span) {
1042                walk_expr(visitor, &cf.iterable);
1043                for stmt in &cf.body {
1044                    walk_stmt(visitor, stmt);
1045                }
1046            }
1047        }
1048        Expr::Reference {
1049            expr: inner, span, ..
1050        } => {
1051            if visitor.visit_expr_reference(expr, *span) {
1052                walk_expr(visitor, inner);
1053            }
1054        }
1055    }
1056
1057    visitor.leave_expr(expr);
1058}
1059
1060/// Walk a test statement
1061fn walk_test_statement<V: Visitor>(visitor: &mut V, test_stmt: &TestStatement) {
1062    match test_stmt {
1063        TestStatement::Statement(stmt) => walk_stmt(visitor, stmt),
1064        TestStatement::Assert(assert) => {
1065            walk_expr(visitor, &assert.condition);
1066        }
1067        TestStatement::Expect(expect) => {
1068            walk_expr(visitor, &expect.actual);
1069            match &expect.matcher {
1070                ExpectationMatcher::ToBe(e) => walk_expr(visitor, e),
1071                ExpectationMatcher::ToEqual(e) => walk_expr(visitor, e),
1072                ExpectationMatcher::ToBeCloseTo { expected, .. } => walk_expr(visitor, expected),
1073                ExpectationMatcher::ToBeGreaterThan(e) => walk_expr(visitor, e),
1074                ExpectationMatcher::ToBeLessThan(e) => walk_expr(visitor, e),
1075                ExpectationMatcher::ToContain(e) => walk_expr(visitor, e),
1076                ExpectationMatcher::ToBeTruthy => {}
1077                ExpectationMatcher::ToBeFalsy => {}
1078                ExpectationMatcher::ToThrow(_) => {}
1079                ExpectationMatcher::ToMatchPattern { .. } => {}
1080            }
1081        }
1082        TestStatement::Should(should) => {
1083            walk_expr(visitor, &should.subject);
1084            match &should.matcher {
1085                ShouldMatcher::Be(e) => walk_expr(visitor, e),
1086                ShouldMatcher::Equal(e) => walk_expr(visitor, e),
1087                ShouldMatcher::Contain(e) => walk_expr(visitor, e),
1088                ShouldMatcher::Match(_) => {}
1089                ShouldMatcher::BeCloseTo { expected, .. } => walk_expr(visitor, expected),
1090            }
1091        }
1092        TestStatement::Fixture(fixture) => match fixture {
1093            TestFixture::WithData { data, body } => {
1094                walk_expr(visitor, data);
1095                for stmt in body {
1096                    walk_stmt(visitor, stmt);
1097                }
1098            }
1099            TestFixture::WithMock {
1100                mock_value, body, ..
1101            } => {
1102                if let Some(value) = mock_value {
1103                    walk_expr(visitor, value);
1104                }
1105                for stmt in body {
1106                    walk_stmt(visitor, stmt);
1107                }
1108            }
1109        },
1110    }
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115    use super::*;
1116
1117    /// Simple visitor that counts expressions
1118    struct ExprCounter {
1119        count: usize,
1120    }
1121
1122    impl Visitor for ExprCounter {
1123        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1124            self.count += 1;
1125            true
1126        }
1127    }
1128
1129    #[test]
1130    fn test_visitor_counts_expressions() {
1131        let program = Program {
1132            items: vec![Item::Expression(
1133                Expr::BinaryOp {
1134                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1135                    op: BinaryOp::Add,
1136                    right: Box::new(Expr::Literal(Literal::Number(1.0), Span::DUMMY)),
1137                    span: Span::DUMMY,
1138                },
1139                Span::DUMMY,
1140            )],
1141            docs: shape_ast::ast::ProgramDocs::default(),
1142        };
1143
1144        let mut counter = ExprCounter { count: 0 };
1145        walk_program(&mut counter, &program);
1146
1147        // Should count: BinaryOp, Identifier, Literal = 3
1148        assert_eq!(counter.count, 3);
1149    }
1150
1151    #[test]
1152    fn test_visitor_handles_try_operator() {
1153        let program = Program {
1154            items: vec![Item::Expression(
1155                Expr::TryOperator(
1156                    Box::new(Expr::FunctionCall {
1157                        name: "some_function".to_string(),
1158                        args: vec![Expr::Literal(
1159                            Literal::String("arg".to_string()),
1160                            Span::DUMMY,
1161                        )],
1162                        named_args: vec![],
1163                        span: Span::DUMMY,
1164                    }),
1165                    Span::DUMMY,
1166                ),
1167                Span::DUMMY,
1168            )],
1169            docs: shape_ast::ast::ProgramDocs::default(),
1170        };
1171
1172        let mut counter = ExprCounter { count: 0 };
1173        walk_program(&mut counter, &program);
1174
1175        // Should count: TryOperator, FunctionCall, Literal = 3
1176        assert_eq!(counter.count, 3);
1177    }
1178
1179    /// Test that per-variant visitor methods work
1180    struct IdentifierCollector {
1181        names: Vec<String>,
1182    }
1183
1184    impl Visitor for IdentifierCollector {
1185        fn visit_expr_identifier(&mut self, expr: &Expr, _span: Span) -> bool {
1186            if let Expr::Identifier(name, _) = expr {
1187                self.names.push(name.clone());
1188            }
1189            true
1190        }
1191    }
1192
1193    #[test]
1194    fn test_per_variant_visitor_identifier() {
1195        let program = Program {
1196            items: vec![Item::Expression(
1197                Expr::BinaryOp {
1198                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1199                    op: BinaryOp::Add,
1200                    right: Box::new(Expr::Identifier("y".to_string(), Span::DUMMY)),
1201                    span: Span::DUMMY,
1202                },
1203                Span::DUMMY,
1204            )],
1205            docs: shape_ast::ast::ProgramDocs::default(),
1206        };
1207
1208        let mut collector = IdentifierCollector { names: vec![] };
1209        walk_program(&mut collector, &program);
1210
1211        assert_eq!(collector.names, vec!["x", "y"]);
1212    }
1213
1214    /// Test that per-variant method can skip children
1215    struct SkippingVisitor {
1216        count: usize,
1217    }
1218
1219    impl Visitor for SkippingVisitor {
1220        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1221            self.count += 1;
1222            true
1223        }
1224        // Skip children of BinaryOp
1225        fn visit_expr_binary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
1226            false
1227        }
1228    }
1229
1230    #[test]
1231    fn test_per_variant_skip_children() {
1232        let program = Program {
1233            items: vec![Item::Expression(
1234                Expr::BinaryOp {
1235                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1236                    op: BinaryOp::Add,
1237                    right: Box::new(Expr::Literal(Literal::Number(1.0), Span::DUMMY)),
1238                    span: Span::DUMMY,
1239                },
1240                Span::DUMMY,
1241            )],
1242            docs: shape_ast::ast::ProgramDocs::default(),
1243        };
1244
1245        let mut v = SkippingVisitor { count: 0 };
1246        walk_program(&mut v, &program);
1247
1248        // Only BinaryOp counted, children skipped
1249        assert_eq!(v.count, 1);
1250    }
1251
1252    /// Test combined coarse + per-variant
1253    struct MatchCollector {
1254        match_count: usize,
1255        total_expr_count: usize,
1256    }
1257
1258    impl Visitor for MatchCollector {
1259        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1260            self.total_expr_count += 1;
1261            true
1262        }
1263        fn visit_expr_match(&mut self, _expr: &Expr, _span: Span) -> bool {
1264            self.match_count += 1;
1265            true
1266        }
1267    }
1268
1269    #[test]
1270    fn test_coarse_and_per_variant_combined() {
1271        let program = Program {
1272            items: vec![Item::Expression(
1273                Expr::BinaryOp {
1274                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1275                    op: BinaryOp::Add,
1276                    right: Box::new(Expr::Identifier("y".to_string(), Span::DUMMY)),
1277                    span: Span::DUMMY,
1278                },
1279                Span::DUMMY,
1280            )],
1281            docs: shape_ast::ast::ProgramDocs::default(),
1282        };
1283
1284        let mut mc = MatchCollector {
1285            match_count: 0,
1286            total_expr_count: 0,
1287        };
1288        walk_program(&mut mc, &program);
1289
1290        assert_eq!(mc.total_expr_count, 3); // BinaryOp + x + y
1291        assert_eq!(mc.match_count, 0); // No Match expressions
1292    }
1293}