yass/
hybrid_jaccard.rs

1use crate::error::StrSimError;
2
3use super::{ExpectTokenizerType, JaroWinkler, StrSim, TokenizerType};
4use anyhow::Result;
5use derive_more::Display;
6use lsap::get_assigned_cost;
7
8#[derive(Display)]
9#[display(fmt = "HybridJaccard")]
10pub struct HybridJaccard<S: StrSim<Vec<char>> + ExpectTokenizerType> {
11    pub threshold: f64,
12    pub lower_bound: f64,
13    pub strsim: S,
14}
15
16impl HybridJaccard<JaroWinkler> {
17    pub fn default() -> Self {
18        HybridJaccard {
19            threshold: 0.5,
20            lower_bound: 0.0,
21            strsim: JaroWinkler::default(),
22        }
23    }
24}
25
26impl<S: StrSim<Vec<char>> + ExpectTokenizerType> HybridJaccard<S> {
27    pub fn new(strsim: S, threshold: Option<f64>, lower_bound: Option<f64>) -> Self {
28        HybridJaccard {
29            threshold: threshold.unwrap_or(0.5),
30            lower_bound: lower_bound.unwrap_or(0.0),
31            strsim,
32        }
33    }
34
35    pub fn similarity<'t>(
36        &self,
37        mut set1: &'t Vec<Vec<char>>,
38        mut set2: &'t Vec<Vec<char>>,
39    ) -> Result<f64, StrSimError> {
40        if set1.len() > set2.len() {
41            (set1, set2) = (set2, set1);
42        }
43        let total_num_matches = set1.len();
44        let mut matching_score = vec![1.0; set1.len() * set2.len()];
45        // let mut matching_score = Array2::from_elem((set1.len(), set2.len()), 1.0);
46        let mut row_max: Vec<f64> = vec![0.0; set1.len()];
47
48        for (i, s1) in set1.iter().enumerate() {
49            for (j, s2) in set2.iter().enumerate() {
50                let mut score: f64 = self.strsim.similarity_pre_tok2(s1, s2)?;
51                if score < self.threshold {
52                    score = 0.0;
53                }
54                row_max[i] = row_max[i].max(score);
55                // matching_score[[i, j]] = 1.0 - score // munkres finds out the smallest element
56                // matching_score[[i, j]] = score
57                matching_score[i * set2.len() + j] = score
58            }
59
60            if self.lower_bound > 0.0 {
61                let max_possible_score_sum: f64 =
62                    row_max[..i + 1].iter().sum::<f64>() + (total_num_matches - i - 1) as f64;
63                let max_possible =
64                    max_possible_score_sum / (set1.len() + set2.len() - total_num_matches) as f64;
65                if max_possible < self.lower_bound {
66                    return Ok(0.0);
67                }
68            }
69        }
70
71        let score_sum = get_assigned_cost(set1.len(), set2.len(), &matching_score, true)?;
72
73        if set1.len() + set2.len() - total_num_matches == 0 {
74            return Ok(1.0);
75        }
76        let sim = score_sum / (set1.len() + set2.len() - total_num_matches) as f64;
77        if self.lower_bound > 0.0 && sim < self.lower_bound {
78            Ok(0.0)
79        } else {
80            Ok(sim)
81        }
82    }
83
84    // /**
85    //  *
86    //  */
87    // fn similarity_impl_v1(&self, mut set1: &Vec<Vec<char>>, mut set2: &Vec<Vec<char>>) -> f64 {
88    //     if set1.len() > set2.len() {
89    //         let tmp = set1;
90    //         set1 = set2;
91    //         set2 = set1;
92    //     }
93
94    //     let mut match_score = 0.0;
95    //     let mut match_count = 0.0;
96    //     let mut matches = vec![];
97
98    //     for (i, s1) in set1.iter().enumerate() {
99    //         for (j, s2) in set2.iter().enumerate() {
100    //             let mut score = self.strsim.similarity(s1, s2);
101    //             if score > self.threshold {
102    //                 matches.push((s1, s2, score));
103    //             }
104    //         }
105    //     }
106
107    //     // sort the score of all the pairs
108    //     matches.sort_by(|a, b| b[2].partial_cmp(&a[2]).unwrap());
109
110    //     // select score in increasing order of their weightage
111    //     // do not reselect the same element from either set.
112    //     let mut set1x = HashSet::new();
113    //     let mut set2x = HashSet::new();
114    //     for (s1, s2, score) in matches {
115    //         if !set1x.contains(s1) && !set2x.contains(s2) {
116    //             set1x.add(s1);
117    //             set2x.add(s2);
118    //             match_score += score;
119    //             match_count += 1.0;
120    //         }
121    //     }
122
123    //     match_score / (set1.len() + set2.len() - match_count)
124    // }
125}
126
127impl<S: StrSim<Vec<char>> + ExpectTokenizerType> StrSim<Vec<Vec<char>>> for HybridJaccard<S> {
128    fn similarity_pre_tok2(
129        &self,
130        set1: &Vec<Vec<char>>,
131        set2: &Vec<Vec<char>>,
132    ) -> Result<f64, StrSimError> {
133        self.similarity(set1, set2)
134    }
135}
136
137impl<S: StrSim<Vec<char>> + ExpectTokenizerType> ExpectTokenizerType for HybridJaccard<S> {
138    fn get_expected_tokenizer_type(&self) -> TokenizerType {
139        TokenizerType::Set(Box::new(Some(self.strsim.get_expected_tokenizer_type())))
140    }
141}