1use ndarray::{self, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
4use ndarray_linalg::{lobpcg::TruncatedOrder, TruncatedSvd};
5
6use crate::Float;
7
8pub fn cosine_similarity<S, T>(a: &ArrayBase<S, Ix1>, b: &ArrayBase<T, Ix1>) -> Option<Float>
10where
11 S: Data<Elem = Float>,
12 T: Data<Elem = Float>,
13{
14 let dot_product = a.dot(b);
15 let norm_a = a.dot(a).sqrt();
16 let norm_b = b.dot(b).sqrt();
17 if norm_a == 0. || norm_b == 0. {
18 None
19 } else {
20 Some(dot_product / (norm_a * norm_b))
21 }
22}
23
24const SVD_MAX_ITER: usize = 7;
27
28pub(crate) fn principal_components<S>(
51 vectors: &ArrayBase<S, Ix2>,
52 n_components: usize,
53) -> (Array1<Float>, Array2<Float>)
54where
55 S: Data<Elem = Float>,
56{
57 debug_assert_ne!(n_components, 0);
58 debug_assert!(!vectors.iter().any(|&x| x.is_nan()));
59
60 let n_components = n_components.min(vectors.ncols()).min(vectors.nrows());
61 let svd = TruncatedSvd::new(vectors.to_owned(), TruncatedOrder::Largest)
62 .maxiter(SVD_MAX_ITER)
63 .decompose(n_components)
64 .unwrap();
65 let (_, s, vt) = svd.values_vectors();
66 (s, vt)
67}
68
69pub(crate) fn remove_principal_components<S>(
83 vectors: &ArrayBase<S, Ix2>,
84 components: &ArrayBase<S, Ix2>,
85 weights: Option<&ArrayBase<S, Ix1>>,
86) -> Array2<Float>
87where
88 S: Data<Elem = Float>,
89{
90 debug_assert!(!components.is_empty());
93 debug_assert_eq!(vectors.ncols(), components.ncols());
94
95 let weighted_components = weights.map_or_else(
97 || components.to_owned(),
98 |weights| {
99 debug_assert_eq!(components.nrows(), weights.len());
100 let weights = weights.to_owned().insert_axis(Axis(1));
101 components * &weights
102 },
103 );
104
105 let projection = vectors
110 .dot(&weighted_components.t())
111 .dot(&weighted_components);
112 vectors.to_owned() - &projection
113}
114
115pub(crate) fn sample_sentences<'a, S>(sentences: &'a [S], sample_size: usize) -> Vec<&'a str>
117where
118 S: AsRef<str> + 'a,
119{
120 let n_sentences = sentences.len();
121 if n_sentences <= sample_size {
122 sentences.iter().map(|s| s.as_ref()).collect()
123 } else {
124 let indices = rand::seq::index::sample(&mut rand::thread_rng(), n_sentences, sample_size);
125 indices.into_iter().map(|i| sentences[i].as_ref()).collect()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_principal_components_k1() {
135 let vectors = ndarray::arr2(&[
136 [1., 1., 1., 0., 0.],
137 [3., 3., 3., 0., 0.],
138 [4., 4., 4., 0., 0.],
139 [5., 5., 5., 0., 0.],
140 [0., 2., 0., 4., 4.],
141 [0., 0., 0., 5., 5.],
142 [0., 1., 0., 2., 2.],
143 ]);
144 let (s, vt) = principal_components(&vectors, 1);
145 assert_eq!(s.shape(), &[1]);
146 assert_eq!(vt.shape(), &[1, 5]);
147 }
148
149 #[test]
150 fn test_principal_components_k2() {
151 let vectors = ndarray::arr2(&[
152 [1., 1., 1., 0., 0.],
153 [3., 3., 3., 0., 0.],
154 [4., 4., 4., 0., 0.],
155 [5., 5., 5., 0., 0.],
156 [0., 2., 0., 4., 4.],
157 [0., 0., 0., 5., 5.],
158 [0., 1., 0., 2., 2.],
159 ]);
160 let (s, vt) = principal_components(&vectors, 2);
161 assert_eq!(s.shape(), &[2]);
162 assert_eq!(vt.shape(), &[2, 5]);
163 }
164
165 #[test]
166 fn test_principal_components_k10() {
167 let vectors = ndarray::arr2(&[
169 [1., 1., 1., 0., 0.],
170 [3., 3., 3., 0., 0.],
171 [4., 4., 4., 0., 0.],
172 [5., 5., 5., 0., 0.],
173 [0., 2., 0., 4., 4.],
174 [0., 0., 0., 5., 5.],
175 [0., 1., 0., 2., 2.],
176 ]);
177 let (s, vt) = principal_components(&vectors, 10);
178 assert_eq!(s.shape(), &[3]);
179 assert_eq!(vt.shape(), &[3, 5]);
180 }
181
182 #[test]
183 fn test_principal_components_zeros() {
184 let vectors = ndarray::arr2(&[
186 [0., 0., 0., 0., 0.],
187 [0., 0., 0., 0., 0.],
188 [0., 0., 0., 0., 0.],
189 [0., 0., 0., 0., 0.],
190 ]);
191 let (s, vt) = principal_components(&vectors, 5);
192 assert_eq!(s.shape(), &[0]);
193 assert_eq!(vt.shape(), &[0, 5]);
194 }
195
196 #[test]
197 fn test_remove_principal_components_k1() {
198 let vectors = ndarray::arr2(&[
199 [1., 1., 1., 0., 0.],
200 [3., 3., 3., 0., 0.],
201 [4., 4., 4., 0., 0.],
202 [5., 5., 5., 0., 0.],
203 [0., 2., 0., 4., 4.],
204 [0., 0., 0., 5., 5.],
205 [0., 1., 0., 2., 2.],
206 ]);
207 let components = ndarray::arr2(&[[1., 1., 1., 0., 0.]]);
208 let weights = ndarray::arr1(&[1.]);
209 let result = remove_principal_components(&vectors, &components, Some(&weights));
210 assert_eq!(result.shape(), &[7, 5]);
211 }
212
213 #[test]
214 fn test_remove_principal_components_k3() {
215 let vectors = ndarray::arr2(&[
216 [1., 1., 1., 0., 0.],
217 [3., 3., 3., 0., 0.],
218 [4., 4., 4., 0., 0.],
219 [5., 5., 5., 0., 0.],
220 [0., 2., 0., 4., 4.],
221 [0., 0., 0., 5., 5.],
222 [0., 1., 0., 2., 2.],
223 ]);
224 let components = ndarray::arr2(&[
225 [1., 1., 1., 0., 0.],
226 [1., 2., 3., 4., 5.],
227 [0., 1., 0., 3., 3.],
228 ]);
229 let weights = ndarray::arr1(&[1., 2., 4.]);
230 let result = remove_principal_components(&vectors, &components, Some(&weights));
231 assert_eq!(result.shape(), &[7, 5]);
232 }
233
234 #[test]
235 fn test_remove_principal_components_d1() {
236 let vectors = ndarray::arr2(&[[1.], [2.], [3.]]);
237 let components = ndarray::arr2(&[[1.]]);
238 let weights = ndarray::arr1(&[1.]);
239 let result = remove_principal_components(&vectors, &components, Some(&weights));
240 assert_eq!(result.shape(), &[3, 1]);
241 }
242
243 #[test]
244 fn test_sample_sentences() {
245 let sentences = vec!["a", "b", "c", "d", "e", "f", "g"];
246 let sample_size = 3;
247 let sampled = sample_sentences(&sentences, sample_size);
248 assert_eq!(sampled.len(), sample_size);
249 assert!(sampled.iter().all(|s| sentences.contains(s)));
250 }
251}