1use crate::content::{ContentExpr, ContentRule, ParseExprError};
27use crate::node::Node;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30use std::fmt;
31
32#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
35pub struct Schema {
36 #[serde(default)]
38 pub nodes: HashMap<String, NodeSpec>,
39 #[serde(default)]
41 pub marks: HashMap<String, MarkSpec>,
42}
43
44#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
46pub struct NodeSpec {
47 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub content: Option<ContentRule>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub group: Option<String>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub marks: Option<HashSet<String>>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub attrs: Option<HashSet<String>>,
61 #[serde(default, skip_serializing_if = "HashSet::is_empty")]
63 pub required_attrs: HashSet<String>,
64}
65
66#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
68pub struct MarkSpec {
69 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub attrs: Option<HashSet<String>>,
72 #[serde(default, skip_serializing_if = "HashSet::is_empty")]
74 pub required_attrs: HashSet<String>,
75}
76
77fn into_set<I, S>(items: I) -> HashSet<String>
78where
79 I: IntoIterator<Item = S>,
80 S: Into<String>,
81{
82 items.into_iter().map(Into::into).collect()
83}
84
85impl Schema {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn node(mut self, node_type: impl Into<String>, spec: NodeSpec) -> Self {
93 self.nodes.insert(node_type.into(), spec);
94 self
95 }
96
97 pub fn mark(mut self, mark_type: impl Into<String>, spec: MarkSpec) -> Self {
99 self.marks.insert(mark_type.into(), spec);
100 self
101 }
102
103 pub fn from_json_str(s: &str) -> crate::Result<Self> {
114 Ok(serde_json::from_str(s)?)
115 }
116}
117
118impl NodeSpec {
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn content<I, S>(mut self, types: I) -> Self
127 where
128 I: IntoIterator<Item = S>,
129 S: Into<String>,
130 {
131 self.content = Some(ContentRule::Types(into_set(types)));
132 self
133 }
134
135 pub fn content_match(self, expr: &str) -> Self {
139 self.try_content_match(expr)
140 .expect("invalid content expression")
141 }
142
143 pub fn try_content_match(mut self, expr: &str) -> Result<Self, ParseExprError> {
145 self.content = Some(ContentRule::Expr(ContentExpr::parse(expr)?));
146 Ok(self)
147 }
148
149 pub fn group(mut self, group: impl Into<String>) -> Self {
151 self.group = Some(group.into());
152 self
153 }
154
155 pub fn content_types(&self) -> Option<&HashSet<String>> {
158 match &self.content {
159 Some(ContentRule::Types(set)) => Some(set),
160 _ => None,
161 }
162 }
163
164 pub fn marks<I, S>(mut self, types: I) -> Self
166 where
167 I: IntoIterator<Item = S>,
168 S: Into<String>,
169 {
170 self.marks = Some(into_set(types));
171 self
172 }
173
174 pub fn attrs<I, S>(mut self, keys: I) -> Self
176 where
177 I: IntoIterator<Item = S>,
178 S: Into<String>,
179 {
180 self.attrs = Some(into_set(keys));
181 self
182 }
183
184 pub fn required_attrs<I, S>(mut self, keys: I) -> Self
186 where
187 I: IntoIterator<Item = S>,
188 S: Into<String>,
189 {
190 self.required_attrs = into_set(keys);
191 self
192 }
193}
194
195impl MarkSpec {
196 pub fn new() -> Self {
198 Self::default()
199 }
200
201 pub fn attrs<I, S>(mut self, keys: I) -> Self
203 where
204 I: IntoIterator<Item = S>,
205 S: Into<String>,
206 {
207 self.attrs = Some(into_set(keys));
208 self
209 }
210
211 pub fn required_attrs<I, S>(mut self, keys: I) -> Self
213 where
214 I: IntoIterator<Item = S>,
215 S: Into<String>,
216 {
217 self.required_attrs = into_set(keys);
218 self
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Violation {
225 pub path: Vec<usize>,
227 pub kind: ViolationKind,
229}
230
231#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
233pub enum ViolationKind {
234 MissingNodeType,
236 UnknownNodeType(String),
238 DisallowedChild { parent: String, child: String },
240 InvalidContent { parent: String, expr: String },
242 UnknownMark(String),
244 DisallowedMark { node: String, mark: String },
246 MissingAttr { key: String },
248 UnknownAttr { key: String },
250}
251
252impl fmt::Display for ViolationKind {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 match self {
255 ViolationKind::MissingNodeType => write!(f, "node has no type"),
256 ViolationKind::UnknownNodeType(t) => write!(f, "unknown node type `{t}`"),
257 ViolationKind::DisallowedChild { parent, child } => {
258 write!(f, "node type `{child}` not allowed inside `{parent}`")
259 }
260 ViolationKind::InvalidContent { parent, expr } => {
261 write!(
262 f,
263 "children of `{parent}` do not match content expression `{expr}`"
264 )
265 }
266 ViolationKind::UnknownMark(m) => write!(f, "unknown mark type `{m}`"),
267 ViolationKind::DisallowedMark { node, mark } => {
268 write!(f, "mark `{mark}` not allowed on `{node}`")
269 }
270 ViolationKind::MissingAttr { key } => write!(f, "missing required attribute `{key}`"),
271 ViolationKind::UnknownAttr { key } => write!(f, "unknown attribute `{key}`"),
272 }
273 }
274}
275
276impl fmt::Display for Violation {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 write!(f, "at {:?}: {}", self.path, self.kind)
279 }
280}
281
282impl Node {
283 pub fn validate(&self, schema: &Schema) -> Vec<Violation> {
299 let mut out = Vec::new();
300 let mut path = Vec::new();
301 validate_node(self, schema, &mut path, &mut out);
302 out
303 }
304
305 pub fn is_valid(&self, schema: &Schema) -> bool {
314 self.validate(schema).is_empty()
315 }
316}
317
318fn validate_node(node: &Node, schema: &Schema, path: &mut Vec<usize>, out: &mut Vec<Violation>) {
319 let push = |out: &mut Vec<Violation>, path: &[usize], kind: ViolationKind| {
320 out.push(Violation {
321 path: path.to_vec(),
322 kind,
323 });
324 };
325
326 let spec = match &node.node_type {
327 None => {
328 push(out, path, ViolationKind::MissingNodeType);
329 None
330 }
331 Some(t) => match schema.nodes.get(t) {
332 Some(spec) => Some(spec),
333 None => {
334 push(out, path, ViolationKind::UnknownNodeType(t.clone()));
335 None
336 }
337 },
338 };
339
340 if let Some(spec) = spec {
341 check_attrs(
343 node.attrs.as_ref(),
344 spec.attrs.as_ref(),
345 &spec.required_attrs,
346 path,
347 out,
348 );
349
350 if let Some(marks) = &node.marks {
352 let node_type = node.node_type.as_deref().unwrap_or_default();
353 for mark in marks {
354 match schema.marks.get(&mark.mark_type) {
355 None => push(
356 out,
357 path,
358 ViolationKind::UnknownMark(mark.mark_type.clone()),
359 ),
360 Some(mark_spec) => {
361 if let Some(allowed) = &spec.marks {
362 if !allowed.contains(&mark.mark_type) {
363 push(
364 out,
365 path,
366 ViolationKind::DisallowedMark {
367 node: node_type.to_string(),
368 mark: mark.mark_type.clone(),
369 },
370 );
371 }
372 }
373 check_attrs(
374 mark.attrs.as_ref(),
375 mark_spec.attrs.as_ref(),
376 &mark_spec.required_attrs,
377 path,
378 out,
379 );
380 }
381 }
382 }
383 }
384
385 if let Some(rule) = &spec.content {
387 let parent = node.node_type.as_deref().unwrap_or_default();
388 match rule {
389 ContentRule::Types(allowed) => {
391 if let Some(children) = &node.content {
392 for child in children {
393 if let Some(ct) = &child.node_type {
394 if !allowed.contains(ct) {
395 push(
396 out,
397 path,
398 ViolationKind::DisallowedChild {
399 parent: parent.to_string(),
400 child: ct.clone(),
401 },
402 );
403 }
404 }
405 }
406 }
407 }
408 ContentRule::Expr(expr) => {
410 let children = node.content.as_deref().unwrap_or(&[]);
411 if !expr.matches(children, schema) {
412 push(
413 out,
414 path,
415 ViolationKind::InvalidContent {
416 parent: parent.to_string(),
417 expr: expr.as_str().to_string(),
418 },
419 );
420 }
421 }
422 }
423 }
424 }
425
426 if let Some(children) = &node.content {
428 for (i, child) in children.iter().enumerate() {
429 path.push(i);
430 validate_node(child, schema, path, out);
431 path.pop();
432 }
433 }
434}
435
436fn check_attrs(
437 attrs: Option<&serde_json::Map<String, serde_json::Value>>,
438 allowed: Option<&HashSet<String>>,
439 required: &HashSet<String>,
440 path: &[usize],
441 out: &mut Vec<Violation>,
442) {
443 for key in required {
444 let present = attrs.is_some_and(|m| m.contains_key(key));
445 if !present {
446 out.push(Violation {
447 path: path.to_vec(),
448 kind: ViolationKind::MissingAttr { key: key.clone() },
449 });
450 }
451 }
452 if let (Some(allowed), Some(attrs)) = (allowed, attrs) {
453 for key in attrs.keys() {
454 if !allowed.contains(key) {
455 out.push(Violation {
456 path: path.to_vec(),
457 kind: ViolationKind::UnknownAttr { key: key.clone() },
458 });
459 }
460 }
461 }
462}