1pub mod contrastive;
135pub mod crosslingual;
136pub mod fasttext;
137pub mod glove;
138pub mod sentence;
139
140pub use fasttext::{FastText, FastTextConfig};
142pub use glove::{
143 cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
144 GloVeTrainerConfig,
145};
146
147use crate::error::{Result, TextError};
148use crate::tokenize::{Tokenizer, WordTokenizer};
149use crate::vocabulary::Vocabulary;
150use scirs2_core::ndarray::{Array1, Array2};
151use scirs2_core::random::prelude::*;
152use std::collections::HashMap;
153use std::fmt::Debug;
154use std::fs::File;
155use std::io::{BufRead, BufReader, Write};
156use std::path::Path;
157
158pub trait WordEmbedding {
166 fn embedding(&self, word: &str) -> Result<Array1<f64>>;
168
169 fn dimension(&self) -> usize;
171
172 fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
174 let v1 = self.embedding(word1)?;
175 let v2 = self.embedding(word2)?;
176 Ok(embedding_cosine_similarity(&v1, &v2))
177 }
178
179 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
181
182 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
184
185 fn vocab_size(&self) -> usize;
187}
188
189pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
191 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
192 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
193 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
194
195 if norm_a > 0.0 && norm_b > 0.0 {
196 dot_product / (norm_a * norm_b)
197 } else {
198 0.0
199 }
200}
201
202pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
204 let vectors: Vec<Array1<f64>> = words
205 .iter()
206 .map(|&w| model.embedding(w))
207 .collect::<Result<Vec<_>>>()?;
208
209 let n = vectors.len();
210 let mut matrix = vec![vec![0.0; n]; n];
211
212 for i in 0..n {
213 for j in i..n {
214 let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
215 matrix[i][j] = sim;
216 matrix[j][i] = sim;
217 }
218 }
219
220 Ok(matrix)
221}
222
223impl WordEmbedding for GloVe {
226 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
227 self.get_word_vector(word)
228 }
229
230 fn dimension(&self) -> usize {
231 self.vector_size()
232 }
233
234 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
235 self.most_similar(word, top_n)
236 }
237
238 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
239 self.analogy(a, b, c, top_n)
240 }
241
242 fn vocab_size(&self) -> usize {
243 self.vocabulary_size()
244 }
245}
246
247impl WordEmbedding for FastText {
248 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
249 self.get_word_vector(word)
250 }
251
252 fn dimension(&self) -> usize {
253 self.vector_size()
254 }
255
256 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
257 self.most_similar(word, top_n)
258 }
259
260 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
261 self.analogy(a, b, c, top_n)
262 }
263
264 fn vocab_size(&self) -> usize {
265 self.vocabulary_size()
266 }
267}
268
269#[derive(Debug, Clone)]
273struct HuffmanNode {
274 id: usize,
276 frequency: usize,
278 left: Option<usize>,
280 right: Option<usize>,
282 is_leaf: bool,
284}
285
286#[derive(Debug, Clone)]
288struct HuffmanTree {
289 codes: Vec<Vec<u8>>,
291 paths: Vec<Vec<usize>>,
293 num_internal: usize,
295}
296
297impl HuffmanTree {
298 fn build(frequencies: &[usize]) -> Result<Self> {
302 let vocab_size = frequencies.len();
303 if vocab_size == 0 {
304 return Err(TextError::EmbeddingError(
305 "Cannot build Huffman tree with empty vocabulary".into(),
306 ));
307 }
308 if vocab_size == 1 {
309 return Ok(Self {
311 codes: vec![vec![0]],
312 paths: vec![vec![0]],
313 num_internal: 1,
314 });
315 }
316
317 let mut nodes: Vec<HuffmanNode> = frequencies
319 .iter()
320 .enumerate()
321 .map(|(id, &freq)| HuffmanNode {
322 id,
323 frequency: freq.max(1), left: None,
325 right: None,
326 is_leaf: true,
327 })
328 .collect();
329
330 let mut queue: Vec<(usize, usize)> = nodes
333 .iter()
334 .enumerate()
335 .map(|(i, n)| (i, n.frequency))
336 .collect();
337 queue.sort_by_key(|item| std::cmp::Reverse(item.1)); while queue.len() > 1 {
341 let (idx1, freq1) = queue
343 .pop()
344 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
345 let (idx2, freq2) = queue
346 .pop()
347 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
348
349 let new_id = nodes.len();
350 let new_node = HuffmanNode {
351 id: new_id,
352 frequency: freq1 + freq2,
353 left: Some(idx1),
354 right: Some(idx2),
355 is_leaf: false,
356 };
357 nodes.push(new_node);
358
359 let new_freq = freq1 + freq2;
361 let insert_pos = queue
362 .binary_search_by(|(_, f)| new_freq.cmp(f))
363 .unwrap_or_else(|pos| pos);
364 queue.insert(insert_pos, (new_id, new_freq));
365 }
366
367 let num_internal = nodes.len() - vocab_size;
369 let mut codes = vec![Vec::new(); vocab_size];
370 let mut paths = vec![Vec::new(); vocab_size];
371
372 let root_idx = nodes.len() - 1;
374 let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
375
376 while let Some((node_idx, code, path)) = stack.pop() {
377 let node = &nodes[node_idx];
378
379 if node.is_leaf {
380 codes[node.id] = code;
381 paths[node.id] = path;
382 } else {
383 let internal_idx = node.id - vocab_size;
385
386 if let Some(left_idx) = node.left {
387 let mut left_code = code.clone();
388 left_code.push(0);
389 let mut left_path = path.clone();
390 left_path.push(internal_idx);
391 stack.push((left_idx, left_code, left_path));
392 }
393
394 if let Some(right_idx) = node.right {
395 let mut right_code = code.clone();
396 right_code.push(1);
397 let mut right_path = path.clone();
398 right_path.push(internal_idx);
399 stack.push((right_idx, right_code, right_path));
400 }
401 }
402 }
403
404 Ok(Self {
405 codes,
406 paths,
407 num_internal,
408 })
409 }
410}
411
412#[derive(Debug, Clone)]
414struct SamplingTable {
415 cdf: Vec<f64>,
417 weights: Vec<f64>,
419}
420
421impl SamplingTable {
422 fn new(weights: &[f64]) -> Result<Self> {
424 if weights.is_empty() {
425 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
426 }
427
428 if weights.iter().any(|&w| w < 0.0) {
430 return Err(TextError::EmbeddingError("Weights must be positive".into()));
431 }
432
433 let sum: f64 = weights.iter().sum();
435 if sum <= 0.0 {
436 return Err(TextError::EmbeddingError(
437 "Sum of _weights must be positive".into(),
438 ));
439 }
440
441 let mut cdf = Vec::with_capacity(weights.len());
442 let mut total = 0.0;
443
444 for &w in weights {
445 total += w;
446 cdf.push(total / sum);
447 }
448
449 Ok(Self {
450 cdf,
451 weights: weights.to_vec(),
452 })
453 }
454
455 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
457 let r = rng.random::<f64>();
458
459 match self.cdf.binary_search_by(|&cdf_val| {
461 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
462 }) {
463 Ok(idx) => idx,
464 Err(idx) => idx,
465 }
466 }
467
468 fn weights(&self) -> &[f64] {
470 &self.weights
471 }
472}
473
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
476pub enum Word2VecAlgorithm {
477 CBOW,
479 SkipGram,
481}
482
483#[derive(Debug, Clone)]
485pub struct Word2VecConfig {
486 pub vector_size: usize,
488 pub window_size: usize,
490 pub min_count: usize,
492 pub epochs: usize,
494 pub learning_rate: f64,
496 pub algorithm: Word2VecAlgorithm,
498 pub negative_samples: usize,
500 pub subsample: f64,
502 pub batch_size: usize,
504 pub hierarchical_softmax: bool,
506}
507
508impl Default for Word2VecConfig {
509 fn default() -> Self {
510 Self {
511 vector_size: 100,
512 window_size: 5,
513 min_count: 5,
514 epochs: 5,
515 learning_rate: 0.025,
516 algorithm: Word2VecAlgorithm::SkipGram,
517 negative_samples: 5,
518 subsample: 1e-3,
519 batch_size: 128,
520 hierarchical_softmax: false,
521 }
522 }
523}
524
525pub struct Word2Vec {
535 config: Word2VecConfig,
537 vocabulary: Vocabulary,
539 input_embeddings: Option<Array2<f64>>,
541 output_embeddings: Option<Array2<f64>>,
543 tokenizer: Box<dyn Tokenizer + Send + Sync>,
545 sampling_table: Option<SamplingTable>,
547 huffman_tree: Option<HuffmanTree>,
549 hs_params: Option<Array2<f64>>,
551 current_learning_rate: f64,
553}
554
555impl Debug for Word2Vec {
556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 f.debug_struct("Word2Vec")
558 .field("config", &self.config)
559 .field("vocabulary", &self.vocabulary)
560 .field("input_embeddings", &self.input_embeddings)
561 .field("output_embeddings", &self.output_embeddings)
562 .field("sampling_table", &self.sampling_table)
563 .field("huffman_tree", &self.huffman_tree)
564 .field("current_learning_rate", &self.current_learning_rate)
565 .finish()
566 }
567}
568
569impl Default for Word2Vec {
571 fn default() -> Self {
572 Self::new()
573 }
574}
575
576impl Clone for Word2Vec {
577 fn clone(&self) -> Self {
578 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
579
580 Self {
581 config: self.config.clone(),
582 vocabulary: self.vocabulary.clone(),
583 input_embeddings: self.input_embeddings.clone(),
584 output_embeddings: self.output_embeddings.clone(),
585 tokenizer,
586 sampling_table: self.sampling_table.clone(),
587 huffman_tree: self.huffman_tree.clone(),
588 hs_params: self.hs_params.clone(),
589 current_learning_rate: self.current_learning_rate,
590 }
591 }
592}
593
594impl Word2Vec {
595 pub fn new() -> Self {
597 Self {
598 config: Word2VecConfig::default(),
599 vocabulary: Vocabulary::new(),
600 input_embeddings: None,
601 output_embeddings: None,
602 tokenizer: Box::new(WordTokenizer::default()),
603 sampling_table: None,
604 huffman_tree: None,
605 hs_params: None,
606 current_learning_rate: 0.025,
607 }
608 }
609
610 pub fn with_config(config: Word2VecConfig) -> Self {
612 let learning_rate = config.learning_rate;
613 Self {
614 config,
615 vocabulary: Vocabulary::new(),
616 input_embeddings: None,
617 output_embeddings: None,
618 tokenizer: Box::new(WordTokenizer::default()),
619 sampling_table: None,
620 huffman_tree: None,
621 hs_params: None,
622 current_learning_rate: learning_rate,
623 }
624 }
625
626 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
628 self.tokenizer = tokenizer;
629 self
630 }
631
632 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
634 self.config.vector_size = vectorsize;
635 self
636 }
637
638 pub fn with_window_size(mut self, windowsize: usize) -> Self {
640 self.config.window_size = windowsize;
641 self
642 }
643
644 pub fn with_min_count(mut self, mincount: usize) -> Self {
646 self.config.min_count = mincount;
647 self
648 }
649
650 pub fn with_epochs(mut self, epochs: usize) -> Self {
652 self.config.epochs = epochs;
653 self
654 }
655
656 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
658 self.config.learning_rate = learningrate;
659 self.current_learning_rate = learningrate;
660 self
661 }
662
663 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
665 self.config.algorithm = algorithm;
666 self
667 }
668
669 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
671 self.config.negative_samples = negativesamples;
672 self
673 }
674
675 pub fn with_subsample(mut self, subsample: f64) -> Self {
677 self.config.subsample = subsample;
678 self
679 }
680
681 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
683 self.config.batch_size = batchsize;
684 self
685 }
686
687 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
689 if texts.is_empty() {
690 return Err(TextError::InvalidInput(
691 "No texts provided for building vocabulary".into(),
692 ));
693 }
694
695 let mut word_counts = HashMap::new();
697 let mut _total_words = 0;
698
699 for &text in texts {
700 let tokens = self.tokenizer.tokenize(text)?;
701 for token in tokens {
702 *word_counts.entry(token).or_insert(0) += 1;
703 _total_words += 1;
704 }
705 }
706
707 self.vocabulary = Vocabulary::new();
709 for (word, count) in &word_counts {
710 if *count >= self.config.min_count {
711 self.vocabulary.add_token(word);
712 }
713 }
714
715 if self.vocabulary.is_empty() {
716 return Err(TextError::VocabularyError(
717 "No words meet the minimum count threshold".into(),
718 ));
719 }
720
721 let vocab_size = self.vocabulary.len();
723 let vector_size = self.config.vector_size;
724
725 let mut rng = scirs2_core::random::rng();
727 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
728 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
729 });
730 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
731 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
732 });
733
734 self.input_embeddings = Some(input_embeddings);
735 self.output_embeddings = Some(output_embeddings);
736
737 self.create_sampling_table(&word_counts)?;
739
740 if self.config.hierarchical_softmax {
742 let frequencies: Vec<usize> = (0..vocab_size)
743 .map(|i| {
744 self.vocabulary
745 .get_token(i)
746 .and_then(|word| word_counts.get(word).copied())
747 .unwrap_or(1)
748 })
749 .collect();
750
751 let tree = HuffmanTree::build(&frequencies)?;
752 let num_internal = tree.num_internal;
753
754 let hs_params = Array2::zeros((num_internal, vector_size));
756 self.hs_params = Some(hs_params);
757 self.huffman_tree = Some(tree);
758 }
759
760 Ok(())
761 }
762
763 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
765 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
767
768 for (word, &count) in wordcounts.iter() {
769 if let Some(idx) = self.vocabulary.get_index(word) {
770 sampling_weights[idx] = (count as f64).powf(0.75);
772 }
773 }
774
775 match SamplingTable::new(&sampling_weights) {
776 Ok(table) => {
777 self.sampling_table = Some(table);
778 Ok(())
779 }
780 Err(e) => Err(e),
781 }
782 }
783
784 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
786 if texts.is_empty() {
787 return Err(TextError::InvalidInput(
788 "No texts provided for training".into(),
789 ));
790 }
791
792 if self.vocabulary.is_empty() {
794 self.build_vocabulary(texts)?;
795 }
796
797 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
798 return Err(TextError::EmbeddingError(
799 "Embeddings not initialized. Call build_vocabulary() first".into(),
800 ));
801 }
802
803 let mut _total_tokens = 0;
805 let mut sentences = Vec::new();
806 for &text in texts {
807 let tokens = self.tokenizer.tokenize(text)?;
808 let filtered_tokens: Vec<usize> = tokens
809 .iter()
810 .filter_map(|token| self.vocabulary.get_index(token))
811 .collect();
812 if !filtered_tokens.is_empty() {
813 _total_tokens += filtered_tokens.len();
814 sentences.push(filtered_tokens);
815 }
816 }
817
818 for epoch in 0..self.config.epochs {
820 self.current_learning_rate =
822 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
823 self.current_learning_rate = self
824 .current_learning_rate
825 .max(self.config.learning_rate * 0.0001);
826
827 for sentence in &sentences {
829 let subsampled_sentence = if self.config.subsample > 0.0 {
831 self.subsample_sentence(sentence)?
832 } else {
833 sentence.clone()
834 };
835
836 if subsampled_sentence.is_empty() {
838 continue;
839 }
840
841 if self.config.hierarchical_softmax {
843 match self.config.algorithm {
845 Word2VecAlgorithm::SkipGram => {
846 self.train_skipgram_hs_sentence(&subsampled_sentence)?;
847 }
848 Word2VecAlgorithm::CBOW => {
849 self.train_cbow_hs_sentence(&subsampled_sentence)?;
850 }
851 }
852 } else {
853 match self.config.algorithm {
855 Word2VecAlgorithm::CBOW => {
856 self.train_cbow_sentence(&subsampled_sentence)?;
857 }
858 Word2VecAlgorithm::SkipGram => {
859 self.train_skipgram_sentence(&subsampled_sentence)?;
860 }
861 }
862 }
863 }
864 }
865
866 Ok(())
867 }
868
869 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
871 let mut rng = scirs2_core::random::rng();
872 let total_words: f64 = self.vocabulary.len() as f64;
873 let threshold = self.config.subsample * total_words;
874
875 let subsampled: Vec<usize> = sentence
877 .iter()
878 .filter(|&&word_idx| {
879 let word_freq = self.get_word_frequency(word_idx);
880 if word_freq == 0.0 {
881 return true; }
883 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
885 rng.random::<f64>() < keep_prob
886 })
887 .copied()
888 .collect();
889
890 Ok(subsampled)
891 }
892
893 fn get_word_frequency(&self, wordidx: usize) -> f64 {
895 if let Some(table) = &self.sampling_table {
898 table.weights()[wordidx]
899 } else {
900 1.0 }
902 }
903
904 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
906 if sentence.len() < 2 {
907 return Ok(()); }
909
910 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
911 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
912 let vector_size = self.config.vector_size;
913 let window_size = self.config.window_size;
914 let negative_samples = self.config.negative_samples;
915
916 for pos in 0..sentence.len() {
918 let mut rng = scirs2_core::random::rng();
920 let window = 1 + rng.random_range(0..window_size);
921 let target_word = sentence[pos];
922
923 let mut context_words = Vec::new();
925 #[allow(clippy::needless_range_loop)]
926 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
927 if i != pos {
928 context_words.push(sentence[i]);
929 }
930 }
931
932 if context_words.is_empty() {
933 continue; }
935
936 let mut context_sum = Array1::zeros(vector_size);
938 for &context_idx in &context_words {
939 context_sum += &input_embeddings.row(context_idx);
940 }
941 let context_avg = &context_sum / context_words.len() as f64;
942
943 let mut target_output = output_embeddings.row_mut(target_word);
945 let dot_product = (&context_avg * &target_output).sum();
946 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
947 let error = (1.0 - sigmoid) * self.current_learning_rate;
948
949 let mut target_update = target_output.to_owned();
951 target_update.scaled_add(error, &context_avg);
952 target_output.assign(&target_update);
953
954 if let Some(sampler) = &self.sampling_table {
956 for _ in 0..negative_samples {
957 let negative_idx = sampler.sample(&mut rng);
958 if negative_idx == target_word {
959 continue; }
961
962 let mut negative_output = output_embeddings.row_mut(negative_idx);
963 let dot_product = (&context_avg * &negative_output).sum();
964 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
965 let error = -sigmoid * self.current_learning_rate;
966
967 let mut negative_update = negative_output.to_owned();
969 negative_update.scaled_add(error, &context_avg);
970 negative_output.assign(&negative_update);
971 }
972 }
973
974 for &context_idx in &context_words {
976 let mut input_vec = input_embeddings.row_mut(context_idx);
977
978 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
980 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
981 let error =
982 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
983
984 let mut input_update = input_vec.to_owned();
986 input_update.scaled_add(error, &output_embeddings.row(target_word));
987
988 if let Some(sampler) = &self.sampling_table {
990 for _ in 0..negative_samples {
991 let negative_idx = sampler.sample(&mut rng);
992 if negative_idx == target_word {
993 continue;
994 }
995
996 let dot_product =
997 (&context_avg * &output_embeddings.row(negative_idx)).sum();
998 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
999 let error =
1000 -sigmoid * self.current_learning_rate / context_words.len() as f64;
1001
1002 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1003 }
1004 }
1005
1006 input_vec.assign(&input_update);
1007 }
1008 }
1009
1010 Ok(())
1011 }
1012
1013 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1015 if sentence.len() < 2 {
1016 return Ok(()); }
1018
1019 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1020 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1021 let vector_size = self.config.vector_size;
1022 let window_size = self.config.window_size;
1023 let negative_samples = self.config.negative_samples;
1024
1025 for pos in 0..sentence.len() {
1027 let mut rng = scirs2_core::random::rng();
1029 let window = 1 + rng.random_range(0..window_size);
1030 let target_word = sentence[pos];
1031
1032 #[allow(clippy::needless_range_loop)]
1034 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1035 if i == pos {
1036 continue; }
1038
1039 let context_word = sentence[i];
1040 let target_input = input_embeddings.row(target_word);
1041 let mut context_output = output_embeddings.row_mut(context_word);
1042
1043 let dot_product = (&target_input * &context_output).sum();
1045 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1046 let error = (1.0 - sigmoid) * self.current_learning_rate;
1047
1048 let mut context_update = context_output.to_owned();
1050 context_update.scaled_add(error, &target_input);
1051 context_output.assign(&context_update);
1052
1053 let mut input_update = Array1::zeros(vector_size);
1055 input_update.scaled_add(error, &context_output);
1056
1057 if let Some(sampler) = &self.sampling_table {
1059 for _ in 0..negative_samples {
1060 let negative_idx = sampler.sample(&mut rng);
1061 if negative_idx == context_word {
1062 continue; }
1064
1065 let mut negative_output = output_embeddings.row_mut(negative_idx);
1066 let dot_product = (&target_input * &negative_output).sum();
1067 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1068 let error = -sigmoid * self.current_learning_rate;
1069
1070 let mut negative_update = negative_output.to_owned();
1072 negative_update.scaled_add(error, &target_input);
1073 negative_output.assign(&negative_update);
1074
1075 input_update.scaled_add(error, &negative_output);
1077 }
1078 }
1079
1080 let mut target_input_mut = input_embeddings.row_mut(target_word);
1082 target_input_mut += &input_update;
1083 }
1084 }
1085
1086 Ok(())
1087 }
1088
1089 pub fn vector_size(&self) -> usize {
1091 self.config.vector_size
1092 }
1093
1094 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1096 if self.input_embeddings.is_none() {
1097 return Err(TextError::EmbeddingError(
1098 "Model not trained. Call train() first".into(),
1099 ));
1100 }
1101
1102 match self.vocabulary.get_index(word) {
1103 Some(idx) => Ok(self
1104 .input_embeddings
1105 .as_ref()
1106 .expect("Operation failed")
1107 .row(idx)
1108 .to_owned()),
1109 None => Err(TextError::VocabularyError(format!(
1110 "Word '{word}' not in vocabulary"
1111 ))),
1112 }
1113 }
1114
1115 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1117 let word_vec = self.get_word_vector(word)?;
1118 self.most_similar_by_vector(&word_vec, topn, &[word])
1119 }
1120
1121 pub fn most_similar_by_vector(
1123 &self,
1124 vector: &Array1<f64>,
1125 top_n: usize,
1126 exclude_words: &[&str],
1127 ) -> Result<Vec<(String, f64)>> {
1128 if self.input_embeddings.is_none() {
1129 return Err(TextError::EmbeddingError(
1130 "Model not trained. Call train() first".into(),
1131 ));
1132 }
1133
1134 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1135 let vocab_size = self.vocabulary.len();
1136
1137 let exclude_indices: Vec<usize> = exclude_words
1139 .iter()
1140 .filter_map(|&word| self.vocabulary.get_index(word))
1141 .collect();
1142
1143 let mut similarities = Vec::with_capacity(vocab_size);
1145
1146 for i in 0..vocab_size {
1147 if exclude_indices.contains(&i) {
1148 continue;
1149 }
1150
1151 let word_vec = input_embeddings.row(i);
1152 let similarity = cosine_similarity(vector, &word_vec.to_owned());
1153
1154 if let Some(word) = self.vocabulary.get_token(i) {
1155 similarities.push((word.to_string(), similarity));
1156 }
1157 }
1158
1159 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1161
1162 let result = similarities.into_iter().take(top_n).collect();
1164 Ok(result)
1165 }
1166
1167 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1169 if self.input_embeddings.is_none() {
1170 return Err(TextError::EmbeddingError(
1171 "Model not trained. Call train() first".into(),
1172 ));
1173 }
1174
1175 let a_vec = self.get_word_vector(a)?;
1177 let b_vec = self.get_word_vector(b)?;
1178 let c_vec = self.get_word_vector(c)?;
1179
1180 let mut d_vec = b_vec.clone();
1182 d_vec -= &a_vec;
1183 d_vec += &c_vec;
1184
1185 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1187 d_vec.mapv_inplace(|val| val / norm);
1188
1189 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1191 }
1192
1193 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1195 if self.input_embeddings.is_none() {
1196 return Err(TextError::EmbeddingError(
1197 "Model not trained. Call train() first".into(),
1198 ));
1199 }
1200
1201 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1202
1203 writeln!(
1205 &mut file,
1206 "{} {}",
1207 self.vocabulary.len(),
1208 self.config.vector_size
1209 )
1210 .map_err(|e| TextError::IoError(e.to_string()))?;
1211
1212 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1214
1215 for i in 0..self.vocabulary.len() {
1216 if let Some(word) = self.vocabulary.get_token(i) {
1217 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1219
1220 let vector = input_embeddings.row(i);
1222 for j in 0..self.config.vector_size {
1223 write!(&mut file, "{:.6} ", vector[j])
1224 .map_err(|e| TextError::IoError(e.to_string()))?;
1225 }
1226
1227 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1228 }
1229 }
1230
1231 Ok(())
1232 }
1233
1234 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1236 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1237 let mut reader = BufReader::new(file);
1238
1239 let mut header = String::new();
1241 reader
1242 .read_line(&mut header)
1243 .map_err(|e| TextError::IoError(e.to_string()))?;
1244
1245 let parts: Vec<&str> = header.split_whitespace().collect();
1246 if parts.len() != 2 {
1247 return Err(TextError::EmbeddingError(
1248 "Invalid model file format".into(),
1249 ));
1250 }
1251
1252 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1253 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1254 })?;
1255
1256 let vector_size = parts[1]
1257 .parse::<usize>()
1258 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1259
1260 let mut model = Self::new().with_vector_size(vector_size);
1262 let mut vocabulary = Vocabulary::new();
1263 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1264
1265 let mut i = 0;
1267 for line in reader.lines() {
1268 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1269 let parts: Vec<&str> = line.split_whitespace().collect();
1270
1271 if parts.len() != vector_size + 1 {
1272 let line_num = i + 2;
1273 return Err(TextError::EmbeddingError(format!(
1274 "Invalid vector format at line {line_num}"
1275 )));
1276 }
1277
1278 let word = parts[0];
1279 vocabulary.add_token(word);
1280
1281 for j in 0..vector_size {
1282 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1283 TextError::EmbeddingError(format!(
1284 "Invalid vector component at line {}, position {}",
1285 i + 2,
1286 j + 1
1287 ))
1288 })?;
1289 }
1290
1291 i += 1;
1292 }
1293
1294 if i != vocab_size {
1295 return Err(TextError::EmbeddingError(format!(
1296 "Expected {vocab_size} words but found {i}"
1297 )));
1298 }
1299
1300 model.vocabulary = vocabulary;
1301 model.input_embeddings = Some(input_embeddings);
1302 model.output_embeddings = None; Ok(model)
1305 }
1306
1307 pub fn get_vocabulary(&self) -> Vec<String> {
1311 let mut vocab = Vec::new();
1312 for i in 0..self.vocabulary.len() {
1313 if let Some(token) = self.vocabulary.get_token(i) {
1314 vocab.push(token.to_string());
1315 }
1316 }
1317 vocab
1318 }
1319
1320 pub fn get_vector_size(&self) -> usize {
1322 self.config.vector_size
1323 }
1324
1325 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1327 self.config.algorithm
1328 }
1329
1330 pub fn get_window_size(&self) -> usize {
1332 self.config.window_size
1333 }
1334
1335 pub fn get_min_count(&self) -> usize {
1337 self.config.min_count
1338 }
1339
1340 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1342 self.input_embeddings.clone()
1343 }
1344
1345 pub fn get_negative_samples(&self) -> usize {
1347 self.config.negative_samples
1348 }
1349
1350 pub fn get_learning_rate(&self) -> f64 {
1352 self.config.learning_rate
1353 }
1354
1355 pub fn get_epochs(&self) -> usize {
1357 self.config.epochs
1358 }
1359
1360 pub fn get_subsampling_threshold(&self) -> f64 {
1362 self.config.subsample
1363 }
1364
1365 pub fn uses_hierarchical_softmax(&self) -> bool {
1367 self.config.hierarchical_softmax
1368 }
1369
1370 fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1374 if sentence.len() < 2 {
1375 return Ok(());
1376 }
1377
1378 let input_embeddings = self
1379 .input_embeddings
1380 .as_mut()
1381 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1382 let hs_params = self
1383 .hs_params
1384 .as_mut()
1385 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1386 let tree = self
1387 .huffman_tree
1388 .as_ref()
1389 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1390
1391 let vector_size = self.config.vector_size;
1392 let window_size = self.config.window_size;
1393 let lr = self.current_learning_rate;
1394
1395 let codes = tree.codes.clone();
1396 let paths = tree.paths.clone();
1397
1398 let mut rng = scirs2_core::random::rng();
1399
1400 for pos in 0..sentence.len() {
1401 let window = 1 + rng.random_range(0..window_size);
1402 let target_word = sentence[pos];
1403
1404 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1405 if i == pos {
1406 continue;
1407 }
1408
1409 let context_word = sentence[i];
1410 let code = &codes[context_word];
1411 let path = &paths[context_word];
1412
1413 let mut grad_input = Array1::zeros(vector_size);
1414
1415 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1417 if node_idx >= hs_params.nrows() {
1418 continue;
1419 }
1420
1421 let input_vec = input_embeddings.row(target_word);
1423 let param_vec = hs_params.row(node_idx);
1424
1425 let dot: f64 = input_vec
1426 .iter()
1427 .zip(param_vec.iter())
1428 .map(|(a, b)| a * b)
1429 .sum();
1430 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1431
1432 let target = if label == 0 { 1.0 } else { 0.0 };
1434 let gradient = (target - sigmoid) * lr;
1435
1436 grad_input.scaled_add(gradient, ¶m_vec.to_owned());
1438
1439 let input_owned = input_vec.to_owned();
1441 let mut param_mut = hs_params.row_mut(node_idx);
1442 param_mut.scaled_add(gradient, &input_owned);
1443 }
1444
1445 let mut input_mut = input_embeddings.row_mut(target_word);
1447 input_mut += &grad_input;
1448 }
1449 }
1450
1451 Ok(())
1452 }
1453
1454 fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1456 if sentence.len() < 2 {
1457 return Ok(());
1458 }
1459
1460 let input_embeddings = self
1461 .input_embeddings
1462 .as_mut()
1463 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1464 let hs_params = self
1465 .hs_params
1466 .as_mut()
1467 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1468 let tree = self
1469 .huffman_tree
1470 .as_ref()
1471 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1472
1473 let vector_size = self.config.vector_size;
1474 let window_size = self.config.window_size;
1475 let lr = self.current_learning_rate;
1476
1477 let codes = tree.codes.clone();
1478 let paths = tree.paths.clone();
1479
1480 let mut rng = scirs2_core::random::rng();
1481
1482 for pos in 0..sentence.len() {
1483 let window = 1 + rng.random_range(0..window_size);
1484 let target_word = sentence[pos];
1485
1486 let mut context_words = Vec::new();
1488 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1489 if i != pos {
1490 context_words.push(sentence[i]);
1491 }
1492 }
1493
1494 if context_words.is_empty() {
1495 continue;
1496 }
1497
1498 let mut context_avg = Array1::zeros(vector_size);
1500 for &ctx_idx in &context_words {
1501 context_avg += &input_embeddings.row(ctx_idx);
1502 }
1503 context_avg /= context_words.len() as f64;
1504
1505 let code = &codes[target_word];
1507 let path = &paths[target_word];
1508
1509 let mut grad_context = Array1::zeros(vector_size);
1510
1511 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1512 if node_idx >= hs_params.nrows() {
1513 continue;
1514 }
1515
1516 let param_vec = hs_params.row(node_idx);
1517
1518 let dot: f64 = context_avg
1519 .iter()
1520 .zip(param_vec.iter())
1521 .map(|(a, b)| a * b)
1522 .sum();
1523 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1524
1525 let target = if label == 0 { 1.0 } else { 0.0 };
1526 let gradient = (target - sigmoid) * lr;
1527
1528 grad_context.scaled_add(gradient, ¶m_vec.to_owned());
1529
1530 let ctx_owned = context_avg.clone();
1532 let mut param_mut = hs_params.row_mut(node_idx);
1533 param_mut.scaled_add(gradient, &ctx_owned);
1534 }
1535
1536 let grad_per_word = &grad_context / context_words.len() as f64;
1538 for &ctx_idx in &context_words {
1539 let mut input_mut = input_embeddings.row_mut(ctx_idx);
1540 input_mut += &grad_per_word;
1541 }
1542 }
1543
1544 Ok(())
1545 }
1546}
1547
1548impl WordEmbedding for Word2Vec {
1551 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1552 self.get_word_vector(word)
1553 }
1554
1555 fn dimension(&self) -> usize {
1556 self.vector_size()
1557 }
1558
1559 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1560 self.most_similar(word, top_n)
1561 }
1562
1563 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1564 self.analogy(a, b, c, top_n)
1565 }
1566
1567 fn vocab_size(&self) -> usize {
1568 self.vocabulary.len()
1569 }
1570}
1571
1572#[allow(dead_code)]
1574pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1575 let dot_product = (a * b).sum();
1576 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1577 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1578
1579 if norm_a > 0.0 && norm_b > 0.0 {
1580 dot_product / (norm_a * norm_b)
1581 } else {
1582 0.0
1583 }
1584}
1585
1586#[cfg(test)]
1587mod tests {
1588 use super::*;
1589 use approx::assert_relative_eq;
1590
1591 #[test]
1592 fn test_cosine_similarity() {
1593 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1594 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1595
1596 let similarity = cosine_similarity(&a, &b);
1597 let expected = 0.9746318461970762;
1598 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1599 }
1600
1601 #[test]
1602 fn test_word2vec_config() {
1603 let config = Word2VecConfig::default();
1604 assert_eq!(config.vector_size, 100);
1605 assert_eq!(config.window_size, 5);
1606 assert_eq!(config.min_count, 5);
1607 assert_eq!(config.epochs, 5);
1608 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1609 }
1610
1611 #[test]
1612 fn test_word2vec_builder() {
1613 let model = Word2Vec::new()
1614 .with_vector_size(200)
1615 .with_window_size(10)
1616 .with_learning_rate(0.05)
1617 .with_algorithm(Word2VecAlgorithm::CBOW);
1618
1619 assert_eq!(model.config.vector_size, 200);
1620 assert_eq!(model.config.window_size, 10);
1621 assert_eq!(model.config.learning_rate, 0.05);
1622 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1623 }
1624
1625 #[test]
1626 fn test_build_vocabulary() {
1627 let texts = [
1628 "the quick brown fox jumps over the lazy dog",
1629 "a quick brown fox jumps over a lazy dog",
1630 ];
1631
1632 let mut model = Word2Vec::new().with_min_count(1);
1633 let result = model.build_vocabulary(&texts);
1634 assert!(result.is_ok());
1635
1636 assert_eq!(model.vocabulary.len(), 9);
1638
1639 assert!(model.input_embeddings.is_some());
1641 assert!(model.output_embeddings.is_some());
1642 assert_eq!(
1643 model
1644 .input_embeddings
1645 .as_ref()
1646 .expect("Operation failed")
1647 .shape(),
1648 &[9, 100]
1649 );
1650 }
1651
1652 #[test]
1653 fn test_skipgram_training_small() {
1654 let texts = [
1655 "the quick brown fox jumps over the lazy dog",
1656 "a quick brown fox jumps over a lazy dog",
1657 ];
1658
1659 let mut model = Word2Vec::new()
1660 .with_vector_size(10)
1661 .with_window_size(2)
1662 .with_min_count(1)
1663 .with_epochs(1)
1664 .with_algorithm(Word2VecAlgorithm::SkipGram);
1665
1666 let result = model.train(&texts);
1667 assert!(result.is_ok());
1668
1669 let result = model.get_word_vector("fox");
1671 assert!(result.is_ok());
1672 let vec = result.expect("Operation failed");
1673 assert_eq!(vec.len(), 10);
1674 }
1675
1676 #[test]
1679 fn test_huffman_tree_build() {
1680 let frequencies = vec![5, 3, 8, 1, 2];
1681 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1682
1683 assert_eq!(tree.codes.len(), 5);
1685 assert_eq!(tree.paths.len(), 5);
1686
1687 for code in &tree.codes {
1689 assert!(!code.is_empty());
1690 }
1691
1692 assert_eq!(tree.num_internal, 4);
1694 }
1695
1696 #[test]
1697 fn test_huffman_tree_single_word() {
1698 let frequencies = vec![10];
1699 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1700 assert_eq!(tree.codes.len(), 1);
1701 assert_eq!(tree.paths.len(), 1);
1702 }
1703
1704 #[test]
1705 fn test_skipgram_hierarchical_softmax() {
1706 let texts = [
1707 "the quick brown fox jumps over the lazy dog",
1708 "a quick brown fox jumps over a lazy dog",
1709 ];
1710
1711 let config = Word2VecConfig {
1712 vector_size: 10,
1713 window_size: 2,
1714 min_count: 1,
1715 epochs: 3,
1716 learning_rate: 0.025,
1717 algorithm: Word2VecAlgorithm::SkipGram,
1718 hierarchical_softmax: true,
1719 ..Default::default()
1720 };
1721
1722 let mut model = Word2Vec::with_config(config);
1723 let result = model.train(&texts);
1724 assert!(
1725 result.is_ok(),
1726 "HS skipgram training failed: {:?}",
1727 result.err()
1728 );
1729
1730 assert!(model.uses_hierarchical_softmax());
1731
1732 let vec = model.get_word_vector("fox");
1734 assert!(vec.is_ok());
1735 assert_eq!(vec.expect("get vec").len(), 10);
1736 }
1737
1738 #[test]
1739 fn test_cbow_hierarchical_softmax() {
1740 let texts = [
1741 "the quick brown fox jumps over the lazy dog",
1742 "a quick brown fox jumps over a lazy dog",
1743 ];
1744
1745 let config = Word2VecConfig {
1746 vector_size: 10,
1747 window_size: 2,
1748 min_count: 1,
1749 epochs: 3,
1750 learning_rate: 0.025,
1751 algorithm: Word2VecAlgorithm::CBOW,
1752 hierarchical_softmax: true,
1753 ..Default::default()
1754 };
1755
1756 let mut model = Word2Vec::with_config(config);
1757 let result = model.train(&texts);
1758 assert!(
1759 result.is_ok(),
1760 "HS CBOW training failed: {:?}",
1761 result.err()
1762 );
1763
1764 let vec = model.get_word_vector("dog");
1765 assert!(vec.is_ok());
1766 }
1767
1768 #[test]
1771 fn test_word_embedding_trait_word2vec() {
1772 let texts = [
1773 "the quick brown fox jumps over the lazy dog",
1774 "a quick brown fox jumps over a lazy dog",
1775 ];
1776
1777 let mut model = Word2Vec::new()
1778 .with_vector_size(10)
1779 .with_min_count(1)
1780 .with_epochs(1);
1781
1782 model.train(&texts).expect("Training failed");
1783
1784 let emb: &dyn WordEmbedding = &model;
1786 assert_eq!(emb.dimension(), 10);
1787 assert!(emb.vocab_size() > 0);
1788
1789 let vec = emb.embedding("fox");
1790 assert!(vec.is_ok());
1791
1792 let sim = emb.similarity("fox", "dog");
1793 assert!(sim.is_ok());
1794 assert!(sim.expect("sim").is_finite());
1795
1796 let similar = emb.find_similar("fox", 2);
1797 assert!(similar.is_ok());
1798
1799 let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1800 assert!(analogy.is_ok());
1801 }
1802
1803 #[test]
1804 fn test_embedding_cosine_similarity_fn() {
1805 let a = Array1::from_vec(vec![1.0, 0.0]);
1806 let b = Array1::from_vec(vec![0.0, 1.0]);
1807 assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1808
1809 let c = Array1::from_vec(vec![1.0, 1.0]);
1810 let d = Array1::from_vec(vec![1.0, 1.0]);
1811 assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1812 }
1813
1814 #[test]
1815 fn test_pairwise_similarity_fn() {
1816 let texts = ["the quick brown fox", "the lazy brown dog"];
1817
1818 let mut model = Word2Vec::new()
1819 .with_vector_size(10)
1820 .with_min_count(1)
1821 .with_epochs(1);
1822 model.train(&texts).expect("Training failed");
1823
1824 let words = vec!["the", "fox", "dog"];
1825 let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1826
1827 assert_eq!(matrix.len(), 3);
1828 assert_eq!(matrix[0].len(), 3);
1829
1830 for i in 0..3 {
1832 assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1833 }
1834
1835 for i in 0..3 {
1837 for j in 0..3 {
1838 assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1839 }
1840 }
1841 }
1842}