sif_embedding/
sif.rs

1//! SIF: Smooth Inverse Frequency + Common Component Removal.
2use anyhow::{anyhow, Result};
3use ndarray::Array1;
4use ndarray::Array2;
5
6use crate::util;
7use crate::Float;
8use crate::SentenceEmbedder;
9use crate::WordEmbeddings;
10use crate::WordProbabilities;
11use crate::DEFAULT_N_SAMPLES_TO_FIT;
12use crate::DEFAULT_SEPARATOR;
13
14/// Default value of the SIF-weighting parameter `a`,
15/// following the original setting.
16pub const DEFAULT_PARAM_A: Float = 1e-3;
17
18/// Default value of the number of principal components to remove,
19/// following the original setting.
20pub const DEFAULT_N_COMPONENTS: usize = 1;
21
22const MODEL_MAGIC: &[u8] = b"sif_embedding::Sif 0.6\n";
23
24/// An implementation of SIF.
25///
26/// SIF is *Smooth Inverse Frequency* and *Common Component Removal*,
27/// simple but pewerful techniques for sentence embeddings described in the paper:
28/// Sanjeev Arora, Yingyu Liang, and Tengyu Ma,
29/// [A Simple but Tough-to-Beat Baseline for Sentence Embeddings](https://openreview.net/forum?id=SyK00v5xx),
30/// ICLR 2017.
31///
32/// # Brief description of API
33///
34/// The algorithm consists of two steps:
35///
36/// 1. Compute sentence embeddings with the SIF weighting.
37/// 2. Remove the common components from the sentence embeddings.
38///
39/// The common components are computed from input sentences.
40///
41/// Our API is designed to allow reuse of common components once computed
42/// because it is not always possible to obtain a sufficient number of sentences as queries to compute.
43///
44/// [`Sif::fit`] computes the common components from input sentences and returns a fitted instance of [`Sif`].
45/// [`Sif::embeddings`] computes sentence embeddings with the fitted components.
46///
47/// # Examples
48///
49/// ```
50/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
51/// use std::io::BufReader;
52///
53/// use finalfusion::compat::text::ReadText;
54/// use finalfusion::embeddings::Embeddings;
55/// use wordfreq::WordFreq;
56///
57/// use sif_embedding::{Sif, SentenceEmbedder};
58///
59/// // Loads word embeddings from a pretrained model.
60/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
61/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
62/// let word_embeddings = Embeddings::read_text(&mut reader)?;
63///
64/// // Loads word probabilities from a pretrained model.
65/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
66///
67/// // Prepares input sentences.
68/// let sentences = ["las vegas", "mega vegas"];
69///
70/// // Fits the model with input sentences.
71/// let model = Sif::new(&word_embeddings, &word_probs);
72/// let model = model.fit(&sentences)?;
73///
74/// // Computes sentence embeddings in shape (n, m),
75/// // where n is the number of sentences and m is the number of dimensions.
76/// let sent_embeddings = model.embeddings(sentences)?;
77/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
78/// # Ok(())
79/// # }
80/// ```
81///
82/// ## Only SIF weighting
83///
84/// If you want to apply only the SIF weighting to avoid the computation of common components,
85/// use [`Sif::with_parameters`] and set `n_components` to `0`.
86/// In this case, you can skip [`Sif::fit`] and directly perform [`Sif::embeddings`]
87/// because there is no parameter to fit
88/// (although the quality of the embeddings may be worse).
89///
90/// ```
91/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
92/// use std::io::BufReader;
93///
94/// use finalfusion::compat::text::ReadText;
95/// use finalfusion::embeddings::Embeddings;
96/// use wordfreq::WordFreq;
97///
98/// use sif_embedding::{Sif, SentenceEmbedder};
99///
100/// // Loads word embeddings from a pretrained model.
101/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
102/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
103/// let word_embeddings = Embeddings::read_text(&mut reader)?;
104///
105/// // Loads word probabilities from a pretrained model.
106/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
107///
108/// // When setting `n_components` to `0`, no common components are removed, and
109/// // the sentence embeddings can be computed without `fit`.
110/// let model = Sif::with_parameters(&word_embeddings, &word_probs, 1e-3, 0)?;
111/// let sent_embeddings = model.embeddings(["las vegas", "mega vegas"])?;
112/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
113/// # Ok(())
114/// # }
115/// ```
116///
117/// ## Serialization of fitted parameters
118///
119/// If you want to serialize and deserialize the fitted parameters,
120/// use [`Sif::serialize`] and [`Sif::deserialize`].
121///
122/// ```
123/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
124/// use std::io::BufReader;
125///
126/// use approx::assert_relative_eq;
127/// use finalfusion::compat::text::ReadText;
128/// use finalfusion::embeddings::Embeddings;
129/// use wordfreq::WordFreq;
130///
131/// use sif_embedding::{Sif, SentenceEmbedder};
132///
133/// // Loads word embeddings from a pretrained model.
134/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
135/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
136/// let word_embeddings = Embeddings::read_text(&mut reader)?;
137///
138/// // Loads word probabilities from a pretrained model.
139/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
140///
141/// // Prepares input sentences.
142/// let sentences = ["las vegas", "mega vegas"];
143///
144/// // Fits the model and computes sentence embeddings.
145/// let model = Sif::new(&word_embeddings, &word_probs);
146/// let model = model.fit(&sentences)?;
147/// let sent_embeddings = model.embeddings(&sentences)?;
148///
149/// // Serializes and deserializes the fitted parameters.
150/// let bytes = model.serialize()?;
151/// let other = Sif::deserialize(&bytes, &word_embeddings, &word_probs)?;
152/// let other_embeddings = other.embeddings(&sentences)?;
153/// assert_relative_eq!(sent_embeddings, other_embeddings);
154/// # Ok(())
155/// # }
156/// ```
157#[derive(Clone)]
158pub struct Sif<'w, 'p, W, P> {
159    word_embeddings: &'w W,
160    word_probs: &'p P,
161    param_a: Float,
162    n_components: usize,
163    common_components: Option<Array2<Float>>,
164    separator: char,
165    n_samples_to_fit: usize,
166}
167
168impl<'w, 'p, W, P> Sif<'w, 'p, W, P>
169where
170    W: WordEmbeddings,
171    P: WordProbabilities,
172{
173    /// Creates a new instance with default parameters defined by
174    /// [`DEFAULT_PARAM_A`] and [`DEFAULT_N_COMPONENTS`].
175    ///
176    /// # Arguments
177    ///
178    /// * `word_embeddings` - Word embeddings.
179    /// * `word_probs` - Word probabilities.
180    pub const fn new(word_embeddings: &'w W, word_probs: &'p P) -> Self {
181        Self {
182            word_embeddings,
183            word_probs,
184            param_a: DEFAULT_PARAM_A,
185            n_components: DEFAULT_N_COMPONENTS,
186            common_components: None,
187            separator: DEFAULT_SEPARATOR,
188            n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
189        }
190    }
191
192    /// Creates a new instance with manually specified parameters.
193    ///
194    /// # Arguments
195    ///
196    /// * `word_embeddings` - Word embeddings.
197    /// * `word_probs` - Word probabilities.
198    /// * `param_a` - A parameter `a` for SIF-weighting that should be positive.
199    /// * `n_components` - The number of principal components to remove.
200    ///
201    /// When setting `n_components` to `0`, no principal components are removed.
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if `param_a` is not positive.
206    pub fn with_parameters(
207        word_embeddings: &'w W,
208        word_probs: &'p P,
209        param_a: Float,
210        n_components: usize,
211    ) -> Result<Self> {
212        if param_a <= 0. {
213            return Err(anyhow!("param_a must be positive."));
214        }
215        Ok(Self {
216            word_embeddings,
217            word_probs,
218            param_a,
219            n_components,
220            common_components: None,
221            separator: DEFAULT_SEPARATOR,
222            n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
223        })
224    }
225
226    /// Sets a separator for sentence segmentation (default: [`DEFAULT_SEPARATOR`]).
227    pub const fn separator(mut self, separator: char) -> Self {
228        self.separator = separator;
229        self
230    }
231
232    /// Sets the number of samples to fit the model (default: [`DEFAULT_N_SAMPLES_TO_FIT`]).
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if `n_samples_to_fit` is 0.
237    pub fn n_samples_to_fit(mut self, n_samples_to_fit: usize) -> Result<Self> {
238        if n_samples_to_fit == 0 {
239            return Err(anyhow!("n_samples_to_fit must not be 0."));
240        }
241        self.n_samples_to_fit = n_samples_to_fit;
242        Ok(self)
243    }
244
245    /// Applies SIF-weighting.
246    /// (Lines 1--3 in Algorithm 1)
247    ///
248    /// # Complexities
249    ///
250    /// * Time complexity: `O(avg_num_words * embedding_size * num_sentences)`
251    /// * Space complexity: `O(embedding_size * num_sentences)`
252    fn weighted_embeddings<I, S>(&self, sentences: I) -> Array2<Float>
253    where
254        I: IntoIterator<Item = S>,
255        S: AsRef<str>,
256    {
257        let mut sent_embeddings = vec![];
258        let mut n_sentences = 0;
259        // O(num_words * embedding_size * num_sentences)
260        for sent in sentences {
261            let sent = sent.as_ref();
262            let mut n_words = 0;
263            let mut sent_embedding = Array1::zeros(self.embedding_size());
264            // O(avg_num_words * embedding_size)
265            for word in sent.split(self.separator) {
266                if let Some(word_embedding) = self.word_embeddings.embedding(word) {
267                    let weight = self.param_a / (self.param_a + self.word_probs.probability(word));
268                    sent_embedding += &(word_embedding.to_owned() * weight);
269                    n_words += 1;
270                }
271            }
272            if n_words != 0 {
273                sent_embedding /= n_words as Float;
274            } else {
275                // If no parseable tokens, return a vector of a's
276                sent_embedding += self.param_a;
277            }
278            sent_embeddings.extend(sent_embedding.iter());
279            n_sentences += 1;
280        }
281        Array2::from_shape_vec((n_sentences, self.embedding_size()), sent_embeddings).unwrap()
282    }
283
284    /// Serializes the model.
285    pub fn serialize(&self) -> Result<Vec<u8>> {
286        let mut bytes = Vec::new();
287        bytes.extend_from_slice(MODEL_MAGIC);
288        bincode::serialize_into(&mut bytes, &self.param_a)?;
289        bincode::serialize_into(&mut bytes, &self.n_components)?;
290        bincode::serialize_into(&mut bytes, &self.common_components)?;
291        bincode::serialize_into(&mut bytes, &self.separator)?;
292        bincode::serialize_into(&mut bytes, &self.n_samples_to_fit)?;
293        Ok(bytes)
294    }
295
296    /// Deserializes the model.
297    ///
298    /// # Arguments
299    ///
300    /// * `bytes` - Byte sequence exported by [`Self::serialize`].
301    /// * `word_embeddings` - Word embeddings.
302    /// * `word_probs` - Word probabilities.
303    ///
304    /// `word_embeddings` and `word_probs` must be the same as those used in serialization.
305    pub fn deserialize(bytes: &[u8], word_embeddings: &'w W, word_probs: &'p P) -> Result<Self> {
306        if !bytes.starts_with(MODEL_MAGIC) {
307            return Err(anyhow!("The magic number of the input model mismatches."));
308        }
309        let mut bytes = &bytes[MODEL_MAGIC.len()..];
310        let param_a = bincode::deserialize_from(&mut bytes)?;
311        let n_components = bincode::deserialize_from(&mut bytes)?;
312        let common_components = bincode::deserialize_from(&mut bytes)?;
313        let separator = bincode::deserialize_from(&mut bytes)?;
314        let n_samples_to_fit = bincode::deserialize_from(&mut bytes)?;
315        Ok(Self {
316            word_embeddings,
317            word_probs,
318            param_a,
319            n_components,
320            common_components,
321            separator,
322            n_samples_to_fit,
323        })
324    }
325}
326
327impl<W, P> SentenceEmbedder for Sif<'_, '_, W, P>
328where
329    W: WordEmbeddings,
330    P: WordProbabilities,
331{
332    /// Returns the number of dimensions for sentence embeddings,
333    /// which is the same as the number of dimensions for word embeddings.
334    fn embedding_size(&self) -> usize {
335        self.word_embeddings.embedding_size()
336    }
337
338    /// Fits the model with input sentences.
339    ///
340    /// Sentences to fit are randomly sampled from `sentences` with [`Self::n_samples_to_fit`].
341    ///
342    /// If `n_components` is 0, does nothing and returns `self`.
343    ///
344    /// # Errors
345    ///
346    /// Returns an error if `sentences` is empty.
347    ///
348    /// # Complexities
349    ///
350    /// * Time complexity: `O(L*D*S + max(D,S)^3)`
351    /// * Space complexity: `O(D*S + max(D,S)^2)`
352    ///
353    /// where
354    ///
355    /// * `L` is the average number of words in a sentence.
356    /// * `D` is the number of dimensions for word embeddings (`embedding_size`).
357    /// * `S` is the number of sentences used to fit (`n_samples_to_fit`).
358    fn fit<S>(mut self, sentences: &[S]) -> Result<Self>
359    where
360        S: AsRef<str>,
361    {
362        if sentences.is_empty() {
363            return Err(anyhow!("Input sentences must not be empty."));
364        }
365        if self.n_components == 0 {
366            eprintln!("Warning: Nothing to fit since n_components is 0.");
367            return Ok(self);
368        }
369
370        // Time: O(n_samples_to_fit)
371        let sentences = util::sample_sentences(sentences, self.n_samples_to_fit);
372
373        // SIF-weighting.
374        //
375        // Time: O(avg_num_words * embedding_size * n_samples_to_fit)
376        // Space: O(embedding_size * n_samples_to_fit)
377        let sent_embeddings = self.weighted_embeddings(sentences);
378
379        // Common component removal.
380        //
381        // Time: O(max(embedding_size, n_samples_to_fit)^3)
382        // Space: O(max(embedding_size, n_samples_to_fit)^2)
383        let (_, common_components) =
384            util::principal_components(&sent_embeddings, self.n_components);
385        self.common_components = Some(common_components);
386
387        Ok(self)
388    }
389
390    /// Computes embeddings for input sentences using the fitted model.
391    ///
392    /// If `n_components` is 0, the fitting is not required.
393    ///
394    /// # Errors
395    ///
396    /// Returns an error if the model is not fitted.
397    ///
398    /// # Complexities
399    ///
400    /// * Time complexity: `O(L*D*N + C*D*N)`
401    /// * Space complexity: `O(D*N)`
402    ///
403    /// where
404    ///
405    /// * `L` is the average number of words in a sentence.
406    /// * `D` is the number of dimensions for word embeddings (`embedding_size`).
407    /// * `N` is the number of sentences (`sentences.len()`).
408    /// * `C` is the number of components to remove (`n_components`).
409    fn embeddings<I, S>(&self, sentences: I) -> Result<Array2<Float>>
410    where
411        I: IntoIterator<Item = S>,
412        S: AsRef<str>,
413    {
414        if self.n_components != 0 && self.common_components.is_none() {
415            return Err(anyhow!("The model is not fitted."));
416        }
417
418        // SIF-weighting.
419        //
420        // Time: O(avg_num_words * embedding_size * n_sentences)
421        // Space: O(embedding_size * n_sentences)
422        let sent_embeddings = self.weighted_embeddings(sentences);
423        if sent_embeddings.is_empty() {
424            return Ok(sent_embeddings);
425        }
426        if self.n_components == 0 {
427            return Ok(sent_embeddings);
428        }
429
430        // Common component removal.
431        //
432        // Time: O(embedding_size * n_sentences * n_components)
433        // Space: O(embedding_size * n_sentences)
434        let common_components = self.common_components.as_ref().unwrap();
435        let sent_embeddings =
436            util::remove_principal_components(&sent_embeddings, common_components, None);
437        Ok(sent_embeddings)
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    use approx::assert_relative_eq;
446    use ndarray::{arr1, CowArray, Ix1};
447
448    struct SimpleWordEmbeddings {}
449
450    impl WordEmbeddings for SimpleWordEmbeddings {
451        fn embedding(&self, word: &str) -> Option<CowArray<Float, Ix1>> {
452            match word {
453                "A" => Some(arr1(&[1., 2., 3.]).into()),
454                "BB" => Some(arr1(&[4., 5., 6.]).into()),
455                "CCC" => Some(arr1(&[7., 8., 9.]).into()),
456                "DDDD" => Some(arr1(&[10., 11., 12.]).into()),
457                _ => None,
458            }
459        }
460
461        fn embedding_size(&self) -> usize {
462            3
463        }
464    }
465
466    struct SimpleWordProbabilities {}
467
468    impl WordProbabilities for SimpleWordProbabilities {
469        fn probability(&self, word: &str) -> Float {
470            match word {
471                "A" => 0.6,
472                "BB" => 0.2,
473                "CCC" => 0.1,
474                "DDDD" => 0.1,
475                _ => 0.,
476            }
477        }
478
479        fn n_words(&self) -> usize {
480            4
481        }
482
483        fn entries(&self) -> Box<dyn Iterator<Item = (String, Float)> + '_> {
484            Box::new(
485                [("A", 0.6), ("BB", 0.2), ("CCC", 0.1), ("DDDD", 0.1)]
486                    .iter()
487                    .map(|&(word, prob)| (word.to_string(), prob)),
488            )
489        }
490    }
491
492    #[test]
493    fn test_basic() {
494        let word_embeddings = SimpleWordEmbeddings {};
495        let word_probs = SimpleWordProbabilities {};
496
497        let sif = Sif::new(&word_embeddings, &word_probs)
498            .fit(&["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
499            .unwrap();
500
501        let sent_embeddings = sif
502            .embeddings(["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
503            .unwrap();
504        assert_ne!(sent_embeddings, Array2::zeros((5, 3)));
505
506        let sent_embeddings = sif.embeddings(Vec::<&str>::new()).unwrap();
507        assert_eq!(sent_embeddings.shape(), &[0, 3]);
508
509        let sent_embeddings = sif.embeddings([""]).unwrap();
510        assert_ne!(sent_embeddings, Array2::zeros((1, 3)));
511    }
512
513    #[test]
514    fn test_separator() {
515        let word_embeddings = SimpleWordEmbeddings {};
516        let word_probs = SimpleWordProbabilities {};
517
518        let sentences_1 = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
519        let sentences_2 = &["A,BB,CCC,DDDD", "BB,CCC", "A,B,C", "Z", ""];
520
521        let sif = Sif::new(&word_embeddings, &word_probs);
522
523        let sif = sif.fit(sentences_1).unwrap();
524        let embeddings_1 = sif.embeddings(sentences_1).unwrap();
525
526        let sif = sif.separator(',');
527        let embeddings_2 = sif.embeddings(sentences_2).unwrap();
528
529        assert_relative_eq!(embeddings_1, embeddings_2);
530    }
531
532    #[test]
533    fn test_invalid_param_a() {
534        let word_embeddings = SimpleWordEmbeddings {};
535        let word_probs = SimpleWordProbabilities {};
536
537        let sif = Sif::with_parameters(&word_embeddings, &word_probs, 0., DEFAULT_N_COMPONENTS);
538        assert!(sif.is_err());
539    }
540
541    #[test]
542    fn test_no_fitted() {
543        let word_embeddings = SimpleWordEmbeddings {};
544        let word_probs = SimpleWordProbabilities {};
545
546        let sentences = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
547
548        let sif = Sif::new(&word_embeddings, &word_probs);
549        let embeddings = sif.embeddings(sentences);
550        assert!(embeddings.is_err());
551    }
552
553    #[test]
554    fn test_empty_fit() {
555        let word_embeddings = SimpleWordEmbeddings {};
556        let word_probs = SimpleWordProbabilities {};
557
558        let sif = Sif::new(&word_embeddings, &word_probs);
559        let sif = sif.fit(&Vec::<&str>::new());
560        assert!(sif.is_err());
561    }
562
563    #[test]
564    fn test_io() {
565        let word_embeddings = SimpleWordEmbeddings {};
566        let word_probs = SimpleWordProbabilities {};
567
568        let sentences = ["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
569        let model_a = Sif::new(&word_embeddings, &word_probs)
570            .fit(&sentences)
571            .unwrap();
572        let bytes = model_a.serialize().unwrap();
573        let model_b = Sif::deserialize(&bytes, &word_embeddings, &word_probs).unwrap();
574
575        let embeddings_a = model_a.embeddings(sentences).unwrap();
576        let embeddings_b = model_b.embeddings(sentences).unwrap();
577
578        assert_relative_eq!(embeddings_a, embeddings_b);
579    }
580}