1use std::collections::HashMap;
8
9use crate::node::{NodeRef, XmlContent};
10
11pub const MAX_DIST: f64 = 1.0;
13
14pub const ZERO_CHILDREN_MATCH: f64 = 1.0;
16
17pub const ELEMENT_NAME_INFO: i32 = 1;
19
20pub const ATTR_INFO: i32 = 2;
22
23pub const ATTR_VALUE_THRESHOLD: usize = 5;
25
26pub const TEXT_THRESHOLD: usize = 5;
28
29const PENALTY_C: i32 = 20;
31
32fn decide_q(combined_len: usize) -> usize {
35 if combined_len < 5 {
36 1
37 } else if combined_len < 50 {
38 2
39 } else {
40 4
41 }
42}
43
44pub struct Measure {
49 mismatched: i32,
51 total: i32,
53 total_mismatch: bool,
55 a_grams: HashMap<String, i32>,
57 b_grams: HashMap<String, i32>,
58}
59
60impl Default for Measure {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Measure {
67 pub fn new() -> Self {
69 Measure {
70 mismatched: 0,
71 total: 0,
72 total_mismatch: false,
73 a_grams: HashMap::with_capacity(2048),
74 b_grams: HashMap::with_capacity(2048),
75 }
76 }
77
78 pub fn get_distance(&mut self, a: Option<&NodeRef>, b: Option<&NodeRef>) -> f64 {
83 if let (Some(a), Some(b)) = (a, b) {
84 self.include_nodes(a, b);
85 }
86
87 let result = if self.total_mismatch {
88 MAX_DIST
89 } else if self.total == 0 {
90 0.0
91 } else {
92 let penalty = (1.0 - (self.total as f64) / (PENALTY_C as f64)).max(0.0);
93 penalty + (1.0 - penalty) * (self.mismatched as f64) / (self.total as f64)
94 };
95
96 self.reset_distance();
97 result
98 }
99
100 fn reset_distance(&mut self) {
102 self.mismatched = 0;
103 self.total = 0;
104 self.total_mismatch = false;
105 }
106
107 fn include_nodes(&mut self, a: &NodeRef, b: &NodeRef) {
109 if self.total_mismatch {
110 return;
111 }
112
113 let a_borrowed = a.borrow();
114 let b_borrowed = b.borrow();
115
116 let ca = a_borrowed.content();
117 let cb = b_borrowed.content();
118
119 match (ca, cb) {
120 (Some(XmlContent::Element(ea)), Some(XmlContent::Element(eb))) => {
121 self.total += ELEMENT_NAME_INFO;
123 if ea.qname() != eb.qname() {
124 self.mismatched += ELEMENT_NAME_INFO;
125 }
126
127 let attrs_a = ea.attributes();
129 let attrs_b = eb.attributes();
130
131 for (name, v1) in attrs_a.iter() {
132 if let Some(v2) = attrs_b.get(name) {
133 let amismatch = self.string_dist_str(v1, v2);
135 let info = if v1.len() > ATTR_VALUE_THRESHOLD {
136 v1.len() as i32
137 } else {
138 1
139 } + if v2.len() > ATTR_VALUE_THRESHOLD {
140 v2.len() as i32
141 } else {
142 1
143 };
144 self.mismatched += if amismatch > info { info } else { amismatch };
145 self.total += info;
146 } else {
147 self.mismatched += ATTR_INFO;
149 self.total += ATTR_INFO;
150 }
151 }
152
153 for name in attrs_b.keys() {
155 if !attrs_a.contains_key(name) {
156 self.mismatched += ATTR_INFO;
157 self.total += ATTR_INFO;
158 }
159 }
160 }
161 (Some(XmlContent::Text(ta)), Some(XmlContent::Text(tb))) => {
162 let info = (ta.info_size() + tb.info_size()) / 2;
164 let amismatch = self.string_dist_chars(ta.text(), tb.text()) / 2;
165
166 self.mismatched += if amismatch > info { info } else { amismatch };
167 self.total += info;
168 }
169 _ => {
170 self.total_mismatch = true;
172 }
173 }
174 }
175
176 pub fn string_dist_str(&mut self, a: &str, b: &str) -> i32 {
178 self.q_dist_str(a, b)
179 }
180
181 pub fn string_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
183 self.q_dist_chars(a, b)
184 }
185
186 pub fn child_list_distance(&mut self, a: &NodeRef, b: &NodeRef) -> f64 {
191 let a_borrowed = a.borrow();
192 let b_borrowed = b.borrow();
193
194 let a_count = a_borrowed.child_count();
195 let b_count = b_borrowed.child_count();
196
197 if a_count == 0 && b_count == 0 {
198 return ZERO_CHILDREN_MATCH;
199 }
200
201 let ac: Vec<char> = a_borrowed
203 .children()
204 .iter()
205 .map(|child| {
206 let child_borrowed = child.borrow();
207 if let Some(content) = child_borrowed.content() {
208 char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
209 } else {
210 '\0'
211 }
212 })
213 .collect();
214
215 let bc: Vec<char> = b_borrowed
216 .children()
217 .iter()
218 .map(|child| {
219 let child_borrowed = child.borrow();
220 if let Some(content) = child_borrowed.content() {
221 char::from_u32((content.content_hash() & 0xffff) as u32).unwrap_or('\0')
222 } else {
223 '\0'
224 }
225 })
226 .collect();
227
228 let dist = self.string_dist_chars(&ac, &bc);
229 (dist as f64) / ((a_count + b_count) as f64)
230 }
231
232 pub fn matched_child_list_distance(
242 &mut self,
243 base: &NodeRef,
244 branch: &NodeRef,
245 _is_left: bool,
246 ) -> i32 {
247 let base_borrowed = base.borrow();
248 let branch_borrowed = branch.borrow();
249
250 let ac: Vec<char> = (0..base_borrowed.child_count())
252 .map(|i| char::from_u32((i + 1) as u32).unwrap_or('\0'))
253 .collect();
254
255 let bc: Vec<char> = branch_borrowed
257 .children()
258 .iter()
259 .enumerate()
260 .map(|(i, child)| {
261 let child_borrowed = child.borrow();
262
263 if let Some(base_match_weak) = child_borrowed.get_base_match() {
265 if let Some(base_match) = base_match_weak.upgrade() {
266 let base_match_borrowed = base_match.borrow();
267
268 if let Some(parent) = base_match_borrowed.parent().upgrade() {
270 if parent.borrow().id() == base_borrowed.id() {
272 return char::from_u32(
274 (base_match_borrowed.child_pos() + 1) as u32,
275 )
276 .unwrap_or('\0');
277 }
278 }
279 }
280 }
281
282 char::from_u32(0x10000 + i as u32).unwrap_or('\0')
287 })
288 .collect();
289
290 self.string_dist_chars(&ac, &bc)
291 }
292
293 fn q_dist_str(&mut self, a: &str, b: &str) -> i32 {
295 self.a_grams.clear();
296 self.b_grams.clear();
297
298 let q = decide_q(a.len() + b.len());
300
301 let chars_a: Vec<char> = a.chars().collect();
303 for i in 0..chars_a.len() {
304 let end = (i + q).min(chars_a.len());
305 let gram: String = chars_a[i..end].iter().collect();
306 *self.a_grams.entry(gram).or_insert(0) += 1;
307 }
308
309 let chars_b: Vec<char> = b.chars().collect();
311 for i in 0..chars_b.len() {
312 let end = (i + q).min(chars_b.len());
313 let gram: String = chars_b[i..end].iter().collect();
314 *self.b_grams.entry(gram).or_insert(0) += 1;
315 }
316
317 self.calc_q_distance()
318 }
319
320 fn q_dist_chars(&mut self, a: &[char], b: &[char]) -> i32 {
322 self.a_grams.clear();
323 self.b_grams.clear();
324
325 let q = decide_q(a.len() + b.len());
327
328 for i in 0..a.len() {
330 let end = (i + q).min(a.len());
331 let gram: String = a[i..end].iter().collect();
332 *self.a_grams.entry(gram).or_insert(0) += 1;
333 }
334
335 for i in 0..b.len() {
337 let end = (i + q).min(b.len());
338 let gram: String = b[i..end].iter().collect();
339 *self.b_grams.entry(gram).or_insert(0) += 1;
340 }
341
342 self.calc_q_distance()
343 }
344
345 #[cfg(test)]
349 fn build_q_grams_str(&self, s: &str, grams: &mut HashMap<String, i32>) {
350 grams.clear();
351 let chars: Vec<char> = s.chars().collect();
352 let q = decide_q(chars.len() * 2);
353 for i in 0..chars.len() {
354 let end = (i + q).min(chars.len());
355 let gram: String = chars[i..end].iter().collect();
356 *grams.entry(gram).or_insert(0) += 1;
357 }
358 }
359
360 fn calc_q_distance(&self) -> i32 {
362 let mut dist = 0;
363
364 for (gram, count_a) in &self.a_grams {
366 let count_b = self.b_grams.get(gram).copied().unwrap_or(0);
367 dist += (count_a - count_b).abs();
368 }
369
370 for (gram, count_b) in &self.b_grams {
372 if !self.a_grams.contains_key(gram) {
373 dist += *count_b;
374 }
375 }
376
377 dist
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use crate::node::{new_base_node, XmlElement, XmlText};
385 use std::collections::HashMap as StdHashMap;
386
387 #[test]
388 fn test_q_dist_identical_strings() {
389 let mut measure = Measure::new();
390 let dist = measure.string_dist_str("hello world", "hello world");
391 assert_eq!(dist, 0);
392 }
393
394 #[test]
395 fn test_q_dist_different_strings() {
396 let mut measure = Measure::new();
397 let dist = measure.string_dist_str("hello", "world");
398 assert!(dist > 0);
399 }
400
401 #[test]
402 fn test_q_dist_similar_strings() {
403 let mut measure = Measure::new();
404 let dist = measure.string_dist_str(
406 "return stringDist( a, b, a.length()+b.length() );",
407 "return stzingDist( a, b, a.length()+b.length() );",
408 );
409 assert!(dist > 0);
411 assert!(dist < 20); }
413
414 #[test]
415 fn test_q_dist_empty_strings() {
416 let mut measure = Measure::new();
417 let dist = measure.string_dist_str("", "");
418 assert_eq!(dist, 0);
419 }
420
421 #[test]
422 fn test_q_dist_one_empty() {
423 let mut measure = Measure::new();
424 let dist = measure.string_dist_str("hello", "");
425 assert!(dist > 0);
426 }
427
428 #[test]
429 fn test_get_distance_same_elements() {
430 let mut measure = Measure::new();
431
432 let attrs = StdHashMap::new();
433 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
434 "div".to_string(),
435 attrs.clone(),
436 ))));
437 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
438 "div".to_string(),
439 attrs,
440 ))));
441
442 let dist = measure.get_distance(Some(&a), Some(&b));
443 assert!((dist - 0.95).abs() < 0.01);
447 }
448
449 #[test]
450 fn test_get_distance_different_element_names() {
451 let mut measure = Measure::new();
452
453 let attrs = StdHashMap::new();
454 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
455 "div".to_string(),
456 attrs.clone(),
457 ))));
458 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
459 "span".to_string(),
460 attrs,
461 ))));
462
463 let dist = measure.get_distance(Some(&a), Some(&b));
464 assert!(dist > 0.0);
465 assert!(dist <= 1.0);
466 }
467
468 #[test]
469 fn test_get_distance_same_attributes() {
470 let mut measure = Measure::new();
471
472 let mut attrs = StdHashMap::new();
473 attrs.insert("id".to_string(), "foo".to_string());
474 attrs.insert("class".to_string(), "bar".to_string());
475
476 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
477 "div".to_string(),
478 attrs.clone(),
479 ))));
480 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
481 "div".to_string(),
482 attrs,
483 ))));
484
485 let dist = measure.get_distance(Some(&a), Some(&b));
486 assert!((dist - 0.75).abs() < 0.01);
491 }
492
493 #[test]
494 fn test_get_distance_different_attribute_values() {
495 let mut measure = Measure::new();
496
497 let mut attrs_a = StdHashMap::new();
498 attrs_a.insert("id".to_string(), "foo".to_string());
499
500 let mut attrs_b = StdHashMap::new();
501 attrs_b.insert("id".to_string(), "bar".to_string());
502
503 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
504 "div".to_string(),
505 attrs_a,
506 ))));
507 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
508 "div".to_string(),
509 attrs_b,
510 ))));
511
512 let dist = measure.get_distance(Some(&a), Some(&b));
513 assert!(dist > 0.0);
514 assert!(dist <= 1.0);
515 }
516
517 #[test]
518 fn test_get_distance_missing_attribute() {
519 let mut measure = Measure::new();
520
521 let mut attrs_a = StdHashMap::new();
522 attrs_a.insert("id".to_string(), "foo".to_string());
523
524 let attrs_b = StdHashMap::new();
525
526 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
527 "div".to_string(),
528 attrs_a,
529 ))));
530 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
531 "div".to_string(),
532 attrs_b,
533 ))));
534
535 let dist = measure.get_distance(Some(&a), Some(&b));
536 assert!(dist > 0.0);
537 }
538
539 #[test]
540 fn test_get_distance_same_text() {
541 let mut measure = Measure::new();
542
543 let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
544 let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello world"))));
545
546 let dist = measure.get_distance(Some(&a), Some(&b));
547 assert!(dist >= 0.0);
553 assert!(dist < 1.0); let mut measure2 = Measure::new();
556 let long_text = "This is a much longer piece of text that should have lower penalty";
557 let c = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
558 let d = new_base_node(Some(XmlContent::Text(XmlText::new(long_text))));
559 let dist2 = measure2.get_distance(Some(&c), Some(&d));
560 assert!(dist2 < dist);
562 }
563
564 #[test]
565 fn test_get_distance_different_text() {
566 let mut measure = Measure::new();
567
568 let a = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
569 let b = new_base_node(Some(XmlContent::Text(XmlText::new("world"))));
570
571 let dist = measure.get_distance(Some(&a), Some(&b));
572 assert!(dist > 0.0);
573 assert!(dist <= 1.0);
574 }
575
576 #[test]
577 fn test_get_distance_element_vs_text() {
578 let mut measure = Measure::new();
579
580 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
581 "div".to_string(),
582 StdHashMap::new(),
583 ))));
584 let b = new_base_node(Some(XmlContent::Text(XmlText::new("hello"))));
585
586 let dist = measure.get_distance(Some(&a), Some(&b));
587 assert_eq!(dist, 1.0); }
589
590 #[test]
591 fn test_child_list_distance_both_empty() {
592 let mut measure = Measure::new();
593
594 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
595 "div".to_string(),
596 StdHashMap::new(),
597 ))));
598 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
599 "div".to_string(),
600 StdHashMap::new(),
601 ))));
602
603 let dist = measure.child_list_distance(&a, &b);
604 assert_eq!(dist, ZERO_CHILDREN_MATCH);
605 }
606
607 #[test]
608 fn test_child_list_distance_same_children() {
609 let mut measure = Measure::new();
610
611 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
612 "div".to_string(),
613 StdHashMap::new(),
614 ))));
615 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
616 "div".to_string(),
617 StdHashMap::new(),
618 ))));
619
620 let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
622 "span".to_string(),
623 StdHashMap::new(),
624 ))));
625 let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
626 "span".to_string(),
627 StdHashMap::new(),
628 ))));
629
630 crate::node::NodeInner::add_child_to_ref(&a, child1);
631 crate::node::NodeInner::add_child_to_ref(&b, child2);
632
633 let dist = measure.child_list_distance(&a, &b);
634 assert_eq!(dist, 0.0);
635 }
636
637 #[test]
638 fn test_child_list_distance_different_children() {
639 let mut measure = Measure::new();
640
641 let a = new_base_node(Some(XmlContent::Element(XmlElement::new(
642 "div".to_string(),
643 StdHashMap::new(),
644 ))));
645 let b = new_base_node(Some(XmlContent::Element(XmlElement::new(
646 "div".to_string(),
647 StdHashMap::new(),
648 ))));
649
650 let child1 = new_base_node(Some(XmlContent::Element(XmlElement::new(
652 "span".to_string(),
653 StdHashMap::new(),
654 ))));
655 let child2 = new_base_node(Some(XmlContent::Element(XmlElement::new(
656 "p".to_string(),
657 StdHashMap::new(),
658 ))));
659
660 crate::node::NodeInner::add_child_to_ref(&a, child1);
661 crate::node::NodeInner::add_child_to_ref(&b, child2);
662
663 let dist = measure.child_list_distance(&a, &b);
664 assert!(dist > 0.0);
665 assert!(dist <= 1.0);
666 }
667
668 #[test]
669 fn test_penalty_for_short_content() {
670 let mut measure = Measure::new();
671
672 let a = new_base_node(Some(XmlContent::Text(XmlText::new("a"))));
674 let b = new_base_node(Some(XmlContent::Text(XmlText::new("b"))));
675
676 let dist = measure.get_distance(Some(&a), Some(&b));
677 assert!(dist > 0.0);
679 }
680
681 #[test]
682 fn test_q_grams_built_correctly() {
683 let measure = Measure::new();
684 let mut grams = StdHashMap::new();
685 measure.build_q_grams_str("hello", &mut grams);
686
687 assert!(grams.contains_key("he"));
690 assert!(grams.contains_key("el"));
691 assert!(grams.contains_key("ll"));
692 assert!(grams.contains_key("lo"));
693 assert!(grams.contains_key("o"));
694 }
695}