world_id_core/requests/
constraints.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3
4/// Upper bound on total constraint AST nodes (expr + type nodes).
5/// This prevents extremely large request bodies from causing excessive work.
6pub const MAX_CONSTRAINT_NODES: usize = 12;
7
8/// Logical operator kinds supported in constraint expressions.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ConstraintKind {
12    /// All of the children must be satisfied
13    All,
14    /// Any of the children must be satisfied
15    Any,
16}
17
18/// Constraint expression tree: either a list of types/expressions under `all` or `any`.
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(deny_unknown_fields)]
21#[serde(untagged)]
22pub enum ConstraintExpr<'a> {
23    /// All children must be satisfied
24    All {
25        /// Children nodes that must all be satisfied
26        all: Vec<ConstraintNode<'a>>,
27    },
28    /// Any child may satisfy the expression
29    Any {
30        /// Children nodes where any one must be satisfied
31        any: Vec<ConstraintNode<'a>>,
32    },
33}
34
35/// Node of a constraint expression.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(deny_unknown_fields)]
38#[serde(untagged)]
39pub enum ConstraintNode<'a> {
40    /// Issuer schema id string
41    Type(Cow<'a, str>),
42    /// Expressions
43    Expr(ConstraintExpr<'a>),
44}
45
46impl ConstraintExpr<'_> {
47    /// Evaluate the constraint against a predicate that reports whether a issuer schema id was provided successfully
48    pub fn evaluate<F>(&self, has_type: &F) -> bool
49    where
50        F: Fn(&str) -> bool,
51    {
52        match self {
53            ConstraintExpr::All { all } => all.iter().all(|n| n.evaluate(has_type)),
54            ConstraintExpr::Any { any } => any.iter().any(|n| n.evaluate(has_type)),
55        }
56    }
57
58    /// Validate the maximum nesting depth. Depth counts the number of Expr nodes encountered.
59    /// A flat list has depth 1. Allow at most 2 (one nested level under root).
60    #[must_use]
61    pub fn validate_max_depth(&self, max_depth: usize) -> bool {
62        fn validate_expr(expr: &ConstraintExpr<'_>, depth: usize, max_depth: usize) -> bool {
63            if depth > max_depth {
64                return false;
65            }
66            match expr {
67                ConstraintExpr::All { all } => {
68                    all.iter().all(|n| validate_node(n, depth, max_depth))
69                }
70                ConstraintExpr::Any { any } => {
71                    any.iter().all(|n| validate_node(n, depth, max_depth))
72                }
73            }
74        }
75        fn validate_node(node: &ConstraintNode<'_>, parent_depth: usize, max_depth: usize) -> bool {
76            match node {
77                ConstraintNode::Type(_) => true,
78                ConstraintNode::Expr(child) => validate_expr(child, parent_depth + 1, max_depth),
79            }
80        }
81        validate_expr(self, 1, max_depth)
82    }
83
84    /// Validate the maximum total number of nodes in the constraint AST.
85    /// Counts both expression containers and leaf type nodes. Short-circuits
86    /// once the running total exceeds `max_nodes` to avoid full traversal.
87    #[must_use]
88    pub fn validate_max_nodes(&self, max_nodes: usize) -> bool {
89        fn count_expr(expr: &ConstraintExpr<'_>, count: &mut usize, max_nodes: usize) -> bool {
90            // Count the expr node itself
91            *count += 1;
92            if *count > max_nodes {
93                return false;
94            }
95            match expr {
96                ConstraintExpr::All { all } => {
97                    for n in all {
98                        if !count_node(n, count, max_nodes) {
99                            return false;
100                        }
101                    }
102                    true
103                }
104                ConstraintExpr::Any { any } => {
105                    for n in any {
106                        if !count_node(n, count, max_nodes) {
107                            return false;
108                        }
109                    }
110                    true
111                }
112            }
113        }
114
115        fn count_node(node: &ConstraintNode<'_>, count: &mut usize, max_nodes: usize) -> bool {
116            match node {
117                ConstraintNode::Type(_) => {
118                    *count += 1;
119                    *count <= max_nodes
120                }
121                ConstraintNode::Expr(child) => count_expr(child, count, max_nodes),
122            }
123        }
124
125        let mut count = 0;
126        count_expr(self, &mut count, max_nodes)
127    }
128}
129
130impl ConstraintNode<'_> {
131    fn evaluate<F>(&self, has_type: &F) -> bool
132    where
133        F: Fn(&str) -> bool,
134    {
135        match self {
136            ConstraintNode::Type(t) => has_type(t),
137            ConstraintNode::Expr(expr) => expr.evaluate(has_type),
138        }
139    }
140}