Skip to main content

seqc/ast/
program.rs

1//! Program-level AST methods: word-call validation, auto-generated variant
2//! constructors (`Make-Variant`), and type fix-up for union types declared
3//! in stack effects.
4
5use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10    pub fn new() -> Self {
11        Program {
12            includes: Vec::new(),
13            unions: Vec::new(),
14            words: Vec::new(),
15        }
16    }
17
18    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19        self.words.iter().find(|w| w.name == name)
20    }
21
22    /// Validate that all word calls reference either a defined word or a built-in
23    pub fn validate_word_calls(&self) -> Result<(), String> {
24        self.validate_word_calls_with_externals(&[])
25    }
26
27    /// Validate that all word calls reference a defined word, built-in, or external word.
28    ///
29    /// The `external_words` parameter should contain names of words available from
30    /// external sources (e.g., included modules) that should be considered valid.
31    pub fn validate_word_calls_with_externals(
32        &self,
33        external_words: &[&str],
34    ) -> Result<(), String> {
35        // List of known runtime built-ins
36        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
37        let builtins = [
38            // I/O operations
39            "io.write",
40            "io.write-line",
41            "io.read-line",
42            "io.read-line+",
43            "io.read-n",
44            "int->string",
45            "symbol->string",
46            "string->symbol",
47            // Command-line arguments
48            "args.count",
49            "args.at",
50            // File operations
51            "file.slurp",
52            "file.exists?",
53            "file.for-each-line+",
54            "file.spit",
55            "file.append",
56            "file.delete",
57            "file.size",
58            // Directory operations
59            "dir.exists?",
60            "dir.make",
61            "dir.delete",
62            "dir.list",
63            // String operations
64            "string.concat",
65            "string.length",
66            "string.byte-length",
67            "string.char-at",
68            "string.substring",
69            "char->string",
70            "string.find",
71            "string.split",
72            "string.contains",
73            "string.starts-with",
74            "string.empty?",
75            "string.trim",
76            "string.chomp",
77            "string.to-upper",
78            "string.to-lower",
79            "string.equal?",
80            "string.join",
81            "string.json-escape",
82            "string->int",
83            // Symbol operations
84            "symbol.=",
85            // Encoding operations
86            "encoding.base64-encode",
87            "encoding.base64-decode",
88            "encoding.base64url-encode",
89            "encoding.base64url-decode",
90            "encoding.hex-encode",
91            "encoding.hex-decode",
92            // Crypto operations
93            "crypto.sha256",
94            "crypto.hmac-sha256",
95            "crypto.constant-time-eq",
96            "crypto.random-bytes",
97            "crypto.random-int",
98            "crypto.uuid4",
99            "crypto.aes-gcm-encrypt",
100            "crypto.aes-gcm-decrypt",
101            "crypto.pbkdf2-sha256",
102            "crypto.ed25519-keypair",
103            "crypto.ed25519-sign",
104            "crypto.ed25519-verify",
105            // HTTP client operations
106            "http.get",
107            "http.post",
108            "http.put",
109            "http.delete",
110            // List operations
111            "list.make",
112            "list.push",
113            "list.get",
114            "list.set",
115            "list.map",
116            "list.filter",
117            "list.fold",
118            "list.each",
119            "list.length",
120            "list.empty?",
121            "list.reverse",
122            // Map operations
123            "map.make",
124            "map.get",
125            "map.set",
126            "map.has?",
127            "map.remove",
128            "map.keys",
129            "map.values",
130            "map.size",
131            "map.empty?",
132            "map.each",
133            "map.fold",
134            // Variant operations
135            "variant.field-count",
136            "variant.tag",
137            "variant.field-at",
138            "variant.append",
139            "variant.last",
140            "variant.init",
141            "variant.make-0",
142            "variant.make-1",
143            "variant.make-2",
144            "variant.make-3",
145            "variant.make-4",
146            // SON wrap aliases
147            "wrap-0",
148            "wrap-1",
149            "wrap-2",
150            "wrap-3",
151            "wrap-4",
152            // Integer arithmetic operations
153            "i.add",
154            "i.subtract",
155            "i.multiply",
156            "i.divide",
157            "i.modulo",
158            // Terse integer arithmetic
159            "i.+",
160            "i.-",
161            "i.*",
162            "i./",
163            "i.%",
164            // Integer comparison operations (return 0 or 1)
165            "i.=",
166            "i.<",
167            "i.>",
168            "i.<=",
169            "i.>=",
170            "i.<>",
171            // Integer comparison operations (verbose form)
172            "i.eq",
173            "i.lt",
174            "i.gt",
175            "i.lte",
176            "i.gte",
177            "i.neq",
178            // Stack operations (simple - no parameters)
179            "dup",
180            "drop",
181            "swap",
182            "over",
183            "rot",
184            "nip",
185            "tuck",
186            "2dup",
187            "3drop",
188            "pick",
189            "roll",
190            // Aux stack operations
191            ">aux",
192            "aux>",
193            // Boolean operations
194            "and",
195            "or",
196            "not",
197            // Bitwise operations
198            "band",
199            "bor",
200            "bxor",
201            "bnot",
202            "i.neg",
203            "negate",
204            // Arithmetic sugar (resolved to concrete ops by typechecker)
205            "+",
206            "-",
207            "*",
208            "/",
209            "%",
210            "=",
211            "<",
212            ">",
213            "<=",
214            ">=",
215            "<>",
216            "shl",
217            "shr",
218            "popcount",
219            "clz",
220            "ctz",
221            "int-bits",
222            // Channel operations
223            "chan.make",
224            "chan.send",
225            "chan.receive",
226            "chan.close",
227            "chan.yield",
228            // Quotation operations
229            "call",
230            // Dataflow combinators
231            "dip",
232            "keep",
233            "bi",
234            "strand.spawn",
235            "strand.weave",
236            "strand.resume",
237            "strand.weave-cancel",
238            "yield",
239            "cond",
240            // TCP operations
241            "tcp.listen",
242            "tcp.accept",
243            "tcp.read",
244            "tcp.write",
245            "tcp.close",
246            // OS operations
247            "os.getenv",
248            "os.home-dir",
249            "os.current-dir",
250            "os.path-exists",
251            "os.path-is-file",
252            "os.path-is-dir",
253            "os.path-join",
254            "os.path-parent",
255            "os.path-filename",
256            "os.exit",
257            "os.name",
258            "os.arch",
259            // Signal handling
260            "signal.trap",
261            "signal.received?",
262            "signal.pending?",
263            "signal.default",
264            "signal.ignore",
265            "signal.clear",
266            "signal.SIGINT",
267            "signal.SIGTERM",
268            "signal.SIGHUP",
269            "signal.SIGPIPE",
270            "signal.SIGUSR1",
271            "signal.SIGUSR2",
272            "signal.SIGCHLD",
273            "signal.SIGALRM",
274            "signal.SIGCONT",
275            // Terminal operations
276            "terminal.raw-mode",
277            "terminal.read-char",
278            "terminal.read-char?",
279            "terminal.width",
280            "terminal.height",
281            "terminal.flush",
282            // Float arithmetic operations (verbose form)
283            "f.add",
284            "f.subtract",
285            "f.multiply",
286            "f.divide",
287            // Float arithmetic operations (terse form)
288            "f.+",
289            "f.-",
290            "f.*",
291            "f./",
292            // Float comparison operations (symbol form)
293            "f.=",
294            "f.<",
295            "f.>",
296            "f.<=",
297            "f.>=",
298            "f.<>",
299            // Float comparison operations (verbose form)
300            "f.eq",
301            "f.lt",
302            "f.gt",
303            "f.lte",
304            "f.gte",
305            "f.neq",
306            // Type conversions
307            "int->float",
308            "float->int",
309            "float->string",
310            "string->float",
311            // Test framework operations
312            "test.init",
313            "test.set-name",
314            "test.finish",
315            "test.has-failures",
316            "test.assert",
317            "test.assert-not",
318            "test.assert-eq",
319            "test.assert-eq-str",
320            "test.fail",
321            "test.pass-count",
322            "test.fail-count",
323            // Time operations
324            "time.now",
325            "time.nanos",
326            "time.sleep-ms",
327            // SON serialization
328            "son.dump",
329            "son.dump-pretty",
330            // Stack introspection (for REPL)
331            "stack.dump",
332            // Regex operations
333            "regex.match?",
334            "regex.find",
335            "regex.find-all",
336            "regex.replace",
337            "regex.replace-all",
338            "regex.captures",
339            "regex.split",
340            "regex.valid?",
341            // Compression operations
342            "compress.gzip",
343            "compress.gzip-level",
344            "compress.gunzip",
345            "compress.zstd",
346            "compress.zstd-level",
347            "compress.unzstd",
348        ];
349
350        for word in &self.words {
351            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
352        }
353
354        Ok(())
355    }
356
357    /// Helper to validate word calls in a list of statements (recursively)
358    fn validate_statements(
359        &self,
360        statements: &[Statement],
361        word_name: &str,
362        builtins: &[&str],
363        external_words: &[&str],
364    ) -> Result<(), String> {
365        for statement in statements {
366            match statement {
367                Statement::WordCall { name, .. } => {
368                    // Check if it's a built-in
369                    if builtins.contains(&name.as_str()) {
370                        continue;
371                    }
372                    // Check if it's a user-defined word
373                    if self.find_word(name).is_some() {
374                        continue;
375                    }
376                    // Check if it's an external word (from includes)
377                    if external_words.contains(&name.as_str()) {
378                        continue;
379                    }
380                    // Undefined word!
381                    return Err(format!(
382                        "Undefined word '{}' called in word '{}'. \
383                         Did you forget to define it or misspell a built-in?",
384                        name, word_name
385                    ));
386                }
387                Statement::If {
388                    then_branch,
389                    else_branch,
390                    span: _,
391                } => {
392                    // Recursively validate both branches
393                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
394                    if let Some(eb) = else_branch {
395                        self.validate_statements(eb, word_name, builtins, external_words)?;
396                    }
397                }
398                Statement::Quotation { body, .. } => {
399                    // Recursively validate quotation body
400                    self.validate_statements(body, word_name, builtins, external_words)?;
401                }
402                Statement::Match { arms, span: _ } => {
403                    // Recursively validate each match arm's body
404                    for arm in arms {
405                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
406                    }
407                }
408                _ => {} // Literals don't need validation
409            }
410        }
411        Ok(())
412    }
413
414    /// Generate constructor words for all union definitions
415    ///
416    /// Maximum number of fields a variant can have (limited by runtime support)
417    pub const MAX_VARIANT_FIELDS: usize = 12;
418
419    /// Generate helper words for union types:
420    /// 1. Constructors: `Make-VariantName` - creates variant instances
421    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
422    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
423    ///
424    /// Example: For `union Message { Get { chan: Int } }`
425    /// Generates:
426    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
427    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
428    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
429    ///
430    /// Returns an error if any variant exceeds the maximum field count.
431    pub fn generate_constructors(&mut self) -> Result<(), String> {
432        let mut new_words = Vec::new();
433
434        for union_def in &self.unions {
435            for variant in &union_def.variants {
436                let field_count = variant.fields.len();
437
438                // Check field count limit before generating constructor
439                if field_count > Self::MAX_VARIANT_FIELDS {
440                    return Err(format!(
441                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
442                         Consider grouping fields into nested union types.",
443                        variant.name,
444                        union_def.name,
445                        field_count,
446                        Self::MAX_VARIANT_FIELDS
447                    ));
448                }
449
450                // 1. Generate constructor: Make-VariantName
451                let constructor_name = format!("Make-{}", variant.name);
452                let mut input_stack = StackType::RowVar("a".to_string());
453                for field in &variant.fields {
454                    let field_type = parse_type_name(&field.type_name);
455                    input_stack = input_stack.push(field_type);
456                }
457                let output_stack =
458                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
459                let effect = Effect::new(input_stack, output_stack);
460                let body = vec![
461                    Statement::Symbol(variant.name.clone()),
462                    Statement::WordCall {
463                        name: format!("variant.make-{}", field_count),
464                        span: None,
465                    },
466                ];
467                new_words.push(WordDef {
468                    name: constructor_name,
469                    effect: Some(effect),
470                    body,
471                    source: variant.source.clone(),
472                    allowed_lints: vec![],
473                });
474
475                // 2. Generate predicate: is-VariantName?
476                // Effect: ( UnionType -- Bool )
477                // Body: variant.tag :VariantName symbol.=
478                let predicate_name = format!("is-{}?", variant.name);
479                let predicate_input =
480                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
481                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
482                let predicate_effect = Effect::new(predicate_input, predicate_output);
483                let predicate_body = vec![
484                    Statement::WordCall {
485                        name: "variant.tag".to_string(),
486                        span: None,
487                    },
488                    Statement::Symbol(variant.name.clone()),
489                    Statement::WordCall {
490                        name: "symbol.=".to_string(),
491                        span: None,
492                    },
493                ];
494                new_words.push(WordDef {
495                    name: predicate_name,
496                    effect: Some(predicate_effect),
497                    body: predicate_body,
498                    source: variant.source.clone(),
499                    allowed_lints: vec![],
500                });
501
502                // 3. Generate field accessors: VariantName-fieldname
503                // Effect: ( UnionType -- FieldType )
504                // Body: N variant.field-at
505                for (index, field) in variant.fields.iter().enumerate() {
506                    let accessor_name = format!("{}-{}", variant.name, field.name);
507                    let field_type = parse_type_name(&field.type_name);
508                    let accessor_input = StackType::RowVar("a".to_string())
509                        .push(Type::Union(union_def.name.clone()));
510                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
511                    let accessor_effect = Effect::new(accessor_input, accessor_output);
512                    let accessor_body = vec![
513                        Statement::IntLiteral(index as i64),
514                        Statement::WordCall {
515                            name: "variant.field-at".to_string(),
516                            span: None,
517                        },
518                    ];
519                    new_words.push(WordDef {
520                        name: accessor_name,
521                        effect: Some(accessor_effect),
522                        body: accessor_body,
523                        source: variant.source.clone(), // Use variant's source for field accessors
524                        allowed_lints: vec![],
525                    });
526                }
527            }
528        }
529
530        self.words.extend(new_words);
531        Ok(())
532    }
533
534    /// RFC #345: Fix up type variables in stack effects that should be union types
535    ///
536    /// When parsing files with includes, type variables like "Message" in
537    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
538    /// union definition is in an included file. After resolving includes,
539    /// we know all union names and can convert these to `Type::Union("Message")`.
540    ///
541    /// This ensures proper nominal type checking for union types across files.
542    pub fn fixup_union_types(&mut self) {
543        // Collect all union names from the program
544        let union_names: std::collections::HashSet<String> =
545            self.unions.iter().map(|u| u.name.clone()).collect();
546
547        // Fix up types in all word effects
548        for word in &mut self.words {
549            if let Some(ref mut effect) = word.effect {
550                Self::fixup_stack_type(&mut effect.inputs, &union_names);
551                Self::fixup_stack_type(&mut effect.outputs, &union_names);
552            }
553        }
554    }
555
556    /// Recursively fix up types in a stack type
557    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
558        match stack {
559            StackType::Empty | StackType::RowVar(_) => {}
560            StackType::Cons { rest, top } => {
561                Self::fixup_type(top, union_names);
562                Self::fixup_stack_type(rest, union_names);
563            }
564        }
565    }
566
567    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
568    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
569        match ty {
570            Type::Var(name) if union_names.contains(name) => {
571                *ty = Type::Union(name.clone());
572            }
573            Type::Quotation(effect) => {
574                Self::fixup_stack_type(&mut effect.inputs, union_names);
575                Self::fixup_stack_type(&mut effect.outputs, union_names);
576            }
577            Type::Closure { effect, captures } => {
578                Self::fixup_stack_type(&mut effect.inputs, union_names);
579                Self::fixup_stack_type(&mut effect.outputs, union_names);
580                for cap in captures {
581                    Self::fixup_type(cap, union_names);
582                }
583            }
584            _ => {}
585        }
586    }
587}
588
589/// Parse a type name string into a Type
590/// Used by constructor generation to build stack effects
591fn parse_type_name(name: &str) -> Type {
592    match name {
593        "Int" => Type::Int,
594        "Float" => Type::Float,
595        "Bool" => Type::Bool,
596        "String" => Type::String,
597        "Channel" => Type::Channel,
598        other => Type::Union(other.to_string()),
599    }
600}
601
602impl Default for Program {
603    fn default() -> Self {
604        Self::new()
605    }
606}