Skip to main content

world_id_primitives/request/
constraints.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3
4/// Upper bound on total constraint AST nodes (expr + type nodes).
5/// This prevents extremely large request bodies from causing excessive work.
6pub const MAX_CONSTRAINT_NODES: usize = 12;
7
8/// Logical operator kinds supported in constraint expressions.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ConstraintKind {
12    /// All of the children must be satisfied
13    All,
14    /// Any of the children must be satisfied
15    Any,
16    /// All satisfiable children should be selected. Will fail if there are no matches
17    Enumerate,
18}
19
20/// Constraint expression tree: a list of types/expressions under `all`, `any`, or `enumerate`.
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(deny_unknown_fields)]
23#[serde(untagged)]
24pub enum ConstraintExpr<'a> {
25    /// All children must be satisfied
26    All {
27        /// Children nodes that must all be satisfied
28        all: Vec<ConstraintNode<'a>>,
29    },
30    /// Any child may satisfy the expression
31    Any {
32        /// Children nodes where any one must be satisfied
33        any: Vec<ConstraintNode<'a>>,
34    },
35    /// All satisfiable children should be selected
36    Enumerate {
37        /// Children nodes to evaluate and collect if satisfiable
38        enumerate: Vec<ConstraintNode<'a>>,
39    },
40}
41
42/// Node of a constraint expression.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(deny_unknown_fields)]
45#[serde(untagged)]
46pub enum ConstraintNode<'a> {
47    /// Issuer schema id string
48    Type(Cow<'a, str>),
49    /// Expressions
50    Expr(ConstraintExpr<'a>),
51}
52
53impl ConstraintExpr<'_> {
54    /// Evaluate the constraint against a predicate that reports whether a issuer schema id was provided successfully
55    pub fn evaluate<F>(&self, has_type: &F) -> bool
56    where
57        F: Fn(&str) -> bool,
58    {
59        match self {
60            ConstraintExpr::All { all } => all.iter().all(|n| n.evaluate(has_type)),
61            ConstraintExpr::Any { any } => any.iter().any(|n| n.evaluate(has_type)),
62            ConstraintExpr::Enumerate { enumerate } => {
63                enumerate.iter().any(|n| n.evaluate(has_type))
64            }
65        }
66    }
67
68    /// Validate the maximum nesting depth. Depth counts the number of Expr nodes encountered.
69    /// A flat list has depth 1. Allow at most 2 (one nested level under root).
70    #[must_use]
71    pub fn validate_max_depth(&self, max_depth: usize) -> bool {
72        fn validate_expr(expr: &ConstraintExpr<'_>, depth: usize, max_depth: usize) -> bool {
73            if depth > max_depth {
74                return false;
75            }
76            match expr {
77                ConstraintExpr::All { all } => {
78                    all.iter().all(|n| validate_node(n, depth, max_depth))
79                }
80                ConstraintExpr::Any { any } => {
81                    any.iter().all(|n| validate_node(n, depth, max_depth))
82                }
83                ConstraintExpr::Enumerate { enumerate } => {
84                    enumerate.iter().all(|n| validate_node(n, depth, max_depth))
85                }
86            }
87        }
88        fn validate_node(node: &ConstraintNode<'_>, parent_depth: usize, max_depth: usize) -> bool {
89            match node {
90                ConstraintNode::Type(_) => true,
91                ConstraintNode::Expr(child) => validate_expr(child, parent_depth + 1, max_depth),
92            }
93        }
94        validate_expr(self, 1, max_depth)
95    }
96
97    /// Validate the maximum total number of nodes in the constraint AST.
98    /// Counts both expression containers and leaf type nodes. Short-circuits
99    /// once the running total exceeds `max_nodes` to avoid full traversal.
100    #[must_use]
101    pub fn validate_max_nodes(&self, max_nodes: usize) -> bool {
102        fn count_expr(expr: &ConstraintExpr<'_>, count: &mut usize, max_nodes: usize) -> bool {
103            // Count the expr node itself
104            *count += 1;
105            if *count > max_nodes {
106                return false;
107            }
108            match expr {
109                ConstraintExpr::All { all } => {
110                    for n in all {
111                        if !count_node(n, count, max_nodes) {
112                            return false;
113                        }
114                    }
115                    true
116                }
117                ConstraintExpr::Any { any } => {
118                    for n in any {
119                        if !count_node(n, count, max_nodes) {
120                            return false;
121                        }
122                    }
123                    true
124                }
125                ConstraintExpr::Enumerate { enumerate } => {
126                    for n in enumerate {
127                        if !count_node(n, count, max_nodes) {
128                            return false;
129                        }
130                    }
131                    true
132                }
133            }
134        }
135
136        fn count_node(node: &ConstraintNode<'_>, count: &mut usize, max_nodes: usize) -> bool {
137            match node {
138                ConstraintNode::Type(_) => {
139                    *count += 1;
140                    *count <= max_nodes
141                }
142                ConstraintNode::Expr(child) => count_expr(child, count, max_nodes),
143            }
144        }
145
146        let mut count = 0;
147        count_expr(self, &mut count, max_nodes)
148    }
149}
150
151impl ConstraintNode<'_> {
152    fn evaluate<F>(&self, has_type: &F) -> bool
153    where
154        F: Fn(&str) -> bool,
155    {
156        match self {
157            ConstraintNode::Type(t) => has_type(t),
158            ConstraintNode::Expr(expr) => expr.evaluate(has_type),
159        }
160    }
161}