1use crate::DistanceMetric;
2
3pub struct TokenSet<D: DistanceMetric> {
9 inner: D,
11}
12
13impl<D: DistanceMetric> TokenSet<D> {
14 pub fn new(inner: D) -> Self {
16 Self { inner }
17 }
18}
19
20impl<D: DistanceMetric> DistanceMetric for TokenSet<D> {
21 type Dist = <D as DistanceMetric>::Dist;
22
23 fn distance<S, T>(&self, a: S, b: T) -> Self::Dist
24 where
25 S: IntoIterator,
26 T: IntoIterator,
27 <S as IntoIterator>::IntoIter: Clone,
28 <T as IntoIterator>::IntoIter: Clone,
29 <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
30 <T as IntoIterator>::Item: PartialEq,
31 {
32 let a = a.into_iter();
33 let b = b.into_iter();
34
35 let intersect = b.clone().filter(|x| a.clone().any(|y| y == *x));
36
37 if intersect.clone().count() == 0 {
38 return self.inner.distance(a, b);
39 }
40
41 let dist_inter_a = self.inner.distance(a.clone(), intersect.clone());
42 let dist_inter_b = self.inner.distance(intersect, b.clone());
43 let dist_a_b = self.inner.distance(a, b);
44
45 if dist_inter_a < dist_inter_b {
46 if dist_inter_a < dist_a_b {
47 dist_inter_a
48 } else {
49 dist_a_b
50 }
51 } else if dist_inter_b < dist_a_b {
52 dist_inter_b
53 } else {
54 dist_a_b
55 }
56 }
57
58 fn str_distance<S, T>(&self, a: S, b: T) -> Self::Dist
59 where
60 S: AsRef<str>,
61 T: AsRef<str>,
62 {
63 let a = a.as_ref();
64 let mut words_a: Vec<_> = a.split_whitespace().collect();
65 words_a.sort();
66 words_a.dedup_by(|a, b| a == b);
67
68 let b = b.as_ref();
69 let mut words_b: Vec<_> = b.split_whitespace().collect();
70 words_b.sort();
71 words_b.dedup_by(|a, b| a == b);
72
73 let words_intersect: Vec<_> = words_b
74 .iter()
75 .cloned()
76 .filter(|s| words_a.contains(s))
77 .collect();
78
79 if words_intersect.is_empty() {
80 return self.inner.str_distance(a, b);
81 }
82
83 let intersect = words_intersect.join(" ");
84 let a = words_a.join(" ");
85 let b = words_b.join(" ");
86
87 let dist_inter_a = self.inner.str_distance(&intersect, &a);
88 let dist_inter_b = self.inner.str_distance(intersect, &b);
89 let dist_a_b = self.inner.str_distance(a, &b);
90
91 if dist_inter_a < dist_inter_b {
92 if dist_inter_a < dist_a_b {
93 dist_inter_a
94 } else {
95 dist_a_b
96 }
97 } else if dist_inter_b < dist_a_b {
98 dist_inter_b
99 } else {
100 dist_a_b
101 }
102 }
103
104 fn normalized<S, T>(&self, a: S, b: T) -> f64
105 where
106 S: IntoIterator,
107 T: IntoIterator,
108 <S as IntoIterator>::IntoIter: Clone,
109 <T as IntoIterator>::IntoIter: Clone,
110 <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
111 <T as IntoIterator>::Item: PartialEq,
112 {
113 self.inner.normalized(a, b)
114 }
115}
116
117pub struct TokenSort<D: DistanceMetric> {
122 inner: D,
124}
125
126impl<D> DistanceMetric for TokenSort<D>
127where
128 D: DistanceMetric,
129{
130 type Dist = <D as DistanceMetric>::Dist;
131
132 fn distance<S, T>(&self, a: S, b: T) -> Self::Dist
133 where
134 S: IntoIterator,
135 T: IntoIterator,
136 <S as IntoIterator>::IntoIter: Clone,
137 <T as IntoIterator>::IntoIter: Clone,
138 <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
139 <T as IntoIterator>::Item: PartialEq,
140 {
141 self.inner.distance(a, b)
142 }
143
144 fn str_distance<S, T>(&self, a: S, b: T) -> Self::Dist
145 where
146 S: AsRef<str>,
147 T: AsRef<str>,
148 {
149 let mut a: Vec<_> = a.as_ref().split_whitespace().collect();
150 a.sort();
151 let mut b: Vec<_> = b.as_ref().split_whitespace().collect();
152 b.sort();
153 self.distance(a.join(" ").chars(), b.join(" ").chars())
154 }
155
156 fn normalized<S, T>(&self, a: S, b: T) -> f64
157 where
158 S: IntoIterator,
159 T: IntoIterator,
160 <S as IntoIterator>::IntoIter: Clone,
161 <T as IntoIterator>::IntoIter: Clone,
162 <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
163 <T as IntoIterator>::Item: PartialEq,
164 {
165 self.inner.normalized(a, b)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use crate::RatcliffObershelp;
172
173 use super::*;
174
175 #[test]
176 fn token_set_ratcliff() {
177 let s1 = "Real Madrid vs FC Barcelona";
178 let s2 = "Barcelona vs Real Madrid";
179 assert_eq!(TokenSet::new(RatcliffObershelp).str_distance(s1, s2), 0.0);
180
181 let s2 = "Barcelona vs Rel Madrid";
182 assert_eq!(
183 format!(
184 "{:.6}",
185 TokenSet::new(RatcliffObershelp).str_distance(s1, s2)
186 ),
187 "0.080000"
188 );
189 }
190}