tf_idf_vectorizer/vectorizer/compute/
compare.rs1use num::{traits::MulAdd, Num};
2use std::cmp::Ordering;
3
4pub trait Compare<N>
5where
6 N: Num + Copy,
7{
8 fn dot(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
11 fn cosine_similarity(vec: impl Iterator<Item = (usize, N)>, other: impl Iterator<Item = (usize, N)>) -> f64;
15 fn euclidean_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
18 fn manhattan_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
21 fn chebyshev_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
24}
25
26#[derive(Debug)]
27pub struct DefaultCompare;
28
29impl Compare<u8> for DefaultCompare {
31 #[inline(always)]
32 fn dot(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
33 let max = u8::MAX as u32;
36 vec.zip(other)
37 .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
38 .sum()
39 }
40
41 #[inline(always)]
42 fn cosine_similarity(vec: impl Iterator<Item = (usize, u8)>, other: impl Iterator<Item = (usize, u8)>) -> f64 {
43 let max = u8::MAX as u32;
44 let mut a_it = vec.fuse();
45 let mut b_it = other.fuse();
46 let mut a_next = a_it.next();
47 let mut b_next = b_it.next();
48 let mut norm_a = 0_f64;
49 let mut norm_b = 0_f64;
50 let mut dot = 0_f64;
51 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
52 match ia.cmp(&ib) {
53 Ordering::Equal => {
54 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
55 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
56 dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
57 a_next = a_it.next();
58 b_next = b_it.next();
59 }
60 Ordering::Less => {
61 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
62 a_next = a_it.next();
63 }
64 Ordering::Greater => {
65 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
66 b_next = b_it.next();
67 }
68 }
69 }
70 while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
71 while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
72 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
73 }
74
75 #[inline(always)]
76 fn euclidean_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
77 vec.zip(other)
78 .map(|(a, b)| {
79 let diff = a as i32 - b as i32;
80 (diff * diff) as f64
81 })
82 .sum::<f64>()
83 .sqrt()
84 }
85
86 #[inline(always)]
87 fn manhattan_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
88 vec.zip(other)
89 .map(|(a, b)| {
90 let diff = a as i32 - b as i32;
91 diff.abs() as f64
92 })
93 .sum()
94 }
95
96 #[inline(always)]
97 fn chebyshev_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
98 vec.zip(other)
99 .map(|(a, b)| {
100 let diff = a as i32 - b as i32;
101 diff.abs() as f64
102 })
103 .max_by(|a, b| a.partial_cmp(b).unwrap())
104 .unwrap_or(0.0)
105 }
106}
107
108impl Compare<u16> for DefaultCompare {
109 #[inline(always)]
110 fn dot(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
111 let max = u16::MAX as u32;
113 vec.zip(other)
114 .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
115 .sum()
116 }
117
118 #[inline(always)]
119 fn cosine_similarity(vec: impl Iterator<Item = (usize, u16)>, other: impl Iterator<Item = (usize, u16)>) -> f64 {
120 let max = u16::MAX as u32;
121 let mut a_it = vec.fuse();
122 let mut b_it = other.fuse();
123 let mut a_next = a_it.next();
124 let mut b_next = b_it.next();
125 let mut norm_a = 0_f64;
126 let mut norm_b = 0_f64;
127 let mut dot = 0_f64;
128 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
129 match ia.cmp(&ib) {
130 Ordering::Equal => {
131 norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
132 norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
133 dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
134 a_next = a_it.next();
135 b_next = b_it.next();
136 }
137 Ordering::Less => { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
138 Ordering::Greater => { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
139 }
140 }
141 while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
142 while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
143 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
144 }
145
146 #[inline(always)]
148 fn euclidean_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
149 vec.zip(other)
150 .map(|(a, b)| {
151 let diff = a as i32 - b as i32;
152 (diff * diff) as f64
153 })
154 .sum::<f64>()
155 .sqrt()
156 }
157
158 #[inline(always)]
160 fn manhattan_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
161 vec.zip(other)
162 .map(|(a, b)| {
163 let diff = a as i32 - b as i32;
164 diff.abs() as f64
165 })
166 .sum()
167 }
168
169 #[inline(always)]
171 fn chebyshev_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
172 vec.zip(other)
173 .map(|(a, b)| {
174 let diff = a as i32 - b as i32;
175 diff.abs() as f64
176 })
177 .max_by(|a, b| a.partial_cmp(b).unwrap())
178 .unwrap_or(0.0)
179 }
180}
181
182impl Compare<u32> for DefaultCompare {
183 #[inline(always)]
184 fn dot(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
185 let max = u32::MAX as u128;
187 vec.zip(other)
188 .map(|(a, b)| {
189 let prod = (a as u128).mul_add(b as u128, max - 1);
190 (prod / max) as f64
191 })
192 .sum()
193 }
194
195 #[inline(always)]
196 fn cosine_similarity(vec: impl Iterator<Item = (usize, u32)>, other: impl Iterator<Item = (usize, u32)>) -> f64 {
197 let max = u32::MAX as u128;
198 let mut a_it = vec.fuse();
199 let mut b_it = other.fuse();
200 let mut a_next = a_it.next();
201 let mut b_next = b_it.next();
202 let mut norm_a = 0_f64;
203 let mut norm_b = 0_f64;
204 let mut dot = 0_f64;
205 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
206 match ia.cmp(&ib) {
207 Ordering::Equal => {
208 let aa = (va as u128).mul_add(va as u128, max - 1);
209 let bb = (vb as u128).mul_add(vb as u128, max - 1);
210 let ab = (va as u128).mul_add(vb as u128, max - 1);
211 norm_a += (aa / max) as f64;
212 norm_b += (bb / max) as f64;
213 dot += (ab / max) as f64;
214 a_next = a_it.next();
215 b_next = b_it.next();
216 }
217 Ordering::Less => {
218 let aa = (va as u128).mul_add(va as u128, max - 1);
219 norm_a += (aa / max) as f64;
220 a_next = a_it.next();
221 }
222 Ordering::Greater => {
223 let bb = (vb as u128).mul_add(vb as u128, max - 1);
224 norm_b += (bb / max) as f64;
225 b_next = b_it.next();
226 }
227 }
228 }
229 while let Some((_, va)) = a_next { let aa = (va as u128).mul_add(va as u128, max -1); norm_a += (aa / max) as f64; a_next = a_it.next(); }
230 while let Some((_, vb)) = b_next { let bb = (vb as u128).mul_add(vb as u128, max -1); norm_b += (bb / max) as f64; b_next = b_it.next(); }
231 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
232 }
233
234 #[inline(always)]
235 fn euclidean_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
236 vec.zip(other)
237 .map(|(a, b)| {
238 let diff = a as i64 - b as i64;
239 (diff * diff) as f64
240 })
241 .sum::<f64>()
242 .sqrt()
243 }
244
245 #[inline(always)]
246 fn manhattan_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
247 vec.zip(other)
248 .map(|(a, b)| {
249 let diff = a as i64 - b as i64;
250 diff.abs() as f64
251 })
252 .sum()
253 }
254
255 #[inline(always)]
256 fn chebyshev_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
257 vec.zip(other)
258 .map(|(a, b)| {
259 let diff = a as i64 - b as i64;
260 diff.abs() as f64
261 })
262 .max_by(|a, b| a.partial_cmp(b).unwrap())
263 .unwrap_or(0.0)
264 }
265}
266
267impl Compare<f32> for DefaultCompare {
268 #[inline(always)]
269 fn dot(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
270 let mut acc: f32 = 0.0;
271 for (a, b) in vec.zip(other) {
272 acc += a * b; }
274 acc as f64
275 }
276
277 #[inline(always)]
278 fn cosine_similarity(vec: impl Iterator<Item = (usize, f32)>, other: impl Iterator<Item = (usize, f32)>) -> f64 {
279 let mut a_it = vec.fuse();
280 let mut b_it = other.fuse();
281 let mut a_next = a_it.next();
282 let mut b_next = b_it.next();
283 let mut sum_a2: f64 = 0.0;
284 let mut sum_b2: f64 = 0.0;
285 let mut sum_ab: f64 = 0.0;
286 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
287 match ia.cmp(&ib) {
288 Ordering::Equal => { sum_a2 += (va * va) as f64; sum_b2 += (vb * vb) as f64; sum_ab += (va * vb) as f64; a_next = a_it.next(); b_next = b_it.next(); }
289 Ordering::Less => { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
290 Ordering::Greater => { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
291 }
292 }
293 while let Some((_, va)) = a_next { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
294 while let Some((_, vb)) = b_next { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
295 if sum_a2 == 0.0 || sum_b2 == 0.0 { 0.0 } else { sum_ab / (sum_a2.sqrt() * sum_b2.sqrt()) }
296 }
297
298 #[inline(always)]
299 fn chebyshev_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
300 let mut maxd: f32 = 0.0;
301 for (a, b) in vec.zip(other) {
302 let d = (a - b).abs();
303 if d > maxd { maxd = d; }
304 }
305 maxd as f64
306 }
307
308 #[inline(always)]
309 fn euclidean_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
310 let mut acc: f32 = 0.0;
311 for (a, b) in vec.zip(other) {
312 let d = a - b;
313 acc += d * d;
314 }
315 (acc.sqrt()) as f64
316 }
317
318 #[inline(always)]
319 fn manhattan_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
320 let mut acc: f32 = 0.0;
321 for (a, b) in vec.zip(other) {
322 acc += (a - b).abs();
323 }
324 acc as f64
325 }
326}
327
328impl Compare<f64> for DefaultCompare {
329 #[inline(always)]
330 fn dot(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
331 vec.zip(other)
332 .map(|(a, b)| a * b)
333 .sum()
334 }
335
336 #[inline(always)]
337 fn cosine_similarity(vec: impl Iterator<Item = (usize, f64)>, other: impl Iterator<Item = (usize, f64)>) -> f64 {
338 let mut a_it = vec.fuse();
339 let mut b_it = other.fuse();
340 let mut a_next = a_it.next();
341 let mut b_next = b_it.next();
342 let mut norm_a = 0_f64;
343 let mut norm_b = 0_f64;
344 let mut dot = 0_f64;
345 while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
346 match ia.cmp(&ib) {
347 Ordering::Equal => { norm_a += va * va; norm_b += vb * vb; dot += va * vb; a_next = a_it.next(); b_next = b_it.next(); }
348 Ordering::Less => { norm_a += va * va; a_next = a_it.next(); }
349 Ordering::Greater => { norm_b += vb * vb; b_next = b_it.next(); }
350 }
351 }
352 while let Some((_, va)) = a_next { norm_a += va * va; a_next = a_it.next(); }
353 while let Some((_, vb)) = b_next { norm_b += vb * vb; b_next = b_it.next(); }
354 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
355 }
356
357 #[inline(always)]
358 fn euclidean_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
359 vec.zip(other)
360 .map(|(a, b)| {
361 let diff = a - b;
362 diff * diff
363 })
364 .sum::<f64>()
365 .sqrt()
366 }
367
368 #[inline(always)]
369 fn manhattan_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
370 vec.zip(other)
371 .map(|(a, b)| (a - b).abs())
372 .sum()
373 }
374
375 #[inline(always)]
376 fn chebyshev_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
377 vec.zip(other)
378 .map(|(a, b)| (a - b).abs())
379 .max_by(|a, b| a.partial_cmp(b).unwrap())
380 .unwrap_or(0.0)
381 }
382}