1use rustc_hash::FxHashMap;
8
9use crate::node::{NodeRef, XmlContent};
10
11pub const MAX_DIST: f64 = 1.0;
13
14pub const ZERO_CHILDREN_MATCH: f64 = 1.0;
16
17pub const ELEMENT_NAME_INFO: i32 = 1;
19
20pub const ATTR_INFO: i32 = 2;
22
23pub const ATTR_VALUE_THRESHOLD: usize = 5;
25
26pub const TEXT_THRESHOLD: usize = 5;
28
29const PENALTY_C: i32 = 20;
31
32fn decide_q(combined_len: usize) -> usize {
35 if combined_len < 5 {
36 1
37 } else if combined_len < 50 {
38 2
39 } else {
40 4
41 }
42}
43
44pub struct Measure {
49 mismatched: i32,
51 total: i32,
53 total_mismatch: bool,
55 a_grams: FxHashMap<String, i32>,
57 b_grams: FxHashMap<String, i32>,
58}
59
60impl Default for Measure {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Measure {
67 pub fn new() -> Self {
69 Measure {
70 mismatched: 0,
71 total: 0,
72 total_mismatch: false,
73 a_grams: FxHashMap::with_capacity_and_hasher(2048, Default::default()),
74 b_grams: FxHashMap::with_capacity_and_hasher(2048, Default::default()),
75 }
76 }
77
78 pub fn get_distance(&mut self, a: Option<&NodeRef>, b: Option<&NodeRef>) -> f64 {
83 if let (Some(a), Some(b)) = (a, b) {
84 self.include_nodes(a, b);
85 }
86
87 let result = if self.total_mismatch {
88 MAX_DIST
89 } else if self.total == 0 {
90 0.0
91 } else {
92 let penalty = (1.0 - (self.total as f64) / (PENALTY_C as f64)).max(0.0);
93 penalty + (1.0 - penalty) * (self.mismatched as f64) / (self.total as f64)
94 };
95
96 self.reset_distance();
97 result
98 }
99
100 fn reset_distance(&mut self) {
102 self.mismatched = 0;
103 self.total = 0;
104 self.total_mismatch = false;
105 }
106
107 fn include_nodes(&mut self, a: &NodeRef, b: &NodeRef) {
109 if self.total_mismatch {
110 return;
111 }
112
113 let a_borrowed = a.borrow();
114 let b_borrowed = b.borrow();
115
116 let ca = a_borrowed.content();
117 let cb = b_borrowed.content();
118
119 match (ca, cb) {
120 (Some(XmlContent::Element(ea)), Some(XmlContent::Element(eb))) => {
121 self.total += ELEMENT_NAME_INFO;
123 if !ea.names_match(eb) {
124 self.mismatched += ELEMENT_NAME_INFO;
125 }
126
127 let attrs_a = ea.attributes();
129 let attrs_b = eb.attributes();
130
131 for (name, v1) in attrs_a.iter() {
132 if let Some(v2) = attrs_b.get(name) {
133 let amismatch = self.string_dist_str(v1, v2);
135 let info = if v1.len() > ATTR_VALUE_THRESHOLD {
136 v1.len() as i32
137 } else {
138 1
139 } + if v2.len() > ATTR_VALUE_THRESHOLD {
140 v2.len() as i32
141 } else {
142 1
143 };
144 self.mismatched += if amismatch > info { info } else { amismatch };
145 self.total += info;
146 } else {
147 self.mismatched += ATTR_INFO;
149 self.total += ATTR_INFO;
150 }
151 }
152
153 for name in attrs_b.keys() {
155 if !attrs_a.contains_key(name) {
156 self.mismatched += ATTR_INFO;
157 self.total += ATTR_INFO;
158 }
159 }
160 }
161 (Some(XmlContent::Text(ta)), Some(XmlContent::Text(tb))) => {
162 let info = (ta.info_size() + tb.info_size()) / 2;
164 let amismatch = self.string_dist_chars(ta.text(), tb.text()) / 2;
165
166 self.mismatched += if amismatch > info { info } else { amismatch };
167 self.total += info;
168 }
169 (
170 Some(XmlContent::ProcessingInstruction(pa)),
171 Some(XmlContent::ProcessingInstruction(pb)),
172 ) => {
173 self.total += 1;
175 if pa.target() != pb.target() || !pa.content_equals(pb) {
176 self.mismatched += 1;
177 }
178 }
179 _ => {
180 self.total_mismatch = true;
182 }
183 }
184 }
185
186 pub fn string_dist_str(&mut self, a: &str, b: &str) -> i32 {
188 self.q_dist_str(a, b)
189 }
190
191 pub fn string_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
193 self.q_dist_chars(a, b)
194 }
195
196 pub fn child_list_distance(&mut self, a: &NodeRef, b: &NodeRef) -> f64 {
201 let a_borrowed = a.borrow();
202 let b_borrowed = b.borrow();
203
204 let a_count = a_borrowed.child_count();
205 let b_count = b_borrowed.child_count();
206
207 if a_count == 0 && b_count == 0 {
208 return ZERO_CHILDREN_MATCH;
209 }
210
211 let ac: Vec<char> = a_borrowed
213 .children()
214 .iter()
215 .map(|child| {
216 let child_borrowed = child.borrow();
217 if let Some(content) = child_borrowed.content() {
218 char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
219 } else {
220 '\0'
221 }
222 })
223 .collect();
224
225 let bc: Vec<char> = b_borrowed
226 .children()
227 .iter()
228 .map(|child| {
229 let child_borrowed = child.borrow();
230 if let Some(content) = child_borrowed.content() {
231 char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
232 } else {
233 '\0'
234 }
235 })
236 .collect();
237
238 let dist = self.string_dist_chars(&ac, &bc);
239 (dist as f64) / ((a_count + b_count) as f64)
240 }
241
242 pub fn matched_child_list_distance(
252 &mut self,
253 base: &NodeRef,
254 branch: &NodeRef,
255 _is_left: bool,
256 ) -> i32 {
257 let base_borrowed = base.borrow();
258 let branch_borrowed = branch.borrow();
259
260 let ac: Vec<char> = (0..base_borrowed.child_count())
262 .map(|i| char::from_u32((i + 1) as u32).unwrap_or('\0'))
263 .collect();
264
265 let bc: Vec<char> = branch_borrowed
267 .children()
268 .iter()
269 .enumerate()
270 .map(|(i, child)| {
271 let child_borrowed = child.borrow();
272
273 if let Some(base_match_weak) = child_borrowed.get_base_match() {
275 if let Some(base_match) = base_match_weak.upgrade() {
276 let base_match_borrowed = base_match.borrow();
277
278 if let Some(parent) = base_match_borrowed.parent().upgrade() {
280 if parent.borrow().id() == base_borrowed.id() {
282 return char::from_u32(
284 (base_match_borrowed.child_pos() + 1) as u32,
285 )
286 .unwrap_or('\0');
287 }
288 }
289 }
290 }
291
292 char::from_u32(0x10000 + i as u32).unwrap_or('\0')
297 })
298 .collect();
299
300 self.string_dist_chars(&ac, &bc)
301 }
302
303 fn q_dist_str(&mut self, a: &str, b: &str) -> i32 {
305 self.a_grams.clear();
306 self.b_grams.clear();
307
308 let q = decide_q(a.len() + b.len());
310
311 let chars_a: Vec<char> = a.chars().collect();
313 for i in 0..chars_a.len() {
314 let end = (i + q).min(chars_a.len());
315 let gram: String = chars_a[i..end].iter().collect();
316 *self.a_grams.entry(gram).or_insert(0) += 1;
317 }
318
319 let chars_b: Vec<char> = b.chars().collect();
321 for i in 0..chars_b.len() {
322 let end = (i + q).min(chars_b.len());
323 let gram: String = chars_b[i..end].iter().collect();
324 *self.b_grams.entry(gram).or_insert(0) += 1;
325 }
326
327 self.calc_q_distance()
328 }
329
330 fn q_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
332 self.a_grams.clear();
333 self.b_grams.clear();
334
335 let q = decide_q(a.len() + b.len());
337
338 for i in 0..a.len() {
340 let end = (i + q).min(a.len());
341 let gram: String = a[i..end].iter().collect();
342 *self.a_grams.entry(gram).or_insert(0) += 1;
343 }
344
345 for i in 0..b.len() {
347 let end = (i + q).min(b.len());
348 let gram: String = b[i..end].iter().collect();
349 *self.b_grams.entry(gram).or_insert(0) += 1;
350 }
351
352 self.calc_q_distance()
353 }
354
355 #[cfg(test)]
359 fn build_q_grams_str(&self, s: &str, grams: &mut FxHashMap<String, i32>) {
360 grams.clear();
361 let chars: Vec<char> = s.chars().collect();
362 let q = decide_q(chars.len() * 2);
363 for i in 0..chars.len() {
364 let end = (i + q).min(chars.len());
365 let gram: String = chars[i..end].iter().collect();
366 *grams.entry(gram).or_insert(0) += 1;
367 }
368 }
369
370 fn calc_q_distance(&self) -> i32 {
372 let mut dist = 0;
373
374 for (gram, count_a) in &self.a_grams {
376 let count_b = self.b_grams.get(gram).copied().unwrap_or(0);
377 dist += (count_a - count_b).abs();
378 }
379
380 for (gram, count_b) in &self.b_grams {
382 if !self.a_grams.contains_key(gram) {
383 dist += *count_b;
384 }
385 }
386
387 dist
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::node::{new_base_node, XmlElement, XmlText};
395 use std::collections::HashMap as StdHashMap;
396
397 #[test]
398 fn test_q_dist_identical_strings() {
399 let mut measure = Measure::new();
400 let dist = measure.string_dist_str("hello world", "hello world");
401 assert_eq!(dist, 0);
402 }
403
404 #[test]
405 fn test_q_dist_different_strings() {
406 let mut measure = Measure::new();
407 let dist = measure.string_dist_str("hello", "world");
408 assert!(dist > 0);
409 }
410
411 #[test]
412 fn test_q_dist_similar_strings() {
413 let mut measure = Measure::new();
414 let dist = measure.string_dist_str(
416 "return stringDist( a, b, a.length()+b.length() );",
417 "return stzingDist( a, b, a.length()+b.length() );",
418 );
419 assert!(dist > 0);
421 assert!(dist < 20); }
423
424 #[test]
425 fn test_q_dist_empty_strings() {
426 let mut measure = Measure::new();
427 let dist = measure.string_dist_str("", "");
428 assert_eq!(dist, 0);
429 }
430
431 #[test]
432 fn test_q_dist_one_empty() {
433 let mut measure = Measure::new();
434 let dist = measure.string_dist_str("hello", "");
435 assert!(dist > 0);
436 }
437
438 #[test]
439 fn test_get_distance_same_elements() {
440 let mut measure = Measure::new();
441
442 let attrs = StdHashMap::new();
443 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
444 "div".to_string(),
445 attrs.clone(),
446 ))));
447 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
448 "div".to_string(),
449 attrs,
450 ))));
451
452 let dist = measure.get_distance(Some(&a), Some(&b));
453 assert!((dist - 0.95).abs() < 0.01);
457 }
458
459 #[test]
460 fn test_get_distance_different_element_names() {
461 let mut measure = Measure::new();
462
463 let attrs = StdHashMap::new();
464 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
465 "div".to_string(),
466 attrs.clone(),
467 ))));
468 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
469 "span".to_string(),
470 attrs,
471 ))));
472
473 let dist = measure.get_distance(Some(&a), Some(&b));
474 assert!(dist > 0.0);
475 assert!(dist <= 1.0);
476 }
477
478 #[test]
479 fn test_get_distance_same_attributes() {
480 let mut measure = Measure::new();
481
482 let mut attrs = StdHashMap::new();
483 attrs.insert("id".to_string(), "foo".to_string());
484 attrs.insert("class".to_string(), "bar".to_string());
485
486 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
487 "div".to_string(),
488 attrs.clone(),
489 ))));
490 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
491 "div".to_string(),
492 attrs,
493 ))));
494
495 let dist = measure.get_distance(Some(&a), Some(&b));
496 assert!((dist - 0.75).abs() < 0.01);
501 }
502
503 #[test]
504 fn test_get_distance_different_attribute_values() {
505 let mut measure = Measure::new();
506
507 let mut attrs_a = StdHashMap::new();
508 attrs_a.insert("id".to_string(), "foo".to_string());
509
510 let mut attrs_b = StdHashMap::new();
511 attrs_b.insert("id".to_string(), "bar".to_string());
512
513 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
514 "div".to_string(),
515 attrs_a,
516 ))));
517 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
518 "div".to_string(),
519 attrs_b,
520 ))));
521
522 let dist = measure.get_distance(Some(&a), Some(&b));
523 assert!(dist > 0.0);
524 assert!(dist <= 1.0);
525 }
526
527 #[test]
528 fn test_get_distance_missing_attribute() {
529 let mut measure = Measure::new();
530
531 let mut attrs_a = StdHashMap::new();
532 attrs_a.insert("id".to_string(), "foo".to_string());
533
534 let attrs_b = StdHashMap::new();
535
536 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
537 "div".to_string(),
538 attrs_a,
539 ))));
540 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
541 "div".to_string(),
542 attrs_b,
543 ))));
544
545 let dist = measure.get_distance(Some(&a), Some(&b));
546 assert!(dist > 0.0);
547 }
548
549 #[test]
550 fn test_get_distance_same_text() {
551 let mut measure = Measure::new();
552
553 let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
554 let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
555
556 let dist = measure.get_distance(Some(&a), Some(&b));
557 assert!(dist >= 0.0);
563 assert!(dist < 1.0); let mut measure2 = Measure::new();
566 let long_text = "This is a much longer piece of text that should have lower penalty";
567 let c = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
568 let d = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
569 let dist2 = measure2.get_distance(Some(&c), Some(&d));
570 assert!(dist2 < dist);
572 }
573
574 #[test]
575 fn test_get_distance_different_text() {
576 let mut measure = Measure::new();
577
578 let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
579 let b = new_base_node(Some(XmlContent::Text(XmlText::new("world"))));
580
581 let dist = measure.get_distance(Some(&a), Some(&b));
582 assert!(dist > 0.0);
583 assert!(dist <= 1.0);
584 }
585
586 #[test]
587 fn test_get_distance_element_vs_text() {
588 let mut measure = Measure::new();
589
590 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
591 "div".to_string(),
592 StdHashMap::new(),
593 ))));
594 let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
595
596 let dist = measure.get_distance(Some(&a), Some(&b));
597 assert_eq!(dist, 1.0); }
599
600 #[test]
601 fn test_child_list_distance_both_empty() {
602 let mut measure = Measure::new();
603
604 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
605 "div".to_string(),
606 StdHashMap::new(),
607 ))));
608 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
609 "div".to_string(),
610 StdHashMap::new(),
611 ))));
612
613 let dist = measure.child_list_distance(&a, &b);
614 assert_eq!(dist, ZERO_CHILDREN_MATCH);
615 }
616
617 #[test]
618 fn test_child_list_distance_same_children() {
619 let mut measure = Measure::new();
620
621 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
622 "div".to_string(),
623 StdHashMap::new(),
624 ))));
625 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
626 "div".to_string(),
627 StdHashMap::new(),
628 ))));
629
630 let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
632 "span".to_string(),
633 StdHashMap::new(),
634 ))));
635 let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
636 "span".to_string(),
637 StdHashMap::new(),
638 ))));
639
640 crate::node::NodeInner::add_child_to_ref(&a, child1);
641 crate::node::NodeInner::add_child_to_ref(&b, child2);
642
643 let dist = measure.child_list_distance(&a, &b);
644 assert_eq!(dist, 0.0);
645 }
646
647 #[test]
648 fn test_child_list_distance_different_children() {
649 let mut measure = Measure::new();
650
651 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
652 "div".to_string(),
653 StdHashMap::new(),
654 ))));
655 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
656 "div".to_string(),
657 StdHashMap::new(),
658 ))));
659
660 let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
662 "span".to_string(),
663 StdHashMap::new(),
664 ))));
665 let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
666 "p".to_string(),
667 StdHashMap::new(),
668 ))));
669
670 crate::node::NodeInner::add_child_to_ref(&a, child1);
671 crate::node::NodeInner::add_child_to_ref(&b, child2);
672
673 let dist = measure.child_list_distance(&a, &b);
674 assert!(dist > 0.0);
675 assert!(dist <= 1.0);
676 }
677
678 #[test]
679 fn test_penalty_for_short_content() {
680 let mut measure = Measure::new();
681
682 let a = new_base_node(Some(XmlContent::Text(XmlText::new("a"))));
684 let b = new_base_node(Some(XmlContent::Text(XmlText::new("b"))));
685
686 let dist = measure.get_distance(Some(&a), Some(&b));
687 assert!(dist > 0.0);
689 }
690
691 #[test]
692 fn test_q_grams_built_correctly() {
693 let measure = Measure::new();
694 let mut grams: FxHashMap<String, i32> = FxHashMap::default();
695 measure.build_q_grams_str("hello", &mut grams);
696
697 assert!(grams.contains_key("he"));
700 assert!(grams.contains_key("el"));
701 assert!(grams.contains_key("ll"));
702 assert!(grams.contains_key("lo"));
703 assert!(grams.contains_key("o"));
704 }
705}