Skip to main content

scirs2_text/embeddings/
mod.rs

1//! # Word Embeddings Module
2//!
3//! This module provides comprehensive implementations for word embeddings, including
4//! Word2Vec, GloVe, and FastText.
5//!
6//! ## Overview
7//!
8//! Word embeddings are dense vector representations of words that capture semantic
9//! relationships. This module implements:
10//!
11//! - **Word2Vec Skip-gram**: Predicts context words from a target word
12//! - **Word2Vec CBOW**: Predicts a target word from context words
13//! - **GloVe**: Global Vectors for Word Representation
14//! - **FastText**: Character n-gram based embeddings (handles OOV words)
15//! - **Negative sampling**: Efficient training technique for large vocabularies
16//! - **Hierarchical softmax**: Alternative to negative sampling for optimization
17//!
18//! ## Quick Start
19//!
20//! ```rust
21//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
22//!
23//! // Basic Word2Vec training
24//! let documents = vec![
25//!     "the quick brown fox jumps over the lazy dog",
26//!     "the dog was lazy but the fox was quick",
27//!     "brown fox and lazy dog are common phrases"
28//! ];
29//!
30//! let config = Word2VecConfig {
31//!     algorithm: Word2VecAlgorithm::SkipGram,
32//!     vector_size: 100,
33//!     window_size: 5,
34//!     min_count: 1,
35//!     learning_rate: 0.025,
36//!     epochs: 5,
37//!     negative_samples: 5,
38//!     ..Default::default()
39//! };
40//!
41//! let mut model = Word2Vec::with_config(config);
42//! model.train(&documents).expect("Training failed");
43//!
44//! // Get word vector
45//! if let Ok(vector) = model.get_word_vector("fox") {
46//!     println!("Vector for 'fox': {:?}", vector);
47//! }
48//!
49//! // Find similar words
50//! let similar = model.most_similar("fox", 3).expect("Failed to find similar words");
51//! for (word, similarity) in similar {
52//!     println!("{}: {:.4}", word, similarity);
53//! }
54//! ```
55//!
56//! ## Advanced Usage
57//!
58//! ### Custom Configuration
59//!
60//! ```rust
61//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
62//!
63//! let config = Word2VecConfig {
64//!     algorithm: Word2VecAlgorithm::CBOW,
65//!     vector_size: 300,        // Larger vectors for better quality
66//!     window_size: 10,         // Larger context window
67//!     min_count: 5,           // Filter rare words
68//!     learning_rate: 0.01,    // Lower learning rate for stability
69//!     epochs: 15,             // More training iterations
70//!     negative_samples: 10,   // More negative samples
71//!     subsample: 1e-4, // Subsample frequent words
72//!     batch_size: 128,
73//!     hierarchical_softmax: false,
74//! };
75//! ```
76//!
77//! ### Incremental Training
78//!
79//! ```rust
80//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
81//! # let mut model = Word2Vec::new().with_min_count(1);
82//! // Initial training
83//! let batch1 = vec!["the quick brown fox jumps over the lazy dog"];
84//! model.train(&batch1).expect("Training failed");
85//!
86//! // Continue training with new data
87//! let batch2 = vec!["the dog was lazy but the fox was quick"];
88//! model.train(&batch2).expect("Training failed");
89//! ```
90//!
91//! ### Saving and Loading Models
92//!
93//! ```rust
94//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
95//! # let mut model = Word2Vec::new().with_min_count(1);
96//! # let texts = vec!["the quick brown fox jumps over the lazy dog"];
97//! # model.train(&texts).expect("Training failed");
98//! // Save trained model
99//! model.save("my_model.w2v").expect("Failed to save model");
100//!
101//! // Load model
102//! let loaded_model = Word2Vec::load("my_model.w2v")
103//!     .expect("Failed to load model");
104//! ```
105//!
106//! ## Performance Tips
107//!
108//! 1. **Vocabulary Size**: Use `min_count` to filter rare words and reduce memory usage
109//! 2. **Vector Dimensions**: Balance between quality (higher dimensions) and speed (lower dimensions)
110//! 3. **Training Algorithm**: Skip-gram works better with rare words, CBOW is faster
111//! 4. **Negative Sampling**: Usually faster than hierarchical softmax for large vocabularies
112//! 5. **Subsampling**: Set `subsample_threshold` to handle frequent words efficiently
113//!
114//! ## Mathematical Background
115//!
116//! ### Skip-gram Model
117//!
118//! The Skip-gram model maximizes the probability of context words given a target word:
119//!
120//! P(context|target) = ∏ P(w_context|w_target)
121//!
122//! ### CBOW Model
123//!
124//! The CBOW model predicts a target word from its context:
125//!
126//! P(target|context) = P(w_target|w_context1, w_context2, ...)
127//!
128//! ### Negative Sampling
129//!
130//! Instead of computing the full softmax, negative sampling approximates it by
131//! sampling negative examples, making training much more efficient.
132
133// Sub-modules
134pub mod contrastive;
135pub mod crosslingual;
136pub mod fasttext;
137pub mod glove;
138pub mod sentence;
139
140// Re-export
141pub use fasttext::{FastText, FastTextConfig};
142pub use glove::{
143    cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
144    GloVeTrainerConfig,
145};
146
147use crate::error::{Result, TextError};
148use crate::tokenize::{Tokenizer, WordTokenizer};
149use crate::vocabulary::Vocabulary;
150use scirs2_core::ndarray::{Array1, Array2};
151use scirs2_core::random::prelude::*;
152use std::collections::HashMap;
153use std::fmt::Debug;
154use std::fs::File;
155use std::io::{BufRead, BufReader, Write};
156use std::path::Path;
157
158// ─── WordEmbedding trait ─────────────────────────────────────────────────────
159
160/// Shared trait for word embedding models
161///
162/// Provides a common interface for querying word vectors, computing similarity,
163/// and solving analogies across different embedding implementations
164/// (Word2Vec, GloVe, FastText).
165pub trait WordEmbedding {
166    /// Get the embedding vector for a word
167    fn embedding(&self, word: &str) -> Result<Array1<f64>>;
168
169    /// Get the dimensionality of the embedding vectors
170    fn dimension(&self) -> usize;
171
172    /// Compute cosine similarity between two words
173    fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
174        let v1 = self.embedding(word1)?;
175        let v2 = self.embedding(word2)?;
176        Ok(embedding_cosine_similarity(&v1, &v2))
177    }
178
179    /// Find the top-N most similar words to the query word
180    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
181
182    /// Solve the analogy: a is to b as c is to ?
183    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
184
185    /// Get the number of words in the vocabulary
186    fn vocab_size(&self) -> usize;
187}
188
189/// Calculate cosine similarity between two embedding vectors
190pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
191    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
192    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
193    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
194
195    if norm_a > 0.0 && norm_b > 0.0 {
196        dot_product / (norm_a * norm_b)
197    } else {
198        0.0
199    }
200}
201
202/// Compute pairwise cosine similarity matrix for a list of words
203pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
204    let vectors: Vec<Array1<f64>> = words
205        .iter()
206        .map(|&w| model.embedding(w))
207        .collect::<Result<Vec<_>>>()?;
208
209    let n = vectors.len();
210    let mut matrix = vec![vec![0.0; n]; n];
211
212    for i in 0..n {
213        for j in i..n {
214            let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
215            matrix[i][j] = sim;
216            matrix[j][i] = sim;
217        }
218    }
219
220    Ok(matrix)
221}
222
223// ─── WordEmbedding implementations ──────────────────────────────────────────
224
225impl WordEmbedding for GloVe {
226    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
227        self.get_word_vector(word)
228    }
229
230    fn dimension(&self) -> usize {
231        self.vector_size()
232    }
233
234    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
235        self.most_similar(word, top_n)
236    }
237
238    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
239        self.analogy(a, b, c, top_n)
240    }
241
242    fn vocab_size(&self) -> usize {
243        self.vocabulary_size()
244    }
245}
246
247impl WordEmbedding for FastText {
248    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
249        self.get_word_vector(word)
250    }
251
252    fn dimension(&self) -> usize {
253        self.vector_size()
254    }
255
256    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
257        self.most_similar(word, top_n)
258    }
259
260    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
261        self.analogy(a, b, c, top_n)
262    }
263
264    fn vocab_size(&self) -> usize {
265        self.vocabulary_size()
266    }
267}
268
269// ─── Huffman Tree for Hierarchical Softmax ──────────────────────────────────
270
271/// A node in the Huffman tree used for hierarchical softmax
272#[derive(Debug, Clone)]
273struct HuffmanNode {
274    /// Index in vocabulary (for leaf nodes) or internal node id
275    id: usize,
276    /// Frequency (for building the tree)
277    frequency: usize,
278    /// Left child index in the nodes array
279    left: Option<usize>,
280    /// Right child index in the nodes array
281    right: Option<usize>,
282    /// Is this a leaf node?
283    is_leaf: bool,
284}
285
286/// Huffman tree for hierarchical softmax training
287#[derive(Debug, Clone)]
288struct HuffmanTree {
289    /// Huffman codes for each vocabulary word: sequence of 0/1 decisions
290    codes: Vec<Vec<u8>>,
291    /// Path of internal node indices for each vocabulary word
292    paths: Vec<Vec<usize>>,
293    /// Number of internal nodes
294    num_internal: usize,
295}
296
297impl HuffmanTree {
298    /// Build a Huffman tree from word frequencies
299    ///
300    /// Returns codes and paths for each word in the vocabulary.
301    fn build(frequencies: &[usize]) -> Result<Self> {
302        let vocab_size = frequencies.len();
303        if vocab_size == 0 {
304            return Err(TextError::EmbeddingError(
305                "Cannot build Huffman tree with empty vocabulary".into(),
306            ));
307        }
308        if vocab_size == 1 {
309            // Special case: single word
310            return Ok(Self {
311                codes: vec![vec![0]],
312                paths: vec![vec![0]],
313                num_internal: 1,
314            });
315        }
316
317        // Create leaf nodes
318        let mut nodes: Vec<HuffmanNode> = frequencies
319            .iter()
320            .enumerate()
321            .map(|(id, &freq)| HuffmanNode {
322                id,
323                frequency: freq.max(1), // Avoid zero frequency
324                left: None,
325                right: None,
326                is_leaf: true,
327            })
328            .collect();
329
330        // Priority queue simulation using a sorted list
331        // (index_in_nodes, frequency)
332        let mut queue: Vec<(usize, usize)> = nodes
333            .iter()
334            .enumerate()
335            .map(|(i, n)| (i, n.frequency))
336            .collect();
337        queue.sort_by_key(|item| std::cmp::Reverse(item.1)); // Sort descending (pop from end = min)
338
339        // Build the tree bottom-up
340        while queue.len() > 1 {
341            // Pop two smallest
342            let (idx1, freq1) = queue
343                .pop()
344                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
345            let (idx2, freq2) = queue
346                .pop()
347                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
348
349            let new_id = nodes.len();
350            let new_node = HuffmanNode {
351                id: new_id,
352                frequency: freq1 + freq2,
353                left: Some(idx1),
354                right: Some(idx2),
355                is_leaf: false,
356            };
357            nodes.push(new_node);
358
359            // Insert new node maintaining sorted order
360            let new_freq = freq1 + freq2;
361            let insert_pos = queue
362                .binary_search_by(|(_, f)| new_freq.cmp(f))
363                .unwrap_or_else(|pos| pos);
364            queue.insert(insert_pos, (new_id, new_freq));
365        }
366
367        // Traverse tree to assign codes and paths
368        let num_internal = nodes.len() - vocab_size;
369        let mut codes = vec![Vec::new(); vocab_size];
370        let mut paths = vec![Vec::new(); vocab_size];
371
372        // DFS traversal
373        let root_idx = nodes.len() - 1;
374        let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
375
376        while let Some((node_idx, code, path)) = stack.pop() {
377            let node = &nodes[node_idx];
378
379            if node.is_leaf {
380                codes[node.id] = code;
381                paths[node.id] = path;
382            } else {
383                // Internal node index (0-based among internal nodes)
384                let internal_idx = node.id - vocab_size;
385
386                if let Some(left_idx) = node.left {
387                    let mut left_code = code.clone();
388                    left_code.push(0);
389                    let mut left_path = path.clone();
390                    left_path.push(internal_idx);
391                    stack.push((left_idx, left_code, left_path));
392                }
393
394                if let Some(right_idx) = node.right {
395                    let mut right_code = code.clone();
396                    right_code.push(1);
397                    let mut right_path = path.clone();
398                    right_path.push(internal_idx);
399                    stack.push((right_idx, right_code, right_path));
400                }
401            }
402        }
403
404        Ok(Self {
405            codes,
406            paths,
407            num_internal,
408        })
409    }
410}
411
412/// A simplified weighted sampling table
413#[derive(Debug, Clone)]
414struct SamplingTable {
415    /// The cumulative distribution function (CDF)
416    cdf: Vec<f64>,
417    /// The weights
418    weights: Vec<f64>,
419}
420
421impl SamplingTable {
422    /// Create a new sampling table from weights
423    fn new(weights: &[f64]) -> Result<Self> {
424        if weights.is_empty() {
425            return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
426        }
427
428        // Ensure all _weights are positive
429        if weights.iter().any(|&w| w < 0.0) {
430            return Err(TextError::EmbeddingError("Weights must be positive".into()));
431        }
432
433        // Compute the CDF
434        let sum: f64 = weights.iter().sum();
435        if sum <= 0.0 {
436            return Err(TextError::EmbeddingError(
437                "Sum of _weights must be positive".into(),
438            ));
439        }
440
441        let mut cdf = Vec::with_capacity(weights.len());
442        let mut total = 0.0;
443
444        for &w in weights {
445            total += w;
446            cdf.push(total / sum);
447        }
448
449        Ok(Self {
450            cdf,
451            weights: weights.to_vec(),
452        })
453    }
454
455    /// Sample an index based on the weights
456    fn sample<R: Rng>(&self, rng: &mut R) -> usize {
457        let r = rng.random::<f64>();
458
459        // Binary search for the insertion point
460        match self.cdf.binary_search_by(|&cdf_val| {
461            cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
462        }) {
463            Ok(idx) => idx,
464            Err(idx) => idx,
465        }
466    }
467
468    /// Get the weights
469    fn weights(&self) -> &[f64] {
470        &self.weights
471    }
472}
473
474/// Word2Vec training algorithms
475#[derive(Debug, Clone, Copy, PartialEq, Eq)]
476pub enum Word2VecAlgorithm {
477    /// Continuous Bag of Words (CBOW) algorithm
478    CBOW,
479    /// Skip-gram algorithm
480    SkipGram,
481}
482
483/// Word2Vec training options
484#[derive(Debug, Clone)]
485pub struct Word2VecConfig {
486    /// Size of the word vectors
487    pub vector_size: usize,
488    /// Maximum distance between the current and predicted word within a sentence
489    pub window_size: usize,
490    /// Minimum count of words to consider for training
491    pub min_count: usize,
492    /// Number of iterations (epochs) over the corpus
493    pub epochs: usize,
494    /// Learning rate
495    pub learning_rate: f64,
496    /// Skip-gram or CBOW algorithm
497    pub algorithm: Word2VecAlgorithm,
498    /// Number of negative samples per positive sample
499    pub negative_samples: usize,
500    /// Threshold for subsampling frequent words
501    pub subsample: f64,
502    /// Batch size for training
503    pub batch_size: usize,
504    /// Whether to use hierarchical softmax instead of negative sampling
505    pub hierarchical_softmax: bool,
506}
507
508impl Default for Word2VecConfig {
509    fn default() -> Self {
510        Self {
511            vector_size: 100,
512            window_size: 5,
513            min_count: 5,
514            epochs: 5,
515            learning_rate: 0.025,
516            algorithm: Word2VecAlgorithm::SkipGram,
517            negative_samples: 5,
518            subsample: 1e-3,
519            batch_size: 128,
520            hierarchical_softmax: false,
521        }
522    }
523}
524
525/// Word2Vec model for training and using word embeddings
526///
527/// Word2Vec is an algorithm for learning vector representations of words,
528/// also known as word embeddings. These vectors capture semantic meanings
529/// of words, allowing operations like "king - man + woman" to result in
530/// a vector close to "queen".
531///
532/// This implementation supports both Continuous Bag of Words (CBOW) and
533/// Skip-gram models, with negative sampling for efficient training.
534pub struct Word2Vec {
535    /// Configuration options
536    config: Word2VecConfig,
537    /// Vocabulary
538    vocabulary: Vocabulary,
539    /// Input embeddings
540    input_embeddings: Option<Array2<f64>>,
541    /// Output embeddings
542    output_embeddings: Option<Array2<f64>>,
543    /// Tokenizer
544    tokenizer: Box<dyn Tokenizer + Send + Sync>,
545    /// Sampling table for negative sampling
546    sampling_table: Option<SamplingTable>,
547    /// Huffman tree for hierarchical softmax
548    huffman_tree: Option<HuffmanTree>,
549    /// Hierarchical softmax parameter vectors (one per internal node)
550    hs_params: Option<Array2<f64>>,
551    /// Current learning rate (gets updated during training)
552    current_learning_rate: f64,
553}
554
555impl Debug for Word2Vec {
556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        f.debug_struct("Word2Vec")
558            .field("config", &self.config)
559            .field("vocabulary", &self.vocabulary)
560            .field("input_embeddings", &self.input_embeddings)
561            .field("output_embeddings", &self.output_embeddings)
562            .field("sampling_table", &self.sampling_table)
563            .field("huffman_tree", &self.huffman_tree)
564            .field("current_learning_rate", &self.current_learning_rate)
565            .finish()
566    }
567}
568
569// Manual Clone implementation to handle the non-Clone tokenizer
570impl Default for Word2Vec {
571    fn default() -> Self {
572        Self::new()
573    }
574}
575
576impl Clone for Word2Vec {
577    fn clone(&self) -> Self {
578        let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
579
580        Self {
581            config: self.config.clone(),
582            vocabulary: self.vocabulary.clone(),
583            input_embeddings: self.input_embeddings.clone(),
584            output_embeddings: self.output_embeddings.clone(),
585            tokenizer,
586            sampling_table: self.sampling_table.clone(),
587            huffman_tree: self.huffman_tree.clone(),
588            hs_params: self.hs_params.clone(),
589            current_learning_rate: self.current_learning_rate,
590        }
591    }
592}
593
594impl Word2Vec {
595    /// Create a new Word2Vec model with default configuration
596    pub fn new() -> Self {
597        Self {
598            config: Word2VecConfig::default(),
599            vocabulary: Vocabulary::new(),
600            input_embeddings: None,
601            output_embeddings: None,
602            tokenizer: Box::new(WordTokenizer::default()),
603            sampling_table: None,
604            huffman_tree: None,
605            hs_params: None,
606            current_learning_rate: 0.025,
607        }
608    }
609
610    /// Create a new Word2Vec model with the specified configuration
611    pub fn with_config(config: Word2VecConfig) -> Self {
612        let learning_rate = config.learning_rate;
613        Self {
614            config,
615            vocabulary: Vocabulary::new(),
616            input_embeddings: None,
617            output_embeddings: None,
618            tokenizer: Box::new(WordTokenizer::default()),
619            sampling_table: None,
620            huffman_tree: None,
621            hs_params: None,
622            current_learning_rate: learning_rate,
623        }
624    }
625
626    /// Set a custom tokenizer
627    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
628        self.tokenizer = tokenizer;
629        self
630    }
631
632    /// Set vector size
633    pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
634        self.config.vector_size = vectorsize;
635        self
636    }
637
638    /// Set window size
639    pub fn with_window_size(mut self, windowsize: usize) -> Self {
640        self.config.window_size = windowsize;
641        self
642    }
643
644    /// Set minimum count
645    pub fn with_min_count(mut self, mincount: usize) -> Self {
646        self.config.min_count = mincount;
647        self
648    }
649
650    /// Set number of epochs
651    pub fn with_epochs(mut self, epochs: usize) -> Self {
652        self.config.epochs = epochs;
653        self
654    }
655
656    /// Set learning rate
657    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
658        self.config.learning_rate = learningrate;
659        self.current_learning_rate = learningrate;
660        self
661    }
662
663    /// Set algorithm (CBOW or Skip-gram)
664    pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
665        self.config.algorithm = algorithm;
666        self
667    }
668
669    /// Set number of negative samples
670    pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
671        self.config.negative_samples = negativesamples;
672        self
673    }
674
675    /// Set subsampling threshold
676    pub fn with_subsample(mut self, subsample: f64) -> Self {
677        self.config.subsample = subsample;
678        self
679    }
680
681    /// Set batch size
682    pub fn with_batch_size(mut self, batchsize: usize) -> Self {
683        self.config.batch_size = batchsize;
684        self
685    }
686
687    /// Build vocabulary from a corpus
688    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
689        if texts.is_empty() {
690            return Err(TextError::InvalidInput(
691                "No texts provided for building vocabulary".into(),
692            ));
693        }
694
695        // Count word frequencies
696        let mut word_counts = HashMap::new();
697        let mut _total_words = 0;
698
699        for &text in texts {
700            let tokens = self.tokenizer.tokenize(text)?;
701            for token in tokens {
702                *word_counts.entry(token).or_insert(0) += 1;
703                _total_words += 1;
704            }
705        }
706
707        // Create vocabulary with words that meet minimum count
708        self.vocabulary = Vocabulary::new();
709        for (word, count) in &word_counts {
710            if *count >= self.config.min_count {
711                self.vocabulary.add_token(word);
712            }
713        }
714
715        if self.vocabulary.is_empty() {
716            return Err(TextError::VocabularyError(
717                "No words meet the minimum count threshold".into(),
718            ));
719        }
720
721        // Initialize embeddings
722        let vocab_size = self.vocabulary.len();
723        let vector_size = self.config.vector_size;
724
725        // Initialize input and output embeddings with small random values
726        let mut rng = scirs2_core::random::rng();
727        let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
728            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
729        });
730        let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
731            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
732        });
733
734        self.input_embeddings = Some(input_embeddings);
735        self.output_embeddings = Some(output_embeddings);
736
737        // Create sampling table for negative sampling
738        self.create_sampling_table(&word_counts)?;
739
740        // Build Huffman tree for hierarchical softmax if configured
741        if self.config.hierarchical_softmax {
742            let frequencies: Vec<usize> = (0..vocab_size)
743                .map(|i| {
744                    self.vocabulary
745                        .get_token(i)
746                        .and_then(|word| word_counts.get(word).copied())
747                        .unwrap_or(1)
748                })
749                .collect();
750
751            let tree = HuffmanTree::build(&frequencies)?;
752            let num_internal = tree.num_internal;
753
754            // Initialize hierarchical softmax parameter vectors
755            let hs_params = Array2::zeros((num_internal, vector_size));
756            self.hs_params = Some(hs_params);
757            self.huffman_tree = Some(tree);
758        }
759
760        Ok(())
761    }
762
763    /// Create sampling table for negative sampling based on word frequencies
764    fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
765        // Prepare sampling weights (unigram distribution raised to 3/4 power)
766        let mut sampling_weights = vec![0.0; self.vocabulary.len()];
767
768        for (word, &count) in wordcounts.iter() {
769            if let Some(idx) = self.vocabulary.get_index(word) {
770                // Apply smoothing: frequency^0.75
771                sampling_weights[idx] = (count as f64).powf(0.75);
772            }
773        }
774
775        match SamplingTable::new(&sampling_weights) {
776            Ok(table) => {
777                self.sampling_table = Some(table);
778                Ok(())
779            }
780            Err(e) => Err(e),
781        }
782    }
783
784    /// Train the Word2Vec model on a corpus
785    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
786        if texts.is_empty() {
787            return Err(TextError::InvalidInput(
788                "No texts provided for training".into(),
789            ));
790        }
791
792        // Build vocabulary if not already built
793        if self.vocabulary.is_empty() {
794            self.build_vocabulary(texts)?;
795        }
796
797        if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
798            return Err(TextError::EmbeddingError(
799                "Embeddings not initialized. Call build_vocabulary() first".into(),
800            ));
801        }
802
803        // Count total number of tokens for progress tracking
804        let mut _total_tokens = 0;
805        let mut sentences = Vec::new();
806        for &text in texts {
807            let tokens = self.tokenizer.tokenize(text)?;
808            let filtered_tokens: Vec<usize> = tokens
809                .iter()
810                .filter_map(|token| self.vocabulary.get_index(token))
811                .collect();
812            if !filtered_tokens.is_empty() {
813                _total_tokens += filtered_tokens.len();
814                sentences.push(filtered_tokens);
815            }
816        }
817
818        // Train for the specified number of epochs
819        for epoch in 0..self.config.epochs {
820            // Update learning rate for this epoch
821            self.current_learning_rate =
822                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
823            self.current_learning_rate = self
824                .current_learning_rate
825                .max(self.config.learning_rate * 0.0001);
826
827            // Process each sentence
828            for sentence in &sentences {
829                // Apply subsampling of frequent words
830                let subsampled_sentence = if self.config.subsample > 0.0 {
831                    self.subsample_sentence(sentence)?
832                } else {
833                    sentence.clone()
834                };
835
836                // Skip empty sentences
837                if subsampled_sentence.is_empty() {
838                    continue;
839                }
840
841                // Train on the sentence
842                if self.config.hierarchical_softmax {
843                    // Use hierarchical softmax
844                    match self.config.algorithm {
845                        Word2VecAlgorithm::SkipGram => {
846                            self.train_skipgram_hs_sentence(&subsampled_sentence)?;
847                        }
848                        Word2VecAlgorithm::CBOW => {
849                            self.train_cbow_hs_sentence(&subsampled_sentence)?;
850                        }
851                    }
852                } else {
853                    // Use negative sampling
854                    match self.config.algorithm {
855                        Word2VecAlgorithm::CBOW => {
856                            self.train_cbow_sentence(&subsampled_sentence)?;
857                        }
858                        Word2VecAlgorithm::SkipGram => {
859                            self.train_skipgram_sentence(&subsampled_sentence)?;
860                        }
861                    }
862                }
863            }
864        }
865
866        Ok(())
867    }
868
869    /// Apply subsampling to a sentence
870    fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
871        let mut rng = scirs2_core::random::rng();
872        let total_words: f64 = self.vocabulary.len() as f64;
873        let threshold = self.config.subsample * total_words;
874
875        // Filter words based on subsampling probability
876        let subsampled: Vec<usize> = sentence
877            .iter()
878            .filter(|&&word_idx| {
879                let word_freq = self.get_word_frequency(word_idx);
880                if word_freq == 0.0 {
881                    return true; // Keep rare words
882                }
883                // Probability of keeping the word
884                let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
885                rng.random::<f64>() < keep_prob
886            })
887            .copied()
888            .collect();
889
890        Ok(subsampled)
891    }
892
893    /// Get the frequency of a word in the vocabulary
894    fn get_word_frequency(&self, wordidx: usize) -> f64 {
895        // This is a simplified version; ideal implementation would track actual frequencies
896        // For now, we'll use the sampling table weights as a proxy
897        if let Some(table) = &self.sampling_table {
898            table.weights()[wordidx]
899        } else {
900            1.0 // Default weight if no sampling table exists
901        }
902    }
903
904    /// Train CBOW model on a single sentence
905    fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
906        if sentence.len() < 2 {
907            return Ok(()); // Need at least 2 words for context
908        }
909
910        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
911        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
912        let vector_size = self.config.vector_size;
913        let window_size = self.config.window_size;
914        let negative_samples = self.config.negative_samples;
915
916        // For each position in sentence, predict the word from its context
917        for pos in 0..sentence.len() {
918            // Determine context window (with random size)
919            let mut rng = scirs2_core::random::rng();
920            let window = 1 + rng.random_range(0..window_size);
921            let target_word = sentence[pos];
922
923            // Collect context words and average their vectors
924            let mut context_words = Vec::new();
925            #[allow(clippy::needless_range_loop)]
926            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
927                if i != pos {
928                    context_words.push(sentence[i]);
929                }
930            }
931
932            if context_words.is_empty() {
933                continue; // No context words
934            }
935
936            // Average the context word vectors
937            let mut context_sum = Array1::zeros(vector_size);
938            for &context_idx in &context_words {
939                context_sum += &input_embeddings.row(context_idx);
940            }
941            let context_avg = &context_sum / context_words.len() as f64;
942
943            // Update target word's output embedding with positive example
944            let mut target_output = output_embeddings.row_mut(target_word);
945            let dot_product = (&context_avg * &target_output).sum();
946            let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
947            let error = (1.0 - sigmoid) * self.current_learning_rate;
948
949            // Create a copy for update
950            let mut target_update = target_output.to_owned();
951            target_update.scaled_add(error, &context_avg);
952            target_output.assign(&target_update);
953
954            // Negative sampling
955            if let Some(sampler) = &self.sampling_table {
956                for _ in 0..negative_samples {
957                    let negative_idx = sampler.sample(&mut rng);
958                    if negative_idx == target_word {
959                        continue; // Skip if we sample the target word
960                    }
961
962                    let mut negative_output = output_embeddings.row_mut(negative_idx);
963                    let dot_product = (&context_avg * &negative_output).sum();
964                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
965                    let error = -sigmoid * self.current_learning_rate;
966
967                    // Create a copy for update
968                    let mut negative_update = negative_output.to_owned();
969                    negative_update.scaled_add(error, &context_avg);
970                    negative_output.assign(&negative_update);
971                }
972            }
973
974            // Update context word vectors
975            for &context_idx in &context_words {
976                let mut input_vec = input_embeddings.row_mut(context_idx);
977
978                // Positive example
979                let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
980                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
981                let error =
982                    (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
983
984                // Create a copy for update
985                let mut input_update = input_vec.to_owned();
986                input_update.scaled_add(error, &output_embeddings.row(target_word));
987
988                // Negative examples
989                if let Some(sampler) = &self.sampling_table {
990                    for _ in 0..negative_samples {
991                        let negative_idx = sampler.sample(&mut rng);
992                        if negative_idx == target_word {
993                            continue;
994                        }
995
996                        let dot_product =
997                            (&context_avg * &output_embeddings.row(negative_idx)).sum();
998                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
999                        let error =
1000                            -sigmoid * self.current_learning_rate / context_words.len() as f64;
1001
1002                        input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1003                    }
1004                }
1005
1006                input_vec.assign(&input_update);
1007            }
1008        }
1009
1010        Ok(())
1011    }
1012
1013    /// Train Skip-gram model on a single sentence
1014    fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1015        if sentence.len() < 2 {
1016            return Ok(()); // Need at least 2 words for context
1017        }
1018
1019        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1020        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1021        let vector_size = self.config.vector_size;
1022        let window_size = self.config.window_size;
1023        let negative_samples = self.config.negative_samples;
1024
1025        // For each position in sentence, predict the context from the word
1026        for pos in 0..sentence.len() {
1027            // Determine context window (with random size)
1028            let mut rng = scirs2_core::random::rng();
1029            let window = 1 + rng.random_range(0..window_size);
1030            let target_word = sentence[pos];
1031
1032            // For each context position
1033            #[allow(clippy::needless_range_loop)]
1034            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1035                if i == pos {
1036                    continue; // Skip the target word itself
1037                }
1038
1039                let context_word = sentence[i];
1040                let target_input = input_embeddings.row(target_word);
1041                let mut context_output = output_embeddings.row_mut(context_word);
1042
1043                // Update context word's output embedding with positive example
1044                let dot_product = (&target_input * &context_output).sum();
1045                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1046                let error = (1.0 - sigmoid) * self.current_learning_rate;
1047
1048                // Create a copy for update
1049                let mut context_update = context_output.to_owned();
1050                context_update.scaled_add(error, &target_input);
1051                context_output.assign(&context_update);
1052
1053                // Gradient for input word vector
1054                let mut input_update = Array1::zeros(vector_size);
1055                input_update.scaled_add(error, &context_output);
1056
1057                // Negative sampling
1058                if let Some(sampler) = &self.sampling_table {
1059                    for _ in 0..negative_samples {
1060                        let negative_idx = sampler.sample(&mut rng);
1061                        if negative_idx == context_word {
1062                            continue; // Skip if we sample the context word
1063                        }
1064
1065                        let mut negative_output = output_embeddings.row_mut(negative_idx);
1066                        let dot_product = (&target_input * &negative_output).sum();
1067                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1068                        let error = -sigmoid * self.current_learning_rate;
1069
1070                        // Create a copy for update
1071                        let mut negative_update = negative_output.to_owned();
1072                        negative_update.scaled_add(error, &target_input);
1073                        negative_output.assign(&negative_update);
1074
1075                        // Update input gradient
1076                        input_update.scaled_add(error, &negative_output);
1077                    }
1078                }
1079
1080                // Apply the accumulated gradient to the input word vector
1081                let mut target_input_mut = input_embeddings.row_mut(target_word);
1082                target_input_mut += &input_update;
1083            }
1084        }
1085
1086        Ok(())
1087    }
1088
1089    /// Get the vector size
1090    pub fn vector_size(&self) -> usize {
1091        self.config.vector_size
1092    }
1093
1094    /// Get the embedding vector for a word
1095    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1096        if self.input_embeddings.is_none() {
1097            return Err(TextError::EmbeddingError(
1098                "Model not trained. Call train() first".into(),
1099            ));
1100        }
1101
1102        match self.vocabulary.get_index(word) {
1103            Some(idx) => Ok(self
1104                .input_embeddings
1105                .as_ref()
1106                .expect("Operation failed")
1107                .row(idx)
1108                .to_owned()),
1109            None => Err(TextError::VocabularyError(format!(
1110                "Word '{word}' not in vocabulary"
1111            ))),
1112        }
1113    }
1114
1115    /// Get the most similar words to a given word
1116    pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1117        let word_vec = self.get_word_vector(word)?;
1118        self.most_similar_by_vector(&word_vec, topn, &[word])
1119    }
1120
1121    /// Get the most similar words to a given vector
1122    pub fn most_similar_by_vector(
1123        &self,
1124        vector: &Array1<f64>,
1125        top_n: usize,
1126        exclude_words: &[&str],
1127    ) -> Result<Vec<(String, f64)>> {
1128        if self.input_embeddings.is_none() {
1129            return Err(TextError::EmbeddingError(
1130                "Model not trained. Call train() first".into(),
1131            ));
1132        }
1133
1134        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1135        let vocab_size = self.vocabulary.len();
1136
1137        // Create a set of indices to exclude
1138        let exclude_indices: Vec<usize> = exclude_words
1139            .iter()
1140            .filter_map(|&word| self.vocabulary.get_index(word))
1141            .collect();
1142
1143        // Calculate cosine similarity for all _words
1144        let mut similarities = Vec::with_capacity(vocab_size);
1145
1146        for i in 0..vocab_size {
1147            if exclude_indices.contains(&i) {
1148                continue;
1149            }
1150
1151            let word_vec = input_embeddings.row(i);
1152            let similarity = cosine_similarity(vector, &word_vec.to_owned());
1153
1154            if let Some(word) = self.vocabulary.get_token(i) {
1155                similarities.push((word.to_string(), similarity));
1156            }
1157        }
1158
1159        // Sort by similarity (descending)
1160        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1161
1162        // Take top N
1163        let result = similarities.into_iter().take(top_n).collect();
1164        Ok(result)
1165    }
1166
1167    /// Compute the analogy: a is to b as c is to ?
1168    pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1169        if self.input_embeddings.is_none() {
1170            return Err(TextError::EmbeddingError(
1171                "Model not trained. Call train() first".into(),
1172            ));
1173        }
1174
1175        // Get vectors for a, b, and c
1176        let a_vec = self.get_word_vector(a)?;
1177        let b_vec = self.get_word_vector(b)?;
1178        let c_vec = self.get_word_vector(c)?;
1179
1180        // Compute d_vec = b_vec - a_vec + c_vec
1181        let mut d_vec = b_vec.clone();
1182        d_vec -= &a_vec;
1183        d_vec += &c_vec;
1184
1185        // Normalize the vector
1186        let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1187        d_vec.mapv_inplace(|val| val / norm);
1188
1189        // Find most similar words to d_vec
1190        self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1191    }
1192
1193    /// Save the Word2Vec model to a file
1194    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1195        if self.input_embeddings.is_none() {
1196            return Err(TextError::EmbeddingError(
1197                "Model not trained. Call train() first".into(),
1198            ));
1199        }
1200
1201        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1202
1203        // Write header: vector_size and vocabulary size
1204        writeln!(
1205            &mut file,
1206            "{} {}",
1207            self.vocabulary.len(),
1208            self.config.vector_size
1209        )
1210        .map_err(|e| TextError::IoError(e.to_string()))?;
1211
1212        // Write each word and its vector
1213        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1214
1215        for i in 0..self.vocabulary.len() {
1216            if let Some(word) = self.vocabulary.get_token(i) {
1217                // Write the word
1218                write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1219
1220                // Write the vector components
1221                let vector = input_embeddings.row(i);
1222                for j in 0..self.config.vector_size {
1223                    write!(&mut file, "{:.6} ", vector[j])
1224                        .map_err(|e| TextError::IoError(e.to_string()))?;
1225                }
1226
1227                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1228            }
1229        }
1230
1231        Ok(())
1232    }
1233
1234    /// Load a Word2Vec model from a file
1235    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1236        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1237        let mut reader = BufReader::new(file);
1238
1239        // Read header
1240        let mut header = String::new();
1241        reader
1242            .read_line(&mut header)
1243            .map_err(|e| TextError::IoError(e.to_string()))?;
1244
1245        let parts: Vec<&str> = header.split_whitespace().collect();
1246        if parts.len() != 2 {
1247            return Err(TextError::EmbeddingError(
1248                "Invalid model file format".into(),
1249            ));
1250        }
1251
1252        let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1253            TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1254        })?;
1255
1256        let vector_size = parts[1]
1257            .parse::<usize>()
1258            .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1259
1260        // Initialize model
1261        let mut model = Self::new().with_vector_size(vector_size);
1262        let mut vocabulary = Vocabulary::new();
1263        let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1264
1265        // Read each word and its vector
1266        let mut i = 0;
1267        for line in reader.lines() {
1268            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1269            let parts: Vec<&str> = line.split_whitespace().collect();
1270
1271            if parts.len() != vector_size + 1 {
1272                let line_num = i + 2;
1273                return Err(TextError::EmbeddingError(format!(
1274                    "Invalid vector format at line {line_num}"
1275                )));
1276            }
1277
1278            let word = parts[0];
1279            vocabulary.add_token(word);
1280
1281            for j in 0..vector_size {
1282                input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1283                    TextError::EmbeddingError(format!(
1284                        "Invalid vector component at line {}, position {}",
1285                        i + 2,
1286                        j + 1
1287                    ))
1288                })?;
1289            }
1290
1291            i += 1;
1292        }
1293
1294        if i != vocab_size {
1295            return Err(TextError::EmbeddingError(format!(
1296                "Expected {vocab_size} words but found {i}"
1297            )));
1298        }
1299
1300        model.vocabulary = vocabulary;
1301        model.input_embeddings = Some(input_embeddings);
1302        model.output_embeddings = None; // Only input embeddings are saved
1303
1304        Ok(model)
1305    }
1306
1307    // Getter methods for model registry serialization
1308
1309    /// Get the vocabulary as a vector of strings
1310    pub fn get_vocabulary(&self) -> Vec<String> {
1311        let mut vocab = Vec::new();
1312        for i in 0..self.vocabulary.len() {
1313            if let Some(token) = self.vocabulary.get_token(i) {
1314                vocab.push(token.to_string());
1315            }
1316        }
1317        vocab
1318    }
1319
1320    /// Get the vector size
1321    pub fn get_vector_size(&self) -> usize {
1322        self.config.vector_size
1323    }
1324
1325    /// Get the algorithm
1326    pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1327        self.config.algorithm
1328    }
1329
1330    /// Get the window size
1331    pub fn get_window_size(&self) -> usize {
1332        self.config.window_size
1333    }
1334
1335    /// Get the minimum count
1336    pub fn get_min_count(&self) -> usize {
1337        self.config.min_count
1338    }
1339
1340    /// Get the embeddings matrix (input embeddings)
1341    pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1342        self.input_embeddings.clone()
1343    }
1344
1345    /// Get the number of negative samples
1346    pub fn get_negative_samples(&self) -> usize {
1347        self.config.negative_samples
1348    }
1349
1350    /// Get the learning rate
1351    pub fn get_learning_rate(&self) -> f64 {
1352        self.config.learning_rate
1353    }
1354
1355    /// Get the number of epochs
1356    pub fn get_epochs(&self) -> usize {
1357        self.config.epochs
1358    }
1359
1360    /// Get the subsampling threshold
1361    pub fn get_subsampling_threshold(&self) -> f64 {
1362        self.config.subsample
1363    }
1364
1365    /// Check if hierarchical softmax is enabled
1366    pub fn uses_hierarchical_softmax(&self) -> bool {
1367        self.config.hierarchical_softmax
1368    }
1369
1370    // ─── Hierarchical Softmax Training ──────────────────────────────────
1371
1372    /// Train Skip-gram with hierarchical softmax on a single sentence
1373    fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1374        if sentence.len() < 2 {
1375            return Ok(());
1376        }
1377
1378        let input_embeddings = self
1379            .input_embeddings
1380            .as_mut()
1381            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1382        let hs_params = self
1383            .hs_params
1384            .as_mut()
1385            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1386        let tree = self
1387            .huffman_tree
1388            .as_ref()
1389            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1390
1391        let vector_size = self.config.vector_size;
1392        let window_size = self.config.window_size;
1393        let lr = self.current_learning_rate;
1394
1395        let codes = tree.codes.clone();
1396        let paths = tree.paths.clone();
1397
1398        let mut rng = scirs2_core::random::rng();
1399
1400        for pos in 0..sentence.len() {
1401            let window = 1 + rng.random_range(0..window_size);
1402            let target_word = sentence[pos];
1403
1404            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1405                if i == pos {
1406                    continue;
1407                }
1408
1409                let context_word = sentence[i];
1410                let code = &codes[context_word];
1411                let path = &paths[context_word];
1412
1413                let mut grad_input = Array1::zeros(vector_size);
1414
1415                // Walk the Huffman tree path for the context word
1416                for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1417                    if node_idx >= hs_params.nrows() {
1418                        continue;
1419                    }
1420
1421                    // Compute sigmoid(input_vec . hs_param)
1422                    let input_vec = input_embeddings.row(target_word);
1423                    let param_vec = hs_params.row(node_idx);
1424
1425                    let dot: f64 = input_vec
1426                        .iter()
1427                        .zip(param_vec.iter())
1428                        .map(|(a, b)| a * b)
1429                        .sum();
1430                    let sigmoid = 1.0 / (1.0 + (-dot).exp());
1431
1432                    // gradient: (1 - label - sigmoid) * lr
1433                    let target = if label == 0 { 1.0 } else { 0.0 };
1434                    let gradient = (target - sigmoid) * lr;
1435
1436                    // Update gradient for input vector
1437                    grad_input.scaled_add(gradient, &param_vec.to_owned());
1438
1439                    // Update HS parameter vector
1440                    let input_owned = input_vec.to_owned();
1441                    let mut param_mut = hs_params.row_mut(node_idx);
1442                    param_mut.scaled_add(gradient, &input_owned);
1443                }
1444
1445                // Update input embedding
1446                let mut input_mut = input_embeddings.row_mut(target_word);
1447                input_mut += &grad_input;
1448            }
1449        }
1450
1451        Ok(())
1452    }
1453
1454    /// Train CBOW with hierarchical softmax on a single sentence
1455    fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1456        if sentence.len() < 2 {
1457            return Ok(());
1458        }
1459
1460        let input_embeddings = self
1461            .input_embeddings
1462            .as_mut()
1463            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1464        let hs_params = self
1465            .hs_params
1466            .as_mut()
1467            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1468        let tree = self
1469            .huffman_tree
1470            .as_ref()
1471            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1472
1473        let vector_size = self.config.vector_size;
1474        let window_size = self.config.window_size;
1475        let lr = self.current_learning_rate;
1476
1477        let codes = tree.codes.clone();
1478        let paths = tree.paths.clone();
1479
1480        let mut rng = scirs2_core::random::rng();
1481
1482        for pos in 0..sentence.len() {
1483            let window = 1 + rng.random_range(0..window_size);
1484            let target_word = sentence[pos];
1485
1486            // Collect context words
1487            let mut context_words = Vec::new();
1488            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1489                if i != pos {
1490                    context_words.push(sentence[i]);
1491                }
1492            }
1493
1494            if context_words.is_empty() {
1495                continue;
1496            }
1497
1498            // Average context word vectors
1499            let mut context_avg = Array1::zeros(vector_size);
1500            for &ctx_idx in &context_words {
1501                context_avg += &input_embeddings.row(ctx_idx);
1502            }
1503            context_avg /= context_words.len() as f64;
1504
1505            // Walk Huffman path for target word
1506            let code = &codes[target_word];
1507            let path = &paths[target_word];
1508
1509            let mut grad_context = Array1::zeros(vector_size);
1510
1511            for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1512                if node_idx >= hs_params.nrows() {
1513                    continue;
1514                }
1515
1516                let param_vec = hs_params.row(node_idx);
1517
1518                let dot: f64 = context_avg
1519                    .iter()
1520                    .zip(param_vec.iter())
1521                    .map(|(a, b)| a * b)
1522                    .sum();
1523                let sigmoid = 1.0 / (1.0 + (-dot).exp());
1524
1525                let target = if label == 0 { 1.0 } else { 0.0 };
1526                let gradient = (target - sigmoid) * lr;
1527
1528                grad_context.scaled_add(gradient, &param_vec.to_owned());
1529
1530                // Update HS parameter
1531                let ctx_owned = context_avg.clone();
1532                let mut param_mut = hs_params.row_mut(node_idx);
1533                param_mut.scaled_add(gradient, &ctx_owned);
1534            }
1535
1536            // Distribute gradient back to context word input embeddings
1537            let grad_per_word = &grad_context / context_words.len() as f64;
1538            for &ctx_idx in &context_words {
1539                let mut input_mut = input_embeddings.row_mut(ctx_idx);
1540                input_mut += &grad_per_word;
1541            }
1542        }
1543
1544        Ok(())
1545    }
1546}
1547
1548// ─── WordEmbedding trait implementation for Word2Vec ─────────────────────────
1549
1550impl WordEmbedding for Word2Vec {
1551    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1552        self.get_word_vector(word)
1553    }
1554
1555    fn dimension(&self) -> usize {
1556        self.vector_size()
1557    }
1558
1559    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1560        self.most_similar(word, top_n)
1561    }
1562
1563    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1564        self.analogy(a, b, c, top_n)
1565    }
1566
1567    fn vocab_size(&self) -> usize {
1568        self.vocabulary.len()
1569    }
1570}
1571
1572/// Calculate cosine similarity between two vectors
1573#[allow(dead_code)]
1574pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1575    let dot_product = (a * b).sum();
1576    let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1577    let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1578
1579    if norm_a > 0.0 && norm_b > 0.0 {
1580        dot_product / (norm_a * norm_b)
1581    } else {
1582        0.0
1583    }
1584}
1585
1586#[cfg(test)]
1587mod tests {
1588    use super::*;
1589    use approx::assert_relative_eq;
1590
1591    #[test]
1592    fn test_cosine_similarity() {
1593        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1594        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1595
1596        let similarity = cosine_similarity(&a, &b);
1597        let expected = 0.9746318461970762;
1598        assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1599    }
1600
1601    #[test]
1602    fn test_word2vec_config() {
1603        let config = Word2VecConfig::default();
1604        assert_eq!(config.vector_size, 100);
1605        assert_eq!(config.window_size, 5);
1606        assert_eq!(config.min_count, 5);
1607        assert_eq!(config.epochs, 5);
1608        assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1609    }
1610
1611    #[test]
1612    fn test_word2vec_builder() {
1613        let model = Word2Vec::new()
1614            .with_vector_size(200)
1615            .with_window_size(10)
1616            .with_learning_rate(0.05)
1617            .with_algorithm(Word2VecAlgorithm::CBOW);
1618
1619        assert_eq!(model.config.vector_size, 200);
1620        assert_eq!(model.config.window_size, 10);
1621        assert_eq!(model.config.learning_rate, 0.05);
1622        assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1623    }
1624
1625    #[test]
1626    fn test_build_vocabulary() {
1627        let texts = [
1628            "the quick brown fox jumps over the lazy dog",
1629            "a quick brown fox jumps over a lazy dog",
1630        ];
1631
1632        let mut model = Word2Vec::new().with_min_count(1);
1633        let result = model.build_vocabulary(&texts);
1634        assert!(result.is_ok());
1635
1636        // Check vocabulary size (unique words: "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "a")
1637        assert_eq!(model.vocabulary.len(), 9);
1638
1639        // Check that embeddings were initialized
1640        assert!(model.input_embeddings.is_some());
1641        assert!(model.output_embeddings.is_some());
1642        assert_eq!(
1643            model
1644                .input_embeddings
1645                .as_ref()
1646                .expect("Operation failed")
1647                .shape(),
1648            &[9, 100]
1649        );
1650    }
1651
1652    #[test]
1653    fn test_skipgram_training_small() {
1654        let texts = [
1655            "the quick brown fox jumps over the lazy dog",
1656            "a quick brown fox jumps over a lazy dog",
1657        ];
1658
1659        let mut model = Word2Vec::new()
1660            .with_vector_size(10)
1661            .with_window_size(2)
1662            .with_min_count(1)
1663            .with_epochs(1)
1664            .with_algorithm(Word2VecAlgorithm::SkipGram);
1665
1666        let result = model.train(&texts);
1667        assert!(result.is_ok());
1668
1669        // Test getting a word vector
1670        let result = model.get_word_vector("fox");
1671        assert!(result.is_ok());
1672        let vec = result.expect("Operation failed");
1673        assert_eq!(vec.len(), 10);
1674    }
1675
1676    // ─── Hierarchical Softmax Tests ──────────────────────────────────
1677
1678    #[test]
1679    fn test_huffman_tree_build() {
1680        let frequencies = vec![5, 3, 8, 1, 2];
1681        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1682
1683        // Each word should have a code
1684        assert_eq!(tree.codes.len(), 5);
1685        assert_eq!(tree.paths.len(), 5);
1686
1687        // All codes should be non-empty
1688        for code in &tree.codes {
1689            assert!(!code.is_empty());
1690        }
1691
1692        // Internal nodes = vocab_size - 1 (for a binary tree)
1693        assert_eq!(tree.num_internal, 4);
1694    }
1695
1696    #[test]
1697    fn test_huffman_tree_single_word() {
1698        let frequencies = vec![10];
1699        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1700        assert_eq!(tree.codes.len(), 1);
1701        assert_eq!(tree.paths.len(), 1);
1702    }
1703
1704    #[test]
1705    fn test_skipgram_hierarchical_softmax() {
1706        let texts = [
1707            "the quick brown fox jumps over the lazy dog",
1708            "a quick brown fox jumps over a lazy dog",
1709        ];
1710
1711        let config = Word2VecConfig {
1712            vector_size: 10,
1713            window_size: 2,
1714            min_count: 1,
1715            epochs: 3,
1716            learning_rate: 0.025,
1717            algorithm: Word2VecAlgorithm::SkipGram,
1718            hierarchical_softmax: true,
1719            ..Default::default()
1720        };
1721
1722        let mut model = Word2Vec::with_config(config);
1723        let result = model.train(&texts);
1724        assert!(
1725            result.is_ok(),
1726            "HS skipgram training failed: {:?}",
1727            result.err()
1728        );
1729
1730        assert!(model.uses_hierarchical_softmax());
1731
1732        // Should produce valid word vectors
1733        let vec = model.get_word_vector("fox");
1734        assert!(vec.is_ok());
1735        assert_eq!(vec.expect("get vec").len(), 10);
1736    }
1737
1738    #[test]
1739    fn test_cbow_hierarchical_softmax() {
1740        let texts = [
1741            "the quick brown fox jumps over the lazy dog",
1742            "a quick brown fox jumps over a lazy dog",
1743        ];
1744
1745        let config = Word2VecConfig {
1746            vector_size: 10,
1747            window_size: 2,
1748            min_count: 1,
1749            epochs: 3,
1750            learning_rate: 0.025,
1751            algorithm: Word2VecAlgorithm::CBOW,
1752            hierarchical_softmax: true,
1753            ..Default::default()
1754        };
1755
1756        let mut model = Word2Vec::with_config(config);
1757        let result = model.train(&texts);
1758        assert!(
1759            result.is_ok(),
1760            "HS CBOW training failed: {:?}",
1761            result.err()
1762        );
1763
1764        let vec = model.get_word_vector("dog");
1765        assert!(vec.is_ok());
1766    }
1767
1768    // ─── WordEmbedding Trait Tests ──────────────────────────────────
1769
1770    #[test]
1771    fn test_word_embedding_trait_word2vec() {
1772        let texts = [
1773            "the quick brown fox jumps over the lazy dog",
1774            "a quick brown fox jumps over a lazy dog",
1775        ];
1776
1777        let mut model = Word2Vec::new()
1778            .with_vector_size(10)
1779            .with_min_count(1)
1780            .with_epochs(1);
1781
1782        model.train(&texts).expect("Training failed");
1783
1784        // Use via trait
1785        let emb: &dyn WordEmbedding = &model;
1786        assert_eq!(emb.dimension(), 10);
1787        assert!(emb.vocab_size() > 0);
1788
1789        let vec = emb.embedding("fox");
1790        assert!(vec.is_ok());
1791
1792        let sim = emb.similarity("fox", "dog");
1793        assert!(sim.is_ok());
1794        assert!(sim.expect("sim").is_finite());
1795
1796        let similar = emb.find_similar("fox", 2);
1797        assert!(similar.is_ok());
1798
1799        let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1800        assert!(analogy.is_ok());
1801    }
1802
1803    #[test]
1804    fn test_embedding_cosine_similarity_fn() {
1805        let a = Array1::from_vec(vec![1.0, 0.0]);
1806        let b = Array1::from_vec(vec![0.0, 1.0]);
1807        assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1808
1809        let c = Array1::from_vec(vec![1.0, 1.0]);
1810        let d = Array1::from_vec(vec![1.0, 1.0]);
1811        assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1812    }
1813
1814    #[test]
1815    fn test_pairwise_similarity_fn() {
1816        let texts = ["the quick brown fox", "the lazy brown dog"];
1817
1818        let mut model = Word2Vec::new()
1819            .with_vector_size(10)
1820            .with_min_count(1)
1821            .with_epochs(1);
1822        model.train(&texts).expect("Training failed");
1823
1824        let words = vec!["the", "fox", "dog"];
1825        let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1826
1827        assert_eq!(matrix.len(), 3);
1828        assert_eq!(matrix[0].len(), 3);
1829
1830        // Diagonal should be 1.0
1831        for i in 0..3 {
1832            assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1833        }
1834
1835        // Symmetric
1836        for i in 0..3 {
1837            for j in 0..3 {
1838                assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1839            }
1840        }
1841    }
1842}