tf_idf_vectorizer/vectorizer/compute/
compare.rs

1use num::{traits::MulAdd, Num};
2use std::cmp::Ordering;
3
4pub trait Compare<N>
5where
6    N: Num + Copy,
7{
8    /// dot積 
9    /// d(a, b) = Σ(a_i * b_i)
10    fn dot(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
11    /// コサイン類似度
12    /// cos(θ) = Σ(a_i * b_i) / (||a|| * ||b||)
13    /// ||a|| = sqrt(Σ(a_i^2))
14    fn cosine_similarity(vec: impl Iterator<Item = (usize, N)>, other: impl Iterator<Item = (usize, N)>) -> f64;
15    /// ユークリッド距離
16    /// d(a, b) = sqrt(Σ((a_i - b_i)^2))
17    fn euclidean_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
18    /// マンハッタン距離
19    /// d(a, b) = Σ(|a_i - b_i|)
20    fn manhattan_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
21    /// チェビシェフ距離
22    /// d(a, b) = max(|a_i - b_i|)
23    fn chebyshev_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
24}
25
26#[derive(Debug)]
27pub struct DefaultCompare;
28
29/// impl Compare for u8, u16, u32, f32, f64
30impl Compare<u8> for DefaultCompare {
31    #[inline(always)]
32    fn dot(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
33        // NOTE: 元は u32::MAX で割っていたため常に 0 になり比較不能だった。
34        // u8 の量子化レンジに合わせ u8::MAX を使用。
35        let max = u8::MAX as u32;
36        vec.zip(other)
37            .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
38            .sum()
39    }
40
41    #[inline(always)]
42    fn cosine_similarity(vec: impl Iterator<Item = (usize, u8)>, other: impl Iterator<Item = (usize, u8)>) -> f64 {
43        let max = u8::MAX as u32;
44        let mut a_it = vec.fuse();
45        let mut b_it = other.fuse();
46        let mut a_next = a_it.next();
47        let mut b_next = b_it.next();
48        let mut norm_a = 0_f64;
49        let mut norm_b = 0_f64;
50        let mut dot = 0_f64;
51        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
52            match ia.cmp(&ib) {
53                Ordering::Equal => {
54                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
55                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
56                    dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
57                    a_next = a_it.next();
58                    b_next = b_it.next();
59                }
60                Ordering::Less => {
61                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
62                    a_next = a_it.next();
63                }
64                Ordering::Greater => {
65                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
66                    b_next = b_it.next();
67                }
68            }
69        }
70        while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
71        while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
72        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
73    }
74
75    #[inline(always)]
76    fn euclidean_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
77        vec.zip(other)
78            .map(|(a, b)| {
79                let diff = a as i32 - b as i32;
80                (diff * diff) as f64
81            })
82            .sum::<f64>()
83            .sqrt()
84    }
85
86    #[inline(always)]
87    fn manhattan_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
88        vec.zip(other)
89            .map(|(a, b)| {
90                let diff = a as i32 - b as i32;
91                diff.abs() as f64
92            })
93            .sum()
94    }
95
96    #[inline(always)]
97    fn chebyshev_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
98        vec.zip(other)
99            .map(|(a, b)| {
100                let diff = a as i32 - b as i32;
101                diff.abs() as f64
102            })
103            .max_by(|a, b| a.partial_cmp(b).unwrap())
104            .unwrap_or(0.0)
105    }
106}
107
108impl Compare<u16> for DefaultCompare {
109    #[inline(always)]
110    fn dot(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
111        // u16 でも正しいスケール (u16::MAX) を使用。
112        let max = u16::MAX as u32;
113        vec.zip(other)
114            .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
115            .sum()
116    }
117
118    #[inline(always)]
119    fn cosine_similarity(vec: impl Iterator<Item = (usize, u16)>, other: impl Iterator<Item = (usize, u16)>) -> f64 {
120        let max = u16::MAX as u32;
121        let mut a_it = vec.fuse();
122        let mut b_it = other.fuse();
123        let mut a_next = a_it.next();
124        let mut b_next = b_it.next();
125        let mut norm_a = 0_f64;
126        let mut norm_b = 0_f64;
127        let mut dot = 0_f64;
128        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
129            match ia.cmp(&ib) {
130                Ordering::Equal => {
131                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
132                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
133                    dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
134                    a_next = a_it.next();
135                    b_next = b_it.next();
136                }
137                Ordering::Less => { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
138                Ordering::Greater => { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
139            }
140        }
141        while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
142        while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
143        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
144    }
145
146    /// ?
147    #[inline(always)]
148    fn euclidean_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
149        vec.zip(other)
150            .map(|(a, b)| {
151                let diff = a as i32 - b as i32;
152                (diff * diff) as f64
153            })
154            .sum::<f64>()
155            .sqrt()
156    }
157
158    /// ?
159    #[inline(always)]
160    fn manhattan_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
161        vec.zip(other)
162            .map(|(a, b)| {
163                let diff = a as i32 - b as i32;
164                diff.abs() as f64
165            })
166            .sum()
167    }
168
169    /// ?
170    #[inline(always)]
171    fn chebyshev_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
172        vec.zip(other)
173            .map(|(a, b)| {
174                let diff = a as i32 - b as i32;
175                diff.abs() as f64
176            })
177            .max_by(|a, b| a.partial_cmp(b).unwrap())
178            .unwrap_or(0.0)
179    }
180}
181
182impl Compare<u32> for DefaultCompare {
183    #[inline(always)]
184    fn dot(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
185        // u32 の a*b は 64bit を超える可能性があるので u128 で計算。
186        let max = u32::MAX as u128;
187        vec.zip(other)
188            .map(|(a, b)| {
189                let prod = (a as u128).mul_add(b as u128, max - 1);
190                (prod / max) as f64
191            })
192            .sum()
193    }
194
195    #[inline(always)]
196    fn cosine_similarity(vec: impl Iterator<Item = (usize, u32)>, other: impl Iterator<Item = (usize, u32)>) -> f64 {
197        let max = u32::MAX as u128;
198        let mut a_it = vec.fuse();
199        let mut b_it = other.fuse();
200        let mut a_next = a_it.next();
201        let mut b_next = b_it.next();
202        let mut norm_a = 0_f64;
203        let mut norm_b = 0_f64;
204        let mut dot = 0_f64;
205        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
206            match ia.cmp(&ib) {
207                Ordering::Equal => {
208                    let aa = (va as u128).mul_add(va as u128, max - 1);
209                    let bb = (vb as u128).mul_add(vb as u128, max - 1);
210                    let ab = (va as u128).mul_add(vb as u128, max - 1);
211                    norm_a += (aa / max) as f64;
212                    norm_b += (bb / max) as f64;
213                    dot += (ab / max) as f64;
214                    a_next = a_it.next();
215                    b_next = b_it.next();
216                }
217                Ordering::Less => {
218                    let aa = (va as u128).mul_add(va as u128, max - 1);
219                    norm_a += (aa / max) as f64;
220                    a_next = a_it.next();
221                }
222                Ordering::Greater => {
223                    let bb = (vb as u128).mul_add(vb as u128, max - 1);
224                    norm_b += (bb / max) as f64;
225                    b_next = b_it.next();
226                }
227            }
228        }
229        while let Some((_, va)) = a_next { let aa = (va as u128).mul_add(va as u128, max -1); norm_a += (aa / max) as f64; a_next = a_it.next(); }
230        while let Some((_, vb)) = b_next { let bb = (vb as u128).mul_add(vb as u128, max -1); norm_b += (bb / max) as f64; b_next = b_it.next(); }
231        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
232    }
233
234    #[inline(always)]
235    fn euclidean_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
236        vec.zip(other)
237            .map(|(a, b)| {
238                let diff = a as i64 - b as i64;
239                (diff * diff) as f64
240            })
241            .sum::<f64>()
242            .sqrt()
243    }
244
245    #[inline(always)]
246    fn manhattan_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
247        vec.zip(other)
248            .map(|(a, b)| {
249                let diff = a as i64 - b as i64;
250                diff.abs() as f64
251            })
252            .sum()
253    }
254
255    #[inline(always)]
256    fn chebyshev_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
257        vec.zip(other)
258            .map(|(a, b)| {
259                let diff = a as i64 - b as i64;
260                diff.abs() as f64
261            })
262            .max_by(|a, b| a.partial_cmp(b).unwrap())
263            .unwrap_or(0.0)
264    }
265}
266
267impl Compare<f32> for DefaultCompare {
268    #[inline(always)]
269    fn dot(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
270        let mut acc: f32 = 0.0;
271        for (a, b) in vec.zip(other) {
272            acc += a * b; // f32 のまま蓄積
273        }
274        acc as f64
275    }
276
277    #[inline(always)]
278    fn cosine_similarity(vec: impl Iterator<Item = (usize, f32)>, other: impl Iterator<Item = (usize, f32)>) -> f64 {
279        let mut a_it = vec.fuse();
280        let mut b_it = other.fuse();
281        let mut a_next = a_it.next();
282        let mut b_next = b_it.next();
283        let mut sum_a2: f64 = 0.0;
284        let mut sum_b2: f64 = 0.0;
285        let mut sum_ab: f64 = 0.0;
286        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
287            match ia.cmp(&ib) {
288                Ordering::Equal => { sum_a2 += (va * va) as f64; sum_b2 += (vb * vb) as f64; sum_ab += (va * vb) as f64; a_next = a_it.next(); b_next = b_it.next(); }
289                Ordering::Less => { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
290                Ordering::Greater => { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
291            }
292        }
293        while let Some((_, va)) = a_next { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
294        while let Some((_, vb)) = b_next { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
295        if sum_a2 == 0.0 || sum_b2 == 0.0 { 0.0 } else { sum_ab / (sum_a2.sqrt() * sum_b2.sqrt()) }
296    }
297
298    #[inline(always)]
299    fn chebyshev_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
300        let mut maxd: f32 = 0.0;
301        for (a, b) in vec.zip(other) {
302            let d = (a - b).abs();
303            if d > maxd { maxd = d; }
304        }
305        maxd as f64
306    }
307
308    #[inline(always)]
309    fn euclidean_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
310        let mut acc: f32 = 0.0;
311        for (a, b) in vec.zip(other) {
312            let d = a - b;
313            acc += d * d;
314        }
315        (acc.sqrt()) as f64
316    }
317
318    #[inline(always)]
319    fn manhattan_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
320        let mut acc: f32 = 0.0;
321        for (a, b) in vec.zip(other) {
322            acc += (a - b).abs();
323        }
324        acc as f64
325    }
326}
327
328impl Compare<f64> for DefaultCompare {
329    #[inline(always)]
330    fn dot(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
331        vec.zip(other)
332            .map(|(a, b)| a * b)
333            .sum()
334    }
335
336    #[inline(always)]
337    fn cosine_similarity(vec: impl Iterator<Item = (usize, f64)>, other: impl Iterator<Item = (usize, f64)>) -> f64 {
338        let mut a_it = vec.fuse();
339        let mut b_it = other.fuse();
340        let mut a_next = a_it.next();
341        let mut b_next = b_it.next();
342        let mut norm_a = 0_f64;
343        let mut norm_b = 0_f64;
344        let mut dot = 0_f64;
345        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
346            match ia.cmp(&ib) {
347                Ordering::Equal => { norm_a += va * va; norm_b += vb * vb; dot += va * vb; a_next = a_it.next(); b_next = b_it.next(); }
348                Ordering::Less => { norm_a += va * va; a_next = a_it.next(); }
349                Ordering::Greater => { norm_b += vb * vb; b_next = b_it.next(); }
350            }
351        }
352        while let Some((_, va)) = a_next { norm_a += va * va; a_next = a_it.next(); }
353        while let Some((_, vb)) = b_next { norm_b += vb * vb; b_next = b_it.next(); }
354        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
355    }
356
357    #[inline(always)]
358    fn euclidean_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
359        vec.zip(other)
360            .map(|(a, b)| {
361                let diff = a - b;
362                diff * diff
363            })
364            .sum::<f64>()
365            .sqrt()
366    }
367
368    #[inline(always)]
369    fn manhattan_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
370        vec.zip(other)
371            .map(|(a, b)| (a - b).abs())
372            .sum()
373    }
374
375    #[inline(always)]
376    fn chebyshev_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
377        vec.zip(other)
378            .map(|(a, b)| (a - b).abs())
379            .max_by(|a, b| a.partial_cmp(b).unwrap())
380            .unwrap_or(0.0)
381    }
382}