tf_idf_vectorizer/vectorizer/compute/
compare.rs

1use num::{traits::MulAdd, Num};
2use std::cmp::Ordering;
3
4pub trait Compare<N>
5where
6    N: Num + Copy,
7{
8    /// dot積 
9    /// d(a, b) = Σ(a_i * b_i)
10    fn dot(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
11    /// コサイン類似度
12    /// cos(θ) = Σ(a_i * b_i) / (||a|| * ||b||)
13    /// ||a|| = sqrt(Σ(a_i^2))
14    fn cosine_similarity(vec: impl Iterator<Item = (usize, N)>, other: impl Iterator<Item = (usize, N)>) -> f64;
15    /// ユークリッド距離
16    /// d(a, b) = sqrt(Σ((a_i - b_i)^2))
17    fn euclidean_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
18    /// マンハッタン距離
19    /// d(a, b) = Σ(|a_i - b_i|)
20    fn manhattan_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
21    /// チェビシェフ距離
22    /// d(a, b) = max(|a_i - b_i|)
23    fn chebyshev_distance(vec: impl Iterator<Item = N>, other: impl Iterator<Item = N>) -> f64;
24}
25
26#[derive(Debug)]
27pub struct DefaultCompare;
28
29/// impl Compare for u8, u16, u32, f32, f64
30impl Compare<u8> for DefaultCompare {
31    #[inline(always)]
32    fn dot(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
33        // NOTE: 元は u32::MAX で割っていたため常に 0 になり比較不能だった。
34        // u8 の量子化レンジに合わせ u8::MAX を使用。
35        let max = u8::MAX as u32;
36        // Normalize dot product to [0, 1] range for quantized u8 values.
37        vec.zip(other)
38            .map(|(a, b)| ((a as u32 * b as u32) as f64 / max as f64))
39            .sum()
40    }
41
42    #[inline(always)]
43    fn cosine_similarity(vec: impl Iterator<Item = (usize, u8)>, other: impl Iterator<Item = (usize, u8)>) -> f64 {
44        let max = u8::MAX as u32;
45        let mut a_it = vec.fuse();
46        let mut b_it = other.fuse();
47        let mut a_next = a_it.next();
48        let mut b_next = b_it.next();
49        let mut norm_a = 0_f64;
50        let mut norm_b = 0_f64;
51        let mut dot = 0_f64;
52        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
53            match ia.cmp(&ib) {
54                Ordering::Equal => {
55                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
56                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
57                    dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
58                    a_next = a_it.next();
59                    b_next = b_it.next();
60                }
61                Ordering::Less => {
62                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
63                    a_next = a_it.next();
64                }
65                Ordering::Greater => {
66                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
67                    b_next = b_it.next();
68                }
69            }
70        }
71        while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
72        while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
73        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
74    }
75
76    #[inline(always)]
77    fn euclidean_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
78        vec.zip(other)
79            .map(|(a, b)| {
80                let diff = a as i32 - b as i32;
81                (diff * diff) as f64
82            })
83            .sum::<f64>()
84            .sqrt()
85    }
86
87    #[inline(always)]
88    fn manhattan_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
89        vec.zip(other)
90            .map(|(a, b)| {
91                let diff = a as i32 - b as i32;
92                diff.abs() as f64
93            })
94            .sum()
95    }
96
97    #[inline(always)]
98    fn chebyshev_distance(vec: impl Iterator<Item = u8>, other: impl Iterator<Item = u8>) -> f64 {
99        vec.zip(other)
100            .map(|(a, b)| {
101                let diff = a as i32 - b as i32;
102                diff.abs() as f64
103            })
104            .max_by(|a, b| a.partial_cmp(b).unwrap())
105            .unwrap_or(0.0)
106    }
107}
108
109impl Compare<u16> for DefaultCompare {
110    #[inline(always)]
111    fn dot(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
112        // u16 でも正しいスケール (u16::MAX) を使用。
113        let max = u16::MAX as u32;
114        vec.zip(other)
115            .map(|(a, b)| (( (a as u32).mul_add(b as u32, max - 1) ) / max) as f64)
116            .sum()
117    }
118
119    #[inline(always)]
120    fn cosine_similarity(vec: impl Iterator<Item = (usize, u16)>, other: impl Iterator<Item = (usize, u16)>) -> f64 {
121        let max = u16::MAX as u32;
122        let mut a_it = vec.fuse();
123        let mut b_it = other.fuse();
124        let mut a_next = a_it.next();
125        let mut b_next = b_it.next();
126        let mut norm_a = 0_f64;
127        let mut norm_b = 0_f64;
128        let mut dot = 0_f64;
129        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
130            match ia.cmp(&ib) {
131                Ordering::Equal => {
132                    norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64;
133                    norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64;
134                    dot += (((va as u32).mul_add(vb as u32, max - 1)) / max) as f64;
135                    a_next = a_it.next();
136                    b_next = b_it.next();
137                }
138                Ordering::Less => { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
139                Ordering::Greater => { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
140            }
141        }
142        while let Some((_, va)) = a_next { norm_a += (((va as u32).mul_add(va as u32, max - 1)) / max) as f64; a_next = a_it.next(); }
143        while let Some((_, vb)) = b_next { norm_b += (((vb as u32).mul_add(vb as u32, max - 1)) / max) as f64; b_next = b_it.next(); }
144        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
145    }
146
147    /// ?
148    #[inline(always)]
149    fn euclidean_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
150        vec.zip(other)
151            .map(|(a, b)| {
152                let diff = a as i32 - b as i32;
153                (diff * diff) as f64
154            })
155            .sum::<f64>()
156            .sqrt()
157    }
158
159    /// ?
160    #[inline(always)]
161    fn manhattan_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
162        vec.zip(other)
163            .map(|(a, b)| {
164                let diff = a as i32 - b as i32;
165                diff.abs() as f64
166            })
167            .sum()
168    }
169
170    /// ?
171    #[inline(always)]
172    fn chebyshev_distance(vec: impl Iterator<Item = u16>, other: impl Iterator<Item = u16>) -> f64 {
173        vec.zip(other)
174            .map(|(a, b)| {
175                let diff = a as i32 - b as i32;
176                diff.abs() as f64
177            })
178            .max_by(|a, b| a.partial_cmp(b).unwrap())
179            .unwrap_or(0.0)
180    }
181}
182
183impl Compare<u32> for DefaultCompare {
184    #[inline(always)]
185    fn dot(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
186        // u32 の a*b は 64bit を超える可能性があるので u128 で計算。
187        let max = u32::MAX as u128;
188        vec.zip(other)
189            .map(|(a, b)| {
190                let prod = (a as u128).mul_add(b as u128, max - 1);
191                (prod / max) as f64
192            })
193            .sum()
194    }
195
196    #[inline(always)]
197    fn cosine_similarity(vec: impl Iterator<Item = (usize, u32)>, other: impl Iterator<Item = (usize, u32)>) -> f64 {
198        let max = u32::MAX as u128;
199        let mut a_it = vec.fuse();
200        let mut b_it = other.fuse();
201        let mut a_next = a_it.next();
202        let mut b_next = b_it.next();
203        let mut norm_a = 0_f64;
204        let mut norm_b = 0_f64;
205        let mut dot = 0_f64;
206        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
207            match ia.cmp(&ib) {
208                Ordering::Equal => {
209                    let aa = (va as u128).mul_add(va as u128, max - 1);
210                    let bb = (vb as u128).mul_add(vb as u128, max - 1);
211                    let ab = (va as u128).mul_add(vb as u128, max - 1);
212                    norm_a += (aa / max) as f64;
213                    norm_b += (bb / max) as f64;
214                    dot += (ab / max) as f64;
215                    a_next = a_it.next();
216                    b_next = b_it.next();
217                }
218                Ordering::Less => {
219                    let aa = (va as u128).mul_add(va as u128, max - 1);
220                    norm_a += (aa / max) as f64;
221                    a_next = a_it.next();
222                }
223                Ordering::Greater => {
224                    let bb = (vb as u128).mul_add(vb as u128, max - 1);
225                    norm_b += (bb / max) as f64;
226                    b_next = b_it.next();
227                }
228            }
229        }
230        while let Some((_, va)) = a_next { let aa = (va as u128).mul_add(va as u128, max -1); norm_a += (aa / max) as f64; a_next = a_it.next(); }
231        while let Some((_, vb)) = b_next { let bb = (vb as u128).mul_add(vb as u128, max -1); norm_b += (bb / max) as f64; b_next = b_it.next(); }
232        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
233    }
234
235    #[inline(always)]
236    fn euclidean_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
237        vec.zip(other)
238            .map(|(a, b)| {
239                let diff = a as i64 - b as i64;
240                (diff * diff) as f64
241            })
242            .sum::<f64>()
243            .sqrt()
244    }
245
246    #[inline(always)]
247    fn manhattan_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
248        vec.zip(other)
249            .map(|(a, b)| {
250                let diff = a as i64 - b as i64;
251                diff.abs() as f64
252            })
253            .sum()
254    }
255
256    #[inline(always)]
257    fn chebyshev_distance(vec: impl Iterator<Item = u32>, other: impl Iterator<Item = u32>) -> f64 {
258        vec.zip(other)
259            .map(|(a, b)| {
260                let diff = a as i64 - b as i64;
261                diff.abs() as f64
262            })
263            .max_by(|a, b| a.partial_cmp(b).unwrap())
264            .unwrap_or(0.0)
265    }
266}
267
268impl Compare<f32> for DefaultCompare {
269    #[inline(always)]
270    fn dot(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
271        let mut acc: f32 = 0.0;
272        for (a, b) in vec.zip(other) {
273            acc += a * b; // f32 のまま蓄積
274        }
275        acc as f64
276    }
277
278    #[inline(always)]
279    fn cosine_similarity(vec: impl Iterator<Item = (usize, f32)>, other: impl Iterator<Item = (usize, f32)>) -> f64 {
280        let mut a_it = vec.fuse();
281        let mut b_it = other.fuse();
282        let mut a_next = a_it.next();
283        let mut b_next = b_it.next();
284        let mut sum_a2: f64 = 0.0;
285        let mut sum_b2: f64 = 0.0;
286        let mut sum_ab: f64 = 0.0;
287        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
288            match ia.cmp(&ib) {
289                Ordering::Equal => { sum_a2 += (va * va) as f64; sum_b2 += (vb * vb) as f64; sum_ab += (va * vb) as f64; a_next = a_it.next(); b_next = b_it.next(); }
290                Ordering::Less => { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
291                Ordering::Greater => { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
292            }
293        }
294        while let Some((_, va)) = a_next { sum_a2 += (va * va) as f64; a_next = a_it.next(); }
295        while let Some((_, vb)) = b_next { sum_b2 += (vb * vb) as f64; b_next = b_it.next(); }
296        if sum_a2 == 0.0 || sum_b2 == 0.0 { 0.0 } else { sum_ab / (sum_a2.sqrt() * sum_b2.sqrt()) }
297    }
298
299    #[inline(always)]
300    fn chebyshev_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
301        let mut maxd: f32 = 0.0;
302        for (a, b) in vec.zip(other) {
303            let d = (a - b).abs();
304            if d > maxd { maxd = d; }
305        }
306        maxd as f64
307    }
308
309    #[inline(always)]
310    fn euclidean_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
311        let mut acc: f32 = 0.0;
312        for (a, b) in vec.zip(other) {
313            let d = a - b;
314            acc += d * d;
315        }
316        (acc.sqrt()) as f64
317    }
318
319    #[inline(always)]
320    fn manhattan_distance(vec: impl Iterator<Item = f32>, other: impl Iterator<Item = f32>) -> f64 {
321        let mut acc: f32 = 0.0;
322        for (a, b) in vec.zip(other) {
323            acc += (a - b).abs();
324        }
325        acc as f64
326    }
327}
328
329impl Compare<f64> for DefaultCompare {
330    #[inline(always)]
331    fn dot(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
332        vec.zip(other)
333            .map(|(a, b)| a * b)
334            .sum()
335    }
336
337    #[inline(always)]
338    fn cosine_similarity(vec: impl Iterator<Item = (usize, f64)>, other: impl Iterator<Item = (usize, f64)>) -> f64 {
339        let mut a_it = vec.fuse();
340        let mut b_it = other.fuse();
341        let mut a_next = a_it.next();
342        let mut b_next = b_it.next();
343        let mut norm_a = 0_f64;
344        let mut norm_b = 0_f64;
345        let mut dot = 0_f64;
346        while let (Some((ia, va)), Some((ib, vb))) = (a_next, b_next) {
347            match ia.cmp(&ib) {
348                Ordering::Equal => { norm_a += va * va; norm_b += vb * vb; dot += va * vb; a_next = a_it.next(); b_next = b_it.next(); }
349                Ordering::Less => { norm_a += va * va; a_next = a_it.next(); }
350                Ordering::Greater => { norm_b += vb * vb; b_next = b_it.next(); }
351            }
352        }
353        while let Some((_, va)) = a_next { norm_a += va * va; a_next = a_it.next(); }
354        while let Some((_, vb)) = b_next { norm_b += vb * vb; b_next = b_it.next(); }
355        if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a.sqrt() * norm_b.sqrt()) }
356    }
357
358    #[inline(always)]
359    fn euclidean_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
360        vec.zip(other)
361            .map(|(a, b)| {
362                let diff = a - b;
363                diff * diff
364            })
365            .sum::<f64>()
366            .sqrt()
367    }
368
369    #[inline(always)]
370    fn manhattan_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
371        vec.zip(other)
372            .map(|(a, b)| (a - b).abs())
373            .sum()
374    }
375
376    #[inline(always)]
377    fn chebyshev_distance(vec: impl Iterator<Item = f64>, other: impl Iterator<Item = f64>) -> f64 {
378        vec.zip(other)
379            .map(|(a, b)| (a - b).abs())
380            .max_by(|a, b| a.partial_cmp(b).unwrap())
381            .unwrap_or(0.0)
382    }
383}