1use std::cmp::Ordering;
4use std::collections::{BinaryHeap, HashSet};
5
6use ndarray::{s, Array1, ArrayView1, ArrayView2};
7use ordered_float::NotNan;
8
9use crate::embeddings::Embeddings;
10use crate::storage::StorageView;
11use crate::util::l2_normalize;
12use crate::vocab::Vocab;
13
14#[derive(Debug, Eq, PartialEq)]
19pub struct WordSimilarity<'a> {
20 pub similarity: NotNan<f32>,
21 pub word: &'a str,
22}
23
24impl<'a> Ord for WordSimilarity<'a> {
25 fn cmp(&self, other: &Self) -> Ordering {
26 match other.similarity.cmp(&self.similarity) {
27 Ordering::Equal => self.word.cmp(other.word),
28 ordering => ordering,
29 }
30 }
31}
32
33impl<'a> PartialOrd for WordSimilarity<'a> {
34 fn partial_cmp(&self, other: &WordSimilarity) -> Option<Ordering> {
35 Some(self.cmp(other))
36 }
37}
38
39pub trait Analogy {
41 fn analogy(
51 &self,
52 word1: &str,
53 word2: &str,
54 word3: &str,
55 limit: usize,
56 ) -> Option<Vec<WordSimilarity>>;
57}
58
59impl<V, S> Analogy for Embeddings<V, S>
60where
61 V: Vocab,
62 S: StorageView,
63{
64 fn analogy(
65 &self,
66 word1: &str,
67 word2: &str,
68 word3: &str,
69 limit: usize,
70 ) -> Option<Vec<WordSimilarity>> {
71 self.analogy_by(word1, word2, word3, limit, |embeds, embed| {
72 embeds.dot(&embed)
73 })
74 }
75}
76
77pub trait AnalogyBy {
79 fn analogy_by<F>(
89 &self,
90 word1: &str,
91 word2: &str,
92 word3: &str,
93 limit: usize,
94 similarity: F,
95 ) -> Option<Vec<WordSimilarity>>
96 where
97 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
98}
99
100impl<V, S> AnalogyBy for Embeddings<V, S>
101where
102 V: Vocab,
103 S: StorageView,
104{
105 fn analogy_by<F>(
106 &self,
107 word1: &str,
108 word2: &str,
109 word3: &str,
110 limit: usize,
111 similarity: F,
112 ) -> Option<Vec<WordSimilarity>>
113 where
114 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
115 {
116 let embedding1 = self.embedding(word1)?;
117 let embedding2 = self.embedding(word2)?;
118 let embedding3 = self.embedding(word3)?;
119
120 let mut embedding = (&embedding2.as_view() - &embedding1.as_view()) + embedding3.as_view();
121 l2_normalize(embedding.view_mut());
122
123 let skip = [word1, word2, word3].iter().cloned().collect();
124
125 Some(self.similarity_(embedding.view(), &skip, limit, similarity))
126 }
127}
128
129pub trait Similarity {
131 fn similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarity>>;
138}
139
140impl<V, S> Similarity for Embeddings<V, S>
141where
142 V: Vocab,
143 S: StorageView,
144{
145 fn similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarity>> {
146 self.similarity_by(word, limit, |embeds, embed| embeds.dot(&embed))
147 }
148}
149
150pub trait SimilarityBy {
152 fn similarity_by<F>(
159 &self,
160 word: &str,
161 limit: usize,
162 similarity: F,
163 ) -> Option<Vec<WordSimilarity>>
164 where
165 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
166}
167
168impl<V, S> SimilarityBy for Embeddings<V, S>
169where
170 V: Vocab,
171 S: StorageView,
172{
173 fn similarity_by<F>(
174 &self,
175 word: &str,
176 limit: usize,
177 similarity: F,
178 ) -> Option<Vec<WordSimilarity>>
179 where
180 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
181 {
182 let embed = self.embedding(word)?;
183 let mut skip = HashSet::new();
184 skip.insert(word);
185
186 Some(self.similarity_(embed.as_view(), &skip, limit, similarity))
187 }
188}
189
190trait SimilarityPrivate {
191 fn similarity_<F>(
192 &self,
193 embed: ArrayView1<f32>,
194 skip: &HashSet<&str>,
195 limit: usize,
196 similarity: F,
197 ) -> Vec<WordSimilarity>
198 where
199 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>;
200}
201
202impl<V, S> SimilarityPrivate for Embeddings<V, S>
203where
204 V: Vocab,
205 S: StorageView,
206{
207 fn similarity_<F>(
208 &self,
209 embed: ArrayView1<f32>,
210 skip: &HashSet<&str>,
211 limit: usize,
212 mut similarity: F,
213 ) -> Vec<WordSimilarity>
214 where
215 F: FnMut(ArrayView2<f32>, ArrayView1<f32>) -> Array1<f32>,
216 {
217 #[allow(clippy::deref_addrof)]
219 let sims = similarity(
220 self.storage().view().slice(s![0..self.vocab().len(), ..]),
221 embed.view(),
222 );
223
224 let mut results = BinaryHeap::with_capacity(limit);
225 for (idx, &sim) in sims.iter().enumerate() {
226 let word = &self.vocab().words()[idx];
227
228 if skip.contains(word.as_str()) {
230 continue;
231 }
232
233 let word_similarity = WordSimilarity {
234 word,
235 similarity: NotNan::new(sim).expect("Encountered NaN"),
236 };
237
238 if results.len() < limit {
239 results.push(word_similarity);
240 } else {
241 let mut peek = results.peek_mut().expect("Cannot peek non-empty heap");
242 if word_similarity < *peek {
243 *peek = word_similarity
244 }
245 }
246 }
247
248 results.into_sorted_vec()
249 }
250}
251
252#[cfg(test)]
253mod tests {
254
255 use std::fs::File;
256 use std::io::BufReader;
257
258 use crate::embeddings::Embeddings;
259 use crate::similarity::{Analogy, Similarity};
260 use crate::word2vec::ReadWord2Vec;
261
262 static SIMILARITY_ORDER_STUTTGART_10: &'static [&'static str] = &[
263 "Karlsruhe",
264 "Mannheim",
265 "München",
266 "Darmstadt",
267 "Heidelberg",
268 "Wiesbaden",
269 "Kassel",
270 "Düsseldorf",
271 "Leipzig",
272 "Berlin",
273 ];
274
275 static SIMILARITY_ORDER: &'static [&'static str] = &[
276 "Potsdam",
277 "Hamburg",
278 "Leipzig",
279 "Dresden",
280 "München",
281 "Düsseldorf",
282 "Bonn",
283 "Stuttgart",
284 "Weimar",
285 "Berlin-Charlottenburg",
286 "Rostock",
287 "Karlsruhe",
288 "Chemnitz",
289 "Breslau",
290 "Wiesbaden",
291 "Hannover",
292 "Mannheim",
293 "Kassel",
294 "Köln",
295 "Danzig",
296 "Erfurt",
297 "Dessau",
298 "Bremen",
299 "Charlottenburg",
300 "Magdeburg",
301 "Neuruppin",
302 "Darmstadt",
303 "Jena",
304 "Wien",
305 "Heidelberg",
306 "Dortmund",
307 "Stettin",
308 "Schwerin",
309 "Neubrandenburg",
310 "Greifswald",
311 "Göttingen",
312 "Braunschweig",
313 "Berliner",
314 "Warschau",
315 "Berlin-Spandau",
316 ];
317
318 static ANALOGY_ORDER: &'static [&'static str] = &[
319 "Deutschland",
320 "Westdeutschland",
321 "Sachsen",
322 "Mitteldeutschland",
323 "Brandenburg",
324 "Polen",
325 "Norddeutschland",
326 "Dänemark",
327 "Schleswig-Holstein",
328 "Österreich",
329 "Bayern",
330 "Thüringen",
331 "Bundesrepublik",
332 "Ostdeutschland",
333 "Preußen",
334 "Deutschen",
335 "Hessen",
336 "Potsdam",
337 "Mecklenburg",
338 "Niedersachsen",
339 "Hamburg",
340 "Süddeutschland",
341 "Bremen",
342 "Russland",
343 "Deutschlands",
344 "BRD",
345 "Litauen",
346 "Mecklenburg-Vorpommern",
347 "DDR",
348 "West-Berlin",
349 "Saarland",
350 "Lettland",
351 "Hannover",
352 "Rostock",
353 "Sachsen-Anhalt",
354 "Pommern",
355 "Schweden",
356 "Deutsche",
357 "deutschen",
358 "Westfalen",
359 ];
360
361 #[test]
362 fn test_similarity() {
363 let f = File::open("testdata/similarity.bin").unwrap();
364 let mut reader = BufReader::new(f);
365 let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
366
367 let result = embeddings.similarity("Berlin", 40);
368 assert!(result.is_some());
369 let result = result.unwrap();
370 assert_eq!(40, result.len());
371
372 for (idx, word_similarity) in result.iter().enumerate() {
373 assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
374 }
375
376 let result = embeddings.similarity("Berlin", 10);
377 assert!(result.is_some());
378 let result = result.unwrap();
379 assert_eq!(10, result.len());
380
381 println!("{:?}", result);
382
383 for (idx, word_similarity) in result.iter().enumerate() {
384 assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
385 }
386 }
387
388 #[test]
389 fn test_similarity_limit() {
390 let f = File::open("testdata/similarity.bin").unwrap();
391 let mut reader = BufReader::new(f);
392 let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
393
394 let result = embeddings.similarity("Stuttgart", 10);
395 assert!(result.is_some());
396 let result = result.unwrap();
397 assert_eq!(10, result.len());
398
399 println!("{:?}", result);
400
401 for (idx, word_similarity) in result.iter().enumerate() {
402 assert_eq!(SIMILARITY_ORDER_STUTTGART_10[idx], word_similarity.word)
403 }
404 }
405
406 #[test]
407 fn test_analogy() {
408 let f = File::open("testdata/analogy.bin").unwrap();
409 let mut reader = BufReader::new(f);
410 let embeddings = Embeddings::read_word2vec_binary(&mut reader, true).unwrap();
411
412 let result = embeddings.analogy("Paris", "Frankreich", "Berlin", 40);
413 assert!(result.is_some());
414 let result = result.unwrap();
415 assert_eq!(40, result.len());
416
417 for (idx, word_similarity) in result.iter().enumerate() {
418 assert_eq!(ANALOGY_ORDER[idx], word_similarity.word)
419 }
420 }
421
422}