1use std::{
4 array,
5 cell::RefCell,
6 collections::HashMap,
7 fmt,
8 hash::{DefaultHasher, Hash, Hasher},
9 slice,
10};
11
12use serde::*;
13
14use crate::{ImplPrimitive, Node, Primitive, SigNode, Signature, SysOp};
15
16impl Node {
17 pub fn sig(&self) -> Result<Signature, SigCheckError> {
19 nodes_all_sigs(self.as_slice())
20 }
21 pub fn sig_node(self) -> Result<SigNode, SigCheckError> {
23 let sig = self.sig()?;
24 Ok(SigNode::new(sig, self.clone()))
25 }
26 pub fn clean_sig(&self) -> Option<Signature> {
28 nodes_clean_sig(slice::from_ref(self))
29 }
30}
31
32pub fn nodes_sig(nodes: &[Node]) -> Result<Signature, SigCheckError> {
33 nodes_all_sigs(nodes)
34}
35
36pub fn nodes_clean_sig(nodes: &[Node]) -> Option<Signature> {
37 let sig = nodes_all_sigs(nodes).ok()?;
38 if sig.under_args() != 0 || sig.under_outputs() != 0 {
39 None
40 } else {
41 Some(sig)
42 }
43}
44
45fn nodes_all_sigs(nodes: &[Node]) -> Result<Signature, SigCheckError> {
46 type AllSigsCache = HashMap<u64, Signature>;
47 thread_local! {
48 static CACHE: RefCell<AllSigsCache> = RefCell::new(AllSigsCache::new());
49 }
50 let mut hasher = DefaultHasher::new();
51 nodes.hash(&mut hasher);
52 let hash = hasher.finish();
53 CACHE.with(|cache| {
54 if let Some(sigs) = cache.borrow().get(&hash) {
55 return Ok(*sigs);
56 }
57 let env = VirtualEnv::from_nodes(nodes)?;
58 let under_sig = env.under.sig();
59 let sig = (env.stack.sig()).with_under(under_sig.args(), under_sig.outputs());
60 cache.borrow_mut().insert(hash, sig);
61 Ok(sig)
62 })
63}
64
65struct VirtualEnv {
67 stack: Stack,
68 under: Stack,
69 node_depth: usize,
70}
71
72#[derive(Debug, Default)]
73struct Stack {
74 height: i32,
75 min_height: usize,
76}
77
78impl Stack {
79 fn pop(&mut self) {
81 self.height -= 1;
82 self.set_min_height();
83 }
84 fn pop_n(&mut self, n: usize) {
85 self.height -= n as i32;
86 self.set_min_height();
87 }
88 fn push(&mut self) {
89 self.height += 1;
90 }
91 fn push_n(&mut self, n: usize) {
92 self.height += n as i32;
93 }
94 fn handle_args_outputs(&mut self, args: usize, outputs: usize) {
95 self.pop_n(args);
96 self.push_n(outputs);
97 }
98 fn set_min_height(&mut self) {
101 self.min_height = self.min_height.max((-self.height).max(0) as usize);
102 }
103 fn sig(&self) -> Signature {
104 Signature::new(
105 self.min_height,
106 (self.height + self.min_height as i32).max(0) as usize,
107 )
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
112pub struct SigCheckError {
113 pub message: String,
114 pub kind: SigCheckErrorKind,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
118pub enum SigCheckErrorKind {
119 Incorrect,
120 NoInverse,
121}
122
123impl SigCheckError {
124 pub fn no_inverse(self) -> Self {
125 Self {
126 kind: SigCheckErrorKind::NoInverse,
127 ..self
128 }
129 }
130}
131
132impl<'a> From<&'a str> for SigCheckError {
133 fn from(s: &'a str) -> Self {
134 Self {
135 message: s.to_string(),
136 kind: SigCheckErrorKind::Incorrect,
137 }
138 }
139}
140
141impl From<String> for SigCheckError {
142 fn from(s: String) -> Self {
143 Self {
144 message: s,
145 kind: SigCheckErrorKind::Incorrect,
146 }
147 }
148}
149
150impl fmt::Display for SigCheckError {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 self.message.fmt(f)
153 }
154}
155
156const MAX_NODE_DEPTH: usize = if cfg!(debug_assertions) { 26 } else { 50 };
157
158impl VirtualEnv {
159 fn from_nodes(nodes: &[Node]) -> Result<Self, SigCheckError> {
160 let mut env = VirtualEnv {
162 stack: Stack::default(),
163 under: Stack::default(),
164 node_depth: 0,
165 };
166 env.nodes(nodes)?;
167 Ok(env)
168 }
169 fn nodes(&mut self, nodes: &[Node]) -> Result<(), SigCheckError> {
170 nodes.iter().try_for_each(|node| self.node(node))
171 }
172 fn sig_node(&mut self, sn: &SigNode) -> Result<(), SigCheckError> {
173 self.node(&sn.node)
174 }
175 fn node(&mut self, node: &Node) -> Result<(), SigCheckError> {
176 use {ImplPrimitive::*, Primitive::*};
177 if self.node_depth > MAX_NODE_DEPTH {
178 return Err("Function is too complex".into());
179 }
180 self.node_depth += 1;
181 match node {
182 Node::Run(nodes) => nodes.iter().try_for_each(|node| self.node(node))?,
183 Node::Push(_) => self.push(),
184 Node::Array { len, inner, .. } => {
185 self.node(inner)?;
186 self.stack.pop_n(*len);
187 self.stack.push();
188 }
189 Node::Label(..) | Node::RemoveLabel(..) => self.handle_args_outputs(1, 1),
190 Node::Call(func, _) => self.handle_sig(func.sig),
191 Node::CallMacro { sig, .. } | Node::CallGlobal(_, sig) => self.handle_sig(*sig),
192 Node::BindGlobal { .. } => self.handle_args_outputs(1, 0),
193 Node::CustomInverse(cust, _) => self.handle_sig(cust.sig()?),
194 Node::Dynamic(dy) => self.handle_sig(dy.sig),
195 &Node::Switch {
196 sig, under_cond, ..
197 } => {
198 self.pop();
199 self.handle_sig(sig);
200 if under_cond {
201 self.under.push();
202 }
203 }
204 Node::Format(parts, ..) => self.handle_args_outputs(parts.len().saturating_sub(1), 1),
205 Node::MatchFormatPattern(parts, ..) => {
206 self.handle_args_outputs(1, parts.len().saturating_sub(1))
207 }
208 Node::Unpack { count, .. } => self.handle_args_outputs(1, *count),
209 Node::ImplMod(Astar, args, _)
210 | Node::ImplMod(AstarFirst, args, _)
211 | Node::ImplMod(AstarSignLen, args, _) => {
212 self.pop();
213 let [neighbors, heuristic, is_goal] = get_args(args)?;
214 let has_costs = neighbors.outputs() == 2;
215 let args = neighbors
216 .args()
217 .max(heuristic.args())
218 .max(is_goal.args())
219 .saturating_sub(1);
220 self.handle_args_outputs(args, 1 + has_costs as usize);
221 }
222 Node::ImplMod(AstarTake, args, _) => {
223 self.pop();
224 self.pop();
225 let [neighbors, heuristic, is_goal] = get_args(args)?;
226 let has_costs = neighbors.outputs() == 2;
227 let args = neighbors
228 .args()
229 .max(heuristic.args())
230 .max(is_goal.args())
231 .saturating_sub(1);
232 self.handle_args_outputs(args, 1 + has_costs as usize);
233 }
234 Node::ImplMod(AstarPop, args, _) => {
235 self.pop();
236 let [neighbors, heuristic, is_goal] = get_args(args)?;
237 let has_costs = neighbors.outputs() == 2;
238 let args = neighbors
239 .args()
240 .max(heuristic.args())
241 .max(is_goal.args())
242 .saturating_sub(1);
243 self.handle_args_outputs(args, has_costs as usize);
244 }
245 Node::Mod(Path, args, _)
246 | Node::ImplMod(PathFirst, args, _)
247 | Node::ImplMod(PathSignLen, args, _) => {
248 self.pop();
249 let [neighbors, is_goal] = get_args(args)?;
250 let has_costs = neighbors.outputs() == 2;
251 let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
252 self.handle_args_outputs(args, 1 + has_costs as usize);
253 }
254 Node::ImplMod(PathTake, args, _) => {
255 self.pop();
256 self.pop();
257 let [neighbors, is_goal] = get_args(args)?;
258 let has_costs = neighbors.outputs() == 2;
259 let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
260 self.handle_args_outputs(args, 1 + has_costs as usize);
261 }
262 Node::ImplMod(PathPop, args, _) => {
263 self.pop();
264 let [neighbors, is_goal] = get_args(args)?;
265 let has_costs = neighbors.outputs() == 2;
266 let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
267 self.handle_args_outputs(args, has_costs as usize);
268 }
269 Node::Prim(prim, _) => {
270 let args = prim
271 .args()
272 .ok_or_else(|| format!("{prim} has indeterminate args"))?;
273 let outputs = prim
274 .outputs()
275 .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
276 self.handle_args_outputs(args, outputs);
277 }
278 Node::ImplPrim(prim, _) => {
279 let args = prim
280 .args()
281 .ok_or_else(|| format!("{prim} has indeterminate args"))?;
282 let outputs = prim
283 .outputs()
284 .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
285 self.handle_args_outputs(args, outputs);
286 }
287 Node::Mod(prim, args, _) => match prim {
288 Reduce | Scan => {
289 let [sig] = get_args(args)?;
290 let args = sig.args().saturating_sub(sig.outputs()).max(1);
291 self.handle_args_outputs(args, sig.outputs());
292 }
293 Each | Rows | Inventory => {
294 let [f] = get_args_nodes(args)?;
295 self.sig_node(f)?;
296 }
297 Table | Tuples => {
298 let [sig] = get_args(args)?;
299 self.handle_sig(sig);
300 }
301 Stencil => {
302 let [sig] = get_args(args)?;
303 if sig.args() <= 1 {
304 self.pop();
305 }
306 self.handle_args_outputs(1, sig.outputs());
307 }
308 Group | Partition => {
309 let [sig] = get_args(args)?;
310 self.handle_args_outputs(sig.args().max(1) + 1, sig.outputs());
311 }
312 Spawn | Pool => {
313 let [sig] = get_args(args)?;
314 self.handle_args_outputs(sig.args(), 1);
315 }
316 Repeat => {
317 let [f] = get_args_nodes(args)?;
318 self.pop();
319 self.repeat(f)?;
320 }
321 Do => {
322 let [body, cond] = get_args(args)?;
323 let copy_count = cond.args().saturating_sub(cond.outputs().saturating_sub(1));
324 let cond_sub_sig = Signature::new(
325 cond.args(),
326 (cond.outputs() + copy_count).saturating_sub(1),
327 );
328 let comp_sig = body.compose(cond_sub_sig);
329 self.handle_args_outputs(
330 comp_sig.args(),
331 comp_sig.outputs() + cond_sub_sig.outputs().saturating_sub(cond.args()),
332 );
333 if comp_sig.args() < comp_sig.outputs() {
334 self.stack.pop_n(comp_sig.args());
335 }
336 }
337 Un => {
338 let [sig] = get_args(args)?;
339 self.handle_sig(sig.inverse());
340 }
341 Anti => {
342 let [sig] = get_args(args)?;
343 self.handle_sig(sig.anti().unwrap_or(sig));
344 }
345 Fold => {
346 let [f] = get_args(args)?;
347 if f == (0, 0) {
348 } else if f.outputs() >= f.args() {
349 self.handle_args_outputs(f.args(), f.outputs() + 1 - f.args());
350 } else {
351 self.handle_sig(f);
352 }
353 }
354 Try => {
355 let [mut f_sig, handler_sig] = get_args(args)?;
356 f_sig.update_outputs(|o| o.max(handler_sig.outputs()));
357 self.handle_sig(f_sig);
358 }
359 Case => {
360 let [f] = get_args(args)?;
361 self.handle_sig(f);
362 }
363 Fill => self.fill(args)?,
364 Content | Memo | Comptime => {
365 let [f] = get_args(args)?;
366 self.handle_sig(f);
367 }
368 Dump => {
369 let [_] = get_args(args)?;
370 }
371 Fork => {
372 let [f, g] = get_args(args)?;
373 self.handle_args_outputs(f.args().max(g.args()), f.outputs() + g.outputs());
374 }
375 Bracket => {
376 let (args, outputs) = args.iter().fold((0, 0), |(a, o), sn| {
377 (a + sn.sig.args(), o + sn.sig.outputs())
378 });
379 self.handle_args_outputs(args, outputs);
380 }
381 Both => {
382 let [f] = get_args_nodes(args)?;
383 self.stack.pop_n(f.sig.args());
384 self.sig_node(f)?;
385 self.stack.push_n(f.sig.args());
386 self.sig_node(f)?;
387 }
388 Dip => {
389 let [f] = get_args_nodes(args)?;
390 self.pop();
391 self.sig_node(f)?;
392 self.push();
393 }
394 Gap => {
395 let [f] = get_args_nodes(args)?;
396 self.pop();
397 self.sig_node(f)?;
398 }
399 Reach => {
400 let [f] = get_args_nodes(args)?;
401 self.pop();
402 self.pop();
403 self.push();
404 self.sig_node(f)?;
405 }
406 On => {
407 let [f] = get_args_nodes(args)?;
408 self.pop();
409 self.push();
410 self.sig_node(f)?;
411 self.push();
412 }
413 By => {
414 let [f] = get_args_nodes(args)?;
415 self.sig_node(f)?;
416 self.push();
417 }
418 Above | Below => {
419 let [f] = get_args(args)?;
420 self.handle_args_outputs(f.args(), f.args() + f.outputs());
421 }
422 With | Off => {
423 let [f] = get_args(args)?;
424 self.handle_args_outputs(f.args(), f.outputs() + 1);
425 }
426 Recur => {
427 let [is_leaf, children, combine] = get_args(args)?;
428 let args = is_leaf
429 .args()
430 .max(children.args())
431 .max(combine.args().saturating_sub(1).max(1));
432 self.handle_args_outputs(args, 1);
433 }
434 Sys(SysOp::ReadLines) => {
435 let [f] = get_args(args)?;
436 self.handle_sig(f);
437 }
438 Sys(SysOp::AudioStream) => {
439 let [f] = get_args(args)?;
440 self.handle_args_outputs(
441 f.args().saturating_sub(1),
442 f.outputs().saturating_sub(1),
443 );
444 }
445 prim if prim.modifier_args().is_some() => {
446 if let Some(sig) = prim.sig() {
447 self.handle_sig(sig);
448 } else {
449 return Err(SigCheckError::from(format!(
450 "{} was not checked. This is a bug in the interpreter",
451 prim.format()
452 )));
453 }
454 }
455 prim => {
456 return Err(SigCheckError::from(format!(
457 "{} was checked as a modifier. This is a bug in the interpreter",
458 prim.format()
459 )));
460 }
461 },
462 Node::ImplMod(prim, args, _) => match prim {
463 &OnSub(n) | &BySub(n) | &WithSub(n) | &OffSub(n) => {
464 let [sn] = get_args_nodes(args)?;
465 let args = sn.sig.args().max(n);
466 self.handle_args_outputs(args, args);
467 self.sig_node(sn)?;
468 self.handle_args_outputs(0, n);
469 }
470 &DipN(n) => {
471 let [mut sig] = get_args(args)?;
472 sig.update_args_outputs(|a, o| (a + n, o + n));
473 self.handle_sig(sig);
474 }
475 ReduceContent | ReduceDepth(_) => {
476 let [sig] = get_args(args)?;
477 let args = sig.args().saturating_sub(sig.outputs());
478 self.handle_args_outputs(args, sig.outputs());
479 }
480 ReduceConjoinInventory => {
481 let [sig] = get_args(args)?;
482 self.handle_sig(sig);
483 }
484 RepeatWithInverse => {
485 let [f, inv] = get_args_nodes(args)?;
486 if f.sig.inverse() != inv.sig {
487 return Err(SigCheckError::from(
488 "repeat inverse does not have inverse signature",
489 ));
490 }
491 self.pop();
492 self.repeat(f)?;
493 }
494 RepeatCountConvergence => {
495 let [f] = get_args_nodes(args)?;
496 self.repeat(f)?;
497 self.push();
498 }
499 UnFill | SidedFill(_) => self.fill(args)?,
500 UnBracket => {
501 let (args, outputs) = args.iter().fold((0, 0), |(a, o), sn| {
502 (a + sn.sig.args(), o + sn.sig.outputs())
503 });
504 self.handle_args_outputs(args, outputs);
505 }
506 BothImpl(sub) | UnBothImpl(sub) => {
507 let [f] = get_args(args)?;
508 let reused = sub.side.map(|side| side.n.unwrap_or(1)).unwrap_or(0);
509 let n = sub.num.unwrap_or(2) as usize;
510 let unique = f.args().saturating_sub(reused) * n;
511 let sig = Signature::new(unique + reused, n * f.outputs())
512 .with_under(n * f.under_args(), n * f.under_outputs());
513 self.handle_sig(sig);
514 }
515 EachSub(_) => {
516 let [f] = get_args_nodes(args)?;
517 self.sig_node(f)?;
518 }
519 RowsSub(sub, _) => {
520 let [mut f] = get_args(args)?;
521 f.update_args_outputs(|a, o| {
522 let new_a = a.max(sub.side.and_then(|side| side.n).unwrap_or(0));
523 let new_o = o + new_a - a;
524 (new_a, new_o)
525 });
526 self.handle_sig(f);
527 }
528 UndoRows | UndoInventory => {
529 let [f] = get_args_nodes(args)?;
530 self.stack.pop();
531 self.sig_node(f)?;
532 }
533 UnScan => self.handle_args_outputs(1, 1),
534 SplitBy | SplitByScalar | SplitByKeepEmpty => {
535 let [f] = get_args(args)?;
536 self.handle_args_outputs(2, f.outputs());
537 }
538 prim => {
539 let args = prim
540 .args()
541 .ok_or_else(|| format!("{prim} has indeterminate args"))?;
542 let outputs = prim
543 .outputs()
544 .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
545 for _ in 0..args {
546 self.pop();
547 }
548 for _ in 0..outputs {
549 self.push();
550 }
551 }
552 },
553 Node::SetOutputComment { .. } => {}
554 Node::ValidateType { .. } => self.handle_args_outputs(1, 1),
555 Node::PushUnder(n, _) => {
556 for _ in 0..*n {
557 self.stack.pop();
558 self.under.push();
559 }
560 }
561 Node::CopyToUnder(n, _) => {
562 for _ in 0..*n {
563 self.stack.pop();
564 self.under.push();
565 self.stack.push();
566 }
567 }
568 Node::PopUnder(n, _) => {
569 for _ in 0..*n {
570 self.under.pop();
571 self.stack.push();
572 }
573 }
574 Node::TrackCaller(inner) | Node::NoInline(inner) => self.node(inner)?,
575 }
576 self.node_depth -= 1;
577 Ok(())
579 }
580 fn push(&mut self) {
581 self.stack.push();
582 }
583 fn pop(&mut self) {
584 self.stack.pop()
585 }
586 fn handle_args_outputs(&mut self, args: usize, outputs: usize) {
587 self.stack.handle_args_outputs(args, outputs);
588 }
589 fn handle_sig(&mut self, sig: Signature) {
590 self.stack.handle_args_outputs(sig.args(), sig.outputs());
591 self.under
592 .handle_args_outputs(sig.under_args(), sig.under_outputs());
593 }
594 fn fill(&mut self, args: &[SigNode]) -> Result<(), SigCheckError> {
595 let [fill, f] = get_args_nodes(args)?;
596 if fill.sig.outputs() > 0 || fill.sig.args() > 0 && fill.sig.outputs() != 0 {
597 self.sig_node(fill)?;
598 }
599 self.handle_args_outputs(fill.sig.outputs(), 0);
600 self.sig_node(f)
601 }
602 fn repeat(&mut self, sn: &SigNode) -> Result<(), SigCheckError> {
603 let sig = sn.sig;
604 self.sig_node(sn)?;
605 if sig.outputs() > sig.args() {
606 self.stack.pop_n(sig.args());
607 }
608 Ok(())
609 }
610}
611
612fn get_args_nodes<const N: usize>(args: &[SigNode]) -> Result<[&SigNode; N], SigCheckError> {
613 if args.len() != N {
614 return Err(format!(
615 "Expected {} operand{}, but got {}",
616 N,
617 if N == 1 { "" } else { "s" },
618 args.len()
619 )
620 .into());
621 }
622 Ok(array::from_fn(|i| &args[i]))
623}
624
625fn get_args<const N: usize>(args: &[SigNode]) -> Result<[Signature; N], SigCheckError> {
626 let mut res = [Signature::default(); N];
627 if args.len() != N {
628 return Err(format!(
629 "Expected {} operand{}, but got {}",
630 N,
631 if N == 1 { "" } else { "s" },
632 args.len()
633 )
634 .into());
635 }
636 for (i, arg) in args.iter().enumerate() {
637 res[i] = arg.sig;
638 }
639 Ok(res)
640}