sif_embedding/
usif.rs

1//! uSIF: Unsupervised Smooth Inverse Frequency + Piecewise Common Component Removal.
2use anyhow::{anyhow, Result};
3use ndarray::Array1;
4use ndarray::Array2;
5
6use crate::util;
7use crate::Float;
8use crate::SentenceEmbedder;
9use crate::WordEmbeddings;
10use crate::WordProbabilities;
11use crate::DEFAULT_N_SAMPLES_TO_FIT;
12use crate::DEFAULT_SEPARATOR;
13
14/// Default value of the number of principal components,
15/// following the original setting.
16pub const DEFAULT_N_COMPONENTS: usize = 5;
17
18const FLOAT_0_5: Float = 0.5;
19const MODEL_MAGIC: &[u8] = b"sif_embedding::USif 0.6\n";
20
21/// An implementation of uSIF.
22///
23/// uSIF is *Unsupervised Smooth Inverse Frequency* and *Piecewise Common Component Removal*,
24/// simple but pewerful techniques for sentence embeddings described in the paper:
25/// Kawin Ethayarajh,
26/// [Unsupervised Random Walk Sentence Embeddings: A Strong but Simple Baseline](https://aclanthology.org/W18-3012/),
27/// RepL4NLP 2018.
28///
29/// # Brief description of API
30///
31/// The algorithm consists of two steps:
32///
33/// 1. Compute sentence embeddings with the uSIF weighting.
34/// 2. Remove the common components from the sentence embeddings.
35///
36/// The weighting parameter and common components are computed from input sentences.
37///
38/// Our API is designed to allow reuse of these values once computed
39/// because it is not always possible to obtain a sufficient number of sentences as queries to compute.
40///
41/// [`USif::fit`] computes these values from input sentences and returns a fitted instance of [`USif`].
42/// [`USif::embeddings`] computes sentence embeddings with the fitted values.
43///
44/// # Examples
45///
46/// ```
47/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
48/// use std::io::BufReader;
49///
50/// use finalfusion::compat::text::ReadText;
51/// use finalfusion::embeddings::Embeddings;
52/// use wordfreq::WordFreq;
53///
54/// use sif_embedding::{USif, SentenceEmbedder};
55///
56/// // Loads word embeddings from a pretrained model.
57/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
58/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
59/// let word_embeddings = Embeddings::read_text(&mut reader)?;
60///
61/// // Loads word probabilities from a pretrained model.
62/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
63///
64/// // Prepares input sentences.
65/// let sentences = ["las vegas", "mega vegas"];
66///
67/// // Fits the model with input sentences.
68/// let model = USif::new(&word_embeddings, &word_probs);
69/// let model = model.fit(&sentences)?;
70///
71/// // Computes sentence embeddings in shape (n, m),
72/// // where n is the number of sentences and m is the number of dimensions.
73/// let sent_embeddings = model.embeddings(sentences)?;
74/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
75/// # Ok(())
76/// # }
77/// ```
78///
79/// ## Only uSIF weighting
80///
81/// If you want to apply only the uSIF weighting to avoid the computation of common components,
82/// use [`USif::with_parameters`] and set `n_components` to `0`.
83///
84/// ```
85/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
86/// use std::io::BufReader;
87///
88/// use finalfusion::compat::text::ReadText;
89/// use finalfusion::embeddings::Embeddings;
90/// use wordfreq::WordFreq;
91///
92/// use sif_embedding::{USif, SentenceEmbedder};
93///
94/// // Loads word embeddings from a pretrained model.
95/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
96/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
97/// let word_embeddings = Embeddings::read_text(&mut reader)?;
98///
99/// // Loads word probabilities from a pretrained model.
100/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
101///
102/// // Prepares input sentences.
103/// let sentences = ["las vegas", "mega vegas"];
104///
105/// // When setting `n_components` to `0`, no common components are removed.
106/// let model = USif::with_parameters(&word_embeddings, &word_probs, 0);
107/// let model = model.fit(&sentences)?;
108/// let sent_embeddings = model.embeddings(sentences)?;
109/// assert_eq!(sent_embeddings.shape(), &[2, 3]);
110/// # Ok(())
111/// # }
112/// ```
113///
114/// ## Serialization of fitted parameters
115///
116/// If you want to serialize and deserialize the fitted parameters,
117/// use [`USif::serialize`] and [`USif::deserialize`].
118///
119/// ```
120/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
121/// use std::io::BufReader;
122///
123/// use approx::assert_relative_eq;
124/// use finalfusion::compat::text::ReadText;
125/// use finalfusion::embeddings::Embeddings;
126/// use wordfreq::WordFreq;
127///
128/// use sif_embedding::{USif, SentenceEmbedder};
129///
130/// // Loads word embeddings from a pretrained model.
131/// let word_embeddings_text = "las 0.0 1.0 2.0\nvegas -3.0 -4.0 -5.0\n";
132/// let mut reader = BufReader::new(word_embeddings_text.as_bytes());
133/// let word_embeddings = Embeddings::read_text(&mut reader)?;
134///
135/// // Loads word probabilities from a pretrained model.
136/// let word_probs = WordFreq::new([("las", 0.4), ("vegas", 0.6)]);
137///
138/// // Prepares input sentences.
139/// let sentences = ["las vegas", "mega vegas"];
140///
141/// // Fits the model and computes sentence embeddings.
142/// let model = USif::new(&word_embeddings, &word_probs);
143/// let model = model.fit(&sentences)?;
144/// let sent_embeddings = model.embeddings(&sentences)?;
145///
146/// // Serializes and deserializes the fitted parameters.
147/// let bytes = model.serialize()?;
148/// let other = USif::deserialize(&bytes, &word_embeddings, &word_probs)?;
149/// let other_embeddings = other.embeddings(&sentences)?;
150/// assert_relative_eq!(sent_embeddings, other_embeddings);
151/// # Ok(())
152/// # }
153/// ```
154#[derive(Clone)]
155pub struct USif<'w, 'p, W, P> {
156    word_embeddings: &'w W,
157    word_probs: &'p P,
158    n_components: usize,
159    param_a: Option<Float>,
160    weights: Option<Array1<Float>>,
161    common_components: Option<Array2<Float>>,
162    separator: char,
163    n_samples_to_fit: usize,
164}
165
166impl<'w, 'p, W, P> USif<'w, 'p, W, P>
167where
168    W: WordEmbeddings,
169    P: WordProbabilities,
170{
171    /// Creates a new instance with default parameters defined by
172    /// [`DEFAULT_N_COMPONENTS`].
173    ///
174    /// # Arguments
175    ///
176    /// * `word_embeddings` - Word embeddings.
177    /// * `word_probs` - Word probabilities.
178    pub const fn new(word_embeddings: &'w W, word_probs: &'p P) -> Self {
179        Self {
180            word_embeddings,
181            word_probs,
182            n_components: DEFAULT_N_COMPONENTS,
183            param_a: None,
184            weights: None,
185            common_components: None,
186            separator: DEFAULT_SEPARATOR,
187            n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
188        }
189    }
190
191    /// Creates a new instance with manually specified parameters.
192    ///
193    /// # Arguments
194    ///
195    /// * `word_embeddings` - Word embeddings.
196    /// * `word_probs` - Word probabilities.
197    /// * `n_components` - The number of principal components to remove.
198    ///
199    /// When setting `n_components` to `0`, no principal components are removed.
200    pub const fn with_parameters(
201        word_embeddings: &'w W,
202        word_probs: &'p P,
203        n_components: usize,
204    ) -> Self {
205        Self {
206            word_embeddings,
207            word_probs,
208            n_components,
209            param_a: None,
210            weights: None,
211            common_components: None,
212            separator: DEFAULT_SEPARATOR,
213            n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
214        }
215    }
216
217    /// Sets a separator for sentence segmentation (default: [`DEFAULT_SEPARATOR`]).
218    pub const fn separator(mut self, separator: char) -> Self {
219        self.separator = separator;
220        self
221    }
222
223    /// Sets the number of samples to fit the model (default: [`DEFAULT_N_SAMPLES_TO_FIT`]).
224    ///
225    /// # Errors
226    ///
227    /// Returns an error if `n_samples_to_fit` is 0.
228    pub fn n_samples_to_fit(mut self, n_samples_to_fit: usize) -> Result<Self> {
229        if n_samples_to_fit == 0 {
230            return Err(anyhow!("n_samples_to_fit must not be 0."));
231        }
232        self.n_samples_to_fit = n_samples_to_fit;
233        Ok(self)
234    }
235
236    /// Computes the average length of sentences.
237    /// (Line 3 in Algorithm 1)
238    fn average_sentence_length<S>(&self, sentences: &[S]) -> Float
239    where
240        S: AsRef<str>,
241    {
242        let mut n_words = 0;
243        for sent in sentences {
244            let sent = sent.as_ref();
245            n_words += sent.split(self.separator).count();
246        }
247        n_words as Float / sentences.len() as Float
248    }
249
250    /// Estimates the parameter `a` for the weight function.
251    /// The returned value is always a positive number.
252    /// (Lines 5--7 in Algorithm 1)
253    fn estimate_param_a(&self, sent_len: Float) -> Float {
254        debug_assert!(sent_len > 0.);
255        let vocab_size = self.word_probs.n_words() as Float;
256        let threshold = 1. - (1. - (1. / vocab_size)).powf(sent_len);
257        let n_greater = self
258            .word_probs
259            .entries()
260            .filter(|(_, prob)| *prob > threshold)
261            .count() as Float;
262        let alpha = n_greater / vocab_size;
263        let partiion = 0.5 * vocab_size;
264        let param_a = (1. - alpha) / alpha.mul_add(partiion, Float::EPSILON); // avoid division by zero.
265        param_a.max(Float::EPSILON) // avoid returning zero.
266    }
267
268    /// Applies SIF-weighting for sentences.
269    /// (Line 8 in Algorithm 1)
270    fn weighted_embeddings<I, S>(&self, sentences: I, param_a: Float) -> Array2<Float>
271    where
272        I: IntoIterator<Item = S>,
273        S: AsRef<str>,
274    {
275        debug_assert!(param_a > 0.);
276        let mut sent_embeddings = vec![];
277        let mut n_sentences = 0;
278        for sent in sentences {
279            let sent_embedding = self.weighted_embedding(sent.as_ref(), param_a);
280            sent_embeddings.extend(sent_embedding.iter());
281            n_sentences += 1;
282        }
283        Array2::from_shape_vec((n_sentences, self.embedding_size()), sent_embeddings).unwrap()
284    }
285
286    /// Applies SIF-weighting for a sentence.
287    /// (Line 8 in Algorithm 1)
288    fn weighted_embedding(&self, sent: &str, param_a: Float) -> Array1<Float> {
289        debug_assert!(param_a > 0.);
290
291        // 1. Extract word embeddings and weights.
292        let mut n_words = 0;
293        let mut word_embeddings: Vec<Float> = vec![];
294        let mut word_weights: Vec<Float> = vec![];
295        for word in sent.split(self.separator) {
296            if let Some(word_embedding) = self.word_embeddings.embedding(word) {
297                word_embeddings.extend(word_embedding.iter());
298                word_weights
299                    .push(param_a / FLOAT_0_5.mul_add(param_a, self.word_probs.probability(word)));
300                n_words += 1;
301            }
302        }
303
304        // If no parseable tokens, return a vector of a's
305        if n_words == 0 {
306            return Array1::zeros(self.embedding_size()) + param_a;
307        }
308
309        // 2. Convert to nd-arrays.
310        let word_embeddings =
311            Array2::from_shape_vec((n_words, self.embedding_size()), word_embeddings).unwrap();
312        let word_weights = Array2::from_shape_vec((n_words, 1), word_weights).unwrap();
313
314        // 3. Normalize word embeddings.
315        let axis = ndarray_linalg::norm::NormalizeAxis::Column; // equivalent to Axis(0)
316        let (mut word_embeddings, _) = ndarray_linalg::norm::normalize(word_embeddings, axis);
317
318        // NOTE: It appears that the normalization above sometimes produces NaNs.
319        // This is a workaround, but I don't know this is correct.
320        word_embeddings.mapv_inplace(|x| if x.is_nan() { 0. } else { x });
321
322        // 4. Weight word embeddings.
323        word_embeddings *= &word_weights;
324
325        // 5. Average word embeddings.
326        word_embeddings.mean_axis(ndarray::Axis(0)).unwrap()
327    }
328
329    /// Estimates the principal components of sentence embeddings.
330    /// (Lines 11--17 in Algorithm 1)
331    ///
332    /// NOTE: Principal components can be empty iff sentence embeddings are all zeros.
333    fn estimate_principal_components(
334        &self,
335        sent_embeddings: &Array2<Float>,
336    ) -> (Array1<Float>, Array2<Float>) {
337        let (singular_values, singular_vectors) =
338            util::principal_components(sent_embeddings, self.n_components);
339        let singular_weights = singular_values.mapv(|v| v.powi(2));
340        let singular_weights = singular_weights.to_owned() / singular_weights.sum();
341        (singular_weights, singular_vectors)
342    }
343
344    /// Serializes the model.
345    pub fn serialize(&self) -> Result<Vec<u8>> {
346        let mut bytes = Vec::new();
347        bytes.extend_from_slice(MODEL_MAGIC);
348        bincode::serialize_into(&mut bytes, &self.n_components)?;
349        bincode::serialize_into(&mut bytes, &self.param_a)?;
350        bincode::serialize_into(&mut bytes, &self.weights)?;
351        bincode::serialize_into(&mut bytes, &self.common_components)?;
352        bincode::serialize_into(&mut bytes, &self.separator)?;
353        bincode::serialize_into(&mut bytes, &self.n_samples_to_fit)?;
354        Ok(bytes)
355    }
356
357    /// Deserializes the model.
358    ///
359    /// # Arguments
360    ///
361    /// * `bytes` - Byte sequence exported by [`Self::serialize`].
362    /// * `word_embeddings` - Word embeddings.
363    /// * `word_probs` - Word probabilities.
364    ///
365    /// `word_embeddings` and `word_probs` must be the same as those used in serialization.
366    pub fn deserialize(bytes: &[u8], word_embeddings: &'w W, word_probs: &'p P) -> Result<Self> {
367        if !bytes.starts_with(MODEL_MAGIC) {
368            return Err(anyhow!("The magic number of the input model mismatches."));
369        }
370        let mut bytes = &bytes[MODEL_MAGIC.len()..];
371        let n_components = bincode::deserialize_from(&mut bytes)?;
372        let param_a = bincode::deserialize_from(&mut bytes)?;
373        let weights = bincode::deserialize_from(&mut bytes)?;
374        let common_components = bincode::deserialize_from(&mut bytes)?;
375        let separator = bincode::deserialize_from(&mut bytes)?;
376        let n_samples_to_fit = bincode::deserialize_from(&mut bytes)?;
377        Ok(Self {
378            word_embeddings,
379            word_probs,
380            n_components,
381            param_a,
382            weights,
383            common_components,
384            separator,
385            n_samples_to_fit,
386        })
387    }
388}
389
390impl<W, P> SentenceEmbedder for USif<'_, '_, W, P>
391where
392    W: WordEmbeddings,
393    P: WordProbabilities,
394{
395    /// Returns the number of dimensions for sentence embeddings,
396    /// which is the same as the number of dimensions for word embeddings.
397    fn embedding_size(&self) -> usize {
398        self.word_embeddings.embedding_size()
399    }
400
401    /// Fits the model with input sentences.
402    ///
403    /// Sentences to fit are randomly sampled from `sentences` with [`Self::n_samples_to_fit`].
404    ///
405    /// # Errors
406    ///
407    /// Returns an error if `sentences` is empty.
408    fn fit<S>(mut self, sentences: &[S]) -> Result<Self>
409    where
410        S: AsRef<str>,
411    {
412        if sentences.is_empty() {
413            return Err(anyhow!("Input sentences must not be empty."));
414        }
415
416        let sentences = util::sample_sentences(sentences, self.n_samples_to_fit);
417
418        // SIF-weighting.
419        let sent_len = self.average_sentence_length(&sentences);
420        if sent_len == 0. {
421            return Err(anyhow!("Input sentences must not be empty."));
422        }
423        let param_a = self.estimate_param_a(sent_len);
424        let sent_embeddings = self.weighted_embeddings(sentences, param_a);
425        self.param_a = Some(param_a);
426
427        // Common component removal.
428        if self.n_components != 0 {
429            let (weights, common_components) = self.estimate_principal_components(&sent_embeddings);
430            self.weights = Some(weights);
431            self.common_components = Some(common_components);
432        }
433        // NOTE: There is no need to set weights and common_components to None.
434        //       because n_components can be set up only in initialization.
435
436        Ok(self)
437    }
438
439    /// Computes embeddings for input sentences using the fitted model.
440    ///
441    /// # Errors
442    ///
443    /// Returns an error if the model is not fitted.
444    fn embeddings<I, S>(&self, sentences: I) -> Result<Array2<Float>>
445    where
446        I: IntoIterator<Item = S>,
447        S: AsRef<str>,
448    {
449        if self.param_a.is_none() {
450            return Err(anyhow!("The model is not fitted."));
451        }
452        // SIF-weighting.
453        let sent_embeddings = self.weighted_embeddings(sentences, self.param_a.unwrap());
454        if sent_embeddings.is_empty() {
455            return Ok(sent_embeddings);
456        }
457        if self.n_components == 0 {
458            return Ok(sent_embeddings);
459        }
460        // Common component removal.
461        let weights = self.weights.as_ref().unwrap();
462        let common_components = self.common_components.as_ref().unwrap();
463        let sent_embeddings =
464            util::remove_principal_components(&sent_embeddings, common_components, Some(weights));
465        Ok(sent_embeddings)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    use approx::assert_relative_eq;
474    use ndarray::{arr1, CowArray, Ix1};
475
476    struct SimpleWordEmbeddings {}
477
478    impl WordEmbeddings for SimpleWordEmbeddings {
479        fn embedding(&self, word: &str) -> Option<CowArray<Float, Ix1>> {
480            match word {
481                "A" => Some(arr1(&[1., 2., 3.]).into()),
482                "BB" => Some(arr1(&[4., 5., 6.]).into()),
483                "CCC" => Some(arr1(&[7., 8., 9.]).into()),
484                "DDDD" => Some(arr1(&[10., 11., 12.]).into()),
485                _ => None,
486            }
487        }
488
489        fn embedding_size(&self) -> usize {
490            3
491        }
492    }
493
494    struct SimpleWordProbabilities {}
495
496    impl WordProbabilities for SimpleWordProbabilities {
497        fn probability(&self, word: &str) -> Float {
498            match word {
499                "A" => 0.6,
500                "BB" => 0.2,
501                "CCC" => 0.1,
502                "DDDD" => 0.1,
503                _ => 0.,
504            }
505        }
506
507        fn n_words(&self) -> usize {
508            4
509        }
510
511        fn entries(&self) -> Box<dyn Iterator<Item = (String, Float)> + '_> {
512            Box::new(
513                [("A", 0.6), ("BB", 0.2), ("CCC", 0.1), ("DDDD", 0.1)]
514                    .iter()
515                    .map(|&(word, prob)| (word.to_string(), prob)),
516            )
517        }
518    }
519
520    #[test]
521    fn test_basic() {
522        let word_embeddings = SimpleWordEmbeddings {};
523        let word_probs = SimpleWordProbabilities {};
524
525        let sif = USif::new(&word_embeddings, &word_probs)
526            .fit(&["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
527            .unwrap();
528
529        let sent_embeddings = sif
530            .embeddings(["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
531            .unwrap();
532        assert_ne!(sent_embeddings, Array2::zeros((5, 3)));
533
534        let sent_embeddings = sif.embeddings(Vec::<&str>::new()).unwrap();
535        assert_eq!(sent_embeddings.shape(), &[0, 3]);
536
537        let sent_embeddings = sif.embeddings([""]).unwrap();
538        assert_ne!(sent_embeddings, Array2::zeros((1, 3)));
539    }
540
541    #[test]
542    fn test_separator() {
543        let word_embeddings = SimpleWordEmbeddings {};
544        let word_probs = SimpleWordProbabilities {};
545
546        let sentences_1 = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
547        let sentences_2 = &["A,BB,CCC,DDDD", "BB,CCC", "A,B,C", "Z", ""];
548
549        let sif = USif::new(&word_embeddings, &word_probs);
550
551        let sif = sif.fit(sentences_1).unwrap();
552        let embeddings_1 = sif.embeddings(sentences_1).unwrap();
553
554        let sif = sif.separator(',');
555        let embeddings_2 = sif.embeddings(sentences_2).unwrap();
556
557        assert_relative_eq!(embeddings_1, embeddings_2);
558    }
559
560    #[test]
561    fn test_no_fitted() {
562        let word_embeddings = SimpleWordEmbeddings {};
563        let word_probs = SimpleWordProbabilities {};
564
565        let sentences = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
566
567        let sif = USif::new(&word_embeddings, &word_probs);
568        let embeddings = sif.embeddings(sentences);
569
570        assert!(embeddings.is_err());
571    }
572
573    #[test]
574    fn test_empty_fit() {
575        let word_embeddings = SimpleWordEmbeddings {};
576        let word_probs = SimpleWordProbabilities {};
577
578        let sif = USif::new(&word_embeddings, &word_probs);
579        let sif = sif.fit(&Vec::<&str>::new());
580
581        assert!(sif.is_err());
582    }
583
584    #[test]
585    fn test_io() {
586        let word_embeddings = SimpleWordEmbeddings {};
587        let word_probs = SimpleWordProbabilities {};
588
589        let sentences = ["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
590        let model_a = USif::new(&word_embeddings, &word_probs)
591            .fit(&sentences)
592            .unwrap();
593        let bytes = model_a.serialize().unwrap();
594        let model_b = USif::deserialize(&bytes, &word_embeddings, &word_probs).unwrap();
595
596        let embeddings_a = model_a.embeddings(sentences).unwrap();
597        let embeddings_b = model_b.embeddings(sentences).unwrap();
598
599        assert_relative_eq!(embeddings_a, embeddings_b);
600    }
601}