1pub mod fasttext;
135pub mod glove;
136
137pub use fasttext::{FastText, FastTextConfig};
139pub use glove::GloVe;
140
141use crate::error::{Result, TextError};
142use crate::tokenize::{Tokenizer, WordTokenizer};
143use crate::vocabulary::Vocabulary;
144use scirs2_core::ndarray::{Array1, Array2};
145use scirs2_core::random::prelude::*;
146use std::collections::HashMap;
147use std::fmt::Debug;
148use std::fs::File;
149use std::io::{BufRead, BufReader, Write};
150use std::path::Path;
151
152#[derive(Debug, Clone)]
154struct SamplingTable {
155 cdf: Vec<f64>,
157 weights: Vec<f64>,
159}
160
161impl SamplingTable {
162 fn new(weights: &[f64]) -> Result<Self> {
164 if weights.is_empty() {
165 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
166 }
167
168 if weights.iter().any(|&w| w < 0.0) {
170 return Err(TextError::EmbeddingError("Weights must be positive".into()));
171 }
172
173 let sum: f64 = weights.iter().sum();
175 if sum <= 0.0 {
176 return Err(TextError::EmbeddingError(
177 "Sum of _weights must be positive".into(),
178 ));
179 }
180
181 let mut cdf = Vec::with_capacity(weights.len());
182 let mut total = 0.0;
183
184 for &w in weights {
185 total += w;
186 cdf.push(total / sum);
187 }
188
189 Ok(Self {
190 cdf,
191 weights: weights.to_vec(),
192 })
193 }
194
195 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
197 let r = rng.random::<f64>();
198
199 match self.cdf.binary_search_by(|&cdf_val| {
201 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
202 }) {
203 Ok(idx) => idx,
204 Err(idx) => idx,
205 }
206 }
207
208 fn weights(&self) -> &[f64] {
210 &self.weights
211 }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub enum Word2VecAlgorithm {
217 CBOW,
219 SkipGram,
221}
222
223#[derive(Debug, Clone)]
225pub struct Word2VecConfig {
226 pub vector_size: usize,
228 pub window_size: usize,
230 pub min_count: usize,
232 pub epochs: usize,
234 pub learning_rate: f64,
236 pub algorithm: Word2VecAlgorithm,
238 pub negative_samples: usize,
240 pub subsample: f64,
242 pub batch_size: usize,
244 pub hierarchical_softmax: bool,
246}
247
248impl Default for Word2VecConfig {
249 fn default() -> Self {
250 Self {
251 vector_size: 100,
252 window_size: 5,
253 min_count: 5,
254 epochs: 5,
255 learning_rate: 0.025,
256 algorithm: Word2VecAlgorithm::SkipGram,
257 negative_samples: 5,
258 subsample: 1e-3,
259 batch_size: 128,
260 hierarchical_softmax: false,
261 }
262 }
263}
264
265pub struct Word2Vec {
275 config: Word2VecConfig,
277 vocabulary: Vocabulary,
279 input_embeddings: Option<Array2<f64>>,
281 output_embeddings: Option<Array2<f64>>,
283 tokenizer: Box<dyn Tokenizer + Send + Sync>,
285 sampling_table: Option<SamplingTable>,
287 current_learning_rate: f64,
289}
290
291impl Debug for Word2Vec {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("Word2Vec")
294 .field("config", &self.config)
295 .field("vocabulary", &self.vocabulary)
296 .field("input_embeddings", &self.input_embeddings)
297 .field("output_embeddings", &self.output_embeddings)
298 .field("sampling_table", &self.sampling_table)
299 .field("current_learning_rate", &self.current_learning_rate)
300 .finish()
301 }
302}
303
304impl Default for Word2Vec {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl Clone for Word2Vec {
312 fn clone(&self) -> Self {
313 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
317
318 Self {
319 config: self.config.clone(),
320 vocabulary: self.vocabulary.clone(),
321 input_embeddings: self.input_embeddings.clone(),
322 output_embeddings: self.output_embeddings.clone(),
323 tokenizer,
324 sampling_table: self.sampling_table.clone(),
325 current_learning_rate: self.current_learning_rate,
326 }
327 }
328}
329
330impl Word2Vec {
331 pub fn new() -> Self {
333 Self {
334 config: Word2VecConfig::default(),
335 vocabulary: Vocabulary::new(),
336 input_embeddings: None,
337 output_embeddings: None,
338 tokenizer: Box::new(WordTokenizer::default()),
339 sampling_table: None,
340 current_learning_rate: 0.025,
341 }
342 }
343
344 pub fn with_config(config: Word2VecConfig) -> Self {
346 let learning_rate = config.learning_rate;
347 Self {
348 config,
349 vocabulary: Vocabulary::new(),
350 input_embeddings: None,
351 output_embeddings: None,
352 tokenizer: Box::new(WordTokenizer::default()),
353 sampling_table: None,
354 current_learning_rate: learning_rate,
355 }
356 }
357
358 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
360 self.tokenizer = tokenizer;
361 self
362 }
363
364 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
366 self.config.vector_size = vectorsize;
367 self
368 }
369
370 pub fn with_window_size(mut self, windowsize: usize) -> Self {
372 self.config.window_size = windowsize;
373 self
374 }
375
376 pub fn with_min_count(mut self, mincount: usize) -> Self {
378 self.config.min_count = mincount;
379 self
380 }
381
382 pub fn with_epochs(mut self, epochs: usize) -> Self {
384 self.config.epochs = epochs;
385 self
386 }
387
388 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
390 self.config.learning_rate = learningrate;
391 self.current_learning_rate = learningrate;
392 self
393 }
394
395 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
397 self.config.algorithm = algorithm;
398 self
399 }
400
401 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
403 self.config.negative_samples = negativesamples;
404 self
405 }
406
407 pub fn with_subsample(mut self, subsample: f64) -> Self {
409 self.config.subsample = subsample;
410 self
411 }
412
413 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
415 self.config.batch_size = batchsize;
416 self
417 }
418
419 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
421 if texts.is_empty() {
422 return Err(TextError::InvalidInput(
423 "No texts provided for building vocabulary".into(),
424 ));
425 }
426
427 let mut word_counts = HashMap::new();
429 let mut _total_words = 0;
430
431 for &text in texts {
432 let tokens = self.tokenizer.tokenize(text)?;
433 for token in tokens {
434 *word_counts.entry(token).or_insert(0) += 1;
435 _total_words += 1;
436 }
437 }
438
439 self.vocabulary = Vocabulary::new();
441 for (word, count) in &word_counts {
442 if *count >= self.config.min_count {
443 self.vocabulary.add_token(word);
444 }
445 }
446
447 if self.vocabulary.is_empty() {
448 return Err(TextError::VocabularyError(
449 "No words meet the minimum count threshold".into(),
450 ));
451 }
452
453 let vocab_size = self.vocabulary.len();
455 let vector_size = self.config.vector_size;
456
457 let mut rng = scirs2_core::random::rng();
459 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
460 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
461 });
462 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
463 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
464 });
465
466 self.input_embeddings = Some(input_embeddings);
467 self.output_embeddings = Some(output_embeddings);
468
469 self.create_sampling_table(&word_counts)?;
471
472 Ok(())
473 }
474
475 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
477 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
479
480 for (word, &count) in wordcounts.iter() {
481 if let Some(idx) = self.vocabulary.get_index(word) {
482 sampling_weights[idx] = (count as f64).powf(0.75);
484 }
485 }
486
487 match SamplingTable::new(&sampling_weights) {
488 Ok(table) => {
489 self.sampling_table = Some(table);
490 Ok(())
491 }
492 Err(e) => Err(e),
493 }
494 }
495
496 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
498 if texts.is_empty() {
499 return Err(TextError::InvalidInput(
500 "No texts provided for training".into(),
501 ));
502 }
503
504 if self.vocabulary.is_empty() {
506 self.build_vocabulary(texts)?;
507 }
508
509 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
510 return Err(TextError::EmbeddingError(
511 "Embeddings not initialized. Call build_vocabulary() first".into(),
512 ));
513 }
514
515 let mut _total_tokens = 0;
517 let mut sentences = Vec::new();
518 for &text in texts {
519 let tokens = self.tokenizer.tokenize(text)?;
520 let filtered_tokens: Vec<usize> = tokens
521 .iter()
522 .filter_map(|token| self.vocabulary.get_index(token))
523 .collect();
524 if !filtered_tokens.is_empty() {
525 _total_tokens += filtered_tokens.len();
526 sentences.push(filtered_tokens);
527 }
528 }
529
530 for epoch in 0..self.config.epochs {
532 self.current_learning_rate =
534 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
535 self.current_learning_rate = self
536 .current_learning_rate
537 .max(self.config.learning_rate * 0.0001);
538
539 for sentence in &sentences {
541 let subsampled_sentence = if self.config.subsample > 0.0 {
543 self.subsample_sentence(sentence)?
544 } else {
545 sentence.clone()
546 };
547
548 if subsampled_sentence.is_empty() {
550 continue;
551 }
552
553 match self.config.algorithm {
555 Word2VecAlgorithm::CBOW => {
556 self.train_cbow_sentence(&subsampled_sentence)?;
557 }
558 Word2VecAlgorithm::SkipGram => {
559 self.train_skipgram_sentence(&subsampled_sentence)?;
560 }
561 }
562 }
563 }
564
565 Ok(())
566 }
567
568 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
570 let mut rng = scirs2_core::random::rng();
571 let total_words: f64 = self.vocabulary.len() as f64;
572 let threshold = self.config.subsample * total_words;
573
574 let subsampled: Vec<usize> = sentence
576 .iter()
577 .filter(|&&word_idx| {
578 let word_freq = self.get_word_frequency(word_idx);
579 if word_freq == 0.0 {
580 return true; }
582 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
584 rng.random::<f64>() < keep_prob
585 })
586 .copied()
587 .collect();
588
589 Ok(subsampled)
590 }
591
592 fn get_word_frequency(&self, wordidx: usize) -> f64 {
594 if let Some(table) = &self.sampling_table {
597 table.weights()[wordidx]
598 } else {
599 1.0 }
601 }
602
603 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
605 if sentence.len() < 2 {
606 return Ok(()); }
608
609 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
610 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
611 let vector_size = self.config.vector_size;
612 let window_size = self.config.window_size;
613 let negative_samples = self.config.negative_samples;
614
615 for pos in 0..sentence.len() {
617 let mut rng = scirs2_core::random::rng();
619 let window = 1 + rng.random_range(0..window_size);
620 let target_word = sentence[pos];
621
622 let mut context_words = Vec::new();
624 #[allow(clippy::needless_range_loop)]
625 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
626 if i != pos {
627 context_words.push(sentence[i]);
628 }
629 }
630
631 if context_words.is_empty() {
632 continue; }
634
635 let mut context_sum = Array1::zeros(vector_size);
637 for &context_idx in &context_words {
638 context_sum += &input_embeddings.row(context_idx);
639 }
640 let context_avg = &context_sum / context_words.len() as f64;
641
642 let mut target_output = output_embeddings.row_mut(target_word);
644 let dot_product = (&context_avg * &target_output).sum();
645 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
646 let error = (1.0 - sigmoid) * self.current_learning_rate;
647
648 let mut target_update = target_output.to_owned();
650 target_update.scaled_add(error, &context_avg);
651 target_output.assign(&target_update);
652
653 if let Some(sampler) = &self.sampling_table {
655 for _ in 0..negative_samples {
656 let negative_idx = sampler.sample(&mut rng);
657 if negative_idx == target_word {
658 continue; }
660
661 let mut negative_output = output_embeddings.row_mut(negative_idx);
662 let dot_product = (&context_avg * &negative_output).sum();
663 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
664 let error = -sigmoid * self.current_learning_rate;
665
666 let mut negative_update = negative_output.to_owned();
668 negative_update.scaled_add(error, &context_avg);
669 negative_output.assign(&negative_update);
670 }
671 }
672
673 for &context_idx in &context_words {
675 let mut input_vec = input_embeddings.row_mut(context_idx);
676
677 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
679 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
680 let error =
681 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
682
683 let mut input_update = input_vec.to_owned();
685 input_update.scaled_add(error, &output_embeddings.row(target_word));
686
687 if let Some(sampler) = &self.sampling_table {
689 for _ in 0..negative_samples {
690 let negative_idx = sampler.sample(&mut rng);
691 if negative_idx == target_word {
692 continue;
693 }
694
695 let dot_product =
696 (&context_avg * &output_embeddings.row(negative_idx)).sum();
697 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
698 let error =
699 -sigmoid * self.current_learning_rate / context_words.len() as f64;
700
701 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
702 }
703 }
704
705 input_vec.assign(&input_update);
706 }
707 }
708
709 Ok(())
710 }
711
712 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
714 if sentence.len() < 2 {
715 return Ok(()); }
717
718 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
719 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
720 let vector_size = self.config.vector_size;
721 let window_size = self.config.window_size;
722 let negative_samples = self.config.negative_samples;
723
724 for pos in 0..sentence.len() {
726 let mut rng = scirs2_core::random::rng();
728 let window = 1 + rng.random_range(0..window_size);
729 let target_word = sentence[pos];
730
731 #[allow(clippy::needless_range_loop)]
733 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
734 if i == pos {
735 continue; }
737
738 let context_word = sentence[i];
739 let target_input = input_embeddings.row(target_word);
740 let mut context_output = output_embeddings.row_mut(context_word);
741
742 let dot_product = (&target_input * &context_output).sum();
744 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
745 let error = (1.0 - sigmoid) * self.current_learning_rate;
746
747 let mut context_update = context_output.to_owned();
749 context_update.scaled_add(error, &target_input);
750 context_output.assign(&context_update);
751
752 let mut input_update = Array1::zeros(vector_size);
754 input_update.scaled_add(error, &context_output);
755
756 if let Some(sampler) = &self.sampling_table {
758 for _ in 0..negative_samples {
759 let negative_idx = sampler.sample(&mut rng);
760 if negative_idx == context_word {
761 continue; }
763
764 let mut negative_output = output_embeddings.row_mut(negative_idx);
765 let dot_product = (&target_input * &negative_output).sum();
766 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
767 let error = -sigmoid * self.current_learning_rate;
768
769 let mut negative_update = negative_output.to_owned();
771 negative_update.scaled_add(error, &target_input);
772 negative_output.assign(&negative_update);
773
774 input_update.scaled_add(error, &negative_output);
776 }
777 }
778
779 let mut target_input_mut = input_embeddings.row_mut(target_word);
781 target_input_mut += &input_update;
782 }
783 }
784
785 Ok(())
786 }
787
788 pub fn vector_size(&self) -> usize {
790 self.config.vector_size
791 }
792
793 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
795 if self.input_embeddings.is_none() {
796 return Err(TextError::EmbeddingError(
797 "Model not trained. Call train() first".into(),
798 ));
799 }
800
801 match self.vocabulary.get_index(word) {
802 Some(idx) => Ok(self
803 .input_embeddings
804 .as_ref()
805 .expect("Operation failed")
806 .row(idx)
807 .to_owned()),
808 None => Err(TextError::VocabularyError(format!(
809 "Word '{word}' not in vocabulary"
810 ))),
811 }
812 }
813
814 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
816 let word_vec = self.get_word_vector(word)?;
817 self.most_similar_by_vector(&word_vec, topn, &[word])
818 }
819
820 pub fn most_similar_by_vector(
822 &self,
823 vector: &Array1<f64>,
824 top_n: usize,
825 exclude_words: &[&str],
826 ) -> Result<Vec<(String, f64)>> {
827 if self.input_embeddings.is_none() {
828 return Err(TextError::EmbeddingError(
829 "Model not trained. Call train() first".into(),
830 ));
831 }
832
833 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
834 let vocab_size = self.vocabulary.len();
835
836 let exclude_indices: Vec<usize> = exclude_words
838 .iter()
839 .filter_map(|&word| self.vocabulary.get_index(word))
840 .collect();
841
842 let mut similarities = Vec::with_capacity(vocab_size);
844
845 for i in 0..vocab_size {
846 if exclude_indices.contains(&i) {
847 continue;
848 }
849
850 let word_vec = input_embeddings.row(i);
851 let similarity = cosine_similarity(vector, &word_vec.to_owned());
852
853 if let Some(word) = self.vocabulary.get_token(i) {
854 similarities.push((word.to_string(), similarity));
855 }
856 }
857
858 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
860
861 let result = similarities.into_iter().take(top_n).collect();
863 Ok(result)
864 }
865
866 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
868 if self.input_embeddings.is_none() {
869 return Err(TextError::EmbeddingError(
870 "Model not trained. Call train() first".into(),
871 ));
872 }
873
874 let a_vec = self.get_word_vector(a)?;
876 let b_vec = self.get_word_vector(b)?;
877 let c_vec = self.get_word_vector(c)?;
878
879 let mut d_vec = b_vec.clone();
881 d_vec -= &a_vec;
882 d_vec += &c_vec;
883
884 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
886 d_vec.mapv_inplace(|val| val / norm);
887
888 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
890 }
891
892 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
894 if self.input_embeddings.is_none() {
895 return Err(TextError::EmbeddingError(
896 "Model not trained. Call train() first".into(),
897 ));
898 }
899
900 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
901
902 writeln!(
904 &mut file,
905 "{} {}",
906 self.vocabulary.len(),
907 self.config.vector_size
908 )
909 .map_err(|e| TextError::IoError(e.to_string()))?;
910
911 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
913
914 for i in 0..self.vocabulary.len() {
915 if let Some(word) = self.vocabulary.get_token(i) {
916 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
918
919 let vector = input_embeddings.row(i);
921 for j in 0..self.config.vector_size {
922 write!(&mut file, "{:.6} ", vector[j])
923 .map_err(|e| TextError::IoError(e.to_string()))?;
924 }
925
926 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
927 }
928 }
929
930 Ok(())
931 }
932
933 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
935 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
936 let mut reader = BufReader::new(file);
937
938 let mut header = String::new();
940 reader
941 .read_line(&mut header)
942 .map_err(|e| TextError::IoError(e.to_string()))?;
943
944 let parts: Vec<&str> = header.split_whitespace().collect();
945 if parts.len() != 2 {
946 return Err(TextError::EmbeddingError(
947 "Invalid model file format".into(),
948 ));
949 }
950
951 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
952 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
953 })?;
954
955 let vector_size = parts[1]
956 .parse::<usize>()
957 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
958
959 let mut model = Self::new().with_vector_size(vector_size);
961 let mut vocabulary = Vocabulary::new();
962 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
963
964 let mut i = 0;
966 for line in reader.lines() {
967 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
968 let parts: Vec<&str> = line.split_whitespace().collect();
969
970 if parts.len() != vector_size + 1 {
971 let line_num = i + 2;
972 return Err(TextError::EmbeddingError(format!(
973 "Invalid vector format at line {line_num}"
974 )));
975 }
976
977 let word = parts[0];
978 vocabulary.add_token(word);
979
980 for j in 0..vector_size {
981 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
982 TextError::EmbeddingError(format!(
983 "Invalid vector component at line {}, position {}",
984 i + 2,
985 j + 1
986 ))
987 })?;
988 }
989
990 i += 1;
991 }
992
993 if i != vocab_size {
994 return Err(TextError::EmbeddingError(format!(
995 "Expected {vocab_size} words but found {i}"
996 )));
997 }
998
999 model.vocabulary = vocabulary;
1000 model.input_embeddings = Some(input_embeddings);
1001 model.output_embeddings = None; Ok(model)
1004 }
1005
1006 pub fn get_vocabulary(&self) -> Vec<String> {
1010 let mut vocab = Vec::new();
1011 for i in 0..self.vocabulary.len() {
1012 if let Some(token) = self.vocabulary.get_token(i) {
1013 vocab.push(token.to_string());
1014 }
1015 }
1016 vocab
1017 }
1018
1019 pub fn get_vector_size(&self) -> usize {
1021 self.config.vector_size
1022 }
1023
1024 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1026 self.config.algorithm
1027 }
1028
1029 pub fn get_window_size(&self) -> usize {
1031 self.config.window_size
1032 }
1033
1034 pub fn get_min_count(&self) -> usize {
1036 self.config.min_count
1037 }
1038
1039 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1041 self.input_embeddings.clone()
1042 }
1043
1044 pub fn get_negative_samples(&self) -> usize {
1046 self.config.negative_samples
1047 }
1048
1049 pub fn get_learning_rate(&self) -> f64 {
1051 self.config.learning_rate
1052 }
1053
1054 pub fn get_epochs(&self) -> usize {
1056 self.config.epochs
1057 }
1058
1059 pub fn get_subsampling_threshold(&self) -> f64 {
1061 self.config.subsample
1062 }
1063}
1064
1065#[allow(dead_code)]
1067pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1068 let dot_product = (a * b).sum();
1069 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1070 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1071
1072 if norm_a > 0.0 && norm_b > 0.0 {
1073 dot_product / (norm_a * norm_b)
1074 } else {
1075 0.0
1076 }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use super::*;
1082 use approx::assert_relative_eq;
1083
1084 #[test]
1085 fn test_cosine_similarity() {
1086 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1087 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1088
1089 let similarity = cosine_similarity(&a, &b);
1090 let expected = 0.9746318461970762;
1091 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1092 }
1093
1094 #[test]
1095 fn test_word2vec_config() {
1096 let config = Word2VecConfig::default();
1097 assert_eq!(config.vector_size, 100);
1098 assert_eq!(config.window_size, 5);
1099 assert_eq!(config.min_count, 5);
1100 assert_eq!(config.epochs, 5);
1101 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1102 }
1103
1104 #[test]
1105 fn test_word2vec_builder() {
1106 let model = Word2Vec::new()
1107 .with_vector_size(200)
1108 .with_window_size(10)
1109 .with_learning_rate(0.05)
1110 .with_algorithm(Word2VecAlgorithm::CBOW);
1111
1112 assert_eq!(model.config.vector_size, 200);
1113 assert_eq!(model.config.window_size, 10);
1114 assert_eq!(model.config.learning_rate, 0.05);
1115 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1116 }
1117
1118 #[test]
1119 fn test_build_vocabulary() {
1120 let texts = [
1121 "the quick brown fox jumps over the lazy dog",
1122 "a quick brown fox jumps over a lazy dog",
1123 ];
1124
1125 let mut model = Word2Vec::new().with_min_count(1);
1126 let result = model.build_vocabulary(&texts);
1127 assert!(result.is_ok());
1128
1129 assert_eq!(model.vocabulary.len(), 9);
1131
1132 assert!(model.input_embeddings.is_some());
1134 assert!(model.output_embeddings.is_some());
1135 assert_eq!(
1136 model
1137 .input_embeddings
1138 .as_ref()
1139 .expect("Operation failed")
1140 .shape(),
1141 &[9, 100]
1142 );
1143 }
1144
1145 #[test]
1146 fn test_skipgram_training_small() {
1147 let texts = [
1148 "the quick brown fox jumps over the lazy dog",
1149 "a quick brown fox jumps over a lazy dog",
1150 ];
1151
1152 let mut model = Word2Vec::new()
1153 .with_vector_size(10)
1154 .with_window_size(2)
1155 .with_min_count(1)
1156 .with_epochs(1)
1157 .with_algorithm(Word2VecAlgorithm::SkipGram);
1158
1159 let result = model.train(&texts);
1160 assert!(result.is_ok());
1161
1162 let result = model.get_word_vector("fox");
1164 assert!(result.is_ok());
1165 let vec = result.expect("Operation failed");
1166 assert_eq!(vec.len(), 10);
1167 }
1168}