tf_idf_vectorizer/utils/math/vector/
math.rs1use std::{cmp::Ordering, ops::{AddAssign, MulAssign}, ptr};
2
3use num::Num;
4
5use super::ZeroSpVec;
6
7impl<N> ZeroSpVec<N>
8where
9 N: Num + AddAssign + MulAssign
10{
11 #[inline]
19 pub fn dot<R>(&self, other: &Self) -> R
20 where
21 R: Num + AddAssign,
22 N: Into<R> + Copy,
23 {
24 debug_assert_eq!(
25 self.len(),
26 other.len(),
27 "Vectors must be of the same length to compute dot product."
28 );
29
30 let mut result = R::zero();
31 let self_nnz = self.nnz();
32 let other_nnz = other.nnz();
33
34 if self_nnz == 0 || other_nnz == 0 {
35 return result;
36 }
37
38 unsafe {
39 let self_inds = std::slice::from_raw_parts(self.ind_ptr(), self_nnz);
40 let self_vals = std::slice::from_raw_parts(self.val_ptr(), self_nnz);
41 let other_inds = std::slice::from_raw_parts(other.ind_ptr(), other_nnz);
42 let other_vals = std::slice::from_raw_parts(other.val_ptr(), other_nnz);
43
44 let mut i = 0;
45 let mut j = 0;
46
47 while i < self_nnz && j < other_nnz {
48 let s_ind = *self_inds.get_unchecked(i);
49 let o_ind = *other_inds.get_unchecked(j);
50 match s_ind.cmp(&o_ind) {
51 Ordering::Equal => {
52 result += self_vals[i].into() * other_vals[j].into();
53 i += 1;
54 j += 1;
55 }
56 Ordering::Less => i += 1,
57 Ordering::Greater => j += 1,
58 }
59 }
60 }
61
62 result
63 }
64
65 #[inline]
66 pub fn norm_sq<R>(&self) -> R
67 where
68 R: Num + AddAssign + Copy,
69 N: Into<R> + Copy,
70 {
71 let mut result = R::zero();
72 let self_nnz = self.nnz();
73
74 if self_nnz == 0 {
75 return result;
76 }
77
78 unsafe {
79 let self_val_ptr = self.val_ptr();
80
81 for i in 0..self_nnz {
82 let val: R = (*self_val_ptr.add(i)).into();
83 result += val * val;
84 }
85 }
86
87 result
88 }
89
90 #[inline]
98 pub fn hadamard(&self, other: &Self) -> Self {
99 debug_assert_eq!(
100 self.len(),
101 other.len(),
102 "Vectors must be of the same length to compute hadamard product."
103 );
104
105 let min_nnz = self.nnz().min(other.nnz());
106 let mut result: ZeroSpVec<N> = ZeroSpVec::with_capacity(min_nnz);
107 result.len = self.len();
108
109 if self.nnz() == 0 {
111 result.len = self.len();
112 return result;
113 }
114
115 unsafe {
116 let mut i = 0;
117 let mut j = 0;
118 while i < self.nnz() && j < other.nnz() {
119 let self_ind = ptr::read(self.ind_ptr().add(i));
120 let other_ind = ptr::read(other.ind_ptr().add(j));
121 if self_ind == other_ind {
122 let value = ptr::read(self.val_ptr().add(i)) * ptr::read(other.val_ptr().add(j));
124 result.raw_push(self_ind, value);
125 i += 1;
126 j += 1;
127 } else if self_ind < other_ind {
128 i += 1;
129 } else {
130 j += 1;
131 }
132 }
133 }
134 result
135 }
136}