tf_idf_vectorizer/vectorizer/compute/
compare.rs1use num::{traits::MulAdd, Num};
2use std::cmp::Ordering;
3
4pub trait Compare<N>
5where
6 N: Num + Copy,
7{
8 fn dot(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
11 fn cosine_similarity(vec: impl Iterator<Item = (usize, N)>, other: impl Iterator<Item = (usize, N)>) -> f64;
15 fn euclidean_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
18 fn manhattan_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
21 fn chebyshev_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
24}
25
26#[derive(Debug)]
27pub struct DefaultCompare;
28
29impl Compare<u8> for DefaultCompare {
31 #[inline(always)]
32 fn dot(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
33 let max = u8::MAX as u32;
36 vec.zip(other)
38 .map(|(a, b)| ((a as u32 * b as u32) as f64 / max as f64))
39 .sum()
40 }
41
42 #[inline(always)]
43 fn cosine_similarity(vec: impl Iterator<Item = (usize, u8)>, other: impl Iterator<Item = (usize, u8)>) -> f64 {
44 let max = u8::MAX as u32;
45 let mut a_it = vec.fuse();
46 let mut b_it = other.fuse();
47 let mut a_next = a_it.next();
48 let mut b_next = b_it.next();
49 let mut norm_a = 0_f64;
50 let mut norm_b = 0_f64;
51 let mut dot = 0_f64;
52 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
53 match ia.cmp(&ib) {
54 Ordering::Equal => {
55 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
56 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
57 dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
58 a_next = a_it.next();
59 b_next = b_it.next();
60 }
61 Ordering::Less => {
62 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
63 a_next = a_it.next();
64 }
65 Ordering::Greater => {
66 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
67 b_next = b_it.next();
68 }
69 }
70 }
71 while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
72 while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
73 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
74 }
75
76 #[inline(always)]
77 fn euclidean_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
78 vec.zip(other)
79 .map(|(a, b)| {
80 let diff = a as i32 - b as i32;
81 (diff * diff) as f64
82 })
83 .sum::<f64>()
84 .sqrt()
85 }
86
87 #[inline(always)]
88 fn manhattan_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
89 vec.zip(other)
90 .map(|(a, b)| {
91 let diff = a as i32 - b as i32;
92 diff.abs() as f64
93 })
94 .sum()
95 }
96
97 #[inline(always)]
98 fn chebyshev_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
99 vec.zip(other)
100 .map(|(a, b)| {
101 let diff = a as i32 - b as i32;
102 diff.abs() as f64
103 })
104 .max_by(|a, b| a.partial_cmp(b).unwrap())
105 .unwrap_or(0.0)
106 }
107}
108
109impl Compare<u16> for DefaultCompare {
110 #[inline(always)]
111 fn dot(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
112 let max = u16::MAX as u32;
114 vec.zip(other)
115 .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
116 .sum()
117 }
118
119 #[inline(always)]
120 fn cosine_similarity(vec: impl Iterator<Item = (usize, u16)>, other: impl Iterator<Item = (usize, u16)>) -> f64 {
121 let max = u16::MAX as u32;
122 let mut a_it = vec.fuse();
123 let mut b_it = other.fuse();
124 let mut a_next = a_it.next();
125 let mut b_next = b_it.next();
126 let mut norm_a = 0_f64;
127 let mut norm_b = 0_f64;
128 let mut dot = 0_f64;
129 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
130 match ia.cmp(&ib) {
131 Ordering::Equal => {
132 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
133 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
134 dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
135 a_next = a_it.next();
136 b_next = b_it.next();
137 }
138 Ordering::Less => { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
139 Ordering::Greater => { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
140 }
141 }
142 while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
143 while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
144 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
145 }
146
147 #[inline(always)]
149 fn euclidean_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
150 vec.zip(other)
151 .map(|(a, b)| {
152 let diff = a as i32 - b as i32;
153 (diff * diff) as f64
154 })
155 .sum::<f64>()
156 .sqrt()
157 }
158
159 #[inline(always)]
161 fn manhattan_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
162 vec.zip(other)
163 .map(|(a, b)| {
164 let diff = a as i32 - b as i32;
165 diff.abs() as f64
166 })
167 .sum()
168 }
169
170 #[inline(always)]
172 fn chebyshev_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
173 vec.zip(other)
174 .map(|(a, b)| {
175 let diff = a as i32 - b as i32;
176 diff.abs() as f64
177 })
178 .max_by(|a, b| a.partial_cmp(b).unwrap())
179 .unwrap_or(0.0)
180 }
181}
182
183impl Compare<u32> for DefaultCompare {
184 #[inline(always)]
185 fn dot(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
186 let max = u32::MAX as u128;
188 vec.zip(other)
189 .map(|(a, b)| {
190 let prod = (a as u128).mul_add(b as u128, max - 1);
191 (prod / max) as f64
192 })
193 .sum()
194 }
195
196 #[inline(always)]
197 fn cosine_similarity(vec: impl Iterator<Item = (usize, u32)>, other: impl Iterator<Item = (usize, u32)>) -> f64 {
198 let max = u32::MAX as u128;
199 let mut a_it = vec.fuse();
200 let mut b_it = other.fuse();
201 let mut a_next = a_it.next();
202 let mut b_next = b_it.next();
203 let mut norm_a = 0_f64;
204 let mut norm_b = 0_f64;
205 let mut dot = 0_f64;
206 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
207 match ia.cmp(&ib) {
208 Ordering::Equal => {
209 let aa = (va as u128).mul_add(va as u128, max - 1);
210 let bb = (vb as u128).mul_add(vb as u128, max - 1);
211 let ab = (va as u128).mul_add(vb as u128, max - 1);
212 norm_a += (aa / max) as f64;
213 norm_b += (bb / max) as f64;
214 dot += (ab / max) as f64;
215 a_next = a_it.next();
216 b_next = b_it.next();
217 }
218 Ordering::Less => {
219 let aa = (va as u128).mul_add(va as u128, max - 1);
220 norm_a += (aa / max) as f64;
221 a_next = a_it.next();
222 }
223 Ordering::Greater => {
224 let bb = (vb as u128).mul_add(vb as u128, max - 1);
225 norm_b += (bb / max) as f64;
226 b_next = b_it.next();
227 }
228 }
229 }
230 while let Some((_, va)) = a_next { let aa = (va as u128).mul_add(va as u128, max -1); norm_a += (aa / max) as f64; a_next = a_it.next(); }
231 while let Some((_, vb)) = b_next { let bb = (vb as u128).mul_add(vb as u128, max -1); norm_b += (bb / max) as f64; b_next = b_it.next(); }
232 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
233 }
234
235 #[inline(always)]
236 fn euclidean_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
237 vec.zip(other)
238 .map(|(a, b)| {
239 let diff = a as i64 - b as i64;
240 (diff * diff) as f64
241 })
242 .sum::<f64>()
243 .sqrt()
244 }
245
246 #[inline(always)]
247 fn manhattan_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
248 vec.zip(other)
249 .map(|(a, b)| {
250 let diff = a as i64 - b as i64;
251 diff.abs() as f64
252 })
253 .sum()
254 }
255
256 #[inline(always)]
257 fn chebyshev_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
258 vec.zip(other)
259 .map(|(a, b)| {
260 let diff = a as i64 - b as i64;
261 diff.abs() as f64
262 })
263 .max_by(|a, b| a.partial_cmp(b).unwrap())
264 .unwrap_or(0.0)
265 }
266}
267
268impl Compare<f32> for DefaultCompare {
269 #[inline(always)]
270 fn dot(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
271 let mut acc: f32 = 0.0;
272 for (a, b) in vec.zip(other) {
273 acc += a * b; }
275 acc as f64
276 }
277
278 #[inline(always)]
279 fn cosine_similarity(vec: impl Iterator<Item = (usize, f32)>, other: impl Iterator<Item = (usize, f32)>) -> f64 {
280 let mut a_it = vec.fuse();
281 let mut b_it = other.fuse();
282 let mut a_next = a_it.next();
283 let mut b_next = b_it.next();
284 let mut sum_a2: f64 = 0.0;
285 let mut sum_b2: f64 = 0.0;
286 let mut sum_ab: f64 = 0.0;
287 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
288 match ia.cmp(&ib) {
289 Ordering::Equal => { sum_a2 += (va * va) as f64; sum_b2 += (vb * vb) as f64; sum_ab += (va * vb) as f64; a_next = a_it.next(); b_next = b_it.next(); }
290 Ordering::Less => { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
291 Ordering::Greater => { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
292 }
293 }
294 while let Some((_, va)) = a_next { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
295 while let Some((_, vb)) = b_next { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
296 if sum_a2 == 0.0 || sum_b2 == 0.0 { 0.0 } else { sum_ab / (sum_a2.sqrt() * sum_b2.sqrt()) }
297 }
298
299 #[inline(always)]
300 fn chebyshev_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
301 let mut maxd: f32 = 0.0;
302 for (a, b) in vec.zip(other) {
303 let d = (a - b).abs();
304 if d > maxd { maxd = d; }
305 }
306 maxd as f64
307 }
308
309 #[inline(always)]
310 fn euclidean_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
311 let mut acc: f32 = 0.0;
312 for (a, b) in vec.zip(other) {
313 let d = a - b;
314 acc += d * d;
315 }
316 (acc.sqrt()) as f64
317 }
318
319 #[inline(always)]
320 fn manhattan_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
321 let mut acc: f32 = 0.0;
322 for (a, b) in vec.zip(other) {
323 acc += (a - b).abs();
324 }
325 acc as f64
326 }
327}
328
329impl Compare<f64> for DefaultCompare {
330 #[inline(always)]
331 fn dot(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
332 vec.zip(other)
333 .map(|(a, b)| a * b)
334 .sum()
335 }
336
337 #[inline(always)]
338 fn cosine_similarity(vec: impl Iterator<Item = (usize, f64)>, other: impl Iterator<Item = (usize, f64)>) -> f64 {
339 let mut a_it = vec.fuse();
340 let mut b_it = other.fuse();
341 let mut a_next = a_it.next();
342 let mut b_next = b_it.next();
343 let mut norm_a = 0_f64;
344 let mut norm_b = 0_f64;
345 let mut dot = 0_f64;
346 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
347 match ia.cmp(&ib) {
348 Ordering::Equal => { norm_a += va * va; norm_b += vb * vb; dot += va * vb; a_next = a_it.next(); b_next = b_it.next(); }
349 Ordering::Less => { norm_a += va * va; a_next = a_it.next(); }
350 Ordering::Greater => { norm_b += vb * vb; b_next = b_it.next(); }
351 }
352 }
353 while let Some((_, va)) = a_next { norm_a += va * va; a_next = a_it.next(); }
354 while let Some((_, vb)) = b_next { norm_b += vb * vb; b_next = b_it.next(); }
355 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
356 }
357
358 #[inline(always)]
359 fn euclidean_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
360 vec.zip(other)
361 .map(|(a, b)| {
362 let diff = a - b;
363 diff * diff
364 })
365 .sum::<f64>()
366 .sqrt()
367 }
368
369 #[inline(always)]
370 fn manhattan_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
371 vec.zip(other)
372 .map(|(a, b)| (a - b).abs())
373 .sum()
374 }
375
376 #[inline(always)]
377 fn chebyshev_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
378 vec.zip(other)
379 .map(|(a, b)| (a - b).abs())
380 .max_by(|a, b| a.partial_cmp(b).unwrap())
381 .unwrap_or(0.0)
382 }
383}