Skip to main content

tiptap_rusty_parser/
content.rs

1//! ProseMirror **content expressions** — cardinality + ordering for schema
2//! validation (e.g. `paragraph+`, `heading{1,3}`, `(text | image)*`, `block+`).
3//!
4//! A [`NodeSpec`](crate::NodeSpec)'s `content` is a [`ContentRule`]: either a
5//! set of allowed child types (the array form, order/count-insensitive) or a
6//! [`ContentExpr`] parsed from a content-expression string. Expressions compile
7//! to an NFA and are matched against a node's child-type sequence; a name in an
8//! expression matches a child whose `type` equals it **or** whose
9//! [`NodeSpec::group`](crate::NodeSpec) contains it.
10//!
11//! Supported grammar (the full ProseMirror content-expression language):
12//! ```text
13//! Choice  := Seq ('|' Seq)*
14//! Seq     := Postfix*                  // whitespace-separated; empty allowed
15//! Postfix := Atom ('*' | '+' | '?' | '{' n (',' m?)? '}')?
16//! Atom    := '(' Choice ')' | Name     // Name = [A-Za-z0-9_-]+
17//! ```
18
19use crate::node::Node;
20use crate::schema::Schema;
21use serde::de::{self, Deserializer, Visitor};
22use serde::{Deserialize, Serialize, Serializer};
23use std::collections::HashSet;
24use std::fmt;
25
26/// Largest explicit repeat bound allowed in `{n}`/`{n,m}` (guards NFA blowup).
27const MAX_REPEAT: u32 = 1000;
28/// Hard ceiling on compiled NFA states (guards pathological expressions).
29const MAX_NFA_STATES: usize = 100_000;
30
31/// Error parsing (or compiling) a content expression.
32#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
33#[error("invalid content expression at byte {pos}: {msg}")]
34pub struct ParseExprError {
35    /// Byte offset into the source string where the problem was detected.
36    pub pos: usize,
37    /// Human-readable reason.
38    pub msg: String,
39}
40
41// ---- AST ----------------------------------------------------------------
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44enum Expr {
45    Empty,
46    Name(String),
47    Seq(Vec<Expr>),
48    Choice(Vec<Expr>),
49    Star(Box<Expr>),
50    Plus(Box<Expr>),
51    Opt(Box<Expr>),
52    Range {
53        min: u32,
54        max: Option<u32>,
55        inner: Box<Expr>,
56    },
57}
58
59fn is_name_char(c: char) -> bool {
60    c.is_alphanumeric() || c == '_' || c == '-'
61}
62
63struct Parser<'a> {
64    input: &'a str,
65    pos: usize,
66}
67
68impl<'a> Parser<'a> {
69    fn new(input: &'a str) -> Self {
70        Self { input, pos: 0 }
71    }
72    fn peek(&self) -> Option<char> {
73        self.input[self.pos..].chars().next()
74    }
75    fn bump(&mut self) -> Option<char> {
76        let c = self.peek()?;
77        self.pos += c.len_utf8();
78        Some(c)
79    }
80    fn skip_ws(&mut self) {
81        while let Some(c) = self.peek() {
82            if c.is_whitespace() {
83                self.pos += c.len_utf8();
84            } else {
85                break;
86            }
87        }
88    }
89    fn err(&self, msg: impl Into<String>) -> ParseExprError {
90        ParseExprError {
91            pos: self.pos,
92            msg: msg.into(),
93        }
94    }
95
96    fn parse_choice(&mut self) -> Result<Expr, ParseExprError> {
97        let mut opts = vec![self.parse_seq()?];
98        loop {
99            self.skip_ws();
100            if self.peek() == Some('|') {
101                self.bump();
102                opts.push(self.parse_seq()?);
103            } else {
104                break;
105            }
106        }
107        Ok(if opts.len() == 1 {
108            opts.pop().unwrap()
109        } else {
110            Expr::Choice(opts)
111        })
112    }
113
114    fn parse_seq(&mut self) -> Result<Expr, ParseExprError> {
115        let mut items = Vec::new();
116        loop {
117            self.skip_ws();
118            match self.peek() {
119                Some(c) if c == '(' || is_name_char(c) => items.push(self.parse_postfix()?),
120                _ => break,
121            }
122        }
123        Ok(match items.len() {
124            0 => Expr::Empty,
125            1 => items.pop().unwrap(),
126            _ => Expr::Seq(items),
127        })
128    }
129
130    fn parse_postfix(&mut self) -> Result<Expr, ParseExprError> {
131        let atom = self.parse_atom()?;
132        self.skip_ws();
133        match self.peek() {
134            Some('*') => {
135                self.bump();
136                Ok(Expr::Star(Box::new(atom)))
137            }
138            Some('+') => {
139                self.bump();
140                Ok(Expr::Plus(Box::new(atom)))
141            }
142            Some('?') => {
143                self.bump();
144                Ok(Expr::Opt(Box::new(atom)))
145            }
146            Some('{') => {
147                self.bump();
148                let (min, max) = self.parse_range()?;
149                Ok(Expr::Range {
150                    min,
151                    max,
152                    inner: Box::new(atom),
153                })
154            }
155            _ => Ok(atom),
156        }
157    }
158
159    fn parse_atom(&mut self) -> Result<Expr, ParseExprError> {
160        self.skip_ws();
161        match self.peek() {
162            Some('(') => {
163                self.bump();
164                let e = self.parse_choice()?;
165                self.skip_ws();
166                if self.peek() != Some(')') {
167                    return Err(self.err("expected ')'"));
168                }
169                self.bump();
170                Ok(e)
171            }
172            Some(c) if is_name_char(c) => {
173                let start = self.pos;
174                while self.peek().is_some_and(is_name_char) {
175                    self.bump();
176                }
177                Ok(Expr::Name(self.input[start..self.pos].to_string()))
178            }
179            _ => Err(self.err("expected a name or '('")),
180        }
181    }
182
183    /// Parse a `{n}` / `{n,}` / `{n,m}` body (the opening `{` already consumed).
184    fn parse_range(&mut self) -> Result<(u32, Option<u32>), ParseExprError> {
185        self.skip_ws();
186        let min = self.parse_num()?;
187        self.skip_ws();
188        let max = match self.peek() {
189            Some(',') => {
190                self.bump();
191                self.skip_ws();
192                match self.peek() {
193                    Some('}') => None,
194                    Some(c) if c.is_ascii_digit() => Some(self.parse_num()?),
195                    _ => return Err(self.err("expected a number or '}' in range")),
196                }
197            }
198            Some('}') => Some(min),
199            _ => return Err(self.err("expected ',' or '}' in range")),
200        };
201        self.skip_ws();
202        if self.peek() != Some('}') {
203            return Err(self.err("expected '}'"));
204        }
205        self.bump();
206        if min > MAX_REPEAT || max.is_some_and(|m| m > MAX_REPEAT) {
207            return Err(self.err(format!("repeat count exceeds cap of {MAX_REPEAT}")));
208        }
209        if max.is_some_and(|m| m < min) {
210            return Err(self.err("range maximum is less than minimum"));
211        }
212        Ok((min, max))
213    }
214
215    fn parse_num(&mut self) -> Result<u32, ParseExprError> {
216        let start = self.pos;
217        while self.peek().is_some_and(|c| c.is_ascii_digit()) {
218            self.bump();
219        }
220        if self.pos == start {
221            return Err(self.err("expected a number"));
222        }
223        self.input[start..self.pos]
224            .parse()
225            .map_err(|_| self.err("number too large"))
226    }
227}
228
229fn parse(input: &str) -> Result<Expr, ParseExprError> {
230    let mut p = Parser::new(input);
231    let e = p.parse_choice()?;
232    p.skip_ws();
233    if p.pos != input.len() {
234        return Err(p.err("unexpected trailing input"));
235    }
236    Ok(e)
237}
238
239// ---- NFA ----------------------------------------------------------------
240
241#[derive(Debug, Clone)]
242struct State {
243    eps: Vec<usize>,
244    edges: Vec<(String, usize)>,
245}
246
247#[derive(Debug, Clone)]
248struct Nfa {
249    states: Vec<State>,
250    start: usize,
251    accept: usize,
252}
253
254struct Builder {
255    states: Vec<State>,
256}
257
258impl Builder {
259    fn new_state(&mut self) -> Result<usize, ParseExprError> {
260        if self.states.len() >= MAX_NFA_STATES {
261            return Err(ParseExprError {
262                pos: 0,
263                msg: format!("content expression too large (> {MAX_NFA_STATES} states)"),
264            });
265        }
266        self.states.push(State {
267            eps: Vec::new(),
268            edges: Vec::new(),
269        });
270        Ok(self.states.len() - 1)
271    }
272    fn eps(&mut self, from: usize, to: usize) {
273        self.states[from].eps.push(to);
274    }
275
276    /// Build a Thompson fragment for `e`, returning its `(start, out)` states.
277    fn build(&mut self, e: &Expr) -> Result<(usize, usize), ParseExprError> {
278        match e {
279            Expr::Empty => {
280                let (s, o) = (self.new_state()?, self.new_state()?);
281                self.eps(s, o);
282                Ok((s, o))
283            }
284            Expr::Name(n) => {
285                let (s, o) = (self.new_state()?, self.new_state()?);
286                self.states[s].edges.push((n.clone(), o));
287                Ok((s, o))
288            }
289            Expr::Seq(items) => {
290                let s = self.new_state()?;
291                let mut cur = s;
292                for it in items {
293                    let (fs, fo) = self.build(it)?;
294                    self.eps(cur, fs);
295                    cur = fo;
296                }
297                Ok((s, cur))
298            }
299            Expr::Choice(opts) => {
300                let (s, o) = (self.new_state()?, self.new_state()?);
301                for opt in opts {
302                    let (fs, fo) = self.build(opt)?;
303                    self.eps(s, fs);
304                    self.eps(fo, o);
305                }
306                Ok((s, o))
307            }
308            Expr::Star(inner) => {
309                let (s, o) = (self.new_state()?, self.new_state()?);
310                let (fs, fo) = self.build(inner)?;
311                self.eps(s, fs);
312                self.eps(s, o);
313                self.eps(fo, fs);
314                self.eps(fo, o);
315                Ok((s, o))
316            }
317            Expr::Plus(inner) => {
318                let (s, o) = (self.new_state()?, self.new_state()?);
319                let (fs, fo) = self.build(inner)?;
320                self.eps(s, fs);
321                self.eps(fo, fs);
322                self.eps(fo, o);
323                Ok((s, o))
324            }
325            Expr::Opt(inner) => {
326                let (s, o) = (self.new_state()?, self.new_state()?);
327                let (fs, fo) = self.build(inner)?;
328                self.eps(s, fs);
329                self.eps(s, o);
330                self.eps(fo, o);
331                Ok((s, o))
332            }
333            Expr::Range { min, max, inner } => {
334                let s = self.new_state()?;
335                let mut cur = s;
336                for _ in 0..*min {
337                    let (fs, fo) = self.build(inner)?;
338                    self.eps(cur, fs);
339                    cur = fo;
340                }
341                match max {
342                    None => {
343                        // open `{n,}` => append `inner*` (loop-back, no expansion)
344                        let (fs, fo) = self.build(inner)?;
345                        let (ss, so) = (self.new_state()?, self.new_state()?);
346                        self.eps(ss, fs);
347                        self.eps(ss, so);
348                        self.eps(fo, fs);
349                        self.eps(fo, so);
350                        self.eps(cur, ss);
351                        cur = so;
352                    }
353                    Some(m) => {
354                        // `(m - min)` optional copies
355                        for _ in *min..*m {
356                            let (fs, fo) = self.build(inner)?;
357                            let (os, oo) = (self.new_state()?, self.new_state()?);
358                            self.eps(os, fs);
359                            self.eps(os, oo);
360                            self.eps(fo, oo);
361                            self.eps(cur, os);
362                            cur = oo;
363                        }
364                    }
365                }
366                Ok((s, cur))
367            }
368        }
369    }
370}
371
372fn compile(ast: &Expr) -> Result<Nfa, ParseExprError> {
373    let mut b = Builder { states: Vec::new() };
374    let (start, accept) = b.build(ast)?;
375    Ok(Nfa {
376        states: b.states,
377        start,
378        accept,
379    })
380}
381
382/// A label matches a child if it equals the child's type or one of its groups.
383fn label_matches(label: &str, child_type: &str, schema: &Schema) -> bool {
384    label == child_type
385        || schema
386            .nodes
387            .get(child_type)
388            .and_then(|spec| spec.group.as_deref())
389            .is_some_and(|g| g.split_whitespace().any(|grp| grp == label))
390}
391
392impl Nfa {
393    fn eps_closure(&self, set: &mut [bool], stack: &mut Vec<usize>) {
394        while let Some(s) = stack.pop() {
395            for &t in &self.states[s].eps {
396                if !set[t] {
397                    set[t] = true;
398                    stack.push(t);
399                }
400            }
401        }
402    }
403
404    fn matches(&self, children: &[Node], schema: &Schema) -> bool {
405        let n = self.states.len();
406        let mut current = vec![false; n];
407        let mut stack = vec![self.start];
408        current[self.start] = true;
409        self.eps_closure(&mut current, &mut stack);
410
411        for child in children {
412            let Some(ct) = child.node_type.as_deref() else {
413                return false; // an untyped child can't satisfy a named slot
414            };
415            let mut next = vec![false; n];
416            let mut nstack = Vec::new();
417            for (s, &active) in current.iter().enumerate() {
418                if active {
419                    for (label, dst) in &self.states[s].edges {
420                        if !next[*dst] && label_matches(label, ct, schema) {
421                            next[*dst] = true;
422                            nstack.push(*dst);
423                        }
424                    }
425                }
426            }
427            self.eps_closure(&mut next, &mut nstack);
428            if !next.iter().any(|&b| b) {
429                return false; // dead — no state survives this child
430            }
431            current = next;
432        }
433        current[self.accept]
434    }
435}
436
437// ---- public types -------------------------------------------------------
438
439/// A compiled ProseMirror content expression (e.g. `paragraph+`).
440///
441/// Parse with [`ContentExpr::parse`]; (de)serializes as its source string.
442#[derive(Debug, Clone)]
443pub struct ContentExpr {
444    raw: String,
445    ast: Expr,
446    nfa: Nfa,
447}
448
449impl ContentExpr {
450    /// Parse and compile a content-expression string.
451    ///
452    /// ```
453    /// use tiptap_rusty_parser::ContentExpr;
454    /// assert!(ContentExpr::parse("paragraph+").is_ok());
455    /// assert!(ContentExpr::parse("(a |").is_err());
456    /// ```
457    pub fn parse(s: &str) -> Result<Self, ParseExprError> {
458        let ast = parse(s)?;
459        let nfa = compile(&ast)?;
460        Ok(Self {
461            raw: s.to_string(),
462            ast,
463            nfa,
464        })
465    }
466
467    /// The original expression source.
468    pub fn as_str(&self) -> &str {
469        &self.raw
470    }
471
472    /// Whether `children`'s type sequence satisfies this expression.
473    pub(crate) fn matches(&self, children: &[Node], schema: &Schema) -> bool {
474        self.nfa.matches(children, schema)
475    }
476}
477
478impl PartialEq for ContentExpr {
479    fn eq(&self, other: &Self) -> bool {
480        self.ast == other.ast // compare structure, not the compiled NFA
481    }
482}
483
484impl Serialize for ContentExpr {
485    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
486        s.serialize_str(&self.raw)
487    }
488}
489
490impl<'de> Deserialize<'de> for ContentExpr {
491    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
492        let s = String::deserialize(d)?;
493        ContentExpr::parse(&s).map_err(de::Error::custom)
494    }
495}
496
497/// A node's allowed content: a set of child types (array form) or an ordered
498/// content [expression](ContentExpr) (string form).
499#[derive(Debug, Clone, PartialEq)]
500pub enum ContentRule {
501    /// Allowed child types, any count/order. Emits `DisallowedChild`.
502    Types(HashSet<String>),
503    /// A content expression (cardinality + ordering). Emits `InvalidContent`.
504    Expr(ContentExpr),
505}
506
507impl Serialize for ContentRule {
508    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
509        match self {
510            ContentRule::Types(set) => set.serialize(s), // JSON array
511            ContentRule::Expr(e) => e.serialize(s),      // JSON string
512        }
513    }
514}
515
516impl<'de> Deserialize<'de> for ContentRule {
517    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
518        struct RuleVisitor;
519        impl<'de> Visitor<'de> for RuleVisitor {
520            type Value = ContentRule;
521            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
522                f.write_str("an array of child type names or a content-expression string")
523            }
524            fn visit_str<E: de::Error>(self, v: &str) -> Result<ContentRule, E> {
525                ContentExpr::parse(v)
526                    .map(ContentRule::Expr)
527                    .map_err(E::custom)
528            }
529            fn visit_string<E: de::Error>(self, v: String) -> Result<ContentRule, E> {
530                self.visit_str(&v)
531            }
532            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ContentRule, A::Error> {
533                let mut set = HashSet::new();
534                while let Some(s) = seq.next_element::<String>()? {
535                    set.insert(s);
536                }
537                Ok(ContentRule::Types(set))
538            }
539        }
540        d.deserialize_any(RuleVisitor)
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    fn name(s: &str) -> Expr {
549        Expr::Name(s.to_string())
550    }
551
552    #[test]
553    fn precedence_and_shape() {
554        // `|` binds loosest; sequence by whitespace; postfix binds to its atom.
555        assert_eq!(
556            parse("a b | c").unwrap(),
557            Expr::Choice(vec![Expr::Seq(vec![name("a"), name("b")]), name("c")])
558        );
559        assert_eq!(
560            parse("a b+").unwrap(),
561            Expr::Seq(vec![name("a"), Expr::Plus(Box::new(name("b")))])
562        );
563        assert_eq!(
564            parse("(a b)+").unwrap(),
565            Expr::Plus(Box::new(Expr::Seq(vec![name("a"), name("b")])))
566        );
567        assert_eq!(parse("").unwrap(), Expr::Empty);
568        assert_eq!(
569            parse("h{2,3}").unwrap(),
570            Expr::Range {
571                min: 2,
572                max: Some(3),
573                inner: Box::new(name("h")),
574            }
575        );
576        assert_eq!(
577            parse("h{2,}").unwrap(),
578            Expr::Range {
579                min: 2,
580                max: None,
581                inner: Box::new(name("h")),
582            }
583        );
584    }
585
586    #[test]
587    fn range_cap_and_errors() {
588        assert!(parse("a{2000}").is_err());
589        assert!(parse("a{3,1}").is_err());
590        assert!(parse("a**").is_err());
591        // serialized form round-trips through ContentExpr
592        let e = ContentExpr::parse("(a | b) c*").unwrap();
593        assert_eq!(e.as_str(), "(a | b) c*");
594    }
595}