Skip to main content

scirs2_text/embeddings/
mod.rs

1//! # Word Embeddings Module
2//!
3//! This module provides comprehensive implementations for word embeddings, including
4//! Word2Vec, GloVe, and FastText.
5//!
6//! ## Overview
7//!
8//! Word embeddings are dense vector representations of words that capture semantic
9//! relationships. This module implements:
10//!
11//! - **Word2Vec Skip-gram**: Predicts context words from a target word
12//! - **Word2Vec CBOW**: Predicts a target word from context words
13//! - **GloVe**: Global Vectors for Word Representation
14//! - **FastText**: Character n-gram based embeddings (handles OOV words)
15//! - **Negative sampling**: Efficient training technique for large vocabularies
16//! - **Hierarchical softmax**: Alternative to negative sampling for optimization
17//!
18//! ## Quick Start
19//!
20//! ```rust
21//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
22//!
23//! // Basic Word2Vec training
24//! let documents = vec![
25//!     "the quick brown fox jumps over the lazy dog",
26//!     "the dog was lazy but the fox was quick",
27//!     "brown fox and lazy dog are common phrases"
28//! ];
29//!
30//! let config = Word2VecConfig {
31//!     algorithm: Word2VecAlgorithm::SkipGram,
32//!     vector_size: 100,
33//!     window_size: 5,
34//!     min_count: 1,
35//!     learning_rate: 0.025,
36//!     epochs: 5,
37//!     negative_samples: 5,
38//!     ..Default::default()
39//! };
40//!
41//! let mut model = Word2Vec::with_config(config);
42//! model.train(&documents).expect("Training failed");
43//!
44//! // Get word vector
45//! if let Ok(vector) = model.get_word_vector("fox") {
46//!     println!("Vector for 'fox': {:?}", vector);
47//! }
48//!
49//! // Find similar words
50//! let similar = model.most_similar("fox", 3).expect("Failed to find similar words");
51//! for (word, similarity) in similar {
52//!     println!("{}: {:.4}", word, similarity);
53//! }
54//! ```
55//!
56//! ## Advanced Usage
57//!
58//! ### Custom Configuration
59//!
60//! ```rust
61//! use scirs2_text::embeddings::{Word2Vec, Word2VecConfig, Word2VecAlgorithm};
62//!
63//! let config = Word2VecConfig {
64//!     algorithm: Word2VecAlgorithm::CBOW,
65//!     vector_size: 300,        // Larger vectors for better quality
66//!     window_size: 10,         // Larger context window
67//!     min_count: 5,           // Filter rare words
68//!     learning_rate: 0.01,    // Lower learning rate for stability
69//!     epochs: 15,             // More training iterations
70//!     negative_samples: 10,   // More negative samples
71//!     subsample: 1e-4, // Subsample frequent words
72//!     batch_size: 128,
73//!     hierarchical_softmax: false,
74//! };
75//! ```
76//!
77//! ### Incremental Training
78//!
79//! ```rust
80//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
81//! # let mut model = Word2Vec::new().with_min_count(1);
82//! // Initial training
83//! let batch1 = vec!["the quick brown fox jumps over the lazy dog"];
84//! model.train(&batch1).expect("Training failed");
85//!
86//! // Continue training with new data
87//! let batch2 = vec!["the dog was lazy but the fox was quick"];
88//! model.train(&batch2).expect("Training failed");
89//! ```
90//!
91//! ### Saving and Loading Models
92//!
93//! ```rust
94//! # use scirs2_text::embeddings::{Word2Vec, Word2VecConfig};
95//! # let mut model = Word2Vec::new().with_min_count(1);
96//! # let texts = vec!["the quick brown fox jumps over the lazy dog"];
97//! # model.train(&texts).expect("Training failed");
98//! // Save trained model
99//! model.save("my_model.w2v").expect("Failed to save model");
100//!
101//! // Load model
102//! let loaded_model = Word2Vec::load("my_model.w2v")
103//!     .expect("Failed to load model");
104//! ```
105//!
106//! ## Performance Tips
107//!
108//! 1. **Vocabulary Size**: Use `min_count` to filter rare words and reduce memory usage
109//! 2. **Vector Dimensions**: Balance between quality (higher dimensions) and speed (lower dimensions)
110//! 3. **Training Algorithm**: Skip-gram works better with rare words, CBOW is faster
111//! 4. **Negative Sampling**: Usually faster than hierarchical softmax for large vocabularies
112//! 5. **Subsampling**: Set `subsample_threshold` to handle frequent words efficiently
113//!
114//! ## Mathematical Background
115//!
116//! ### Skip-gram Model
117//!
118//! The Skip-gram model maximizes the probability of context words given a target word:
119//!
120//! P(context|target) = ∏ P(w_context|w_target)
121//!
122//! ### CBOW Model
123//!
124//! The CBOW model predicts a target word from its context:
125//!
126//! P(target|context) = P(w_target|w_context1, w_context2, ...)
127//!
128//! ### Negative Sampling
129//!
130//! Instead of computing the full softmax, negative sampling approximates it by
131//! sampling negative examples, making training much more efficient.
132
133// Sub-modules
134pub mod fasttext;
135pub mod glove;
136
137// Re-export
138pub use fasttext::{FastText, FastTextConfig};
139pub use glove::{
140    cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
141    GloVeTrainerConfig,
142};
143
144use crate::error::{Result, TextError};
145use crate::tokenize::{Tokenizer, WordTokenizer};
146use crate::vocabulary::Vocabulary;
147use scirs2_core::ndarray::{Array1, Array2};
148use scirs2_core::random::prelude::*;
149use std::collections::HashMap;
150use std::fmt::Debug;
151use std::fs::File;
152use std::io::{BufRead, BufReader, Write};
153use std::path::Path;
154
155// ─── WordEmbedding trait ─────────────────────────────────────────────────────
156
157/// Shared trait for word embedding models
158///
159/// Provides a common interface for querying word vectors, computing similarity,
160/// and solving analogies across different embedding implementations
161/// (Word2Vec, GloVe, FastText).
162pub trait WordEmbedding {
163    /// Get the embedding vector for a word
164    fn embedding(&self, word: &str) -> Result<Array1<f64>>;
165
166    /// Get the dimensionality of the embedding vectors
167    fn dimension(&self) -> usize;
168
169    /// Compute cosine similarity between two words
170    fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
171        let v1 = self.embedding(word1)?;
172        let v2 = self.embedding(word2)?;
173        Ok(embedding_cosine_similarity(&v1, &v2))
174    }
175
176    /// Find the top-N most similar words to the query word
177    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
178
179    /// Solve the analogy: a is to b as c is to ?
180    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
181
182    /// Get the number of words in the vocabulary
183    fn vocab_size(&self) -> usize;
184}
185
186/// Calculate cosine similarity between two embedding vectors
187pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
188    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
189    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
190    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
191
192    if norm_a > 0.0 && norm_b > 0.0 {
193        dot_product / (norm_a * norm_b)
194    } else {
195        0.0
196    }
197}
198
199/// Compute pairwise cosine similarity matrix for a list of words
200pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
201    let vectors: Vec<Array1<f64>> = words
202        .iter()
203        .map(|&w| model.embedding(w))
204        .collect::<Result<Vec<_>>>()?;
205
206    let n = vectors.len();
207    let mut matrix = vec![vec![0.0; n]; n];
208
209    for i in 0..n {
210        for j in i..n {
211            let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
212            matrix[i][j] = sim;
213            matrix[j][i] = sim;
214        }
215    }
216
217    Ok(matrix)
218}
219
220// ─── WordEmbedding implementations ──────────────────────────────────────────
221
222impl WordEmbedding for GloVe {
223    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
224        self.get_word_vector(word)
225    }
226
227    fn dimension(&self) -> usize {
228        self.vector_size()
229    }
230
231    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
232        self.most_similar(word, top_n)
233    }
234
235    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
236        self.analogy(a, b, c, top_n)
237    }
238
239    fn vocab_size(&self) -> usize {
240        self.vocabulary_size()
241    }
242}
243
244impl WordEmbedding for FastText {
245    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
246        self.get_word_vector(word)
247    }
248
249    fn dimension(&self) -> usize {
250        self.vector_size()
251    }
252
253    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
254        self.most_similar(word, top_n)
255    }
256
257    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
258        self.analogy(a, b, c, top_n)
259    }
260
261    fn vocab_size(&self) -> usize {
262        self.vocabulary_size()
263    }
264}
265
266// ─── Huffman Tree for Hierarchical Softmax ──────────────────────────────────
267
268/// A node in the Huffman tree used for hierarchical softmax
269#[derive(Debug, Clone)]
270struct HuffmanNode {
271    /// Index in vocabulary (for leaf nodes) or internal node id
272    id: usize,
273    /// Frequency (for building the tree)
274    frequency: usize,
275    /// Left child index in the nodes array
276    left: Option<usize>,
277    /// Right child index in the nodes array
278    right: Option<usize>,
279    /// Is this a leaf node?
280    is_leaf: bool,
281}
282
283/// Huffman tree for hierarchical softmax training
284#[derive(Debug, Clone)]
285struct HuffmanTree {
286    /// Huffman codes for each vocabulary word: sequence of 0/1 decisions
287    codes: Vec<Vec<u8>>,
288    /// Path of internal node indices for each vocabulary word
289    paths: Vec<Vec<usize>>,
290    /// Number of internal nodes
291    num_internal: usize,
292}
293
294impl HuffmanTree {
295    /// Build a Huffman tree from word frequencies
296    ///
297    /// Returns codes and paths for each word in the vocabulary.
298    fn build(frequencies: &[usize]) -> Result<Self> {
299        let vocab_size = frequencies.len();
300        if vocab_size == 0 {
301            return Err(TextError::EmbeddingError(
302                "Cannot build Huffman tree with empty vocabulary".into(),
303            ));
304        }
305        if vocab_size == 1 {
306            // Special case: single word
307            return Ok(Self {
308                codes: vec![vec![0]],
309                paths: vec![vec![0]],
310                num_internal: 1,
311            });
312        }
313
314        // Create leaf nodes
315        let mut nodes: Vec<HuffmanNode> = frequencies
316            .iter()
317            .enumerate()
318            .map(|(id, &freq)| HuffmanNode {
319                id,
320                frequency: freq.max(1), // Avoid zero frequency
321                left: None,
322                right: None,
323                is_leaf: true,
324            })
325            .collect();
326
327        // Priority queue simulation using a sorted list
328        // (index_in_nodes, frequency)
329        let mut queue: Vec<(usize, usize)> = nodes
330            .iter()
331            .enumerate()
332            .map(|(i, n)| (i, n.frequency))
333            .collect();
334        queue.sort_by(|a, b| b.1.cmp(&a.1)); // Sort descending (pop from end = min)
335
336        // Build the tree bottom-up
337        while queue.len() > 1 {
338            // Pop two smallest
339            let (idx1, freq1) = queue
340                .pop()
341                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
342            let (idx2, freq2) = queue
343                .pop()
344                .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
345
346            let new_id = nodes.len();
347            let new_node = HuffmanNode {
348                id: new_id,
349                frequency: freq1 + freq2,
350                left: Some(idx1),
351                right: Some(idx2),
352                is_leaf: false,
353            };
354            nodes.push(new_node);
355
356            // Insert new node maintaining sorted order
357            let new_freq = freq1 + freq2;
358            let insert_pos = queue
359                .binary_search_by(|(_, f)| new_freq.cmp(f))
360                .unwrap_or_else(|pos| pos);
361            queue.insert(insert_pos, (new_id, new_freq));
362        }
363
364        // Traverse tree to assign codes and paths
365        let num_internal = nodes.len() - vocab_size;
366        let mut codes = vec![Vec::new(); vocab_size];
367        let mut paths = vec![Vec::new(); vocab_size];
368
369        // DFS traversal
370        let root_idx = nodes.len() - 1;
371        let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
372
373        while let Some((node_idx, code, path)) = stack.pop() {
374            let node = &nodes[node_idx];
375
376            if node.is_leaf {
377                codes[node.id] = code;
378                paths[node.id] = path;
379            } else {
380                // Internal node index (0-based among internal nodes)
381                let internal_idx = node.id - vocab_size;
382
383                if let Some(left_idx) = node.left {
384                    let mut left_code = code.clone();
385                    left_code.push(0);
386                    let mut left_path = path.clone();
387                    left_path.push(internal_idx);
388                    stack.push((left_idx, left_code, left_path));
389                }
390
391                if let Some(right_idx) = node.right {
392                    let mut right_code = code.clone();
393                    right_code.push(1);
394                    let mut right_path = path.clone();
395                    right_path.push(internal_idx);
396                    stack.push((right_idx, right_code, right_path));
397                }
398            }
399        }
400
401        Ok(Self {
402            codes,
403            paths,
404            num_internal,
405        })
406    }
407}
408
409/// A simplified weighted sampling table
410#[derive(Debug, Clone)]
411struct SamplingTable {
412    /// The cumulative distribution function (CDF)
413    cdf: Vec<f64>,
414    /// The weights
415    weights: Vec<f64>,
416}
417
418impl SamplingTable {
419    /// Create a new sampling table from weights
420    fn new(weights: &[f64]) -> Result<Self> {
421        if weights.is_empty() {
422            return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
423        }
424
425        // Ensure all _weights are positive
426        if weights.iter().any(|&w| w < 0.0) {
427            return Err(TextError::EmbeddingError("Weights must be positive".into()));
428        }
429
430        // Compute the CDF
431        let sum: f64 = weights.iter().sum();
432        if sum <= 0.0 {
433            return Err(TextError::EmbeddingError(
434                "Sum of _weights must be positive".into(),
435            ));
436        }
437
438        let mut cdf = Vec::with_capacity(weights.len());
439        let mut total = 0.0;
440
441        for &w in weights {
442            total += w;
443            cdf.push(total / sum);
444        }
445
446        Ok(Self {
447            cdf,
448            weights: weights.to_vec(),
449        })
450    }
451
452    /// Sample an index based on the weights
453    fn sample<R: Rng>(&self, rng: &mut R) -> usize {
454        let r = rng.random::<f64>();
455
456        // Binary search for the insertion point
457        match self.cdf.binary_search_by(|&cdf_val| {
458            cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
459        }) {
460            Ok(idx) => idx,
461            Err(idx) => idx,
462        }
463    }
464
465    /// Get the weights
466    fn weights(&self) -> &[f64] {
467        &self.weights
468    }
469}
470
471/// Word2Vec training algorithms
472#[derive(Debug, Clone, Copy, PartialEq, Eq)]
473pub enum Word2VecAlgorithm {
474    /// Continuous Bag of Words (CBOW) algorithm
475    CBOW,
476    /// Skip-gram algorithm
477    SkipGram,
478}
479
480/// Word2Vec training options
481#[derive(Debug, Clone)]
482pub struct Word2VecConfig {
483    /// Size of the word vectors
484    pub vector_size: usize,
485    /// Maximum distance between the current and predicted word within a sentence
486    pub window_size: usize,
487    /// Minimum count of words to consider for training
488    pub min_count: usize,
489    /// Number of iterations (epochs) over the corpus
490    pub epochs: usize,
491    /// Learning rate
492    pub learning_rate: f64,
493    /// Skip-gram or CBOW algorithm
494    pub algorithm: Word2VecAlgorithm,
495    /// Number of negative samples per positive sample
496    pub negative_samples: usize,
497    /// Threshold for subsampling frequent words
498    pub subsample: f64,
499    /// Batch size for training
500    pub batch_size: usize,
501    /// Whether to use hierarchical softmax instead of negative sampling
502    pub hierarchical_softmax: bool,
503}
504
505impl Default for Word2VecConfig {
506    fn default() -> Self {
507        Self {
508            vector_size: 100,
509            window_size: 5,
510            min_count: 5,
511            epochs: 5,
512            learning_rate: 0.025,
513            algorithm: Word2VecAlgorithm::SkipGram,
514            negative_samples: 5,
515            subsample: 1e-3,
516            batch_size: 128,
517            hierarchical_softmax: false,
518        }
519    }
520}
521
522/// Word2Vec model for training and using word embeddings
523///
524/// Word2Vec is an algorithm for learning vector representations of words,
525/// also known as word embeddings. These vectors capture semantic meanings
526/// of words, allowing operations like "king - man + woman" to result in
527/// a vector close to "queen".
528///
529/// This implementation supports both Continuous Bag of Words (CBOW) and
530/// Skip-gram models, with negative sampling for efficient training.
531pub struct Word2Vec {
532    /// Configuration options
533    config: Word2VecConfig,
534    /// Vocabulary
535    vocabulary: Vocabulary,
536    /// Input embeddings
537    input_embeddings: Option<Array2<f64>>,
538    /// Output embeddings
539    output_embeddings: Option<Array2<f64>>,
540    /// Tokenizer
541    tokenizer: Box<dyn Tokenizer + Send + Sync>,
542    /// Sampling table for negative sampling
543    sampling_table: Option<SamplingTable>,
544    /// Huffman tree for hierarchical softmax
545    huffman_tree: Option<HuffmanTree>,
546    /// Hierarchical softmax parameter vectors (one per internal node)
547    hs_params: Option<Array2<f64>>,
548    /// Current learning rate (gets updated during training)
549    current_learning_rate: f64,
550}
551
552impl Debug for Word2Vec {
553    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
554        f.debug_struct("Word2Vec")
555            .field("config", &self.config)
556            .field("vocabulary", &self.vocabulary)
557            .field("input_embeddings", &self.input_embeddings)
558            .field("output_embeddings", &self.output_embeddings)
559            .field("sampling_table", &self.sampling_table)
560            .field("huffman_tree", &self.huffman_tree)
561            .field("current_learning_rate", &self.current_learning_rate)
562            .finish()
563    }
564}
565
566// Manual Clone implementation to handle the non-Clone tokenizer
567impl Default for Word2Vec {
568    fn default() -> Self {
569        Self::new()
570    }
571}
572
573impl Clone for Word2Vec {
574    fn clone(&self) -> Self {
575        let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
576
577        Self {
578            config: self.config.clone(),
579            vocabulary: self.vocabulary.clone(),
580            input_embeddings: self.input_embeddings.clone(),
581            output_embeddings: self.output_embeddings.clone(),
582            tokenizer,
583            sampling_table: self.sampling_table.clone(),
584            huffman_tree: self.huffman_tree.clone(),
585            hs_params: self.hs_params.clone(),
586            current_learning_rate: self.current_learning_rate,
587        }
588    }
589}
590
591impl Word2Vec {
592    /// Create a new Word2Vec model with default configuration
593    pub fn new() -> Self {
594        Self {
595            config: Word2VecConfig::default(),
596            vocabulary: Vocabulary::new(),
597            input_embeddings: None,
598            output_embeddings: None,
599            tokenizer: Box::new(WordTokenizer::default()),
600            sampling_table: None,
601            huffman_tree: None,
602            hs_params: None,
603            current_learning_rate: 0.025,
604        }
605    }
606
607    /// Create a new Word2Vec model with the specified configuration
608    pub fn with_config(config: Word2VecConfig) -> Self {
609        let learning_rate = config.learning_rate;
610        Self {
611            config,
612            vocabulary: Vocabulary::new(),
613            input_embeddings: None,
614            output_embeddings: None,
615            tokenizer: Box::new(WordTokenizer::default()),
616            sampling_table: None,
617            huffman_tree: None,
618            hs_params: None,
619            current_learning_rate: learning_rate,
620        }
621    }
622
623    /// Set a custom tokenizer
624    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
625        self.tokenizer = tokenizer;
626        self
627    }
628
629    /// Set vector size
630    pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
631        self.config.vector_size = vectorsize;
632        self
633    }
634
635    /// Set window size
636    pub fn with_window_size(mut self, windowsize: usize) -> Self {
637        self.config.window_size = windowsize;
638        self
639    }
640
641    /// Set minimum count
642    pub fn with_min_count(mut self, mincount: usize) -> Self {
643        self.config.min_count = mincount;
644        self
645    }
646
647    /// Set number of epochs
648    pub fn with_epochs(mut self, epochs: usize) -> Self {
649        self.config.epochs = epochs;
650        self
651    }
652
653    /// Set learning rate
654    pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
655        self.config.learning_rate = learningrate;
656        self.current_learning_rate = learningrate;
657        self
658    }
659
660    /// Set algorithm (CBOW or Skip-gram)
661    pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
662        self.config.algorithm = algorithm;
663        self
664    }
665
666    /// Set number of negative samples
667    pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
668        self.config.negative_samples = negativesamples;
669        self
670    }
671
672    /// Set subsampling threshold
673    pub fn with_subsample(mut self, subsample: f64) -> Self {
674        self.config.subsample = subsample;
675        self
676    }
677
678    /// Set batch size
679    pub fn with_batch_size(mut self, batchsize: usize) -> Self {
680        self.config.batch_size = batchsize;
681        self
682    }
683
684    /// Build vocabulary from a corpus
685    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
686        if texts.is_empty() {
687            return Err(TextError::InvalidInput(
688                "No texts provided for building vocabulary".into(),
689            ));
690        }
691
692        // Count word frequencies
693        let mut word_counts = HashMap::new();
694        let mut _total_words = 0;
695
696        for &text in texts {
697            let tokens = self.tokenizer.tokenize(text)?;
698            for token in tokens {
699                *word_counts.entry(token).or_insert(0) += 1;
700                _total_words += 1;
701            }
702        }
703
704        // Create vocabulary with words that meet minimum count
705        self.vocabulary = Vocabulary::new();
706        for (word, count) in &word_counts {
707            if *count >= self.config.min_count {
708                self.vocabulary.add_token(word);
709            }
710        }
711
712        if self.vocabulary.is_empty() {
713            return Err(TextError::VocabularyError(
714                "No words meet the minimum count threshold".into(),
715            ));
716        }
717
718        // Initialize embeddings
719        let vocab_size = self.vocabulary.len();
720        let vector_size = self.config.vector_size;
721
722        // Initialize input and output embeddings with small random values
723        let mut rng = scirs2_core::random::rng();
724        let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
725            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
726        });
727        let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
728            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
729        });
730
731        self.input_embeddings = Some(input_embeddings);
732        self.output_embeddings = Some(output_embeddings);
733
734        // Create sampling table for negative sampling
735        self.create_sampling_table(&word_counts)?;
736
737        // Build Huffman tree for hierarchical softmax if configured
738        if self.config.hierarchical_softmax {
739            let frequencies: Vec<usize> = (0..vocab_size)
740                .map(|i| {
741                    self.vocabulary
742                        .get_token(i)
743                        .and_then(|word| word_counts.get(word).copied())
744                        .unwrap_or(1)
745                })
746                .collect();
747
748            let tree = HuffmanTree::build(&frequencies)?;
749            let num_internal = tree.num_internal;
750
751            // Initialize hierarchical softmax parameter vectors
752            let hs_params = Array2::zeros((num_internal, vector_size));
753            self.hs_params = Some(hs_params);
754            self.huffman_tree = Some(tree);
755        }
756
757        Ok(())
758    }
759
760    /// Create sampling table for negative sampling based on word frequencies
761    fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
762        // Prepare sampling weights (unigram distribution raised to 3/4 power)
763        let mut sampling_weights = vec![0.0; self.vocabulary.len()];
764
765        for (word, &count) in wordcounts.iter() {
766            if let Some(idx) = self.vocabulary.get_index(word) {
767                // Apply smoothing: frequency^0.75
768                sampling_weights[idx] = (count as f64).powf(0.75);
769            }
770        }
771
772        match SamplingTable::new(&sampling_weights) {
773            Ok(table) => {
774                self.sampling_table = Some(table);
775                Ok(())
776            }
777            Err(e) => Err(e),
778        }
779    }
780
781    /// Train the Word2Vec model on a corpus
782    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
783        if texts.is_empty() {
784            return Err(TextError::InvalidInput(
785                "No texts provided for training".into(),
786            ));
787        }
788
789        // Build vocabulary if not already built
790        if self.vocabulary.is_empty() {
791            self.build_vocabulary(texts)?;
792        }
793
794        if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
795            return Err(TextError::EmbeddingError(
796                "Embeddings not initialized. Call build_vocabulary() first".into(),
797            ));
798        }
799
800        // Count total number of tokens for progress tracking
801        let mut _total_tokens = 0;
802        let mut sentences = Vec::new();
803        for &text in texts {
804            let tokens = self.tokenizer.tokenize(text)?;
805            let filtered_tokens: Vec<usize> = tokens
806                .iter()
807                .filter_map(|token| self.vocabulary.get_index(token))
808                .collect();
809            if !filtered_tokens.is_empty() {
810                _total_tokens += filtered_tokens.len();
811                sentences.push(filtered_tokens);
812            }
813        }
814
815        // Train for the specified number of epochs
816        for epoch in 0..self.config.epochs {
817            // Update learning rate for this epoch
818            self.current_learning_rate =
819                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
820            self.current_learning_rate = self
821                .current_learning_rate
822                .max(self.config.learning_rate * 0.0001);
823
824            // Process each sentence
825            for sentence in &sentences {
826                // Apply subsampling of frequent words
827                let subsampled_sentence = if self.config.subsample > 0.0 {
828                    self.subsample_sentence(sentence)?
829                } else {
830                    sentence.clone()
831                };
832
833                // Skip empty sentences
834                if subsampled_sentence.is_empty() {
835                    continue;
836                }
837
838                // Train on the sentence
839                if self.config.hierarchical_softmax {
840                    // Use hierarchical softmax
841                    match self.config.algorithm {
842                        Word2VecAlgorithm::SkipGram => {
843                            self.train_skipgram_hs_sentence(&subsampled_sentence)?;
844                        }
845                        Word2VecAlgorithm::CBOW => {
846                            self.train_cbow_hs_sentence(&subsampled_sentence)?;
847                        }
848                    }
849                } else {
850                    // Use negative sampling
851                    match self.config.algorithm {
852                        Word2VecAlgorithm::CBOW => {
853                            self.train_cbow_sentence(&subsampled_sentence)?;
854                        }
855                        Word2VecAlgorithm::SkipGram => {
856                            self.train_skipgram_sentence(&subsampled_sentence)?;
857                        }
858                    }
859                }
860            }
861        }
862
863        Ok(())
864    }
865
866    /// Apply subsampling to a sentence
867    fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
868        let mut rng = scirs2_core::random::rng();
869        let total_words: f64 = self.vocabulary.len() as f64;
870        let threshold = self.config.subsample * total_words;
871
872        // Filter words based on subsampling probability
873        let subsampled: Vec<usize> = sentence
874            .iter()
875            .filter(|&&word_idx| {
876                let word_freq = self.get_word_frequency(word_idx);
877                if word_freq == 0.0 {
878                    return true; // Keep rare words
879                }
880                // Probability of keeping the word
881                let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
882                rng.random::<f64>() < keep_prob
883            })
884            .copied()
885            .collect();
886
887        Ok(subsampled)
888    }
889
890    /// Get the frequency of a word in the vocabulary
891    fn get_word_frequency(&self, wordidx: usize) -> f64 {
892        // This is a simplified version; ideal implementation would track actual frequencies
893        // For now, we'll use the sampling table weights as a proxy
894        if let Some(table) = &self.sampling_table {
895            table.weights()[wordidx]
896        } else {
897            1.0 // Default weight if no sampling table exists
898        }
899    }
900
901    /// Train CBOW model on a single sentence
902    fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
903        if sentence.len() < 2 {
904            return Ok(()); // Need at least 2 words for context
905        }
906
907        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
908        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
909        let vector_size = self.config.vector_size;
910        let window_size = self.config.window_size;
911        let negative_samples = self.config.negative_samples;
912
913        // For each position in sentence, predict the word from its context
914        for pos in 0..sentence.len() {
915            // Determine context window (with random size)
916            let mut rng = scirs2_core::random::rng();
917            let window = 1 + rng.random_range(0..window_size);
918            let target_word = sentence[pos];
919
920            // Collect context words and average their vectors
921            let mut context_words = Vec::new();
922            #[allow(clippy::needless_range_loop)]
923            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
924                if i != pos {
925                    context_words.push(sentence[i]);
926                }
927            }
928
929            if context_words.is_empty() {
930                continue; // No context words
931            }
932
933            // Average the context word vectors
934            let mut context_sum = Array1::zeros(vector_size);
935            for &context_idx in &context_words {
936                context_sum += &input_embeddings.row(context_idx);
937            }
938            let context_avg = &context_sum / context_words.len() as f64;
939
940            // Update target word's output embedding with positive example
941            let mut target_output = output_embeddings.row_mut(target_word);
942            let dot_product = (&context_avg * &target_output).sum();
943            let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
944            let error = (1.0 - sigmoid) * self.current_learning_rate;
945
946            // Create a copy for update
947            let mut target_update = target_output.to_owned();
948            target_update.scaled_add(error, &context_avg);
949            target_output.assign(&target_update);
950
951            // Negative sampling
952            if let Some(sampler) = &self.sampling_table {
953                for _ in 0..negative_samples {
954                    let negative_idx = sampler.sample(&mut rng);
955                    if negative_idx == target_word {
956                        continue; // Skip if we sample the target word
957                    }
958
959                    let mut negative_output = output_embeddings.row_mut(negative_idx);
960                    let dot_product = (&context_avg * &negative_output).sum();
961                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
962                    let error = -sigmoid * self.current_learning_rate;
963
964                    // Create a copy for update
965                    let mut negative_update = negative_output.to_owned();
966                    negative_update.scaled_add(error, &context_avg);
967                    negative_output.assign(&negative_update);
968                }
969            }
970
971            // Update context word vectors
972            for &context_idx in &context_words {
973                let mut input_vec = input_embeddings.row_mut(context_idx);
974
975                // Positive example
976                let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
977                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
978                let error =
979                    (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
980
981                // Create a copy for update
982                let mut input_update = input_vec.to_owned();
983                input_update.scaled_add(error, &output_embeddings.row(target_word));
984
985                // Negative examples
986                if let Some(sampler) = &self.sampling_table {
987                    for _ in 0..negative_samples {
988                        let negative_idx = sampler.sample(&mut rng);
989                        if negative_idx == target_word {
990                            continue;
991                        }
992
993                        let dot_product =
994                            (&context_avg * &output_embeddings.row(negative_idx)).sum();
995                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
996                        let error =
997                            -sigmoid * self.current_learning_rate / context_words.len() as f64;
998
999                        input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1000                    }
1001                }
1002
1003                input_vec.assign(&input_update);
1004            }
1005        }
1006
1007        Ok(())
1008    }
1009
1010    /// Train Skip-gram model on a single sentence
1011    fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1012        if sentence.len() < 2 {
1013            return Ok(()); // Need at least 2 words for context
1014        }
1015
1016        let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1017        let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1018        let vector_size = self.config.vector_size;
1019        let window_size = self.config.window_size;
1020        let negative_samples = self.config.negative_samples;
1021
1022        // For each position in sentence, predict the context from the word
1023        for pos in 0..sentence.len() {
1024            // Determine context window (with random size)
1025            let mut rng = scirs2_core::random::rng();
1026            let window = 1 + rng.random_range(0..window_size);
1027            let target_word = sentence[pos];
1028
1029            // For each context position
1030            #[allow(clippy::needless_range_loop)]
1031            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1032                if i == pos {
1033                    continue; // Skip the target word itself
1034                }
1035
1036                let context_word = sentence[i];
1037                let target_input = input_embeddings.row(target_word);
1038                let mut context_output = output_embeddings.row_mut(context_word);
1039
1040                // Update context word's output embedding with positive example
1041                let dot_product = (&target_input * &context_output).sum();
1042                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1043                let error = (1.0 - sigmoid) * self.current_learning_rate;
1044
1045                // Create a copy for update
1046                let mut context_update = context_output.to_owned();
1047                context_update.scaled_add(error, &target_input);
1048                context_output.assign(&context_update);
1049
1050                // Gradient for input word vector
1051                let mut input_update = Array1::zeros(vector_size);
1052                input_update.scaled_add(error, &context_output);
1053
1054                // Negative sampling
1055                if let Some(sampler) = &self.sampling_table {
1056                    for _ in 0..negative_samples {
1057                        let negative_idx = sampler.sample(&mut rng);
1058                        if negative_idx == context_word {
1059                            continue; // Skip if we sample the context word
1060                        }
1061
1062                        let mut negative_output = output_embeddings.row_mut(negative_idx);
1063                        let dot_product = (&target_input * &negative_output).sum();
1064                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1065                        let error = -sigmoid * self.current_learning_rate;
1066
1067                        // Create a copy for update
1068                        let mut negative_update = negative_output.to_owned();
1069                        negative_update.scaled_add(error, &target_input);
1070                        negative_output.assign(&negative_update);
1071
1072                        // Update input gradient
1073                        input_update.scaled_add(error, &negative_output);
1074                    }
1075                }
1076
1077                // Apply the accumulated gradient to the input word vector
1078                let mut target_input_mut = input_embeddings.row_mut(target_word);
1079                target_input_mut += &input_update;
1080            }
1081        }
1082
1083        Ok(())
1084    }
1085
1086    /// Get the vector size
1087    pub fn vector_size(&self) -> usize {
1088        self.config.vector_size
1089    }
1090
1091    /// Get the embedding vector for a word
1092    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1093        if self.input_embeddings.is_none() {
1094            return Err(TextError::EmbeddingError(
1095                "Model not trained. Call train() first".into(),
1096            ));
1097        }
1098
1099        match self.vocabulary.get_index(word) {
1100            Some(idx) => Ok(self
1101                .input_embeddings
1102                .as_ref()
1103                .expect("Operation failed")
1104                .row(idx)
1105                .to_owned()),
1106            None => Err(TextError::VocabularyError(format!(
1107                "Word '{word}' not in vocabulary"
1108            ))),
1109        }
1110    }
1111
1112    /// Get the most similar words to a given word
1113    pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1114        let word_vec = self.get_word_vector(word)?;
1115        self.most_similar_by_vector(&word_vec, topn, &[word])
1116    }
1117
1118    /// Get the most similar words to a given vector
1119    pub fn most_similar_by_vector(
1120        &self,
1121        vector: &Array1<f64>,
1122        top_n: usize,
1123        exclude_words: &[&str],
1124    ) -> Result<Vec<(String, f64)>> {
1125        if self.input_embeddings.is_none() {
1126            return Err(TextError::EmbeddingError(
1127                "Model not trained. Call train() first".into(),
1128            ));
1129        }
1130
1131        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1132        let vocab_size = self.vocabulary.len();
1133
1134        // Create a set of indices to exclude
1135        let exclude_indices: Vec<usize> = exclude_words
1136            .iter()
1137            .filter_map(|&word| self.vocabulary.get_index(word))
1138            .collect();
1139
1140        // Calculate cosine similarity for all _words
1141        let mut similarities = Vec::with_capacity(vocab_size);
1142
1143        for i in 0..vocab_size {
1144            if exclude_indices.contains(&i) {
1145                continue;
1146            }
1147
1148            let word_vec = input_embeddings.row(i);
1149            let similarity = cosine_similarity(vector, &word_vec.to_owned());
1150
1151            if let Some(word) = self.vocabulary.get_token(i) {
1152                similarities.push((word.to_string(), similarity));
1153            }
1154        }
1155
1156        // Sort by similarity (descending)
1157        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1158
1159        // Take top N
1160        let result = similarities.into_iter().take(top_n).collect();
1161        Ok(result)
1162    }
1163
1164    /// Compute the analogy: a is to b as c is to ?
1165    pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1166        if self.input_embeddings.is_none() {
1167            return Err(TextError::EmbeddingError(
1168                "Model not trained. Call train() first".into(),
1169            ));
1170        }
1171
1172        // Get vectors for a, b, and c
1173        let a_vec = self.get_word_vector(a)?;
1174        let b_vec = self.get_word_vector(b)?;
1175        let c_vec = self.get_word_vector(c)?;
1176
1177        // Compute d_vec = b_vec - a_vec + c_vec
1178        let mut d_vec = b_vec.clone();
1179        d_vec -= &a_vec;
1180        d_vec += &c_vec;
1181
1182        // Normalize the vector
1183        let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1184        d_vec.mapv_inplace(|val| val / norm);
1185
1186        // Find most similar words to d_vec
1187        self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1188    }
1189
1190    /// Save the Word2Vec model to a file
1191    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1192        if self.input_embeddings.is_none() {
1193            return Err(TextError::EmbeddingError(
1194                "Model not trained. Call train() first".into(),
1195            ));
1196        }
1197
1198        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1199
1200        // Write header: vector_size and vocabulary size
1201        writeln!(
1202            &mut file,
1203            "{} {}",
1204            self.vocabulary.len(),
1205            self.config.vector_size
1206        )
1207        .map_err(|e| TextError::IoError(e.to_string()))?;
1208
1209        // Write each word and its vector
1210        let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1211
1212        for i in 0..self.vocabulary.len() {
1213            if let Some(word) = self.vocabulary.get_token(i) {
1214                // Write the word
1215                write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1216
1217                // Write the vector components
1218                let vector = input_embeddings.row(i);
1219                for j in 0..self.config.vector_size {
1220                    write!(&mut file, "{:.6} ", vector[j])
1221                        .map_err(|e| TextError::IoError(e.to_string()))?;
1222                }
1223
1224                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1225            }
1226        }
1227
1228        Ok(())
1229    }
1230
1231    /// Load a Word2Vec model from a file
1232    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1233        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1234        let mut reader = BufReader::new(file);
1235
1236        // Read header
1237        let mut header = String::new();
1238        reader
1239            .read_line(&mut header)
1240            .map_err(|e| TextError::IoError(e.to_string()))?;
1241
1242        let parts: Vec<&str> = header.split_whitespace().collect();
1243        if parts.len() != 2 {
1244            return Err(TextError::EmbeddingError(
1245                "Invalid model file format".into(),
1246            ));
1247        }
1248
1249        let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1250            TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1251        })?;
1252
1253        let vector_size = parts[1]
1254            .parse::<usize>()
1255            .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1256
1257        // Initialize model
1258        let mut model = Self::new().with_vector_size(vector_size);
1259        let mut vocabulary = Vocabulary::new();
1260        let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1261
1262        // Read each word and its vector
1263        let mut i = 0;
1264        for line in reader.lines() {
1265            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1266            let parts: Vec<&str> = line.split_whitespace().collect();
1267
1268            if parts.len() != vector_size + 1 {
1269                let line_num = i + 2;
1270                return Err(TextError::EmbeddingError(format!(
1271                    "Invalid vector format at line {line_num}"
1272                )));
1273            }
1274
1275            let word = parts[0];
1276            vocabulary.add_token(word);
1277
1278            for j in 0..vector_size {
1279                input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1280                    TextError::EmbeddingError(format!(
1281                        "Invalid vector component at line {}, position {}",
1282                        i + 2,
1283                        j + 1
1284                    ))
1285                })?;
1286            }
1287
1288            i += 1;
1289        }
1290
1291        if i != vocab_size {
1292            return Err(TextError::EmbeddingError(format!(
1293                "Expected {vocab_size} words but found {i}"
1294            )));
1295        }
1296
1297        model.vocabulary = vocabulary;
1298        model.input_embeddings = Some(input_embeddings);
1299        model.output_embeddings = None; // Only input embeddings are saved
1300
1301        Ok(model)
1302    }
1303
1304    // Getter methods for model registry serialization
1305
1306    /// Get the vocabulary as a vector of strings
1307    pub fn get_vocabulary(&self) -> Vec<String> {
1308        let mut vocab = Vec::new();
1309        for i in 0..self.vocabulary.len() {
1310            if let Some(token) = self.vocabulary.get_token(i) {
1311                vocab.push(token.to_string());
1312            }
1313        }
1314        vocab
1315    }
1316
1317    /// Get the vector size
1318    pub fn get_vector_size(&self) -> usize {
1319        self.config.vector_size
1320    }
1321
1322    /// Get the algorithm
1323    pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1324        self.config.algorithm
1325    }
1326
1327    /// Get the window size
1328    pub fn get_window_size(&self) -> usize {
1329        self.config.window_size
1330    }
1331
1332    /// Get the minimum count
1333    pub fn get_min_count(&self) -> usize {
1334        self.config.min_count
1335    }
1336
1337    /// Get the embeddings matrix (input embeddings)
1338    pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1339        self.input_embeddings.clone()
1340    }
1341
1342    /// Get the number of negative samples
1343    pub fn get_negative_samples(&self) -> usize {
1344        self.config.negative_samples
1345    }
1346
1347    /// Get the learning rate
1348    pub fn get_learning_rate(&self) -> f64 {
1349        self.config.learning_rate
1350    }
1351
1352    /// Get the number of epochs
1353    pub fn get_epochs(&self) -> usize {
1354        self.config.epochs
1355    }
1356
1357    /// Get the subsampling threshold
1358    pub fn get_subsampling_threshold(&self) -> f64 {
1359        self.config.subsample
1360    }
1361
1362    /// Check if hierarchical softmax is enabled
1363    pub fn uses_hierarchical_softmax(&self) -> bool {
1364        self.config.hierarchical_softmax
1365    }
1366
1367    // ─── Hierarchical Softmax Training ──────────────────────────────────
1368
1369    /// Train Skip-gram with hierarchical softmax on a single sentence
1370    fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1371        if sentence.len() < 2 {
1372            return Ok(());
1373        }
1374
1375        let input_embeddings = self
1376            .input_embeddings
1377            .as_mut()
1378            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1379        let hs_params = self
1380            .hs_params
1381            .as_mut()
1382            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1383        let tree = self
1384            .huffman_tree
1385            .as_ref()
1386            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1387
1388        let vector_size = self.config.vector_size;
1389        let window_size = self.config.window_size;
1390        let lr = self.current_learning_rate;
1391
1392        let codes = tree.codes.clone();
1393        let paths = tree.paths.clone();
1394
1395        let mut rng = scirs2_core::random::rng();
1396
1397        for pos in 0..sentence.len() {
1398            let window = 1 + rng.random_range(0..window_size);
1399            let target_word = sentence[pos];
1400
1401            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1402                if i == pos {
1403                    continue;
1404                }
1405
1406                let context_word = sentence[i];
1407                let code = &codes[context_word];
1408                let path = &paths[context_word];
1409
1410                let mut grad_input = Array1::zeros(vector_size);
1411
1412                // Walk the Huffman tree path for the context word
1413                for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1414                    if node_idx >= hs_params.nrows() {
1415                        continue;
1416                    }
1417
1418                    // Compute sigmoid(input_vec . hs_param)
1419                    let input_vec = input_embeddings.row(target_word);
1420                    let param_vec = hs_params.row(node_idx);
1421
1422                    let dot: f64 = input_vec
1423                        .iter()
1424                        .zip(param_vec.iter())
1425                        .map(|(a, b)| a * b)
1426                        .sum();
1427                    let sigmoid = 1.0 / (1.0 + (-dot).exp());
1428
1429                    // gradient: (1 - label - sigmoid) * lr
1430                    let target = if label == 0 { 1.0 } else { 0.0 };
1431                    let gradient = (target - sigmoid) * lr;
1432
1433                    // Update gradient for input vector
1434                    grad_input.scaled_add(gradient, &param_vec.to_owned());
1435
1436                    // Update HS parameter vector
1437                    let input_owned = input_vec.to_owned();
1438                    let mut param_mut = hs_params.row_mut(node_idx);
1439                    param_mut.scaled_add(gradient, &input_owned);
1440                }
1441
1442                // Update input embedding
1443                let mut input_mut = input_embeddings.row_mut(target_word);
1444                input_mut += &grad_input;
1445            }
1446        }
1447
1448        Ok(())
1449    }
1450
1451    /// Train CBOW with hierarchical softmax on a single sentence
1452    fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1453        if sentence.len() < 2 {
1454            return Ok(());
1455        }
1456
1457        let input_embeddings = self
1458            .input_embeddings
1459            .as_mut()
1460            .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1461        let hs_params = self
1462            .hs_params
1463            .as_mut()
1464            .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1465        let tree = self
1466            .huffman_tree
1467            .as_ref()
1468            .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1469
1470        let vector_size = self.config.vector_size;
1471        let window_size = self.config.window_size;
1472        let lr = self.current_learning_rate;
1473
1474        let codes = tree.codes.clone();
1475        let paths = tree.paths.clone();
1476
1477        let mut rng = scirs2_core::random::rng();
1478
1479        for pos in 0..sentence.len() {
1480            let window = 1 + rng.random_range(0..window_size);
1481            let target_word = sentence[pos];
1482
1483            // Collect context words
1484            let mut context_words = Vec::new();
1485            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1486                if i != pos {
1487                    context_words.push(sentence[i]);
1488                }
1489            }
1490
1491            if context_words.is_empty() {
1492                continue;
1493            }
1494
1495            // Average context word vectors
1496            let mut context_avg = Array1::zeros(vector_size);
1497            for &ctx_idx in &context_words {
1498                context_avg += &input_embeddings.row(ctx_idx);
1499            }
1500            context_avg /= context_words.len() as f64;
1501
1502            // Walk Huffman path for target word
1503            let code = &codes[target_word];
1504            let path = &paths[target_word];
1505
1506            let mut grad_context = Array1::zeros(vector_size);
1507
1508            for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1509                if node_idx >= hs_params.nrows() {
1510                    continue;
1511                }
1512
1513                let param_vec = hs_params.row(node_idx);
1514
1515                let dot: f64 = context_avg
1516                    .iter()
1517                    .zip(param_vec.iter())
1518                    .map(|(a, b)| a * b)
1519                    .sum();
1520                let sigmoid = 1.0 / (1.0 + (-dot).exp());
1521
1522                let target = if label == 0 { 1.0 } else { 0.0 };
1523                let gradient = (target - sigmoid) * lr;
1524
1525                grad_context.scaled_add(gradient, &param_vec.to_owned());
1526
1527                // Update HS parameter
1528                let ctx_owned = context_avg.clone();
1529                let mut param_mut = hs_params.row_mut(node_idx);
1530                param_mut.scaled_add(gradient, &ctx_owned);
1531            }
1532
1533            // Distribute gradient back to context word input embeddings
1534            let grad_per_word = &grad_context / context_words.len() as f64;
1535            for &ctx_idx in &context_words {
1536                let mut input_mut = input_embeddings.row_mut(ctx_idx);
1537                input_mut += &grad_per_word;
1538            }
1539        }
1540
1541        Ok(())
1542    }
1543}
1544
1545// ─── WordEmbedding trait implementation for Word2Vec ─────────────────────────
1546
1547impl WordEmbedding for Word2Vec {
1548    fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1549        self.get_word_vector(word)
1550    }
1551
1552    fn dimension(&self) -> usize {
1553        self.vector_size()
1554    }
1555
1556    fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1557        self.most_similar(word, top_n)
1558    }
1559
1560    fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1561        self.analogy(a, b, c, top_n)
1562    }
1563
1564    fn vocab_size(&self) -> usize {
1565        self.vocabulary.len()
1566    }
1567}
1568
1569/// Calculate cosine similarity between two vectors
1570#[allow(dead_code)]
1571pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1572    let dot_product = (a * b).sum();
1573    let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1574    let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1575
1576    if norm_a > 0.0 && norm_b > 0.0 {
1577        dot_product / (norm_a * norm_b)
1578    } else {
1579        0.0
1580    }
1581}
1582
1583#[cfg(test)]
1584mod tests {
1585    use super::*;
1586    use approx::assert_relative_eq;
1587
1588    #[test]
1589    fn test_cosine_similarity() {
1590        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1591        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1592
1593        let similarity = cosine_similarity(&a, &b);
1594        let expected = 0.9746318461970762;
1595        assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1596    }
1597
1598    #[test]
1599    fn test_word2vec_config() {
1600        let config = Word2VecConfig::default();
1601        assert_eq!(config.vector_size, 100);
1602        assert_eq!(config.window_size, 5);
1603        assert_eq!(config.min_count, 5);
1604        assert_eq!(config.epochs, 5);
1605        assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1606    }
1607
1608    #[test]
1609    fn test_word2vec_builder() {
1610        let model = Word2Vec::new()
1611            .with_vector_size(200)
1612            .with_window_size(10)
1613            .with_learning_rate(0.05)
1614            .with_algorithm(Word2VecAlgorithm::CBOW);
1615
1616        assert_eq!(model.config.vector_size, 200);
1617        assert_eq!(model.config.window_size, 10);
1618        assert_eq!(model.config.learning_rate, 0.05);
1619        assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1620    }
1621
1622    #[test]
1623    fn test_build_vocabulary() {
1624        let texts = [
1625            "the quick brown fox jumps over the lazy dog",
1626            "a quick brown fox jumps over a lazy dog",
1627        ];
1628
1629        let mut model = Word2Vec::new().with_min_count(1);
1630        let result = model.build_vocabulary(&texts);
1631        assert!(result.is_ok());
1632
1633        // Check vocabulary size (unique words: "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", "a")
1634        assert_eq!(model.vocabulary.len(), 9);
1635
1636        // Check that embeddings were initialized
1637        assert!(model.input_embeddings.is_some());
1638        assert!(model.output_embeddings.is_some());
1639        assert_eq!(
1640            model
1641                .input_embeddings
1642                .as_ref()
1643                .expect("Operation failed")
1644                .shape(),
1645            &[9, 100]
1646        );
1647    }
1648
1649    #[test]
1650    fn test_skipgram_training_small() {
1651        let texts = [
1652            "the quick brown fox jumps over the lazy dog",
1653            "a quick brown fox jumps over a lazy dog",
1654        ];
1655
1656        let mut model = Word2Vec::new()
1657            .with_vector_size(10)
1658            .with_window_size(2)
1659            .with_min_count(1)
1660            .with_epochs(1)
1661            .with_algorithm(Word2VecAlgorithm::SkipGram);
1662
1663        let result = model.train(&texts);
1664        assert!(result.is_ok());
1665
1666        // Test getting a word vector
1667        let result = model.get_word_vector("fox");
1668        assert!(result.is_ok());
1669        let vec = result.expect("Operation failed");
1670        assert_eq!(vec.len(), 10);
1671    }
1672
1673    // ─── Hierarchical Softmax Tests ──────────────────────────────────
1674
1675    #[test]
1676    fn test_huffman_tree_build() {
1677        let frequencies = vec![5, 3, 8, 1, 2];
1678        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1679
1680        // Each word should have a code
1681        assert_eq!(tree.codes.len(), 5);
1682        assert_eq!(tree.paths.len(), 5);
1683
1684        // All codes should be non-empty
1685        for code in &tree.codes {
1686            assert!(!code.is_empty());
1687        }
1688
1689        // Internal nodes = vocab_size - 1 (for a binary tree)
1690        assert_eq!(tree.num_internal, 4);
1691    }
1692
1693    #[test]
1694    fn test_huffman_tree_single_word() {
1695        let frequencies = vec![10];
1696        let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1697        assert_eq!(tree.codes.len(), 1);
1698        assert_eq!(tree.paths.len(), 1);
1699    }
1700
1701    #[test]
1702    fn test_skipgram_hierarchical_softmax() {
1703        let texts = [
1704            "the quick brown fox jumps over the lazy dog",
1705            "a quick brown fox jumps over a lazy dog",
1706        ];
1707
1708        let config = Word2VecConfig {
1709            vector_size: 10,
1710            window_size: 2,
1711            min_count: 1,
1712            epochs: 3,
1713            learning_rate: 0.025,
1714            algorithm: Word2VecAlgorithm::SkipGram,
1715            hierarchical_softmax: true,
1716            ..Default::default()
1717        };
1718
1719        let mut model = Word2Vec::with_config(config);
1720        let result = model.train(&texts);
1721        assert!(
1722            result.is_ok(),
1723            "HS skipgram training failed: {:?}",
1724            result.err()
1725        );
1726
1727        assert!(model.uses_hierarchical_softmax());
1728
1729        // Should produce valid word vectors
1730        let vec = model.get_word_vector("fox");
1731        assert!(vec.is_ok());
1732        assert_eq!(vec.expect("get vec").len(), 10);
1733    }
1734
1735    #[test]
1736    fn test_cbow_hierarchical_softmax() {
1737        let texts = [
1738            "the quick brown fox jumps over the lazy dog",
1739            "a quick brown fox jumps over a lazy dog",
1740        ];
1741
1742        let config = Word2VecConfig {
1743            vector_size: 10,
1744            window_size: 2,
1745            min_count: 1,
1746            epochs: 3,
1747            learning_rate: 0.025,
1748            algorithm: Word2VecAlgorithm::CBOW,
1749            hierarchical_softmax: true,
1750            ..Default::default()
1751        };
1752
1753        let mut model = Word2Vec::with_config(config);
1754        let result = model.train(&texts);
1755        assert!(
1756            result.is_ok(),
1757            "HS CBOW training failed: {:?}",
1758            result.err()
1759        );
1760
1761        let vec = model.get_word_vector("dog");
1762        assert!(vec.is_ok());
1763    }
1764
1765    // ─── WordEmbedding Trait Tests ──────────────────────────────────
1766
1767    #[test]
1768    fn test_word_embedding_trait_word2vec() {
1769        let texts = [
1770            "the quick brown fox jumps over the lazy dog",
1771            "a quick brown fox jumps over a lazy dog",
1772        ];
1773
1774        let mut model = Word2Vec::new()
1775            .with_vector_size(10)
1776            .with_min_count(1)
1777            .with_epochs(1);
1778
1779        model.train(&texts).expect("Training failed");
1780
1781        // Use via trait
1782        let emb: &dyn WordEmbedding = &model;
1783        assert_eq!(emb.dimension(), 10);
1784        assert!(emb.vocab_size() > 0);
1785
1786        let vec = emb.embedding("fox");
1787        assert!(vec.is_ok());
1788
1789        let sim = emb.similarity("fox", "dog");
1790        assert!(sim.is_ok());
1791        assert!(sim.expect("sim").is_finite());
1792
1793        let similar = emb.find_similar("fox", 2);
1794        assert!(similar.is_ok());
1795
1796        let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1797        assert!(analogy.is_ok());
1798    }
1799
1800    #[test]
1801    fn test_embedding_cosine_similarity_fn() {
1802        let a = Array1::from_vec(vec![1.0, 0.0]);
1803        let b = Array1::from_vec(vec![0.0, 1.0]);
1804        assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1805
1806        let c = Array1::from_vec(vec![1.0, 1.0]);
1807        let d = Array1::from_vec(vec![1.0, 1.0]);
1808        assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1809    }
1810
1811    #[test]
1812    fn test_pairwise_similarity_fn() {
1813        let texts = ["the quick brown fox", "the lazy brown dog"];
1814
1815        let mut model = Word2Vec::new()
1816            .with_vector_size(10)
1817            .with_min_count(1)
1818            .with_epochs(1);
1819        model.train(&texts).expect("Training failed");
1820
1821        let words = vec!["the", "fox", "dog"];
1822        let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1823
1824        assert_eq!(matrix.len(), 3);
1825        assert_eq!(matrix[0].len(), 3);
1826
1827        // Diagonal should be 1.0
1828        for i in 0..3 {
1829            assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1830        }
1831
1832        // Symmetric
1833        for i in 0..3 {
1834            for j in 0..3 {
1835                assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1836            }
1837        }
1838    }
1839}