Skip to main content

seqc/
ast.rs

1//! Abstract Syntax Tree for Seq
2//!
3//! Minimal AST sufficient for hello-world and basic programs.
4//! Will be extended as we add more language features.
5
6use crate::types::{Effect, StackType, Type};
7use std::path::PathBuf;
8
9/// Source location for error reporting and tooling
10#[derive(Debug, Clone, PartialEq)]
11pub struct SourceLocation {
12    pub file: PathBuf,
13    /// Start line (0-indexed for LSP compatibility)
14    pub start_line: usize,
15    /// End line (0-indexed, inclusive)
16    pub end_line: usize,
17}
18
19impl SourceLocation {
20    /// Create a new source location with just a single line (for backward compatibility)
21    pub fn new(file: PathBuf, line: usize) -> Self {
22        SourceLocation {
23            file,
24            start_line: line,
25            end_line: line,
26        }
27    }
28
29    /// Create a source location spanning multiple lines
30    pub fn span(file: PathBuf, start_line: usize, end_line: usize) -> Self {
31        debug_assert!(
32            start_line <= end_line,
33            "SourceLocation: start_line ({}) must be <= end_line ({})",
34            start_line,
35            end_line
36        );
37        SourceLocation {
38            file,
39            start_line,
40            end_line,
41        }
42    }
43}
44
45impl std::fmt::Display for SourceLocation {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        if self.start_line == self.end_line {
48            write!(f, "{}:{}", self.file.display(), self.start_line + 1)
49        } else {
50            write!(
51                f,
52                "{}:{}-{}",
53                self.file.display(),
54                self.start_line + 1,
55                self.end_line + 1
56            )
57        }
58    }
59}
60
61/// Include statement
62#[derive(Debug, Clone, PartialEq)]
63pub enum Include {
64    /// Standard library include: `include std:http`
65    Std(String),
66    /// Relative path include: `include "my-utils"`
67    Relative(String),
68    /// FFI library include: `include ffi:readline`
69    Ffi(String),
70}
71
72// ============================================================================
73//                     ALGEBRAIC DATA TYPES (ADTs)
74// ============================================================================
75
76/// A field in a union variant
77/// Example: `response-chan: Int`
78#[derive(Debug, Clone, PartialEq)]
79pub struct UnionField {
80    pub name: String,
81    pub type_name: String, // For now, just store the type name as string
82}
83
84/// A variant in a union type
85/// Example: `Get { response-chan: Int }`
86#[derive(Debug, Clone, PartialEq)]
87pub struct UnionVariant {
88    pub name: String,
89    pub fields: Vec<UnionField>,
90    pub source: Option<SourceLocation>,
91}
92
93/// A union type definition
94/// Example:
95/// ```seq
96/// union Message {
97///   Get { response-chan: Int }
98///   Increment { response-chan: Int }
99///   Report { op: Int, delta: Int, total: Int }
100/// }
101/// ```
102#[derive(Debug, Clone, PartialEq)]
103pub struct UnionDef {
104    pub name: String,
105    pub variants: Vec<UnionVariant>,
106    pub source: Option<SourceLocation>,
107}
108
109/// A pattern in a match expression
110/// For Phase 1: just the variant name (stack-based matching)
111/// Later phases will add field bindings: `Get { chan }`
112#[derive(Debug, Clone, PartialEq)]
113pub enum Pattern {
114    /// Match a variant by name, pushing all fields to stack
115    /// Example: `Get ->` pushes response-chan to stack
116    Variant(String),
117
118    /// Match a variant with named field bindings (Phase 5)
119    /// Example: `Get { chan } ->` binds chan to the response-chan field
120    VariantWithBindings { name: String, bindings: Vec<String> },
121}
122
123/// A single arm in a match expression
124#[derive(Debug, Clone, PartialEq)]
125pub struct MatchArm {
126    pub pattern: Pattern,
127    pub body: Vec<Statement>,
128    /// Source span for error reporting (points to variant name)
129    pub span: Option<Span>,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct Program {
134    pub includes: Vec<Include>,
135    pub unions: Vec<UnionDef>,
136    pub words: Vec<WordDef>,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct WordDef {
141    pub name: String,
142    /// Optional stack effect declaration
143    /// Example: ( ..a Int -- ..a Bool )
144    pub effect: Option<Effect>,
145    pub body: Vec<Statement>,
146    /// Source location for error reporting (collision detection)
147    pub source: Option<SourceLocation>,
148    /// Lint IDs that are allowed (suppressed) for this word
149    /// Set via `# seq:allow(lint-id)` annotation before the word definition
150    pub allowed_lints: Vec<String>,
151}
152
153/// Source span for a single token or expression
154#[derive(Debug, Clone, PartialEq, Default)]
155pub struct Span {
156    /// Line number (0-indexed)
157    pub line: usize,
158    /// Start column (0-indexed)
159    pub column: usize,
160    /// Length of the span in characters
161    pub length: usize,
162}
163
164impl Span {
165    pub fn new(line: usize, column: usize, length: usize) -> Self {
166        Span {
167            line,
168            column,
169            length,
170        }
171    }
172}
173
174/// Source span for a quotation, supporting multi-line ranges
175#[derive(Debug, Clone, PartialEq, Default)]
176pub struct QuotationSpan {
177    /// Start line (0-indexed)
178    pub start_line: usize,
179    /// Start column (0-indexed)
180    pub start_column: usize,
181    /// End line (0-indexed)
182    pub end_line: usize,
183    /// End column (0-indexed, exclusive)
184    pub end_column: usize,
185}
186
187impl QuotationSpan {
188    pub fn new(start_line: usize, start_column: usize, end_line: usize, end_column: usize) -> Self {
189        QuotationSpan {
190            start_line,
191            start_column,
192            end_line,
193            end_column,
194        }
195    }
196
197    /// Check if a position (line, column) falls within this span
198    pub fn contains(&self, line: usize, column: usize) -> bool {
199        if line < self.start_line || line > self.end_line {
200            return false;
201        }
202        if line == self.start_line && column < self.start_column {
203            return false;
204        }
205        if line == self.end_line && column >= self.end_column {
206            return false;
207        }
208        true
209    }
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub enum Statement {
214    /// Integer literal: pushes value onto stack
215    IntLiteral(i64),
216
217    /// Floating-point literal: pushes IEEE 754 double onto stack
218    FloatLiteral(f64),
219
220    /// Boolean literal: pushes true/false onto stack
221    BoolLiteral(bool),
222
223    /// String literal: pushes string onto stack
224    StringLiteral(String),
225
226    /// Symbol literal: pushes symbol onto stack
227    /// Syntax: :foo, :some-name, :ok
228    /// Used for dynamic variant construction and SON.
229    /// Note: Symbols are not currently interned (future optimization).
230    Symbol(String),
231
232    /// Word call: calls another word or built-in
233    /// Contains the word name and optional source span for precise diagnostics
234    WordCall { name: String, span: Option<Span> },
235
236    /// Conditional: if/else/then
237    ///
238    /// Pops an integer from the stack (0 = zero, non-zero = non-zero)
239    /// and executes the appropriate branch
240    If {
241        /// Statements to execute when condition is non-zero (the 'then' clause)
242        then_branch: Vec<Statement>,
243        /// Optional statements to execute when condition is zero (the 'else' clause)
244        else_branch: Option<Vec<Statement>>,
245        /// Source span for error reporting (points to 'if' keyword)
246        span: Option<Span>,
247    },
248
249    /// Quotation: [ ... ]
250    ///
251    /// A block of deferred code (quotation/lambda)
252    /// Quotations are first-class values that can be pushed onto the stack
253    /// and executed later with combinators like `call`, `times`, or `while`
254    ///
255    /// The id field is used by the typechecker to track the inferred type
256    /// (Quotation vs Closure) for this quotation. The id is assigned during parsing.
257    /// The span field records the source location for LSP hover support.
258    Quotation {
259        id: usize,
260        body: Vec<Statement>,
261        span: Option<QuotationSpan>,
262    },
263
264    /// Match expression: pattern matching on union types
265    ///
266    /// Pops a union value from the stack and dispatches to the
267    /// appropriate arm based on the variant tag.
268    ///
269    /// Example:
270    /// ```seq
271    /// match
272    ///   Get -> send-response
273    ///   Increment -> do-increment send-response
274    ///   Report -> aggregate-add
275    /// end
276    /// ```
277    Match {
278        /// The match arms in order
279        arms: Vec<MatchArm>,
280        /// Source span for error reporting (points to 'match' keyword)
281        span: Option<Span>,
282    },
283}
284
285impl Program {
286    pub fn new() -> Self {
287        Program {
288            includes: Vec::new(),
289            unions: Vec::new(),
290            words: Vec::new(),
291        }
292    }
293
294    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
295        self.words.iter().find(|w| w.name == name)
296    }
297
298    /// Validate that all word calls reference either a defined word or a built-in
299    pub fn validate_word_calls(&self) -> Result<(), String> {
300        self.validate_word_calls_with_externals(&[])
301    }
302
303    /// Validate that all word calls reference a defined word, built-in, or external word.
304    ///
305    /// The `external_words` parameter should contain names of words available from
306    /// external sources (e.g., included modules) that should be considered valid.
307    pub fn validate_word_calls_with_externals(
308        &self,
309        external_words: &[&str],
310    ) -> Result<(), String> {
311        // List of known runtime built-ins
312        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
313        let builtins = [
314            // I/O operations
315            "io.write",
316            "io.write-line",
317            "io.read-line",
318            "io.read-line+",
319            "io.read-n",
320            "int->string",
321            "symbol->string",
322            "string->symbol",
323            // Command-line arguments
324            "args.count",
325            "args.at",
326            // File operations
327            "file.slurp",
328            "file.exists?",
329            "file.for-each-line+",
330            "file.spit",
331            "file.append",
332            "file.delete",
333            "file.size",
334            // Directory operations
335            "dir.exists?",
336            "dir.make",
337            "dir.delete",
338            "dir.list",
339            // String operations
340            "string.concat",
341            "string.length",
342            "string.byte-length",
343            "string.char-at",
344            "string.substring",
345            "char->string",
346            "string.find",
347            "string.split",
348            "string.contains",
349            "string.starts-with",
350            "string.empty?",
351            "string.trim",
352            "string.chomp",
353            "string.to-upper",
354            "string.to-lower",
355            "string.equal?",
356            "string.join",
357            "string.json-escape",
358            "string->int",
359            // Symbol operations
360            "symbol.=",
361            // Encoding operations
362            "encoding.base64-encode",
363            "encoding.base64-decode",
364            "encoding.base64url-encode",
365            "encoding.base64url-decode",
366            "encoding.hex-encode",
367            "encoding.hex-decode",
368            // Crypto operations
369            "crypto.sha256",
370            "crypto.hmac-sha256",
371            "crypto.constant-time-eq",
372            "crypto.random-bytes",
373            "crypto.random-int",
374            "crypto.uuid4",
375            "crypto.aes-gcm-encrypt",
376            "crypto.aes-gcm-decrypt",
377            "crypto.pbkdf2-sha256",
378            "crypto.ed25519-keypair",
379            "crypto.ed25519-sign",
380            "crypto.ed25519-verify",
381            // HTTP client operations
382            "http.get",
383            "http.post",
384            "http.put",
385            "http.delete",
386            // List operations
387            "list.make",
388            "list.push",
389            "list.get",
390            "list.set",
391            "list.map",
392            "list.filter",
393            "list.fold",
394            "list.each",
395            "list.length",
396            "list.empty?",
397            "list.reverse",
398            // Map operations
399            "map.make",
400            "map.get",
401            "map.set",
402            "map.has?",
403            "map.remove",
404            "map.keys",
405            "map.values",
406            "map.size",
407            "map.empty?",
408            "map.each",
409            "map.fold",
410            // Variant operations
411            "variant.field-count",
412            "variant.tag",
413            "variant.field-at",
414            "variant.append",
415            "variant.last",
416            "variant.init",
417            "variant.make-0",
418            "variant.make-1",
419            "variant.make-2",
420            "variant.make-3",
421            "variant.make-4",
422            // SON wrap aliases
423            "wrap-0",
424            "wrap-1",
425            "wrap-2",
426            "wrap-3",
427            "wrap-4",
428            // Integer arithmetic operations
429            "i.add",
430            "i.subtract",
431            "i.multiply",
432            "i.divide",
433            "i.modulo",
434            // Terse integer arithmetic
435            "i.+",
436            "i.-",
437            "i.*",
438            "i./",
439            "i.%",
440            // Integer comparison operations (return 0 or 1)
441            "i.=",
442            "i.<",
443            "i.>",
444            "i.<=",
445            "i.>=",
446            "i.<>",
447            // Integer comparison operations (verbose form)
448            "i.eq",
449            "i.lt",
450            "i.gt",
451            "i.lte",
452            "i.gte",
453            "i.neq",
454            // Stack operations (simple - no parameters)
455            "dup",
456            "drop",
457            "swap",
458            "over",
459            "rot",
460            "nip",
461            "tuck",
462            "2dup",
463            "3drop",
464            "pick",
465            "roll",
466            // Aux stack operations
467            ">aux",
468            "aux>",
469            // Boolean operations
470            "and",
471            "or",
472            "not",
473            // Bitwise operations
474            "band",
475            "bor",
476            "bxor",
477            "bnot",
478            "i.neg",
479            "negate",
480            // Arithmetic sugar (resolved to concrete ops by typechecker)
481            "+",
482            "-",
483            "*",
484            "/",
485            "%",
486            "=",
487            "<",
488            ">",
489            "<=",
490            ">=",
491            "<>",
492            "shl",
493            "shr",
494            "popcount",
495            "clz",
496            "ctz",
497            "int-bits",
498            // Channel operations
499            "chan.make",
500            "chan.send",
501            "chan.receive",
502            "chan.close",
503            "chan.yield",
504            // Quotation operations
505            "call",
506            // Dataflow combinators
507            "dip",
508            "keep",
509            "bi",
510            "strand.spawn",
511            "strand.weave",
512            "strand.resume",
513            "strand.weave-cancel",
514            "yield",
515            "cond",
516            // TCP operations
517            "tcp.listen",
518            "tcp.accept",
519            "tcp.read",
520            "tcp.write",
521            "tcp.close",
522            // OS operations
523            "os.getenv",
524            "os.home-dir",
525            "os.current-dir",
526            "os.path-exists",
527            "os.path-is-file",
528            "os.path-is-dir",
529            "os.path-join",
530            "os.path-parent",
531            "os.path-filename",
532            "os.exit",
533            "os.name",
534            "os.arch",
535            // Signal handling
536            "signal.trap",
537            "signal.received?",
538            "signal.pending?",
539            "signal.default",
540            "signal.ignore",
541            "signal.clear",
542            "signal.SIGINT",
543            "signal.SIGTERM",
544            "signal.SIGHUP",
545            "signal.SIGPIPE",
546            "signal.SIGUSR1",
547            "signal.SIGUSR2",
548            "signal.SIGCHLD",
549            "signal.SIGALRM",
550            "signal.SIGCONT",
551            // Terminal operations
552            "terminal.raw-mode",
553            "terminal.read-char",
554            "terminal.read-char?",
555            "terminal.width",
556            "terminal.height",
557            "terminal.flush",
558            // Float arithmetic operations (verbose form)
559            "f.add",
560            "f.subtract",
561            "f.multiply",
562            "f.divide",
563            // Float arithmetic operations (terse form)
564            "f.+",
565            "f.-",
566            "f.*",
567            "f./",
568            // Float comparison operations (symbol form)
569            "f.=",
570            "f.<",
571            "f.>",
572            "f.<=",
573            "f.>=",
574            "f.<>",
575            // Float comparison operations (verbose form)
576            "f.eq",
577            "f.lt",
578            "f.gt",
579            "f.lte",
580            "f.gte",
581            "f.neq",
582            // Type conversions
583            "int->float",
584            "float->int",
585            "float->string",
586            "string->float",
587            // Test framework operations
588            "test.init",
589            "test.finish",
590            "test.has-failures",
591            "test.assert",
592            "test.assert-not",
593            "test.assert-eq",
594            "test.assert-eq-str",
595            "test.fail",
596            "test.pass-count",
597            "test.fail-count",
598            // Time operations
599            "time.now",
600            "time.nanos",
601            "time.sleep-ms",
602            // SON serialization
603            "son.dump",
604            "son.dump-pretty",
605            // Stack introspection (for REPL)
606            "stack.dump",
607            // Regex operations
608            "regex.match?",
609            "regex.find",
610            "regex.find-all",
611            "regex.replace",
612            "regex.replace-all",
613            "regex.captures",
614            "regex.split",
615            "regex.valid?",
616            // Compression operations
617            "compress.gzip",
618            "compress.gzip-level",
619            "compress.gunzip",
620            "compress.zstd",
621            "compress.zstd-level",
622            "compress.unzstd",
623        ];
624
625        for word in &self.words {
626            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
627        }
628
629        Ok(())
630    }
631
632    /// Helper to validate word calls in a list of statements (recursively)
633    fn validate_statements(
634        &self,
635        statements: &[Statement],
636        word_name: &str,
637        builtins: &[&str],
638        external_words: &[&str],
639    ) -> Result<(), String> {
640        for statement in statements {
641            match statement {
642                Statement::WordCall { name, .. } => {
643                    // Check if it's a built-in
644                    if builtins.contains(&name.as_str()) {
645                        continue;
646                    }
647                    // Check if it's a user-defined word
648                    if self.find_word(name).is_some() {
649                        continue;
650                    }
651                    // Check if it's an external word (from includes)
652                    if external_words.contains(&name.as_str()) {
653                        continue;
654                    }
655                    // Undefined word!
656                    return Err(format!(
657                        "Undefined word '{}' called in word '{}'. \
658                         Did you forget to define it or misspell a built-in?",
659                        name, word_name
660                    ));
661                }
662                Statement::If {
663                    then_branch,
664                    else_branch,
665                    span: _,
666                } => {
667                    // Recursively validate both branches
668                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
669                    if let Some(eb) = else_branch {
670                        self.validate_statements(eb, word_name, builtins, external_words)?;
671                    }
672                }
673                Statement::Quotation { body, .. } => {
674                    // Recursively validate quotation body
675                    self.validate_statements(body, word_name, builtins, external_words)?;
676                }
677                Statement::Match { arms, span: _ } => {
678                    // Recursively validate each match arm's body
679                    for arm in arms {
680                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
681                    }
682                }
683                _ => {} // Literals don't need validation
684            }
685        }
686        Ok(())
687    }
688
689    /// Generate constructor words for all union definitions
690    ///
691    /// Maximum number of fields a variant can have (limited by runtime support)
692    pub const MAX_VARIANT_FIELDS: usize = 12;
693
694    /// Generate helper words for union types:
695    /// 1. Constructors: `Make-VariantName` - creates variant instances
696    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
697    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
698    ///
699    /// Example: For `union Message { Get { chan: Int } }`
700    /// Generates:
701    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
702    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
703    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
704    ///
705    /// Returns an error if any variant exceeds the maximum field count.
706    pub fn generate_constructors(&mut self) -> Result<(), String> {
707        let mut new_words = Vec::new();
708
709        for union_def in &self.unions {
710            for variant in &union_def.variants {
711                let field_count = variant.fields.len();
712
713                // Check field count limit before generating constructor
714                if field_count > Self::MAX_VARIANT_FIELDS {
715                    return Err(format!(
716                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
717                         Consider grouping fields into nested union types.",
718                        variant.name,
719                        union_def.name,
720                        field_count,
721                        Self::MAX_VARIANT_FIELDS
722                    ));
723                }
724
725                // 1. Generate constructor: Make-VariantName
726                let constructor_name = format!("Make-{}", variant.name);
727                let mut input_stack = StackType::RowVar("a".to_string());
728                for field in &variant.fields {
729                    let field_type = parse_type_name(&field.type_name);
730                    input_stack = input_stack.push(field_type);
731                }
732                let output_stack =
733                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
734                let effect = Effect::new(input_stack, output_stack);
735                let body = vec![
736                    Statement::Symbol(variant.name.clone()),
737                    Statement::WordCall {
738                        name: format!("variant.make-{}", field_count),
739                        span: None,
740                    },
741                ];
742                new_words.push(WordDef {
743                    name: constructor_name,
744                    effect: Some(effect),
745                    body,
746                    source: variant.source.clone(),
747                    allowed_lints: vec![],
748                });
749
750                // 2. Generate predicate: is-VariantName?
751                // Effect: ( UnionType -- Bool )
752                // Body: variant.tag :VariantName symbol.=
753                let predicate_name = format!("is-{}?", variant.name);
754                let predicate_input =
755                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
756                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
757                let predicate_effect = Effect::new(predicate_input, predicate_output);
758                let predicate_body = vec![
759                    Statement::WordCall {
760                        name: "variant.tag".to_string(),
761                        span: None,
762                    },
763                    Statement::Symbol(variant.name.clone()),
764                    Statement::WordCall {
765                        name: "symbol.=".to_string(),
766                        span: None,
767                    },
768                ];
769                new_words.push(WordDef {
770                    name: predicate_name,
771                    effect: Some(predicate_effect),
772                    body: predicate_body,
773                    source: variant.source.clone(),
774                    allowed_lints: vec![],
775                });
776
777                // 3. Generate field accessors: VariantName-fieldname
778                // Effect: ( UnionType -- FieldType )
779                // Body: N variant.field-at
780                for (index, field) in variant.fields.iter().enumerate() {
781                    let accessor_name = format!("{}-{}", variant.name, field.name);
782                    let field_type = parse_type_name(&field.type_name);
783                    let accessor_input = StackType::RowVar("a".to_string())
784                        .push(Type::Union(union_def.name.clone()));
785                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
786                    let accessor_effect = Effect::new(accessor_input, accessor_output);
787                    let accessor_body = vec![
788                        Statement::IntLiteral(index as i64),
789                        Statement::WordCall {
790                            name: "variant.field-at".to_string(),
791                            span: None,
792                        },
793                    ];
794                    new_words.push(WordDef {
795                        name: accessor_name,
796                        effect: Some(accessor_effect),
797                        body: accessor_body,
798                        source: variant.source.clone(), // Use variant's source for field accessors
799                        allowed_lints: vec![],
800                    });
801                }
802            }
803        }
804
805        self.words.extend(new_words);
806        Ok(())
807    }
808
809    /// RFC #345: Fix up type variables in stack effects that should be union types
810    ///
811    /// When parsing files with includes, type variables like "Message" in
812    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
813    /// union definition is in an included file. After resolving includes,
814    /// we know all union names and can convert these to `Type::Union("Message")`.
815    ///
816    /// This ensures proper nominal type checking for union types across files.
817    pub fn fixup_union_types(&mut self) {
818        // Collect all union names from the program
819        let union_names: std::collections::HashSet<String> =
820            self.unions.iter().map(|u| u.name.clone()).collect();
821
822        // Fix up types in all word effects
823        for word in &mut self.words {
824            if let Some(ref mut effect) = word.effect {
825                Self::fixup_stack_type(&mut effect.inputs, &union_names);
826                Self::fixup_stack_type(&mut effect.outputs, &union_names);
827            }
828        }
829    }
830
831    /// Recursively fix up types in a stack type
832    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
833        match stack {
834            StackType::Empty | StackType::RowVar(_) => {}
835            StackType::Cons { rest, top } => {
836                Self::fixup_type(top, union_names);
837                Self::fixup_stack_type(rest, union_names);
838            }
839        }
840    }
841
842    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
843    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
844        match ty {
845            Type::Var(name) if union_names.contains(name) => {
846                *ty = Type::Union(name.clone());
847            }
848            Type::Quotation(effect) => {
849                Self::fixup_stack_type(&mut effect.inputs, union_names);
850                Self::fixup_stack_type(&mut effect.outputs, union_names);
851            }
852            Type::Closure { effect, captures } => {
853                Self::fixup_stack_type(&mut effect.inputs, union_names);
854                Self::fixup_stack_type(&mut effect.outputs, union_names);
855                for cap in captures {
856                    Self::fixup_type(cap, union_names);
857                }
858            }
859            _ => {}
860        }
861    }
862}
863
864/// Parse a type name string into a Type
865/// Used by constructor generation to build stack effects
866fn parse_type_name(name: &str) -> Type {
867    match name {
868        "Int" => Type::Int,
869        "Float" => Type::Float,
870        "Bool" => Type::Bool,
871        "String" => Type::String,
872        "Channel" => Type::Channel,
873        other => Type::Union(other.to_string()),
874    }
875}
876
877impl Default for Program {
878    fn default() -> Self {
879        Self::new()
880    }
881}
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886
887    #[test]
888    fn test_validate_builtin_words() {
889        let program = Program {
890            includes: vec![],
891            unions: vec![],
892            words: vec![WordDef {
893                name: "main".to_string(),
894                effect: None,
895                body: vec![
896                    Statement::IntLiteral(2),
897                    Statement::IntLiteral(3),
898                    Statement::WordCall {
899                        name: "i.add".to_string(),
900                        span: None,
901                    },
902                    Statement::WordCall {
903                        name: "io.write-line".to_string(),
904                        span: None,
905                    },
906                ],
907                source: None,
908                allowed_lints: vec![],
909            }],
910        };
911
912        // Should succeed - i.add and io.write-line are built-ins
913        assert!(program.validate_word_calls().is_ok());
914    }
915
916    #[test]
917    fn test_validate_user_defined_words() {
918        let program = Program {
919            includes: vec![],
920            unions: vec![],
921            words: vec![
922                WordDef {
923                    name: "helper".to_string(),
924                    effect: None,
925                    body: vec![Statement::IntLiteral(42)],
926                    source: None,
927                    allowed_lints: vec![],
928                },
929                WordDef {
930                    name: "main".to_string(),
931                    effect: None,
932                    body: vec![Statement::WordCall {
933                        name: "helper".to_string(),
934                        span: None,
935                    }],
936                    source: None,
937                    allowed_lints: vec![],
938                },
939            ],
940        };
941
942        // Should succeed - helper is defined
943        assert!(program.validate_word_calls().is_ok());
944    }
945
946    #[test]
947    fn test_validate_undefined_word() {
948        let program = Program {
949            includes: vec![],
950            unions: vec![],
951            words: vec![WordDef {
952                name: "main".to_string(),
953                effect: None,
954                body: vec![Statement::WordCall {
955                    name: "undefined_word".to_string(),
956                    span: None,
957                }],
958                source: None,
959                allowed_lints: vec![],
960            }],
961        };
962
963        // Should fail - undefined_word is not a built-in or user-defined word
964        let result = program.validate_word_calls();
965        assert!(result.is_err());
966        let error = result.unwrap_err();
967        assert!(error.contains("undefined_word"));
968        assert!(error.contains("main"));
969    }
970
971    #[test]
972    fn test_validate_misspelled_builtin() {
973        let program = Program {
974            includes: vec![],
975            unions: vec![],
976            words: vec![WordDef {
977                name: "main".to_string(),
978                effect: None,
979                body: vec![Statement::WordCall {
980                    name: "wrte_line".to_string(),
981                    span: None,
982                }], // typo
983                source: None,
984                allowed_lints: vec![],
985            }],
986        };
987
988        // Should fail with helpful message
989        let result = program.validate_word_calls();
990        assert!(result.is_err());
991        let error = result.unwrap_err();
992        assert!(error.contains("wrte_line"));
993        assert!(error.contains("misspell"));
994    }
995
996    #[test]
997    fn test_generate_constructors() {
998        let mut program = Program {
999            includes: vec![],
1000            unions: vec![UnionDef {
1001                name: "Message".to_string(),
1002                variants: vec![
1003                    UnionVariant {
1004                        name: "Get".to_string(),
1005                        fields: vec![UnionField {
1006                            name: "response-chan".to_string(),
1007                            type_name: "Int".to_string(),
1008                        }],
1009                        source: None,
1010                    },
1011                    UnionVariant {
1012                        name: "Put".to_string(),
1013                        fields: vec![
1014                            UnionField {
1015                                name: "value".to_string(),
1016                                type_name: "String".to_string(),
1017                            },
1018                            UnionField {
1019                                name: "response-chan".to_string(),
1020                                type_name: "Int".to_string(),
1021                            },
1022                        ],
1023                        source: None,
1024                    },
1025                ],
1026                source: None,
1027            }],
1028            words: vec![],
1029        };
1030
1031        // Generate constructors, predicates, and accessors
1032        program.generate_constructors().unwrap();
1033
1034        // Should have 7 words:
1035        // - Get variant: Make-Get, is-Get?, Get-response-chan (1 field)
1036        // - Put variant: Make-Put, is-Put?, Put-value, Put-response-chan (2 fields)
1037        assert_eq!(program.words.len(), 7);
1038
1039        // Check Make-Get constructor
1040        let make_get = program
1041            .find_word("Make-Get")
1042            .expect("Make-Get should exist");
1043        assert_eq!(make_get.name, "Make-Get");
1044        assert!(make_get.effect.is_some());
1045        let effect = make_get.effect.as_ref().unwrap();
1046        // Input: ( ..a Int -- )
1047        // Output: ( ..a Message -- )
1048        assert_eq!(
1049            format!("{:?}", effect.outputs),
1050            "Cons { rest: RowVar(\"a\"), top: Union(\"Message\") }"
1051        );
1052
1053        // Check Make-Put constructor
1054        let make_put = program
1055            .find_word("Make-Put")
1056            .expect("Make-Put should exist");
1057        assert_eq!(make_put.name, "Make-Put");
1058        assert!(make_put.effect.is_some());
1059
1060        // Check the body generates correct code
1061        // Make-Get should be: :Get variant.make-1
1062        assert_eq!(make_get.body.len(), 2);
1063        match &make_get.body[0] {
1064            Statement::Symbol(s) if s == "Get" => {}
1065            other => panic!("Expected Symbol(\"Get\") for variant tag, got {:?}", other),
1066        }
1067        match &make_get.body[1] {
1068            Statement::WordCall { name, span: None } if name == "variant.make-1" => {}
1069            _ => panic!("Expected WordCall(variant.make-1)"),
1070        }
1071
1072        // Make-Put should be: :Put variant.make-2
1073        assert_eq!(make_put.body.len(), 2);
1074        match &make_put.body[0] {
1075            Statement::Symbol(s) if s == "Put" => {}
1076            other => panic!("Expected Symbol(\"Put\") for variant tag, got {:?}", other),
1077        }
1078        match &make_put.body[1] {
1079            Statement::WordCall { name, span: None } if name == "variant.make-2" => {}
1080            _ => panic!("Expected WordCall(variant.make-2)"),
1081        }
1082
1083        // Check is-Get? predicate
1084        let is_get = program.find_word("is-Get?").expect("is-Get? should exist");
1085        assert_eq!(is_get.name, "is-Get?");
1086        assert!(is_get.effect.is_some());
1087        let effect = is_get.effect.as_ref().unwrap();
1088        // Input: ( ..a Message -- )
1089        // Output: ( ..a Bool -- )
1090        assert_eq!(
1091            format!("{:?}", effect.outputs),
1092            "Cons { rest: RowVar(\"a\"), top: Bool }"
1093        );
1094
1095        // Check Get-response-chan accessor
1096        let get_chan = program
1097            .find_word("Get-response-chan")
1098            .expect("Get-response-chan should exist");
1099        assert_eq!(get_chan.name, "Get-response-chan");
1100        assert!(get_chan.effect.is_some());
1101        let effect = get_chan.effect.as_ref().unwrap();
1102        // Input: ( ..a Message -- )
1103        // Output: ( ..a Int -- )
1104        assert_eq!(
1105            format!("{:?}", effect.outputs),
1106            "Cons { rest: RowVar(\"a\"), top: Int }"
1107        );
1108    }
1109}