Skip to main content

xml_3dm/
measure.rs

1//! Node similarity measurement.
2//!
3//! This module provides distance calculations between nodes using the Q-gram
4//! distance algorithm (Ukkonen92). It calculates content distance, child list
5//! distance, and matched child list distance between nodes.
6
7use std::collections::HashMap;
8
9use crate::node::{NodeRef, XmlContent};
10
11/// Maximum distance. The distance is normalized between 0 and MAX_DIST.
12pub const MAX_DIST: f64 = 1.0;
13
14/// Distance to return by childListDistance if both nodes have 0 children.
15pub const ZERO_CHILDREN_MATCH: f64 = 1.0;
16
17/// Info bytes in an element name ($c_e$ in thesis).
18pub const ELEMENT_NAME_INFO: i32 = 1;
19
20/// Info bytes in the presence of an attribute ($c_a$ in thesis).
21pub const ATTR_INFO: i32 = 2;
22
23/// Attribute values less than this have an info size of 1 ($c_v$ in thesis).
24pub const ATTR_VALUE_THRESHOLD: usize = 5;
25
26/// Text nodes shorter than this have an info size of 1 ($c_t$ in thesis).
27pub const TEXT_THRESHOLD: usize = 5;
28
29/// Penalty term ($c_p$ in thesis).
30const PENALTY_C: i32 = 20;
31
32/// Determines Q-gram size based on combined string length (thesis equation 6.1).
33/// Returns q=1 for very short strings, q=2 for medium, q=4 for long.
34fn decide_q(combined_len: usize) -> usize {
35    if combined_len < 5 {
36        1
37    } else if combined_len < 50 {
38        2
39    } else {
40        4
41    }
42}
43
44/// Node similarity measurement calculator.
45///
46/// This struct calculates the content, child list and matched child list
47/// distance between nodes using the Q-gram distance algorithm.
48pub struct Measure {
49    /// Mismatched info bytes.
50    mismatched: i32,
51    /// Total info bytes.
52    total: i32,
53    /// Set to true if total mismatch occurs (e.g., text and element node compared).
54    total_mismatch: bool,
55    /// Hash tables used to store Q-grams.
56    a_grams: HashMap<String, i32>,
57    b_grams: HashMap<String, i32>,
58}
59
60impl Default for Measure {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Measure {
67    /// Creates a new Measure instance.
68    pub fn new() -> Self {
69        Measure {
70            mismatched: 0,
71            total: 0,
72            total_mismatch: false,
73            a_grams: HashMap::with_capacity(2048),
74            b_grams: HashMap::with_capacity(2048),
75        }
76    }
77
78    /// Returns content distance between nodes.
79    ///
80    /// The distance is normalized between 0.0 and 1.0 (MAX_DIST).
81    /// Returns 1.0 if node types are incompatible (element vs text).
82    pub fn get_distance(&mut self, a: Option<&NodeRef>, b: Option<&NodeRef>) -> f64 {
83        if let (Some(a), Some(b)) = (a, b) {
84            self.include_nodes(a, b);
85        }
86
87        let result = if self.total_mismatch {
88            MAX_DIST
89        } else if self.total == 0 {
90            0.0
91        } else {
92            let penalty = (1.0 - (self.total as f64) / (PENALTY_C as f64)).max(0.0);
93            penalty + (1.0 - penalty) * (self.mismatched as f64) / (self.total as f64)
94        };
95
96        self.reset_distance();
97        result
98    }
99
100    /// Resets the distance calculation state.
101    fn reset_distance(&mut self) {
102        self.mismatched = 0;
103        self.total = 0;
104        self.total_mismatch = false;
105    }
106
107    /// Adds a node pair to the distance calculation state.
108    fn include_nodes(&mut self, a: &NodeRef, b: &NodeRef) {
109        if self.total_mismatch {
110            return;
111        }
112
113        let a_borrowed = a.borrow();
114        let b_borrowed = b.borrow();
115
116        let ca = a_borrowed.content();
117        let cb = b_borrowed.content();
118
119        match (ca, cb) {
120            (Some(XmlContent::Element(ea)), Some(XmlContent::Element(eb))) => {
121                // Compare element names
122                self.total += ELEMENT_NAME_INFO;
123                if ea.qname() != eb.qname() {
124                    self.mismatched += ELEMENT_NAME_INFO;
125                }
126
127                // Compare attributes present in a
128                let attrs_a = ea.attributes();
129                let attrs_b = eb.attributes();
130
131                for (name, v1) in attrs_a.iter() {
132                    if let Some(v2) = attrs_b.get(name) {
133                        // Attribute exists in both - compare values
134                        let amismatch = self.string_dist_str(v1, v2);
135                        let info = if v1.len() > ATTR_VALUE_THRESHOLD {
136                            v1.len() as i32
137                        } else {
138                            1
139                        } + if v2.len() > ATTR_VALUE_THRESHOLD {
140                            v2.len() as i32
141                        } else {
142                            1
143                        };
144                        self.mismatched += if amismatch > info { info } else { amismatch };
145                        self.total += info;
146                    } else {
147                        // Attribute only in a (deleted from b)
148                        self.mismatched += ATTR_INFO;
149                        self.total += ATTR_INFO;
150                    }
151                }
152
153                // Scan for attributes present in b but not in a
154                for name in attrs_b.keys() {
155                    if !attrs_a.contains_key(name) {
156                        self.mismatched += ATTR_INFO;
157                        self.total += ATTR_INFO;
158                    }
159                }
160            }
161            (Some(XmlContent::Text(ta)), Some(XmlContent::Text(tb))) => {
162                // Compare text content
163                let info = (ta.info_size() + tb.info_size()) / 2;
164                let amismatch = self.string_dist_chars(ta.text(), tb.text()) / 2;
165
166                self.mismatched += if amismatch > info { info } else { amismatch };
167                self.total += info;
168            }
169            _ => {
170                // Incompatible types (element vs text, or missing content)
171                self.total_mismatch = true;
172            }
173        }
174    }
175
176    /// Calculates string distance using Q-gram algorithm.
177    pub fn string_dist_str(&mut self, a: &str, b: &str) -> i32 {
178        self.q_dist_str(a, b)
179    }
180
181    /// Calculates string distance for char arrays using Q-gram algorithm.
182    pub fn string_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
183        self.q_dist_chars(a, b)
184    }
185
186    /// Calculates child list distance between nodes.
187    ///
188    /// This compares children by their content hashes to measure structural similarity.
189    /// Returns ZERO_CHILDREN_MATCH (1.0) if both nodes have no children.
190    pub fn child_list_distance(&mut self, a: &NodeRef, b: &NodeRef) -> f64 {
191        let a_borrowed = a.borrow();
192        let b_borrowed = b.borrow();
193
194        let a_count = a_borrowed.child_count();
195        let b_count = b_borrowed.child_count();
196
197        if a_count == 0 && b_count == 0 {
198            return ZERO_CHILDREN_MATCH;
199        }
200
201        // Build char arrays from content hashes (truncated to 16 bits)
202        let ac: Vec<char> = a_borrowed
203            .children()
204            .iter()
205            .map(|child| {
206                let child_borrowed = child.borrow();
207                if let Some(content) = child_borrowed.content() {
208                    char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
209                } else {
210                    '\0'
211                }
212            })
213            .collect();
214
215        let bc: Vec<char> = b_borrowed
216            .children()
217            .iter()
218            .map(|child| {
219                let child_borrowed = child.borrow();
220                if let Some(content) = child_borrowed.content() {
221                    char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
222                } else {
223                    '\0'
224                }
225            })
226            .collect();
227
228        let dist = self.string_dist_chars(&ac, &bc);
229        (dist as f64) / ((a_count + b_count) as f64)
230    }
231
232    /// Calculates matched child list distance.
233    ///
234    /// This measures how well children of a base node match children of a branch node,
235    /// based on their matching relationships (not content hashes).
236    ///
237    /// # Arguments
238    /// * `base` - The base node
239    /// * `branch` - The branch node
240    /// * `_is_left` - Whether the branch is the left branch (unused, for API compatibility)
241    pub fn matched_child_list_distance(
242        &mut self,
243        base: &NodeRef,
244        branch: &NodeRef,
245        _is_left: bool,
246    ) -> i32 {
247        let base_borrowed = base.borrow();
248        let branch_borrowed = branch.borrow();
249
250        // Build array for base children: position + 1 (to avoid 0 = -0 bug)
251        let ac: Vec<char> = (0..base_borrowed.child_count())
252            .map(|i| char::from_u32((i + 1) as u32).unwrap_or('\0'))
253            .collect();
254
255        // Build array for branch children based on their base match
256        let bc: Vec<char> = branch_borrowed
257            .children()
258            .iter()
259            .enumerate()
260            .map(|(i, child)| {
261                let child_borrowed = child.borrow();
262
263                // Get base match for this branch child
264                if let Some(base_match_weak) = child_borrowed.get_base_match() {
265                    if let Some(base_match) = base_match_weak.upgrade() {
266                        let base_match_borrowed = base_match.borrow();
267
268                        // Check if parent of base match is our base node
269                        if let Some(parent) = base_match_borrowed.parent().upgrade() {
270                            // Compare by ID
271                            if parent.borrow().id() == base_borrowed.id() {
272                                // Use child position (0-indexed, so add 1)
273                                return char::from_u32(
274                                    (base_match_borrowed.child_pos() + 1) as u32,
275                                )
276                                .unwrap_or('\0');
277                            }
278                        }
279                    }
280                }
281
282                // No match or different parent - use negative position
283                // Note: In Java, this uses -(i+1), but since we're using chars,
284                // we need to handle this differently. We'll use high values
285                // that won't collide with valid positions.
286                char::from_u32(0x10000 + i as u32).unwrap_or('\0')
287            })
288            .collect();
289
290        self.string_dist_chars(&ac, &bc)
291    }
292
293    /// Q-gram distance for strings (Ukkonen92 algorithm).
294    fn q_dist_str(&mut self, a: &str, b: &str) -> i32 {
295        self.a_grams.clear();
296        self.b_grams.clear();
297
298        // Adaptive Q based on combined length
299        let q = decide_q(a.len() + b.len());
300
301        // Build q-grams for string a
302        let chars_a: Vec<char> = a.chars().collect();
303        for i in 0..chars_a.len() {
304            let end = (i + q).min(chars_a.len());
305            let gram: String = chars_a[i..end].iter().collect();
306            *self.a_grams.entry(gram).or_insert(0) += 1;
307        }
308
309        // Build q-grams for string b
310        let chars_b: Vec<char> = b.chars().collect();
311        for i in 0..chars_b.len() {
312            let end = (i + q).min(chars_b.len());
313            let gram: String = chars_b[i..end].iter().collect();
314            *self.b_grams.entry(gram).or_insert(0) += 1;
315        }
316
317        self.calc_q_distance()
318    }
319
320    /// Q-gram distance for char arrays (Ukkonen92 algorithm).
321    fn q_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
322        self.a_grams.clear();
323        self.b_grams.clear();
324
325        // Adaptive Q based on combined length
326        let q = decide_q(a.len() + b.len());
327
328        // Build q-grams for array a
329        for i in 0..a.len() {
330            let end = (i + q).min(a.len());
331            let gram: String = a[i..end].iter().collect();
332            *self.a_grams.entry(gram).or_insert(0) += 1;
333        }
334
335        // Build q-grams for array b
336        for i in 0..b.len() {
337            let end = (i + q).min(b.len());
338            let gram: String = b[i..end].iter().collect();
339            *self.b_grams.entry(gram).or_insert(0) += 1;
340        }
341
342        self.calc_q_distance()
343    }
344
345    /// Builds Q-grams from a string into the provided map.
346    /// Used by tests to verify Q-gram construction.
347    /// Assumes comparing string to itself (combined_len = 2 * len).
348    #[cfg(test)]
349    fn build_q_grams_str(&self, s: &str, grams: &mut HashMap<String, i32>) {
350        grams.clear();
351        let chars: Vec<char> = s.chars().collect();
352        let q = decide_q(chars.len() * 2);
353        for i in 0..chars.len() {
354            let end = (i + q).min(chars.len());
355            let gram: String = chars[i..end].iter().collect();
356            *grams.entry(gram).or_insert(0) += 1;
357        }
358    }
359
360    /// Calculates the Q-gram distance from the built gram tables.
361    fn calc_q_distance(&self) -> i32 {
362        let mut dist = 0;
363
364        // Loop over a_grams
365        for (gram, count_a) in &self.a_grams {
366            let count_b = self.b_grams.get(gram).copied().unwrap_or(0);
367            dist += (count_a - count_b).abs();
368        }
369
370        // Add grams present in b but not in a
371        for (gram, count_b) in &self.b_grams {
372            if !self.a_grams.contains_key(gram) {
373                dist += *count_b;
374            }
375        }
376
377        dist
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::node::{new_base_node, XmlElement, XmlText};
385    use std::collections::HashMap as StdHashMap;
386
387    #[test]
388    fn test_q_dist_identical_strings() {
389        let mut measure = Measure::new();
390        let dist = measure.string_dist_str("hello world", "hello world");
391        assert_eq!(dist, 0);
392    }
393
394    #[test]
395    fn test_q_dist_different_strings() {
396        let mut measure = Measure::new();
397        let dist = measure.string_dist_str("hello", "world");
398        assert!(dist > 0);
399    }
400
401    #[test]
402    fn test_q_dist_similar_strings() {
403        let mut measure = Measure::new();
404        // These differ by one character
405        let dist = measure.string_dist_str(
406            "return stringDist( a, b, a.length()+b.length() );",
407            "return stzingDist( a, b, a.length()+b.length() );",
408        );
409        // Should be small but non-zero
410        assert!(dist > 0);
411        assert!(dist < 20); // Reasonably similar
412    }
413
414    #[test]
415    fn test_q_dist_empty_strings() {
416        let mut measure = Measure::new();
417        let dist = measure.string_dist_str("", "");
418        assert_eq!(dist, 0);
419    }
420
421    #[test]
422    fn test_q_dist_one_empty() {
423        let mut measure = Measure::new();
424        let dist = measure.string_dist_str("hello", "");
425        assert!(dist > 0);
426    }
427
428    #[test]
429    fn test_get_distance_same_elements() {
430        let mut measure = Measure::new();
431
432        let attrs = StdHashMap::new();
433        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
434            "div".to_string(),
435            attrs.clone(),
436        ))));
437        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
438            "div".to_string(),
439            attrs,
440        ))));
441
442        let dist = measure.get_distance(Some(&a), Some(&b));
443        // With short content (just element name, total=1), penalty is applied
444        // penalty = max(0, 1 - 1/20) = 0.95
445        // distance = 0.95 + 0.05 * 0 = 0.95
446        assert!((dist - 0.95).abs() < 0.01);
447    }
448
449    #[test]
450    fn test_get_distance_different_element_names() {
451        let mut measure = Measure::new();
452
453        let attrs = StdHashMap::new();
454        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
455            "div".to_string(),
456            attrs.clone(),
457        ))));
458        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
459            "span".to_string(),
460            attrs,
461        ))));
462
463        let dist = measure.get_distance(Some(&a), Some(&b));
464        assert!(dist > 0.0);
465        assert!(dist <= 1.0);
466    }
467
468    #[test]
469    fn test_get_distance_same_attributes() {
470        let mut measure = Measure::new();
471
472        let mut attrs = StdHashMap::new();
473        attrs.insert("id".to_string(), "foo".to_string());
474        attrs.insert("class".to_string(), "bar".to_string());
475
476        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
477            "div".to_string(),
478            attrs.clone(),
479        ))));
480        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
481            "div".to_string(),
482            attrs,
483        ))));
484
485        let dist = measure.get_distance(Some(&a), Some(&b));
486        // With short content (element name + 2 attrs with short values):
487        // total = 1 + 2*2 = 5 (name + 2 attrs with info=1+1 each)
488        // penalty = max(0, 1 - 5/20) = 0.75
489        // Since identical, mismatched = 0, so distance = penalty = 0.75
490        assert!((dist - 0.75).abs() < 0.01);
491    }
492
493    #[test]
494    fn test_get_distance_different_attribute_values() {
495        let mut measure = Measure::new();
496
497        let mut attrs_a = StdHashMap::new();
498        attrs_a.insert("id".to_string(), "foo".to_string());
499
500        let mut attrs_b = StdHashMap::new();
501        attrs_b.insert("id".to_string(), "bar".to_string());
502
503        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
504            "div".to_string(),
505            attrs_a,
506        ))));
507        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
508            "div".to_string(),
509            attrs_b,
510        ))));
511
512        let dist = measure.get_distance(Some(&a), Some(&b));
513        assert!(dist > 0.0);
514        assert!(dist <= 1.0);
515    }
516
517    #[test]
518    fn test_get_distance_missing_attribute() {
519        let mut measure = Measure::new();
520
521        let mut attrs_a = StdHashMap::new();
522        attrs_a.insert("id".to_string(), "foo".to_string());
523
524        let attrs_b = StdHashMap::new();
525
526        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
527            "div".to_string(),
528            attrs_a,
529        ))));
530        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
531            "div".to_string(),
532            attrs_b,
533        ))));
534
535        let dist = measure.get_distance(Some(&a), Some(&b));
536        assert!(dist > 0.0);
537    }
538
539    #[test]
540    fn test_get_distance_same_text() {
541        let mut measure = Measure::new();
542
543        let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
544        let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
545
546        let dist = measure.get_distance(Some(&a), Some(&b));
547        // "hello world" has 11 chars, info_size = 11 per node
548        // total = (11 + 11) / 2 = 11
549        // penalty = max(0, 1 - 11/20) = 0.45
550        // Since identical, mismatched = 0, distance = penalty = 0.45
551        // But wait, let me check the actual value - shorter text has higher penalty
552        assert!(dist >= 0.0);
553        assert!(dist < 1.0); // Not a total mismatch
554                             // For longer identical texts, distance should be lower
555        let mut measure2 = Measure::new();
556        let long_text = "This is a much longer piece of text that should have lower penalty";
557        let c = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
558        let d = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
559        let dist2 = measure2.get_distance(Some(&c), Some(&d));
560        // Longer text should have lower distance (less penalty)
561        assert!(dist2 < dist);
562    }
563
564    #[test]
565    fn test_get_distance_different_text() {
566        let mut measure = Measure::new();
567
568        let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
569        let b = new_base_node(Some(XmlContent::Text(XmlText::new("world"))));
570
571        let dist = measure.get_distance(Some(&a), Some(&b));
572        assert!(dist > 0.0);
573        assert!(dist <= 1.0);
574    }
575
576    #[test]
577    fn test_get_distance_element_vs_text() {
578        let mut measure = Measure::new();
579
580        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
581            "div".to_string(),
582            StdHashMap::new(),
583        ))));
584        let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
585
586        let dist = measure.get_distance(Some(&a), Some(&b));
587        assert_eq!(dist, 1.0); // Total mismatch
588    }
589
590    #[test]
591    fn test_child_list_distance_both_empty() {
592        let mut measure = Measure::new();
593
594        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
595            "div".to_string(),
596            StdHashMap::new(),
597        ))));
598        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
599            "div".to_string(),
600            StdHashMap::new(),
601        ))));
602
603        let dist = measure.child_list_distance(&a, &b);
604        assert_eq!(dist, ZERO_CHILDREN_MATCH);
605    }
606
607    #[test]
608    fn test_child_list_distance_same_children() {
609        let mut measure = Measure::new();
610
611        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
612            "div".to_string(),
613            StdHashMap::new(),
614        ))));
615        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
616            "div".to_string(),
617            StdHashMap::new(),
618        ))));
619
620        // Add same children to both
621        let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
622            "span".to_string(),
623            StdHashMap::new(),
624        ))));
625        let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
626            "span".to_string(),
627            StdHashMap::new(),
628        ))));
629
630        crate::node::NodeInner::add_child_to_ref(&a, child1);
631        crate::node::NodeInner::add_child_to_ref(&b, child2);
632
633        let dist = measure.child_list_distance(&a, &b);
634        assert_eq!(dist, 0.0);
635    }
636
637    #[test]
638    fn test_child_list_distance_different_children() {
639        let mut measure = Measure::new();
640
641        let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
642            "div".to_string(),
643            StdHashMap::new(),
644        ))));
645        let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
646            "div".to_string(),
647            StdHashMap::new(),
648        ))));
649
650        // Add different children
651        let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
652            "span".to_string(),
653            StdHashMap::new(),
654        ))));
655        let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
656            "p".to_string(),
657            StdHashMap::new(),
658        ))));
659
660        crate::node::NodeInner::add_child_to_ref(&a, child1);
661        crate::node::NodeInner::add_child_to_ref(&b, child2);
662
663        let dist = measure.child_list_distance(&a, &b);
664        assert!(dist > 0.0);
665        assert!(dist <= 1.0);
666    }
667
668    #[test]
669    fn test_penalty_for_short_content() {
670        let mut measure = Measure::new();
671
672        // Very short text nodes should have penalty applied
673        let a = new_base_node(Some(XmlContent::Text(XmlText::new("a"))));
674        let b = new_base_node(Some(XmlContent::Text(XmlText::new("b"))));
675
676        let dist = measure.get_distance(Some(&a), Some(&b));
677        // With short content, penalty increases distance
678        assert!(dist > 0.0);
679    }
680
681    #[test]
682    fn test_q_grams_built_correctly() {
683        let measure = Measure::new();
684        let mut grams = StdHashMap::new();
685        measure.build_q_grams_str("hello", &mut grams);
686
687        // "hello" (len=5) comparing to itself → combined=10 → Q=2
688        // Should produce 2-grams: "he", "el", "ll", "lo", "o"
689        assert!(grams.contains_key("he"));
690        assert!(grams.contains_key("el"));
691        assert!(grams.contains_key("ll"));
692        assert!(grams.contains_key("lo"));
693        assert!(grams.contains_key("o"));
694    }
695}