sif_embedding/
util.rs

1//! Utilities.
2
3use ndarray::{self, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
4use ndarray_linalg::{lobpcg::TruncatedOrder, TruncatedSvd};
5
6use crate::Float;
7
8/// Computes the cosine similarity in `[-1,1]`.
9pub fn cosine_similarity<S, T>(a: &ArrayBase<S, Ix1>, b: &ArrayBase<T, Ix1>) -> Option<Float>
10where
11    S: Data<Elem = Float>,
12    T: Data<Elem = Float>,
13{
14    let dot_product = a.dot(b);
15    let norm_a = a.dot(a).sqrt();
16    let norm_b = b.dot(b).sqrt();
17    if norm_a == 0. || norm_b == 0. {
18        None
19    } else {
20        Some(dot_product / (norm_a * norm_b))
21    }
22}
23
24// The default value of maxiter will take a long time to converge, so we set a small value.
25// (cf. https://github.com/oborchers/Fast_Sentence_Embeddings/blob/master/fse/models/utils.py)
26const SVD_MAX_ITER: usize = 7;
27
28/// Computes the principal components of the input matrix.
29///
30/// # Arguments
31///
32/// - `vectors`: 2D-array of shape `(n, m)`
33/// - `n_components`: Number of components
34///
35/// # Returns
36///
37/// - Singular values of shape `(k,)`
38/// - Right singular vectors of shape `(k, m)`
39///
40/// where `k` is the smaller one of `n_components` and `Rank(vectors)`.
41///
42/// # Complexities
43///
44/// For `m > n`,
45///
46/// * Time complexity: `O(2mn^2 + n^3 + n + mn) = O(m^3)`
47/// * Space complexity: `O(3n^2 + 3n + 2mn) = O(m^2)`
48///
49/// cf. https://arxiv.org/abs/1906.12085
50pub(crate) fn principal_components<S>(
51    vectors: &ArrayBase<S, Ix2>,
52    n_components: usize,
53) -> (Array1<Float>, Array2<Float>)
54where
55    S: Data<Elem = Float>,
56{
57    debug_assert_ne!(n_components, 0);
58    debug_assert!(!vectors.iter().any(|&x| x.is_nan()));
59
60    let n_components = n_components.min(vectors.ncols()).min(vectors.nrows());
61    let svd = TruncatedSvd::new(vectors.to_owned(), TruncatedOrder::Largest)
62        .maxiter(SVD_MAX_ITER)
63        .decompose(n_components)
64        .unwrap();
65    let (_, s, vt) = svd.values_vectors();
66    (s, vt)
67}
68
69/// Removes the principal components from the input vectors,
70/// returning the 2D-array of shape `(n, m)`.
71///
72/// # Arguments
73///
74/// - `vectors`: Sentence vectors to remove components from, of shape `(n, m)`
75/// - `components`: `k` principal components of shape `(k, m)`
76/// - `weights`: Weights of shape `(k,)`
77///
78/// # Complexities
79///
80/// * Time complexity: `O(nmk)`
81/// * Space complexity: `O(nm)`
82pub(crate) fn remove_principal_components<S>(
83    vectors: &ArrayBase<S, Ix2>,
84    components: &ArrayBase<S, Ix2>,
85    weights: Option<&ArrayBase<S, Ix1>>,
86) -> Array2<Float>
87where
88    S: Data<Elem = Float>,
89{
90    // Principal components can be empty if the input matrix is zero.
91    // But, it is not assumed in this crate.
92    debug_assert!(!components.is_empty());
93    debug_assert_eq!(vectors.ncols(), components.ncols());
94
95    // weighted_components of shape (k, m)
96    let weighted_components = weights.map_or_else(
97        || components.to_owned(),
98        |weights| {
99            debug_assert_eq!(components.nrows(), weights.len());
100            let weights = weights.to_owned().insert_axis(Axis(1));
101            components * &weights
102        },
103    );
104
105    // (n,m).dot((k,m).t()).dot((k,m) = (n,m)
106    //
107    // * Time complexity: O(nmk)
108    // * Space complexity: O(nm)
109    let projection = vectors
110        .dot(&weighted_components.t())
111        .dot(&weighted_components);
112    vectors.to_owned() - &projection
113}
114
115/// Time complexity: O(sample_size)
116pub(crate) fn sample_sentences<'a, S>(sentences: &'a [S], sample_size: usize) -> Vec<&'a str>
117where
118    S: AsRef<str> + 'a,
119{
120    let n_sentences = sentences.len();
121    if n_sentences <= sample_size {
122        sentences.iter().map(|s| s.as_ref()).collect()
123    } else {
124        let indices = rand::seq::index::sample(&mut rand::thread_rng(), n_sentences, sample_size);
125        indices.into_iter().map(|i| sentences[i].as_ref()).collect()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_principal_components_k1() {
135        let vectors = ndarray::arr2(&[
136            [1., 1., 1., 0., 0.],
137            [3., 3., 3., 0., 0.],
138            [4., 4., 4., 0., 0.],
139            [5., 5., 5., 0., 0.],
140            [0., 2., 0., 4., 4.],
141            [0., 0., 0., 5., 5.],
142            [0., 1., 0., 2., 2.],
143        ]);
144        let (s, vt) = principal_components(&vectors, 1);
145        assert_eq!(s.shape(), &[1]);
146        assert_eq!(vt.shape(), &[1, 5]);
147    }
148
149    #[test]
150    fn test_principal_components_k2() {
151        let vectors = ndarray::arr2(&[
152            [1., 1., 1., 0., 0.],
153            [3., 3., 3., 0., 0.],
154            [4., 4., 4., 0., 0.],
155            [5., 5., 5., 0., 0.],
156            [0., 2., 0., 4., 4.],
157            [0., 0., 0., 5., 5.],
158            [0., 1., 0., 2., 2.],
159        ]);
160        let (s, vt) = principal_components(&vectors, 2);
161        assert_eq!(s.shape(), &[2]);
162        assert_eq!(vt.shape(), &[2, 5]);
163    }
164
165    #[test]
166    fn test_principal_components_k10() {
167        // Rank(x) = 3.
168        let vectors = ndarray::arr2(&[
169            [1., 1., 1., 0., 0.],
170            [3., 3., 3., 0., 0.],
171            [4., 4., 4., 0., 0.],
172            [5., 5., 5., 0., 0.],
173            [0., 2., 0., 4., 4.],
174            [0., 0., 0., 5., 5.],
175            [0., 1., 0., 2., 2.],
176        ]);
177        let (s, vt) = principal_components(&vectors, 10);
178        assert_eq!(s.shape(), &[3]);
179        assert_eq!(vt.shape(), &[3, 5]);
180    }
181
182    #[test]
183    fn test_principal_components_zeros() {
184        // Rank(x) = 0.
185        let vectors = ndarray::arr2(&[
186            [0., 0., 0., 0., 0.],
187            [0., 0., 0., 0., 0.],
188            [0., 0., 0., 0., 0.],
189            [0., 0., 0., 0., 0.],
190        ]);
191        let (s, vt) = principal_components(&vectors, 5);
192        assert_eq!(s.shape(), &[0]);
193        assert_eq!(vt.shape(), &[0, 5]);
194    }
195
196    #[test]
197    fn test_remove_principal_components_k1() {
198        let vectors = ndarray::arr2(&[
199            [1., 1., 1., 0., 0.],
200            [3., 3., 3., 0., 0.],
201            [4., 4., 4., 0., 0.],
202            [5., 5., 5., 0., 0.],
203            [0., 2., 0., 4., 4.],
204            [0., 0., 0., 5., 5.],
205            [0., 1., 0., 2., 2.],
206        ]);
207        let components = ndarray::arr2(&[[1., 1., 1., 0., 0.]]);
208        let weights = ndarray::arr1(&[1.]);
209        let result = remove_principal_components(&vectors, &components, Some(&weights));
210        assert_eq!(result.shape(), &[7, 5]);
211    }
212
213    #[test]
214    fn test_remove_principal_components_k3() {
215        let vectors = ndarray::arr2(&[
216            [1., 1., 1., 0., 0.],
217            [3., 3., 3., 0., 0.],
218            [4., 4., 4., 0., 0.],
219            [5., 5., 5., 0., 0.],
220            [0., 2., 0., 4., 4.],
221            [0., 0., 0., 5., 5.],
222            [0., 1., 0., 2., 2.],
223        ]);
224        let components = ndarray::arr2(&[
225            [1., 1., 1., 0., 0.],
226            [1., 2., 3., 4., 5.],
227            [0., 1., 0., 3., 3.],
228        ]);
229        let weights = ndarray::arr1(&[1., 2., 4.]);
230        let result = remove_principal_components(&vectors, &components, Some(&weights));
231        assert_eq!(result.shape(), &[7, 5]);
232    }
233
234    #[test]
235    fn test_remove_principal_components_d1() {
236        let vectors = ndarray::arr2(&[[1.], [2.], [3.]]);
237        let components = ndarray::arr2(&[[1.]]);
238        let weights = ndarray::arr1(&[1.]);
239        let result = remove_principal_components(&vectors, &components, Some(&weights));
240        assert_eq!(result.shape(), &[3, 1]);
241    }
242
243    #[test]
244    fn test_sample_sentences() {
245        let sentences = vec!["a", "b", "c", "d", "e", "f", "g"];
246        let sample_size = 3;
247        let sampled = sample_sentences(&sentences, sample_size);
248        assert_eq!(sampled.len(), sample_size);
249        assert!(sampled.iter().all(|s| sentences.contains(s)));
250    }
251}