1use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10 pub fn new() -> Self {
11 Program {
12 includes: Vec::new(),
13 unions: Vec::new(),
14 words: Vec::new(),
15 }
16 }
17
18 pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19 self.words.iter().find(|w| w.name == name)
20 }
21
22 pub fn validate_word_calls(&self) -> Result<(), String> {
24 self.validate_word_calls_with_externals(&[])
25 }
26
27 pub fn validate_word_calls_with_externals(
32 &self,
33 external_words: &[&str],
34 ) -> Result<(), String> {
35 let builtins = [
38 "io.write",
40 "io.write-line",
41 "io.read-line",
42 "io.read-line+",
43 "io.read-n",
44 "int->string",
45 "symbol->string",
46 "string->symbol",
47 "args.count",
49 "args.at",
50 "file.slurp",
52 "file.exists?",
53 "file.for-each-line+",
54 "file.spit",
55 "file.append",
56 "file.delete",
57 "file.size",
58 "dir.exists?",
60 "dir.make",
61 "dir.delete",
62 "dir.list",
63 "string.concat",
65 "string.length",
66 "string.byte-length",
67 "string.char-at",
68 "string.substring",
69 "char->string",
70 "string.find",
71 "string.split",
72 "string.contains",
73 "string.starts-with",
74 "string.empty?",
75 "string.trim",
76 "string.chomp",
77 "string.to-upper",
78 "string.to-lower",
79 "string.equal?",
80 "string.join",
81 "string.json-escape",
82 "string->int",
83 "symbol.=",
85 "encoding.base64-encode",
87 "encoding.base64-decode",
88 "encoding.base64url-encode",
89 "encoding.base64url-decode",
90 "encoding.hex-encode",
91 "encoding.hex-decode",
92 "crypto.sha256",
94 "crypto.hmac-sha256",
95 "crypto.constant-time-eq",
96 "crypto.random-bytes",
97 "crypto.random-int",
98 "crypto.uuid4",
99 "crypto.aes-gcm-encrypt",
100 "crypto.aes-gcm-decrypt",
101 "crypto.pbkdf2-sha256",
102 "crypto.ed25519-keypair",
103 "crypto.ed25519-sign",
104 "crypto.ed25519-verify",
105 "http.get",
107 "http.post",
108 "http.put",
109 "http.delete",
110 "list.make",
112 "list.push",
113 "list.get",
114 "list.set",
115 "list.map",
116 "list.filter",
117 "list.fold",
118 "list.each",
119 "list.length",
120 "list.empty?",
121 "list.reverse",
122 "map.make",
124 "map.get",
125 "map.set",
126 "map.has?",
127 "map.remove",
128 "map.keys",
129 "map.values",
130 "map.size",
131 "map.empty?",
132 "map.each",
133 "map.fold",
134 "variant.field-count",
136 "variant.tag",
137 "variant.field-at",
138 "variant.append",
139 "variant.last",
140 "variant.init",
141 "variant.make-0",
142 "variant.make-1",
143 "variant.make-2",
144 "variant.make-3",
145 "variant.make-4",
146 "wrap-0",
148 "wrap-1",
149 "wrap-2",
150 "wrap-3",
151 "wrap-4",
152 "i.add",
154 "i.subtract",
155 "i.multiply",
156 "i.divide",
157 "i.modulo",
158 "i.+",
160 "i.-",
161 "i.*",
162 "i./",
163 "i.%",
164 "i.=",
166 "i.<",
167 "i.>",
168 "i.<=",
169 "i.>=",
170 "i.<>",
171 "i.eq",
173 "i.lt",
174 "i.gt",
175 "i.lte",
176 "i.gte",
177 "i.neq",
178 "dup",
180 "drop",
181 "swap",
182 "over",
183 "rot",
184 "nip",
185 "tuck",
186 "2dup",
187 "3drop",
188 "pick",
189 "roll",
190 ">aux",
192 "aux>",
193 "and",
195 "or",
196 "not",
197 "band",
199 "bor",
200 "bxor",
201 "bnot",
202 "i.neg",
203 "negate",
204 "+",
206 "-",
207 "*",
208 "/",
209 "%",
210 "=",
211 "<",
212 ">",
213 "<=",
214 ">=",
215 "<>",
216 "shl",
217 "shr",
218 "popcount",
219 "clz",
220 "ctz",
221 "int-bits",
222 "chan.make",
224 "chan.send",
225 "chan.receive",
226 "chan.close",
227 "chan.yield",
228 "call",
230 "dip",
232 "keep",
233 "bi",
234 "strand.spawn",
235 "strand.weave",
236 "strand.resume",
237 "strand.weave-cancel",
238 "yield",
239 "cond",
240 "tcp.listen",
242 "tcp.accept",
243 "tcp.read",
244 "tcp.write",
245 "tcp.close",
246 "os.getenv",
248 "os.home-dir",
249 "os.current-dir",
250 "os.path-exists",
251 "os.path-is-file",
252 "os.path-is-dir",
253 "os.path-join",
254 "os.path-parent",
255 "os.path-filename",
256 "os.exit",
257 "os.name",
258 "os.arch",
259 "signal.trap",
261 "signal.received?",
262 "signal.pending?",
263 "signal.default",
264 "signal.ignore",
265 "signal.clear",
266 "signal.SIGINT",
267 "signal.SIGTERM",
268 "signal.SIGHUP",
269 "signal.SIGPIPE",
270 "signal.SIGUSR1",
271 "signal.SIGUSR2",
272 "signal.SIGCHLD",
273 "signal.SIGALRM",
274 "signal.SIGCONT",
275 "terminal.raw-mode",
277 "terminal.read-char",
278 "terminal.read-char?",
279 "terminal.width",
280 "terminal.height",
281 "terminal.flush",
282 "f.add",
284 "f.subtract",
285 "f.multiply",
286 "f.divide",
287 "f.+",
289 "f.-",
290 "f.*",
291 "f./",
292 "f.=",
294 "f.<",
295 "f.>",
296 "f.<=",
297 "f.>=",
298 "f.<>",
299 "f.eq",
301 "f.lt",
302 "f.gt",
303 "f.lte",
304 "f.gte",
305 "f.neq",
306 "int->float",
308 "float->int",
309 "float->string",
310 "string->float",
311 "test.init",
313 "test.finish",
314 "test.has-failures",
315 "test.assert",
316 "test.assert-not",
317 "test.assert-eq",
318 "test.assert-eq-str",
319 "test.fail",
320 "test.pass-count",
321 "test.fail-count",
322 "time.now",
324 "time.nanos",
325 "time.sleep-ms",
326 "son.dump",
328 "son.dump-pretty",
329 "stack.dump",
331 "regex.match?",
333 "regex.find",
334 "regex.find-all",
335 "regex.replace",
336 "regex.replace-all",
337 "regex.captures",
338 "regex.split",
339 "regex.valid?",
340 "compress.gzip",
342 "compress.gzip-level",
343 "compress.gunzip",
344 "compress.zstd",
345 "compress.zstd-level",
346 "compress.unzstd",
347 ];
348
349 for word in &self.words {
350 self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
351 }
352
353 Ok(())
354 }
355
356 fn validate_statements(
358 &self,
359 statements: &[Statement],
360 word_name: &str,
361 builtins: &[&str],
362 external_words: &[&str],
363 ) -> Result<(), String> {
364 for statement in statements {
365 match statement {
366 Statement::WordCall { name, .. } => {
367 if builtins.contains(&name.as_str()) {
369 continue;
370 }
371 if self.find_word(name).is_some() {
373 continue;
374 }
375 if external_words.contains(&name.as_str()) {
377 continue;
378 }
379 return Err(format!(
381 "Undefined word '{}' called in word '{}'. \
382 Did you forget to define it or misspell a built-in?",
383 name, word_name
384 ));
385 }
386 Statement::If {
387 then_branch,
388 else_branch,
389 span: _,
390 } => {
391 self.validate_statements(then_branch, word_name, builtins, external_words)?;
393 if let Some(eb) = else_branch {
394 self.validate_statements(eb, word_name, builtins, external_words)?;
395 }
396 }
397 Statement::Quotation { body, .. } => {
398 self.validate_statements(body, word_name, builtins, external_words)?;
400 }
401 Statement::Match { arms, span: _ } => {
402 for arm in arms {
404 self.validate_statements(&arm.body, word_name, builtins, external_words)?;
405 }
406 }
407 _ => {} }
409 }
410 Ok(())
411 }
412
413 pub const MAX_VARIANT_FIELDS: usize = 12;
417
418 pub fn generate_constructors(&mut self) -> Result<(), String> {
431 let mut new_words = Vec::new();
432
433 for union_def in &self.unions {
434 for variant in &union_def.variants {
435 let field_count = variant.fields.len();
436
437 if field_count > Self::MAX_VARIANT_FIELDS {
439 return Err(format!(
440 "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
441 Consider grouping fields into nested union types.",
442 variant.name,
443 union_def.name,
444 field_count,
445 Self::MAX_VARIANT_FIELDS
446 ));
447 }
448
449 let constructor_name = format!("Make-{}", variant.name);
451 let mut input_stack = StackType::RowVar("a".to_string());
452 for field in &variant.fields {
453 let field_type = parse_type_name(&field.type_name);
454 input_stack = input_stack.push(field_type);
455 }
456 let output_stack =
457 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
458 let effect = Effect::new(input_stack, output_stack);
459 let body = vec![
460 Statement::Symbol(variant.name.clone()),
461 Statement::WordCall {
462 name: format!("variant.make-{}", field_count),
463 span: None,
464 },
465 ];
466 new_words.push(WordDef {
467 name: constructor_name,
468 effect: Some(effect),
469 body,
470 source: variant.source.clone(),
471 allowed_lints: vec![],
472 });
473
474 let predicate_name = format!("is-{}?", variant.name);
478 let predicate_input =
479 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
480 let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
481 let predicate_effect = Effect::new(predicate_input, predicate_output);
482 let predicate_body = vec![
483 Statement::WordCall {
484 name: "variant.tag".to_string(),
485 span: None,
486 },
487 Statement::Symbol(variant.name.clone()),
488 Statement::WordCall {
489 name: "symbol.=".to_string(),
490 span: None,
491 },
492 ];
493 new_words.push(WordDef {
494 name: predicate_name,
495 effect: Some(predicate_effect),
496 body: predicate_body,
497 source: variant.source.clone(),
498 allowed_lints: vec![],
499 });
500
501 for (index, field) in variant.fields.iter().enumerate() {
505 let accessor_name = format!("{}-{}", variant.name, field.name);
506 let field_type = parse_type_name(&field.type_name);
507 let accessor_input = StackType::RowVar("a".to_string())
508 .push(Type::Union(union_def.name.clone()));
509 let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
510 let accessor_effect = Effect::new(accessor_input, accessor_output);
511 let accessor_body = vec![
512 Statement::IntLiteral(index as i64),
513 Statement::WordCall {
514 name: "variant.field-at".to_string(),
515 span: None,
516 },
517 ];
518 new_words.push(WordDef {
519 name: accessor_name,
520 effect: Some(accessor_effect),
521 body: accessor_body,
522 source: variant.source.clone(), allowed_lints: vec![],
524 });
525 }
526 }
527 }
528
529 self.words.extend(new_words);
530 Ok(())
531 }
532
533 pub fn fixup_union_types(&mut self) {
542 let union_names: std::collections::HashSet<String> =
544 self.unions.iter().map(|u| u.name.clone()).collect();
545
546 for word in &mut self.words {
548 if let Some(ref mut effect) = word.effect {
549 Self::fixup_stack_type(&mut effect.inputs, &union_names);
550 Self::fixup_stack_type(&mut effect.outputs, &union_names);
551 }
552 }
553 }
554
555 fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
557 match stack {
558 StackType::Empty | StackType::RowVar(_) => {}
559 StackType::Cons { rest, top } => {
560 Self::fixup_type(top, union_names);
561 Self::fixup_stack_type(rest, union_names);
562 }
563 }
564 }
565
566 fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
568 match ty {
569 Type::Var(name) if union_names.contains(name) => {
570 *ty = Type::Union(name.clone());
571 }
572 Type::Quotation(effect) => {
573 Self::fixup_stack_type(&mut effect.inputs, union_names);
574 Self::fixup_stack_type(&mut effect.outputs, union_names);
575 }
576 Type::Closure { effect, captures } => {
577 Self::fixup_stack_type(&mut effect.inputs, union_names);
578 Self::fixup_stack_type(&mut effect.outputs, union_names);
579 for cap in captures {
580 Self::fixup_type(cap, union_names);
581 }
582 }
583 _ => {}
584 }
585 }
586}
587
588fn parse_type_name(name: &str) -> Type {
591 match name {
592 "Int" => Type::Int,
593 "Float" => Type::Float,
594 "Bool" => Type::Bool,
595 "String" => Type::String,
596 "Channel" => Type::Channel,
597 other => Type::Union(other.to_string()),
598 }
599}
600
601impl Default for Program {
602 fn default() -> Self {
603 Self::new()
604 }
605}