scirs2_text/
embeddings.rs

1//! # Word Embeddings Module
2//!
3//! This module provides comprehensive implementations for word embeddings, including
4//! Word2Vec with both Skip-gram and CBOW (Continuous Bag of Words) models.
5//!
6//! ## Overview
7//!
8//! Word embeddings are dense vector representations of words that capture semantic
9//! relationships. This module implements:
10//!
11//! - **Word2Vec Skip-gram**: Predicts context words from a target word
12//! - **Word2Vec CBOW**: Predicts a target word from context words
13//! - **Negative sampling**: Efficient training technique for large vocabularies
14//! - **Hierarchical softmax**: Alternative to negative sampling for optimization
15//!
16//! ## Quick Start
17//!
18//! ```rust
19//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
20//!
21//! // Basic Word2Vec training
22//! let documents = vec![
23//!     "the quick brown fox jumps over the lazy dog",
24//!     "the dog was lazy but the fox was quick",
25//!     "brown fox and lazy dog are common phrases"
26//! ];
27//!
28//! let config = Word2VecConfig {
29//!     algorithm: Word2VecAlgorithm::SkipGram,
30//!     vector_size: 100,
31//!     window_size: 5,
32//!     min_count: 1,
33//!     learning_rate: 0.025,
34//!     epochs: 5,
35//!     negative_samples: 5,
36//!     ..Default::default()
37//! };
38//!
39//! let mut model = Word2Vec::with_config(config);
40//! model.train(&documents).expect("Training failed");
41//!
42//! // Get word vector
43//! if let Ok(vector) = model.get_word_vector("fox") {
44//!     println!("Vector for 'fox': {:?}", vector);
45//! }
46//!
47//! // Find similar words
48//! let similar = model.most_similar("fox", 3).expect("Failed to find similar words");
49//! for (word, similarity) in similar {
50//!     println!("{}: {:.4}", word, similarity);
51//! }
52//! ```
53//!
54//! ## Advanced Usage
55//!
56//! ### Custom Configuration
57//!
58//! ```rust
59//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
60//!
61//! let config = Word2VecConfig {
62//!     algorithm: Word2VecAlgorithm::CBOW,
63//!     vector_size: 300,        // Larger vectors for better quality
64//!     window_size: 10,         // Larger context window
65//!     min_count: 5,           // Filter rare words
66//!     learning_rate: 0.01,    // Lower learning rate for stability
67//!     epochs: 15,             // More training iterations
68//!     negative_samples: 10,   // More negative samples
69//!     subsample: 1e-4, // Subsample frequent words
70//!     batch_size: 128,
71//!     hierarchical_softmax: false,
72//! };
73//! ```
74//!
75//! ### Incremental Training
76//!
77//! ```rust
78//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
79//! # let mut model = Word2Vec::new().with_min_count(1);
80//! // Initial training
81//! let batch1 = vec!["the quick brown fox jumps over the lazy dog"];
82//! model.train(&batch1).expect("Training failed");
83//!
84//! // Continue training with new data
85//! let batch2 = vec!["the dog was lazy but the fox was quick"];
86//! model.train(&batch2).expect("Training failed");
87//! ```
88//!
89//! ### Saving and Loading Models
90//!
91//! ```rust
92//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
93//! # let mut model = Word2Vec::new().with_min_count(1);
94//! # let texts = vec!["the quick brown fox jumps over the lazy dog"];
95//! # model.train(&texts).expect("Training failed");
96//! // Save trained model
97//! model.save("my_model.w2v").expect("Failed to save model");
98//!
99//! // Load model
100//! let loaded_model = Word2Vec::load("my_model.w2v")
101//!     .expect("Failed to load model");
102//! ```
103//!
104//! ## Performance Tips
105//!
106//! 1. **Vocabulary Size**: Use `min_count` to filter rare words and reduce memory usage
107//! 2. **Vector Dimensions**: Balance between quality (higher dimensions) and speed (lower dimensions)
108//! 3. **Training Algorithm**: Skip-gram works better with rare words, CBOW is faster
109//! 4. **Negative Sampling**: Usually faster than hierarchical softmax for large vocabularies
110//! 5. **Subsampling**: Set `subsample_threshold` to handle frequent words efficiently
111//!
112//! ## Mathematical Background
113//!
114//! ### Skip-gram Model
115//!
116//! The Skip-gram model maximizes the probability of context words given a target word:
117//!
118//! P(context|target) = ∏ P(w_context|w_target)
119//!
120//! ### CBOW Model
121//!
122//! The CBOW model predicts a target word from its context:
123//!
124//! P(target|context) = P(w_target|w_context1, w_context2, ...)
125//!
126//! ### Negative Sampling
127//!
128//! Instead of computing the full softmax, negative sampling approximates it by
129//! sampling negative examples, making training much more efficient.
130
131use crate::error::{Result, TextError};
132use crate::tokenize::{Tokenizer, WordTokenizer};
133use crate::vocabulary::Vocabulary;
134use scirs2_core::ndarray::{Array1, Array2};
135use scirs2_core::random::prelude::*;
136use std::collections::HashMap;
137use std::fmt::Debug;
138use std::fs::File;
139use std::io::{BufRead, BufReader, Write};
140use std::path::Path;
141
142/// A simplified weighted sampling table
143#[derive(Debug, Clone)]
144struct SamplingTable {
145    /// The cumulative distribution function (CDF)
146    cdf: Vec<f64>,
147    /// The weights
148    weights: Vec<f64>,
149}
150
151impl SamplingTable {
152    /// Create a new sampling table from weights
153    fn new(weights: &[f64]) -> Result<Self> {
154        if weights.is_empty() {
155            return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
156        }
157
158        // Ensure all _weights are positive
159        if weights.iter().any(|&w| w < 0.0) {
160            return Err(TextError::EmbeddingError("Weights must be positive".into()));
161        }
162
163        // Compute the CDF
164        let sum: f64 = weights.iter().sum();
165        if sum <= 0.0 {
166            return Err(TextError::EmbeddingError(
167                "Sum of _weights must be positive".into(),
168            ));
169        }
170
171        let mut cdf = Vec::with_capacity(weights.len());
172        let mut total = 0.0;
173
174        for &w in weights {
175            total += w;
176            cdf.push(total / sum);
177        }
178
179        Ok(Self {
180            cdf,
181            weights: weights.to_vec(),
182        })
183    }
184
185    /// Sample an index based on the weights
186    fn sample<R: Rng>(&self, rng: &mut R) -> usize {
187        let r = rng.random::<f64>();
188
189        // Binary search for the insertion point
190        match self.cdf.binary_search_by(|&cdf_val| {
191            cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
192        }) {
193            Ok(idx) => idx,
194            Err(idx) => idx,
195        }
196    }
197
198    /// Get the weights
199    fn weights(&self) -> &[f64] {
200        &self.weights
201    }
202}
203
204/// Word2Vec training algorithms
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum Word2VecAlgorithm {
207    /// Continuous Bag of Words (CBOW) algorithm
208    CBOW,
209    /// Skip-gram algorithm
210    SkipGram,
211}
212
213/// Word2Vec training options
214#[derive(Debug, Clone)]
215pub struct Word2VecConfig {
216    /// Size of the word vectors
217    pub vector_size: usize,
218    /// Maximum distance between the current and predicted word within a sentence
219    pub window_size: usize,
220    /// Minimum count of words to consider for training
221    pub min_count: usize,
222    /// Number of iterations (epochs) over the corpus
223    pub epochs: usize,
224    /// Learning rate
225    pub learning_rate: f64,
226    /// Skip-gram or CBOW algorithm
227    pub algorithm: Word2VecAlgorithm,
228    /// Number of negative samples per positive sample
229    pub negative_samples: usize,
230    /// Threshold for subsampling frequent words
231    pub subsample: f64,
232    /// Batch size for training
233    pub batch_size: usize,
234    /// Whether to use hierarchical softmax (not yet implemented)
235    pub hierarchical_softmax: bool,
236}
237
238impl Default for Word2VecConfig {
239    fn default() -> Self {
240        Self {
241            vector_size: 100,
242            window_size: 5,
243            min_count: 5,
244            epochs: 5,
245            learning_rate: 0.025,
246            algorithm: Word2VecAlgorithm::SkipGram,
247            negative_samples: 5,
248            subsample: 1e-3,
249            batch_size: 128,
250            hierarchical_softmax: false,
251        }
252    }
253}
254
255/// Word2Vec model for training and using word embeddings
256///
257/// Word2Vec is an algorithm for learning vector representations of words,
258/// also known as word embeddings. These vectors capture semantic meanings
259/// of words, allowing operations like "king - man + woman" to result in
260/// a vector close to "queen".
261///
262/// This implementation supports both Continuous Bag of Words (CBOW) and
263/// Skip-gram models, with negative sampling for efficient training.
264pub struct Word2Vec {
265    /// Configuration options
266    config: Word2VecConfig,
267    /// Vocabulary
268    vocabulary: Vocabulary,
269    /// Input embeddings
270    input_embeddings: Option<Array2<f64>>,
271    /// Output embeddings
272    output_embeddings: Option<Array2<f64>>,
273    /// Tokenizer
274    tokenizer: Box<dyn Tokenizer + Send + Sync>,
275    /// Sampling table for negative sampling
276    sampling_table: Option<SamplingTable>,
277    /// Current learning rate (gets updated during training)
278    current_learning_rate: f64,
279}
280
281impl Debug for Word2Vec {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("Word2Vec")
284            .field("config", &self.config)
285            .field("vocabulary", &self.vocabulary)
286            .field("input_embeddings", &self.input_embeddings)
287            .field("output_embeddings", &self.output_embeddings)
288            .field("sampling_table", &self.sampling_table)
289            .field("current_learning_rate", &self.current_learning_rate)
290            .finish()
291    }
292}
293
294// Manual Clone implementation to handle the non-Clone tokenizer
295impl Default for Word2Vec {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301impl Clone for Word2Vec {
302    fn clone(&self) -> Self {
303        // Create a new tokenizer of the same type
304        // For simplicity, we always use WordTokenizer when cloning
305        // A more sophisticated solution would be to add a clone method to the Tokenizer trait
306        let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
307
308        Self {
309            config: self.config.clone(),
310            vocabulary: self.vocabulary.clone(),
311            input_embeddings: self.input_embeddings.clone(),
312            output_embeddings: self.output_embeddings.clone(),
313            tokenizer,
314            sampling_table: self.sampling_table.clone(),
315            current_learning_rate: self.current_learning_rate,
316        }
317    }
318}
319
320impl Word2Vec {
321    /// Create a new Word2Vec model with default configuration
322    pub fn new() -> Self {
323        Self {
324            config: Word2VecConfig::default(),
325            vocabulary: Vocabulary::new(),
326            input_embeddings: None,
327            output_embeddings: None,
328            tokenizer: Box::new(WordTokenizer::default()),
329            sampling_table: None,
330            current_learning_rate: 0.025,
331        }
332    }
333
334    /// Create a new Word2Vec model with the specified configuration
335    pub fn with_config(config: Word2VecConfig) -> Self {
336        let learning_rate = config.learning_rate;
337        Self {
338            config,
339            vocabulary: Vocabulary::new(),
340            input_embeddings: None,
341            output_embeddings: None,
342            tokenizer: Box::new(WordTokenizer::default()),
343            sampling_table: None,
344            current_learning_rate: learning_rate,
345        }
346    }
347
348    /// Set a custom tokenizer
349    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
350        self.tokenizer = tokenizer;
351        self
352    }
353
354    /// Set vector size
355    pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
356        self.config.vector_size = vectorsize;
357        self
358    }
359
360    /// Set window size
361    pub fn with_window_size(mut self, windowsize: usize) -> Self {
362        self.config.window_size = windowsize;
363        self
364    }
365
366    /// Set minimum count
367    pub fn with_min_count(mut self, mincount: usize) -> Self {
368        self.config.min_count = mincount;
369        self
370    }
371
372    /// Set number of epochs
373    pub fn with_epochs(mut self, epochs: usize) -> Self {
374        self.config.epochs = epochs;
375        self
376    }
377
378    /// Set learning rate
379    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
380        self.config.learning_rate = learningrate;
381        self.current_learning_rate = learningrate;
382        self
383    }
384
385    /// Set algorithm (CBOW or Skip-gram)
386    pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
387        self.config.algorithm = algorithm;
388        self
389    }
390
391    /// Set number of negative samples
392    pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
393        self.config.negative_samples = negativesamples;
394        self
395    }
396
397    /// Set subsampling threshold
398    pub fn with_subsample(mut self, subsample: f64) -> Self {
399        self.config.subsample = subsample;
400        self
401    }
402
403    /// Set batch size
404    pub fn with_batch_size(mut self, batchsize: usize) -> Self {
405        self.config.batch_size = batchsize;
406        self
407    }
408
409    /// Build vocabulary from a corpus
410    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
411        if texts.is_empty() {
412            return Err(TextError::InvalidInput(
413                "No texts provided for building vocabulary".into(),
414            ));
415        }
416
417        // Count word frequencies
418        let mut word_counts = HashMap::new();
419        let mut _total_words = 0;
420
421        for &text in texts {
422            let tokens = self.tokenizer.tokenize(text)?;
423            for token in tokens {
424                *word_counts.entry(token).or_insert(0) += 1;
425                _total_words += 1;
426            }
427        }
428
429        // Create vocabulary with words that meet minimum count
430        self.vocabulary = Vocabulary::new();
431        for (word, count) in &word_counts {
432            if *count >= self.config.min_count {
433                self.vocabulary.add_token(word);
434            }
435        }
436
437        if self.vocabulary.is_empty() {
438            return Err(TextError::VocabularyError(
439                "No words meet the minimum count threshold".into(),
440            ));
441        }
442
443        // Initialize embeddings
444        let vocab_size = self.vocabulary.len();
445        let vector_size = self.config.vector_size;
446
447        // Initialize input and output embeddings with small random values
448        let mut rng = scirs2_core::random::rng();
449        let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
450            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
451        });
452        let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
453            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
454        });
455
456        self.input_embeddings = Some(input_embeddings);
457        self.output_embeddings = Some(output_embeddings);
458
459        // Create sampling table for negative sampling
460        self.create_sampling_table(&word_counts)?;
461
462        Ok(())
463    }
464
465    /// Create sampling table for negative sampling based on word frequencies
466    fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
467        // Prepare sampling weights (unigram distribution raised to 3/4 power)
468        let mut sampling_weights = vec![0.0; self.vocabulary.len()];
469
470        for (word, &count) in wordcounts.iter() {
471            if let Some(idx) = self.vocabulary.get_index(word) {
472                // Apply smoothing: frequency^0.75
473                sampling_weights[idx] = (count as f64).powf(0.75);
474            }
475        }
476
477        match SamplingTable::new(&sampling_weights) {
478            Ok(table) => {
479                self.sampling_table = Some(table);
480                Ok(())
481            }
482            Err(e) => Err(e),
483        }
484    }
485
486    /// Train the Word2Vec model on a corpus
487    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
488        if texts.is_empty() {
489            return Err(TextError::InvalidInput(
490                "No texts provided for training".into(),
491            ));
492        }
493
494        // Build vocabulary if not already built
495        if self.vocabulary.is_empty() {
496            self.build_vocabulary(texts)?;
497        }
498
499        if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
500            return Err(TextError::EmbeddingError(
501                "Embeddings not initialized. Call build_vocabulary() first".into(),
502            ));
503        }
504
505        // Count total number of tokens for progress tracking
506        let mut _total_tokens = 0;
507        let mut sentences = Vec::new();
508        for &text in texts {
509            let tokens = self.tokenizer.tokenize(text)?;
510            let filtered_tokens: Vec<usize> = tokens
511                .iter()
512                .filter_map(|token| self.vocabulary.get_index(token))
513                .collect();
514            if !filtered_tokens.is_empty() {
515                _total_tokens += filtered_tokens.len();
516                sentences.push(filtered_tokens);
517            }
518        }
519
520        // Train for the specified number of epochs
521        for epoch in 0..self.config.epochs {
522            // Update learning rate for this epoch
523            self.current_learning_rate =
524                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
525            self.current_learning_rate = self
526                .current_learning_rate
527                .max(self.config.learning_rate * 0.0001);
528
529            // Process each sentence
530            for sentence in &sentences {
531                // Apply subsampling of frequent words
532                let subsampled_sentence = if self.config.subsample > 0.0 {
533                    self.subsample_sentence(sentence)?
534                } else {
535                    sentence.clone()
536                };
537
538                // Skip empty sentences
539                if subsampled_sentence.is_empty() {
540                    continue;
541                }
542
543                // Train on the sentence
544                match self.config.algorithm {
545                    Word2VecAlgorithm::CBOW => {
546                        self.train_cbow_sentence(&subsampled_sentence)?;
547                    }
548                    Word2VecAlgorithm::SkipGram => {
549                        self.train_skipgram_sentence(&subsampled_sentence)?;
550                    }
551                }
552            }
553        }
554
555        Ok(())
556    }
557
558    /// Apply subsampling to a sentence
559    fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
560        let mut rng = scirs2_core::random::rng();
561        let total_words: f64 = self.vocabulary.len() as f64;
562        let threshold = self.config.subsample * total_words;
563
564        // Filter words based on subsampling probability
565        let subsampled: Vec<usize> = sentence
566            .iter()
567            .filter(|&&word_idx| {
568                let word_freq = self.get_word_frequency(word_idx);
569                if word_freq == 0.0 {
570                    return true; // Keep rare words
571                }
572                // Probability of keeping the word
573                let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
574                rng.random::<f64>() < keep_prob
575            })
576            .copied()
577            .collect();
578
579        Ok(subsampled)
580    }
581
582    /// Get the frequency of a word in the vocabulary
583    fn get_word_frequency(&self, wordidx: usize) -> f64 {
584        // This is a simplified version; ideal implementation would track actual frequencies
585        // For now, we'll use the sampling table weights as a proxy
586        if let Some(table) = &self.sampling_table {
587            table.weights()[wordidx]
588        } else {
589            1.0 // Default weight if no sampling table exists
590        }
591    }
592
593    /// Train CBOW model on a single sentence
594    fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
595        if sentence.len() < 2 {
596            return Ok(()); // Need at least 2 words for context
597        }
598
599        let input_embeddings = self.input_embeddings.as_mut().unwrap();
600        let output_embeddings = self.output_embeddings.as_mut().unwrap();
601        let vector_size = self.config.vector_size;
602        let window_size = self.config.window_size;
603        let negative_samples = self.config.negative_samples;
604
605        // For each position in sentence, predict the word from its context
606        for pos in 0..sentence.len() {
607            // Determine context window (with random size)
608            let mut rng = scirs2_core::random::rng();
609            let window = 1 + rng.random_range(0..window_size);
610            let target_word = sentence[pos];
611
612            // Collect context words and average their vectors
613            let mut context_words = Vec::new();
614            #[allow(clippy::needless_range_loop)]
615            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
616                if i != pos {
617                    context_words.push(sentence[i]);
618                }
619            }
620
621            if context_words.is_empty() {
622                continue; // No context words
623            }
624
625            // Average the context word vectors
626            let mut context_sum = Array1::zeros(vector_size);
627            for &context_idx in &context_words {
628                context_sum += &input_embeddings.row(context_idx);
629            }
630            let context_avg = &context_sum / context_words.len() as f64;
631
632            // Update target word's output embedding with positive example
633            let mut target_output = output_embeddings.row_mut(target_word);
634            let dot_product = (&context_avg * &target_output).sum();
635            let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
636            let error = (1.0 - sigmoid) * self.current_learning_rate;
637
638            // Create a copy for update
639            let mut target_update = target_output.to_owned();
640            target_update.scaled_add(error, &context_avg);
641            target_output.assign(&target_update);
642
643            // Negative sampling
644            if let Some(sampler) = &self.sampling_table {
645                for _ in 0..negative_samples {
646                    let negative_idx = sampler.sample(&mut rng);
647                    if negative_idx == target_word {
648                        continue; // Skip if we sample the target word
649                    }
650
651                    let mut negative_output = output_embeddings.row_mut(negative_idx);
652                    let dot_product = (&context_avg * &negative_output).sum();
653                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
654                    let error = -sigmoid * self.current_learning_rate;
655
656                    // Create a copy for update
657                    let mut negative_update = negative_output.to_owned();
658                    negative_update.scaled_add(error, &context_avg);
659                    negative_output.assign(&negative_update);
660                }
661            }
662
663            // Update context word vectors
664            for &context_idx in &context_words {
665                let mut input_vec = input_embeddings.row_mut(context_idx);
666
667                // Positive example
668                let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
669                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
670                let error =
671                    (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
672
673                // Create a copy for update
674                let mut input_update = input_vec.to_owned();
675                input_update.scaled_add(error, &output_embeddings.row(target_word));
676
677                // Negative examples
678                if let Some(sampler) = &self.sampling_table {
679                    for _ in 0..negative_samples {
680                        let negative_idx = sampler.sample(&mut rng);
681                        if negative_idx == target_word {
682                            continue;
683                        }
684
685                        let dot_product =
686                            (&context_avg * &output_embeddings.row(negative_idx)).sum();
687                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
688                        let error =
689                            -sigmoid * self.current_learning_rate / context_words.len() as f64;
690
691                        input_update.scaled_add(error, &output_embeddings.row(negative_idx));
692                    }
693                }
694
695                input_vec.assign(&input_update);
696            }
697        }
698
699        Ok(())
700    }
701
702    /// Train Skip-gram model on a single sentence
703    fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
704        if sentence.len() < 2 {
705            return Ok(()); // Need at least 2 words for context
706        }
707
708        let input_embeddings = self.input_embeddings.as_mut().unwrap();
709        let output_embeddings = self.output_embeddings.as_mut().unwrap();
710        let vector_size = self.config.vector_size;
711        let window_size = self.config.window_size;
712        let negative_samples = self.config.negative_samples;
713
714        // For each position in sentence, predict the context from the word
715        for pos in 0..sentence.len() {
716            // Determine context window (with random size)
717            let mut rng = scirs2_core::random::rng();
718            let window = 1 + rng.random_range(0..window_size);
719            let target_word = sentence[pos];
720
721            // For each context position
722            #[allow(clippy::needless_range_loop)]
723            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
724                if i == pos {
725                    continue; // Skip the target word itself
726                }
727
728                let context_word = sentence[i];
729                let target_input = input_embeddings.row(target_word);
730                let mut context_output = output_embeddings.row_mut(context_word);
731
732                // Update context word's output embedding with positive example
733                let dot_product = (&target_input * &context_output).sum();
734                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
735                let error = (1.0 - sigmoid) * self.current_learning_rate;
736
737                // Create a copy for update
738                let mut context_update = context_output.to_owned();
739                context_update.scaled_add(error, &target_input);
740                context_output.assign(&context_update);
741
742                // Gradient for input word vector
743                let mut input_update = Array1::zeros(vector_size);
744                input_update.scaled_add(error, &context_output);
745
746                // Negative sampling
747                if let Some(sampler) = &self.sampling_table {
748                    for _ in 0..negative_samples {
749                        let negative_idx = sampler.sample(&mut rng);
750                        if negative_idx == context_word {
751                            continue; // Skip if we sample the context word
752                        }
753
754                        let mut negative_output = output_embeddings.row_mut(negative_idx);
755                        let dot_product = (&target_input * &negative_output).sum();
756                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
757                        let error = -sigmoid * self.current_learning_rate;
758
759                        // Create a copy for update
760                        let mut negative_update = negative_output.to_owned();
761                        negative_update.scaled_add(error, &target_input);
762                        negative_output.assign(&negative_update);
763
764                        // Update input gradient
765                        input_update.scaled_add(error, &negative_output);
766                    }
767                }
768
769                // Apply the accumulated gradient to the input word vector
770                let mut target_input_mut = input_embeddings.row_mut(target_word);
771                target_input_mut += &input_update;
772            }
773        }
774
775        Ok(())
776    }
777
778    /// Get the vector size
779    pub fn vector_size(&self) -> usize {
780        self.config.vector_size
781    }
782
783    /// Get the embedding vector for a word
784    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
785        if self.input_embeddings.is_none() {
786            return Err(TextError::EmbeddingError(
787                "Model not trained. Call train() first".into(),
788            ));
789        }
790
791        match self.vocabulary.get_index(word) {
792            Some(idx) => Ok(self.input_embeddings.as_ref().unwrap().row(idx).to_owned()),
793            None => Err(TextError::VocabularyError(format!(
794                "Word '{word}' not in vocabulary"
795            ))),
796        }
797    }
798
799    /// Get the most similar words to a given word
800    pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
801        let word_vec = self.get_word_vector(word)?;
802        self.most_similar_by_vector(&word_vec, topn, &[word])
803    }
804
805    /// Get the most similar words to a given vector
806    pub fn most_similar_by_vector(
807        &self,
808        vector: &Array1<f64>,
809        top_n: usize,
810        exclude_words: &[&str],
811    ) -> Result<Vec<(String, f64)>> {
812        if self.input_embeddings.is_none() {
813            return Err(TextError::EmbeddingError(
814                "Model not trained. Call train() first".into(),
815            ));
816        }
817
818        let input_embeddings = self.input_embeddings.as_ref().unwrap();
819        let vocab_size = self.vocabulary.len();
820
821        // Create a set of indices to exclude
822        let exclude_indices: Vec<usize> = exclude_words
823            .iter()
824            .filter_map(|&word| self.vocabulary.get_index(word))
825            .collect();
826
827        // Calculate cosine similarity for all _words
828        let mut similarities = Vec::with_capacity(vocab_size);
829
830        for i in 0..vocab_size {
831            if exclude_indices.contains(&i) {
832                continue;
833            }
834
835            let word_vec = input_embeddings.row(i);
836            let similarity = cosine_similarity(vector, &word_vec.to_owned());
837
838            if let Some(word) = self.vocabulary.get_token(i) {
839                similarities.push((word.to_string(), similarity));
840            }
841        }
842
843        // Sort by similarity (descending)
844        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
845
846        // Take top N
847        let result = similarities.into_iter().take(top_n).collect();
848        Ok(result)
849    }
850
851    /// Compute the analogy: a is to b as c is to ?
852    pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
853        if self.input_embeddings.is_none() {
854            return Err(TextError::EmbeddingError(
855                "Model not trained. Call train() first".into(),
856            ));
857        }
858
859        // Get vectors for a, b, and c
860        let a_vec = self.get_word_vector(a)?;
861        let b_vec = self.get_word_vector(b)?;
862        let c_vec = self.get_word_vector(c)?;
863
864        // Compute d_vec = b_vec - a_vec + c_vec
865        let mut d_vec = b_vec.clone();
866        d_vec -= &a_vec;
867        d_vec += &c_vec;
868
869        // Normalize the vector
870        let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
871        d_vec.mapv_inplace(|val| val / norm);
872
873        // Find most similar words to d_vec
874        self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
875    }
876
877    /// Save the Word2Vec model to a file
878    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
879        if self.input_embeddings.is_none() {
880            return Err(TextError::EmbeddingError(
881                "Model not trained. Call train() first".into(),
882            ));
883        }
884
885        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
886
887        // Write header: vector_size and vocabulary size
888        writeln!(
889            &mut file,
890            "{} {}",
891            self.vocabulary.len(),
892            self.config.vector_size
893        )
894        .map_err(|e| TextError::IoError(e.to_string()))?;
895
896        // Write each word and its vector
897        let input_embeddings = self.input_embeddings.as_ref().unwrap();
898
899        for i in 0..self.vocabulary.len() {
900            if let Some(word) = self.vocabulary.get_token(i) {
901                // Write the word
902                write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
903
904                // Write the vector components
905                let vector = input_embeddings.row(i);
906                for j in 0..self.config.vector_size {
907                    write!(&mut file, "{:.6} ", vector[j])
908                        .map_err(|e| TextError::IoError(e.to_string()))?;
909                }
910
911                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
912            }
913        }
914
915        Ok(())
916    }
917
918    /// Load a Word2Vec model from a file
919    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
920        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
921        let mut reader = BufReader::new(file);
922
923        // Read header
924        let mut header = String::new();
925        reader
926            .read_line(&mut header)
927            .map_err(|e| TextError::IoError(e.to_string()))?;
928
929        let parts: Vec<&str> = header.split_whitespace().collect();
930        if parts.len() != 2 {
931            return Err(TextError::EmbeddingError(
932                "Invalid model file format".into(),
933            ));
934        }
935
936        let vocab_size = parts[0].parse::<usize>().map_err(|_| {
937            TextError::EmbeddingError("Invalid vocabulary size in model file".into())
938        })?;
939
940        let vector_size = parts[1]
941            .parse::<usize>()
942            .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
943
944        // Initialize model
945        let mut model = Self::new().with_vector_size(vector_size);
946        let mut vocabulary = Vocabulary::new();
947        let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
948
949        // Read each word and its vector
950        let mut i = 0;
951        for line in reader.lines() {
952            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
953            let parts: Vec<&str> = line.split_whitespace().collect();
954
955            if parts.len() != vector_size + 1 {
956                let line_num = i + 2;
957                return Err(TextError::EmbeddingError(format!(
958                    "Invalid vector format at line {line_num}"
959                )));
960            }
961
962            let word = parts[0];
963            vocabulary.add_token(word);
964
965            for j in 0..vector_size {
966                input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
967                    TextError::EmbeddingError(format!(
968                        "Invalid vector component at line {}, position {}",
969                        i + 2,
970                        j + 1
971                    ))
972                })?;
973            }
974
975            i += 1;
976        }
977
978        if i != vocab_size {
979            return Err(TextError::EmbeddingError(format!(
980                "Expected {vocab_size} words but found {i}"
981            )));
982        }
983
984        model.vocabulary = vocabulary;
985        model.input_embeddings = Some(input_embeddings);
986        model.output_embeddings = None; // Only input embeddings are saved
987
988        Ok(model)
989    }
990
991    // Getter methods for model registry serialization
992
993    /// Get the vocabulary as a vector of strings
994    pub fn get_vocabulary(&self) -> Vec<String> {
995        let mut vocab = Vec::new();
996        for i in 0..self.vocabulary.len() {
997            if let Some(token) = self.vocabulary.get_token(i) {
998                vocab.push(token.to_string());
999            }
1000        }
1001        vocab
1002    }
1003
1004    /// Get the vector size
1005    pub fn get_vector_size(&self) -> usize {
1006        self.config.vector_size
1007    }
1008
1009    /// Get the algorithm
1010    pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1011        self.config.algorithm
1012    }
1013
1014    /// Get the window size
1015    pub fn get_window_size(&self) -> usize {
1016        self.config.window_size
1017    }
1018
1019    /// Get the minimum count
1020    pub fn get_min_count(&self) -> usize {
1021        self.config.min_count
1022    }
1023
1024    /// Get the embeddings matrix (input embeddings)
1025    pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1026        self.input_embeddings.clone()
1027    }
1028
1029    /// Get the number of negative samples
1030    pub fn get_negative_samples(&self) -> usize {
1031        self.config.negative_samples
1032    }
1033
1034    /// Get the learning rate
1035    pub fn get_learning_rate(&self) -> f64 {
1036        self.config.learning_rate
1037    }
1038
1039    /// Get the number of epochs
1040    pub fn get_epochs(&self) -> usize {
1041        self.config.epochs
1042    }
1043
1044    /// Get the subsampling threshold
1045    pub fn get_subsampling_threshold(&self) -> f64 {
1046        self.config.subsample
1047    }
1048}
1049
1050/// Calculate cosine similarity between two vectors
1051#[allow(dead_code)]
1052pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1053    let dot_product = (a * b).sum();
1054    let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1055    let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1056
1057    if norm_a > 0.0 && norm_b > 0.0 {
1058        dot_product / (norm_a * norm_b)
1059    } else {
1060        0.0
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067    use approx::assert_relative_eq;
1068
1069    #[test]
1070    fn test_cosine_similarity() {
1071        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1072        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1073
1074        let similarity = cosine_similarity(&a, &b);
1075        let expected = 0.9746318461970762;
1076        assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1077    }
1078
1079    #[test]
1080    fn test_word2vec_config() {
1081        let config = Word2VecConfig::default();
1082        assert_eq!(config.vector_size, 100);
1083        assert_eq!(config.window_size, 5);
1084        assert_eq!(config.min_count, 5);
1085        assert_eq!(config.epochs, 5);
1086        assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1087    }
1088
1089    #[test]
1090    fn test_word2vec_builder() {
1091        let model = Word2Vec::new()
1092            .with_vector_size(200)
1093            .with_window_size(10)
1094            .with_learning_rate(0.05)
1095            .with_algorithm(Word2VecAlgorithm::CBOW);
1096
1097        assert_eq!(model.config.vector_size, 200);
1098        assert_eq!(model.config.window_size, 10);
1099        assert_eq!(model.config.learning_rate, 0.05);
1100        assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1101    }
1102
1103    #[test]
1104    fn test_build_vocabulary() {
1105        let texts = [
1106            "the quick brown fox jumps over the lazy dog",
1107            "a quick brown fox jumps over a lazy dog",
1108        ];
1109
1110        let mut model = Word2Vec::new().with_min_count(1);
1111        let result = model.build_vocabulary(&texts);
1112        assert!(result.is_ok());
1113
1114        // Check vocabulary size (unique words: "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "a")
1115        assert_eq!(model.vocabulary.len(), 9);
1116
1117        // Check that embeddings were initialized
1118        assert!(model.input_embeddings.is_some());
1119        assert!(model.output_embeddings.is_some());
1120        assert_eq!(model.input_embeddings.as_ref().unwrap().shape(), &[9, 100]);
1121    }
1122
1123    #[test]
1124    fn test_skipgram_training_small() {
1125        let texts = [
1126            "the quick brown fox jumps over the lazy dog",
1127            "a quick brown fox jumps over a lazy dog",
1128        ];
1129
1130        let mut model = Word2Vec::new()
1131            .with_vector_size(10)
1132            .with_window_size(2)
1133            .with_min_count(1)
1134            .with_epochs(1)
1135            .with_algorithm(Word2VecAlgorithm::SkipGram);
1136
1137        let result = model.train(&texts);
1138        assert!(result.is_ok());
1139
1140        // Test getting a word vector
1141        let result = model.get_word_vector("fox");
1142        assert!(result.is_ok());
1143        let vec = result.unwrap();
1144        assert_eq!(vec.len(), 10);
1145    }
1146}