Skip to main content

scirs2_text/embeddings/
mod.rs

1//! # Word Embeddings Module
2//!
3//! This module provides comprehensive implementations for word embeddings, including
4//! Word2Vec, GloVe, and FastText.
5//!
6//! ## Overview
7//!
8//! Word embeddings are dense vector representations of words that capture semantic
9//! relationships. This module implements:
10//!
11//! - **Word2Vec Skip-gram**: Predicts context words from a target word
12//! - **Word2Vec CBOW**: Predicts a target word from context words
13//! - **GloVe**: Global Vectors for Word Representation
14//! - **FastText**: Character n-gram based embeddings (handles OOV words)
15//! - **Negative sampling**: Efficient training technique for large vocabularies
16//! - **Hierarchical softmax**: Alternative to negative sampling for optimization
17//!
18//! ## Quick Start
19//!
20//! ```rust
21//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
22//!
23//! // Basic Word2Vec training
24//! let documents = vec![
25//!     "the quick brown fox jumps over the lazy dog",
26//!     "the dog was lazy but the fox was quick",
27//!     "brown fox and lazy dog are common phrases"
28//! ];
29//!
30//! let config = Word2VecConfig {
31//!     algorithm: Word2VecAlgorithm::SkipGram,
32//!     vector_size: 100,
33//!     window_size: 5,
34//!     min_count: 1,
35//!     learning_rate: 0.025,
36//!     epochs: 5,
37//!     negative_samples: 5,
38//!     ..Default::default()
39//! };
40//!
41//! let mut model = Word2Vec::with_config(config);
42//! model.train(&documents).expect("Training failed");
43//!
44//! // Get word vector
45//! if let Ok(vector) = model.get_word_vector("fox") {
46//!     println!("Vector for 'fox': {:?}", vector);
47//! }
48//!
49//! // Find similar words
50//! let similar = model.most_similar("fox", 3).expect("Failed to find similar words");
51//! for (word, similarity) in similar {
52//!     println!("{}: {:.4}", word, similarity);
53//! }
54//! ```
55//!
56//! ## Advanced Usage
57//!
58//! ### Custom Configuration
59//!
60//! ```rust
61//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
62//!
63//! let config = Word2VecConfig {
64//!     algorithm: Word2VecAlgorithm::CBOW,
65//!     vector_size: 300,        // Larger vectors for better quality
66//!     window_size: 10,         // Larger context window
67//!     min_count: 5,           // Filter rare words
68//!     learning_rate: 0.01,    // Lower learning rate for stability
69//!     epochs: 15,             // More training iterations
70//!     negative_samples: 10,   // More negative samples
71//!     subsample: 1e-4, // Subsample frequent words
72//!     batch_size: 128,
73//!     hierarchical_softmax: false,
74//! };
75//! ```
76//!
77//! ### Incremental Training
78//!
79//! ```rust
80//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
81//! # let mut model = Word2Vec::new().with_min_count(1);
82//! // Initial training
83//! let batch1 = vec!["the quick brown fox jumps over the lazy dog"];
84//! model.train(&batch1).expect("Training failed");
85//!
86//! // Continue training with new data
87//! let batch2 = vec!["the dog was lazy but the fox was quick"];
88//! model.train(&batch2).expect("Training failed");
89//! ```
90//!
91//! ### Saving and Loading Models
92//!
93//! ```rust
94//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
95//! # let mut model = Word2Vec::new().with_min_count(1);
96//! # let texts = vec!["the quick brown fox jumps over the lazy dog"];
97//! # model.train(&texts).expect("Training failed");
98//! // Save trained model
99//! model.save("my_model.w2v").expect("Failed to save model");
100//!
101//! // Load model
102//! let loaded_model = Word2Vec::load("my_model.w2v")
103//!     .expect("Failed to load model");
104//! ```
105//!
106//! ## Performance Tips
107//!
108//! 1. **Vocabulary Size**: Use `min_count` to filter rare words and reduce memory usage
109//! 2. **Vector Dimensions**: Balance between quality (higher dimensions) and speed (lower dimensions)
110//! 3. **Training Algorithm**: Skip-gram works better with rare words, CBOW is faster
111//! 4. **Negative Sampling**: Usually faster than hierarchical softmax for large vocabularies
112//! 5. **Subsampling**: Set `subsample_threshold` to handle frequent words efficiently
113//!
114//! ## Mathematical Background
115//!
116//! ### Skip-gram Model
117//!
118//! The Skip-gram model maximizes the probability of context words given a target word:
119//!
120//! P(context|target) = ∏ P(w_context|w_target)
121//!
122//! ### CBOW Model
123//!
124//! The CBOW model predicts a target word from its context:
125//!
126//! P(target|context) = P(w_target|w_context1, w_context2, ...)
127//!
128//! ### Negative Sampling
129//!
130//! Instead of computing the full softmax, negative sampling approximates it by
131//! sampling negative examples, making training much more efficient.
132
133// Sub-modules
134pub mod fasttext;
135pub mod glove;
136
137// Re-export
138pub use fasttext::{FastText, FastTextConfig};
139pub use glove::GloVe;
140
141use crate::error::{Result, TextError};
142use crate::tokenize::{Tokenizer, WordTokenizer};
143use crate::vocabulary::Vocabulary;
144use scirs2_core::ndarray::{Array1, Array2};
145use scirs2_core::random::prelude::*;
146use std::collections::HashMap;
147use std::fmt::Debug;
148use std::fs::File;
149use std::io::{BufRead, BufReader, Write};
150use std::path::Path;
151
152/// A simplified weighted sampling table
153#[derive(Debug, Clone)]
154struct SamplingTable {
155    /// The cumulative distribution function (CDF)
156    cdf: Vec<f64>,
157    /// The weights
158    weights: Vec<f64>,
159}
160
161impl SamplingTable {
162    /// Create a new sampling table from weights
163    fn new(weights: &[f64]) -> Result<Self> {
164        if weights.is_empty() {
165            return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
166        }
167
168        // Ensure all _weights are positive
169        if weights.iter().any(|&w| w < 0.0) {
170            return Err(TextError::EmbeddingError("Weights must be positive".into()));
171        }
172
173        // Compute the CDF
174        let sum: f64 = weights.iter().sum();
175        if sum <= 0.0 {
176            return Err(TextError::EmbeddingError(
177                "Sum of _weights must be positive".into(),
178            ));
179        }
180
181        let mut cdf = Vec::with_capacity(weights.len());
182        let mut total = 0.0;
183
184        for &w in weights {
185            total += w;
186            cdf.push(total / sum);
187        }
188
189        Ok(Self {
190            cdf,
191            weights: weights.to_vec(),
192        })
193    }
194
195    /// Sample an index based on the weights
196    fn sample<R: Rng>(&self, rng: &mut R) -> usize {
197        let r = rng.random::<f64>();
198
199        // Binary search for the insertion point
200        match self.cdf.binary_search_by(|&cdf_val| {
201            cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
202        }) {
203            Ok(idx) => idx,
204            Err(idx) => idx,
205        }
206    }
207
208    /// Get the weights
209    fn weights(&self) -> &[f64] {
210        &self.weights
211    }
212}
213
214/// Word2Vec training algorithms
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub enum Word2VecAlgorithm {
217    /// Continuous Bag of Words (CBOW) algorithm
218    CBOW,
219    /// Skip-gram algorithm
220    SkipGram,
221}
222
223/// Word2Vec training options
224#[derive(Debug, Clone)]
225pub struct Word2VecConfig {
226    /// Size of the word vectors
227    pub vector_size: usize,
228    /// Maximum distance between the current and predicted word within a sentence
229    pub window_size: usize,
230    /// Minimum count of words to consider for training
231    pub min_count: usize,
232    /// Number of iterations (epochs) over the corpus
233    pub epochs: usize,
234    /// Learning rate
235    pub learning_rate: f64,
236    /// Skip-gram or CBOW algorithm
237    pub algorithm: Word2VecAlgorithm,
238    /// Number of negative samples per positive sample
239    pub negative_samples: usize,
240    /// Threshold for subsampling frequent words
241    pub subsample: f64,
242    /// Batch size for training
243    pub batch_size: usize,
244    /// Whether to use hierarchical softmax (not yet implemented)
245    pub hierarchical_softmax: bool,
246}
247
248impl Default for Word2VecConfig {
249    fn default() -> Self {
250        Self {
251            vector_size: 100,
252            window_size: 5,
253            min_count: 5,
254            epochs: 5,
255            learning_rate: 0.025,
256            algorithm: Word2VecAlgorithm::SkipGram,
257            negative_samples: 5,
258            subsample: 1e-3,
259            batch_size: 128,
260            hierarchical_softmax: false,
261        }
262    }
263}
264
265/// Word2Vec model for training and using word embeddings
266///
267/// Word2Vec is an algorithm for learning vector representations of words,
268/// also known as word embeddings. These vectors capture semantic meanings
269/// of words, allowing operations like "king - man + woman" to result in
270/// a vector close to "queen".
271///
272/// This implementation supports both Continuous Bag of Words (CBOW) and
273/// Skip-gram models, with negative sampling for efficient training.
274pub struct Word2Vec {
275    /// Configuration options
276    config: Word2VecConfig,
277    /// Vocabulary
278    vocabulary: Vocabulary,
279    /// Input embeddings
280    input_embeddings: Option<Array2<f64>>,
281    /// Output embeddings
282    output_embeddings: Option<Array2<f64>>,
283    /// Tokenizer
284    tokenizer: Box<dyn Tokenizer + Send + Sync>,
285    /// Sampling table for negative sampling
286    sampling_table: Option<SamplingTable>,
287    /// Current learning rate (gets updated during training)
288    current_learning_rate: f64,
289}
290
291impl Debug for Word2Vec {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("Word2Vec")
294            .field("config", &self.config)
295            .field("vocabulary", &self.vocabulary)
296            .field("input_embeddings", &self.input_embeddings)
297            .field("output_embeddings", &self.output_embeddings)
298            .field("sampling_table", &self.sampling_table)
299            .field("current_learning_rate", &self.current_learning_rate)
300            .finish()
301    }
302}
303
304// Manual Clone implementation to handle the non-Clone tokenizer
305impl Default for Word2Vec {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311impl Clone for Word2Vec {
312    fn clone(&self) -> Self {
313        // Create a new tokenizer of the same type
314        // For simplicity, we always use WordTokenizer when cloning
315        // A more sophisticated solution would be to add a clone method to the Tokenizer trait
316        let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
317
318        Self {
319            config: self.config.clone(),
320            vocabulary: self.vocabulary.clone(),
321            input_embeddings: self.input_embeddings.clone(),
322            output_embeddings: self.output_embeddings.clone(),
323            tokenizer,
324            sampling_table: self.sampling_table.clone(),
325            current_learning_rate: self.current_learning_rate,
326        }
327    }
328}
329
330impl Word2Vec {
331    /// Create a new Word2Vec model with default configuration
332    pub fn new() -> Self {
333        Self {
334            config: Word2VecConfig::default(),
335            vocabulary: Vocabulary::new(),
336            input_embeddings: None,
337            output_embeddings: None,
338            tokenizer: Box::new(WordTokenizer::default()),
339            sampling_table: None,
340            current_learning_rate: 0.025,
341        }
342    }
343
344    /// Create a new Word2Vec model with the specified configuration
345    pub fn with_config(config: Word2VecConfig) -> Self {
346        let learning_rate = config.learning_rate;
347        Self {
348            config,
349            vocabulary: Vocabulary::new(),
350            input_embeddings: None,
351            output_embeddings: None,
352            tokenizer: Box::new(WordTokenizer::default()),
353            sampling_table: None,
354            current_learning_rate: learning_rate,
355        }
356    }
357
358    /// Set a custom tokenizer
359    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
360        self.tokenizer = tokenizer;
361        self
362    }
363
364    /// Set vector size
365    pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
366        self.config.vector_size = vectorsize;
367        self
368    }
369
370    /// Set window size
371    pub fn with_window_size(mut self, windowsize: usize) -> Self {
372        self.config.window_size = windowsize;
373        self
374    }
375
376    /// Set minimum count
377    pub fn with_min_count(mut self, mincount: usize) -> Self {
378        self.config.min_count = mincount;
379        self
380    }
381
382    /// Set number of epochs
383    pub fn with_epochs(mut self, epochs: usize) -> Self {
384        self.config.epochs = epochs;
385        self
386    }
387
388    /// Set learning rate
389    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
390        self.config.learning_rate = learningrate;
391        self.current_learning_rate = learningrate;
392        self
393    }
394
395    /// Set algorithm (CBOW or Skip-gram)
396    pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
397        self.config.algorithm = algorithm;
398        self
399    }
400
401    /// Set number of negative samples
402    pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
403        self.config.negative_samples = negativesamples;
404        self
405    }
406
407    /// Set subsampling threshold
408    pub fn with_subsample(mut self, subsample: f64) -> Self {
409        self.config.subsample = subsample;
410        self
411    }
412
413    /// Set batch size
414    pub fn with_batch_size(mut self, batchsize: usize) -> Self {
415        self.config.batch_size = batchsize;
416        self
417    }
418
419    /// Build vocabulary from a corpus
420    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
421        if texts.is_empty() {
422            return Err(TextError::InvalidInput(
423                "No texts provided for building vocabulary".into(),
424            ));
425        }
426
427        // Count word frequencies
428        let mut word_counts = HashMap::new();
429        let mut _total_words = 0;
430
431        for &text in texts {
432            let tokens = self.tokenizer.tokenize(text)?;
433            for token in tokens {
434                *word_counts.entry(token).or_insert(0) += 1;
435                _total_words += 1;
436            }
437        }
438
439        // Create vocabulary with words that meet minimum count
440        self.vocabulary = Vocabulary::new();
441        for (word, count) in &word_counts {
442            if *count >= self.config.min_count {
443                self.vocabulary.add_token(word);
444            }
445        }
446
447        if self.vocabulary.is_empty() {
448            return Err(TextError::VocabularyError(
449                "No words meet the minimum count threshold".into(),
450            ));
451        }
452
453        // Initialize embeddings
454        let vocab_size = self.vocabulary.len();
455        let vector_size = self.config.vector_size;
456
457        // Initialize input and output embeddings with small random values
458        let mut rng = scirs2_core::random::rng();
459        let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
460            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
461        });
462        let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
463            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
464        });
465
466        self.input_embeddings = Some(input_embeddings);
467        self.output_embeddings = Some(output_embeddings);
468
469        // Create sampling table for negative sampling
470        self.create_sampling_table(&word_counts)?;
471
472        Ok(())
473    }
474
475    /// Create sampling table for negative sampling based on word frequencies
476    fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
477        // Prepare sampling weights (unigram distribution raised to 3/4 power)
478        let mut sampling_weights = vec![0.0; self.vocabulary.len()];
479
480        for (word, &count) in wordcounts.iter() {
481            if let Some(idx) = self.vocabulary.get_index(word) {
482                // Apply smoothing: frequency^0.75
483                sampling_weights[idx] = (count as f64).powf(0.75);
484            }
485        }
486
487        match SamplingTable::new(&sampling_weights) {
488            Ok(table) => {
489                self.sampling_table = Some(table);
490                Ok(())
491            }
492            Err(e) => Err(e),
493        }
494    }
495
496    /// Train the Word2Vec model on a corpus
497    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
498        if texts.is_empty() {
499            return Err(TextError::InvalidInput(
500                "No texts provided for training".into(),
501            ));
502        }
503
504        // Build vocabulary if not already built
505        if self.vocabulary.is_empty() {
506            self.build_vocabulary(texts)?;
507        }
508
509        if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
510            return Err(TextError::EmbeddingError(
511                "Embeddings not initialized. Call build_vocabulary() first".into(),
512            ));
513        }
514
515        // Count total number of tokens for progress tracking
516        let mut _total_tokens = 0;
517        let mut sentences = Vec::new();
518        for &text in texts {
519            let tokens = self.tokenizer.tokenize(text)?;
520            let filtered_tokens: Vec<usize> = tokens
521                .iter()
522                .filter_map(|token| self.vocabulary.get_index(token))
523                .collect();
524            if !filtered_tokens.is_empty() {
525                _total_tokens += filtered_tokens.len();
526                sentences.push(filtered_tokens);
527            }
528        }
529
530        // Train for the specified number of epochs
531        for epoch in 0..self.config.epochs {
532            // Update learning rate for this epoch
533            self.current_learning_rate =
534                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
535            self.current_learning_rate = self
536                .current_learning_rate
537                .max(self.config.learning_rate * 0.0001);
538
539            // Process each sentence
540            for sentence in &sentences {
541                // Apply subsampling of frequent words
542                let subsampled_sentence = if self.config.subsample > 0.0 {
543                    self.subsample_sentence(sentence)?
544                } else {
545                    sentence.clone()
546                };
547
548                // Skip empty sentences
549                if subsampled_sentence.is_empty() {
550                    continue;
551                }
552
553                // Train on the sentence
554                match self.config.algorithm {
555                    Word2VecAlgorithm::CBOW => {
556                        self.train_cbow_sentence(&subsampled_sentence)?;
557                    }
558                    Word2VecAlgorithm::SkipGram => {
559                        self.train_skipgram_sentence(&subsampled_sentence)?;
560                    }
561                }
562            }
563        }
564
565        Ok(())
566    }
567
568    /// Apply subsampling to a sentence
569    fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
570        let mut rng = scirs2_core::random::rng();
571        let total_words: f64 = self.vocabulary.len() as f64;
572        let threshold = self.config.subsample * total_words;
573
574        // Filter words based on subsampling probability
575        let subsampled: Vec<usize> = sentence
576            .iter()
577            .filter(|&&word_idx| {
578                let word_freq = self.get_word_frequency(word_idx);
579                if word_freq == 0.0 {
580                    return true; // Keep rare words
581                }
582                // Probability of keeping the word
583                let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
584                rng.random::<f64>() < keep_prob
585            })
586            .copied()
587            .collect();
588
589        Ok(subsampled)
590    }
591
592    /// Get the frequency of a word in the vocabulary
593    fn get_word_frequency(&self, wordidx: usize) -> f64 {
594        // This is a simplified version; ideal implementation would track actual frequencies
595        // For now, we'll use the sampling table weights as a proxy
596        if let Some(table) = &self.sampling_table {
597            table.weights()[wordidx]
598        } else {
599            1.0 // Default weight if no sampling table exists
600        }
601    }
602
603    /// Train CBOW model on a single sentence
604    fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
605        if sentence.len() < 2 {
606            return Ok(()); // Need at least 2 words for context
607        }
608
609        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
610        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
611        let vector_size = self.config.vector_size;
612        let window_size = self.config.window_size;
613        let negative_samples = self.config.negative_samples;
614
615        // For each position in sentence, predict the word from its context
616        for pos in 0..sentence.len() {
617            // Determine context window (with random size)
618            let mut rng = scirs2_core::random::rng();
619            let window = 1 + rng.random_range(0..window_size);
620            let target_word = sentence[pos];
621
622            // Collect context words and average their vectors
623            let mut context_words = Vec::new();
624            #[allow(clippy::needless_range_loop)]
625            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
626                if i != pos {
627                    context_words.push(sentence[i]);
628                }
629            }
630
631            if context_words.is_empty() {
632                continue; // No context words
633            }
634
635            // Average the context word vectors
636            let mut context_sum = Array1::zeros(vector_size);
637            for &context_idx in &context_words {
638                context_sum += &input_embeddings.row(context_idx);
639            }
640            let context_avg = &context_sum / context_words.len() as f64;
641
642            // Update target word's output embedding with positive example
643            let mut target_output = output_embeddings.row_mut(target_word);
644            let dot_product = (&context_avg * &target_output).sum();
645            let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
646            let error = (1.0 - sigmoid) * self.current_learning_rate;
647
648            // Create a copy for update
649            let mut target_update = target_output.to_owned();
650            target_update.scaled_add(error, &context_avg);
651            target_output.assign(&target_update);
652
653            // Negative sampling
654            if let Some(sampler) = &self.sampling_table {
655                for _ in 0..negative_samples {
656                    let negative_idx = sampler.sample(&mut rng);
657                    if negative_idx == target_word {
658                        continue; // Skip if we sample the target word
659                    }
660
661                    let mut negative_output = output_embeddings.row_mut(negative_idx);
662                    let dot_product = (&context_avg * &negative_output).sum();
663                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
664                    let error = -sigmoid * self.current_learning_rate;
665
666                    // Create a copy for update
667                    let mut negative_update = negative_output.to_owned();
668                    negative_update.scaled_add(error, &context_avg);
669                    negative_output.assign(&negative_update);
670                }
671            }
672
673            // Update context word vectors
674            for &context_idx in &context_words {
675                let mut input_vec = input_embeddings.row_mut(context_idx);
676
677                // Positive example
678                let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
679                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
680                let error =
681                    (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
682
683                // Create a copy for update
684                let mut input_update = input_vec.to_owned();
685                input_update.scaled_add(error, &output_embeddings.row(target_word));
686
687                // Negative examples
688                if let Some(sampler) = &self.sampling_table {
689                    for _ in 0..negative_samples {
690                        let negative_idx = sampler.sample(&mut rng);
691                        if negative_idx == target_word {
692                            continue;
693                        }
694
695                        let dot_product =
696                            (&context_avg * &output_embeddings.row(negative_idx)).sum();
697                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
698                        let error =
699                            -sigmoid * self.current_learning_rate / context_words.len() as f64;
700
701                        input_update.scaled_add(error, &output_embeddings.row(negative_idx));
702                    }
703                }
704
705                input_vec.assign(&input_update);
706            }
707        }
708
709        Ok(())
710    }
711
712    /// Train Skip-gram model on a single sentence
713    fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
714        if sentence.len() < 2 {
715            return Ok(()); // Need at least 2 words for context
716        }
717
718        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
719        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
720        let vector_size = self.config.vector_size;
721        let window_size = self.config.window_size;
722        let negative_samples = self.config.negative_samples;
723
724        // For each position in sentence, predict the context from the word
725        for pos in 0..sentence.len() {
726            // Determine context window (with random size)
727            let mut rng = scirs2_core::random::rng();
728            let window = 1 + rng.random_range(0..window_size);
729            let target_word = sentence[pos];
730
731            // For each context position
732            #[allow(clippy::needless_range_loop)]
733            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
734                if i == pos {
735                    continue; // Skip the target word itself
736                }
737
738                let context_word = sentence[i];
739                let target_input = input_embeddings.row(target_word);
740                let mut context_output = output_embeddings.row_mut(context_word);
741
742                // Update context word's output embedding with positive example
743                let dot_product = (&target_input * &context_output).sum();
744                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
745                let error = (1.0 - sigmoid) * self.current_learning_rate;
746
747                // Create a copy for update
748                let mut context_update = context_output.to_owned();
749                context_update.scaled_add(error, &target_input);
750                context_output.assign(&context_update);
751
752                // Gradient for input word vector
753                let mut input_update = Array1::zeros(vector_size);
754                input_update.scaled_add(error, &context_output);
755
756                // Negative sampling
757                if let Some(sampler) = &self.sampling_table {
758                    for _ in 0..negative_samples {
759                        let negative_idx = sampler.sample(&mut rng);
760                        if negative_idx == context_word {
761                            continue; // Skip if we sample the context word
762                        }
763
764                        let mut negative_output = output_embeddings.row_mut(negative_idx);
765                        let dot_product = (&target_input * &negative_output).sum();
766                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
767                        let error = -sigmoid * self.current_learning_rate;
768
769                        // Create a copy for update
770                        let mut negative_update = negative_output.to_owned();
771                        negative_update.scaled_add(error, &target_input);
772                        negative_output.assign(&negative_update);
773
774                        // Update input gradient
775                        input_update.scaled_add(error, &negative_output);
776                    }
777                }
778
779                // Apply the accumulated gradient to the input word vector
780                let mut target_input_mut = input_embeddings.row_mut(target_word);
781                target_input_mut += &input_update;
782            }
783        }
784
785        Ok(())
786    }
787
788    /// Get the vector size
789    pub fn vector_size(&self) -> usize {
790        self.config.vector_size
791    }
792
793    /// Get the embedding vector for a word
794    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
795        if self.input_embeddings.is_none() {
796            return Err(TextError::EmbeddingError(
797                "Model not trained. Call train() first".into(),
798            ));
799        }
800
801        match self.vocabulary.get_index(word) {
802            Some(idx) => Ok(self
803                .input_embeddings
804                .as_ref()
805                .expect("Operation failed")
806                .row(idx)
807                .to_owned()),
808            None => Err(TextError::VocabularyError(format!(
809                "Word '{word}' not in vocabulary"
810            ))),
811        }
812    }
813
814    /// Get the most similar words to a given word
815    pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
816        let word_vec = self.get_word_vector(word)?;
817        self.most_similar_by_vector(&word_vec, topn, &[word])
818    }
819
820    /// Get the most similar words to a given vector
821    pub fn most_similar_by_vector(
822        &self,
823        vector: &Array1<f64>,
824        top_n: usize,
825        exclude_words: &[&str],
826    ) -> Result<Vec<(String, f64)>> {
827        if self.input_embeddings.is_none() {
828            return Err(TextError::EmbeddingError(
829                "Model not trained. Call train() first".into(),
830            ));
831        }
832
833        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
834        let vocab_size = self.vocabulary.len();
835
836        // Create a set of indices to exclude
837        let exclude_indices: Vec<usize> = exclude_words
838            .iter()
839            .filter_map(|&word| self.vocabulary.get_index(word))
840            .collect();
841
842        // Calculate cosine similarity for all _words
843        let mut similarities = Vec::with_capacity(vocab_size);
844
845        for i in 0..vocab_size {
846            if exclude_indices.contains(&i) {
847                continue;
848            }
849
850            let word_vec = input_embeddings.row(i);
851            let similarity = cosine_similarity(vector, &word_vec.to_owned());
852
853            if let Some(word) = self.vocabulary.get_token(i) {
854                similarities.push((word.to_string(), similarity));
855            }
856        }
857
858        // Sort by similarity (descending)
859        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
860
861        // Take top N
862        let result = similarities.into_iter().take(top_n).collect();
863        Ok(result)
864    }
865
866    /// Compute the analogy: a is to b as c is to ?
867    pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
868        if self.input_embeddings.is_none() {
869            return Err(TextError::EmbeddingError(
870                "Model not trained. Call train() first".into(),
871            ));
872        }
873
874        // Get vectors for a, b, and c
875        let a_vec = self.get_word_vector(a)?;
876        let b_vec = self.get_word_vector(b)?;
877        let c_vec = self.get_word_vector(c)?;
878
879        // Compute d_vec = b_vec - a_vec + c_vec
880        let mut d_vec = b_vec.clone();
881        d_vec -= &a_vec;
882        d_vec += &c_vec;
883
884        // Normalize the vector
885        let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
886        d_vec.mapv_inplace(|val| val / norm);
887
888        // Find most similar words to d_vec
889        self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
890    }
891
892    /// Save the Word2Vec model to a file
893    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
894        if self.input_embeddings.is_none() {
895            return Err(TextError::EmbeddingError(
896                "Model not trained. Call train() first".into(),
897            ));
898        }
899
900        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
901
902        // Write header: vector_size and vocabulary size
903        writeln!(
904            &mut file,
905            "{} {}",
906            self.vocabulary.len(),
907            self.config.vector_size
908        )
909        .map_err(|e| TextError::IoError(e.to_string()))?;
910
911        // Write each word and its vector
912        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
913
914        for i in 0..self.vocabulary.len() {
915            if let Some(word) = self.vocabulary.get_token(i) {
916                // Write the word
917                write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
918
919                // Write the vector components
920                let vector = input_embeddings.row(i);
921                for j in 0..self.config.vector_size {
922                    write!(&mut file, "{:.6} ", vector[j])
923                        .map_err(|e| TextError::IoError(e.to_string()))?;
924                }
925
926                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
927            }
928        }
929
930        Ok(())
931    }
932
933    /// Load a Word2Vec model from a file
934    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
935        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
936        let mut reader = BufReader::new(file);
937
938        // Read header
939        let mut header = String::new();
940        reader
941            .read_line(&mut header)
942            .map_err(|e| TextError::IoError(e.to_string()))?;
943
944        let parts: Vec<&str> = header.split_whitespace().collect();
945        if parts.len() != 2 {
946            return Err(TextError::EmbeddingError(
947                "Invalid model file format".into(),
948            ));
949        }
950
951        let vocab_size = parts[0].parse::<usize>().map_err(|_| {
952            TextError::EmbeddingError("Invalid vocabulary size in model file".into())
953        })?;
954
955        let vector_size = parts[1]
956            .parse::<usize>()
957            .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
958
959        // Initialize model
960        let mut model = Self::new().with_vector_size(vector_size);
961        let mut vocabulary = Vocabulary::new();
962        let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
963
964        // Read each word and its vector
965        let mut i = 0;
966        for line in reader.lines() {
967            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
968            let parts: Vec<&str> = line.split_whitespace().collect();
969
970            if parts.len() != vector_size + 1 {
971                let line_num = i + 2;
972                return Err(TextError::EmbeddingError(format!(
973                    "Invalid vector format at line {line_num}"
974                )));
975            }
976
977            let word = parts[0];
978            vocabulary.add_token(word);
979
980            for j in 0..vector_size {
981                input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
982                    TextError::EmbeddingError(format!(
983                        "Invalid vector component at line {}, position {}",
984                        i + 2,
985                        j + 1
986                    ))
987                })?;
988            }
989
990            i += 1;
991        }
992
993        if i != vocab_size {
994            return Err(TextError::EmbeddingError(format!(
995                "Expected {vocab_size} words but found {i}"
996            )));
997        }
998
999        model.vocabulary = vocabulary;
1000        model.input_embeddings = Some(input_embeddings);
1001        model.output_embeddings = None; // Only input embeddings are saved
1002
1003        Ok(model)
1004    }
1005
1006    // Getter methods for model registry serialization
1007
1008    /// Get the vocabulary as a vector of strings
1009    pub fn get_vocabulary(&self) -> Vec<String> {
1010        let mut vocab = Vec::new();
1011        for i in 0..self.vocabulary.len() {
1012            if let Some(token) = self.vocabulary.get_token(i) {
1013                vocab.push(token.to_string());
1014            }
1015        }
1016        vocab
1017    }
1018
1019    /// Get the vector size
1020    pub fn get_vector_size(&self) -> usize {
1021        self.config.vector_size
1022    }
1023
1024    /// Get the algorithm
1025    pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1026        self.config.algorithm
1027    }
1028
1029    /// Get the window size
1030    pub fn get_window_size(&self) -> usize {
1031        self.config.window_size
1032    }
1033
1034    /// Get the minimum count
1035    pub fn get_min_count(&self) -> usize {
1036        self.config.min_count
1037    }
1038
1039    /// Get the embeddings matrix (input embeddings)
1040    pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1041        self.input_embeddings.clone()
1042    }
1043
1044    /// Get the number of negative samples
1045    pub fn get_negative_samples(&self) -> usize {
1046        self.config.negative_samples
1047    }
1048
1049    /// Get the learning rate
1050    pub fn get_learning_rate(&self) -> f64 {
1051        self.config.learning_rate
1052    }
1053
1054    /// Get the number of epochs
1055    pub fn get_epochs(&self) -> usize {
1056        self.config.epochs
1057    }
1058
1059    /// Get the subsampling threshold
1060    pub fn get_subsampling_threshold(&self) -> f64 {
1061        self.config.subsample
1062    }
1063}
1064
1065/// Calculate cosine similarity between two vectors
1066#[allow(dead_code)]
1067pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1068    let dot_product = (a * b).sum();
1069    let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1070    let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1071
1072    if norm_a > 0.0 && norm_b > 0.0 {
1073        dot_product / (norm_a * norm_b)
1074    } else {
1075        0.0
1076    }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082    use approx::assert_relative_eq;
1083
1084    #[test]
1085    fn test_cosine_similarity() {
1086        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1087        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1088
1089        let similarity = cosine_similarity(&a, &b);
1090        let expected = 0.9746318461970762;
1091        assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1092    }
1093
1094    #[test]
1095    fn test_word2vec_config() {
1096        let config = Word2VecConfig::default();
1097        assert_eq!(config.vector_size, 100);
1098        assert_eq!(config.window_size, 5);
1099        assert_eq!(config.min_count, 5);
1100        assert_eq!(config.epochs, 5);
1101        assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1102    }
1103
1104    #[test]
1105    fn test_word2vec_builder() {
1106        let model = Word2Vec::new()
1107            .with_vector_size(200)
1108            .with_window_size(10)
1109            .with_learning_rate(0.05)
1110            .with_algorithm(Word2VecAlgorithm::CBOW);
1111
1112        assert_eq!(model.config.vector_size, 200);
1113        assert_eq!(model.config.window_size, 10);
1114        assert_eq!(model.config.learning_rate, 0.05);
1115        assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1116    }
1117
1118    #[test]
1119    fn test_build_vocabulary() {
1120        let texts = [
1121            "the quick brown fox jumps over the lazy dog",
1122            "a quick brown fox jumps over a lazy dog",
1123        ];
1124
1125        let mut model = Word2Vec::new().with_min_count(1);
1126        let result = model.build_vocabulary(&texts);
1127        assert!(result.is_ok());
1128
1129        // Check vocabulary size (unique words: "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "a")
1130        assert_eq!(model.vocabulary.len(), 9);
1131
1132        // Check that embeddings were initialized
1133        assert!(model.input_embeddings.is_some());
1134        assert!(model.output_embeddings.is_some());
1135        assert_eq!(
1136            model
1137                .input_embeddings
1138                .as_ref()
1139                .expect("Operation failed")
1140                .shape(),
1141            &[9, 100]
1142        );
1143    }
1144
1145    #[test]
1146    fn test_skipgram_training_small() {
1147        let texts = [
1148            "the quick brown fox jumps over the lazy dog",
1149            "a quick brown fox jumps over a lazy dog",
1150        ];
1151
1152        let mut model = Word2Vec::new()
1153            .with_vector_size(10)
1154            .with_window_size(2)
1155            .with_min_count(1)
1156            .with_epochs(1)
1157            .with_algorithm(Word2VecAlgorithm::SkipGram);
1158
1159        let result = model.train(&texts);
1160        assert!(result.is_ok());
1161
1162        // Test getting a word vector
1163        let result = model.get_word_vector("fox");
1164        assert!(result.is_ok());
1165        let vec = result.expect("Operation failed");
1166        assert_eq!(vec.len(), 10);
1167    }
1168}