Skip to main content

seqc/
ast.rs

1//! Abstract Syntax Tree for Seq
2//!
3//! Minimal AST sufficient for hello-world and basic programs.
4//! Will be extended as we add more language features.
5
6use crate::types::{Effect, StackType, Type};
7use std::path::PathBuf;
8
9/// Source location for error reporting and tooling
10#[derive(Debug, Clone, PartialEq)]
11pub struct SourceLocation {
12    pub file: PathBuf,
13    /// Start line (0-indexed for LSP compatibility)
14    pub start_line: usize,
15    /// End line (0-indexed, inclusive)
16    pub end_line: usize,
17}
18
19impl SourceLocation {
20    /// Create a new source location with just a single line (for backward compatibility)
21    pub fn new(file: PathBuf, line: usize) -> Self {
22        SourceLocation {
23            file,
24            start_line: line,
25            end_line: line,
26        }
27    }
28
29    /// Create a source location spanning multiple lines
30    pub fn span(file: PathBuf, start_line: usize, end_line: usize) -> Self {
31        debug_assert!(
32            start_line <= end_line,
33            "SourceLocation: start_line ({}) must be <= end_line ({})",
34            start_line,
35            end_line
36        );
37        SourceLocation {
38            file,
39            start_line,
40            end_line,
41        }
42    }
43}
44
45impl std::fmt::Display for SourceLocation {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        if self.start_line == self.end_line {
48            write!(f, "{}:{}", self.file.display(), self.start_line + 1)
49        } else {
50            write!(
51                f,
52                "{}:{}-{}",
53                self.file.display(),
54                self.start_line + 1,
55                self.end_line + 1
56            )
57        }
58    }
59}
60
61/// Include statement
62#[derive(Debug, Clone, PartialEq)]
63pub enum Include {
64    /// Standard library include: `include std:http`
65    Std(String),
66    /// Relative path include: `include "my-utils"`
67    Relative(String),
68    /// FFI library include: `include ffi:readline`
69    Ffi(String),
70}
71
72// ============================================================================
73//                     ALGEBRAIC DATA TYPES (ADTs)
74// ============================================================================
75
76/// A field in a union variant
77/// Example: `response-chan: Int`
78#[derive(Debug, Clone, PartialEq)]
79pub struct UnionField {
80    pub name: String,
81    pub type_name: String, // For now, just store the type name as string
82}
83
84/// A variant in a union type
85/// Example: `Get { response-chan: Int }`
86#[derive(Debug, Clone, PartialEq)]
87pub struct UnionVariant {
88    pub name: String,
89    pub fields: Vec<UnionField>,
90    pub source: Option<SourceLocation>,
91}
92
93/// A union type definition
94/// Example:
95/// ```seq
96/// union Message {
97///   Get { response-chan: Int }
98///   Increment { response-chan: Int }
99///   Report { op: Int, delta: Int, total: Int }
100/// }
101/// ```
102#[derive(Debug, Clone, PartialEq)]
103pub struct UnionDef {
104    pub name: String,
105    pub variants: Vec<UnionVariant>,
106    pub source: Option<SourceLocation>,
107}
108
109/// A pattern in a match expression
110/// For Phase 1: just the variant name (stack-based matching)
111/// Later phases will add field bindings: `Get { chan }`
112#[derive(Debug, Clone, PartialEq)]
113pub enum Pattern {
114    /// Match a variant by name, pushing all fields to stack
115    /// Example: `Get ->` pushes response-chan to stack
116    Variant(String),
117
118    /// Match a variant with named field bindings (Phase 5)
119    /// Example: `Get { chan } ->` binds chan to the response-chan field
120    VariantWithBindings { name: String, bindings: Vec<String> },
121}
122
123/// A single arm in a match expression
124#[derive(Debug, Clone, PartialEq)]
125pub struct MatchArm {
126    pub pattern: Pattern,
127    pub body: Vec<Statement>,
128    /// Source span for error reporting (points to variant name)
129    pub span: Option<Span>,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct Program {
134    pub includes: Vec<Include>,
135    pub unions: Vec<UnionDef>,
136    pub words: Vec<WordDef>,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct WordDef {
141    pub name: String,
142    /// Optional stack effect declaration
143    /// Example: ( ..a Int -- ..a Bool )
144    pub effect: Option<Effect>,
145    pub body: Vec<Statement>,
146    /// Source location for error reporting (collision detection)
147    pub source: Option<SourceLocation>,
148    /// Lint IDs that are allowed (suppressed) for this word
149    /// Set via `# seq:allow(lint-id)` annotation before the word definition
150    pub allowed_lints: Vec<String>,
151}
152
153/// Source span for a single token or expression
154#[derive(Debug, Clone, PartialEq, Default)]
155pub struct Span {
156    /// Line number (0-indexed)
157    pub line: usize,
158    /// Start column (0-indexed)
159    pub column: usize,
160    /// Length of the span in characters
161    pub length: usize,
162}
163
164impl Span {
165    pub fn new(line: usize, column: usize, length: usize) -> Self {
166        Span {
167            line,
168            column,
169            length,
170        }
171    }
172}
173
174/// Source span for a quotation, supporting multi-line ranges
175#[derive(Debug, Clone, PartialEq, Default)]
176pub struct QuotationSpan {
177    /// Start line (0-indexed)
178    pub start_line: usize,
179    /// Start column (0-indexed)
180    pub start_column: usize,
181    /// End line (0-indexed)
182    pub end_line: usize,
183    /// End column (0-indexed, exclusive)
184    pub end_column: usize,
185}
186
187impl QuotationSpan {
188    pub fn new(start_line: usize, start_column: usize, end_line: usize, end_column: usize) -> Self {
189        QuotationSpan {
190            start_line,
191            start_column,
192            end_line,
193            end_column,
194        }
195    }
196
197    /// Check if a position (line, column) falls within this span
198    pub fn contains(&self, line: usize, column: usize) -> bool {
199        if line < self.start_line || line > self.end_line {
200            return false;
201        }
202        if line == self.start_line && column < self.start_column {
203            return false;
204        }
205        if line == self.end_line && column >= self.end_column {
206            return false;
207        }
208        true
209    }
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub enum Statement {
214    /// Integer literal: pushes value onto stack
215    IntLiteral(i64),
216
217    /// Floating-point literal: pushes IEEE 754 double onto stack
218    FloatLiteral(f64),
219
220    /// Boolean literal: pushes true/false onto stack
221    BoolLiteral(bool),
222
223    /// String literal: pushes string onto stack
224    StringLiteral(String),
225
226    /// Symbol literal: pushes symbol onto stack
227    /// Syntax: :foo, :some-name, :ok
228    /// Used for dynamic variant construction and SON.
229    /// Note: Symbols are not currently interned (future optimization).
230    Symbol(String),
231
232    /// Word call: calls another word or built-in
233    /// Contains the word name and optional source span for precise diagnostics
234    WordCall { name: String, span: Option<Span> },
235
236    /// Conditional: if/else/then
237    ///
238    /// Pops an integer from the stack (0 = zero, non-zero = non-zero)
239    /// and executes the appropriate branch
240    If {
241        /// Statements to execute when condition is non-zero (the 'then' clause)
242        then_branch: Vec<Statement>,
243        /// Optional statements to execute when condition is zero (the 'else' clause)
244        else_branch: Option<Vec<Statement>>,
245        /// Source span for error reporting (points to 'if' keyword)
246        span: Option<Span>,
247    },
248
249    /// Quotation: [ ... ]
250    ///
251    /// A block of deferred code (quotation/lambda)
252    /// Quotations are first-class values that can be pushed onto the stack
253    /// and executed later with combinators like `call`, `times`, or `while`
254    ///
255    /// The id field is used by the typechecker to track the inferred type
256    /// (Quotation vs Closure) for this quotation. The id is assigned during parsing.
257    /// The span field records the source location for LSP hover support.
258    Quotation {
259        id: usize,
260        body: Vec<Statement>,
261        span: Option<QuotationSpan>,
262    },
263
264    /// Match expression: pattern matching on union types
265    ///
266    /// Pops a union value from the stack and dispatches to the
267    /// appropriate arm based on the variant tag.
268    ///
269    /// Example:
270    /// ```seq
271    /// match
272    ///   Get -> send-response
273    ///   Increment -> do-increment send-response
274    ///   Report -> aggregate-add
275    /// end
276    /// ```
277    Match {
278        /// The match arms in order
279        arms: Vec<MatchArm>,
280        /// Source span for error reporting (points to 'match' keyword)
281        span: Option<Span>,
282    },
283}
284
285impl Program {
286    pub fn new() -> Self {
287        Program {
288            includes: Vec::new(),
289            unions: Vec::new(),
290            words: Vec::new(),
291        }
292    }
293
294    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
295        self.words.iter().find(|w| w.name == name)
296    }
297
298    /// Validate that all word calls reference either a defined word or a built-in
299    pub fn validate_word_calls(&self) -> Result<(), String> {
300        self.validate_word_calls_with_externals(&[])
301    }
302
303    /// Validate that all word calls reference a defined word, built-in, or external word.
304    ///
305    /// The `external_words` parameter should contain names of words available from
306    /// external sources (e.g., included modules) that should be considered valid.
307    pub fn validate_word_calls_with_externals(
308        &self,
309        external_words: &[&str],
310    ) -> Result<(), String> {
311        // List of known runtime built-ins
312        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
313        let builtins = [
314            // I/O operations
315            "io.write",
316            "io.write-line",
317            "io.read-line",
318            "io.read-line+",
319            "io.read-n",
320            "int->string",
321            "symbol->string",
322            "string->symbol",
323            // Command-line arguments
324            "args.count",
325            "args.at",
326            // File operations
327            "file.slurp",
328            "file.exists?",
329            "file.for-each-line+",
330            "file.spit",
331            "file.append",
332            "file.delete",
333            "file.size",
334            // Directory operations
335            "dir.exists?",
336            "dir.make",
337            "dir.delete",
338            "dir.list",
339            // String operations
340            "string.concat",
341            "string.length",
342            "string.byte-length",
343            "string.char-at",
344            "string.substring",
345            "char->string",
346            "string.find",
347            "string.split",
348            "string.contains",
349            "string.starts-with",
350            "string.empty?",
351            "string.trim",
352            "string.chomp",
353            "string.to-upper",
354            "string.to-lower",
355            "string.equal?",
356            "string.join",
357            "string.json-escape",
358            "string->int",
359            // Symbol operations
360            "symbol.=",
361            // Encoding operations
362            "encoding.base64-encode",
363            "encoding.base64-decode",
364            "encoding.base64url-encode",
365            "encoding.base64url-decode",
366            "encoding.hex-encode",
367            "encoding.hex-decode",
368            // Crypto operations
369            "crypto.sha256",
370            "crypto.hmac-sha256",
371            "crypto.constant-time-eq",
372            "crypto.random-bytes",
373            "crypto.random-int",
374            "crypto.uuid4",
375            "crypto.aes-gcm-encrypt",
376            "crypto.aes-gcm-decrypt",
377            "crypto.pbkdf2-sha256",
378            "crypto.ed25519-keypair",
379            "crypto.ed25519-sign",
380            "crypto.ed25519-verify",
381            // HTTP client operations
382            "http.get",
383            "http.post",
384            "http.put",
385            "http.delete",
386            // List operations
387            "list.make",
388            "list.push",
389            "list.get",
390            "list.set",
391            "list.map",
392            "list.filter",
393            "list.fold",
394            "list.each",
395            "list.length",
396            "list.empty?",
397            "list.reverse",
398            // Map operations
399            "map.make",
400            "map.get",
401            "map.set",
402            "map.has?",
403            "map.remove",
404            "map.keys",
405            "map.values",
406            "map.size",
407            "map.empty?",
408            "map.each",
409            "map.fold",
410            // Variant operations
411            "variant.field-count",
412            "variant.tag",
413            "variant.field-at",
414            "variant.append",
415            "variant.last",
416            "variant.init",
417            "variant.make-0",
418            "variant.make-1",
419            "variant.make-2",
420            "variant.make-3",
421            "variant.make-4",
422            // SON wrap aliases
423            "wrap-0",
424            "wrap-1",
425            "wrap-2",
426            "wrap-3",
427            "wrap-4",
428            // Integer arithmetic operations
429            "i.add",
430            "i.subtract",
431            "i.multiply",
432            "i.divide",
433            "i.modulo",
434            // Terse integer arithmetic
435            "i.+",
436            "i.-",
437            "i.*",
438            "i./",
439            "i.%",
440            // Integer comparison operations (return 0 or 1)
441            "i.=",
442            "i.<",
443            "i.>",
444            "i.<=",
445            "i.>=",
446            "i.<>",
447            // Integer comparison operations (verbose form)
448            "i.eq",
449            "i.lt",
450            "i.gt",
451            "i.lte",
452            "i.gte",
453            "i.neq",
454            // Stack operations (simple - no parameters)
455            "dup",
456            "drop",
457            "swap",
458            "over",
459            "rot",
460            "nip",
461            "tuck",
462            "2dup",
463            "3drop",
464            "pick",
465            "roll",
466            // Aux stack operations
467            ">aux",
468            "aux>",
469            // Boolean operations
470            "and",
471            "or",
472            "not",
473            // Bitwise operations
474            "band",
475            "bor",
476            "bxor",
477            "bnot",
478            "i.neg",
479            "negate",
480            "shl",
481            "shr",
482            "popcount",
483            "clz",
484            "ctz",
485            "int-bits",
486            // Channel operations
487            "chan.make",
488            "chan.send",
489            "chan.receive",
490            "chan.close",
491            "chan.yield",
492            // Quotation operations
493            "call",
494            // Dataflow combinators
495            "dip",
496            "keep",
497            "bi",
498            "strand.spawn",
499            "strand.weave",
500            "strand.resume",
501            "strand.weave-cancel",
502            "yield",
503            "cond",
504            // TCP operations
505            "tcp.listen",
506            "tcp.accept",
507            "tcp.read",
508            "tcp.write",
509            "tcp.close",
510            // OS operations
511            "os.getenv",
512            "os.home-dir",
513            "os.current-dir",
514            "os.path-exists",
515            "os.path-is-file",
516            "os.path-is-dir",
517            "os.path-join",
518            "os.path-parent",
519            "os.path-filename",
520            "os.exit",
521            "os.name",
522            "os.arch",
523            // Signal handling
524            "signal.trap",
525            "signal.received?",
526            "signal.pending?",
527            "signal.default",
528            "signal.ignore",
529            "signal.clear",
530            "signal.SIGINT",
531            "signal.SIGTERM",
532            "signal.SIGHUP",
533            "signal.SIGPIPE",
534            "signal.SIGUSR1",
535            "signal.SIGUSR2",
536            "signal.SIGCHLD",
537            "signal.SIGALRM",
538            "signal.SIGCONT",
539            // Terminal operations
540            "terminal.raw-mode",
541            "terminal.read-char",
542            "terminal.read-char?",
543            "terminal.width",
544            "terminal.height",
545            "terminal.flush",
546            // Float arithmetic operations (verbose form)
547            "f.add",
548            "f.subtract",
549            "f.multiply",
550            "f.divide",
551            // Float arithmetic operations (terse form)
552            "f.+",
553            "f.-",
554            "f.*",
555            "f./",
556            // Float comparison operations (symbol form)
557            "f.=",
558            "f.<",
559            "f.>",
560            "f.<=",
561            "f.>=",
562            "f.<>",
563            // Float comparison operations (verbose form)
564            "f.eq",
565            "f.lt",
566            "f.gt",
567            "f.lte",
568            "f.gte",
569            "f.neq",
570            // Type conversions
571            "int->float",
572            "float->int",
573            "float->string",
574            "string->float",
575            // Test framework operations
576            "test.init",
577            "test.finish",
578            "test.has-failures",
579            "test.assert",
580            "test.assert-not",
581            "test.assert-eq",
582            "test.assert-eq-str",
583            "test.fail",
584            "test.pass-count",
585            "test.fail-count",
586            // Time operations
587            "time.now",
588            "time.nanos",
589            "time.sleep-ms",
590            // SON serialization
591            "son.dump",
592            "son.dump-pretty",
593            // Stack introspection (for REPL)
594            "stack.dump",
595            // Regex operations
596            "regex.match?",
597            "regex.find",
598            "regex.find-all",
599            "regex.replace",
600            "regex.replace-all",
601            "regex.captures",
602            "regex.split",
603            "regex.valid?",
604            // Compression operations
605            "compress.gzip",
606            "compress.gzip-level",
607            "compress.gunzip",
608            "compress.zstd",
609            "compress.zstd-level",
610            "compress.unzstd",
611        ];
612
613        for word in &self.words {
614            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
615        }
616
617        Ok(())
618    }
619
620    /// Helper to validate word calls in a list of statements (recursively)
621    fn validate_statements(
622        &self,
623        statements: &[Statement],
624        word_name: &str,
625        builtins: &[&str],
626        external_words: &[&str],
627    ) -> Result<(), String> {
628        for statement in statements {
629            match statement {
630                Statement::WordCall { name, .. } => {
631                    // Check if it's a built-in
632                    if builtins.contains(&name.as_str()) {
633                        continue;
634                    }
635                    // Check if it's a user-defined word
636                    if self.find_word(name).is_some() {
637                        continue;
638                    }
639                    // Check if it's an external word (from includes)
640                    if external_words.contains(&name.as_str()) {
641                        continue;
642                    }
643                    // Undefined word!
644                    return Err(format!(
645                        "Undefined word '{}' called in word '{}'. \
646                         Did you forget to define it or misspell a built-in?",
647                        name, word_name
648                    ));
649                }
650                Statement::If {
651                    then_branch,
652                    else_branch,
653                    span: _,
654                } => {
655                    // Recursively validate both branches
656                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
657                    if let Some(eb) = else_branch {
658                        self.validate_statements(eb, word_name, builtins, external_words)?;
659                    }
660                }
661                Statement::Quotation { body, .. } => {
662                    // Recursively validate quotation body
663                    self.validate_statements(body, word_name, builtins, external_words)?;
664                }
665                Statement::Match { arms, span: _ } => {
666                    // Recursively validate each match arm's body
667                    for arm in arms {
668                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
669                    }
670                }
671                _ => {} // Literals don't need validation
672            }
673        }
674        Ok(())
675    }
676
677    /// Generate constructor words for all union definitions
678    ///
679    /// Maximum number of fields a variant can have (limited by runtime support)
680    pub const MAX_VARIANT_FIELDS: usize = 12;
681
682    /// Generate helper words for union types:
683    /// 1. Constructors: `Make-VariantName` - creates variant instances
684    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
685    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
686    ///
687    /// Example: For `union Message { Get { chan: Int } }`
688    /// Generates:
689    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
690    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
691    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
692    ///
693    /// Returns an error if any variant exceeds the maximum field count.
694    pub fn generate_constructors(&mut self) -> Result<(), String> {
695        let mut new_words = Vec::new();
696
697        for union_def in &self.unions {
698            for variant in &union_def.variants {
699                let field_count = variant.fields.len();
700
701                // Check field count limit before generating constructor
702                if field_count > Self::MAX_VARIANT_FIELDS {
703                    return Err(format!(
704                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
705                         Consider grouping fields into nested union types.",
706                        variant.name,
707                        union_def.name,
708                        field_count,
709                        Self::MAX_VARIANT_FIELDS
710                    ));
711                }
712
713                // 1. Generate constructor: Make-VariantName
714                let constructor_name = format!("Make-{}", variant.name);
715                let mut input_stack = StackType::RowVar("a".to_string());
716                for field in &variant.fields {
717                    let field_type = parse_type_name(&field.type_name);
718                    input_stack = input_stack.push(field_type);
719                }
720                let output_stack =
721                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
722                let effect = Effect::new(input_stack, output_stack);
723                let body = vec![
724                    Statement::Symbol(variant.name.clone()),
725                    Statement::WordCall {
726                        name: format!("variant.make-{}", field_count),
727                        span: None,
728                    },
729                ];
730                new_words.push(WordDef {
731                    name: constructor_name,
732                    effect: Some(effect),
733                    body,
734                    source: variant.source.clone(),
735                    allowed_lints: vec![],
736                });
737
738                // 2. Generate predicate: is-VariantName?
739                // Effect: ( UnionType -- Bool )
740                // Body: variant.tag :VariantName symbol.=
741                let predicate_name = format!("is-{}?", variant.name);
742                let predicate_input =
743                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
744                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
745                let predicate_effect = Effect::new(predicate_input, predicate_output);
746                let predicate_body = vec![
747                    Statement::WordCall {
748                        name: "variant.tag".to_string(),
749                        span: None,
750                    },
751                    Statement::Symbol(variant.name.clone()),
752                    Statement::WordCall {
753                        name: "symbol.=".to_string(),
754                        span: None,
755                    },
756                ];
757                new_words.push(WordDef {
758                    name: predicate_name,
759                    effect: Some(predicate_effect),
760                    body: predicate_body,
761                    source: variant.source.clone(),
762                    allowed_lints: vec![],
763                });
764
765                // 3. Generate field accessors: VariantName-fieldname
766                // Effect: ( UnionType -- FieldType )
767                // Body: N variant.field-at
768                for (index, field) in variant.fields.iter().enumerate() {
769                    let accessor_name = format!("{}-{}", variant.name, field.name);
770                    let field_type = parse_type_name(&field.type_name);
771                    let accessor_input = StackType::RowVar("a".to_string())
772                        .push(Type::Union(union_def.name.clone()));
773                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
774                    let accessor_effect = Effect::new(accessor_input, accessor_output);
775                    let accessor_body = vec![
776                        Statement::IntLiteral(index as i64),
777                        Statement::WordCall {
778                            name: "variant.field-at".to_string(),
779                            span: None,
780                        },
781                    ];
782                    new_words.push(WordDef {
783                        name: accessor_name,
784                        effect: Some(accessor_effect),
785                        body: accessor_body,
786                        source: variant.source.clone(), // Use variant's source for field accessors
787                        allowed_lints: vec![],
788                    });
789                }
790            }
791        }
792
793        self.words.extend(new_words);
794        Ok(())
795    }
796
797    /// RFC #345: Fix up type variables in stack effects that should be union types
798    ///
799    /// When parsing files with includes, type variables like "Message" in
800    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
801    /// union definition is in an included file. After resolving includes,
802    /// we know all union names and can convert these to `Type::Union("Message")`.
803    ///
804    /// This ensures proper nominal type checking for union types across files.
805    pub fn fixup_union_types(&mut self) {
806        // Collect all union names from the program
807        let union_names: std::collections::HashSet<String> =
808            self.unions.iter().map(|u| u.name.clone()).collect();
809
810        // Fix up types in all word effects
811        for word in &mut self.words {
812            if let Some(ref mut effect) = word.effect {
813                Self::fixup_stack_type(&mut effect.inputs, &union_names);
814                Self::fixup_stack_type(&mut effect.outputs, &union_names);
815            }
816        }
817    }
818
819    /// Recursively fix up types in a stack type
820    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
821        match stack {
822            StackType::Empty | StackType::RowVar(_) => {}
823            StackType::Cons { rest, top } => {
824                Self::fixup_type(top, union_names);
825                Self::fixup_stack_type(rest, union_names);
826            }
827        }
828    }
829
830    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
831    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
832        match ty {
833            Type::Var(name) if union_names.contains(name) => {
834                *ty = Type::Union(name.clone());
835            }
836            Type::Quotation(effect) => {
837                Self::fixup_stack_type(&mut effect.inputs, union_names);
838                Self::fixup_stack_type(&mut effect.outputs, union_names);
839            }
840            Type::Closure { effect, captures } => {
841                Self::fixup_stack_type(&mut effect.inputs, union_names);
842                Self::fixup_stack_type(&mut effect.outputs, union_names);
843                for cap in captures {
844                    Self::fixup_type(cap, union_names);
845                }
846            }
847            _ => {}
848        }
849    }
850}
851
852/// Parse a type name string into a Type
853/// Used by constructor generation to build stack effects
854fn parse_type_name(name: &str) -> Type {
855    match name {
856        "Int" => Type::Int,
857        "Float" => Type::Float,
858        "Bool" => Type::Bool,
859        "String" => Type::String,
860        "Channel" => Type::Channel,
861        other => Type::Union(other.to_string()),
862    }
863}
864
865impl Default for Program {
866    fn default() -> Self {
867        Self::new()
868    }
869}
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874
875    #[test]
876    fn test_validate_builtin_words() {
877        let program = Program {
878            includes: vec![],
879            unions: vec![],
880            words: vec![WordDef {
881                name: "main".to_string(),
882                effect: None,
883                body: vec![
884                    Statement::IntLiteral(2),
885                    Statement::IntLiteral(3),
886                    Statement::WordCall {
887                        name: "i.add".to_string(),
888                        span: None,
889                    },
890                    Statement::WordCall {
891                        name: "io.write-line".to_string(),
892                        span: None,
893                    },
894                ],
895                source: None,
896                allowed_lints: vec![],
897            }],
898        };
899
900        // Should succeed - i.add and io.write-line are built-ins
901        assert!(program.validate_word_calls().is_ok());
902    }
903
904    #[test]
905    fn test_validate_user_defined_words() {
906        let program = Program {
907            includes: vec![],
908            unions: vec![],
909            words: vec![
910                WordDef {
911                    name: "helper".to_string(),
912                    effect: None,
913                    body: vec![Statement::IntLiteral(42)],
914                    source: None,
915                    allowed_lints: vec![],
916                },
917                WordDef {
918                    name: "main".to_string(),
919                    effect: None,
920                    body: vec![Statement::WordCall {
921                        name: "helper".to_string(),
922                        span: None,
923                    }],
924                    source: None,
925                    allowed_lints: vec![],
926                },
927            ],
928        };
929
930        // Should succeed - helper is defined
931        assert!(program.validate_word_calls().is_ok());
932    }
933
934    #[test]
935    fn test_validate_undefined_word() {
936        let program = Program {
937            includes: vec![],
938            unions: vec![],
939            words: vec![WordDef {
940                name: "main".to_string(),
941                effect: None,
942                body: vec![Statement::WordCall {
943                    name: "undefined_word".to_string(),
944                    span: None,
945                }],
946                source: None,
947                allowed_lints: vec![],
948            }],
949        };
950
951        // Should fail - undefined_word is not a built-in or user-defined word
952        let result = program.validate_word_calls();
953        assert!(result.is_err());
954        let error = result.unwrap_err();
955        assert!(error.contains("undefined_word"));
956        assert!(error.contains("main"));
957    }
958
959    #[test]
960    fn test_validate_misspelled_builtin() {
961        let program = Program {
962            includes: vec![],
963            unions: vec![],
964            words: vec![WordDef {
965                name: "main".to_string(),
966                effect: None,
967                body: vec![Statement::WordCall {
968                    name: "wrte_line".to_string(),
969                    span: None,
970                }], // typo
971                source: None,
972                allowed_lints: vec![],
973            }],
974        };
975
976        // Should fail with helpful message
977        let result = program.validate_word_calls();
978        assert!(result.is_err());
979        let error = result.unwrap_err();
980        assert!(error.contains("wrte_line"));
981        assert!(error.contains("misspell"));
982    }
983
984    #[test]
985    fn test_generate_constructors() {
986        let mut program = Program {
987            includes: vec![],
988            unions: vec![UnionDef {
989                name: "Message".to_string(),
990                variants: vec![
991                    UnionVariant {
992                        name: "Get".to_string(),
993                        fields: vec![UnionField {
994                            name: "response-chan".to_string(),
995                            type_name: "Int".to_string(),
996                        }],
997                        source: None,
998                    },
999                    UnionVariant {
1000                        name: "Put".to_string(),
1001                        fields: vec![
1002                            UnionField {
1003                                name: "value".to_string(),
1004                                type_name: "String".to_string(),
1005                            },
1006                            UnionField {
1007                                name: "response-chan".to_string(),
1008                                type_name: "Int".to_string(),
1009                            },
1010                        ],
1011                        source: None,
1012                    },
1013                ],
1014                source: None,
1015            }],
1016            words: vec![],
1017        };
1018
1019        // Generate constructors, predicates, and accessors
1020        program.generate_constructors().unwrap();
1021
1022        // Should have 7 words:
1023        // - Get variant: Make-Get, is-Get?, Get-response-chan (1 field)
1024        // - Put variant: Make-Put, is-Put?, Put-value, Put-response-chan (2 fields)
1025        assert_eq!(program.words.len(), 7);
1026
1027        // Check Make-Get constructor
1028        let make_get = program
1029            .find_word("Make-Get")
1030            .expect("Make-Get should exist");
1031        assert_eq!(make_get.name, "Make-Get");
1032        assert!(make_get.effect.is_some());
1033        let effect = make_get.effect.as_ref().unwrap();
1034        // Input: ( ..a Int -- )
1035        // Output: ( ..a Message -- )
1036        assert_eq!(
1037            format!("{:?}", effect.outputs),
1038            "Cons { rest: RowVar(\"a\"), top: Union(\"Message\") }"
1039        );
1040
1041        // Check Make-Put constructor
1042        let make_put = program
1043            .find_word("Make-Put")
1044            .expect("Make-Put should exist");
1045        assert_eq!(make_put.name, "Make-Put");
1046        assert!(make_put.effect.is_some());
1047
1048        // Check the body generates correct code
1049        // Make-Get should be: :Get variant.make-1
1050        assert_eq!(make_get.body.len(), 2);
1051        match &make_get.body[0] {
1052            Statement::Symbol(s) if s == "Get" => {}
1053            other => panic!("Expected Symbol(\"Get\") for variant tag, got {:?}", other),
1054        }
1055        match &make_get.body[1] {
1056            Statement::WordCall { name, span: None } if name == "variant.make-1" => {}
1057            _ => panic!("Expected WordCall(variant.make-1)"),
1058        }
1059
1060        // Make-Put should be: :Put variant.make-2
1061        assert_eq!(make_put.body.len(), 2);
1062        match &make_put.body[0] {
1063            Statement::Symbol(s) if s == "Put" => {}
1064            other => panic!("Expected Symbol(\"Put\") for variant tag, got {:?}", other),
1065        }
1066        match &make_put.body[1] {
1067            Statement::WordCall { name, span: None } if name == "variant.make-2" => {}
1068            _ => panic!("Expected WordCall(variant.make-2)"),
1069        }
1070
1071        // Check is-Get? predicate
1072        let is_get = program.find_word("is-Get?").expect("is-Get? should exist");
1073        assert_eq!(is_get.name, "is-Get?");
1074        assert!(is_get.effect.is_some());
1075        let effect = is_get.effect.as_ref().unwrap();
1076        // Input: ( ..a Message -- )
1077        // Output: ( ..a Bool -- )
1078        assert_eq!(
1079            format!("{:?}", effect.outputs),
1080            "Cons { rest: RowVar(\"a\"), top: Bool }"
1081        );
1082
1083        // Check Get-response-chan accessor
1084        let get_chan = program
1085            .find_word("Get-response-chan")
1086            .expect("Get-response-chan should exist");
1087        assert_eq!(get_chan.name, "Get-response-chan");
1088        assert!(get_chan.effect.is_some());
1089        let effect = get_chan.effect.as_ref().unwrap();
1090        // Input: ( ..a Message -- )
1091        // Output: ( ..a Int -- )
1092        assert_eq!(
1093            format!("{:?}", effect.outputs),
1094            "Cons { rest: RowVar(\"a\"), top: Int }"
1095        );
1096    }
1097}