1use crate::node::Node;
27use serde::{Deserialize, Serialize};
28use std::collections::{HashMap, HashSet};
29use std::fmt;
30
31#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
34pub struct Schema {
35 #[serde(default)]
37 pub nodes: HashMap<String, NodeSpec>,
38 #[serde(default)]
40 pub marks: HashMap<String, MarkSpec>,
41}
42
43#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
45pub struct NodeSpec {
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub content: Option<HashSet<String>>,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub marks: Option<HashSet<String>>,
52 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub attrs: Option<HashSet<String>>,
55 #[serde(default, skip_serializing_if = "HashSet::is_empty")]
57 pub required_attrs: HashSet<String>,
58}
59
60#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
62pub struct MarkSpec {
63 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub attrs: Option<HashSet<String>>,
66 #[serde(default, skip_serializing_if = "HashSet::is_empty")]
68 pub required_attrs: HashSet<String>,
69}
70
71fn into_set<I, S>(items: I) -> HashSet<String>
72where
73 I: IntoIterator<Item = S>,
74 S: Into<String>,
75{
76 items.into_iter().map(Into::into).collect()
77}
78
79impl Schema {
80 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn node(mut self, node_type: impl Into<String>, spec: NodeSpec) -> Self {
87 self.nodes.insert(node_type.into(), spec);
88 self
89 }
90
91 pub fn mark(mut self, mark_type: impl Into<String>, spec: MarkSpec) -> Self {
93 self.marks.insert(mark_type.into(), spec);
94 self
95 }
96
97 pub fn from_json_str(s: &str) -> crate::Result<Self> {
108 Ok(serde_json::from_str(s)?)
109 }
110}
111
112impl NodeSpec {
113 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn content<I, S>(mut self, types: I) -> Self
120 where
121 I: IntoIterator<Item = S>,
122 S: Into<String>,
123 {
124 self.content = Some(into_set(types));
125 self
126 }
127
128 pub fn marks<I, S>(mut self, types: I) -> Self
130 where
131 I: IntoIterator<Item = S>,
132 S: Into<String>,
133 {
134 self.marks = Some(into_set(types));
135 self
136 }
137
138 pub fn attrs<I, S>(mut self, keys: I) -> Self
140 where
141 I: IntoIterator<Item = S>,
142 S: Into<String>,
143 {
144 self.attrs = Some(into_set(keys));
145 self
146 }
147
148 pub fn required_attrs<I, S>(mut self, keys: I) -> Self
150 where
151 I: IntoIterator<Item = S>,
152 S: Into<String>,
153 {
154 self.required_attrs = into_set(keys);
155 self
156 }
157}
158
159impl MarkSpec {
160 pub fn new() -> Self {
162 Self::default()
163 }
164
165 pub fn attrs<I, S>(mut self, keys: I) -> Self
167 where
168 I: IntoIterator<Item = S>,
169 S: Into<String>,
170 {
171 self.attrs = Some(into_set(keys));
172 self
173 }
174
175 pub fn required_attrs<I, S>(mut self, keys: I) -> Self
177 where
178 I: IntoIterator<Item = S>,
179 S: Into<String>,
180 {
181 self.required_attrs = into_set(keys);
182 self
183 }
184}
185
186#[derive(Debug, Clone, PartialEq)]
188pub struct Violation {
189 pub path: Vec<usize>,
191 pub kind: ViolationKind,
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
197pub enum ViolationKind {
198 MissingNodeType,
200 UnknownNodeType(String),
202 DisallowedChild { parent: String, child: String },
204 UnknownMark(String),
206 DisallowedMark { node: String, mark: String },
208 MissingAttr { key: String },
210 UnknownAttr { key: String },
212}
213
214impl fmt::Display for ViolationKind {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 match self {
217 ViolationKind::MissingNodeType => write!(f, "node has no type"),
218 ViolationKind::UnknownNodeType(t) => write!(f, "unknown node type `{t}`"),
219 ViolationKind::DisallowedChild { parent, child } => {
220 write!(f, "node type `{child}` not allowed inside `{parent}`")
221 }
222 ViolationKind::UnknownMark(m) => write!(f, "unknown mark type `{m}`"),
223 ViolationKind::DisallowedMark { node, mark } => {
224 write!(f, "mark `{mark}` not allowed on `{node}`")
225 }
226 ViolationKind::MissingAttr { key } => write!(f, "missing required attribute `{key}`"),
227 ViolationKind::UnknownAttr { key } => write!(f, "unknown attribute `{key}`"),
228 }
229 }
230}
231
232impl fmt::Display for Violation {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 write!(f, "at {:?}: {}", self.path, self.kind)
235 }
236}
237
238impl Node {
239 pub fn validate(&self, schema: &Schema) -> Vec<Violation> {
255 let mut out = Vec::new();
256 let mut path = Vec::new();
257 validate_node(self, schema, &mut path, &mut out);
258 out
259 }
260
261 pub fn is_valid(&self, schema: &Schema) -> bool {
270 self.validate(schema).is_empty()
271 }
272}
273
274fn validate_node(node: &Node, schema: &Schema, path: &mut Vec<usize>, out: &mut Vec<Violation>) {
275 let push = |out: &mut Vec<Violation>, path: &[usize], kind: ViolationKind| {
276 out.push(Violation {
277 path: path.to_vec(),
278 kind,
279 });
280 };
281
282 let spec = match &node.node_type {
283 None => {
284 push(out, path, ViolationKind::MissingNodeType);
285 None
286 }
287 Some(t) => match schema.nodes.get(t) {
288 Some(spec) => Some(spec),
289 None => {
290 push(out, path, ViolationKind::UnknownNodeType(t.clone()));
291 None
292 }
293 },
294 };
295
296 if let Some(spec) = spec {
297 check_attrs(
299 node.attrs.as_ref(),
300 spec.attrs.as_ref(),
301 &spec.required_attrs,
302 path,
303 out,
304 );
305
306 if let Some(marks) = &node.marks {
308 let node_type = node.node_type.as_deref().unwrap_or_default();
309 for mark in marks {
310 match schema.marks.get(&mark.mark_type) {
311 None => push(
312 out,
313 path,
314 ViolationKind::UnknownMark(mark.mark_type.clone()),
315 ),
316 Some(mark_spec) => {
317 if let Some(allowed) = &spec.marks {
318 if !allowed.contains(&mark.mark_type) {
319 push(
320 out,
321 path,
322 ViolationKind::DisallowedMark {
323 node: node_type.to_string(),
324 mark: mark.mark_type.clone(),
325 },
326 );
327 }
328 }
329 check_attrs(
330 mark.attrs.as_ref(),
331 mark_spec.attrs.as_ref(),
332 &mark_spec.required_attrs,
333 path,
334 out,
335 );
336 }
337 }
338 }
339 }
340
341 if let (Some(allowed), Some(children)) = (&spec.content, &node.content) {
343 let parent = node.node_type.as_deref().unwrap_or_default();
344 for child in children {
345 if let Some(ct) = &child.node_type {
346 if !allowed.contains(ct) {
347 push(
348 out,
349 path,
350 ViolationKind::DisallowedChild {
351 parent: parent.to_string(),
352 child: ct.clone(),
353 },
354 );
355 }
356 }
357 }
358 }
359 }
360
361 if let Some(children) = &node.content {
363 for (i, child) in children.iter().enumerate() {
364 path.push(i);
365 validate_node(child, schema, path, out);
366 path.pop();
367 }
368 }
369}
370
371fn check_attrs(
372 attrs: Option<&serde_json::Map<String, serde_json::Value>>,
373 allowed: Option<&HashSet<String>>,
374 required: &HashSet<String>,
375 path: &[usize],
376 out: &mut Vec<Violation>,
377) {
378 for key in required {
379 let present = attrs.is_some_and(|m| m.contains_key(key));
380 if !present {
381 out.push(Violation {
382 path: path.to_vec(),
383 kind: ViolationKind::MissingAttr { key: key.clone() },
384 });
385 }
386 }
387 if let (Some(allowed), Some(attrs)) = (allowed, attrs) {
388 for key in attrs.keys() {
389 if !allowed.contains(key) {
390 out.push(Violation {
391 path: path.to_vec(),
392 kind: ViolationKind::UnknownAttr { key: key.clone() },
393 });
394 }
395 }
396 }
397}