tf_idf_vectorizer/utils/math/vector/
math.rs

1use std::{ops::{AddAssign, MulAssign}, ptr};
2
3use num::Num;
4
5use super::ZeroSpVec;
6
7impl<N> ZeroSpVec<N> 
8where
9    N: Num + AddAssign + MulAssign
10{
11    /// ドット積を計算するメソッド
12    ///
13    /// # Arguments
14    /// * `other` - 他のベクトル
15    /// 
16    /// # Returns
17    /// * `N` - ドット積の結果
18    #[inline]
19    pub fn dot<R>(&self, other: &Self) -> R
20    where R: Num + AddAssign, N: Into<R> {
21        debug_assert_eq!(
22            self.len(),
23            other.len(),
24            "Vectors must be of the same length to compute dot product."
25        );
26    
27        let mut result: R = R::zero(); // Updated to use R::zero() directly
28    
29        let self_nnz = self.nnz();
30        let other_nnz = other.nnz();
31    
32        // nnz == 0なら返す
33        if self_nnz == 0 {
34            return result;
35        }
36    
37        unsafe {
38            let mut i = 0;
39            let mut j = 0;
40            // キャッシュしたポインタを用いる
41            let self_ind_ptr = self.ind_ptr();
42            let self_val_ptr = self.val_ptr();
43            let other_ind_ptr = other.ind_ptr();
44            let other_val_ptr = other.val_ptr();
45            
46            while i < self_nnz && j < other_nnz {
47                let self_ind = ptr::read(self_ind_ptr.add(i));
48                let other_ind = ptr::read(other_ind_ptr.add(j));
49                if self_ind == other_ind {
50                    let value = ptr::read(self_val_ptr.add(i)).into() * ptr::read(other_val_ptr.add(j)).into();
51                    result += value;
52                    i += 1;
53                    j += 1;
54                } else if self_ind < other_ind {
55                    i += 1;
56                } else {
57                    j += 1;
58                }
59            }
60        }
61        result
62    }
63
64    /// アダマール積を計算するメソッド
65    /// 
66    /// # Arguments
67    /// * `other` - 他のベクトル
68    /// 
69    /// # Returns
70    /// * `ZeroSpVec<N>` - アダマール積の結果
71    #[inline]
72    pub fn hadamard(&self, other: &Self) -> Self {
73        debug_assert_eq!(
74            self.len(),
75            other.len(),
76            "Vectors must be of the same length to compute hadamard product."
77        );
78
79        let min_nnz = self.nnz().min(other.nnz());
80        let mut result: ZeroSpVec<N> = ZeroSpVec::with_capacity(min_nnz);
81        result.len = self.len();
82
83        // nnz == 0 ならゼロ埋めで返す
84        if self.nnz() == 0 {
85            result.len = self.len();
86            return result;
87        }
88
89        unsafe {
90            let mut i = 0;
91            let mut j = 0;
92            while i < self.nnz() && j < other.nnz() {
93                let self_ind = ptr::read(self.ind_ptr().add(i));
94                let other_ind = ptr::read(other.ind_ptr().add(j));
95                if self_ind == other_ind {
96                    // 同じインデックスの要素を掛け算して加算
97                    let value = ptr::read(self.val_ptr().add(i)) * ptr::read(other.val_ptr().add(j));
98                    result.raw_push(self_ind, value);
99                    i += 1;
100                    j += 1;
101                } else if self_ind < other_ind {
102                    i += 1;
103                } else {
104                    j += 1;
105                }
106            }
107        }
108        result
109    }
110}