1use std::collections::HashSet;
27use std::fmt;
28
29use crate::tree::{Document, NodeId, NodeKind};
30
31pub const XINCLUDE_NS: &str = "http://www.w3.org/2001/XInclude";
36
37const INCLUDE_ELEMENT: &str = "include";
39
40const FALLBACK_ELEMENT: &str = "fallback";
42
43#[derive(Debug, Clone)]
57pub struct XIncludeOptions {
58 pub max_depth: usize,
64}
65
66impl Default for XIncludeOptions {
67 fn default() -> Self {
68 Self { max_depth: 50 }
69 }
70}
71
72#[derive(Debug, Clone)]
77pub struct XIncludeError {
78 pub message: String,
80 pub href: Option<String>,
82}
83
84impl fmt::Display for XIncludeError {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 match &self.href {
87 Some(href) => write!(f, "XInclude error for '{href}': {}", self.message),
88 None => write!(f, "XInclude error: {}", self.message),
89 }
90 }
91}
92
93pub struct XIncludeResult {
98 pub inclusions: usize,
100 pub errors: Vec<XIncludeError>,
105}
106
107pub fn process_xincludes<F>(
153 doc: &mut Document,
154 resolver: F,
155 options: &XIncludeOptions,
156) -> XIncludeResult
157where
158 F: Fn(&str) -> Option<String>,
159{
160 let mut state = ProcessingState {
161 inclusions: 0,
162 errors: Vec::new(),
163 active_hrefs: HashSet::new(),
164 max_depth: options.max_depth,
165 };
166
167 process_node(doc, doc.root(), &resolver, &mut state, 0);
168
169 XIncludeResult {
170 inclusions: state.inclusions,
171 errors: state.errors,
172 }
173}
174
175struct ProcessingState {
177 inclusions: usize,
179 errors: Vec<XIncludeError>,
181 active_hrefs: HashSet<String>,
183 max_depth: usize,
185}
186
187fn process_node<F>(
192 doc: &mut Document,
193 node: NodeId,
194 resolver: &F,
195 state: &mut ProcessingState,
196 depth: usize,
197) where
198 F: Fn(&str) -> Option<String>,
199{
200 let children: Vec<NodeId> = doc.children(node).collect();
202
203 for child in children {
204 if is_xinclude_element(doc, child) {
205 process_include_element(doc, child, resolver, state, depth);
206 } else {
207 process_node(doc, child, resolver, state, depth);
209 }
210 }
211}
212
213fn is_xinclude_element(doc: &Document, node: NodeId) -> bool {
215 if let NodeKind::Element {
216 name, namespace, ..
217 } = &doc.node(node).kind
218 {
219 name == INCLUDE_ELEMENT && namespace.as_deref() == Some(XINCLUDE_NS)
220 } else {
221 false
222 }
223}
224
225fn is_fallback_element(doc: &Document, node: NodeId) -> bool {
227 if let NodeKind::Element {
228 name, namespace, ..
229 } = &doc.node(node).kind
230 {
231 name == FALLBACK_ELEMENT && namespace.as_deref() == Some(XINCLUDE_NS)
232 } else {
233 false
234 }
235}
236
237fn process_include_element<F>(
242 doc: &mut Document,
243 include_node: NodeId,
244 resolver: &F,
245 state: &mut ProcessingState,
246 depth: usize,
247) where
248 F: Fn(&str) -> Option<String>,
249{
250 let href = doc.attribute(include_node, "href").map(str::to_owned);
252 let parse = doc
253 .attribute(include_node, "parse")
254 .unwrap_or("xml")
255 .to_owned();
256
257 let Some(href) = href else {
259 state.errors.push(XIncludeError {
260 message: "xi:include element is missing required 'href' attribute".to_string(),
261 href: None,
262 });
263 doc.detach(include_node);
265 return;
266 };
267
268 if parse != "xml" && parse != "text" {
270 state.errors.push(XIncludeError {
271 message: format!("invalid parse attribute value '{parse}'; expected 'xml' or 'text'"),
272 href: Some(href),
273 });
274 doc.detach(include_node);
275 return;
276 }
277
278 if depth >= state.max_depth {
280 state.errors.push(XIncludeError {
281 message: format!(
282 "maximum XInclude nesting depth ({}) exceeded",
283 state.max_depth
284 ),
285 href: Some(href),
286 });
287 doc.detach(include_node);
288 return;
289 }
290
291 let (base_href, _fragment) = split_fragment(&href);
294
295 if state.active_hrefs.contains(base_href) {
297 state.errors.push(XIncludeError {
298 message: "circular inclusion detected".to_string(),
299 href: Some(href),
300 });
301 doc.detach(include_node);
302 return;
303 }
304
305 let content = resolver(base_href);
307
308 match content {
309 Some(content) => {
310 state.active_hrefs.insert(base_href.to_owned());
312
313 let success = match parse.as_str() {
314 "xml" => process_xml_include(doc, include_node, &content, resolver, state, depth),
315 "text" => process_text_include(doc, include_node, &content),
316 _ => false, };
318
319 state.active_hrefs.remove(base_href);
321
322 if success {
323 state.inclusions += 1;
324 }
325 }
326 None => {
327 if !try_fallback(doc, include_node, resolver, state, depth) {
329 state.errors.push(XIncludeError {
330 message: "resource not found and no xi:fallback provided".to_string(),
331 href: Some(href),
332 });
333 doc.detach(include_node);
334 }
335 }
336 }
337}
338
339fn process_xml_include<F>(
344 doc: &mut Document,
345 include_node: NodeId,
346 content: &str,
347 resolver: &F,
348 state: &mut ProcessingState,
349 depth: usize,
350) -> bool
351where
352 F: Fn(&str) -> Option<String>,
353{
354 let included_doc = match Document::parse_str(content) {
356 Ok(d) => d,
357 Err(e) => {
358 if try_fallback(doc, include_node, resolver, state, depth) {
360 return false;
361 }
362 state.errors.push(XIncludeError {
363 message: format!("failed to parse included XML: {e}"),
364 href: None,
365 });
366 doc.detach(include_node);
367 return false;
368 }
369 };
370
371 let included_root = included_doc.root();
374 let included_children: Vec<NodeId> = included_doc.children(included_root).collect();
375
376 let parent = doc.parent(include_node);
378
379 let mut inserted_nodes = Vec::new();
382 for inc_child in &included_children {
383 let new_node = deep_copy_node(doc, &included_doc, *inc_child);
384 inserted_nodes.push(new_node);
385 }
386
387 for new_node in &inserted_nodes {
389 doc.insert_before(include_node, *new_node);
390 }
391
392 doc.detach(include_node);
394
395 if parent.is_some() {
397 for new_node in inserted_nodes {
399 process_node(doc, new_node, resolver, state, depth + 1);
400 }
401 }
402
403 true
404}
405
406fn process_text_include(doc: &mut Document, include_node: NodeId, content: &str) -> bool {
411 let text_node = doc.create_node(NodeKind::Text {
412 content: content.to_string(),
413 });
414
415 doc.insert_before(include_node, text_node);
416 doc.detach(include_node);
417
418 true
419}
420
421fn try_fallback<F>(
426 doc: &mut Document,
427 include_node: NodeId,
428 resolver: &F,
429 state: &mut ProcessingState,
430 depth: usize,
431) -> bool
432where
433 F: Fn(&str) -> Option<String>,
434{
435 let fallback_node = {
437 let children: Vec<NodeId> = doc.children(include_node).collect();
438 children
439 .into_iter()
440 .find(|&child| is_fallback_element(doc, child))
441 };
442
443 let Some(fallback) = fallback_node else {
444 return false;
445 };
446
447 let fallback_children: Vec<NodeId> = doc.children(fallback).collect();
449
450 let mut inserted_nodes = Vec::new();
452 for child in fallback_children {
453 doc.detach(child);
454 doc.insert_before(include_node, child);
455 inserted_nodes.push(child);
456 }
457
458 doc.detach(include_node);
460
461 for node in inserted_nodes {
463 process_node(doc, node, resolver, state, depth + 1);
464 }
465
466 true
467}
468
469fn deep_copy_node(target: &mut Document, source: &Document, source_id: NodeId) -> NodeId {
475 let source_node = source.node(source_id);
476 let new_id = target.create_node(source_node.kind.clone());
477
478 let children: Vec<NodeId> = source.children(source_id).collect();
480 for child_id in children {
481 let new_child = deep_copy_node(target, source, child_id);
482 target.append_child(new_id, new_child);
483 }
484
485 new_id
486}
487
488fn split_fragment(href: &str) -> (&str, Option<&str>) {
493 if let Some(pos) = href.find('#') {
494 let (base, frag) = href.split_at(pos);
495 (base, Some(&frag[1..]))
497 } else {
498 (href, None)
499 }
500}
501
502#[cfg(test)]
503#[allow(clippy::unwrap_used)]
504mod tests {
505 use super::*;
506
507 fn process_with_resolver<F>(xml: &str, resolver: F) -> (Document, XIncludeResult)
510 where
511 F: Fn(&str) -> Option<String>,
512 {
513 let mut doc = Document::parse_str(xml).unwrap();
514 let result = process_xincludes(&mut doc, resolver, &XIncludeOptions::default());
515 (doc, result)
516 }
517
518 fn doc_text_content(doc: &Document) -> String {
520 let root_elem = doc.root_element().unwrap();
521 doc.text_content(root_elem)
522 }
523
524 #[test]
525 fn test_basic_xml_include() {
526 let xml =
527 r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="inc.xml"/></doc>"#;
528 let (doc, result) = process_with_resolver(xml, |href| match href {
529 "inc.xml" => Some("<greeting>hello</greeting>".to_string()),
530 _ => None,
531 });
532
533 assert_eq!(result.inclusions, 1);
534 assert!(result.errors.is_empty());
535
536 let root = doc.root_element().unwrap();
538 let children: Vec<NodeId> = doc.children(root).collect();
539 assert_eq!(children.len(), 1);
540 assert_eq!(doc.node_name(children[0]), Some("greeting"));
541 assert_eq!(doc.text_content(children[0]), "hello");
542 }
543
544 #[test]
545 fn test_basic_text_include() {
546 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="msg.txt" parse="text"/></doc>"#;
547 let (doc, result) = process_with_resolver(xml, |href| match href {
548 "msg.txt" => Some("Hello, World!".to_string()),
549 _ => None,
550 });
551
552 assert_eq!(result.inclusions, 1);
553 assert!(result.errors.is_empty());
554 assert_eq!(doc_text_content(&doc), "Hello, World!");
555 }
556
557 #[test]
558 fn test_fallback_when_resource_not_found() {
559 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="missing.xml"><xi:fallback><alt>fallback content</alt></xi:fallback></xi:include></doc>"#;
560 let (doc, result) = process_with_resolver(xml, |_| None);
561
562 assert_eq!(result.inclusions, 0);
563 assert!(result.errors.is_empty());
564
565 let root = doc.root_element().unwrap();
566 let children: Vec<NodeId> = doc.children(root).collect();
567 assert_eq!(children.len(), 1);
568 assert_eq!(doc.node_name(children[0]), Some("alt"));
569 assert_eq!(doc.text_content(children[0]), "fallback content");
570 }
571
572 #[test]
573 fn test_fallback_with_text_content() {
574 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="missing.xml"><xi:fallback>plain fallback</xi:fallback></xi:include></doc>"#;
575 let (doc, result) = process_with_resolver(xml, |_| None);
576
577 assert_eq!(result.inclusions, 0);
578 assert!(result.errors.is_empty());
579 assert_eq!(doc_text_content(&doc), "plain fallback");
580 }
581
582 #[test]
583 fn test_missing_href_attribute() {
584 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include/></doc>"#;
585 let (_doc, result) = process_with_resolver(xml, |_| None);
586
587 assert_eq!(result.inclusions, 0);
588 assert_eq!(result.errors.len(), 1);
589 assert!(result.errors[0].message.contains("missing required 'href'"));
590 assert!(result.errors[0].href.is_none());
591 }
592
593 #[test]
594 fn test_circular_inclusion_detection() {
595 let xml =
597 r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/></doc>"#;
598 let (_, result) = process_with_resolver(xml, |href| match href {
599 "a.xml" => Some(
600 r#"<a xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/></a>"#
601 .to_string(),
602 ),
603 _ => None,
604 });
605
606 assert_eq!(result.inclusions, 1);
608 assert_eq!(result.errors.len(), 1);
609 assert!(result.errors[0].message.contains("circular inclusion"));
610 }
611
612 #[test]
613 fn test_max_depth_exceeded() {
614 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deep.xml"/></doc>"#;
615 let mut doc = Document::parse_str(xml).unwrap();
616 let opts = XIncludeOptions { max_depth: 2 };
617
618 let result = process_xincludes(
620 &mut doc,
621 |href| {
622 match href {
623 "deep.xml" => Some(
624 r#"<level xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deeper.xml"/></level>"#
625 .to_string(),
626 ),
627 "deeper.xml" => Some(
628 r#"<level xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="deepest.xml"/></level>"#
629 .to_string(),
630 ),
631 "deepest.xml" => Some("<leaf/>".to_string()),
632 _ => None,
633 }
634 },
635 &opts,
636 );
637
638 assert!(result.errors.iter().any(|e| e.message.contains("depth")));
641 }
642
643 #[test]
644 fn test_multiple_includes_in_same_document() {
645 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="a.xml"/><xi:include href="b.xml"/></doc>"#;
646 let (doc, result) = process_with_resolver(xml, |href| match href {
647 "a.xml" => Some("<first/>".to_string()),
648 "b.xml" => Some("<second/>".to_string()),
649 _ => None,
650 });
651
652 assert_eq!(result.inclusions, 2);
653 assert!(result.errors.is_empty());
654
655 let root = doc.root_element().unwrap();
656 let children: Vec<NodeId> = doc.children(root).collect();
657 assert_eq!(children.len(), 2);
658 assert_eq!(doc.node_name(children[0]), Some("first"));
659 assert_eq!(doc.node_name(children[1]), Some("second"));
660 }
661
662 #[test]
663 fn test_nested_includes() {
664 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="outer.xml"/></doc>"#;
665 let (doc, result) = process_with_resolver(xml, |href| {
666 match href {
667 "outer.xml" => Some(
668 r#"<outer xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="inner.xml"/></outer>"#
669 .to_string(),
670 ),
671 "inner.xml" => Some("<inner>nested</inner>".to_string()),
672 _ => None,
673 }
674 });
675
676 assert_eq!(result.inclusions, 2);
677 assert!(result.errors.is_empty());
678
679 let root = doc.root_element().unwrap();
680 let outer: Vec<NodeId> = doc.children(root).collect();
681 assert_eq!(doc.node_name(outer[0]), Some("outer"));
682
683 let inner: Vec<NodeId> = doc.children(outer[0]).collect();
684 assert_eq!(doc.node_name(inner[0]), Some("inner"));
685 assert_eq!(doc.text_content(inner[0]), "nested");
686 }
687
688 #[test]
689 fn test_default_parse_attribute_is_xml() {
690 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="data.xml"/></doc>"#;
692 let (doc, result) = process_with_resolver(xml, |href| match href {
693 "data.xml" => Some("<item>value</item>".to_string()),
694 _ => None,
695 });
696
697 assert_eq!(result.inclusions, 1);
698 assert!(result.errors.is_empty());
699
700 let root = doc.root_element().unwrap();
701 let children: Vec<NodeId> = doc.children(root).collect();
702 assert_eq!(doc.node_name(children[0]), Some("item"));
703 }
704
705 #[test]
706 fn test_include_replaces_entire_xi_include_element() {
707 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><before/><xi:include href="mid.xml"/><after/></doc>"#;
709 let (doc, result) = process_with_resolver(xml, |href| match href {
710 "mid.xml" => Some("<middle/>".to_string()),
711 _ => None,
712 });
713
714 assert_eq!(result.inclusions, 1);
715
716 let root = doc.root_element().unwrap();
717 let names: Vec<Option<&str>> = doc.children(root).map(|c| doc.node_name(c)).collect();
718 assert_eq!(names, vec![Some("before"), Some("middle"), Some("after")]);
719 }
720
721 #[test]
722 fn test_text_include_preserves_whitespace() {
723 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="ws.txt" parse="text"/></doc>"#;
724 let content = " line1\n line2\n";
725 let (doc, result) = process_with_resolver(xml, |href| match href {
726 "ws.txt" => Some(content.to_string()),
727 _ => None,
728 });
729
730 assert_eq!(result.inclusions, 1);
731 assert_eq!(doc_text_content(&doc), content);
732 }
733
734 #[test]
735 fn test_empty_include_content() {
736 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="empty.txt" parse="text"/></doc>"#;
738 let (doc, result) = process_with_resolver(xml, |href| match href {
739 "empty.txt" => Some(String::new()),
740 _ => None,
741 });
742
743 assert_eq!(result.inclusions, 1);
744 assert!(result.errors.is_empty());
745 assert_eq!(doc_text_content(&doc), "");
746 }
747
748 #[test]
749 fn test_include_with_fragment_identifier() {
750 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="data.xml#section1"/></doc>"#;
753 let (doc, result) = process_with_resolver(xml, |href| match href {
754 "data.xml" => Some("<section>content</section>".to_string()),
755 _ => None,
756 });
757
758 assert_eq!(result.inclusions, 1);
759 assert!(result.errors.is_empty());
760
761 let root = doc.root_element().unwrap();
762 let children: Vec<NodeId> = doc.children(root).collect();
763 assert_eq!(doc.node_name(children[0]), Some("section"));
764 }
765
766 #[test]
767 fn test_xinclude_namespace_detection() {
768 let xml = r#"<doc><include href="should-ignore.xml"/></doc>"#;
770 let (_, result) = process_with_resolver(xml, |_| {
771 panic!("resolver should not be called for non-XInclude elements");
772 });
773
774 assert_eq!(result.inclusions, 0);
775 assert!(result.errors.is_empty());
776 }
777
778 #[test]
779 fn test_split_fragment() {
780 assert_eq!(split_fragment("file.xml#sec"), ("file.xml", Some("sec")));
781 assert_eq!(split_fragment("file.xml"), ("file.xml", None));
782 assert_eq!(split_fragment("file.xml#"), ("file.xml", Some("")));
783 assert_eq!(split_fragment("#frag"), ("", Some("frag")));
784 }
785
786 #[test]
787 fn test_no_fallback_records_error() {
788 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="nope.xml"/></doc>"#;
789 let (_, result) = process_with_resolver(xml, |_| None);
790
791 assert_eq!(result.inclusions, 0);
792 assert_eq!(result.errors.len(), 1);
793 assert!(result.errors[0].message.contains("resource not found"));
794 assert_eq!(result.errors[0].href.as_deref(), Some("nope.xml"));
795 }
796
797 #[test]
798 fn test_invalid_parse_attribute() {
799 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="x.xml" parse="json"/></doc>"#;
800 let (_, result) = process_with_resolver(xml, |_| None);
801
802 assert_eq!(result.errors.len(), 1);
803 assert!(result.errors[0].message.contains("invalid parse attribute"));
804 }
805
806 #[test]
807 fn test_xml_include_with_wrapper_element() {
808 let xml = r#"<doc xmlns:xi="http://www.w3.org/2001/XInclude"><xi:include href="multi.xml"/></doc>"#;
810 let (doc, result) = process_with_resolver(xml, |href| match href {
811 "multi.xml" => Some("<wrapper><first/><second/></wrapper>".to_string()),
812 _ => None,
813 });
814
815 assert_eq!(result.inclusions, 1);
816 assert!(result.errors.is_empty());
817
818 let root = doc.root_element().unwrap();
819 let children: Vec<NodeId> = doc.children(root).collect();
820 assert_eq!(children.len(), 1);
822 assert_eq!(doc.node_name(children[0]), Some("wrapper"));
823
824 let wrapper_children: Vec<NodeId> = doc.children(children[0]).collect();
825 assert_eq!(wrapper_children.len(), 2);
826 assert_eq!(doc.node_name(wrapper_children[0]), Some("first"));
827 assert_eq!(doc.node_name(wrapper_children[1]), Some("second"));
828 }
829
830 #[test]
831 fn test_options_default() {
832 let opts = XIncludeOptions::default();
833 assert_eq!(opts.max_depth, 50);
834 }
835
836 #[test]
837 fn test_error_display() {
838 let err = XIncludeError {
839 message: "resource not found".to_string(),
840 href: Some("file.xml".to_string()),
841 };
842 assert_eq!(
843 err.to_string(),
844 "XInclude error for 'file.xml': resource not found"
845 );
846
847 let err_no_href = XIncludeError {
848 message: "bad element".to_string(),
849 href: None,
850 };
851 assert_eq!(err_no_href.to_string(), "XInclude error: bad element");
852 }
853}