Skip to main content

seqc/ast/
program.rs

1//! Program-level AST methods: word-call validation, auto-generated variant
2//! constructors (`Make-Variant`), and type fix-up for union types declared
3//! in stack effects.
4
5use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10    pub fn new() -> Self {
11        Program {
12            includes: Vec::new(),
13            unions: Vec::new(),
14            words: Vec::new(),
15        }
16    }
17
18    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19        self.words.iter().find(|w| w.name == name)
20    }
21
22    /// Validate that all word calls reference either a defined word or a built-in
23    pub fn validate_word_calls(&self) -> Result<(), String> {
24        self.validate_word_calls_with_externals(&[])
25    }
26
27    /// Validate that all word calls reference a defined word, built-in, or external word.
28    ///
29    /// The `external_words` parameter should contain names of words available from
30    /// external sources (e.g., included modules) that should be considered valid.
31    pub fn validate_word_calls_with_externals(
32        &self,
33        external_words: &[&str],
34    ) -> Result<(), String> {
35        // List of known runtime built-ins
36        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
37        let builtins = [
38            // I/O operations
39            "io.write",
40            "io.write-line",
41            "io.read-line",
42            "io.read-line+",
43            "io.read-n",
44            "int->string",
45            "symbol->string",
46            "string->symbol",
47            // Command-line arguments
48            "args.count",
49            "args.at",
50            // File operations
51            "file.slurp",
52            "file.exists?",
53            "file.for-each-line+",
54            "file.spit",
55            "file.append",
56            "file.delete",
57            "file.size",
58            // Directory operations
59            "dir.exists?",
60            "dir.make",
61            "dir.delete",
62            "dir.list",
63            // String operations
64            "string.concat",
65            "string.length",
66            "string.byte-length",
67            "string.char-at",
68            "string.substring",
69            "char->string",
70            "string.find",
71            "string.split",
72            "string.contains",
73            "string.starts-with",
74            "string.empty?",
75            "string.trim",
76            "string.chomp",
77            "string.to-upper",
78            "string.to-lower",
79            "string.equal?",
80            "string.join",
81            "string.json-escape",
82            "string->int",
83            // Symbol operations
84            "symbol.=",
85            // Encoding operations
86            "encoding.base64-encode",
87            "encoding.base64-decode",
88            "encoding.base64url-encode",
89            "encoding.base64url-decode",
90            "encoding.hex-encode",
91            "encoding.hex-decode",
92            // Crypto operations
93            "crypto.sha256",
94            "crypto.hmac-sha256",
95            "crypto.constant-time-eq",
96            "crypto.random-bytes",
97            "crypto.random-int",
98            "crypto.uuid4",
99            "crypto.aes-gcm-encrypt",
100            "crypto.aes-gcm-decrypt",
101            "crypto.pbkdf2-sha256",
102            "crypto.ed25519-keypair",
103            "crypto.ed25519-sign",
104            "crypto.ed25519-verify",
105            // HTTP client operations
106            "http.get",
107            "http.post",
108            "http.put",
109            "http.delete",
110            // List operations
111            "list.make",
112            "list.push",
113            "list.get",
114            "list.set",
115            "list.map",
116            "list.filter",
117            "list.fold",
118            "list.each",
119            "list.length",
120            "list.empty?",
121            "list.reverse",
122            // Map operations
123            "map.make",
124            "map.get",
125            "map.set",
126            "map.has?",
127            "map.remove",
128            "map.keys",
129            "map.values",
130            "map.size",
131            "map.empty?",
132            "map.each",
133            "map.fold",
134            // Variant operations
135            "variant.field-count",
136            "variant.tag",
137            "variant.field-at",
138            "variant.append",
139            "variant.last",
140            "variant.init",
141            "variant.make-0",
142            "variant.make-1",
143            "variant.make-2",
144            "variant.make-3",
145            "variant.make-4",
146            // SON wrap aliases
147            "wrap-0",
148            "wrap-1",
149            "wrap-2",
150            "wrap-3",
151            "wrap-4",
152            // Integer arithmetic operations
153            "i.add",
154            "i.subtract",
155            "i.multiply",
156            "i.divide",
157            "i.modulo",
158            // Terse integer arithmetic
159            "i.+",
160            "i.-",
161            "i.*",
162            "i./",
163            "i.%",
164            // Integer comparison operations (return 0 or 1)
165            "i.=",
166            "i.<",
167            "i.>",
168            "i.<=",
169            "i.>=",
170            "i.<>",
171            // Integer comparison operations (verbose form)
172            "i.eq",
173            "i.lt",
174            "i.gt",
175            "i.lte",
176            "i.gte",
177            "i.neq",
178            // Stack operations (simple - no parameters)
179            "dup",
180            "drop",
181            "swap",
182            "over",
183            "rot",
184            "nip",
185            "tuck",
186            "2dup",
187            "3drop",
188            "pick",
189            "roll",
190            // Aux stack operations
191            ">aux",
192            "aux>",
193            // Boolean operations
194            "and",
195            "or",
196            "not",
197            // Bitwise operations
198            "band",
199            "bor",
200            "bxor",
201            "bnot",
202            "i.neg",
203            "negate",
204            // Arithmetic sugar (resolved to concrete ops by typechecker)
205            "+",
206            "-",
207            "*",
208            "/",
209            "%",
210            "=",
211            "<",
212            ">",
213            "<=",
214            ">=",
215            "<>",
216            "shl",
217            "shr",
218            "popcount",
219            "clz",
220            "ctz",
221            "int-bits",
222            // Channel operations
223            "chan.make",
224            "chan.send",
225            "chan.receive",
226            "chan.close",
227            "chan.yield",
228            // Quotation operations
229            "call",
230            // Dataflow combinators
231            "dip",
232            "keep",
233            "bi",
234            "if",
235            "strand.spawn",
236            "strand.weave",
237            "strand.resume",
238            "strand.weave-cancel",
239            "yield",
240            "cond",
241            // TCP operations
242            "tcp.listen",
243            "tcp.accept",
244            "tcp.read",
245            "tcp.write",
246            "tcp.close",
247            // OS operations
248            "os.getenv",
249            "os.home-dir",
250            "os.current-dir",
251            "os.path-exists",
252            "os.path-is-file",
253            "os.path-is-dir",
254            "os.path-join",
255            "os.path-parent",
256            "os.path-filename",
257            "os.exit",
258            "os.name",
259            "os.arch",
260            // Signal handling
261            "signal.trap",
262            "signal.received?",
263            "signal.pending?",
264            "signal.default",
265            "signal.ignore",
266            "signal.clear",
267            "signal.SIGINT",
268            "signal.SIGTERM",
269            "signal.SIGHUP",
270            "signal.SIGPIPE",
271            "signal.SIGUSR1",
272            "signal.SIGUSR2",
273            "signal.SIGCHLD",
274            "signal.SIGALRM",
275            "signal.SIGCONT",
276            // Terminal operations
277            "terminal.raw-mode",
278            "terminal.read-char",
279            "terminal.read-char?",
280            "terminal.width",
281            "terminal.height",
282            "terminal.flush",
283            // Float arithmetic operations (verbose form)
284            "f.add",
285            "f.subtract",
286            "f.multiply",
287            "f.divide",
288            // Float arithmetic operations (terse form)
289            "f.+",
290            "f.-",
291            "f.*",
292            "f./",
293            // Float comparison operations (symbol form)
294            "f.=",
295            "f.<",
296            "f.>",
297            "f.<=",
298            "f.>=",
299            "f.<>",
300            // Float comparison operations (verbose form)
301            "f.eq",
302            "f.lt",
303            "f.gt",
304            "f.lte",
305            "f.gte",
306            "f.neq",
307            // Type conversions
308            "int->float",
309            "float->int",
310            "float->string",
311            "string->float",
312            // Test framework operations
313            "test.init",
314            "test.set-name",
315            "test.finish",
316            "test.has-failures",
317            "test.assert",
318            "test.assert-not",
319            "test.assert-eq",
320            "test.assert-eq-str",
321            "test.fail",
322            "test.pass-count",
323            "test.fail-count",
324            // Time operations
325            "time.now",
326            "time.nanos",
327            "time.sleep-ms",
328            // SON serialization
329            "son.dump",
330            "son.dump-pretty",
331            // Stack introspection (for REPL)
332            "stack.dump",
333            // Regex operations
334            "regex.match?",
335            "regex.find",
336            "regex.find-all",
337            "regex.replace",
338            "regex.replace-all",
339            "regex.captures",
340            "regex.split",
341            "regex.valid?",
342            // Compression operations
343            "compress.gzip",
344            "compress.gzip-level",
345            "compress.gunzip",
346            "compress.zstd",
347            "compress.zstd-level",
348            "compress.unzstd",
349        ];
350
351        for word in &self.words {
352            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
353        }
354
355        Ok(())
356    }
357
358    /// Helper to validate word calls in a list of statements (recursively)
359    fn validate_statements(
360        &self,
361        statements: &[Statement],
362        word_name: &str,
363        builtins: &[&str],
364        external_words: &[&str],
365    ) -> Result<(), String> {
366        for statement in statements {
367            match statement {
368                Statement::WordCall { name, .. } => {
369                    // Check if it's a built-in
370                    if builtins.contains(&name.as_str()) {
371                        continue;
372                    }
373                    // Check if it's a user-defined word
374                    if self.find_word(name).is_some() {
375                        continue;
376                    }
377                    // Check if it's an external word (from includes)
378                    if external_words.contains(&name.as_str()) {
379                        continue;
380                    }
381                    // Undefined word!
382                    return Err(format!(
383                        "Undefined word '{}' called in word '{}'. \
384                         Did you forget to define it or misspell a built-in?",
385                        name, word_name
386                    ));
387                }
388                Statement::If {
389                    then_branch,
390                    else_branch,
391                    span: _,
392                } => {
393                    // Recursively validate both branches
394                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
395                    if let Some(eb) = else_branch {
396                        self.validate_statements(eb, word_name, builtins, external_words)?;
397                    }
398                }
399                Statement::Quotation { body, .. } => {
400                    // Recursively validate quotation body
401                    self.validate_statements(body, word_name, builtins, external_words)?;
402                }
403                Statement::Match { arms, span: _ } => {
404                    // Recursively validate each match arm's body
405                    for arm in arms {
406                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
407                    }
408                }
409                _ => {} // Literals don't need validation
410            }
411        }
412        Ok(())
413    }
414
415    /// Generate constructor words for all union definitions
416    ///
417    /// Maximum number of fields a variant can have (limited by runtime support)
418    pub const MAX_VARIANT_FIELDS: usize = 12;
419
420    /// Generate helper words for union types:
421    /// 1. Constructors: `Make-VariantName` - creates variant instances
422    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
423    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
424    ///
425    /// Example: For `union Message { Get { chan: Int } }`
426    /// Generates:
427    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
428    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
429    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
430    ///
431    /// Returns an error if any variant exceeds the maximum field count.
432    pub fn generate_constructors(&mut self) -> Result<(), String> {
433        let mut new_words = Vec::new();
434
435        for union_def in &self.unions {
436            for variant in &union_def.variants {
437                let field_count = variant.fields.len();
438
439                // Check field count limit before generating constructor
440                if field_count > Self::MAX_VARIANT_FIELDS {
441                    return Err(format!(
442                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
443                         Consider grouping fields into nested union types.",
444                        variant.name,
445                        union_def.name,
446                        field_count,
447                        Self::MAX_VARIANT_FIELDS
448                    ));
449                }
450
451                // 1. Generate constructor: Make-VariantName
452                let constructor_name = format!("Make-{}", variant.name);
453                let mut input_stack = StackType::RowVar("a".to_string());
454                for field in &variant.fields {
455                    let field_type = parse_type_name(&field.type_name);
456                    input_stack = input_stack.push(field_type);
457                }
458                let output_stack =
459                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
460                let effect = Effect::new(input_stack, output_stack);
461                let body = vec![
462                    Statement::Symbol(variant.name.clone()),
463                    Statement::WordCall {
464                        name: format!("variant.make-{}", field_count),
465                        span: None,
466                    },
467                ];
468                new_words.push(WordDef {
469                    name: constructor_name,
470                    effect: Some(effect),
471                    body,
472                    source: variant.source.clone(),
473                    allowed_lints: vec![],
474                });
475
476                // 2. Generate predicate: is-VariantName?
477                // Effect: ( UnionType -- Bool )
478                // Body: variant.tag :VariantName symbol.=
479                let predicate_name = format!("is-{}?", variant.name);
480                let predicate_input =
481                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
482                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
483                let predicate_effect = Effect::new(predicate_input, predicate_output);
484                let predicate_body = vec![
485                    Statement::WordCall {
486                        name: "variant.tag".to_string(),
487                        span: None,
488                    },
489                    Statement::Symbol(variant.name.clone()),
490                    Statement::WordCall {
491                        name: "symbol.=".to_string(),
492                        span: None,
493                    },
494                ];
495                new_words.push(WordDef {
496                    name: predicate_name,
497                    effect: Some(predicate_effect),
498                    body: predicate_body,
499                    source: variant.source.clone(),
500                    allowed_lints: vec![],
501                });
502
503                // 3. Generate field accessors: VariantName-fieldname
504                // Effect: ( UnionType -- FieldType )
505                // Body: N variant.field-at
506                for (index, field) in variant.fields.iter().enumerate() {
507                    let accessor_name = format!("{}-{}", variant.name, field.name);
508                    let field_type = parse_type_name(&field.type_name);
509                    let accessor_input = StackType::RowVar("a".to_string())
510                        .push(Type::Union(union_def.name.clone()));
511                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
512                    let accessor_effect = Effect::new(accessor_input, accessor_output);
513                    let accessor_body = vec![
514                        Statement::IntLiteral(index as i64),
515                        Statement::WordCall {
516                            name: "variant.field-at".to_string(),
517                            span: None,
518                        },
519                    ];
520                    new_words.push(WordDef {
521                        name: accessor_name,
522                        effect: Some(accessor_effect),
523                        body: accessor_body,
524                        source: variant.source.clone(), // Use variant's source for field accessors
525                        allowed_lints: vec![],
526                    });
527                }
528            }
529        }
530
531        self.words.extend(new_words);
532        Ok(())
533    }
534
535    /// RFC #345: Fix up type variables in stack effects that should be union types
536    ///
537    /// When parsing files with includes, type variables like "Message" in
538    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
539    /// union definition is in an included file. After resolving includes,
540    /// we know all union names and can convert these to `Type::Union("Message")`.
541    ///
542    /// This ensures proper nominal type checking for union types across files.
543    pub fn fixup_union_types(&mut self) {
544        // Collect all union names from the program
545        let union_names: std::collections::HashSet<String> =
546            self.unions.iter().map(|u| u.name.clone()).collect();
547
548        // Fix up types in all word effects
549        for word in &mut self.words {
550            if let Some(ref mut effect) = word.effect {
551                Self::fixup_stack_type(&mut effect.inputs, &union_names);
552                Self::fixup_stack_type(&mut effect.outputs, &union_names);
553            }
554        }
555    }
556
557    /// Recursively fix up types in a stack type
558    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
559        match stack {
560            StackType::Empty | StackType::RowVar(_) => {}
561            StackType::Cons { rest, top } => {
562                Self::fixup_type(top, union_names);
563                Self::fixup_stack_type(rest, union_names);
564            }
565        }
566    }
567
568    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
569    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
570        match ty {
571            Type::Var(name) if union_names.contains(name) => {
572                *ty = Type::Union(name.clone());
573            }
574            Type::Quotation(effect) => {
575                Self::fixup_stack_type(&mut effect.inputs, union_names);
576                Self::fixup_stack_type(&mut effect.outputs, union_names);
577            }
578            Type::Closure { effect, captures } => {
579                Self::fixup_stack_type(&mut effect.inputs, union_names);
580                Self::fixup_stack_type(&mut effect.outputs, union_names);
581                for cap in captures {
582                    Self::fixup_type(cap, union_names);
583                }
584            }
585            _ => {}
586        }
587    }
588}
589
590/// Parse a type name string into a Type
591/// Used by constructor generation to build stack effects
592fn parse_type_name(name: &str) -> Type {
593    match name {
594        "Int" => Type::Int,
595        "Float" => Type::Float,
596        "Bool" => Type::Bool,
597        "String" => Type::String,
598        "Channel" => Type::Channel,
599        other => Type::Union(other.to_string()),
600    }
601}
602
603impl Default for Program {
604    fn default() -> Self {
605        Self::new()
606    }
607}