xml_3dm/
measure.rs

1//! Node similarity measurement.
2//!
3//! This module provides distance calculations between nodes using the Q-gram
4//! distance algorithm (Ukkonen92). It calculates content distance, child list
5//! distance, and matched child list distance between nodes.
6
7use std::collections::HashMap;
8
9use crate::node::{NodeRef, XmlContent};
10
11/// Maximum distance. The distance is normalized between 0 and MAX_DIST.
12pub const MAX_DIST: f64 = 1.0;
13
14/// Distance to return by childListDistance if both nodes have 0 children.
15pub const ZERO_CHILDREN_MATCH: f64 = 1.0;
16
17/// Info bytes in an element name ($c_e$ in thesis).
18pub const ELEMENT_NAME_INFO: i32 = 1;
19
20/// Info bytes in the presence of an attribute ($c_a$ in thesis).
21pub const ATTR_INFO: i32 = 2;
22
23/// Attribute values less than this have an info size of 1 ($c_v$ in thesis).
24pub const ATTR_VALUE_THRESHOLD: usize = 5;
25
26/// Text nodes shorter than this have an info size of 1 ($c_t$ in thesis).
27pub const TEXT_THRESHOLD: usize = 5;
28
29/// Penalty term ($c_p$ in thesis).
30const PENALTY_C: i32 = 20;
31
32/// Q-gram size. Note: In the Java implementation, decideQ() doesn't actually
33/// update this value, so it's effectively always 4. We match that behavior.
34const Q: usize = 4;
35
36/// Node similarity measurement calculator.
37///
38/// This struct calculates the content, child list and matched child list
39/// distance between nodes using the Q-gram distance algorithm.
40pub struct Measure {
41    /// Mismatched info bytes.
42    mismatched: i32,
43    /// Total info bytes.
44    total: i32,
45    /// Set to true if total mismatch occurs (e.g., text and element node compared).
46    total_mismatch: bool,
47    /// Hash tables used to store Q-grams.
48    a_grams: HashMap<String, i32>,
49    b_grams: HashMap<String, i32>,
50}
51
52impl Default for Measure {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl Measure {
59    /// Creates a new Measure instance.
60    pub fn new() -> Self {
61        Measure {
62            mismatched: 0,
63            total: 0,
64            total_mismatch: false,
65            a_grams: HashMap::with_capacity(2048),
66            b_grams: HashMap::with_capacity(2048),
67        }
68    }
69
70    /// Returns content distance between nodes.
71    ///
72    /// The distance is normalized between 0.0 and 1.0 (MAX_DIST).
73    /// Returns 1.0 if node types are incompatible (element vs text).
74    pub fn get_distance(&mut self, a: Option<&NodeRef>, b: Option<&NodeRef>) -> f64 {
75        if let (Some(a), Some(b)) = (a, b) {
76            self.include_nodes(a, b);
77        }
78
79        let result = if self.total_mismatch {
80            MAX_DIST
81        } else if self.total == 0 {
82            0.0
83        } else {
84            let penalty = (1.0 - (self.total as f64) / (PENALTY_C as f64)).max(0.0);
85            penalty + (1.0 - penalty) * (self.mismatched as f64) / (self.total as f64)
86        };
87
88        self.reset_distance();
89        result
90    }
91
92    /// Resets the distance calculation state.
93    fn reset_distance(&mut self) {
94        self.mismatched = 0;
95        self.total = 0;
96        self.total_mismatch = false;
97    }
98
99    /// Adds a node pair to the distance calculation state.
100    fn include_nodes(&mut self, a: &NodeRef, b: &NodeRef) {
101        if self.total_mismatch {
102            return;
103        }
104
105        let a_borrowed = a.borrow();
106        let b_borrowed = b.borrow();
107
108        let ca = a_borrowed.content();
109        let cb = b_borrowed.content();
110
111        match (ca, cb) {
112            (Some(XmlContent::Element(ea)), Some(XmlContent::Element(eb))) => {
113                // Compare element names
114                self.total += ELEMENT_NAME_INFO;
115                if ea.qname() != eb.qname() {
116                    self.mismatched += ELEMENT_NAME_INFO;
117                }
118
119                // Compare attributes present in a
120                let attrs_a = ea.attributes();
121                let attrs_b = eb.attributes();
122
123                for (name, v1) in attrs_a.iter() {
124                    if let Some(v2) = attrs_b.get(name) {
125                        // Attribute exists in both - compare values
126                        let amismatch = self.string_dist_str(v1, v2);
127                        let info = if v1.len() > ATTR_VALUE_THRESHOLD {
128                            v1.len() as i32
129                        } else {
130                            1
131                        } + if v2.len() > ATTR_VALUE_THRESHOLD {
132                            v2.len() as i32
133                        } else {
134                            1
135                        };
136                        self.mismatched += if amismatch > info { info } else { amismatch };
137                        self.total += info;
138                    } else {
139                        // Attribute only in a (deleted from b)
140                        self.mismatched += ATTR_INFO;
141                        self.total += ATTR_INFO;
142                    }
143                }
144
145                // Scan for attributes present in b but not in a
146                for name in attrs_b.keys() {
147                    if !attrs_a.contains_key(name) {
148                        self.mismatched += ATTR_INFO;
149                        self.total += ATTR_INFO;
150                    }
151                }
152            }
153            (Some(XmlContent::Text(ta)), Some(XmlContent::Text(tb))) => {
154                // Compare text content
155                let info = (ta.info_size() + tb.info_size()) / 2;
156                let amismatch = self.string_dist_chars(ta.text(), tb.text()) / 2;
157
158                self.mismatched += if amismatch > info { info } else { amismatch };
159                self.total += info;
160            }
161            _ => {
162                // Incompatible types (element vs text, or missing content)
163                self.total_mismatch = true;
164            }
165        }
166    }
167
168    /// Calculates string distance using Q-gram algorithm.
169    pub fn string_dist_str(&mut self, a: &str, b: &str) -> i32 {
170        self.q_dist_str(a, b)
171    }
172
173    /// Calculates string distance for char arrays using Q-gram algorithm.
174    pub fn string_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
175        self.q_dist_chars(a, b)
176    }
177
178    /// Calculates child list distance between nodes.
179    ///
180    /// This compares children by their content hashes to measure structural similarity.
181    /// Returns ZERO_CHILDREN_MATCH (1.0) if both nodes have no children.
182    pub fn child_list_distance(&mut self, a: &NodeRef, b: &NodeRef) -> f64 {
183        let a_borrowed = a.borrow();
184        let b_borrowed = b.borrow();
185
186        let a_count = a_borrowed.child_count();
187        let b_count = b_borrowed.child_count();
188
189        if a_count == 0 && b_count == 0 {
190            return ZERO_CHILDREN_MATCH;
191        }
192
193        // Build char arrays from content hashes (truncated to 16 bits)
194        let ac: Vec<char> = a_borrowed
195            .children()
196            .iter()
197            .map(|child| {
198                let child_borrowed = child.borrow();
199                if let Some(content) = child_borrowed.content() {
200                    char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
201                } else {
202                    '\0'
203                }
204            })
205            .collect();
206
207        let bc: Vec<char> = b_borrowed
208            .children()
209            .iter()
210            .map(|child| {
211                let child_borrowed = child.borrow();
212                if let Some(content) = child_borrowed.content() {
213                    char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
214                } else {
215                    '\0'
216                }
217            })
218            .collect();
219
220        let dist = self.string_dist_chars(&ac, &bc);
221        (dist as f64) / ((a_count + b_count) as f64)
222    }
223
224    /// Calculates matched child list distance.
225    ///
226    /// This measures how well children of a base node match children of a branch node,
227    /// based on their matching relationships (not content hashes).
228    ///
229    /// # Arguments
230    /// * `base` - The base node
231    /// * `branch` - The branch node
232    /// * `_is_left` - Whether the branch is the left branch (unused, for API compatibility)
233    pub fn matched_child_list_distance(
234        &mut self,
235        base: &NodeRef,
236        branch: &NodeRef,
237        _is_left: bool,
238    ) -> i32 {
239        let base_borrowed = base.borrow();
240        let branch_borrowed = branch.borrow();
241
242        // Build array for base children: position + 1 (to avoid 0 = -0 bug)
243        let ac: Vec<char> = (0..base_borrowed.child_count())
244            .map(|i| char::from_u32((i + 1) as u32).unwrap_or('\0'))
245            .collect();
246
247        // Build array for branch children based on their base match
248        let bc: Vec<char> = branch_borrowed
249            .children()
250            .iter()
251            .enumerate()
252            .map(|(i, child)| {
253                let child_borrowed = child.borrow();
254
255                // Get base match for this branch child
256                if let Some(base_match_weak) = child_borrowed.get_base_match() {
257                    if let Some(base_match) = base_match_weak.upgrade() {
258                        let base_match_borrowed = base_match.borrow();
259
260                        // Check if parent of base match is our base node
261                        if let Some(parent) = base_match_borrowed.parent().upgrade() {
262                            // Compare by ID
263                            if parent.borrow().id() == base_borrowed.id() {
264                                // Use child position (0-indexed, so add 1)
265                                return char::from_u32(
266                                    (base_match_borrowed.child_pos() + 1) as u32,
267                                )
268                                .unwrap_or('\0');
269                            }
270                        }
271                    }
272                }
273
274                // No match or different parent - use negative position
275                // Note: In Java, this uses -(i+1), but since we're using chars,
276                // we need to handle this differently. We'll use high values
277                // that won't collide with valid positions.
278                char::from_u32(0x10000 + i as u32).unwrap_or('\0')
279            })
280            .collect();
281
282        self.string_dist_chars(&ac, &bc)
283    }
284
285    /// Q-gram distance for strings (Ukkonen92 algorithm).
286    fn q_dist_str(&mut self, a: &str, b: &str) -> i32 {
287        // Note: Java's decideQ doesn't actually update Q, so it's always 4
288        // We match that behavior
289        self.a_grams.clear();
290        self.b_grams.clear();
291
292        // Build q-grams for string a
293        let chars_a: Vec<char> = a.chars().collect();
294        for i in 0..chars_a.len() {
295            let end = (i + Q).min(chars_a.len());
296            let gram: String = chars_a[i..end].iter().collect();
297            *self.a_grams.entry(gram).or_insert(0) += 1;
298        }
299
300        // Build q-grams for string b
301        let chars_b: Vec<char> = b.chars().collect();
302        for i in 0..chars_b.len() {
303            let end = (i + Q).min(chars_b.len());
304            let gram: String = chars_b[i..end].iter().collect();
305            *self.b_grams.entry(gram).or_insert(0) += 1;
306        }
307
308        self.calc_q_distance()
309    }
310
311    /// Q-gram distance for char arrays (Ukkonen92 algorithm).
312    fn q_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
313        self.a_grams.clear();
314        self.b_grams.clear();
315
316        // Build q-grams for array a
317        for i in 0..a.len() {
318            let end = (i + Q).min(a.len());
319            let gram: String = a[i..end].iter().collect();
320            *self.a_grams.entry(gram).or_insert(0) += 1;
321        }
322
323        // Build q-grams for array b
324        for i in 0..b.len() {
325            let end = (i + Q).min(b.len());
326            let gram: String = b[i..end].iter().collect();
327            *self.b_grams.entry(gram).or_insert(0) += 1;
328        }
329
330        self.calc_q_distance()
331    }
332
333    /// Builds Q-grams from a string into the provided map.
334    /// Used by tests to verify Q-gram construction.
335    #[cfg(test)]
336    fn build_q_grams_str(&self, s: &str, grams: &mut HashMap<String, i32>) {
337        grams.clear();
338        let chars: Vec<char> = s.chars().collect();
339        for i in 0..chars.len() {
340            let end = (i + Q).min(chars.len());
341            let gram: String = chars[i..end].iter().collect();
342            *grams.entry(gram).or_insert(0) += 1;
343        }
344    }
345
346    /// Calculates the Q-gram distance from the built gram tables.
347    fn calc_q_distance(&self) -> i32 {
348        let mut dist = 0;
349
350        // Loop over a_grams
351        for (gram, count_a) in &self.a_grams {
352            let count_b = self.b_grams.get(gram).copied().unwrap_or(0);
353            dist += (count_a - count_b).abs();
354        }
355
356        // Add grams present in b but not in a
357        for (gram, count_b) in &self.b_grams {
358            if !self.a_grams.contains_key(gram) {
359                dist += *count_b;
360            }
361        }
362
363        dist
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::node::{new_base_node, XmlElement, XmlText};
371    use std::collections::HashMap as StdHashMap;
372
373    #[test]
374    fn test_q_dist_identical_strings() {
375        let mut measure = Measure::new();
376        let dist = measure.string_dist_str("hello world", "hello world");
377        assert_eq!(dist, 0);
378    }
379
380    #[test]
381    fn test_q_dist_different_strings() {
382        let mut measure = Measure::new();
383        let dist = measure.string_dist_str("hello", "world");
384        assert!(dist > 0);
385    }
386
387    #[test]
388    fn test_q_dist_similar_strings() {
389        let mut measure = Measure::new();
390        // These differ by one character
391        let dist = measure.string_dist_str(
392            "return stringDist( a, b, a.length()+b.length() );",
393            "return stzingDist( a, b, a.length()+b.length() );",
394        );
395        // Should be small but non-zero
396        assert!(dist > 0);
397        assert!(dist < 20); // Reasonably similar
398    }
399
400    #[test]
401    fn test_q_dist_empty_strings() {
402        let mut measure = Measure::new();
403        let dist = measure.string_dist_str("", "");
404        assert_eq!(dist, 0);
405    }
406
407    #[test]
408    fn test_q_dist_one_empty() {
409        let mut measure = Measure::new();
410        let dist = measure.string_dist_str("hello", "");
411        assert!(dist > 0);
412    }
413
414    #[test]
415    fn test_get_distance_same_elements() {
416        let mut measure = Measure::new();
417
418        let attrs = StdHashMap::new();
419        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
420            "div".to_string(),
421            attrs.clone(),
422        ))));
423        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
424            "div".to_string(),
425            attrs,
426        ))));
427
428        let dist = measure.get_distance(Some(&a), Some(&b));
429        // With short content (just element name, total=1), penalty is applied
430        // penalty = max(0, 1 - 1/20) = 0.95
431        // distance = 0.95 + 0.05 * 0 = 0.95
432        assert!((dist - 0.95).abs() < 0.01);
433    }
434
435    #[test]
436    fn test_get_distance_different_element_names() {
437        let mut measure = Measure::new();
438
439        let attrs = StdHashMap::new();
440        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
441            "div".to_string(),
442            attrs.clone(),
443        ))));
444        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
445            "span".to_string(),
446            attrs,
447        ))));
448
449        let dist = measure.get_distance(Some(&a), Some(&b));
450        assert!(dist > 0.0);
451        assert!(dist <= 1.0);
452    }
453
454    #[test]
455    fn test_get_distance_same_attributes() {
456        let mut measure = Measure::new();
457
458        let mut attrs = StdHashMap::new();
459        attrs.insert("id".to_string(), "foo".to_string());
460        attrs.insert("class".to_string(), "bar".to_string());
461
462        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
463            "div".to_string(),
464            attrs.clone(),
465        ))));
466        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
467            "div".to_string(),
468            attrs,
469        ))));
470
471        let dist = measure.get_distance(Some(&a), Some(&b));
472        // With short content (element name + 2 attrs with short values):
473        // total = 1 + 2*2 = 5 (name + 2 attrs with info=1+1 each)
474        // penalty = max(0, 1 - 5/20) = 0.75
475        // Since identical, mismatched = 0, so distance = penalty = 0.75
476        assert!((dist - 0.75).abs() < 0.01);
477    }
478
479    #[test]
480    fn test_get_distance_different_attribute_values() {
481        let mut measure = Measure::new();
482
483        let mut attrs_a = StdHashMap::new();
484        attrs_a.insert("id".to_string(), "foo".to_string());
485
486        let mut attrs_b = StdHashMap::new();
487        attrs_b.insert("id".to_string(), "bar".to_string());
488
489        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
490            "div".to_string(),
491            attrs_a,
492        ))));
493        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
494            "div".to_string(),
495            attrs_b,
496        ))));
497
498        let dist = measure.get_distance(Some(&a), Some(&b));
499        assert!(dist > 0.0);
500        assert!(dist <= 1.0);
501    }
502
503    #[test]
504    fn test_get_distance_missing_attribute() {
505        let mut measure = Measure::new();
506
507        let mut attrs_a = StdHashMap::new();
508        attrs_a.insert("id".to_string(), "foo".to_string());
509
510        let attrs_b = StdHashMap::new();
511
512        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
513            "div".to_string(),
514            attrs_a,
515        ))));
516        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
517            "div".to_string(),
518            attrs_b,
519        ))));
520
521        let dist = measure.get_distance(Some(&a), Some(&b));
522        assert!(dist > 0.0);
523    }
524
525    #[test]
526    fn test_get_distance_same_text() {
527        let mut measure = Measure::new();
528
529        let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
530        let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
531
532        let dist = measure.get_distance(Some(&a), Some(&b));
533        // "hello world" has 11 chars, info_size = 11 per node
534        // total = (11 + 11) / 2 = 11
535        // penalty = max(0, 1 - 11/20) = 0.45
536        // Since identical, mismatched = 0, distance = penalty = 0.45
537        // But wait, let me check the actual value - shorter text has higher penalty
538        assert!(dist >= 0.0);
539        assert!(dist < 1.0); // Not a total mismatch
540                             // For longer identical texts, distance should be lower
541        let mut measure2 = Measure::new();
542        let long_text = "This is a much longer piece of text that should have lower penalty";
543        let c = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
544        let d = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
545        let dist2 = measure2.get_distance(Some(&c), Some(&d));
546        // Longer text should have lower distance (less penalty)
547        assert!(dist2 < dist);
548    }
549
550    #[test]
551    fn test_get_distance_different_text() {
552        let mut measure = Measure::new();
553
554        let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
555        let b = new_base_node(Some(XmlContent::Text(XmlText::new("world"))));
556
557        let dist = measure.get_distance(Some(&a), Some(&b));
558        assert!(dist > 0.0);
559        assert!(dist <= 1.0);
560    }
561
562    #[test]
563    fn test_get_distance_element_vs_text() {
564        let mut measure = Measure::new();
565
566        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
567            "div".to_string(),
568            StdHashMap::new(),
569        ))));
570        let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
571
572        let dist = measure.get_distance(Some(&a), Some(&b));
573        assert_eq!(dist, 1.0); // Total mismatch
574    }
575
576    #[test]
577    fn test_child_list_distance_both_empty() {
578        let mut measure = Measure::new();
579
580        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
581            "div".to_string(),
582            StdHashMap::new(),
583        ))));
584        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
585            "div".to_string(),
586            StdHashMap::new(),
587        ))));
588
589        let dist = measure.child_list_distance(&a, &b);
590        assert_eq!(dist, ZERO_CHILDREN_MATCH);
591    }
592
593    #[test]
594    fn test_child_list_distance_same_children() {
595        let mut measure = Measure::new();
596
597        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
598            "div".to_string(),
599            StdHashMap::new(),
600        ))));
601        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
602            "div".to_string(),
603            StdHashMap::new(),
604        ))));
605
606        // Add same children to both
607        let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
608            "span".to_string(),
609            StdHashMap::new(),
610        ))));
611        let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
612            "span".to_string(),
613            StdHashMap::new(),
614        ))));
615
616        crate::node::NodeInner::add_child_to_ref(&a, child1);
617        crate::node::NodeInner::add_child_to_ref(&b, child2);
618
619        let dist = measure.child_list_distance(&a, &b);
620        assert_eq!(dist, 0.0);
621    }
622
623    #[test]
624    fn test_child_list_distance_different_children() {
625        let mut measure = Measure::new();
626
627        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
628            "div".to_string(),
629            StdHashMap::new(),
630        ))));
631        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
632            "div".to_string(),
633            StdHashMap::new(),
634        ))));
635
636        // Add different children
637        let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
638            "span".to_string(),
639            StdHashMap::new(),
640        ))));
641        let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
642            "p".to_string(),
643            StdHashMap::new(),
644        ))));
645
646        crate::node::NodeInner::add_child_to_ref(&a, child1);
647        crate::node::NodeInner::add_child_to_ref(&b, child2);
648
649        let dist = measure.child_list_distance(&a, &b);
650        assert!(dist > 0.0);
651        assert!(dist <= 1.0);
652    }
653
654    #[test]
655    fn test_penalty_for_short_content() {
656        let mut measure = Measure::new();
657
658        // Very short text nodes should have penalty applied
659        let a = new_base_node(Some(XmlContent::Text(XmlText::new("a"))));
660        let b = new_base_node(Some(XmlContent::Text(XmlText::new("b"))));
661
662        let dist = measure.get_distance(Some(&a), Some(&b));
663        // With short content, penalty increases distance
664        assert!(dist > 0.0);
665    }
666
667    #[test]
668    fn test_q_grams_built_correctly() {
669        let measure = Measure::new();
670        let mut grams = StdHashMap::new();
671        measure.build_q_grams_str("hello", &mut grams);
672
673        // With Q=4, "hello" should produce grams: "hell", "ello", "llo", "lo", "o"
674        assert!(grams.contains_key("hell"));
675        assert!(grams.contains_key("ello"));
676        assert!(grams.contains_key("llo"));
677        assert!(grams.contains_key("lo"));
678        assert!(grams.contains_key("o"));
679    }
680}