1use anyhow::{anyhow, Result};
3use ndarray::Array1;
4use ndarray::Array2;
5
6use crate::util;
7use crate::Float;
8use crate::SentenceEmbedder;
9use crate::WordEmbeddings;
10use crate::WordProbabilities;
11use crate::DEFAULT_N_SAMPLES_TO_FIT;
12use crate::DEFAULT_SEPARATOR;
13
14pub const DEFAULT_N_COMPONENTS: usize = 5;
17
18const FLOAT_0_5: Float = 0.5;
19const MODEL_MAGIC: &[u8] = b"sif_embedding::USif 0.6\n";
20
21#[derive(Clone)]
155pub struct USif<'w, 'p, W, P> {
156 word_embeddings: &'w W,
157 word_probs: &'p P,
158 n_components: usize,
159 param_a: Option<Float>,
160 weights: Option<Array1<Float>>,
161 common_components: Option<Array2<Float>>,
162 separator: char,
163 n_samples_to_fit: usize,
164}
165
166impl<'w, 'p, W, P> USif<'w, 'p, W, P>
167where
168 W: WordEmbeddings,
169 P: WordProbabilities,
170{
171 pub const fn new(word_embeddings: &'w W, word_probs: &'p P) -> Self {
179 Self {
180 word_embeddings,
181 word_probs,
182 n_components: DEFAULT_N_COMPONENTS,
183 param_a: None,
184 weights: None,
185 common_components: None,
186 separator: DEFAULT_SEPARATOR,
187 n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
188 }
189 }
190
191 pub const fn with_parameters(
201 word_embeddings: &'w W,
202 word_probs: &'p P,
203 n_components: usize,
204 ) -> Self {
205 Self {
206 word_embeddings,
207 word_probs,
208 n_components,
209 param_a: None,
210 weights: None,
211 common_components: None,
212 separator: DEFAULT_SEPARATOR,
213 n_samples_to_fit: DEFAULT_N_SAMPLES_TO_FIT,
214 }
215 }
216
217 pub const fn separator(mut self, separator: char) -> Self {
219 self.separator = separator;
220 self
221 }
222
223 pub fn n_samples_to_fit(mut self, n_samples_to_fit: usize) -> Result<Self> {
229 if n_samples_to_fit == 0 {
230 return Err(anyhow!("n_samples_to_fit must not be 0."));
231 }
232 self.n_samples_to_fit = n_samples_to_fit;
233 Ok(self)
234 }
235
236 fn average_sentence_length<S>(&self, sentences: &[S]) -> Float
239 where
240 S: AsRef<str>,
241 {
242 let mut n_words = 0;
243 for sent in sentences {
244 let sent = sent.as_ref();
245 n_words += sent.split(self.separator).count();
246 }
247 n_words as Float / sentences.len() as Float
248 }
249
250 fn estimate_param_a(&self, sent_len: Float) -> Float {
254 debug_assert!(sent_len > 0.);
255 let vocab_size = self.word_probs.n_words() as Float;
256 let threshold = 1. - (1. - (1. / vocab_size)).powf(sent_len);
257 let n_greater = self
258 .word_probs
259 .entries()
260 .filter(|(_, prob)| *prob > threshold)
261 .count() as Float;
262 let alpha = n_greater / vocab_size;
263 let partiion = 0.5 * vocab_size;
264 let param_a = (1. - alpha) / alpha.mul_add(partiion, Float::EPSILON); param_a.max(Float::EPSILON) }
267
268 fn weighted_embeddings<I, S>(&self, sentences: I, param_a: Float) -> Array2<Float>
271 where
272 I: IntoIterator<Item = S>,
273 S: AsRef<str>,
274 {
275 debug_assert!(param_a > 0.);
276 let mut sent_embeddings = vec![];
277 let mut n_sentences = 0;
278 for sent in sentences {
279 let sent_embedding = self.weighted_embedding(sent.as_ref(), param_a);
280 sent_embeddings.extend(sent_embedding.iter());
281 n_sentences += 1;
282 }
283 Array2::from_shape_vec((n_sentences, self.embedding_size()), sent_embeddings).unwrap()
284 }
285
286 fn weighted_embedding(&self, sent: &str, param_a: Float) -> Array1<Float> {
289 debug_assert!(param_a > 0.);
290
291 let mut n_words = 0;
293 let mut word_embeddings: Vec<Float> = vec![];
294 let mut word_weights: Vec<Float> = vec![];
295 for word in sent.split(self.separator) {
296 if let Some(word_embedding) = self.word_embeddings.embedding(word) {
297 word_embeddings.extend(word_embedding.iter());
298 word_weights
299 .push(param_a / FLOAT_0_5.mul_add(param_a, self.word_probs.probability(word)));
300 n_words += 1;
301 }
302 }
303
304 if n_words == 0 {
306 return Array1::zeros(self.embedding_size()) + param_a;
307 }
308
309 let word_embeddings =
311 Array2::from_shape_vec((n_words, self.embedding_size()), word_embeddings).unwrap();
312 let word_weights = Array2::from_shape_vec((n_words, 1), word_weights).unwrap();
313
314 let axis = ndarray_linalg::norm::NormalizeAxis::Column; let (mut word_embeddings, _) = ndarray_linalg::norm::normalize(word_embeddings, axis);
317
318 word_embeddings.mapv_inplace(|x| if x.is_nan() { 0. } else { x });
321
322 word_embeddings *= &word_weights;
324
325 word_embeddings.mean_axis(ndarray::Axis(0)).unwrap()
327 }
328
329 fn estimate_principal_components(
334 &self,
335 sent_embeddings: &Array2<Float>,
336 ) -> (Array1<Float>, Array2<Float>) {
337 let (singular_values, singular_vectors) =
338 util::principal_components(sent_embeddings, self.n_components);
339 let singular_weights = singular_values.mapv(|v| v.powi(2));
340 let singular_weights = singular_weights.to_owned() / singular_weights.sum();
341 (singular_weights, singular_vectors)
342 }
343
344 pub fn serialize(&self) -> Result<Vec<u8>> {
346 let mut bytes = Vec::new();
347 bytes.extend_from_slice(MODEL_MAGIC);
348 bincode::serialize_into(&mut bytes, &self.n_components)?;
349 bincode::serialize_into(&mut bytes, &self.param_a)?;
350 bincode::serialize_into(&mut bytes, &self.weights)?;
351 bincode::serialize_into(&mut bytes, &self.common_components)?;
352 bincode::serialize_into(&mut bytes, &self.separator)?;
353 bincode::serialize_into(&mut bytes, &self.n_samples_to_fit)?;
354 Ok(bytes)
355 }
356
357 pub fn deserialize(bytes: &[u8], word_embeddings: &'w W, word_probs: &'p P) -> Result<Self> {
367 if !bytes.starts_with(MODEL_MAGIC) {
368 return Err(anyhow!("The magic number of the input model mismatches."));
369 }
370 let mut bytes = &bytes[MODEL_MAGIC.len()..];
371 let n_components = bincode::deserialize_from(&mut bytes)?;
372 let param_a = bincode::deserialize_from(&mut bytes)?;
373 let weights = bincode::deserialize_from(&mut bytes)?;
374 let common_components = bincode::deserialize_from(&mut bytes)?;
375 let separator = bincode::deserialize_from(&mut bytes)?;
376 let n_samples_to_fit = bincode::deserialize_from(&mut bytes)?;
377 Ok(Self {
378 word_embeddings,
379 word_probs,
380 n_components,
381 param_a,
382 weights,
383 common_components,
384 separator,
385 n_samples_to_fit,
386 })
387 }
388}
389
390impl<W, P> SentenceEmbedder for USif<'_, '_, W, P>
391where
392 W: WordEmbeddings,
393 P: WordProbabilities,
394{
395 fn embedding_size(&self) -> usize {
398 self.word_embeddings.embedding_size()
399 }
400
401 fn fit<S>(mut self, sentences: &[S]) -> Result<Self>
409 where
410 S: AsRef<str>,
411 {
412 if sentences.is_empty() {
413 return Err(anyhow!("Input sentences must not be empty."));
414 }
415
416 let sentences = util::sample_sentences(sentences, self.n_samples_to_fit);
417
418 let sent_len = self.average_sentence_length(&sentences);
420 if sent_len == 0. {
421 return Err(anyhow!("Input sentences must not be empty."));
422 }
423 let param_a = self.estimate_param_a(sent_len);
424 let sent_embeddings = self.weighted_embeddings(sentences, param_a);
425 self.param_a = Some(param_a);
426
427 if self.n_components != 0 {
429 let (weights, common_components) = self.estimate_principal_components(&sent_embeddings);
430 self.weights = Some(weights);
431 self.common_components = Some(common_components);
432 }
433 Ok(self)
437 }
438
439 fn embeddings<I, S>(&self, sentences: I) -> Result<Array2<Float>>
445 where
446 I: IntoIterator<Item = S>,
447 S: AsRef<str>,
448 {
449 if self.param_a.is_none() {
450 return Err(anyhow!("The model is not fitted."));
451 }
452 let sent_embeddings = self.weighted_embeddings(sentences, self.param_a.unwrap());
454 if sent_embeddings.is_empty() {
455 return Ok(sent_embeddings);
456 }
457 if self.n_components == 0 {
458 return Ok(sent_embeddings);
459 }
460 let weights = self.weights.as_ref().unwrap();
462 let common_components = self.common_components.as_ref().unwrap();
463 let sent_embeddings =
464 util::remove_principal_components(&sent_embeddings, common_components, Some(weights));
465 Ok(sent_embeddings)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 use approx::assert_relative_eq;
474 use ndarray::{arr1, CowArray, Ix1};
475
476 struct SimpleWordEmbeddings {}
477
478 impl WordEmbeddings for SimpleWordEmbeddings {
479 fn embedding(&self, word: &str) -> Option<CowArray<Float, Ix1>> {
480 match word {
481 "A" => Some(arr1(&[1., 2., 3.]).into()),
482 "BB" => Some(arr1(&[4., 5., 6.]).into()),
483 "CCC" => Some(arr1(&[7., 8., 9.]).into()),
484 "DDDD" => Some(arr1(&[10., 11., 12.]).into()),
485 _ => None,
486 }
487 }
488
489 fn embedding_size(&self) -> usize {
490 3
491 }
492 }
493
494 struct SimpleWordProbabilities {}
495
496 impl WordProbabilities for SimpleWordProbabilities {
497 fn probability(&self, word: &str) -> Float {
498 match word {
499 "A" => 0.6,
500 "BB" => 0.2,
501 "CCC" => 0.1,
502 "DDDD" => 0.1,
503 _ => 0.,
504 }
505 }
506
507 fn n_words(&self) -> usize {
508 4
509 }
510
511 fn entries(&self) -> Box<dyn Iterator<Item = (String, Float)> + '_> {
512 Box::new(
513 [("A", 0.6), ("BB", 0.2), ("CCC", 0.1), ("DDDD", 0.1)]
514 .iter()
515 .map(|&(word, prob)| (word.to_string(), prob)),
516 )
517 }
518 }
519
520 #[test]
521 fn test_basic() {
522 let word_embeddings = SimpleWordEmbeddings {};
523 let word_probs = SimpleWordProbabilities {};
524
525 let sif = USif::new(&word_embeddings, &word_probs)
526 .fit(&["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
527 .unwrap();
528
529 let sent_embeddings = sif
530 .embeddings(["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""])
531 .unwrap();
532 assert_ne!(sent_embeddings, Array2::zeros((5, 3)));
533
534 let sent_embeddings = sif.embeddings(Vec::<&str>::new()).unwrap();
535 assert_eq!(sent_embeddings.shape(), &[0, 3]);
536
537 let sent_embeddings = sif.embeddings([""]).unwrap();
538 assert_ne!(sent_embeddings, Array2::zeros((1, 3)));
539 }
540
541 #[test]
542 fn test_separator() {
543 let word_embeddings = SimpleWordEmbeddings {};
544 let word_probs = SimpleWordProbabilities {};
545
546 let sentences_1 = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
547 let sentences_2 = &["A,BB,CCC,DDDD", "BB,CCC", "A,B,C", "Z", ""];
548
549 let sif = USif::new(&word_embeddings, &word_probs);
550
551 let sif = sif.fit(sentences_1).unwrap();
552 let embeddings_1 = sif.embeddings(sentences_1).unwrap();
553
554 let sif = sif.separator(',');
555 let embeddings_2 = sif.embeddings(sentences_2).unwrap();
556
557 assert_relative_eq!(embeddings_1, embeddings_2);
558 }
559
560 #[test]
561 fn test_no_fitted() {
562 let word_embeddings = SimpleWordEmbeddings {};
563 let word_probs = SimpleWordProbabilities {};
564
565 let sentences = &["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
566
567 let sif = USif::new(&word_embeddings, &word_probs);
568 let embeddings = sif.embeddings(sentences);
569
570 assert!(embeddings.is_err());
571 }
572
573 #[test]
574 fn test_empty_fit() {
575 let word_embeddings = SimpleWordEmbeddings {};
576 let word_probs = SimpleWordProbabilities {};
577
578 let sif = USif::new(&word_embeddings, &word_probs);
579 let sif = sif.fit(&Vec::<&str>::new());
580
581 assert!(sif.is_err());
582 }
583
584 #[test]
585 fn test_io() {
586 let word_embeddings = SimpleWordEmbeddings {};
587 let word_probs = SimpleWordProbabilities {};
588
589 let sentences = ["A BB CCC DDDD", "BB CCC", "A B C", "Z", ""];
590 let model_a = USif::new(&word_embeddings, &word_probs)
591 .fit(&sentences)
592 .unwrap();
593 let bytes = model_a.serialize().unwrap();
594 let model_b = USif::deserialize(&bytes, &word_embeddings, &word_probs).unwrap();
595
596 let embeddings_a = model_a.embeddings(sentences).unwrap();
597 let embeddings_b = model_b.embeddings(sentences).unwrap();
598
599 assert_relative_eq!(embeddings_a, embeddings_b);
600 }
601}