Skip to main content

tiptap_rusty_parser/
schema.rs

1//! Opt-in schema validation.
2//!
3//! The crate is schema-agnostic by default; nothing here runs unless you call
4//! [`Node::validate`]. A [`Schema`] is an allow-list of node types, marks,
5//! attributes, and child types. Validation collects *all* problems as
6//! [`Violation`]s (each carrying the offending node's index path), so a single
7//! pass reports everything wrong.
8//!
9//! A schema can be built in Rust or loaded from JSON:
10//!
11//! ```
12//! use tiptap_rusty_parser::{Document, Schema, NodeSpec, MarkSpec};
13//!
14//! let schema = Schema::new()
15//!     .node("doc", NodeSpec::new().content(["paragraph"]))
16//!     .node("paragraph", NodeSpec::new().content(["text"]).marks(["bold"]))
17//!     .node("text", NodeSpec::new())
18//!     .mark("bold", MarkSpec::new());
19//!
20//! let doc = Document::from_json_str(
21//!     r#"{"type":"doc","content":[{"type":"paragraph","content":[{"type":"text","text":"hi"}]}]}"#,
22//! ).unwrap();
23//! assert!(doc.is_valid(&schema));
24//! ```
25
26use crate::content::{ContentExpr, ContentRule, ParseExprError};
27use crate::node::Node;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30use std::fmt;
31
32/// An allow-list schema: which node/mark types, attributes, and children are
33/// permitted. Build with [`Schema::new`] or load with [`Schema::from_json_str`].
34#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
35pub struct Schema {
36    /// Node type -> its spec. Types absent here are reported as unknown.
37    #[serde(default)]
38    pub nodes: HashMap<String, NodeSpec>,
39    /// Mark type -> its spec. Marks absent here are reported as unknown.
40    #[serde(default)]
41    pub marks: HashMap<String, MarkSpec>,
42}
43
44/// Rules for one node type.
45#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
46pub struct NodeSpec {
47    /// Allowed content: a set of child types (array form) or a content
48    /// expression (string form). `None` = any child allowed.
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub content: Option<ContentRule>,
51    /// Groups this node belongs to (space-separated, ProseMirror-style), so
52    /// content expressions can reference it by group name (e.g. `block`).
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub group: Option<String>,
55    /// Allowed mark types on this node. `None` = any mark allowed.
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub marks: Option<HashSet<String>>,
58    /// Allowed attribute keys. `None` = any attrs allowed.
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub attrs: Option<HashSet<String>>,
61    /// Attribute keys that must be present.
62    #[serde(default, skip_serializing_if = "HashSet::is_empty")]
63    pub required_attrs: HashSet<String>,
64}
65
66/// Rules for one mark type.
67#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
68pub struct MarkSpec {
69    /// Allowed attribute keys. `None` = any attrs allowed.
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub attrs: Option<HashSet<String>>,
72    /// Attribute keys that must be present.
73    #[serde(default, skip_serializing_if = "HashSet::is_empty")]
74    pub required_attrs: HashSet<String>,
75}
76
77fn into_set<I, S>(items: I) -> HashSet<String>
78where
79    I: IntoIterator<Item = S>,
80    S: Into<String>,
81{
82    items.into_iter().map(Into::into).collect()
83}
84
85impl Schema {
86    /// An empty schema.
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Register (or replace) a node type's spec.
92    pub fn node(mut self, node_type: impl Into<String>, spec: NodeSpec) -> Self {
93        self.nodes.insert(node_type.into(), spec);
94        self
95    }
96
97    /// Register (or replace) a mark type's spec.
98    pub fn mark(mut self, mark_type: impl Into<String>, spec: MarkSpec) -> Self {
99        self.marks.insert(mark_type.into(), spec);
100        self
101    }
102
103    /// Load a schema from its JSON definition.
104    ///
105    /// ```
106    /// use tiptap_rusty_parser::Schema;
107    /// let schema = Schema::from_json_str(r#"{
108    ///   "nodes": { "doc": { "content": ["paragraph"] }, "paragraph": { "content": ["text"] }, "text": {} },
109    ///   "marks": { "link": { "attrs": ["href"], "required_attrs": ["href"] } }
110    /// }"#).unwrap();
111    /// assert!(schema.nodes.contains_key("doc"));
112    /// ```
113    pub fn from_json_str(s: &str) -> crate::Result<Self> {
114        Ok(serde_json::from_str(s)?)
115    }
116}
117
118impl NodeSpec {
119    /// An unrestricted node spec (any attrs/marks/children allowed).
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    /// Restrict allowed child node types (any count/order). For ordering and
125    /// cardinality use [`content_match`](Self::content_match).
126    pub fn content<I, S>(mut self, types: I) -> Self
127    where
128        I: IntoIterator<Item = S>,
129        S: Into<String>,
130    {
131        self.content = Some(ContentRule::Types(into_set(types)));
132        self
133    }
134
135    /// Restrict content with a ProseMirror content expression (e.g.
136    /// `"heading paragraph+"`). Panics on an invalid expression — use
137    /// [`try_content_match`](Self::try_content_match) to handle the error.
138    pub fn content_match(self, expr: &str) -> Self {
139        self.try_content_match(expr)
140            .expect("invalid content expression")
141    }
142
143    /// Fallible [`content_match`](Self::content_match).
144    pub fn try_content_match(mut self, expr: &str) -> Result<Self, ParseExprError> {
145        self.content = Some(ContentRule::Expr(ContentExpr::parse(expr)?));
146        Ok(self)
147    }
148
149    /// Set the groups this node belongs to (space-separated, e.g. `"block"`).
150    pub fn group(mut self, group: impl Into<String>) -> Self {
151        self.group = Some(group.into());
152        self
153    }
154
155    /// The allowed child-type set, if `content` is the array form (not an
156    /// expression). Convenience accessor for the `content` field.
157    pub fn content_types(&self) -> Option<&HashSet<String>> {
158        match &self.content {
159            Some(ContentRule::Types(set)) => Some(set),
160            _ => None,
161        }
162    }
163
164    /// Restrict allowed mark types on this node.
165    pub fn marks<I, S>(mut self, types: I) -> Self
166    where
167        I: IntoIterator<Item = S>,
168        S: Into<String>,
169    {
170        self.marks = Some(into_set(types));
171        self
172    }
173
174    /// Restrict allowed attribute keys.
175    pub fn attrs<I, S>(mut self, keys: I) -> Self
176    where
177        I: IntoIterator<Item = S>,
178        S: Into<String>,
179    {
180        self.attrs = Some(into_set(keys));
181        self
182    }
183
184    /// Set attribute keys that must be present.
185    pub fn required_attrs<I, S>(mut self, keys: I) -> Self
186    where
187        I: IntoIterator<Item = S>,
188        S: Into<String>,
189    {
190        self.required_attrs = into_set(keys);
191        self
192    }
193}
194
195impl MarkSpec {
196    /// An unrestricted mark spec (any attrs allowed).
197    pub fn new() -> Self {
198        Self::default()
199    }
200
201    /// Restrict allowed attribute keys.
202    pub fn attrs<I, S>(mut self, keys: I) -> Self
203    where
204        I: IntoIterator<Item = S>,
205        S: Into<String>,
206    {
207        self.attrs = Some(into_set(keys));
208        self
209    }
210
211    /// Set attribute keys that must be present.
212    pub fn required_attrs<I, S>(mut self, keys: I) -> Self
213    where
214        I: IntoIterator<Item = S>,
215        S: Into<String>,
216    {
217        self.required_attrs = into_set(keys);
218        self
219    }
220}
221
222/// A single schema violation, located by the offending node's index path.
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Violation {
225    /// Index path to the node (root = `[]`), as used by [`Node::node_at`].
226    pub path: Vec<usize>,
227    /// What's wrong.
228    pub kind: ViolationKind,
229}
230
231/// The kinds of schema violation.
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
233pub enum ViolationKind {
234    /// A node had no `type`.
235    MissingNodeType,
236    /// A node's type is not in the schema.
237    UnknownNodeType(String),
238    /// A child type is not allowed under its parent.
239    DisallowedChild { parent: String, child: String },
240    /// A node's children don't satisfy its content expression.
241    InvalidContent { parent: String, expr: String },
242    /// A mark type is not in the schema.
243    UnknownMark(String),
244    /// A mark is registered but not allowed on this node type.
245    DisallowedMark { node: String, mark: String },
246    /// A required attribute is missing (on a node, or a mark of the node).
247    MissingAttr { key: String },
248    /// An attribute key is not in the allowed set.
249    UnknownAttr { key: String },
250}
251
252impl fmt::Display for ViolationKind {
253    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254        match self {
255            ViolationKind::MissingNodeType => write!(f, "node has no type"),
256            ViolationKind::UnknownNodeType(t) => write!(f, "unknown node type `{t}`"),
257            ViolationKind::DisallowedChild { parent, child } => {
258                write!(f, "node type `{child}` not allowed inside `{parent}`")
259            }
260            ViolationKind::InvalidContent { parent, expr } => {
261                write!(
262                    f,
263                    "children of `{parent}` do not match content expression `{expr}`"
264                )
265            }
266            ViolationKind::UnknownMark(m) => write!(f, "unknown mark type `{m}`"),
267            ViolationKind::DisallowedMark { node, mark } => {
268                write!(f, "mark `{mark}` not allowed on `{node}`")
269            }
270            ViolationKind::MissingAttr { key } => write!(f, "missing required attribute `{key}`"),
271            ViolationKind::UnknownAttr { key } => write!(f, "unknown attribute `{key}`"),
272        }
273    }
274}
275
276impl fmt::Display for Violation {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        write!(f, "at {:?}: {}", self.path, self.kind)
279    }
280}
281
282impl Node {
283    /// Validate against `schema`, collecting every [`Violation`]. An empty
284    /// result means the document is valid.
285    ///
286    /// ```
287    /// use tiptap_rusty_parser::{Document, Schema, NodeSpec};
288    /// let schema = Schema::new()
289    ///     .node("doc", NodeSpec::new().content(["paragraph"]))
290    ///     .node("paragraph", NodeSpec::new())
291    ///     .node("heading", NodeSpec::new());
292    /// let doc = Document::from_json_str(
293    ///     r#"{"type":"doc","content":[{"type":"heading"}]}"#,
294    /// ).unwrap();
295    /// let v = doc.validate(&schema);
296    /// assert_eq!(v.len(), 1); // heading is a known type, but not allowed as a child of doc
297    /// ```
298    pub fn validate(&self, schema: &Schema) -> Vec<Violation> {
299        let mut out = Vec::new();
300        let mut path = Vec::new();
301        validate_node(self, schema, &mut path, &mut out);
302        out
303    }
304
305    /// True if the document has no schema violations.
306    ///
307    /// ```
308    /// use tiptap_rusty_parser::{Document, Schema, NodeSpec};
309    /// let schema = Schema::new().node("doc", NodeSpec::new());
310    /// let doc = Document::from_json_str(r#"{"type":"doc"}"#).unwrap();
311    /// assert!(doc.is_valid(&schema));
312    /// ```
313    pub fn is_valid(&self, schema: &Schema) -> bool {
314        self.validate(schema).is_empty()
315    }
316}
317
318fn validate_node(node: &Node, schema: &Schema, path: &mut Vec<usize>, out: &mut Vec<Violation>) {
319    let push = |out: &mut Vec<Violation>, path: &[usize], kind: ViolationKind| {
320        out.push(Violation {
321            path: path.to_vec(),
322            kind,
323        });
324    };
325
326    let spec = match &node.node_type {
327        None => {
328            push(out, path, ViolationKind::MissingNodeType);
329            None
330        }
331        Some(t) => match schema.nodes.get(t) {
332            Some(spec) => Some(spec),
333            None => {
334                push(out, path, ViolationKind::UnknownNodeType(t.clone()));
335                None
336            }
337        },
338    };
339
340    if let Some(spec) = spec {
341        // node attrs
342        check_attrs(
343            node.attrs.as_ref(),
344            spec.attrs.as_ref(),
345            &spec.required_attrs,
346            path,
347            out,
348        );
349
350        // marks
351        if let Some(marks) = &node.marks {
352            let node_type = node.node_type.as_deref().unwrap_or_default();
353            for mark in marks {
354                match schema.marks.get(&mark.mark_type) {
355                    None => push(
356                        out,
357                        path,
358                        ViolationKind::UnknownMark(mark.mark_type.clone()),
359                    ),
360                    Some(mark_spec) => {
361                        if let Some(allowed) = &spec.marks {
362                            if !allowed.contains(&mark.mark_type) {
363                                push(
364                                    out,
365                                    path,
366                                    ViolationKind::DisallowedMark {
367                                        node: node_type.to_string(),
368                                        mark: mark.mark_type.clone(),
369                                    },
370                                );
371                            }
372                        }
373                        check_attrs(
374                            mark.attrs.as_ref(),
375                            mark_spec.attrs.as_ref(),
376                            &mark_spec.required_attrs,
377                            path,
378                            out,
379                        );
380                    }
381                }
382            }
383        }
384
385        // content rules
386        if let Some(rule) = &spec.content {
387            let parent = node.node_type.as_deref().unwrap_or_default();
388            match rule {
389                // array form: per-child type membership (unchanged behavior)
390                ContentRule::Types(allowed) => {
391                    if let Some(children) = &node.content {
392                        for child in children {
393                            if let Some(ct) = &child.node_type {
394                                if !allowed.contains(ct) {
395                                    push(
396                                        out,
397                                        path,
398                                        ViolationKind::DisallowedChild {
399                                            parent: parent.to_string(),
400                                            child: ct.clone(),
401                                        },
402                                    );
403                                }
404                            }
405                        }
406                    }
407                }
408                // expression form: cardinality + ordering over the child sequence
409                ContentRule::Expr(expr) => {
410                    let children = node.content.as_deref().unwrap_or(&[]);
411                    if !expr.matches(children, schema) {
412                        push(
413                            out,
414                            path,
415                            ViolationKind::InvalidContent {
416                                parent: parent.to_string(),
417                                expr: expr.as_str().to_string(),
418                            },
419                        );
420                    }
421                }
422            }
423        }
424    }
425
426    // recurse
427    if let Some(children) = &node.content {
428        for (i, child) in children.iter().enumerate() {
429            path.push(i);
430            validate_node(child, schema, path, out);
431            path.pop();
432        }
433    }
434}
435
436fn check_attrs(
437    attrs: Option<&serde_json::Map<String, serde_json::Value>>,
438    allowed: Option<&HashSet<String>>,
439    required: &HashSet<String>,
440    path: &[usize],
441    out: &mut Vec<Violation>,
442) {
443    for key in required {
444        let present = attrs.is_some_and(|m| m.contains_key(key));
445        if !present {
446            out.push(Violation {
447                path: path.to_vec(),
448                kind: ViolationKind::MissingAttr { key: key.clone() },
449            });
450        }
451    }
452    if let (Some(allowed), Some(attrs)) = (allowed, attrs) {
453        for key in attrs.keys() {
454            if !allowed.contains(key) {
455                out.push(Violation {
456                    path: path.to_vec(),
457                    kind: ViolationKind::UnknownAttr { key: key.clone() },
458                });
459            }
460        }
461    }
462}