tf_idf_vectorizer/utils/math/vector/
math.rs1use std::{cmp::Ordering, ops::{AddAssign, MulAssign}, ptr};
2
3use num::Num;
4
5use super::ZeroSpVec;
6
7impl<N> ZeroSpVec<N>
8where
9 N: Num + AddAssign + MulAssign
10{
11 #[inline(always)]
19 pub fn dot<R>(&self, other: &Self) -> R
20 where
21 R: Num + AddAssign,
22 N: Into<R> + Copy,
23 {
24 debug_assert_eq!(
25 self.len(),
26 other.len(),
27 "Vectors must be of the same length to compute dot product."
28 );
29
30 let mut result = R::zero();
31 let self_nnz = self.nnz();
32 let other_nnz = other.nnz();
33
34 if self_nnz == 0 || other_nnz == 0 {
35 return result;
36 }
37
38 unsafe {
39 let self_inds = std::slice::from_raw_parts(self.ind_ptr(), self_nnz);
40 let self_vals = std::slice::from_raw_parts(self.val_ptr(), self_nnz);
41 let other_inds = std::slice::from_raw_parts(other.ind_ptr(), other_nnz);
42 let other_vals = std::slice::from_raw_parts(other.val_ptr(), other_nnz);
43
44 let mut i = 0;
45 let mut j = 0;
46
47 while i < self_nnz && j < other_nnz {
48 match self_inds[i].cmp(&other_inds[j]) {
49 Ordering::Equal => {
50 result += self_vals[i].into() * other_vals[j].into();
51 i += 1;
52 j += 1;
53 }
54 Ordering::Less => i += 1,
55 Ordering::Greater => j += 1,
56 }
57 }
58 }
59
60 result
61 }
62
63 #[inline(always)]
64 pub fn norm_sq<R>(&self) -> R
65 where
66 R: Num + AddAssign + Copy,
67 N: Into<R> + Copy,
68 {
69 let mut result = R::zero();
70 let self_nnz = self.nnz();
71
72 if self_nnz == 0 {
73 return result;
74 }
75
76 unsafe {
77 let self_val_ptr = self.val_ptr();
78
79 for i in 0..self_nnz {
80 let val: R = (*self_val_ptr.add(i)).into();
81 result += val * val;
82 }
83 }
84
85 result
86 }
87
88 #[inline]
96 pub fn hadamard(&self, other: &Self) -> Self {
97 debug_assert_eq!(
98 self.len(),
99 other.len(),
100 "Vectors must be of the same length to compute hadamard product."
101 );
102
103 let min_nnz = self.nnz().min(other.nnz());
104 let mut result: ZeroSpVec<N> = ZeroSpVec::with_capacity(min_nnz);
105 result.len = self.len();
106
107 if self.nnz() == 0 {
109 result.len = self.len();
110 return result;
111 }
112
113 unsafe {
114 let mut i = 0;
115 let mut j = 0;
116 while i < self.nnz() && j < other.nnz() {
117 let self_ind = ptr::read(self.ind_ptr().add(i));
118 let other_ind = ptr::read(other.ind_ptr().add(j));
119 if self_ind == other_ind {
120 let value = ptr::read(self.val_ptr().add(i)) * ptr::read(other.val_ptr().add(j));
122 result.raw_push(self_ind, value);
123 i += 1;
124 j += 1;
125 } else if self_ind < other_ind {
126 i += 1;
127 } else {
128 j += 1;
129 }
130 }
131 }
132 result
133 }
134}