rust2vec/
similarity.rs

1//! Traits and trait implementations for similarity queries.
2
3use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashSet};
5
6use ndarray::{s, Array1, ArrayView1, ArrayView2};
7use ordered_float::NotNan;
8
9use crate::embeddings::Embeddings;
10use crate::storage::StorageView;
11use crate::util::l2_normalize;
12use crate::vocab::Vocab;
13
14/// A word with its similarity.
15///
16/// This data structure is used to store a pair consisting of a word and
17/// its similarity to a query word.
18#[derive(Debug, Eq, PartialEq)]
19pub struct WordSimilarity<'a> {
20    pub similarity: NotNan<f32>,
21    pub word: &'a str,
22}
23
24impl<'a> Ord for WordSimilarity<'a> {
25    fn cmp(&self, other: &Self) -> Ordering {
26        match other.similarity.cmp(&self.similarity) {
27            Ordering::Equal => self.word.cmp(other.word),
28            ordering => ordering,
29        }
30    }
31}
32
33impl<'a> PartialOrd for WordSimilarity<'a> {
34    fn partial_cmp(&self, other: &WordSimilarity) -> Option<Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39/// Trait for analogy queries.
40pub trait Analogy {
41    /// Perform an analogy query.
42    ///
43    /// This method returns words that are close in vector space the analogy
44    /// query `word1` is to `word2` as `word3` is to `?`. More concretely,
45    /// it searches embeddings that are similar to:
46    ///
47    /// *embedding(word2) - embedding(word1) + embedding(word3)*
48    ///
49    /// At most, `limit` results are returned.
50    fn analogy(
51        &self,
52        word1: &str,
53        word2: &str,
54        word3: &str,
55        limit: usize,
56    ) -> Option<Vec<WordSimilarity>>;
57}
58
59impl<V, S> Analogy for Embeddings<V, S>
60where
61    V: Vocab,
62    S: StorageView,
63{
64    fn analogy(
65        &self,
66        word1: &str,
67        word2: &str,
68        word3: &str,
69        limit: usize,
70    ) -> Option<Vec<WordSimilarity>> {
71        self.analogy_by(word1, word2, word3, limit, |embeds, embed| {
72            embeds.dot(&embed)
73        })
74    }
75}
76
77/// Trait for analogy queries with a custom similarity function.
78pub trait AnalogyBy {
79    /// Perform an analogy query using the given similarity function.
80    ///
81    /// This method returns words that are close in vector space the analogy
82    /// query `word1` is to `word2` as `word3` is to `?`. More concretely,
83    /// it searches embeddings that are similar to:
84    ///
85    /// *embedding(word2) - embedding(word1) + embedding(word3)*
86    ///
87    /// At most, `limit` results are returned.
88    fn analogy_by<F>(
89        &self,
90        word1: &str,
91        word2: &str,
92        word3: &str,
93        limit: usize,
94        similarity: F,
95    ) -> Option<Vec<WordSimilarity>>
96    where
97        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
98}
99
100impl<V, S> AnalogyBy for Embeddings<V, S>
101where
102    V: Vocab,
103    S: StorageView,
104{
105    fn analogy_by<F>(
106        &self,
107        word1: &str,
108        word2: &str,
109        word3: &str,
110        limit: usize,
111        similarity: F,
112    ) -> Option<Vec<WordSimilarity>>
113    where
114        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
115    {
116        let embedding1 = self.embedding(word1)?;
117        let embedding2 = self.embedding(word2)?;
118        let embedding3 = self.embedding(word3)?;
119
120        let mut embedding = (&embedding2.as_view() - &embedding1.as_view()) + embedding3.as_view();
121        l2_normalize(embedding.view_mut());
122
123        let skip = [word1, word2, word3].iter().cloned().collect();
124
125        Some(self.similarity_(embedding.view(), &skip, limit, similarity))
126    }
127}
128
129/// Trait for similarity queries.
130pub trait Similarity {
131    /// Find words that are similar to the query word.
132    ///
133    /// The similarity between two words is defined by the dot product of
134    /// the embeddings. If the vectors are unit vectors (e.g. by virtue of
135    /// calling `normalize`), this is the cosine similarity. At most, `limit`
136    /// results are returned.
137    fn similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarity>>;
138}
139
140impl<V, S> Similarity for Embeddings<V, S>
141where
142    V: Vocab,
143    S: StorageView,
144{
145    fn similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarity>> {
146        self.similarity_by(word, limit, |embeds, embed| embeds.dot(&embed))
147    }
148}
149
150/// Trait for similarity queries with a custom similarity function.
151pub trait SimilarityBy {
152    /// Find words that are similar to the query word using the given similarity
153    /// function.
154    ///
155    /// The similarity function should return, given the embeddings matrix and
156    /// the word vector a vector of similarity scores. At most, `limit` results
157    /// are returned.
158    fn similarity_by<F>(
159        &self,
160        word: &str,
161        limit: usize,
162        similarity: F,
163    ) -> Option<Vec<WordSimilarity>>
164    where
165        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
166}
167
168impl<V, S> SimilarityBy for Embeddings<V, S>
169where
170    V: Vocab,
171    S: StorageView,
172{
173    fn similarity_by<F>(
174        &self,
175        word: &str,
176        limit: usize,
177        similarity: F,
178    ) -> Option<Vec<WordSimilarity>>
179    where
180        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
181    {
182        let embed = self.embedding(word)?;
183        let mut skip = HashSet::new();
184        skip.insert(word);
185
186        Some(self.similarity_(embed.as_view(), &skip, limit, similarity))
187    }
188}
189
190trait SimilarityPrivate {
191    fn similarity_<F>(
192        &self,
193        embed: ArrayView1<f32>,
194        skip: &HashSet<&str>,
195        limit: usize,
196        similarity: F,
197    ) -> Vec<WordSimilarity>
198    where
199        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
200}
201
202impl<V, S> SimilarityPrivate for Embeddings<V, S>
203where
204    V: Vocab,
205    S: StorageView,
206{
207    fn similarity_<F>(
208        &self,
209        embed: ArrayView1<f32>,
210        skip: &HashSet<&str>,
211        limit: usize,
212        mut similarity: F,
213    ) -> Vec<WordSimilarity>
214    where
215        F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
216    {
217        // ndarray#474
218        #[allow(clippy::deref_addrof)]
219        let sims = similarity(
220            self.storage().view().slice(s![0..self.vocab().len(), ..]),
221            embed.view(),
222        );
223
224        let mut results = BinaryHeap::with_capacity(limit);
225        for (idx, &sim) in sims.iter().enumerate() {
226            let word = &self.vocab().words()[idx];
227
228            // Don't add words that we are explicitly asked to skip.
229            if skip.contains(word.as_str()) {
230                continue;
231            }
232
233            let word_similarity = WordSimilarity {
234                word,
235                similarity: NotNan::new(sim).expect("Encountered NaN"),
236            };
237
238            if results.len() < limit {
239                results.push(word_similarity);
240            } else {
241                let mut peek = results.peek_mut().expect("Cannot peek non-empty heap");
242                if word_similarity < *peek {
243                    *peek = word_similarity
244                }
245            }
246        }
247
248        results.into_sorted_vec()
249    }
250}
251
252#[cfg(test)]
253mod tests {
254
255    use std::fs::File;
256    use std::io::BufReader;
257
258    use crate::embeddings::Embeddings;
259    use crate::similarity::{Analogy, Similarity};
260    use crate::word2vec::ReadWord2Vec;
261
262    static SIMILARITY_ORDER_STUTTGART_10: &'static [&'static str] = &[
263        "Karlsruhe",
264        "Mannheim",
265        "München",
266        "Darmstadt",
267        "Heidelberg",
268        "Wiesbaden",
269        "Kassel",
270        "Düsseldorf",
271        "Leipzig",
272        "Berlin",
273    ];
274
275    static SIMILARITY_ORDER: &'static [&'static str] = &[
276        "Potsdam",
277        "Hamburg",
278        "Leipzig",
279        "Dresden",
280        "München",
281        "Düsseldorf",
282        "Bonn",
283        "Stuttgart",
284        "Weimar",
285        "Berlin-Charlottenburg",
286        "Rostock",
287        "Karlsruhe",
288        "Chemnitz",
289        "Breslau",
290        "Wiesbaden",
291        "Hannover",
292        "Mannheim",
293        "Kassel",
294        "Köln",
295        "Danzig",
296        "Erfurt",
297        "Dessau",
298        "Bremen",
299        "Charlottenburg",
300        "Magdeburg",
301        "Neuruppin",
302        "Darmstadt",
303        "Jena",
304        "Wien",
305        "Heidelberg",
306        "Dortmund",
307        "Stettin",
308        "Schwerin",
309        "Neubrandenburg",
310        "Greifswald",
311        "Göttingen",
312        "Braunschweig",
313        "Berliner",
314        "Warschau",
315        "Berlin-Spandau",
316    ];
317
318    static ANALOGY_ORDER: &'static [&'static str] = &[
319        "Deutschland",
320        "Westdeutschland",
321        "Sachsen",
322        "Mitteldeutschland",
323        "Brandenburg",
324        "Polen",
325        "Norddeutschland",
326        "Dänemark",
327        "Schleswig-Holstein",
328        "Österreich",
329        "Bayern",
330        "Thüringen",
331        "Bundesrepublik",
332        "Ostdeutschland",
333        "Preußen",
334        "Deutschen",
335        "Hessen",
336        "Potsdam",
337        "Mecklenburg",
338        "Niedersachsen",
339        "Hamburg",
340        "Süddeutschland",
341        "Bremen",
342        "Russland",
343        "Deutschlands",
344        "BRD",
345        "Litauen",
346        "Mecklenburg-Vorpommern",
347        "DDR",
348        "West-Berlin",
349        "Saarland",
350        "Lettland",
351        "Hannover",
352        "Rostock",
353        "Sachsen-Anhalt",
354        "Pommern",
355        "Schweden",
356        "Deutsche",
357        "deutschen",
358        "Westfalen",
359    ];
360
361    #[test]
362    fn test_similarity() {
363        let f = File::open("testdata/similarity.bin").unwrap();
364        let mut reader = BufReader::new(f);
365        let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
366
367        let result = embeddings.similarity("Berlin", 40);
368        assert!(result.is_some());
369        let result = result.unwrap();
370        assert_eq!(40, result.len());
371
372        for (idx, word_similarity) in result.iter().enumerate() {
373            assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
374        }
375
376        let result = embeddings.similarity("Berlin", 10);
377        assert!(result.is_some());
378        let result = result.unwrap();
379        assert_eq!(10, result.len());
380
381        println!("{:?}", result);
382
383        for (idx, word_similarity) in result.iter().enumerate() {
384            assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
385        }
386    }
387
388    #[test]
389    fn test_similarity_limit() {
390        let f = File::open("testdata/similarity.bin").unwrap();
391        let mut reader = BufReader::new(f);
392        let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
393
394        let result = embeddings.similarity("Stuttgart", 10);
395        assert!(result.is_some());
396        let result = result.unwrap();
397        assert_eq!(10, result.len());
398
399        println!("{:?}", result);
400
401        for (idx, word_similarity) in result.iter().enumerate() {
402            assert_eq!(SIMILARITY_ORDER_STUTTGART_10[idx], word_similarity.word)
403        }
404    }
405
406    #[test]
407    fn test_analogy() {
408        let f = File::open("testdata/analogy.bin").unwrap();
409        let mut reader = BufReader::new(f);
410        let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
411
412        let result = embeddings.analogy("Paris", "Frankreich", "Berlin", 40);
413        assert!(result.is_some());
414        let result = result.unwrap();
415        assert_eq!(40, result.len());
416
417        for (idx, word_similarity) in result.iter().enumerate() {
418            assert_eq!(ANALOGY_ORDER[idx], word_similarity.word)
419        }
420    }
421
422}