1use crate::error::{Result, TextError};
132use crate::tokenize::{Tokenizer, WordTokenizer};
133use crate::vocabulary::Vocabulary;
134use scirs2_core::ndarray::{Array1, Array2};
135use scirs2_core::random::prelude::*;
136use std::collections::HashMap;
137use std::fmt::Debug;
138use std::fs::File;
139use std::io::{BufRead, BufReader, Write};
140use std::path::Path;
141
142#[derive(Debug, Clone)]
144struct SamplingTable {
145 cdf: Vec<f64>,
147 weights: Vec<f64>,
149}
150
151impl SamplingTable {
152 fn new(weights: &[f64]) -> Result<Self> {
154 if weights.is_empty() {
155 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
156 }
157
158 if weights.iter().any(|&w| w < 0.0) {
160 return Err(TextError::EmbeddingError("Weights must be positive".into()));
161 }
162
163 let sum: f64 = weights.iter().sum();
165 if sum <= 0.0 {
166 return Err(TextError::EmbeddingError(
167 "Sum of _weights must be positive".into(),
168 ));
169 }
170
171 let mut cdf = Vec::with_capacity(weights.len());
172 let mut total = 0.0;
173
174 for &w in weights {
175 total += w;
176 cdf.push(total / sum);
177 }
178
179 Ok(Self {
180 cdf,
181 weights: weights.to_vec(),
182 })
183 }
184
185 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
187 let r = rng.random::<f64>();
188
189 match self.cdf.binary_search_by(|&cdf_val| {
191 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
192 }) {
193 Ok(idx) => idx,
194 Err(idx) => idx,
195 }
196 }
197
198 fn weights(&self) -> &[f64] {
200 &self.weights
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum Word2VecAlgorithm {
207 CBOW,
209 SkipGram,
211}
212
213#[derive(Debug, Clone)]
215pub struct Word2VecConfig {
216 pub vector_size: usize,
218 pub window_size: usize,
220 pub min_count: usize,
222 pub epochs: usize,
224 pub learning_rate: f64,
226 pub algorithm: Word2VecAlgorithm,
228 pub negative_samples: usize,
230 pub subsample: f64,
232 pub batch_size: usize,
234 pub hierarchical_softmax: bool,
236}
237
238impl Default for Word2VecConfig {
239 fn default() -> Self {
240 Self {
241 vector_size: 100,
242 window_size: 5,
243 min_count: 5,
244 epochs: 5,
245 learning_rate: 0.025,
246 algorithm: Word2VecAlgorithm::SkipGram,
247 negative_samples: 5,
248 subsample: 1e-3,
249 batch_size: 128,
250 hierarchical_softmax: false,
251 }
252 }
253}
254
255pub struct Word2Vec {
265 config: Word2VecConfig,
267 vocabulary: Vocabulary,
269 input_embeddings: Option<Array2<f64>>,
271 output_embeddings: Option<Array2<f64>>,
273 tokenizer: Box<dyn Tokenizer + Send + Sync>,
275 sampling_table: Option<SamplingTable>,
277 current_learning_rate: f64,
279}
280
281impl Debug for Word2Vec {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("Word2Vec")
284 .field("config", &self.config)
285 .field("vocabulary", &self.vocabulary)
286 .field("input_embeddings", &self.input_embeddings)
287 .field("output_embeddings", &self.output_embeddings)
288 .field("sampling_table", &self.sampling_table)
289 .field("current_learning_rate", &self.current_learning_rate)
290 .finish()
291 }
292}
293
294impl Default for Word2Vec {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301impl Clone for Word2Vec {
302 fn clone(&self) -> Self {
303 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
307
308 Self {
309 config: self.config.clone(),
310 vocabulary: self.vocabulary.clone(),
311 input_embeddings: self.input_embeddings.clone(),
312 output_embeddings: self.output_embeddings.clone(),
313 tokenizer,
314 sampling_table: self.sampling_table.clone(),
315 current_learning_rate: self.current_learning_rate,
316 }
317 }
318}
319
320impl Word2Vec {
321 pub fn new() -> Self {
323 Self {
324 config: Word2VecConfig::default(),
325 vocabulary: Vocabulary::new(),
326 input_embeddings: None,
327 output_embeddings: None,
328 tokenizer: Box::new(WordTokenizer::default()),
329 sampling_table: None,
330 current_learning_rate: 0.025,
331 }
332 }
333
334 pub fn with_config(config: Word2VecConfig) -> Self {
336 let learning_rate = config.learning_rate;
337 Self {
338 config,
339 vocabulary: Vocabulary::new(),
340 input_embeddings: None,
341 output_embeddings: None,
342 tokenizer: Box::new(WordTokenizer::default()),
343 sampling_table: None,
344 current_learning_rate: learning_rate,
345 }
346 }
347
348 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
350 self.tokenizer = tokenizer;
351 self
352 }
353
354 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
356 self.config.vector_size = vectorsize;
357 self
358 }
359
360 pub fn with_window_size(mut self, windowsize: usize) -> Self {
362 self.config.window_size = windowsize;
363 self
364 }
365
366 pub fn with_min_count(mut self, mincount: usize) -> Self {
368 self.config.min_count = mincount;
369 self
370 }
371
372 pub fn with_epochs(mut self, epochs: usize) -> Self {
374 self.config.epochs = epochs;
375 self
376 }
377
378 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
380 self.config.learning_rate = learningrate;
381 self.current_learning_rate = learningrate;
382 self
383 }
384
385 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
387 self.config.algorithm = algorithm;
388 self
389 }
390
391 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
393 self.config.negative_samples = negativesamples;
394 self
395 }
396
397 pub fn with_subsample(mut self, subsample: f64) -> Self {
399 self.config.subsample = subsample;
400 self
401 }
402
403 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
405 self.config.batch_size = batchsize;
406 self
407 }
408
409 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
411 if texts.is_empty() {
412 return Err(TextError::InvalidInput(
413 "No texts provided for building vocabulary".into(),
414 ));
415 }
416
417 let mut word_counts = HashMap::new();
419 let mut _total_words = 0;
420
421 for &text in texts {
422 let tokens = self.tokenizer.tokenize(text)?;
423 for token in tokens {
424 *word_counts.entry(token).or_insert(0) += 1;
425 _total_words += 1;
426 }
427 }
428
429 self.vocabulary = Vocabulary::new();
431 for (word, count) in &word_counts {
432 if *count >= self.config.min_count {
433 self.vocabulary.add_token(word);
434 }
435 }
436
437 if self.vocabulary.is_empty() {
438 return Err(TextError::VocabularyError(
439 "No words meet the minimum count threshold".into(),
440 ));
441 }
442
443 let vocab_size = self.vocabulary.len();
445 let vector_size = self.config.vector_size;
446
447 let mut rng = scirs2_core::random::rng();
449 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
450 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
451 });
452 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
453 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
454 });
455
456 self.input_embeddings = Some(input_embeddings);
457 self.output_embeddings = Some(output_embeddings);
458
459 self.create_sampling_table(&word_counts)?;
461
462 Ok(())
463 }
464
465 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
467 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
469
470 for (word, &count) in wordcounts.iter() {
471 if let Some(idx) = self.vocabulary.get_index(word) {
472 sampling_weights[idx] = (count as f64).powf(0.75);
474 }
475 }
476
477 match SamplingTable::new(&sampling_weights) {
478 Ok(table) => {
479 self.sampling_table = Some(table);
480 Ok(())
481 }
482 Err(e) => Err(e),
483 }
484 }
485
486 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
488 if texts.is_empty() {
489 return Err(TextError::InvalidInput(
490 "No texts provided for training".into(),
491 ));
492 }
493
494 if self.vocabulary.is_empty() {
496 self.build_vocabulary(texts)?;
497 }
498
499 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
500 return Err(TextError::EmbeddingError(
501 "Embeddings not initialized. Call build_vocabulary() first".into(),
502 ));
503 }
504
505 let mut _total_tokens = 0;
507 let mut sentences = Vec::new();
508 for &text in texts {
509 let tokens = self.tokenizer.tokenize(text)?;
510 let filtered_tokens: Vec<usize> = tokens
511 .iter()
512 .filter_map(|token| self.vocabulary.get_index(token))
513 .collect();
514 if !filtered_tokens.is_empty() {
515 _total_tokens += filtered_tokens.len();
516 sentences.push(filtered_tokens);
517 }
518 }
519
520 for epoch in 0..self.config.epochs {
522 self.current_learning_rate =
524 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
525 self.current_learning_rate = self
526 .current_learning_rate
527 .max(self.config.learning_rate * 0.0001);
528
529 for sentence in &sentences {
531 let subsampled_sentence = if self.config.subsample > 0.0 {
533 self.subsample_sentence(sentence)?
534 } else {
535 sentence.clone()
536 };
537
538 if subsampled_sentence.is_empty() {
540 continue;
541 }
542
543 match self.config.algorithm {
545 Word2VecAlgorithm::CBOW => {
546 self.train_cbow_sentence(&subsampled_sentence)?;
547 }
548 Word2VecAlgorithm::SkipGram => {
549 self.train_skipgram_sentence(&subsampled_sentence)?;
550 }
551 }
552 }
553 }
554
555 Ok(())
556 }
557
558 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
560 let mut rng = scirs2_core::random::rng();
561 let total_words: f64 = self.vocabulary.len() as f64;
562 let threshold = self.config.subsample * total_words;
563
564 let subsampled: Vec<usize> = sentence
566 .iter()
567 .filter(|&&word_idx| {
568 let word_freq = self.get_word_frequency(word_idx);
569 if word_freq == 0.0 {
570 return true; }
572 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
574 rng.random::<f64>() < keep_prob
575 })
576 .copied()
577 .collect();
578
579 Ok(subsampled)
580 }
581
582 fn get_word_frequency(&self, wordidx: usize) -> f64 {
584 if let Some(table) = &self.sampling_table {
587 table.weights()[wordidx]
588 } else {
589 1.0 }
591 }
592
593 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
595 if sentence.len() < 2 {
596 return Ok(()); }
598
599 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
600 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
601 let vector_size = self.config.vector_size;
602 let window_size = self.config.window_size;
603 let negative_samples = self.config.negative_samples;
604
605 for pos in 0..sentence.len() {
607 let mut rng = scirs2_core::random::rng();
609 let window = 1 + rng.random_range(0..window_size);
610 let target_word = sentence[pos];
611
612 let mut context_words = Vec::new();
614 #[allow(clippy::needless_range_loop)]
615 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
616 if i != pos {
617 context_words.push(sentence[i]);
618 }
619 }
620
621 if context_words.is_empty() {
622 continue; }
624
625 let mut context_sum = Array1::zeros(vector_size);
627 for &context_idx in &context_words {
628 context_sum += &input_embeddings.row(context_idx);
629 }
630 let context_avg = &context_sum / context_words.len() as f64;
631
632 let mut target_output = output_embeddings.row_mut(target_word);
634 let dot_product = (&context_avg * &target_output).sum();
635 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
636 let error = (1.0 - sigmoid) * self.current_learning_rate;
637
638 let mut target_update = target_output.to_owned();
640 target_update.scaled_add(error, &context_avg);
641 target_output.assign(&target_update);
642
643 if let Some(sampler) = &self.sampling_table {
645 for _ in 0..negative_samples {
646 let negative_idx = sampler.sample(&mut rng);
647 if negative_idx == target_word {
648 continue; }
650
651 let mut negative_output = output_embeddings.row_mut(negative_idx);
652 let dot_product = (&context_avg * &negative_output).sum();
653 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
654 let error = -sigmoid * self.current_learning_rate;
655
656 let mut negative_update = negative_output.to_owned();
658 negative_update.scaled_add(error, &context_avg);
659 negative_output.assign(&negative_update);
660 }
661 }
662
663 for &context_idx in &context_words {
665 let mut input_vec = input_embeddings.row_mut(context_idx);
666
667 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
669 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
670 let error =
671 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
672
673 let mut input_update = input_vec.to_owned();
675 input_update.scaled_add(error, &output_embeddings.row(target_word));
676
677 if let Some(sampler) = &self.sampling_table {
679 for _ in 0..negative_samples {
680 let negative_idx = sampler.sample(&mut rng);
681 if negative_idx == target_word {
682 continue;
683 }
684
685 let dot_product =
686 (&context_avg * &output_embeddings.row(negative_idx)).sum();
687 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
688 let error =
689 -sigmoid * self.current_learning_rate / context_words.len() as f64;
690
691 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
692 }
693 }
694
695 input_vec.assign(&input_update);
696 }
697 }
698
699 Ok(())
700 }
701
702 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
704 if sentence.len() < 2 {
705 return Ok(()); }
707
708 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
709 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
710 let vector_size = self.config.vector_size;
711 let window_size = self.config.window_size;
712 let negative_samples = self.config.negative_samples;
713
714 for pos in 0..sentence.len() {
716 let mut rng = scirs2_core::random::rng();
718 let window = 1 + rng.random_range(0..window_size);
719 let target_word = sentence[pos];
720
721 #[allow(clippy::needless_range_loop)]
723 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
724 if i == pos {
725 continue; }
727
728 let context_word = sentence[i];
729 let target_input = input_embeddings.row(target_word);
730 let mut context_output = output_embeddings.row_mut(context_word);
731
732 let dot_product = (&target_input * &context_output).sum();
734 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
735 let error = (1.0 - sigmoid) * self.current_learning_rate;
736
737 let mut context_update = context_output.to_owned();
739 context_update.scaled_add(error, &target_input);
740 context_output.assign(&context_update);
741
742 let mut input_update = Array1::zeros(vector_size);
744 input_update.scaled_add(error, &context_output);
745
746 if let Some(sampler) = &self.sampling_table {
748 for _ in 0..negative_samples {
749 let negative_idx = sampler.sample(&mut rng);
750 if negative_idx == context_word {
751 continue; }
753
754 let mut negative_output = output_embeddings.row_mut(negative_idx);
755 let dot_product = (&target_input * &negative_output).sum();
756 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
757 let error = -sigmoid * self.current_learning_rate;
758
759 let mut negative_update = negative_output.to_owned();
761 negative_update.scaled_add(error, &target_input);
762 negative_output.assign(&negative_update);
763
764 input_update.scaled_add(error, &negative_output);
766 }
767 }
768
769 let mut target_input_mut = input_embeddings.row_mut(target_word);
771 target_input_mut += &input_update;
772 }
773 }
774
775 Ok(())
776 }
777
778 pub fn vector_size(&self) -> usize {
780 self.config.vector_size
781 }
782
783 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
785 if self.input_embeddings.is_none() {
786 return Err(TextError::EmbeddingError(
787 "Model not trained. Call train() first".into(),
788 ));
789 }
790
791 match self.vocabulary.get_index(word) {
792 Some(idx) => Ok(self
793 .input_embeddings
794 .as_ref()
795 .expect("Operation failed")
796 .row(idx)
797 .to_owned()),
798 None => Err(TextError::VocabularyError(format!(
799 "Word '{word}' not in vocabulary"
800 ))),
801 }
802 }
803
804 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
806 let word_vec = self.get_word_vector(word)?;
807 self.most_similar_by_vector(&word_vec, topn, &[word])
808 }
809
810 pub fn most_similar_by_vector(
812 &self,
813 vector: &Array1<f64>,
814 top_n: usize,
815 exclude_words: &[&str],
816 ) -> Result<Vec<(String, f64)>> {
817 if self.input_embeddings.is_none() {
818 return Err(TextError::EmbeddingError(
819 "Model not trained. Call train() first".into(),
820 ));
821 }
822
823 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
824 let vocab_size = self.vocabulary.len();
825
826 let exclude_indices: Vec<usize> = exclude_words
828 .iter()
829 .filter_map(|&word| self.vocabulary.get_index(word))
830 .collect();
831
832 let mut similarities = Vec::with_capacity(vocab_size);
834
835 for i in 0..vocab_size {
836 if exclude_indices.contains(&i) {
837 continue;
838 }
839
840 let word_vec = input_embeddings.row(i);
841 let similarity = cosine_similarity(vector, &word_vec.to_owned());
842
843 if let Some(word) = self.vocabulary.get_token(i) {
844 similarities.push((word.to_string(), similarity));
845 }
846 }
847
848 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
850
851 let result = similarities.into_iter().take(top_n).collect();
853 Ok(result)
854 }
855
856 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
858 if self.input_embeddings.is_none() {
859 return Err(TextError::EmbeddingError(
860 "Model not trained. Call train() first".into(),
861 ));
862 }
863
864 let a_vec = self.get_word_vector(a)?;
866 let b_vec = self.get_word_vector(b)?;
867 let c_vec = self.get_word_vector(c)?;
868
869 let mut d_vec = b_vec.clone();
871 d_vec -= &a_vec;
872 d_vec += &c_vec;
873
874 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
876 d_vec.mapv_inplace(|val| val / norm);
877
878 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
880 }
881
882 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
884 if self.input_embeddings.is_none() {
885 return Err(TextError::EmbeddingError(
886 "Model not trained. Call train() first".into(),
887 ));
888 }
889
890 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
891
892 writeln!(
894 &mut file,
895 "{} {}",
896 self.vocabulary.len(),
897 self.config.vector_size
898 )
899 .map_err(|e| TextError::IoError(e.to_string()))?;
900
901 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
903
904 for i in 0..self.vocabulary.len() {
905 if let Some(word) = self.vocabulary.get_token(i) {
906 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
908
909 let vector = input_embeddings.row(i);
911 for j in 0..self.config.vector_size {
912 write!(&mut file, "{:.6} ", vector[j])
913 .map_err(|e| TextError::IoError(e.to_string()))?;
914 }
915
916 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
917 }
918 }
919
920 Ok(())
921 }
922
923 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
925 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
926 let mut reader = BufReader::new(file);
927
928 let mut header = String::new();
930 reader
931 .read_line(&mut header)
932 .map_err(|e| TextError::IoError(e.to_string()))?;
933
934 let parts: Vec<&str> = header.split_whitespace().collect();
935 if parts.len() != 2 {
936 return Err(TextError::EmbeddingError(
937 "Invalid model file format".into(),
938 ));
939 }
940
941 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
942 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
943 })?;
944
945 let vector_size = parts[1]
946 .parse::<usize>()
947 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
948
949 let mut model = Self::new().with_vector_size(vector_size);
951 let mut vocabulary = Vocabulary::new();
952 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
953
954 let mut i = 0;
956 for line in reader.lines() {
957 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
958 let parts: Vec<&str> = line.split_whitespace().collect();
959
960 if parts.len() != vector_size + 1 {
961 let line_num = i + 2;
962 return Err(TextError::EmbeddingError(format!(
963 "Invalid vector format at line {line_num}"
964 )));
965 }
966
967 let word = parts[0];
968 vocabulary.add_token(word);
969
970 for j in 0..vector_size {
971 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
972 TextError::EmbeddingError(format!(
973 "Invalid vector component at line {}, position {}",
974 i + 2,
975 j + 1
976 ))
977 })?;
978 }
979
980 i += 1;
981 }
982
983 if i != vocab_size {
984 return Err(TextError::EmbeddingError(format!(
985 "Expected {vocab_size} words but found {i}"
986 )));
987 }
988
989 model.vocabulary = vocabulary;
990 model.input_embeddings = Some(input_embeddings);
991 model.output_embeddings = None; Ok(model)
994 }
995
996 pub fn get_vocabulary(&self) -> Vec<String> {
1000 let mut vocab = Vec::new();
1001 for i in 0..self.vocabulary.len() {
1002 if let Some(token) = self.vocabulary.get_token(i) {
1003 vocab.push(token.to_string());
1004 }
1005 }
1006 vocab
1007 }
1008
1009 pub fn get_vector_size(&self) -> usize {
1011 self.config.vector_size
1012 }
1013
1014 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1016 self.config.algorithm
1017 }
1018
1019 pub fn get_window_size(&self) -> usize {
1021 self.config.window_size
1022 }
1023
1024 pub fn get_min_count(&self) -> usize {
1026 self.config.min_count
1027 }
1028
1029 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1031 self.input_embeddings.clone()
1032 }
1033
1034 pub fn get_negative_samples(&self) -> usize {
1036 self.config.negative_samples
1037 }
1038
1039 pub fn get_learning_rate(&self) -> f64 {
1041 self.config.learning_rate
1042 }
1043
1044 pub fn get_epochs(&self) -> usize {
1046 self.config.epochs
1047 }
1048
1049 pub fn get_subsampling_threshold(&self) -> f64 {
1051 self.config.subsample
1052 }
1053}
1054
1055#[allow(dead_code)]
1057pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1058 let dot_product = (a * b).sum();
1059 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1060 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1061
1062 if norm_a > 0.0 && norm_b > 0.0 {
1063 dot_product / (norm_a * norm_b)
1064 } else {
1065 0.0
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use approx::assert_relative_eq;
1073
1074 #[test]
1075 fn test_cosine_similarity() {
1076 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1077 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1078
1079 let similarity = cosine_similarity(&a, &b);
1080 let expected = 0.9746318461970762;
1081 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1082 }
1083
1084 #[test]
1085 fn test_word2vec_config() {
1086 let config = Word2VecConfig::default();
1087 assert_eq!(config.vector_size, 100);
1088 assert_eq!(config.window_size, 5);
1089 assert_eq!(config.min_count, 5);
1090 assert_eq!(config.epochs, 5);
1091 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1092 }
1093
1094 #[test]
1095 fn test_word2vec_builder() {
1096 let model = Word2Vec::new()
1097 .with_vector_size(200)
1098 .with_window_size(10)
1099 .with_learning_rate(0.05)
1100 .with_algorithm(Word2VecAlgorithm::CBOW);
1101
1102 assert_eq!(model.config.vector_size, 200);
1103 assert_eq!(model.config.window_size, 10);
1104 assert_eq!(model.config.learning_rate, 0.05);
1105 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1106 }
1107
1108 #[test]
1109 fn test_build_vocabulary() {
1110 let texts = [
1111 "the quick brown fox jumps over the lazy dog",
1112 "a quick brown fox jumps over a lazy dog",
1113 ];
1114
1115 let mut model = Word2Vec::new().with_min_count(1);
1116 let result = model.build_vocabulary(&texts);
1117 assert!(result.is_ok());
1118
1119 assert_eq!(model.vocabulary.len(), 9);
1121
1122 assert!(model.input_embeddings.is_some());
1124 assert!(model.output_embeddings.is_some());
1125 assert_eq!(
1126 model
1127 .input_embeddings
1128 .as_ref()
1129 .expect("Operation failed")
1130 .shape(),
1131 &[9, 100]
1132 );
1133 }
1134
1135 #[test]
1136 fn test_skipgram_training_small() {
1137 let texts = [
1138 "the quick brown fox jumps over the lazy dog",
1139 "a quick brown fox jumps over a lazy dog",
1140 ];
1141
1142 let mut model = Word2Vec::new()
1143 .with_vector_size(10)
1144 .with_window_size(2)
1145 .with_min_count(1)
1146 .with_epochs(1)
1147 .with_algorithm(Word2VecAlgorithm::SkipGram);
1148
1149 let result = model.train(&texts);
1150 assert!(result.is_ok());
1151
1152 let result = model.get_word_vector("fox");
1154 assert!(result.is_ok());
1155 let vec = result.expect("Operation failed");
1156 assert_eq!(vec.len(), 10);
1157 }
1158}