1use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16
17use roxmltree::{Document, Node, NodeId};
18
19use super::types::{NodeSet, TransformData, TransformError};
20
21const DEFAULT_ID_ATTRS: &[&str] = &["ID", "Id", "id"];
28
29pub struct UriReferenceResolver<'a> {
51 doc: &'a Document<'a>,
52 id_map: HashMap<&'a str, Node<'a, 'a>>,
54}
55
56impl<'a> UriReferenceResolver<'a> {
57 pub fn new(doc: &'a Document<'a>) -> Self {
59 Self::with_id_attrs(doc, DEFAULT_ID_ATTRS)
60 }
61
62 pub fn with_id_attrs(doc: &'a Document<'a>, extra_attrs: &[&str]) -> Self {
73 let mut id_map = HashMap::new();
74 let mut duplicate_ids: HashSet<&'a str> = HashSet::new();
77
78 let mut attr_names: Vec<&str> = DEFAULT_ID_ATTRS.to_vec();
80 for name in extra_attrs {
81 if !attr_names.contains(name) {
82 attr_names.push(name);
83 }
84 }
85
86 for node in doc.descendants() {
88 if node.is_element() {
89 for attr_name in &attr_names {
90 if let Some(value) = node.attribute(*attr_name) {
91 if duplicate_ids.contains(value) {
93 continue;
94 }
95
96 match id_map.entry(value) {
101 Entry::Vacant(v) => {
102 v.insert(node);
103 }
104 Entry::Occupied(o) => {
105 if o.get().id() != node.id() {
110 o.remove();
111 duplicate_ids.insert(value);
112 }
113 }
114 }
115 }
116 }
117 }
118 }
119
120 Self { doc, id_map }
121 }
122
123 pub fn dereference(&self, uri: &str) -> Result<TransformData<'a>, TransformError> {
135 if uri.is_empty() {
136 Ok(TransformData::NodeSet(
140 NodeSet::entire_document_without_comments(self.doc),
141 ))
142 } else if let Some(fragment) = uri.strip_prefix('#') {
143 self.dereference_fragment(fragment)
148 } else {
149 Err(TransformError::UnsupportedUri(uri.to_string()))
150 }
151 }
152
153 fn dereference_fragment(&self, fragment: &str) -> Result<TransformData<'a>, TransformError> {
160 if fragment.is_empty() {
161 return Err(TransformError::UnsupportedUri("#".to_string()));
163 }
164
165 if fragment == "xpointer(/)" {
166 Ok(TransformData::NodeSet(
170 NodeSet::entire_document_with_comments(self.doc),
171 ))
172 } else if let Some(id) = parse_xpointer_id_fragment(fragment) {
173 if id.is_empty() {
176 return Err(TransformError::UnsupportedUri(format!("#{fragment}")));
177 }
178 self.resolve_id(id)
179 } else if fragment.starts_with("xpointer(") {
180 Err(TransformError::UnsupportedUri(format!("#{fragment}")))
182 } else {
183 self.resolve_id(fragment)
185 }
186 }
187
188 fn resolve_id(&self, id: &str) -> Result<TransformData<'a>, TransformError> {
190 match self.id_map.get(id) {
191 Some(&element) => Ok(TransformData::NodeSet(NodeSet::subtree(element))),
192 None => Err(TransformError::ElementNotFound(id.to_string())),
193 }
194 }
195
196 pub fn has_id(&self, id: &str) -> bool {
198 self.id_map.contains_key(id)
199 }
200
201 pub(crate) fn node_id_for_id(&self, id: &str) -> Option<NodeId> {
206 self.id_map.get(id).map(|node| node.id())
207 }
208
209 pub fn id_count(&self) -> usize {
211 self.id_map.len()
212 }
213}
214
215pub(crate) fn parse_xpointer_id_fragment(fragment: &str) -> Option<&str> {
218 let inner = fragment.strip_prefix("xpointer(id(")?.strip_suffix("))")?;
219
220 if let Some(stripped) = inner.strip_prefix('\'').and_then(|s| s.strip_suffix('\'')) {
223 Some(stripped)
224 } else if let Some(stripped) = inner.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
225 Some(stripped)
226 } else {
227 None
228 }
229}
230
231#[cfg(test)]
232#[allow(clippy::unwrap_used)]
233mod tests {
234 use super::super::types::NodeSet;
235 use super::*;
236
237 #[test]
238 fn empty_uri_returns_whole_document() {
239 let xml = "<root><child>text</child></root>";
240 let doc = Document::parse(xml).unwrap();
241 let resolver = UriReferenceResolver::new(&doc);
242
243 let data = resolver.dereference("").unwrap();
244 let node_set = data.into_node_set().unwrap();
245
246 let root = doc.root_element();
248 assert!(node_set.contains(root));
249 let child = root.first_child().unwrap();
250 assert!(node_set.contains(child));
251 }
252
253 #[test]
254 fn empty_uri_excludes_comments() {
255 let xml = "<root><!-- comment --><child/></root>";
256 let doc = Document::parse(xml).unwrap();
257 let resolver = UriReferenceResolver::new(&doc);
258
259 let data = resolver.dereference("").unwrap();
260 let node_set = data.into_node_set().unwrap();
261
262 for node in doc.descendants() {
264 if node.is_comment() {
265 assert!(
266 !node_set.contains(node),
267 "comment should be excluded for empty URI"
268 );
269 }
270 }
271 assert!(node_set.contains(doc.root_element()));
273 }
274
275 #[test]
276 fn fragment_uri_resolves_by_id_attr() {
277 let xml = r#"<root><item ID="abc">content</item><item ID="def">other</item></root>"#;
278 let doc = Document::parse(xml).unwrap();
279 let resolver = UriReferenceResolver::new(&doc);
280
281 let data = resolver.dereference("#abc").unwrap();
282 let node_set = data.into_node_set().unwrap();
283
284 let abc_elem = doc
286 .descendants()
287 .find(|n| n.attribute("ID") == Some("abc"))
288 .unwrap();
289 assert!(node_set.contains(abc_elem));
290
291 let text_child = abc_elem.first_child().unwrap();
293 assert!(node_set.contains(text_child));
294
295 assert!(!node_set.contains(doc.root_element()));
297
298 let def_elem = doc
300 .descendants()
301 .find(|n| n.attribute("ID") == Some("def"))
302 .unwrap();
303 assert!(!node_set.contains(def_elem));
304 }
305
306 #[test]
307 fn fragment_uri_resolves_lowercase_id() {
308 let xml = r#"<root><item id="lower">text</item></root>"#;
309 let doc = Document::parse(xml).unwrap();
310 let resolver = UriReferenceResolver::new(&doc);
311
312 let data = resolver.dereference("#lower").unwrap();
313 let node_set = data.into_node_set().unwrap();
314
315 let elem = doc
316 .descendants()
317 .find(|n| n.attribute("id") == Some("lower"))
318 .unwrap();
319 assert!(node_set.contains(elem));
320 }
321
322 #[test]
323 fn fragment_uri_resolves_mixed_case_id() {
324 let xml = r#"<root><ds:Signature Id="sig1" xmlns:ds="http://www.w3.org/2000/09/xmldsig#"/></root>"#;
325 let doc = Document::parse(xml).unwrap();
326 let resolver = UriReferenceResolver::new(&doc);
327
328 assert!(resolver.has_id("sig1"));
329 let data = resolver.dereference("#sig1").unwrap();
330 assert!(data.into_node_set().is_ok());
331 }
332
333 #[test]
334 fn fragment_uri_not_found() {
335 let xml = "<root><child>text</child></root>";
336 let doc = Document::parse(xml).unwrap();
337 let resolver = UriReferenceResolver::new(&doc);
338
339 let result = resolver.dereference("#nonexistent");
340 assert!(result.is_err());
341 match result.unwrap_err() {
342 TransformError::ElementNotFound(id) => assert_eq!(id, "nonexistent"),
343 other => panic!("expected ElementNotFound, got: {other:?}"),
344 }
345 }
346
347 #[test]
348 fn unsupported_external_uri() {
349 let xml = "<root/>";
350 let doc = Document::parse(xml).unwrap();
351 let resolver = UriReferenceResolver::new(&doc);
352
353 let result = resolver.dereference("http://example.com/doc.xml");
354 assert!(result.is_err());
355 match result.unwrap_err() {
356 TransformError::UnsupportedUri(uri) => {
357 assert_eq!(uri, "http://example.com/doc.xml")
358 }
359 other => panic!("expected UnsupportedUri, got: {other:?}"),
360 }
361 }
362
363 #[test]
364 fn unsupported_xpointer_expression() {
365 let xml = "<root/>";
368 let doc = Document::parse(xml).unwrap();
369 let resolver = UriReferenceResolver::new(&doc);
370
371 let result = resolver.dereference("#xpointer(foo())");
372 assert!(result.is_err());
373 match result.unwrap_err() {
374 TransformError::UnsupportedUri(uri) => {
375 assert_eq!(uri, "#xpointer(foo())")
376 }
377 other => panic!("expected UnsupportedUri, got: {other:?}"),
378 }
379
380 let result = resolver.dereference("#xpointer(//element)");
382 assert!(result.is_err());
383 assert!(matches!(
384 result.unwrap_err(),
385 TransformError::UnsupportedUri(_)
386 ));
387 }
388
389 #[test]
390 fn empty_fragment_rejected() {
391 let xml = "<root/>";
393 let doc = Document::parse(xml).unwrap();
394 let resolver = UriReferenceResolver::new(&doc);
395
396 let result = resolver.dereference("#");
397 assert!(result.is_err());
398 match result.unwrap_err() {
399 TransformError::UnsupportedUri(uri) => assert_eq!(uri, "#"),
400 other => panic!("expected UnsupportedUri, got: {other:?}"),
401 }
402 }
403
404 #[test]
405 fn foreign_document_node_rejected() {
406 let xml1 = "<root><child/></root>";
408 let xml2 = "<other><item/></other>";
409 let doc1 = Document::parse(xml1).unwrap();
410 let doc2 = Document::parse(xml2).unwrap();
411
412 let node_set = NodeSet::entire_document_without_comments(&doc1);
413
414 let foreign_node = doc2.root_element();
416 assert!(
417 !node_set.contains(foreign_node),
418 "foreign document node should be rejected"
419 );
420
421 let own_node = doc1.root_element();
423 assert!(node_set.contains(own_node));
424 }
425
426 #[test]
427 fn custom_id_attr_name() {
428 let xml = r#"<root><elem myid="custom1">data</elem></root>"#;
431 let doc = Document::parse(xml).unwrap();
432
433 let resolver_default = UriReferenceResolver::new(&doc);
435 assert!(!resolver_default.has_id("custom1"));
436
437 let resolver_custom = UriReferenceResolver::with_id_attrs(&doc, &["myid"]);
439 assert!(resolver_custom.has_id("custom1"));
440
441 let data = resolver_custom.dereference("#custom1").unwrap();
442 assert!(data.into_node_set().is_ok());
443 }
444
445 #[test]
446 fn namespaced_id_attr_found_by_local_name() {
447 let xml =
449 r#"<root><elem wsu:Id="ts1" xmlns:wsu="http://example.com/wsu">data</elem></root>"#;
450 let doc = Document::parse(xml).unwrap();
451
452 let resolver = UriReferenceResolver::new(&doc);
453 assert!(resolver.has_id("ts1"));
454 }
455
456 #[test]
457 fn id_count_reports_unique_ids() {
458 let xml = r#"<root ID="r1"><a ID="a1"/><b Id="b1"/><c id="c1"/></root>"#;
459 let doc = Document::parse(xml).unwrap();
460 let resolver = UriReferenceResolver::new(&doc);
461
462 assert_eq!(resolver.id_count(), 4);
464 }
465
466 #[test]
467 fn duplicate_ids_are_rejected() {
468 let xml = r#"<root><a ID="dup">first</a><b ID="dup">second</b></root>"#;
471 let doc = Document::parse(xml).unwrap();
472 let resolver = UriReferenceResolver::new(&doc);
473
474 assert!(!resolver.has_id("dup"));
476 let result = resolver.dereference("#dup");
477 assert!(result.is_err());
478 assert!(matches!(
479 result.unwrap_err(),
480 TransformError::ElementNotFound(_)
481 ));
482 }
483
484 #[test]
485 fn triple_duplicate_ids_stay_rejected() {
486 let xml = r#"<root><a ID="dup">1</a><b ID="dup">2</b><c ID="dup">3</c></root>"#;
489 let doc = Document::parse(xml).unwrap();
490 let resolver = UriReferenceResolver::new(&doc);
491
492 assert!(!resolver.has_id("dup"));
493 assert!(resolver.dereference("#dup").is_err());
494 }
495
496 #[test]
497 fn node_set_exclude_subtree() {
498 let xml = r#"<root><keep>yes</keep><remove><deep>no</deep></remove></root>"#;
499 let doc = Document::parse(xml).unwrap();
500 let resolver = UriReferenceResolver::new(&doc);
501
502 let data = resolver.dereference("").unwrap();
503 let mut node_set = data.into_node_set().unwrap();
504
505 let remove_elem = doc
507 .descendants()
508 .find(|n| n.is_element() && n.has_tag_name("remove"))
509 .unwrap();
510 node_set.exclude_subtree(remove_elem);
511
512 let keep_elem = doc
514 .descendants()
515 .find(|n| n.is_element() && n.has_tag_name("keep"))
516 .unwrap();
517 assert!(node_set.contains(keep_elem));
518
519 assert!(!node_set.contains(remove_elem));
521 let deep_elem = doc
522 .descendants()
523 .find(|n| n.is_element() && n.has_tag_name("deep"))
524 .unwrap();
525 assert!(!node_set.contains(deep_elem));
526 }
527
528 #[test]
529 fn subtree_includes_comments() {
530 let xml = r#"<root><item ID="x"><!-- comment --><child/></item></root>"#;
532 let doc = Document::parse(xml).unwrap();
533 let resolver = UriReferenceResolver::new(&doc);
534
535 let data = resolver.dereference("#x").unwrap();
536 let node_set = data.into_node_set().unwrap();
537
538 for node in doc.descendants() {
539 if node.is_comment() {
540 assert!(
541 node_set.contains(node),
542 "comment should be included in #id subtree"
543 );
544 }
545 }
546 }
547
548 #[test]
549 fn xpointer_root_returns_whole_document_with_comments() {
550 let xml = "<root><!-- comment --><child/></root>";
551 let doc = Document::parse(xml).unwrap();
552 let resolver = UriReferenceResolver::new(&doc);
553
554 let data = resolver.dereference("#xpointer(/)").unwrap();
555 let node_set = data.into_node_set().unwrap();
556
557 for node in doc.descendants() {
559 if node.is_comment() {
560 assert!(
561 node_set.contains(node),
562 "comment should be included for #xpointer(/)"
563 );
564 }
565 }
566 assert!(node_set.contains(doc.root_element()));
567 }
568
569 #[test]
570 fn xpointer_id_single_quotes() {
571 let xml = r#"<root><item ID="abc">content</item></root>"#;
572 let doc = Document::parse(xml).unwrap();
573 let resolver = UriReferenceResolver::new(&doc);
574
575 let data = resolver.dereference("#xpointer(id('abc'))").unwrap();
576 let node_set = data.into_node_set().unwrap();
577
578 let elem = doc
579 .descendants()
580 .find(|n| n.attribute("ID") == Some("abc"))
581 .unwrap();
582 assert!(node_set.contains(elem));
583 }
584
585 #[test]
586 fn xpointer_id_double_quotes() {
587 let xml = r#"<root><item ID="xyz">content</item></root>"#;
588 let doc = Document::parse(xml).unwrap();
589 let resolver = UriReferenceResolver::new(&doc);
590
591 let data = resolver.dereference(r#"#xpointer(id("xyz"))"#).unwrap();
592 let node_set = data.into_node_set().unwrap();
593
594 let elem = doc
595 .descendants()
596 .find(|n| n.attribute("ID") == Some("xyz"))
597 .unwrap();
598 assert!(node_set.contains(elem));
599 }
600
601 #[test]
602 fn xpointer_id_not_found() {
603 let xml = "<root/>";
604 let doc = Document::parse(xml).unwrap();
605 let resolver = UriReferenceResolver::new(&doc);
606
607 let result = resolver.dereference("#xpointer(id('missing'))");
608 assert!(result.is_err());
609 match result.unwrap_err() {
610 TransformError::ElementNotFound(id) => assert_eq!(id, "missing"),
611 other => panic!("expected ElementNotFound, got: {other:?}"),
612 }
613 }
614
615 #[test]
616 fn xpointer_id_empty_value_rejected() {
617 let xml = "<root/>";
619 let doc = Document::parse(xml).unwrap();
620 let resolver = UriReferenceResolver::new(&doc);
621
622 let result = resolver.dereference("#xpointer(id(''))");
623 assert!(result.is_err());
624 assert!(matches!(
625 result.unwrap_err(),
626 TransformError::UnsupportedUri(_)
627 ));
628 }
629
630 #[test]
631 fn parse_xpointer_id_variants() {
632 assert_eq!(
634 super::parse_xpointer_id_fragment("xpointer(id('foo'))"),
635 Some("foo")
636 );
637 assert_eq!(
638 super::parse_xpointer_id_fragment(r#"xpointer(id("bar"))"#),
639 Some("bar")
640 );
641
642 assert_eq!(super::parse_xpointer_id_fragment("xpointer(/)"), None);
644 assert_eq!(super::parse_xpointer_id_fragment("xpointer(id(foo))"), None); assert_eq!(super::parse_xpointer_id_fragment("not-xpointer"), None);
646 assert_eq!(super::parse_xpointer_id_fragment(""), None);
647
648 assert_eq!(super::parse_xpointer_id_fragment("xpointer(id('))"), None);
650 assert_eq!(
651 super::parse_xpointer_id_fragment(r#"xpointer(id("))"#),
652 None
653 );
654 }
655
656 #[test]
657 fn same_element_multiple_id_attrs_not_duplicate() {
658 let xml = r#"<root><item ID="x" Id="x">data</item></root>"#;
662 let doc = Document::parse(xml).unwrap();
663 let resolver = UriReferenceResolver::new(&doc);
664
665 assert!(resolver.has_id("x"));
666 assert!(resolver.dereference("#x").is_ok());
667 }
668
669 #[test]
670 fn saml_style_document() {
671 let xml = r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
673 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
674 ID="_resp1">
675 <saml:Assertion ID="_assert1">
676 <saml:Subject>user@example.com</saml:Subject>
677 </saml:Assertion>
678 <ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#" Id="sig1">
679 <ds:SignedInfo/>
680 </ds:Signature>
681 </samlp:Response>"#;
682
683 let doc = Document::parse(xml).unwrap();
684 let resolver = UriReferenceResolver::new(&doc);
685
686 assert!(resolver.has_id("_resp1"));
688 assert!(resolver.has_id("_assert1"));
689 assert!(resolver.has_id("sig1"));
690 assert_eq!(resolver.id_count(), 3);
691
692 let data = resolver.dereference("#_assert1").unwrap();
694 let node_set = data.into_node_set().unwrap();
695
696 let assertion = doc
698 .descendants()
699 .find(|n| n.attribute("ID") == Some("_assert1"))
700 .unwrap();
701 assert!(node_set.contains(assertion));
702
703 let subject = assertion
705 .children()
706 .find(|n| n.is_element() && n.has_tag_name("Subject"))
707 .unwrap();
708 assert!(node_set.contains(subject));
709
710 assert!(!node_set.contains(doc.root_element()));
712 }
713}