1use crate::error::{Result, TextError};
132use crate::tokenize::{Tokenizer, WordTokenizer};
133use crate::vocabulary::Vocabulary;
134use scirs2_core::ndarray::{Array1, Array2};
135use scirs2_core::random::prelude::*;
136use std::collections::HashMap;
137use std::fmt::Debug;
138use std::fs::File;
139use std::io::{BufRead, BufReader, Write};
140use std::path::Path;
141
142#[derive(Debug, Clone)]
144struct SamplingTable {
145 cdf: Vec<f64>,
147 weights: Vec<f64>,
149}
150
151impl SamplingTable {
152 fn new(weights: &[f64]) -> Result<Self> {
154 if weights.is_empty() {
155 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
156 }
157
158 if weights.iter().any(|&w| w < 0.0) {
160 return Err(TextError::EmbeddingError("Weights must be positive".into()));
161 }
162
163 let sum: f64 = weights.iter().sum();
165 if sum <= 0.0 {
166 return Err(TextError::EmbeddingError(
167 "Sum of _weights must be positive".into(),
168 ));
169 }
170
171 let mut cdf = Vec::with_capacity(weights.len());
172 let mut total = 0.0;
173
174 for &w in weights {
175 total += w;
176 cdf.push(total / sum);
177 }
178
179 Ok(Self {
180 cdf,
181 weights: weights.to_vec(),
182 })
183 }
184
185 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
187 let r = rng.random::<f64>();
188
189 match self.cdf.binary_search_by(|&cdf_val| {
191 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
192 }) {
193 Ok(idx) => idx,
194 Err(idx) => idx,
195 }
196 }
197
198 fn weights(&self) -> &[f64] {
200 &self.weights
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum Word2VecAlgorithm {
207 CBOW,
209 SkipGram,
211}
212
213#[derive(Debug, Clone)]
215pub struct Word2VecConfig {
216 pub vector_size: usize,
218 pub window_size: usize,
220 pub min_count: usize,
222 pub epochs: usize,
224 pub learning_rate: f64,
226 pub algorithm: Word2VecAlgorithm,
228 pub negative_samples: usize,
230 pub subsample: f64,
232 pub batch_size: usize,
234 pub hierarchical_softmax: bool,
236}
237
238impl Default for Word2VecConfig {
239 fn default() -> Self {
240 Self {
241 vector_size: 100,
242 window_size: 5,
243 min_count: 5,
244 epochs: 5,
245 learning_rate: 0.025,
246 algorithm: Word2VecAlgorithm::SkipGram,
247 negative_samples: 5,
248 subsample: 1e-3,
249 batch_size: 128,
250 hierarchical_softmax: false,
251 }
252 }
253}
254
255pub struct Word2Vec {
265 config: Word2VecConfig,
267 vocabulary: Vocabulary,
269 input_embeddings: Option<Array2<f64>>,
271 output_embeddings: Option<Array2<f64>>,
273 tokenizer: Box<dyn Tokenizer + Send + Sync>,
275 sampling_table: Option<SamplingTable>,
277 current_learning_rate: f64,
279}
280
281impl Debug for Word2Vec {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("Word2Vec")
284 .field("config", &self.config)
285 .field("vocabulary", &self.vocabulary)
286 .field("input_embeddings", &self.input_embeddings)
287 .field("output_embeddings", &self.output_embeddings)
288 .field("sampling_table", &self.sampling_table)
289 .field("current_learning_rate", &self.current_learning_rate)
290 .finish()
291 }
292}
293
294impl Default for Word2Vec {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301impl Clone for Word2Vec {
302 fn clone(&self) -> Self {
303 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
307
308 Self {
309 config: self.config.clone(),
310 vocabulary: self.vocabulary.clone(),
311 input_embeddings: self.input_embeddings.clone(),
312 output_embeddings: self.output_embeddings.clone(),
313 tokenizer,
314 sampling_table: self.sampling_table.clone(),
315 current_learning_rate: self.current_learning_rate,
316 }
317 }
318}
319
320impl Word2Vec {
321 pub fn new() -> Self {
323 Self {
324 config: Word2VecConfig::default(),
325 vocabulary: Vocabulary::new(),
326 input_embeddings: None,
327 output_embeddings: None,
328 tokenizer: Box::new(WordTokenizer::default()),
329 sampling_table: None,
330 current_learning_rate: 0.025,
331 }
332 }
333
334 pub fn with_config(config: Word2VecConfig) -> Self {
336 let learning_rate = config.learning_rate;
337 Self {
338 config,
339 vocabulary: Vocabulary::new(),
340 input_embeddings: None,
341 output_embeddings: None,
342 tokenizer: Box::new(WordTokenizer::default()),
343 sampling_table: None,
344 current_learning_rate: learning_rate,
345 }
346 }
347
348 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
350 self.tokenizer = tokenizer;
351 self
352 }
353
354 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
356 self.config.vector_size = vectorsize;
357 self
358 }
359
360 pub fn with_window_size(mut self, windowsize: usize) -> Self {
362 self.config.window_size = windowsize;
363 self
364 }
365
366 pub fn with_min_count(mut self, mincount: usize) -> Self {
368 self.config.min_count = mincount;
369 self
370 }
371
372 pub fn with_epochs(mut self, epochs: usize) -> Self {
374 self.config.epochs = epochs;
375 self
376 }
377
378 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
380 self.config.learning_rate = learningrate;
381 self.current_learning_rate = learningrate;
382 self
383 }
384
385 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
387 self.config.algorithm = algorithm;
388 self
389 }
390
391 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
393 self.config.negative_samples = negativesamples;
394 self
395 }
396
397 pub fn with_subsample(mut self, subsample: f64) -> Self {
399 self.config.subsample = subsample;
400 self
401 }
402
403 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
405 self.config.batch_size = batchsize;
406 self
407 }
408
409 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
411 if texts.is_empty() {
412 return Err(TextError::InvalidInput(
413 "No texts provided for building vocabulary".into(),
414 ));
415 }
416
417 let mut word_counts = HashMap::new();
419 let mut _total_words = 0;
420
421 for &text in texts {
422 let tokens = self.tokenizer.tokenize(text)?;
423 for token in tokens {
424 *word_counts.entry(token).or_insert(0) += 1;
425 _total_words += 1;
426 }
427 }
428
429 self.vocabulary = Vocabulary::new();
431 for (word, count) in &word_counts {
432 if *count >= self.config.min_count {
433 self.vocabulary.add_token(word);
434 }
435 }
436
437 if self.vocabulary.is_empty() {
438 return Err(TextError::VocabularyError(
439 "No words meet the minimum count threshold".into(),
440 ));
441 }
442
443 let vocab_size = self.vocabulary.len();
445 let vector_size = self.config.vector_size;
446
447 let mut rng = scirs2_core::random::rng();
449 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
450 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
451 });
452 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
453 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
454 });
455
456 self.input_embeddings = Some(input_embeddings);
457 self.output_embeddings = Some(output_embeddings);
458
459 self.create_sampling_table(&word_counts)?;
461
462 Ok(())
463 }
464
465 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
467 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
469
470 for (word, &count) in wordcounts.iter() {
471 if let Some(idx) = self.vocabulary.get_index(word) {
472 sampling_weights[idx] = (count as f64).powf(0.75);
474 }
475 }
476
477 match SamplingTable::new(&sampling_weights) {
478 Ok(table) => {
479 self.sampling_table = Some(table);
480 Ok(())
481 }
482 Err(e) => Err(e),
483 }
484 }
485
486 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
488 if texts.is_empty() {
489 return Err(TextError::InvalidInput(
490 "No texts provided for training".into(),
491 ));
492 }
493
494 if self.vocabulary.is_empty() {
496 self.build_vocabulary(texts)?;
497 }
498
499 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
500 return Err(TextError::EmbeddingError(
501 "Embeddings not initialized. Call build_vocabulary() first".into(),
502 ));
503 }
504
505 let mut _total_tokens = 0;
507 let mut sentences = Vec::new();
508 for &text in texts {
509 let tokens = self.tokenizer.tokenize(text)?;
510 let filtered_tokens: Vec<usize> = tokens
511 .iter()
512 .filter_map(|token| self.vocabulary.get_index(token))
513 .collect();
514 if !filtered_tokens.is_empty() {
515 _total_tokens += filtered_tokens.len();
516 sentences.push(filtered_tokens);
517 }
518 }
519
520 for epoch in 0..self.config.epochs {
522 self.current_learning_rate =
524 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
525 self.current_learning_rate = self
526 .current_learning_rate
527 .max(self.config.learning_rate * 0.0001);
528
529 for sentence in &sentences {
531 let subsampled_sentence = if self.config.subsample > 0.0 {
533 self.subsample_sentence(sentence)?
534 } else {
535 sentence.clone()
536 };
537
538 if subsampled_sentence.is_empty() {
540 continue;
541 }
542
543 match self.config.algorithm {
545 Word2VecAlgorithm::CBOW => {
546 self.train_cbow_sentence(&subsampled_sentence)?;
547 }
548 Word2VecAlgorithm::SkipGram => {
549 self.train_skipgram_sentence(&subsampled_sentence)?;
550 }
551 }
552 }
553 }
554
555 Ok(())
556 }
557
558 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
560 let mut rng = scirs2_core::random::rng();
561 let total_words: f64 = self.vocabulary.len() as f64;
562 let threshold = self.config.subsample * total_words;
563
564 let subsampled: Vec<usize> = sentence
566 .iter()
567 .filter(|&&word_idx| {
568 let word_freq = self.get_word_frequency(word_idx);
569 if word_freq == 0.0 {
570 return true; }
572 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
574 rng.random::<f64>() < keep_prob
575 })
576 .copied()
577 .collect();
578
579 Ok(subsampled)
580 }
581
582 fn get_word_frequency(&self, wordidx: usize) -> f64 {
584 if let Some(table) = &self.sampling_table {
587 table.weights()[wordidx]
588 } else {
589 1.0 }
591 }
592
593 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
595 if sentence.len() < 2 {
596 return Ok(()); }
598
599 let input_embeddings = self.input_embeddings.as_mut().unwrap();
600 let output_embeddings = self.output_embeddings.as_mut().unwrap();
601 let vector_size = self.config.vector_size;
602 let window_size = self.config.window_size;
603 let negative_samples = self.config.negative_samples;
604
605 for pos in 0..sentence.len() {
607 let mut rng = scirs2_core::random::rng();
609 let window = 1 + rng.random_range(0..window_size);
610 let target_word = sentence[pos];
611
612 let mut context_words = Vec::new();
614 #[allow(clippy::needless_range_loop)]
615 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
616 if i != pos {
617 context_words.push(sentence[i]);
618 }
619 }
620
621 if context_words.is_empty() {
622 continue; }
624
625 let mut context_sum = Array1::zeros(vector_size);
627 for &context_idx in &context_words {
628 context_sum += &input_embeddings.row(context_idx);
629 }
630 let context_avg = &context_sum / context_words.len() as f64;
631
632 let mut target_output = output_embeddings.row_mut(target_word);
634 let dot_product = (&context_avg * &target_output).sum();
635 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
636 let error = (1.0 - sigmoid) * self.current_learning_rate;
637
638 let mut target_update = target_output.to_owned();
640 target_update.scaled_add(error, &context_avg);
641 target_output.assign(&target_update);
642
643 if let Some(sampler) = &self.sampling_table {
645 for _ in 0..negative_samples {
646 let negative_idx = sampler.sample(&mut rng);
647 if negative_idx == target_word {
648 continue; }
650
651 let mut negative_output = output_embeddings.row_mut(negative_idx);
652 let dot_product = (&context_avg * &negative_output).sum();
653 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
654 let error = -sigmoid * self.current_learning_rate;
655
656 let mut negative_update = negative_output.to_owned();
658 negative_update.scaled_add(error, &context_avg);
659 negative_output.assign(&negative_update);
660 }
661 }
662
663 for &context_idx in &context_words {
665 let mut input_vec = input_embeddings.row_mut(context_idx);
666
667 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
669 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
670 let error =
671 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
672
673 let mut input_update = input_vec.to_owned();
675 input_update.scaled_add(error, &output_embeddings.row(target_word));
676
677 if let Some(sampler) = &self.sampling_table {
679 for _ in 0..negative_samples {
680 let negative_idx = sampler.sample(&mut rng);
681 if negative_idx == target_word {
682 continue;
683 }
684
685 let dot_product =
686 (&context_avg * &output_embeddings.row(negative_idx)).sum();
687 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
688 let error =
689 -sigmoid * self.current_learning_rate / context_words.len() as f64;
690
691 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
692 }
693 }
694
695 input_vec.assign(&input_update);
696 }
697 }
698
699 Ok(())
700 }
701
702 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
704 if sentence.len() < 2 {
705 return Ok(()); }
707
708 let input_embeddings = self.input_embeddings.as_mut().unwrap();
709 let output_embeddings = self.output_embeddings.as_mut().unwrap();
710 let vector_size = self.config.vector_size;
711 let window_size = self.config.window_size;
712 let negative_samples = self.config.negative_samples;
713
714 for pos in 0..sentence.len() {
716 let mut rng = scirs2_core::random::rng();
718 let window = 1 + rng.random_range(0..window_size);
719 let target_word = sentence[pos];
720
721 #[allow(clippy::needless_range_loop)]
723 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
724 if i == pos {
725 continue; }
727
728 let context_word = sentence[i];
729 let target_input = input_embeddings.row(target_word);
730 let mut context_output = output_embeddings.row_mut(context_word);
731
732 let dot_product = (&target_input * &context_output).sum();
734 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
735 let error = (1.0 - sigmoid) * self.current_learning_rate;
736
737 let mut context_update = context_output.to_owned();
739 context_update.scaled_add(error, &target_input);
740 context_output.assign(&context_update);
741
742 let mut input_update = Array1::zeros(vector_size);
744 input_update.scaled_add(error, &context_output);
745
746 if let Some(sampler) = &self.sampling_table {
748 for _ in 0..negative_samples {
749 let negative_idx = sampler.sample(&mut rng);
750 if negative_idx == context_word {
751 continue; }
753
754 let mut negative_output = output_embeddings.row_mut(negative_idx);
755 let dot_product = (&target_input * &negative_output).sum();
756 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
757 let error = -sigmoid * self.current_learning_rate;
758
759 let mut negative_update = negative_output.to_owned();
761 negative_update.scaled_add(error, &target_input);
762 negative_output.assign(&negative_update);
763
764 input_update.scaled_add(error, &negative_output);
766 }
767 }
768
769 let mut target_input_mut = input_embeddings.row_mut(target_word);
771 target_input_mut += &input_update;
772 }
773 }
774
775 Ok(())
776 }
777
778 pub fn vector_size(&self) -> usize {
780 self.config.vector_size
781 }
782
783 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
785 if self.input_embeddings.is_none() {
786 return Err(TextError::EmbeddingError(
787 "Model not trained. Call train() first".into(),
788 ));
789 }
790
791 match self.vocabulary.get_index(word) {
792 Some(idx) => Ok(self.input_embeddings.as_ref().unwrap().row(idx).to_owned()),
793 None => Err(TextError::VocabularyError(format!(
794 "Word '{word}' not in vocabulary"
795 ))),
796 }
797 }
798
799 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
801 let word_vec = self.get_word_vector(word)?;
802 self.most_similar_by_vector(&word_vec, topn, &[word])
803 }
804
805 pub fn most_similar_by_vector(
807 &self,
808 vector: &Array1<f64>,
809 top_n: usize,
810 exclude_words: &[&str],
811 ) -> Result<Vec<(String, f64)>> {
812 if self.input_embeddings.is_none() {
813 return Err(TextError::EmbeddingError(
814 "Model not trained. Call train() first".into(),
815 ));
816 }
817
818 let input_embeddings = self.input_embeddings.as_ref().unwrap();
819 let vocab_size = self.vocabulary.len();
820
821 let exclude_indices: Vec<usize> = exclude_words
823 .iter()
824 .filter_map(|&word| self.vocabulary.get_index(word))
825 .collect();
826
827 let mut similarities = Vec::with_capacity(vocab_size);
829
830 for i in 0..vocab_size {
831 if exclude_indices.contains(&i) {
832 continue;
833 }
834
835 let word_vec = input_embeddings.row(i);
836 let similarity = cosine_similarity(vector, &word_vec.to_owned());
837
838 if let Some(word) = self.vocabulary.get_token(i) {
839 similarities.push((word.to_string(), similarity));
840 }
841 }
842
843 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845
846 let result = similarities.into_iter().take(top_n).collect();
848 Ok(result)
849 }
850
851 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
853 if self.input_embeddings.is_none() {
854 return Err(TextError::EmbeddingError(
855 "Model not trained. Call train() first".into(),
856 ));
857 }
858
859 let a_vec = self.get_word_vector(a)?;
861 let b_vec = self.get_word_vector(b)?;
862 let c_vec = self.get_word_vector(c)?;
863
864 let mut d_vec = b_vec.clone();
866 d_vec -= &a_vec;
867 d_vec += &c_vec;
868
869 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
871 d_vec.mapv_inplace(|val| val / norm);
872
873 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
875 }
876
877 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
879 if self.input_embeddings.is_none() {
880 return Err(TextError::EmbeddingError(
881 "Model not trained. Call train() first".into(),
882 ));
883 }
884
885 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
886
887 writeln!(
889 &mut file,
890 "{} {}",
891 self.vocabulary.len(),
892 self.config.vector_size
893 )
894 .map_err(|e| TextError::IoError(e.to_string()))?;
895
896 let input_embeddings = self.input_embeddings.as_ref().unwrap();
898
899 for i in 0..self.vocabulary.len() {
900 if let Some(word) = self.vocabulary.get_token(i) {
901 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
903
904 let vector = input_embeddings.row(i);
906 for j in 0..self.config.vector_size {
907 write!(&mut file, "{:.6} ", vector[j])
908 .map_err(|e| TextError::IoError(e.to_string()))?;
909 }
910
911 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
912 }
913 }
914
915 Ok(())
916 }
917
918 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
920 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
921 let mut reader = BufReader::new(file);
922
923 let mut header = String::new();
925 reader
926 .read_line(&mut header)
927 .map_err(|e| TextError::IoError(e.to_string()))?;
928
929 let parts: Vec<&str> = header.split_whitespace().collect();
930 if parts.len() != 2 {
931 return Err(TextError::EmbeddingError(
932 "Invalid model file format".into(),
933 ));
934 }
935
936 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
937 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
938 })?;
939
940 let vector_size = parts[1]
941 .parse::<usize>()
942 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
943
944 let mut model = Self::new().with_vector_size(vector_size);
946 let mut vocabulary = Vocabulary::new();
947 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
948
949 let mut i = 0;
951 for line in reader.lines() {
952 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
953 let parts: Vec<&str> = line.split_whitespace().collect();
954
955 if parts.len() != vector_size + 1 {
956 let line_num = i + 2;
957 return Err(TextError::EmbeddingError(format!(
958 "Invalid vector format at line {line_num}"
959 )));
960 }
961
962 let word = parts[0];
963 vocabulary.add_token(word);
964
965 for j in 0..vector_size {
966 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
967 TextError::EmbeddingError(format!(
968 "Invalid vector component at line {}, position {}",
969 i + 2,
970 j + 1
971 ))
972 })?;
973 }
974
975 i += 1;
976 }
977
978 if i != vocab_size {
979 return Err(TextError::EmbeddingError(format!(
980 "Expected {vocab_size} words but found {i}"
981 )));
982 }
983
984 model.vocabulary = vocabulary;
985 model.input_embeddings = Some(input_embeddings);
986 model.output_embeddings = None; Ok(model)
989 }
990
991 pub fn get_vocabulary(&self) -> Vec<String> {
995 let mut vocab = Vec::new();
996 for i in 0..self.vocabulary.len() {
997 if let Some(token) = self.vocabulary.get_token(i) {
998 vocab.push(token.to_string());
999 }
1000 }
1001 vocab
1002 }
1003
1004 pub fn get_vector_size(&self) -> usize {
1006 self.config.vector_size
1007 }
1008
1009 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1011 self.config.algorithm
1012 }
1013
1014 pub fn get_window_size(&self) -> usize {
1016 self.config.window_size
1017 }
1018
1019 pub fn get_min_count(&self) -> usize {
1021 self.config.min_count
1022 }
1023
1024 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1026 self.input_embeddings.clone()
1027 }
1028
1029 pub fn get_negative_samples(&self) -> usize {
1031 self.config.negative_samples
1032 }
1033
1034 pub fn get_learning_rate(&self) -> f64 {
1036 self.config.learning_rate
1037 }
1038
1039 pub fn get_epochs(&self) -> usize {
1041 self.config.epochs
1042 }
1043
1044 pub fn get_subsampling_threshold(&self) -> f64 {
1046 self.config.subsample
1047 }
1048}
1049
1050#[allow(dead_code)]
1052pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1053 let dot_product = (a * b).sum();
1054 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1055 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1056
1057 if norm_a > 0.0 && norm_b > 0.0 {
1058 dot_product / (norm_a * norm_b)
1059 } else {
1060 0.0
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067 use approx::assert_relative_eq;
1068
1069 #[test]
1070 fn test_cosine_similarity() {
1071 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1072 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1073
1074 let similarity = cosine_similarity(&a, &b);
1075 let expected = 0.9746318461970762;
1076 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1077 }
1078
1079 #[test]
1080 fn test_word2vec_config() {
1081 let config = Word2VecConfig::default();
1082 assert_eq!(config.vector_size, 100);
1083 assert_eq!(config.window_size, 5);
1084 assert_eq!(config.min_count, 5);
1085 assert_eq!(config.epochs, 5);
1086 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1087 }
1088
1089 #[test]
1090 fn test_word2vec_builder() {
1091 let model = Word2Vec::new()
1092 .with_vector_size(200)
1093 .with_window_size(10)
1094 .with_learning_rate(0.05)
1095 .with_algorithm(Word2VecAlgorithm::CBOW);
1096
1097 assert_eq!(model.config.vector_size, 200);
1098 assert_eq!(model.config.window_size, 10);
1099 assert_eq!(model.config.learning_rate, 0.05);
1100 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1101 }
1102
1103 #[test]
1104 fn test_build_vocabulary() {
1105 let texts = [
1106 "the quick brown fox jumps over the lazy dog",
1107 "a quick brown fox jumps over a lazy dog",
1108 ];
1109
1110 let mut model = Word2Vec::new().with_min_count(1);
1111 let result = model.build_vocabulary(&texts);
1112 assert!(result.is_ok());
1113
1114 assert_eq!(model.vocabulary.len(), 9);
1116
1117 assert!(model.input_embeddings.is_some());
1119 assert!(model.output_embeddings.is_some());
1120 assert_eq!(model.input_embeddings.as_ref().unwrap().shape(), &[9, 100]);
1121 }
1122
1123 #[test]
1124 fn test_skipgram_training_small() {
1125 let texts = [
1126 "the quick brown fox jumps over the lazy dog",
1127 "a quick brown fox jumps over a lazy dog",
1128 ];
1129
1130 let mut model = Word2Vec::new()
1131 .with_vector_size(10)
1132 .with_window_size(2)
1133 .with_min_count(1)
1134 .with_epochs(1)
1135 .with_algorithm(Word2VecAlgorithm::SkipGram);
1136
1137 let result = model.train(&texts);
1138 assert!(result.is_ok());
1139
1140 let result = model.get_word_vector("fox");
1142 assert!(result.is_ok());
1143 let vec = result.unwrap();
1144 assert_eq!(vec.len(), 10);
1145 }
1146}