uiua/
check.rs

1//! Signature checker implementation
2
3use std::{
4    array,
5    cell::RefCell,
6    collections::HashMap,
7    fmt,
8    hash::{DefaultHasher, Hash, Hasher},
9    slice,
10};
11
12use serde::*;
13
14use crate::{ImplPrimitive, Node, Primitive, SigNode, Signature, SysOp};
15
16impl Node {
17    /// Get the signature of this node
18    pub fn sig(&self) -> Result<Signature, SigCheckError> {
19        nodes_all_sigs(self.as_slice())
20    }
21    /// Convert this node to a [`SigNode`]
22    pub fn sig_node(self) -> Result<SigNode, SigCheckError> {
23        let sig = self.sig()?;
24        Ok(SigNode::new(sig, self.clone()))
25    }
26    /// Get the signature of this node if there is no net temp stack change
27    pub fn clean_sig(&self) -> Option<Signature> {
28        nodes_clean_sig(slice::from_ref(self))
29    }
30}
31
32pub fn nodes_sig(nodes: &[Node]) -> Result<Signature, SigCheckError> {
33    nodes_all_sigs(nodes)
34}
35
36pub fn nodes_clean_sig(nodes: &[Node]) -> Option<Signature> {
37    let sig = nodes_all_sigs(nodes).ok()?;
38    if sig.under_args() != 0 || sig.under_outputs() != 0 {
39        None
40    } else {
41        Some(sig)
42    }
43}
44
45fn nodes_all_sigs(nodes: &[Node]) -> Result<Signature, SigCheckError> {
46    type AllSigsCache = HashMap<u64, Signature>;
47    thread_local! {
48        static CACHE: RefCell<AllSigsCache> = RefCell::new(AllSigsCache::new());
49    }
50    let mut hasher = DefaultHasher::new();
51    nodes.hash(&mut hasher);
52    let hash = hasher.finish();
53    CACHE.with(|cache| {
54        if let Some(sigs) = cache.borrow().get(&hash) {
55            return Ok(*sigs);
56        }
57        let env = VirtualEnv::from_nodes(nodes)?;
58        let under_sig = env.under.sig();
59        let sig = (env.stack.sig()).with_under(under_sig.args(), under_sig.outputs());
60        cache.borrow_mut().insert(hash, sig);
61        Ok(sig)
62    })
63}
64
65/// An environment that emulates the runtime but only keeps track of the stack.
66struct VirtualEnv {
67    stack: Stack,
68    under: Stack,
69    node_depth: usize,
70}
71
72#[derive(Debug, Default)]
73struct Stack {
74    height: i32,
75    min_height: usize,
76}
77
78impl Stack {
79    // Simulate popping a value. Errors if the stack is empty, which means the function has too many args.
80    fn pop(&mut self) {
81        self.height -= 1;
82        self.set_min_height();
83    }
84    fn pop_n(&mut self, n: usize) {
85        self.height -= n as i32;
86        self.set_min_height();
87    }
88    fn push(&mut self) {
89        self.height += 1;
90    }
91    fn push_n(&mut self, n: usize) {
92        self.height += n as i32;
93    }
94    fn handle_args_outputs(&mut self, args: usize, outputs: usize) {
95        self.pop_n(args);
96        self.push_n(outputs);
97    }
98    /// Set the current stack height as a potential minimum.
99    /// At the end of checking, the minimum stack height is a component in calculating the signature.
100    fn set_min_height(&mut self) {
101        self.min_height = self.min_height.max((-self.height).max(0) as usize);
102    }
103    fn sig(&self) -> Signature {
104        Signature::new(
105            self.min_height,
106            (self.height + self.min_height as i32).max(0) as usize,
107        )
108    }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
112pub struct SigCheckError {
113    pub message: String,
114    pub kind: SigCheckErrorKind,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
118pub enum SigCheckErrorKind {
119    Incorrect,
120    NoInverse,
121}
122
123impl SigCheckError {
124    pub fn no_inverse(self) -> Self {
125        Self {
126            kind: SigCheckErrorKind::NoInverse,
127            ..self
128        }
129    }
130}
131
132impl<'a> From<&'a str> for SigCheckError {
133    fn from(s: &'a str) -> Self {
134        Self {
135            message: s.to_string(),
136            kind: SigCheckErrorKind::Incorrect,
137        }
138    }
139}
140
141impl From<String> for SigCheckError {
142    fn from(s: String) -> Self {
143        Self {
144            message: s,
145            kind: SigCheckErrorKind::Incorrect,
146        }
147    }
148}
149
150impl fmt::Display for SigCheckError {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        self.message.fmt(f)
153    }
154}
155
156const MAX_NODE_DEPTH: usize = if cfg!(debug_assertions) { 26 } else { 50 };
157
158impl VirtualEnv {
159    fn from_nodes(nodes: &[Node]) -> Result<Self, SigCheckError> {
160        // println!("\ncheck sig: {nodes:?}");
161        let mut env = VirtualEnv {
162            stack: Stack::default(),
163            under: Stack::default(),
164            node_depth: 0,
165        };
166        env.nodes(nodes)?;
167        Ok(env)
168    }
169    fn nodes(&mut self, nodes: &[Node]) -> Result<(), SigCheckError> {
170        nodes.iter().try_for_each(|node| self.node(node))
171    }
172    fn sig_node(&mut self, sn: &SigNode) -> Result<(), SigCheckError> {
173        self.node(&sn.node)
174    }
175    fn node(&mut self, node: &Node) -> Result<(), SigCheckError> {
176        use {ImplPrimitive::*, Primitive::*};
177        if self.node_depth > MAX_NODE_DEPTH {
178            return Err("Function is too complex".into());
179        }
180        self.node_depth += 1;
181        match node {
182            Node::Run(nodes) => nodes.iter().try_for_each(|node| self.node(node))?,
183            Node::Push(_) => self.push(),
184            Node::Array { len, inner, .. } => {
185                self.node(inner)?;
186                self.stack.pop_n(*len);
187                self.stack.push();
188            }
189            Node::Label(..) | Node::RemoveLabel(..) => self.handle_args_outputs(1, 1),
190            Node::Call(func, _) => self.handle_sig(func.sig),
191            Node::CallMacro { sig, .. } | Node::CallGlobal(_, sig) => self.handle_sig(*sig),
192            Node::BindGlobal { .. } => self.handle_args_outputs(1, 0),
193            Node::CustomInverse(cust, _) => self.handle_sig(cust.sig()?),
194            Node::Dynamic(dy) => self.handle_sig(dy.sig),
195            &Node::Switch {
196                sig, under_cond, ..
197            } => {
198                self.pop();
199                self.handle_sig(sig);
200                if under_cond {
201                    self.under.push();
202                }
203            }
204            Node::Format(parts, ..) => self.handle_args_outputs(parts.len().saturating_sub(1), 1),
205            Node::MatchFormatPattern(parts, ..) => {
206                self.handle_args_outputs(1, parts.len().saturating_sub(1))
207            }
208            Node::Unpack { count, .. } => self.handle_args_outputs(1, *count),
209            Node::ImplMod(Astar, args, _)
210            | Node::ImplMod(AstarFirst, args, _)
211            | Node::ImplMod(AstarSignLen, args, _) => {
212                self.pop();
213                let [neighbors, heuristic, is_goal] = get_args(args)?;
214                let has_costs = neighbors.outputs() == 2;
215                let args = neighbors
216                    .args()
217                    .max(heuristic.args())
218                    .max(is_goal.args())
219                    .saturating_sub(1);
220                self.handle_args_outputs(args, 1 + has_costs as usize);
221            }
222            Node::ImplMod(AstarTake, args, _) => {
223                self.pop();
224                self.pop();
225                let [neighbors, heuristic, is_goal] = get_args(args)?;
226                let has_costs = neighbors.outputs() == 2;
227                let args = neighbors
228                    .args()
229                    .max(heuristic.args())
230                    .max(is_goal.args())
231                    .saturating_sub(1);
232                self.handle_args_outputs(args, 1 + has_costs as usize);
233            }
234            Node::ImplMod(AstarPop, args, _) => {
235                self.pop();
236                let [neighbors, heuristic, is_goal] = get_args(args)?;
237                let has_costs = neighbors.outputs() == 2;
238                let args = neighbors
239                    .args()
240                    .max(heuristic.args())
241                    .max(is_goal.args())
242                    .saturating_sub(1);
243                self.handle_args_outputs(args, has_costs as usize);
244            }
245            Node::Mod(Path, args, _)
246            | Node::ImplMod(PathFirst, args, _)
247            | Node::ImplMod(PathSignLen, args, _) => {
248                self.pop();
249                let [neighbors, is_goal] = get_args(args)?;
250                let has_costs = neighbors.outputs() == 2;
251                let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
252                self.handle_args_outputs(args, 1 + has_costs as usize);
253            }
254            Node::ImplMod(PathTake, args, _) => {
255                self.pop();
256                self.pop();
257                let [neighbors, is_goal] = get_args(args)?;
258                let has_costs = neighbors.outputs() == 2;
259                let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
260                self.handle_args_outputs(args, 1 + has_costs as usize);
261            }
262            Node::ImplMod(PathPop, args, _) => {
263                self.pop();
264                let [neighbors, is_goal] = get_args(args)?;
265                let has_costs = neighbors.outputs() == 2;
266                let args = neighbors.args().max(is_goal.args()).saturating_sub(1);
267                self.handle_args_outputs(args, has_costs as usize);
268            }
269            Node::Prim(prim, _) => {
270                let args = prim
271                    .args()
272                    .ok_or_else(|| format!("{prim} has indeterminate args"))?;
273                let outputs = prim
274                    .outputs()
275                    .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
276                self.handle_args_outputs(args, outputs);
277            }
278            Node::ImplPrim(prim, _) => {
279                let args = prim
280                    .args()
281                    .ok_or_else(|| format!("{prim} has indeterminate args"))?;
282                let outputs = prim
283                    .outputs()
284                    .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
285                self.handle_args_outputs(args, outputs);
286            }
287            Node::Mod(prim, args, _) => match prim {
288                Reduce | Scan => {
289                    let [sig] = get_args(args)?;
290                    let args = sig.args().saturating_sub(sig.outputs()).max(1);
291                    self.handle_args_outputs(args, sig.outputs());
292                }
293                Each | Rows | Inventory => {
294                    let [f] = get_args_nodes(args)?;
295                    self.sig_node(f)?;
296                }
297                Table | Tuples => {
298                    let [sig] = get_args(args)?;
299                    self.handle_sig(sig);
300                }
301                Stencil => {
302                    let [sig] = get_args(args)?;
303                    if sig.args() <= 1 {
304                        self.pop();
305                    }
306                    self.handle_args_outputs(1, sig.outputs());
307                }
308                Group | Partition => {
309                    let [sig] = get_args(args)?;
310                    self.handle_args_outputs(sig.args().max(1) + 1, sig.outputs());
311                }
312                Spawn | Pool => {
313                    let [sig] = get_args(args)?;
314                    self.handle_args_outputs(sig.args(), 1);
315                }
316                Repeat => {
317                    let [f] = get_args_nodes(args)?;
318                    self.pop();
319                    self.repeat(f)?;
320                }
321                Do => {
322                    let [body, cond] = get_args(args)?;
323                    let copy_count = cond.args().saturating_sub(cond.outputs().saturating_sub(1));
324                    let cond_sub_sig = Signature::new(
325                        cond.args(),
326                        (cond.outputs() + copy_count).saturating_sub(1),
327                    );
328                    let comp_sig = body.compose(cond_sub_sig);
329                    self.handle_args_outputs(
330                        comp_sig.args(),
331                        comp_sig.outputs() + cond_sub_sig.outputs().saturating_sub(cond.args()),
332                    );
333                    if comp_sig.args() < comp_sig.outputs() {
334                        self.stack.pop_n(comp_sig.args());
335                    }
336                }
337                Un => {
338                    let [sig] = get_args(args)?;
339                    self.handle_sig(sig.inverse());
340                }
341                Anti => {
342                    let [sig] = get_args(args)?;
343                    self.handle_sig(sig.anti().unwrap_or(sig));
344                }
345                Fold => {
346                    let [f] = get_args(args)?;
347                    if f == (0, 0) {
348                    } else if f.outputs() >= f.args() {
349                        self.handle_args_outputs(f.args(), f.outputs() + 1 - f.args());
350                    } else {
351                        self.handle_sig(f);
352                    }
353                }
354                Try => {
355                    let [mut f_sig, handler_sig] = get_args(args)?;
356                    f_sig.update_outputs(|o| o.max(handler_sig.outputs()));
357                    self.handle_sig(f_sig);
358                }
359                Case => {
360                    let [f] = get_args(args)?;
361                    self.handle_sig(f);
362                }
363                Fill => self.fill(args)?,
364                Content | Memo | Comptime => {
365                    let [f] = get_args(args)?;
366                    self.handle_sig(f);
367                }
368                Dump => {
369                    let [_] = get_args(args)?;
370                }
371                Fork => {
372                    let [f, g] = get_args(args)?;
373                    self.handle_args_outputs(f.args().max(g.args()), f.outputs() + g.outputs());
374                }
375                Bracket => {
376                    let (args, outputs) = args.iter().fold((0, 0), |(a, o), sn| {
377                        (a + sn.sig.args(), o + sn.sig.outputs())
378                    });
379                    self.handle_args_outputs(args, outputs);
380                }
381                Both => {
382                    let [f] = get_args_nodes(args)?;
383                    self.stack.pop_n(f.sig.args());
384                    self.sig_node(f)?;
385                    self.stack.push_n(f.sig.args());
386                    self.sig_node(f)?;
387                }
388                Dip => {
389                    let [f] = get_args_nodes(args)?;
390                    self.pop();
391                    self.sig_node(f)?;
392                    self.push();
393                }
394                Gap => {
395                    let [f] = get_args_nodes(args)?;
396                    self.pop();
397                    self.sig_node(f)?;
398                }
399                Reach => {
400                    let [f] = get_args_nodes(args)?;
401                    self.pop();
402                    self.pop();
403                    self.push();
404                    self.sig_node(f)?;
405                }
406                On => {
407                    let [f] = get_args_nodes(args)?;
408                    self.pop();
409                    self.push();
410                    self.sig_node(f)?;
411                    self.push();
412                }
413                By => {
414                    let [f] = get_args_nodes(args)?;
415                    self.sig_node(f)?;
416                    self.push();
417                }
418                Above | Below => {
419                    let [f] = get_args(args)?;
420                    self.handle_args_outputs(f.args(), f.args() + f.outputs());
421                }
422                With | Off => {
423                    let [f] = get_args(args)?;
424                    self.handle_args_outputs(f.args(), f.outputs() + 1);
425                }
426                Recur => {
427                    let [is_leaf, children, combine] = get_args(args)?;
428                    let args = is_leaf
429                        .args()
430                        .max(children.args())
431                        .max(combine.args().saturating_sub(1).max(1));
432                    self.handle_args_outputs(args, 1);
433                }
434                Sys(SysOp::ReadLines) => {
435                    let [f] = get_args(args)?;
436                    self.handle_sig(f);
437                }
438                Sys(SysOp::AudioStream) => {
439                    let [f] = get_args(args)?;
440                    self.handle_args_outputs(
441                        f.args().saturating_sub(1),
442                        f.outputs().saturating_sub(1),
443                    );
444                }
445                prim if prim.modifier_args().is_some() => {
446                    if let Some(sig) = prim.sig() {
447                        self.handle_sig(sig);
448                    } else {
449                        return Err(SigCheckError::from(format!(
450                            "{} was not checked. This is a bug in the interpreter",
451                            prim.format()
452                        )));
453                    }
454                }
455                prim => {
456                    return Err(SigCheckError::from(format!(
457                        "{} was checked as a modifier. This is a bug in the interpreter",
458                        prim.format()
459                    )));
460                }
461            },
462            Node::ImplMod(prim, args, _) => match prim {
463                &OnSub(n) | &BySub(n) | &WithSub(n) | &OffSub(n) => {
464                    let [sn] = get_args_nodes(args)?;
465                    let args = sn.sig.args().max(n);
466                    self.handle_args_outputs(args, args);
467                    self.sig_node(sn)?;
468                    self.handle_args_outputs(0, n);
469                }
470                &DipN(n) => {
471                    let [mut sig] = get_args(args)?;
472                    sig.update_args_outputs(|a, o| (a + n, o + n));
473                    self.handle_sig(sig);
474                }
475                ReduceContent | ReduceDepth(_) => {
476                    let [sig] = get_args(args)?;
477                    let args = sig.args().saturating_sub(sig.outputs());
478                    self.handle_args_outputs(args, sig.outputs());
479                }
480                ReduceConjoinInventory => {
481                    let [sig] = get_args(args)?;
482                    self.handle_sig(sig);
483                }
484                RepeatWithInverse => {
485                    let [f, inv] = get_args_nodes(args)?;
486                    if f.sig.inverse() != inv.sig {
487                        return Err(SigCheckError::from(
488                            "repeat inverse does not have inverse signature",
489                        ));
490                    }
491                    self.pop();
492                    self.repeat(f)?;
493                }
494                RepeatCountConvergence => {
495                    let [f] = get_args_nodes(args)?;
496                    self.repeat(f)?;
497                    self.push();
498                }
499                UnFill | SidedFill(_) => self.fill(args)?,
500                UnBracket => {
501                    let (args, outputs) = args.iter().fold((0, 0), |(a, o), sn| {
502                        (a + sn.sig.args(), o + sn.sig.outputs())
503                    });
504                    self.handle_args_outputs(args, outputs);
505                }
506                BothImpl(sub) | UnBothImpl(sub) => {
507                    let [f] = get_args(args)?;
508                    let reused = sub.side.map(|side| side.n.unwrap_or(1)).unwrap_or(0);
509                    let n = sub.num.unwrap_or(2) as usize;
510                    let unique = f.args().saturating_sub(reused) * n;
511                    let sig = Signature::new(unique + reused, n * f.outputs())
512                        .with_under(n * f.under_args(), n * f.under_outputs());
513                    self.handle_sig(sig);
514                }
515                EachSub(_) => {
516                    let [f] = get_args_nodes(args)?;
517                    self.sig_node(f)?;
518                }
519                RowsSub(sub, _) => {
520                    let [mut f] = get_args(args)?;
521                    f.update_args_outputs(|a, o| {
522                        let new_a = a.max(sub.side.and_then(|side| side.n).unwrap_or(0));
523                        let new_o = o + new_a - a;
524                        (new_a, new_o)
525                    });
526                    self.handle_sig(f);
527                }
528                UndoRows | UndoInventory => {
529                    let [f] = get_args_nodes(args)?;
530                    self.stack.pop();
531                    self.sig_node(f)?;
532                }
533                UnScan => self.handle_args_outputs(1, 1),
534                SplitBy | SplitByScalar | SplitByKeepEmpty => {
535                    let [f] = get_args(args)?;
536                    self.handle_args_outputs(2, f.outputs());
537                }
538                prim => {
539                    let args = prim
540                        .args()
541                        .ok_or_else(|| format!("{prim} has indeterminate args"))?;
542                    let outputs = prim
543                        .outputs()
544                        .ok_or_else(|| format!("{prim} has indeterminate outputs"))?;
545                    for _ in 0..args {
546                        self.pop();
547                    }
548                    for _ in 0..outputs {
549                        self.push();
550                    }
551                }
552            },
553            Node::SetOutputComment { .. } => {}
554            Node::ValidateType { .. } => self.handle_args_outputs(1, 1),
555            Node::PushUnder(n, _) => {
556                for _ in 0..*n {
557                    self.stack.pop();
558                    self.under.push();
559                }
560            }
561            Node::CopyToUnder(n, _) => {
562                for _ in 0..*n {
563                    self.stack.pop();
564                    self.under.push();
565                    self.stack.push();
566                }
567            }
568            Node::PopUnder(n, _) => {
569                for _ in 0..*n {
570                    self.under.pop();
571                    self.stack.push();
572                }
573            }
574            Node::TrackCaller(inner) | Node::NoInline(inner) => self.node(inner)?,
575        }
576        self.node_depth -= 1;
577        // println!("{node:?} -> {} ({})", self.stack.sig(), self.under.sig());
578        Ok(())
579    }
580    fn push(&mut self) {
581        self.stack.push();
582    }
583    fn pop(&mut self) {
584        self.stack.pop()
585    }
586    fn handle_args_outputs(&mut self, args: usize, outputs: usize) {
587        self.stack.handle_args_outputs(args, outputs);
588    }
589    fn handle_sig(&mut self, sig: Signature) {
590        self.stack.handle_args_outputs(sig.args(), sig.outputs());
591        self.under
592            .handle_args_outputs(sig.under_args(), sig.under_outputs());
593    }
594    fn fill(&mut self, args: &[SigNode]) -> Result<(), SigCheckError> {
595        let [fill, f] = get_args_nodes(args)?;
596        if fill.sig.outputs() > 0 || fill.sig.args() > 0 && fill.sig.outputs() != 0 {
597            self.sig_node(fill)?;
598        }
599        self.handle_args_outputs(fill.sig.outputs(), 0);
600        self.sig_node(f)
601    }
602    fn repeat(&mut self, sn: &SigNode) -> Result<(), SigCheckError> {
603        let sig = sn.sig;
604        self.sig_node(sn)?;
605        if sig.outputs() > sig.args() {
606            self.stack.pop_n(sig.args());
607        }
608        Ok(())
609    }
610}
611
612fn get_args_nodes<const N: usize>(args: &[SigNode]) -> Result<[&SigNode; N], SigCheckError> {
613    if args.len() != N {
614        return Err(format!(
615            "Expected {} operand{}, but got {}",
616            N,
617            if N == 1 { "" } else { "s" },
618            args.len()
619        )
620        .into());
621    }
622    Ok(array::from_fn(|i| &args[i]))
623}
624
625fn get_args<const N: usize>(args: &[SigNode]) -> Result<[Signature; N], SigCheckError> {
626    let mut res = [Signature::default(); N];
627    if args.len() != N {
628        return Err(format!(
629            "Expected {} operand{}, but got {}",
630            N,
631            if N == 1 { "" } else { "s" },
632            args.len()
633        )
634        .into());
635    }
636    for (i, arg) in args.iter().enumerate() {
637        res[i] = arg.sig;
638    }
639    Ok(res)
640}