Skip to main content

scirs2_text/embeddings/
mod.rs

1//! # Word Embeddings Module
2//!
3//! This module provides comprehensive implementations for word embeddings, including
4//! Word2Vec, GloVe, and FastText.
5//!
6//! ## Overview
7//!
8//! Word embeddings are dense vector representations of words that capture semantic
9//! relationships. This module implements:
10//!
11//! - **Word2Vec Skip-gram**: Predicts context words from a target word
12//! - **Word2Vec CBOW**: Predicts a target word from context words
13//! - **GloVe**: Global Vectors for Word Representation
14//! - **FastText**: Character n-gram based embeddings (handles OOV words)
15//! - **Negative sampling**: Efficient training technique for large vocabularies
16//! - **Hierarchical softmax**: Alternative to negative sampling for optimization
17//!
18//! ## Quick Start
19//!
20//! ```rust
21//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
22//!
23//! // Basic Word2Vec training
24//! let documents = vec![
25//!     "the quick brown fox jumps over the lazy dog",
26//!     "the dog was lazy but the fox was quick",
27//!     "brown fox and lazy dog are common phrases"
28//! ];
29//!
30//! let config = Word2VecConfig {
31//!     algorithm: Word2VecAlgorithm::SkipGram,
32//!     vector_size: 100,
33//!     window_size: 5,
34//!     min_count: 1,
35//!     learning_rate: 0.025,
36//!     epochs: 5,
37//!     negative_samples: 5,
38//!     ..Default::default()
39//! };
40//!
41//! let mut model = Word2Vec::with_config(config);
42//! model.train(&documents).expect("Training failed");
43//!
44//! // Get word vector
45//! if let Ok(vector) = model.get_word_vector("fox") {
46//!     println!("Vector for 'fox': {:?}", vector);
47//! }
48//!
49//! // Find similar words
50//! let similar = model.most_similar("fox", 3).expect("Failed to find similar words");
51//! for (word, similarity) in similar {
52//!     println!("{}: {:.4}", word, similarity);
53//! }
54//! ```
55//!
56//! ## Advanced Usage
57//!
58//! ### Custom Configuration
59//!
60//! ```rust
61//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
62//!
63//! let config = Word2VecConfig {
64//!     algorithm: Word2VecAlgorithm::CBOW,
65//!     vector_size: 300,        // Larger vectors for better quality
66//!     window_size: 10,         // Larger context window
67//!     min_count: 5,           // Filter rare words
68//!     learning_rate: 0.01,    // Lower learning rate for stability
69//!     epochs: 15,             // More training iterations
70//!     negative_samples: 10,   // More negative samples
71//!     subsample: 1e-4, // Subsample frequent words
72//!     batch_size: 128,
73//!     hierarchical_softmax: false,
74//! };
75//! ```
76//!
77//! ### Incremental Training
78//!
79//! ```rust
80//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
81//! # let mut model = Word2Vec::new().with_min_count(1);
82//! // Initial training
83//! let batch1 = vec!["the quick brown fox jumps over the lazy dog"];
84//! model.train(&batch1).expect("Training failed");
85//!
86//! // Continue training with new data
87//! let batch2 = vec!["the dog was lazy but the fox was quick"];
88//! model.train(&batch2).expect("Training failed");
89//! ```
90//!
91//! ### Saving and Loading Models
92//!
93//! ```rust
94//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
95//! # let mut model = Word2Vec::new().with_min_count(1);
96//! # let texts = vec!["the quick brown fox jumps over the lazy dog"];
97//! # model.train(&texts).expect("Training failed");
98//! // Save trained model
99//! model.save("my_model.w2v").expect("Failed to save model");
100//!
101//! // Load model
102//! let loaded_model = Word2Vec::load("my_model.w2v")
103//!     .expect("Failed to load model");
104//! ```
105//!
106//! ## Performance Tips
107//!
108//! 1. **Vocabulary Size**: Use `min_count` to filter rare words and reduce memory usage
109//! 2. **Vector Dimensions**: Balance between quality (higher dimensions) and speed (lower dimensions)
110//! 3. **Training Algorithm**: Skip-gram works better with rare words, CBOW is faster
111//! 4. **Negative Sampling**: Usually faster than hierarchical softmax for large vocabularies
112//! 5. **Subsampling**: Set `subsample_threshold` to handle frequent words efficiently
113//!
114//! ## Mathematical Background
115//!
116//! ### Skip-gram Model
117//!
118//! The Skip-gram model maximizes the probability of context words given a target word:
119//!
120//! P(context|target) = ∏ P(w_context|w_target)
121//!
122//! ### CBOW Model
123//!
124//! The CBOW model predicts a target word from its context:
125//!
126//! P(target|context) = P(w_target|w_context1, w_context2, ...)
127//!
128//! ### Negative Sampling
129//!
130//! Instead of computing the full softmax, negative sampling approximates it by
131//! sampling negative examples, making training much more efficient.
132
133// Sub-modules
134pub mod contrastive;
135pub mod crosslingual;
136pub mod fasttext;
137pub mod glove;
138pub mod sentence;
139pub mod sentence_encoder;
140pub mod universal;
141
142// Re-export
143pub use fasttext::{FastText, FastTextConfig};
144pub use glove::{
145    cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
146    GloVeTrainerConfig,
147};
148
149use crate::error::{Result, TextError};
150use crate::tokenize::{Tokenizer, WordTokenizer};
151use crate::vocabulary::Vocabulary;
152use scirs2_core::ndarray::{Array1, Array2};
153use scirs2_core::random::prelude::*;
154use std::collections::HashMap;
155use std::fmt::Debug;
156use std::fs::File;
157use std::io::{BufRead, BufReader, Write};
158use std::path::Path;
159
160// ─── WordEmbedding trait ─────────────────────────────────────────────────────
161
162/// Shared trait for word embedding models
163///
164/// Provides a common interface for querying word vectors, computing similarity,
165/// and solving analogies across different embedding implementations
166/// (Word2Vec, GloVe, FastText).
167pub trait WordEmbedding {
168    /// Get the embedding vector for a word
169    fn embedding(&self, word: &str) -> Result<Array1<f64>>;
170
171    /// Get the dimensionality of the embedding vectors
172    fn dimension(&self) -> usize;
173
174    /// Compute cosine similarity between two words
175    fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
176        let v1 = self.embedding(word1)?;
177        let v2 = self.embedding(word2)?;
178        Ok(embedding_cosine_similarity(&v1, &v2))
179    }
180
181    /// Find the top-N most similar words to the query word
182    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
183
184    /// Solve the analogy: a is to b as c is to ?
185    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
186
187    /// Get the number of words in the vocabulary
188    fn vocab_size(&self) -> usize;
189}
190
191/// Calculate cosine similarity between two embedding vectors
192pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
193    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
194    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
195    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
196
197    if norm_a > 0.0 && norm_b > 0.0 {
198        dot_product / (norm_a * norm_b)
199    } else {
200        0.0
201    }
202}
203
204/// Compute pairwise cosine similarity matrix for a list of words
205pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
206    let vectors: Vec<Array1<f64>> = words
207        .iter()
208        .map(|&w| model.embedding(w))
209        .collect::<Result<Vec<_>>>()?;
210
211    let n = vectors.len();
212    let mut matrix = vec![vec![0.0; n]; n];
213
214    for i in 0..n {
215        for j in i..n {
216            let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
217            matrix[i][j] = sim;
218            matrix[j][i] = sim;
219        }
220    }
221
222    Ok(matrix)
223}
224
225// ─── WordEmbedding implementations ──────────────────────────────────────────
226
227impl WordEmbedding for GloVe {
228    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
229        self.get_word_vector(word)
230    }
231
232    fn dimension(&self) -> usize {
233        self.vector_size()
234    }
235
236    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
237        self.most_similar(word, top_n)
238    }
239
240    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
241        self.analogy(a, b, c, top_n)
242    }
243
244    fn vocab_size(&self) -> usize {
245        self.vocabulary_size()
246    }
247}
248
249impl WordEmbedding for FastText {
250    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
251        self.get_word_vector(word)
252    }
253
254    fn dimension(&self) -> usize {
255        self.vector_size()
256    }
257
258    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
259        self.most_similar(word, top_n)
260    }
261
262    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
263        self.analogy(a, b, c, top_n)
264    }
265
266    fn vocab_size(&self) -> usize {
267        self.vocabulary_size()
268    }
269}
270
271// ─── Huffman Tree for Hierarchical Softmax ──────────────────────────────────
272
273/// A node in the Huffman tree used for hierarchical softmax
274#[derive(Debug, Clone)]
275struct HuffmanNode {
276    /// Index in vocabulary (for leaf nodes) or internal node id
277    id: usize,
278    /// Frequency (for building the tree)
279    frequency: usize,
280    /// Left child index in the nodes array
281    left: Option<usize>,
282    /// Right child index in the nodes array
283    right: Option<usize>,
284    /// Is this a leaf node?
285    is_leaf: bool,
286}
287
288/// Huffman tree for hierarchical softmax training
289#[derive(Debug, Clone)]
290struct HuffmanTree {
291    /// Huffman codes for each vocabulary word: sequence of 0/1 decisions
292    codes: Vec<Vec<u8>>,
293    /// Path of internal node indices for each vocabulary word
294    paths: Vec<Vec<usize>>,
295    /// Number of internal nodes
296    num_internal: usize,
297}
298
299impl HuffmanTree {
300    /// Build a Huffman tree from word frequencies
301    ///
302    /// Returns codes and paths for each word in the vocabulary.
303    fn build(frequencies: &[usize]) -> Result<Self> {
304        let vocab_size = frequencies.len();
305        if vocab_size == 0 {
306            return Err(TextError::EmbeddingError(
307                "Cannot build Huffman tree with empty vocabulary".into(),
308            ));
309        }
310        if vocab_size == 1 {
311            // Special case: single word
312            return Ok(Self {
313                codes: vec![vec![0]],
314                paths: vec![vec![0]],
315                num_internal: 1,
316            });
317        }
318
319        // Create leaf nodes
320        let mut nodes: Vec<HuffmanNode> = frequencies
321            .iter()
322            .enumerate()
323            .map(|(id, &freq)| HuffmanNode {
324                id,
325                frequency: freq.max(1), // Avoid zero frequency
326                left: None,
327                right: None,
328                is_leaf: true,
329            })
330            .collect();
331
332        // Priority queue simulation using a sorted list
333        // (index_in_nodes, frequency)
334        let mut queue: Vec<(usize, usize)> = nodes
335            .iter()
336            .enumerate()
337            .map(|(i, n)| (i, n.frequency))
338            .collect();
339        queue.sort_by_key(|item| std::cmp::Reverse(item.1)); // Sort descending (pop from end = min)
340
341        // Build the tree bottom-up
342        while queue.len() > 1 {
343            // Pop two smallest
344            let (idx1, freq1) = queue
345                .pop()
346                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
347            let (idx2, freq2) = queue
348                .pop()
349                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
350
351            let new_id = nodes.len();
352            let new_node = HuffmanNode {
353                id: new_id,
354                frequency: freq1 + freq2,
355                left: Some(idx1),
356                right: Some(idx2),
357                is_leaf: false,
358            };
359            nodes.push(new_node);
360
361            // Insert new node maintaining sorted order
362            let new_freq = freq1 + freq2;
363            let insert_pos = queue
364                .binary_search_by(|(_, f)| new_freq.cmp(f))
365                .unwrap_or_else(|pos| pos);
366            queue.insert(insert_pos, (new_id, new_freq));
367        }
368
369        // Traverse tree to assign codes and paths
370        let num_internal = nodes.len() - vocab_size;
371        let mut codes = vec![Vec::new(); vocab_size];
372        let mut paths = vec![Vec::new(); vocab_size];
373
374        // DFS traversal
375        let root_idx = nodes.len() - 1;
376        let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
377
378        while let Some((node_idx, code, path)) = stack.pop() {
379            let node = &nodes[node_idx];
380
381            if node.is_leaf {
382                codes[node.id] = code;
383                paths[node.id] = path;
384            } else {
385                // Internal node index (0-based among internal nodes)
386                let internal_idx = node.id - vocab_size;
387
388                if let Some(left_idx) = node.left {
389                    let mut left_code = code.clone();
390                    left_code.push(0);
391                    let mut left_path = path.clone();
392                    left_path.push(internal_idx);
393                    stack.push((left_idx, left_code, left_path));
394                }
395
396                if let Some(right_idx) = node.right {
397                    let mut right_code = code.clone();
398                    right_code.push(1);
399                    let mut right_path = path.clone();
400                    right_path.push(internal_idx);
401                    stack.push((right_idx, right_code, right_path));
402                }
403            }
404        }
405
406        Ok(Self {
407            codes,
408            paths,
409            num_internal,
410        })
411    }
412}
413
414/// A simplified weighted sampling table
415#[derive(Debug, Clone)]
416struct SamplingTable {
417    /// The cumulative distribution function (CDF)
418    cdf: Vec<f64>,
419    /// The weights
420    weights: Vec<f64>,
421}
422
423impl SamplingTable {
424    /// Create a new sampling table from weights
425    fn new(weights: &[f64]) -> Result<Self> {
426        if weights.is_empty() {
427            return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
428        }
429
430        // Ensure all _weights are positive
431        if weights.iter().any(|&w| w < 0.0) {
432            return Err(TextError::EmbeddingError("Weights must be positive".into()));
433        }
434
435        // Compute the CDF
436        let sum: f64 = weights.iter().sum();
437        if sum <= 0.0 {
438            return Err(TextError::EmbeddingError(
439                "Sum of _weights must be positive".into(),
440            ));
441        }
442
443        let mut cdf = Vec::with_capacity(weights.len());
444        let mut total = 0.0;
445
446        for &w in weights {
447            total += w;
448            cdf.push(total / sum);
449        }
450
451        Ok(Self {
452            cdf,
453            weights: weights.to_vec(),
454        })
455    }
456
457    /// Sample an index based on the weights
458    fn sample<R: Rng>(&self, rng: &mut R) -> usize {
459        let r = rng.random::<f64>();
460
461        // Binary search for the insertion point
462        match self.cdf.binary_search_by(|&cdf_val| {
463            cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
464        }) {
465            Ok(idx) => idx,
466            Err(idx) => idx,
467        }
468    }
469
470    /// Get the weights
471    fn weights(&self) -> &[f64] {
472        &self.weights
473    }
474}
475
476/// Word2Vec training algorithms
477#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478pub enum Word2VecAlgorithm {
479    /// Continuous Bag of Words (CBOW) algorithm
480    CBOW,
481    /// Skip-gram algorithm
482    SkipGram,
483}
484
485/// Word2Vec training options
486#[derive(Debug, Clone)]
487pub struct Word2VecConfig {
488    /// Size of the word vectors
489    pub vector_size: usize,
490    /// Maximum distance between the current and predicted word within a sentence
491    pub window_size: usize,
492    /// Minimum count of words to consider for training
493    pub min_count: usize,
494    /// Number of iterations (epochs) over the corpus
495    pub epochs: usize,
496    /// Learning rate
497    pub learning_rate: f64,
498    /// Skip-gram or CBOW algorithm
499    pub algorithm: Word2VecAlgorithm,
500    /// Number of negative samples per positive sample
501    pub negative_samples: usize,
502    /// Threshold for subsampling frequent words
503    pub subsample: f64,
504    /// Batch size for training
505    pub batch_size: usize,
506    /// Whether to use hierarchical softmax instead of negative sampling
507    pub hierarchical_softmax: bool,
508}
509
510impl Default for Word2VecConfig {
511    fn default() -> Self {
512        Self {
513            vector_size: 100,
514            window_size: 5,
515            min_count: 5,
516            epochs: 5,
517            learning_rate: 0.025,
518            algorithm: Word2VecAlgorithm::SkipGram,
519            negative_samples: 5,
520            subsample: 1e-3,
521            batch_size: 128,
522            hierarchical_softmax: false,
523        }
524    }
525}
526
527/// Word2Vec model for training and using word embeddings
528///
529/// Word2Vec is an algorithm for learning vector representations of words,
530/// also known as word embeddings. These vectors capture semantic meanings
531/// of words, allowing operations like "king - man + woman" to result in
532/// a vector close to "queen".
533///
534/// This implementation supports both Continuous Bag of Words (CBOW) and
535/// Skip-gram models, with negative sampling for efficient training.
536pub struct Word2Vec {
537    /// Configuration options
538    config: Word2VecConfig,
539    /// Vocabulary
540    vocabulary: Vocabulary,
541    /// Input embeddings
542    input_embeddings: Option<Array2<f64>>,
543    /// Output embeddings
544    output_embeddings: Option<Array2<f64>>,
545    /// Tokenizer
546    tokenizer: Box<dyn Tokenizer + Send + Sync>,
547    /// Sampling table for negative sampling
548    sampling_table: Option<SamplingTable>,
549    /// Huffman tree for hierarchical softmax
550    huffman_tree: Option<HuffmanTree>,
551    /// Hierarchical softmax parameter vectors (one per internal node)
552    hs_params: Option<Array2<f64>>,
553    /// Current learning rate (gets updated during training)
554    current_learning_rate: f64,
555}
556
557impl Debug for Word2Vec {
558    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559        f.debug_struct("Word2Vec")
560            .field("config", &self.config)
561            .field("vocabulary", &self.vocabulary)
562            .field("input_embeddings", &self.input_embeddings)
563            .field("output_embeddings", &self.output_embeddings)
564            .field("sampling_table", &self.sampling_table)
565            .field("huffman_tree", &self.huffman_tree)
566            .field("current_learning_rate", &self.current_learning_rate)
567            .finish()
568    }
569}
570
571// Manual Clone implementation to handle the non-Clone tokenizer
572impl Default for Word2Vec {
573    fn default() -> Self {
574        Self::new()
575    }
576}
577
578impl Clone for Word2Vec {
579    fn clone(&self) -> Self {
580        let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
581
582        Self {
583            config: self.config.clone(),
584            vocabulary: self.vocabulary.clone(),
585            input_embeddings: self.input_embeddings.clone(),
586            output_embeddings: self.output_embeddings.clone(),
587            tokenizer,
588            sampling_table: self.sampling_table.clone(),
589            huffman_tree: self.huffman_tree.clone(),
590            hs_params: self.hs_params.clone(),
591            current_learning_rate: self.current_learning_rate,
592        }
593    }
594}
595
596impl Word2Vec {
597    /// Create a new Word2Vec model with default configuration
598    pub fn new() -> Self {
599        Self {
600            config: Word2VecConfig::default(),
601            vocabulary: Vocabulary::new(),
602            input_embeddings: None,
603            output_embeddings: None,
604            tokenizer: Box::new(WordTokenizer::default()),
605            sampling_table: None,
606            huffman_tree: None,
607            hs_params: None,
608            current_learning_rate: 0.025,
609        }
610    }
611
612    /// Create a new Word2Vec model with the specified configuration
613    pub fn with_config(config: Word2VecConfig) -> Self {
614        let learning_rate = config.learning_rate;
615        Self {
616            config,
617            vocabulary: Vocabulary::new(),
618            input_embeddings: None,
619            output_embeddings: None,
620            tokenizer: Box::new(WordTokenizer::default()),
621            sampling_table: None,
622            huffman_tree: None,
623            hs_params: None,
624            current_learning_rate: learning_rate,
625        }
626    }
627
628    /// Set a custom tokenizer
629    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
630        self.tokenizer = tokenizer;
631        self
632    }
633
634    /// Set vector size
635    pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
636        self.config.vector_size = vectorsize;
637        self
638    }
639
640    /// Set window size
641    pub fn with_window_size(mut self, windowsize: usize) -> Self {
642        self.config.window_size = windowsize;
643        self
644    }
645
646    /// Set minimum count
647    pub fn with_min_count(mut self, mincount: usize) -> Self {
648        self.config.min_count = mincount;
649        self
650    }
651
652    /// Set number of epochs
653    pub fn with_epochs(mut self, epochs: usize) -> Self {
654        self.config.epochs = epochs;
655        self
656    }
657
658    /// Set learning rate
659    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
660        self.config.learning_rate = learningrate;
661        self.current_learning_rate = learningrate;
662        self
663    }
664
665    /// Set algorithm (CBOW or Skip-gram)
666    pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
667        self.config.algorithm = algorithm;
668        self
669    }
670
671    /// Set number of negative samples
672    pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
673        self.config.negative_samples = negativesamples;
674        self
675    }
676
677    /// Set subsampling threshold
678    pub fn with_subsample(mut self, subsample: f64) -> Self {
679        self.config.subsample = subsample;
680        self
681    }
682
683    /// Set batch size
684    pub fn with_batch_size(mut self, batchsize: usize) -> Self {
685        self.config.batch_size = batchsize;
686        self
687    }
688
689    /// Build vocabulary from a corpus
690    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
691        if texts.is_empty() {
692            return Err(TextError::InvalidInput(
693                "No texts provided for building vocabulary".into(),
694            ));
695        }
696
697        // Count word frequencies
698        let mut word_counts = HashMap::new();
699        let mut _total_words = 0;
700
701        for &text in texts {
702            let tokens = self.tokenizer.tokenize(text)?;
703            for token in tokens {
704                *word_counts.entry(token).or_insert(0) += 1;
705                _total_words += 1;
706            }
707        }
708
709        // Create vocabulary with words that meet minimum count
710        self.vocabulary = Vocabulary::new();
711        for (word, count) in &word_counts {
712            if *count >= self.config.min_count {
713                self.vocabulary.add_token(word);
714            }
715        }
716
717        if self.vocabulary.is_empty() {
718            return Err(TextError::VocabularyError(
719                "No words meet the minimum count threshold".into(),
720            ));
721        }
722
723        // Initialize embeddings
724        let vocab_size = self.vocabulary.len();
725        let vector_size = self.config.vector_size;
726
727        // Initialize input and output embeddings with small random values
728        let mut rng = scirs2_core::random::rng();
729        let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
730            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
731        });
732        let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
733            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
734        });
735
736        self.input_embeddings = Some(input_embeddings);
737        self.output_embeddings = Some(output_embeddings);
738
739        // Create sampling table for negative sampling
740        self.create_sampling_table(&word_counts)?;
741
742        // Build Huffman tree for hierarchical softmax if configured
743        if self.config.hierarchical_softmax {
744            let frequencies: Vec<usize> = (0..vocab_size)
745                .map(|i| {
746                    self.vocabulary
747                        .get_token(i)
748                        .and_then(|word| word_counts.get(word).copied())
749                        .unwrap_or(1)
750                })
751                .collect();
752
753            let tree = HuffmanTree::build(&frequencies)?;
754            let num_internal = tree.num_internal;
755
756            // Initialize hierarchical softmax parameter vectors
757            let hs_params = Array2::zeros((num_internal, vector_size));
758            self.hs_params = Some(hs_params);
759            self.huffman_tree = Some(tree);
760        }
761
762        Ok(())
763    }
764
765    /// Create sampling table for negative sampling based on word frequencies
766    fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
767        // Prepare sampling weights (unigram distribution raised to 3/4 power)
768        let mut sampling_weights = vec![0.0; self.vocabulary.len()];
769
770        for (word, &count) in wordcounts.iter() {
771            if let Some(idx) = self.vocabulary.get_index(word) {
772                // Apply smoothing: frequency^0.75
773                sampling_weights[idx] = (count as f64).powf(0.75);
774            }
775        }
776
777        match SamplingTable::new(&sampling_weights) {
778            Ok(table) => {
779                self.sampling_table = Some(table);
780                Ok(())
781            }
782            Err(e) => Err(e),
783        }
784    }
785
786    /// Train the Word2Vec model on a corpus
787    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
788        if texts.is_empty() {
789            return Err(TextError::InvalidInput(
790                "No texts provided for training".into(),
791            ));
792        }
793
794        // Build vocabulary if not already built
795        if self.vocabulary.is_empty() {
796            self.build_vocabulary(texts)?;
797        }
798
799        if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
800            return Err(TextError::EmbeddingError(
801                "Embeddings not initialized. Call build_vocabulary() first".into(),
802            ));
803        }
804
805        // Count total number of tokens for progress tracking
806        let mut _total_tokens = 0;
807        let mut sentences = Vec::new();
808        for &text in texts {
809            let tokens = self.tokenizer.tokenize(text)?;
810            let filtered_tokens: Vec<usize> = tokens
811                .iter()
812                .filter_map(|token| self.vocabulary.get_index(token))
813                .collect();
814            if !filtered_tokens.is_empty() {
815                _total_tokens += filtered_tokens.len();
816                sentences.push(filtered_tokens);
817            }
818        }
819
820        // Train for the specified number of epochs
821        for epoch in 0..self.config.epochs {
822            // Update learning rate for this epoch
823            self.current_learning_rate =
824                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
825            self.current_learning_rate = self
826                .current_learning_rate
827                .max(self.config.learning_rate * 0.0001);
828
829            // Process each sentence
830            for sentence in &sentences {
831                // Apply subsampling of frequent words
832                let subsampled_sentence = if self.config.subsample > 0.0 {
833                    self.subsample_sentence(sentence)?
834                } else {
835                    sentence.clone()
836                };
837
838                // Skip empty sentences
839                if subsampled_sentence.is_empty() {
840                    continue;
841                }
842
843                // Train on the sentence
844                if self.config.hierarchical_softmax {
845                    // Use hierarchical softmax
846                    match self.config.algorithm {
847                        Word2VecAlgorithm::SkipGram => {
848                            self.train_skipgram_hs_sentence(&subsampled_sentence)?;
849                        }
850                        Word2VecAlgorithm::CBOW => {
851                            self.train_cbow_hs_sentence(&subsampled_sentence)?;
852                        }
853                    }
854                } else {
855                    // Use negative sampling
856                    match self.config.algorithm {
857                        Word2VecAlgorithm::CBOW => {
858                            self.train_cbow_sentence(&subsampled_sentence)?;
859                        }
860                        Word2VecAlgorithm::SkipGram => {
861                            self.train_skipgram_sentence(&subsampled_sentence)?;
862                        }
863                    }
864                }
865            }
866        }
867
868        Ok(())
869    }
870
871    /// Apply subsampling to a sentence
872    fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
873        let mut rng = scirs2_core::random::rng();
874        let total_words: f64 = self.vocabulary.len() as f64;
875        let threshold = self.config.subsample * total_words;
876
877        // Filter words based on subsampling probability
878        let subsampled: Vec<usize> = sentence
879            .iter()
880            .filter(|&&word_idx| {
881                let word_freq = self.get_word_frequency(word_idx);
882                if word_freq == 0.0 {
883                    return true; // Keep rare words
884                }
885                // Probability of keeping the word
886                let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
887                rng.random::<f64>() < keep_prob
888            })
889            .copied()
890            .collect();
891
892        Ok(subsampled)
893    }
894
895    /// Get the frequency of a word in the vocabulary
896    fn get_word_frequency(&self, wordidx: usize) -> f64 {
897        // This is a simplified version; ideal implementation would track actual frequencies
898        // For now, we'll use the sampling table weights as a proxy
899        if let Some(table) = &self.sampling_table {
900            table.weights()[wordidx]
901        } else {
902            1.0 // Default weight if no sampling table exists
903        }
904    }
905
906    /// Train CBOW model on a single sentence
907    fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
908        if sentence.len() < 2 {
909            return Ok(()); // Need at least 2 words for context
910        }
911
912        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
913        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
914        let vector_size = self.config.vector_size;
915        let window_size = self.config.window_size;
916        let negative_samples = self.config.negative_samples;
917
918        // For each position in sentence, predict the word from its context
919        for pos in 0..sentence.len() {
920            // Determine context window (with random size)
921            let mut rng = scirs2_core::random::rng();
922            let window = 1 + rng.random_range(0..window_size);
923            let target_word = sentence[pos];
924
925            // Collect context words and average their vectors
926            let mut context_words = Vec::new();
927            #[allow(clippy::needless_range_loop)]
928            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
929                if i != pos {
930                    context_words.push(sentence[i]);
931                }
932            }
933
934            if context_words.is_empty() {
935                continue; // No context words
936            }
937
938            // Average the context word vectors
939            let mut context_sum = Array1::zeros(vector_size);
940            for &context_idx in &context_words {
941                context_sum += &input_embeddings.row(context_idx);
942            }
943            let context_avg = &context_sum / context_words.len() as f64;
944
945            // Update target word's output embedding with positive example
946            let mut target_output = output_embeddings.row_mut(target_word);
947            let dot_product = (&context_avg * &target_output).sum();
948            let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
949            let error = (1.0 - sigmoid) * self.current_learning_rate;
950
951            // Create a copy for update
952            let mut target_update = target_output.to_owned();
953            target_update.scaled_add(error, &context_avg);
954            target_output.assign(&target_update);
955
956            // Negative sampling
957            if let Some(sampler) = &self.sampling_table {
958                for _ in 0..negative_samples {
959                    let negative_idx = sampler.sample(&mut rng);
960                    if negative_idx == target_word {
961                        continue; // Skip if we sample the target word
962                    }
963
964                    let mut negative_output = output_embeddings.row_mut(negative_idx);
965                    let dot_product = (&context_avg * &negative_output).sum();
966                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
967                    let error = -sigmoid * self.current_learning_rate;
968
969                    // Create a copy for update
970                    let mut negative_update = negative_output.to_owned();
971                    negative_update.scaled_add(error, &context_avg);
972                    negative_output.assign(&negative_update);
973                }
974            }
975
976            // Update context word vectors
977            for &context_idx in &context_words {
978                let mut input_vec = input_embeddings.row_mut(context_idx);
979
980                // Positive example
981                let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
982                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
983                let error =
984                    (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
985
986                // Create a copy for update
987                let mut input_update = input_vec.to_owned();
988                input_update.scaled_add(error, &output_embeddings.row(target_word));
989
990                // Negative examples
991                if let Some(sampler) = &self.sampling_table {
992                    for _ in 0..negative_samples {
993                        let negative_idx = sampler.sample(&mut rng);
994                        if negative_idx == target_word {
995                            continue;
996                        }
997
998                        let dot_product =
999                            (&context_avg * &output_embeddings.row(negative_idx)).sum();
1000                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1001                        let error =
1002                            -sigmoid * self.current_learning_rate / context_words.len() as f64;
1003
1004                        input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1005                    }
1006                }
1007
1008                input_vec.assign(&input_update);
1009            }
1010        }
1011
1012        Ok(())
1013    }
1014
1015    /// Train Skip-gram model on a single sentence
1016    fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1017        if sentence.len() < 2 {
1018            return Ok(()); // Need at least 2 words for context
1019        }
1020
1021        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1022        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1023        let vector_size = self.config.vector_size;
1024        let window_size = self.config.window_size;
1025        let negative_samples = self.config.negative_samples;
1026
1027        // For each position in sentence, predict the context from the word
1028        for pos in 0..sentence.len() {
1029            // Determine context window (with random size)
1030            let mut rng = scirs2_core::random::rng();
1031            let window = 1 + rng.random_range(0..window_size);
1032            let target_word = sentence[pos];
1033
1034            // For each context position
1035            #[allow(clippy::needless_range_loop)]
1036            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1037                if i == pos {
1038                    continue; // Skip the target word itself
1039                }
1040
1041                let context_word = sentence[i];
1042                let target_input = input_embeddings.row(target_word);
1043                let mut context_output = output_embeddings.row_mut(context_word);
1044
1045                // Update context word's output embedding with positive example
1046                let dot_product = (&target_input * &context_output).sum();
1047                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1048                let error = (1.0 - sigmoid) * self.current_learning_rate;
1049
1050                // Create a copy for update
1051                let mut context_update = context_output.to_owned();
1052                context_update.scaled_add(error, &target_input);
1053                context_output.assign(&context_update);
1054
1055                // Gradient for input word vector
1056                let mut input_update = Array1::zeros(vector_size);
1057                input_update.scaled_add(error, &context_output);
1058
1059                // Negative sampling
1060                if let Some(sampler) = &self.sampling_table {
1061                    for _ in 0..negative_samples {
1062                        let negative_idx = sampler.sample(&mut rng);
1063                        if negative_idx == context_word {
1064                            continue; // Skip if we sample the context word
1065                        }
1066
1067                        let mut negative_output = output_embeddings.row_mut(negative_idx);
1068                        let dot_product = (&target_input * &negative_output).sum();
1069                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1070                        let error = -sigmoid * self.current_learning_rate;
1071
1072                        // Create a copy for update
1073                        let mut negative_update = negative_output.to_owned();
1074                        negative_update.scaled_add(error, &target_input);
1075                        negative_output.assign(&negative_update);
1076
1077                        // Update input gradient
1078                        input_update.scaled_add(error, &negative_output);
1079                    }
1080                }
1081
1082                // Apply the accumulated gradient to the input word vector
1083                let mut target_input_mut = input_embeddings.row_mut(target_word);
1084                target_input_mut += &input_update;
1085            }
1086        }
1087
1088        Ok(())
1089    }
1090
1091    /// Get the vector size
1092    pub fn vector_size(&self) -> usize {
1093        self.config.vector_size
1094    }
1095
1096    /// Get the embedding vector for a word
1097    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1098        if self.input_embeddings.is_none() {
1099            return Err(TextError::EmbeddingError(
1100                "Model not trained. Call train() first".into(),
1101            ));
1102        }
1103
1104        match self.vocabulary.get_index(word) {
1105            Some(idx) => Ok(self
1106                .input_embeddings
1107                .as_ref()
1108                .expect("Operation failed")
1109                .row(idx)
1110                .to_owned()),
1111            None => Err(TextError::VocabularyError(format!(
1112                "Word '{word}' not in vocabulary"
1113            ))),
1114        }
1115    }
1116
1117    /// Get the most similar words to a given word
1118    pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1119        let word_vec = self.get_word_vector(word)?;
1120        self.most_similar_by_vector(&word_vec, topn, &[word])
1121    }
1122
1123    /// Get the most similar words to a given vector
1124    pub fn most_similar_by_vector(
1125        &self,
1126        vector: &Array1<f64>,
1127        top_n: usize,
1128        exclude_words: &[&str],
1129    ) -> Result<Vec<(String, f64)>> {
1130        if self.input_embeddings.is_none() {
1131            return Err(TextError::EmbeddingError(
1132                "Model not trained. Call train() first".into(),
1133            ));
1134        }
1135
1136        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1137        let vocab_size = self.vocabulary.len();
1138
1139        // Create a set of indices to exclude
1140        let exclude_indices: Vec<usize> = exclude_words
1141            .iter()
1142            .filter_map(|&word| self.vocabulary.get_index(word))
1143            .collect();
1144
1145        // Calculate cosine similarity for all _words
1146        let mut similarities = Vec::with_capacity(vocab_size);
1147
1148        for i in 0..vocab_size {
1149            if exclude_indices.contains(&i) {
1150                continue;
1151            }
1152
1153            let word_vec = input_embeddings.row(i);
1154            let similarity = cosine_similarity(vector, &word_vec.to_owned());
1155
1156            if let Some(word) = self.vocabulary.get_token(i) {
1157                similarities.push((word.to_string(), similarity));
1158            }
1159        }
1160
1161        // Sort by similarity (descending)
1162        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1163
1164        // Take top N
1165        let result = similarities.into_iter().take(top_n).collect();
1166        Ok(result)
1167    }
1168
1169    /// Compute the analogy: a is to b as c is to ?
1170    pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1171        if self.input_embeddings.is_none() {
1172            return Err(TextError::EmbeddingError(
1173                "Model not trained. Call train() first".into(),
1174            ));
1175        }
1176
1177        // Get vectors for a, b, and c
1178        let a_vec = self.get_word_vector(a)?;
1179        let b_vec = self.get_word_vector(b)?;
1180        let c_vec = self.get_word_vector(c)?;
1181
1182        // Compute d_vec = b_vec - a_vec + c_vec
1183        let mut d_vec = b_vec.clone();
1184        d_vec -= &a_vec;
1185        d_vec += &c_vec;
1186
1187        // Normalize the vector
1188        let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1189        d_vec.mapv_inplace(|val| val / norm);
1190
1191        // Find most similar words to d_vec
1192        self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1193    }
1194
1195    /// Save the Word2Vec model to a file
1196    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1197        if self.input_embeddings.is_none() {
1198            return Err(TextError::EmbeddingError(
1199                "Model not trained. Call train() first".into(),
1200            ));
1201        }
1202
1203        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1204
1205        // Write header: vector_size and vocabulary size
1206        writeln!(
1207            &mut file,
1208            "{} {}",
1209            self.vocabulary.len(),
1210            self.config.vector_size
1211        )
1212        .map_err(|e| TextError::IoError(e.to_string()))?;
1213
1214        // Write each word and its vector
1215        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1216
1217        for i in 0..self.vocabulary.len() {
1218            if let Some(word) = self.vocabulary.get_token(i) {
1219                // Write the word
1220                write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1221
1222                // Write the vector components
1223                let vector = input_embeddings.row(i);
1224                for j in 0..self.config.vector_size {
1225                    write!(&mut file, "{:.6} ", vector[j])
1226                        .map_err(|e| TextError::IoError(e.to_string()))?;
1227                }
1228
1229                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1230            }
1231        }
1232
1233        Ok(())
1234    }
1235
1236    /// Load a Word2Vec model from a file
1237    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1238        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1239        let mut reader = BufReader::new(file);
1240
1241        // Read header
1242        let mut header = String::new();
1243        reader
1244            .read_line(&mut header)
1245            .map_err(|e| TextError::IoError(e.to_string()))?;
1246
1247        let parts: Vec<&str> = header.split_whitespace().collect();
1248        if parts.len() != 2 {
1249            return Err(TextError::EmbeddingError(
1250                "Invalid model file format".into(),
1251            ));
1252        }
1253
1254        let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1255            TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1256        })?;
1257
1258        let vector_size = parts[1]
1259            .parse::<usize>()
1260            .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1261
1262        // Initialize model
1263        let mut model = Self::new().with_vector_size(vector_size);
1264        let mut vocabulary = Vocabulary::new();
1265        let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1266
1267        // Read each word and its vector
1268        let mut i = 0;
1269        for line in reader.lines() {
1270            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1271            let parts: Vec<&str> = line.split_whitespace().collect();
1272
1273            if parts.len() != vector_size + 1 {
1274                let line_num = i + 2;
1275                return Err(TextError::EmbeddingError(format!(
1276                    "Invalid vector format at line {line_num}"
1277                )));
1278            }
1279
1280            let word = parts[0];
1281            vocabulary.add_token(word);
1282
1283            for j in 0..vector_size {
1284                input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1285                    TextError::EmbeddingError(format!(
1286                        "Invalid vector component at line {}, position {}",
1287                        i + 2,
1288                        j + 1
1289                    ))
1290                })?;
1291            }
1292
1293            i += 1;
1294        }
1295
1296        if i != vocab_size {
1297            return Err(TextError::EmbeddingError(format!(
1298                "Expected {vocab_size} words but found {i}"
1299            )));
1300        }
1301
1302        model.vocabulary = vocabulary;
1303        model.input_embeddings = Some(input_embeddings);
1304        model.output_embeddings = None; // Only input embeddings are saved
1305
1306        Ok(model)
1307    }
1308
1309    // Getter methods for model registry serialization
1310
1311    /// Get the vocabulary as a vector of strings
1312    pub fn get_vocabulary(&self) -> Vec<String> {
1313        let mut vocab = Vec::new();
1314        for i in 0..self.vocabulary.len() {
1315            if let Some(token) = self.vocabulary.get_token(i) {
1316                vocab.push(token.to_string());
1317            }
1318        }
1319        vocab
1320    }
1321
1322    /// Get the vector size
1323    pub fn get_vector_size(&self) -> usize {
1324        self.config.vector_size
1325    }
1326
1327    /// Get the algorithm
1328    pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1329        self.config.algorithm
1330    }
1331
1332    /// Get the window size
1333    pub fn get_window_size(&self) -> usize {
1334        self.config.window_size
1335    }
1336
1337    /// Get the minimum count
1338    pub fn get_min_count(&self) -> usize {
1339        self.config.min_count
1340    }
1341
1342    /// Get the embeddings matrix (input embeddings)
1343    pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1344        self.input_embeddings.clone()
1345    }
1346
1347    /// Get the number of negative samples
1348    pub fn get_negative_samples(&self) -> usize {
1349        self.config.negative_samples
1350    }
1351
1352    /// Get the learning rate
1353    pub fn get_learning_rate(&self) -> f64 {
1354        self.config.learning_rate
1355    }
1356
1357    /// Get the number of epochs
1358    pub fn get_epochs(&self) -> usize {
1359        self.config.epochs
1360    }
1361
1362    /// Get the subsampling threshold
1363    pub fn get_subsampling_threshold(&self) -> f64 {
1364        self.config.subsample
1365    }
1366
1367    /// Check if hierarchical softmax is enabled
1368    pub fn uses_hierarchical_softmax(&self) -> bool {
1369        self.config.hierarchical_softmax
1370    }
1371
1372    /// Restore vocabulary and input embeddings from serialized state.
1373    ///
1374    /// Validates that:
1375    /// - `embeddings` row count equals `vocabulary.len()`
1376    /// - `embeddings` column count equals `self.config.vector_size`
1377    ///
1378    /// Returns an error on any dimension mismatch; never panics.
1379    pub fn restore_weights(
1380        &mut self,
1381        vocabulary: Vec<String>,
1382        embeddings: Array2<f64>,
1383    ) -> Result<()> {
1384        let embed_shape = embeddings.shape();
1385        let n_words = vocabulary.len();
1386
1387        if embed_shape[0] != n_words {
1388            return Err(TextError::EmbeddingError(format!(
1389                "Embedding row count {} does not match vocabulary size {}",
1390                embed_shape[0], n_words
1391            )));
1392        }
1393
1394        if embed_shape[1] != self.config.vector_size {
1395            return Err(TextError::EmbeddingError(format!(
1396                "Embedding dimension {} does not match configured vector_size {}",
1397                embed_shape[1], self.config.vector_size
1398            )));
1399        }
1400
1401        // Rebuild the vocabulary index
1402        self.vocabulary = Vocabulary::new();
1403        for word in &vocabulary {
1404            self.vocabulary.add_token(word);
1405        }
1406
1407        self.input_embeddings = Some(embeddings);
1408        Ok(())
1409    }
1410
1411    // ─── Hierarchical Softmax Training ──────────────────────────────────
1412
1413    /// Train Skip-gram with hierarchical softmax on a single sentence
1414    fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1415        if sentence.len() < 2 {
1416            return Ok(());
1417        }
1418
1419        let input_embeddings = self
1420            .input_embeddings
1421            .as_mut()
1422            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1423        let hs_params = self
1424            .hs_params
1425            .as_mut()
1426            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1427        let tree = self
1428            .huffman_tree
1429            .as_ref()
1430            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1431
1432        let vector_size = self.config.vector_size;
1433        let window_size = self.config.window_size;
1434        let lr = self.current_learning_rate;
1435
1436        let codes = tree.codes.clone();
1437        let paths = tree.paths.clone();
1438
1439        let mut rng = scirs2_core::random::rng();
1440
1441        for pos in 0..sentence.len() {
1442            let window = 1 + rng.random_range(0..window_size);
1443            let target_word = sentence[pos];
1444
1445            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1446                if i == pos {
1447                    continue;
1448                }
1449
1450                let context_word = sentence[i];
1451                let code = &codes[context_word];
1452                let path = &paths[context_word];
1453
1454                let mut grad_input = Array1::zeros(vector_size);
1455
1456                // Walk the Huffman tree path for the context word
1457                for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1458                    if node_idx >= hs_params.nrows() {
1459                        continue;
1460                    }
1461
1462                    // Compute sigmoid(input_vec . hs_param)
1463                    let input_vec = input_embeddings.row(target_word);
1464                    let param_vec = hs_params.row(node_idx);
1465
1466                    let dot: f64 = input_vec
1467                        .iter()
1468                        .zip(param_vec.iter())
1469                        .map(|(a, b)| a * b)
1470                        .sum();
1471                    let sigmoid = 1.0 / (1.0 + (-dot).exp());
1472
1473                    // gradient: (1 - label - sigmoid) * lr
1474                    let target = if label == 0 { 1.0 } else { 0.0 };
1475                    let gradient = (target - sigmoid) * lr;
1476
1477                    // Update gradient for input vector
1478                    grad_input.scaled_add(gradient, &param_vec.to_owned());
1479
1480                    // Update HS parameter vector
1481                    let input_owned = input_vec.to_owned();
1482                    let mut param_mut = hs_params.row_mut(node_idx);
1483                    param_mut.scaled_add(gradient, &input_owned);
1484                }
1485
1486                // Update input embedding
1487                let mut input_mut = input_embeddings.row_mut(target_word);
1488                input_mut += &grad_input;
1489            }
1490        }
1491
1492        Ok(())
1493    }
1494
1495    /// Train CBOW with hierarchical softmax on a single sentence
1496    fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1497        if sentence.len() < 2 {
1498            return Ok(());
1499        }
1500
1501        let input_embeddings = self
1502            .input_embeddings
1503            .as_mut()
1504            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1505        let hs_params = self
1506            .hs_params
1507            .as_mut()
1508            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1509        let tree = self
1510            .huffman_tree
1511            .as_ref()
1512            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1513
1514        let vector_size = self.config.vector_size;
1515        let window_size = self.config.window_size;
1516        let lr = self.current_learning_rate;
1517
1518        let codes = tree.codes.clone();
1519        let paths = tree.paths.clone();
1520
1521        let mut rng = scirs2_core::random::rng();
1522
1523        for pos in 0..sentence.len() {
1524            let window = 1 + rng.random_range(0..window_size);
1525            let target_word = sentence[pos];
1526
1527            // Collect context words
1528            let mut context_words = Vec::new();
1529            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1530                if i != pos {
1531                    context_words.push(sentence[i]);
1532                }
1533            }
1534
1535            if context_words.is_empty() {
1536                continue;
1537            }
1538
1539            // Average context word vectors
1540            let mut context_avg = Array1::zeros(vector_size);
1541            for &ctx_idx in &context_words {
1542                context_avg += &input_embeddings.row(ctx_idx);
1543            }
1544            context_avg /= context_words.len() as f64;
1545
1546            // Walk Huffman path for target word
1547            let code = &codes[target_word];
1548            let path = &paths[target_word];
1549
1550            let mut grad_context = Array1::zeros(vector_size);
1551
1552            for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1553                if node_idx >= hs_params.nrows() {
1554                    continue;
1555                }
1556
1557                let param_vec = hs_params.row(node_idx);
1558
1559                let dot: f64 = context_avg
1560                    .iter()
1561                    .zip(param_vec.iter())
1562                    .map(|(a, b)| a * b)
1563                    .sum();
1564                let sigmoid = 1.0 / (1.0 + (-dot).exp());
1565
1566                let target = if label == 0 { 1.0 } else { 0.0 };
1567                let gradient = (target - sigmoid) * lr;
1568
1569                grad_context.scaled_add(gradient, &param_vec.to_owned());
1570
1571                // Update HS parameter
1572                let ctx_owned = context_avg.clone();
1573                let mut param_mut = hs_params.row_mut(node_idx);
1574                param_mut.scaled_add(gradient, &ctx_owned);
1575            }
1576
1577            // Distribute gradient back to context word input embeddings
1578            let grad_per_word = &grad_context / context_words.len() as f64;
1579            for &ctx_idx in &context_words {
1580                let mut input_mut = input_embeddings.row_mut(ctx_idx);
1581                input_mut += &grad_per_word;
1582            }
1583        }
1584
1585        Ok(())
1586    }
1587}
1588
1589// ─── WordEmbedding trait implementation for Word2Vec ─────────────────────────
1590
1591impl WordEmbedding for Word2Vec {
1592    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1593        self.get_word_vector(word)
1594    }
1595
1596    fn dimension(&self) -> usize {
1597        self.vector_size()
1598    }
1599
1600    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1601        self.most_similar(word, top_n)
1602    }
1603
1604    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1605        self.analogy(a, b, c, top_n)
1606    }
1607
1608    fn vocab_size(&self) -> usize {
1609        self.vocabulary.len()
1610    }
1611}
1612
1613/// Calculate cosine similarity between two vectors
1614#[allow(dead_code)]
1615pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1616    let dot_product = (a * b).sum();
1617    let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1618    let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1619
1620    if norm_a > 0.0 && norm_b > 0.0 {
1621        dot_product / (norm_a * norm_b)
1622    } else {
1623        0.0
1624    }
1625}
1626
1627#[cfg(test)]
1628mod tests {
1629    use super::*;
1630    use approx::assert_relative_eq;
1631
1632    #[test]
1633    fn test_cosine_similarity() {
1634        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1635        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1636
1637        let similarity = cosine_similarity(&a, &b);
1638        let expected = 0.9746318461970762;
1639        assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1640    }
1641
1642    #[test]
1643    fn test_word2vec_config() {
1644        let config = Word2VecConfig::default();
1645        assert_eq!(config.vector_size, 100);
1646        assert_eq!(config.window_size, 5);
1647        assert_eq!(config.min_count, 5);
1648        assert_eq!(config.epochs, 5);
1649        assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1650    }
1651
1652    #[test]
1653    fn test_word2vec_builder() {
1654        let model = Word2Vec::new()
1655            .with_vector_size(200)
1656            .with_window_size(10)
1657            .with_learning_rate(0.05)
1658            .with_algorithm(Word2VecAlgorithm::CBOW);
1659
1660        assert_eq!(model.config.vector_size, 200);
1661        assert_eq!(model.config.window_size, 10);
1662        assert_eq!(model.config.learning_rate, 0.05);
1663        assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1664    }
1665
1666    #[test]
1667    fn test_build_vocabulary() {
1668        let texts = [
1669            "the quick brown fox jumps over the lazy dog",
1670            "a quick brown fox jumps over a lazy dog",
1671        ];
1672
1673        let mut model = Word2Vec::new().with_min_count(1);
1674        let result = model.build_vocabulary(&texts);
1675        assert!(result.is_ok());
1676
1677        // Check vocabulary size (unique words: "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "a")
1678        assert_eq!(model.vocabulary.len(), 9);
1679
1680        // Check that embeddings were initialized
1681        assert!(model.input_embeddings.is_some());
1682        assert!(model.output_embeddings.is_some());
1683        assert_eq!(
1684            model
1685                .input_embeddings
1686                .as_ref()
1687                .expect("Operation failed")
1688                .shape(),
1689            &[9, 100]
1690        );
1691    }
1692
1693    #[test]
1694    fn test_skipgram_training_small() {
1695        let texts = [
1696            "the quick brown fox jumps over the lazy dog",
1697            "a quick brown fox jumps over a lazy dog",
1698        ];
1699
1700        let mut model = Word2Vec::new()
1701            .with_vector_size(10)
1702            .with_window_size(2)
1703            .with_min_count(1)
1704            .with_epochs(1)
1705            .with_algorithm(Word2VecAlgorithm::SkipGram);
1706
1707        let result = model.train(&texts);
1708        assert!(result.is_ok());
1709
1710        // Test getting a word vector
1711        let result = model.get_word_vector("fox");
1712        assert!(result.is_ok());
1713        let vec = result.expect("Operation failed");
1714        assert_eq!(vec.len(), 10);
1715    }
1716
1717    // ─── Hierarchical Softmax Tests ──────────────────────────────────
1718
1719    #[test]
1720    fn test_huffman_tree_build() {
1721        let frequencies = vec![5, 3, 8, 1, 2];
1722        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1723
1724        // Each word should have a code
1725        assert_eq!(tree.codes.len(), 5);
1726        assert_eq!(tree.paths.len(), 5);
1727
1728        // All codes should be non-empty
1729        for code in &tree.codes {
1730            assert!(!code.is_empty());
1731        }
1732
1733        // Internal nodes = vocab_size - 1 (for a binary tree)
1734        assert_eq!(tree.num_internal, 4);
1735    }
1736
1737    #[test]
1738    fn test_huffman_tree_single_word() {
1739        let frequencies = vec![10];
1740        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1741        assert_eq!(tree.codes.len(), 1);
1742        assert_eq!(tree.paths.len(), 1);
1743    }
1744
1745    #[test]
1746    fn test_skipgram_hierarchical_softmax() {
1747        let texts = [
1748            "the quick brown fox jumps over the lazy dog",
1749            "a quick brown fox jumps over a lazy dog",
1750        ];
1751
1752        let config = Word2VecConfig {
1753            vector_size: 10,
1754            window_size: 2,
1755            min_count: 1,
1756            epochs: 3,
1757            learning_rate: 0.025,
1758            algorithm: Word2VecAlgorithm::SkipGram,
1759            hierarchical_softmax: true,
1760            ..Default::default()
1761        };
1762
1763        let mut model = Word2Vec::with_config(config);
1764        let result = model.train(&texts);
1765        assert!(
1766            result.is_ok(),
1767            "HS skipgram training failed: {:?}",
1768            result.err()
1769        );
1770
1771        assert!(model.uses_hierarchical_softmax());
1772
1773        // Should produce valid word vectors
1774        let vec = model.get_word_vector("fox");
1775        assert!(vec.is_ok());
1776        assert_eq!(vec.expect("get vec").len(), 10);
1777    }
1778
1779    #[test]
1780    fn test_cbow_hierarchical_softmax() {
1781        let texts = [
1782            "the quick brown fox jumps over the lazy dog",
1783            "a quick brown fox jumps over a lazy dog",
1784        ];
1785
1786        let config = Word2VecConfig {
1787            vector_size: 10,
1788            window_size: 2,
1789            min_count: 1,
1790            epochs: 3,
1791            learning_rate: 0.025,
1792            algorithm: Word2VecAlgorithm::CBOW,
1793            hierarchical_softmax: true,
1794            ..Default::default()
1795        };
1796
1797        let mut model = Word2Vec::with_config(config);
1798        let result = model.train(&texts);
1799        assert!(
1800            result.is_ok(),
1801            "HS CBOW training failed: {:?}",
1802            result.err()
1803        );
1804
1805        let vec = model.get_word_vector("dog");
1806        assert!(vec.is_ok());
1807    }
1808
1809    // ─── WordEmbedding Trait Tests ──────────────────────────────────
1810
1811    #[test]
1812    fn test_word_embedding_trait_word2vec() {
1813        let texts = [
1814            "the quick brown fox jumps over the lazy dog",
1815            "a quick brown fox jumps over a lazy dog",
1816        ];
1817
1818        let mut model = Word2Vec::new()
1819            .with_vector_size(10)
1820            .with_min_count(1)
1821            .with_epochs(1);
1822
1823        model.train(&texts).expect("Training failed");
1824
1825        // Use via trait
1826        let emb: &dyn WordEmbedding = &model;
1827        assert_eq!(emb.dimension(), 10);
1828        assert!(emb.vocab_size() > 0);
1829
1830        let vec = emb.embedding("fox");
1831        assert!(vec.is_ok());
1832
1833        let sim = emb.similarity("fox", "dog");
1834        assert!(sim.is_ok());
1835        assert!(sim.expect("sim").is_finite());
1836
1837        let similar = emb.find_similar("fox", 2);
1838        assert!(similar.is_ok());
1839
1840        let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1841        assert!(analogy.is_ok());
1842    }
1843
1844    #[test]
1845    fn test_embedding_cosine_similarity_fn() {
1846        let a = Array1::from_vec(vec![1.0, 0.0]);
1847        let b = Array1::from_vec(vec![0.0, 1.0]);
1848        assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1849
1850        let c = Array1::from_vec(vec![1.0, 1.0]);
1851        let d = Array1::from_vec(vec![1.0, 1.0]);
1852        assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1853    }
1854
1855    #[test]
1856    fn test_pairwise_similarity_fn() {
1857        let texts = ["the quick brown fox", "the lazy brown dog"];
1858
1859        let mut model = Word2Vec::new()
1860            .with_vector_size(10)
1861            .with_min_count(1)
1862            .with_epochs(1);
1863        model.train(&texts).expect("Training failed");
1864
1865        let words = vec!["the", "fox", "dog"];
1866        let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1867
1868        assert_eq!(matrix.len(), 3);
1869        assert_eq!(matrix[0].len(), 3);
1870
1871        // Diagonal should be 1.0
1872        for i in 0..3 {
1873            assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1874        }
1875
1876        // Symmetric
1877        for i in 0..3 {
1878            for j in 0..3 {
1879                assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1880            }
1881        }
1882    }
1883}