Skip to main content

scirs2_text/embeddings/
fasttext.rs

1//! FastText embeddings with character n-grams
2//!
3//! This module implements FastText, an extension of Word2Vec that learns
4//! word representations as bags of character n-grams. This approach handles
5//! out-of-vocabulary words and morphologically rich languages better.
6//!
7//! ## Overview
8//!
9//! FastText represents each word as a bag of character n-grams. For example:
10//! - word: "where"
11//! - 3-grams: "<wh", "whe", "her", "ere", "re>"
12//! - The word embedding is the sum of its n-gram embeddings
13//!
14//! Key advantages over vanilla Word2Vec:
15//! - Handles out-of-vocabulary (OOV) words via subword decomposition
16//! - Better representations for morphologically rich languages
17//! - Captures internal word structure (prefixes, suffixes, roots)
18//!
19//! ## Quick Start
20//!
21//! ```rust
22//! use scirs2_text::embeddings::fasttext::{FastText, FastTextConfig};
23//!
24//! // Create configuration
25//! let config = FastTextConfig {
26//!     vector_size: 100,
27//!     min_n: 3,
28//!     max_n: 6,
29//!     window_size: 5,
30//!     epochs: 5,
31//!     learning_rate: 0.05,
32//!     min_count: 1,
33//!     negative_samples: 5,
34//!     ..Default::default()
35//! };
36//!
37//! // Train model
38//! let documents = vec![
39//!     "the quick brown fox jumps over the lazy dog",
40//!     "a quick brown dog outpaces a quick fox"
41//! ];
42//!
43//! let mut model = FastText::with_config(config);
44//! model.train(&documents).expect("Training failed");
45//!
46//! // Get word vector (works even for OOV words!)
47//! if let Ok(vector) = model.get_word_vector("quickest") {
48//!     println!("Vector for OOV word 'quickest': {:?}", vector);
49//! }
50//! ```
51
52use crate::error::{Result, TextError};
53use crate::tokenize::{Tokenizer, WordTokenizer};
54use crate::vocabulary::Vocabulary;
55use scirs2_core::ndarray::{Array1, Array2};
56use scirs2_core::random::prelude::*;
57use std::collections::HashMap;
58use std::fmt::Debug;
59use std::fs::File;
60use std::io::{BufRead, BufReader, Write};
61use std::path::Path;
62
63/// FastText configuration
64#[derive(Debug, Clone)]
65pub struct FastTextConfig {
66    /// Size of word vectors
67    pub vector_size: usize,
68    /// Minimum length of character n-grams
69    pub min_n: usize,
70    /// Maximum length of character n-grams
71    pub max_n: usize,
72    /// Size of context window
73    pub window_size: usize,
74    /// Number of training epochs
75    pub epochs: usize,
76    /// Learning rate
77    pub learning_rate: f64,
78    /// Minimum word count threshold
79    pub min_count: usize,
80    /// Number of negative samples
81    pub negative_samples: usize,
82    /// Subsampling threshold for frequent words
83    pub subsample: f64,
84    /// Bucket size for hashing n-grams
85    pub bucket_size: usize,
86}
87
88impl Default for FastTextConfig {
89    fn default() -> Self {
90        Self {
91            vector_size: 100,
92            min_n: 3,
93            max_n: 6,
94            window_size: 5,
95            epochs: 5,
96            learning_rate: 0.05,
97            min_count: 5,
98            negative_samples: 5,
99            subsample: 1e-3,
100            bucket_size: 2_000_000,
101        }
102    }
103}
104
105/// FastText model for learning word representations with character n-grams
106///
107/// Decomposes each word into character n-grams (subwords) and learns embeddings
108/// for both whole words and their constituent n-grams. This enables:
109///
110/// - Out-of-vocabulary word handling (any word can be represented via its n-grams)
111/// - Morphological awareness (similar prefixes/suffixes produce similar vectors)
112/// - Robustness to misspellings and rare word forms
113pub struct FastText {
114    /// Configuration
115    config: FastTextConfig,
116    /// Vocabulary of words
117    vocabulary: Vocabulary,
118    /// Word frequencies
119    word_counts: HashMap<String, usize>,
120    /// Word embeddings (for words in vocabulary)
121    word_embeddings: Option<Array2<f64>>,
122    /// Output embeddings used during training (for negative sampling)
123    output_embeddings: Option<Array2<f64>>,
124    /// N-gram embeddings (subword information)
125    ngram_embeddings: Option<Array2<f64>>,
126    /// N-gram to bucket index mapping
127    ngram_to_bucket: HashMap<String, usize>,
128    /// Tokenizer
129    tokenizer: Box<dyn Tokenizer + Send + Sync>,
130    /// Current learning rate
131    current_learning_rate: f64,
132    /// Negative sampling table (unigram distribution raised to 3/4)
133    sampling_weights: Vec<f64>,
134}
135
136impl Debug for FastText {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("FastText")
139            .field("config", &self.config)
140            .field("vocabulary_size", &self.vocabulary.len())
141            .field("word_embeddings", &self.word_embeddings.is_some())
142            .field("ngram_embeddings", &self.ngram_embeddings.is_some())
143            .field("ngram_count", &self.ngram_to_bucket.len())
144            .finish()
145    }
146}
147
148impl Clone for FastText {
149    fn clone(&self) -> Self {
150        Self {
151            config: self.config.clone(),
152            vocabulary: self.vocabulary.clone(),
153            word_counts: self.word_counts.clone(),
154            word_embeddings: self.word_embeddings.clone(),
155            output_embeddings: self.output_embeddings.clone(),
156            ngram_embeddings: self.ngram_embeddings.clone(),
157            ngram_to_bucket: self.ngram_to_bucket.clone(),
158            tokenizer: Box::new(WordTokenizer::default()),
159            current_learning_rate: self.current_learning_rate,
160            sampling_weights: self.sampling_weights.clone(),
161        }
162    }
163}
164
165impl FastText {
166    /// Create a new FastText model with default configuration
167    pub fn new() -> Self {
168        Self {
169            config: FastTextConfig::default(),
170            vocabulary: Vocabulary::new(),
171            word_counts: HashMap::new(),
172            word_embeddings: None,
173            output_embeddings: None,
174            ngram_embeddings: None,
175            ngram_to_bucket: HashMap::new(),
176            tokenizer: Box::new(WordTokenizer::default()),
177            current_learning_rate: 0.05,
178            sampling_weights: Vec::new(),
179        }
180    }
181
182    /// Create a new FastText model with custom configuration
183    pub fn with_config(config: FastTextConfig) -> Self {
184        let learning_rate = config.learning_rate;
185        Self {
186            config,
187            vocabulary: Vocabulary::new(),
188            word_counts: HashMap::new(),
189            word_embeddings: None,
190            output_embeddings: None,
191            ngram_embeddings: None,
192            ngram_to_bucket: HashMap::new(),
193            tokenizer: Box::new(WordTokenizer::default()),
194            current_learning_rate: learning_rate,
195            sampling_weights: Vec::new(),
196        }
197    }
198
199    /// Set a custom tokenizer
200    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
201        self.tokenizer = tokenizer;
202        self
203    }
204
205    /// Extract character n-grams from a word
206    ///
207    /// Wraps the word with boundary markers < and > before extracting.
208    /// For example, "fox" with min_n=3, max_n=4 produces:
209    /// 3-grams: "<fo", "fox", "ox>"
210    /// 4-grams: "<fox", "fox>", "`<fox>`"(if len allows)
211    pub fn extract_ngrams(&self, word: &str) -> Vec<String> {
212        let word_with_boundaries = format!("<{}>", word);
213        let chars: Vec<char> = word_with_boundaries.chars().collect();
214        let mut ngrams = Vec::new();
215
216        for n in self.config.min_n..=self.config.max_n {
217            if chars.len() < n {
218                continue;
219            }
220
221            for i in 0..=(chars.len() - n) {
222                let ngram: String = chars[i..i + n].iter().collect();
223                ngrams.push(ngram);
224            }
225        }
226
227        ngrams
228    }
229
230    /// Hash an n-gram to a bucket index using FNV-1a
231    fn hash_ngram(&self, ngram: &str) -> usize {
232        let mut hash: u64 = 2166136261;
233        for byte in ngram.bytes() {
234            hash ^= u64::from(byte);
235            hash = hash.wrapping_mul(16777619);
236        }
237        (hash % (self.config.bucket_size as u64)) as usize
238    }
239
240    /// Build vocabulary from texts
241    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
242        if texts.is_empty() {
243            return Err(TextError::InvalidInput(
244                "No texts provided for building vocabulary".into(),
245            ));
246        }
247
248        // Count word frequencies
249        let mut word_counts = HashMap::new();
250
251        for &text in texts {
252            let tokens = self.tokenizer.tokenize(text)?;
253            for token in tokens {
254                *word_counts.entry(token).or_insert(0) += 1;
255            }
256        }
257
258        // Build vocabulary with min_count threshold
259        self.vocabulary = Vocabulary::new();
260        for (word, count) in &word_counts {
261            if *count >= self.config.min_count {
262                self.vocabulary.add_token(word);
263            }
264        }
265
266        if self.vocabulary.is_empty() {
267            return Err(TextError::VocabularyError(
268                "No words meet the minimum count threshold".into(),
269            ));
270        }
271
272        self.word_counts = word_counts;
273
274        // Initialize embeddings
275        let vocab_size = self.vocabulary.len();
276        let vector_size = self.config.vector_size;
277        let bucket_size = self.config.bucket_size;
278
279        let mut rng = scirs2_core::random::rng();
280
281        // Initialize word embeddings
282        let word_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
283            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
284        });
285
286        // Initialize output embeddings for negative sampling
287        let output_embeddings = Array2::zeros((vocab_size, vector_size));
288
289        // Initialize n-gram embeddings
290        let ngram_embeddings = Array2::from_shape_fn((bucket_size, vector_size), |_| {
291            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
292        });
293
294        self.word_embeddings = Some(word_embeddings);
295        self.output_embeddings = Some(output_embeddings);
296        self.ngram_embeddings = Some(ngram_embeddings);
297
298        // Build n-gram to bucket mapping
299        self.ngram_to_bucket.clear();
300        for i in 0..self.vocabulary.len() {
301            if let Some(word) = self.vocabulary.get_token(i) {
302                let ngrams = self.extract_ngrams(word);
303                for ngram in ngrams {
304                    if !self.ngram_to_bucket.contains_key(&ngram) {
305                        let bucket = self.hash_ngram(&ngram);
306                        self.ngram_to_bucket.insert(ngram, bucket);
307                    }
308                }
309            }
310        }
311
312        // Build negative sampling weights (unigram distribution ^ 0.75)
313        self.sampling_weights = vec![0.0; vocab_size];
314        for i in 0..vocab_size {
315            if let Some(word) = self.vocabulary.get_token(i) {
316                let count = self.word_counts.get(word).copied().unwrap_or(1);
317                self.sampling_weights[i] = (count as f64).powf(0.75);
318            }
319        }
320
321        Ok(())
322    }
323
324    /// Sample a negative example using the unigram distribution
325    fn sample_negative(&self, rng: &mut impl Rng) -> usize {
326        if self.sampling_weights.is_empty() {
327            return 0;
328        }
329        let total: f64 = self.sampling_weights.iter().sum();
330        if total <= 0.0 {
331            return rng.random_range(0..self.vocabulary.len().max(1));
332        }
333        let r = rng.random::<f64>() * total;
334        let mut cumulative = 0.0;
335        for (i, &w) in self.sampling_weights.iter().enumerate() {
336            cumulative += w;
337            if r <= cumulative {
338                return i;
339            }
340        }
341        self.sampling_weights.len() - 1
342    }
343
344    /// Compute the full subword-aware representation for a word index
345    ///
346    /// Returns the average of the word vector and all its n-gram vectors.
347    fn compute_word_representation(&self, word_idx: usize) -> Result<(Array1<f64>, Vec<usize>)> {
348        let word_embeddings = self
349            .word_embeddings
350            .as_ref()
351            .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
352        let ngram_embeddings = self
353            .ngram_embeddings
354            .as_ref()
355            .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
356
357        let word = self
358            .vocabulary
359            .get_token(word_idx)
360            .ok_or_else(|| TextError::VocabularyError("Invalid word index".into()))?;
361
362        let ngrams = self.extract_ngrams(word);
363        let ngram_buckets: Vec<usize> = ngrams
364            .iter()
365            .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
366            .collect();
367
368        let mut vec = word_embeddings.row(word_idx).to_owned();
369        for &bucket in &ngram_buckets {
370            vec += &ngram_embeddings.row(bucket);
371        }
372        let divisor = 1.0 + ngram_buckets.len() as f64;
373        vec /= divisor;
374
375        Ok((vec, ngram_buckets))
376    }
377
378    /// Train the FastText model
379    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
380        if texts.is_empty() {
381            return Err(TextError::InvalidInput(
382                "No texts provided for training".into(),
383            ));
384        }
385
386        // Build vocabulary if not already built
387        if self.vocabulary.is_empty() {
388            self.build_vocabulary(texts)?;
389        }
390
391        // Prepare training data
392        let mut sentences = Vec::new();
393        for &text in texts {
394            let tokens = self.tokenizer.tokenize(text)?;
395            let word_indices: Vec<usize> = tokens
396                .iter()
397                .filter_map(|token| self.vocabulary.get_index(token))
398                .collect();
399            if !word_indices.is_empty() {
400                sentences.push(word_indices);
401            }
402        }
403
404        // Pre-compute n-gram buckets for all words
405        let mut word_ngram_buckets: Vec<Vec<usize>> = Vec::with_capacity(self.vocabulary.len());
406        for i in 0..self.vocabulary.len() {
407            if let Some(word) = self.vocabulary.get_token(i) {
408                let ngrams = self.extract_ngrams(word);
409                let buckets: Vec<usize> = ngrams
410                    .iter()
411                    .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
412                    .collect();
413                word_ngram_buckets.push(buckets);
414            } else {
415                word_ngram_buckets.push(Vec::new());
416            }
417        }
418
419        // Training loop
420        for epoch in 0..self.config.epochs {
421            // Update learning rate
422            self.current_learning_rate =
423                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
424            self.current_learning_rate = self
425                .current_learning_rate
426                .max(self.config.learning_rate * 0.0001);
427
428            // Train on each sentence
429            for sentence in &sentences {
430                self.train_sentence(sentence, &word_ngram_buckets)?;
431            }
432        }
433
434        Ok(())
435    }
436
437    /// Train on a single sentence using skip-gram with negative sampling
438    fn train_sentence(
439        &mut self,
440        sentence: &[usize],
441        word_ngram_buckets: &[Vec<usize>],
442    ) -> Result<()> {
443        if sentence.len() < 2 {
444            return Ok(());
445        }
446
447        // Clone sampling weights to avoid borrow conflict with self.sample_negative()
448        let sampling_weights = self.sampling_weights.clone();
449        let vocab_len = self.vocabulary.len().max(1);
450        let negative_samples = self.config.negative_samples;
451        let current_lr = self.current_learning_rate;
452
453        let word_embeddings = self
454            .word_embeddings
455            .as_mut()
456            .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
457        let output_embeddings = self
458            .output_embeddings
459            .as_mut()
460            .ok_or_else(|| TextError::EmbeddingError("Output embeddings not initialized".into()))?;
461        let ngram_embeddings = self
462            .ngram_embeddings
463            .as_mut()
464            .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
465
466        let vector_size = self.config.vector_size;
467        let mut rng = scirs2_core::random::rng();
468
469        // Pre-compute cumulative distribution for negative sampling to avoid borrow conflict
470        let total_weight: f64 = sampling_weights.iter().sum();
471        let cumulative_weights: Vec<f64> = if total_weight > 0.0 {
472            let mut cum = Vec::with_capacity(sampling_weights.len());
473            let mut acc = 0.0;
474            for &w in &sampling_weights {
475                acc += w;
476                cum.push(acc);
477            }
478            cum
479        } else {
480            Vec::new()
481        };
482
483        // Skip-gram training for each word in the sentence
484        for (pos, &target_idx) in sentence.iter().enumerate() {
485            // Random window size
486            let window = 1 + rng.random_range(0..self.config.window_size);
487
488            // Get n-gram buckets for this target word
489            let ngram_buckets = &word_ngram_buckets[target_idx];
490
491            // Compute the subword-aware input representation
492            let mut input_vec = word_embeddings.row(target_idx).to_owned();
493            for &bucket in ngram_buckets {
494                input_vec += &ngram_embeddings.row(bucket);
495            }
496            let divisor = 1.0 + ngram_buckets.len() as f64;
497            input_vec /= divisor;
498
499            // For each context word in window
500            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
501                if i == pos {
502                    continue;
503                }
504
505                let context_idx = sentence[i];
506
507                // Accumulated gradient for the input vector
508                let mut grad_input = Array1::zeros(vector_size);
509
510                // Positive example
511                let output_vec = output_embeddings.row(context_idx).to_owned();
512                let dot_product: f64 = input_vec
513                    .iter()
514                    .zip(output_vec.iter())
515                    .map(|(a, b)| a * b)
516                    .sum();
517                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
518                let gradient = (1.0 - sigmoid) * current_lr;
519
520                // Accumulate gradient for input
521                grad_input.scaled_add(gradient, &output_vec);
522
523                // Update output embedding for positive example
524                let mut out_row = output_embeddings.row_mut(context_idx);
525                let update = &input_vec * gradient;
526                out_row += &update;
527
528                // Negative sampling
529                for _ in 0..negative_samples {
530                    let neg_idx = if cumulative_weights.is_empty() {
531                        if vocab_len > 0 {
532                            rng.random_range(0..vocab_len)
533                        } else {
534                            0
535                        }
536                    } else {
537                        let r = rng.random::<f64>() * total_weight;
538                        match cumulative_weights.binary_search_by(|w| {
539                            w.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
540                        }) {
541                            Ok(i) => i,
542                            Err(i) => i.min(cumulative_weights.len() - 1),
543                        }
544                    };
545                    if neg_idx == context_idx {
546                        continue;
547                    }
548
549                    let neg_vec = output_embeddings.row(neg_idx).to_owned();
550                    let dot_product: f64 = input_vec
551                        .iter()
552                        .zip(neg_vec.iter())
553                        .map(|(a, b)| a * b)
554                        .sum();
555                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
556                    let gradient = -sigmoid * current_lr;
557
558                    // Accumulate gradient for input
559                    grad_input.scaled_add(gradient, &neg_vec);
560
561                    // Update output embedding for negative example
562                    let mut neg_row = output_embeddings.row_mut(neg_idx);
563                    let neg_update = &input_vec * gradient;
564                    neg_row += &neg_update;
565                }
566
567                // Distribute gradient back to word embedding and n-gram embeddings
568                let scaled_grad = &grad_input / divisor;
569
570                let mut word_row = word_embeddings.row_mut(target_idx);
571                word_row += &scaled_grad;
572
573                for &bucket in ngram_buckets {
574                    let mut ngram_row = ngram_embeddings.row_mut(bucket);
575                    ngram_row += &scaled_grad;
576                }
577            }
578        }
579
580        Ok(())
581    }
582
583    /// Get the embedding vector for a word (handles OOV words via subwords)
584    ///
585    /// For in-vocabulary words, returns the average of the word vector and its n-gram vectors.
586    /// For OOV words, returns the average of matching n-gram vectors.
587    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
588        let word_embeddings = self
589            .word_embeddings
590            .as_ref()
591            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
592        let ngram_embeddings = self
593            .ngram_embeddings
594            .as_ref()
595            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
596
597        let ngrams = self.extract_ngrams(word);
598        let mut vector = Array1::zeros(self.config.vector_size);
599        let mut count = 0.0;
600
601        // Add word embedding if in vocabulary
602        if let Some(idx) = self.vocabulary.get_index(word) {
603            vector += &word_embeddings.row(idx);
604            count += 1.0;
605        }
606
607        // Add n-gram embeddings (always, even for in-vocab words)
608        for ngram in &ngrams {
609            if let Some(&bucket) = self.ngram_to_bucket.get(ngram) {
610                vector += &ngram_embeddings.row(bucket);
611                count += 1.0;
612            } else {
613                // For OOV words, hash the n-gram directly
614                let bucket = self.hash_ngram(ngram);
615                if bucket < self.config.bucket_size {
616                    vector += &ngram_embeddings.row(bucket);
617                    count += 1.0;
618                }
619            }
620        }
621
622        if count > 0.0 {
623            vector /= count;
624            Ok(vector)
625        } else {
626            Err(TextError::VocabularyError(format!(
627                "Cannot compute vector for word '{}': no n-grams found",
628                word
629            )))
630        }
631    }
632
633    /// Find most similar words to a given word
634    pub fn most_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
635        let word_vec = self.get_word_vector(word)?;
636        self.most_similar_by_vector(&word_vec, top_n, &[word])
637    }
638
639    /// Find most similar words to a given vector
640    pub fn most_similar_by_vector(
641        &self,
642        vector: &Array1<f64>,
643        top_n: usize,
644        exclude_words: &[&str],
645    ) -> Result<Vec<(String, f64)>> {
646        let mut similarities = Vec::new();
647
648        for i in 0..self.vocabulary.len() {
649            if let Some(candidate) = self.vocabulary.get_token(i) {
650                if exclude_words.contains(&candidate) {
651                    continue;
652                }
653
654                if let Ok(candidate_vec) = self.get_word_vector(candidate) {
655                    let similarity = cosine_similarity(vector, &candidate_vec);
656                    similarities.push((candidate.to_string(), similarity));
657                }
658            }
659        }
660
661        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
662        Ok(similarities.into_iter().take(top_n).collect())
663    }
664
665    /// Compute word analogy: a is to b as c is to ?
666    ///
667    /// Uses vector arithmetic: result = b - a + c, then finds most similar words.
668    /// Works with OOV words since FastText can compute vectors for any word.
669    pub fn analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
670        let a_vec = self.get_word_vector(a)?;
671        let b_vec = self.get_word_vector(b)?;
672        let c_vec = self.get_word_vector(c)?;
673
674        // d = b - a + c
675        let mut d_vec = b_vec.clone();
676        d_vec -= &a_vec;
677        d_vec += &c_vec;
678
679        // Normalize
680        let norm = d_vec.iter().fold(0.0, |sum, &val| sum + val * val).sqrt();
681        if norm > 0.0 {
682            d_vec.mapv_inplace(|val| val / norm);
683        }
684
685        self.most_similar_by_vector(&d_vec, top_n, &[a, b, c])
686    }
687
688    /// Compute cosine similarity between two words
689    ///
690    /// Both words can be OOV.
691    pub fn word_similarity(&self, word1: &str, word2: &str) -> Result<f64> {
692        let vec1 = self.get_word_vector(word1)?;
693        let vec2 = self.get_word_vector(word2)?;
694        Ok(cosine_similarity(&vec1, &vec2))
695    }
696
697    /// Save the model to a file
698    ///
699    /// Saves in a format that includes word vectors, n-gram info, and config.
700    /// Uses a custom header format:
701    /// Line 1: FASTTEXT <vocab_size> <vector_size> <min_n> <max_n> <bucket_size>
702    /// Lines 2+: word vector_components...
703    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
704        let word_embeddings = self
705            .word_embeddings
706            .as_ref()
707            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
708
709        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
710
711        // Write extended header
712        writeln!(
713            &mut file,
714            "FASTTEXT {} {} {} {} {}",
715            self.vocabulary.len(),
716            self.config.vector_size,
717            self.config.min_n,
718            self.config.max_n,
719            self.config.bucket_size,
720        )
721        .map_err(|e| TextError::IoError(e.to_string()))?;
722
723        // Write each word and its full subword-aware vector
724        for i in 0..self.vocabulary.len() {
725            if let Some(word) = self.vocabulary.get_token(i) {
726                write!(&mut file, "{} ", word).map_err(|e| TextError::IoError(e.to_string()))?;
727
728                // Write the raw word embedding (subword info is reconstructed on load)
729                for j in 0..self.config.vector_size {
730                    write!(&mut file, "{:.6} ", word_embeddings[[i, j]])
731                        .map_err(|e| TextError::IoError(e.to_string()))?;
732                }
733
734                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
735            }
736        }
737
738        // Write n-gram mapping section
739        writeln!(&mut file, "NGRAMS {}", self.ngram_to_bucket.len())
740            .map_err(|e| TextError::IoError(e.to_string()))?;
741
742        for (ngram, &bucket) in &self.ngram_to_bucket {
743            writeln!(&mut file, "{} {}", ngram, bucket)
744                .map_err(|e| TextError::IoError(e.to_string()))?;
745        }
746
747        Ok(())
748    }
749
750    /// Load a FastText model from a file
751    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
752        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
753        let mut reader = BufReader::new(file);
754
755        // Read header
756        let mut header = String::new();
757        reader
758            .read_line(&mut header)
759            .map_err(|e| TextError::IoError(e.to_string()))?;
760
761        let parts: Vec<&str> = header.split_whitespace().collect();
762        if parts.len() < 6 || parts[0] != "FASTTEXT" {
763            return Err(TextError::EmbeddingError(
764                "Invalid FastText file format (expected FASTTEXT header)".into(),
765            ));
766        }
767
768        let vocab_size = parts[1]
769            .parse::<usize>()
770            .map_err(|_| TextError::EmbeddingError("Invalid vocab size".into()))?;
771        let vector_size = parts[2]
772            .parse::<usize>()
773            .map_err(|_| TextError::EmbeddingError("Invalid vector size".into()))?;
774        let min_n = parts[3]
775            .parse::<usize>()
776            .map_err(|_| TextError::EmbeddingError("Invalid min_n".into()))?;
777        let max_n = parts[4]
778            .parse::<usize>()
779            .map_err(|_| TextError::EmbeddingError("Invalid max_n".into()))?;
780        let bucket_size = parts[5]
781            .parse::<usize>()
782            .map_err(|_| TextError::EmbeddingError("Invalid bucket_size".into()))?;
783
784        let config = FastTextConfig {
785            vector_size,
786            min_n,
787            max_n,
788            bucket_size,
789            ..Default::default()
790        };
791
792        let mut vocabulary = Vocabulary::new();
793        let mut word_embeddings = Array2::zeros((vocab_size, vector_size));
794
795        // Read word vectors
796        for i in 0..vocab_size {
797            let mut line = String::new();
798            reader
799                .read_line(&mut line)
800                .map_err(|e| TextError::IoError(e.to_string()))?;
801
802            let parts: Vec<&str> = line.split_whitespace().collect();
803            if parts.len() < vector_size + 1 {
804                return Err(TextError::EmbeddingError(format!(
805                    "Invalid vector at line {}",
806                    i + 2
807                )));
808            }
809
810            vocabulary.add_token(parts[0]);
811
812            for j in 0..vector_size {
813                word_embeddings[[i, j]] = parts[j + 1].parse::<f64>().map_err(|_| {
814                    TextError::EmbeddingError(format!(
815                        "Invalid float at line {}, position {}",
816                        i + 2,
817                        j + 1
818                    ))
819                })?;
820            }
821        }
822
823        // Read n-gram mapping section (if present)
824        let mut ngram_to_bucket = HashMap::new();
825        let mut ngram_header = String::new();
826        if reader
827            .read_line(&mut ngram_header)
828            .map_err(|e| TextError::IoError(e.to_string()))?
829            > 0
830        {
831            let ngram_parts: Vec<&str> = ngram_header.split_whitespace().collect();
832            if ngram_parts.len() >= 2 && ngram_parts[0] == "NGRAMS" {
833                let ngram_count = ngram_parts[1]
834                    .parse::<usize>()
835                    .map_err(|_| TextError::EmbeddingError("Invalid ngram count".into()))?;
836
837                for _ in 0..ngram_count {
838                    let mut ngram_line = String::new();
839                    reader
840                        .read_line(&mut ngram_line)
841                        .map_err(|e| TextError::IoError(e.to_string()))?;
842
843                    let np: Vec<&str> = ngram_line.split_whitespace().collect();
844                    if np.len() >= 2 {
845                        let bucket = np[1]
846                            .parse::<usize>()
847                            .map_err(|_| TextError::EmbeddingError("Invalid bucket".into()))?;
848                        ngram_to_bucket.insert(np[0].to_string(), bucket);
849                    }
850                }
851            }
852        }
853
854        // Initialize n-gram embeddings (zeros since we don't save them in text format)
855        let ngram_embeddings = Array2::zeros((bucket_size, vector_size));
856
857        Ok(Self {
858            config,
859            vocabulary,
860            word_counts: HashMap::new(),
861            word_embeddings: Some(word_embeddings),
862            output_embeddings: None,
863            ngram_embeddings: Some(ngram_embeddings),
864            ngram_to_bucket,
865            tokenizer: Box::new(WordTokenizer::default()),
866            current_learning_rate: 0.05,
867            sampling_weights: Vec::new(),
868        })
869    }
870
871    /// Get the vocabulary size
872    pub fn vocabulary_size(&self) -> usize {
873        self.vocabulary.len()
874    }
875
876    /// Get the vector size
877    pub fn vector_size(&self) -> usize {
878        self.config.vector_size
879    }
880
881    /// Get the n-gram configuration (min_n, max_n)
882    pub fn ngram_range(&self) -> (usize, usize) {
883        (self.config.min_n, self.config.max_n)
884    }
885
886    /// Get the number of unique n-grams discovered
887    pub fn ngram_count(&self) -> usize {
888        self.ngram_to_bucket.len()
889    }
890
891    /// Check if a word is in the vocabulary
892    pub fn contains(&self, word: &str) -> bool {
893        self.vocabulary.contains(word)
894    }
895
896    /// Check if a word can be represented (either in vocab or has matching n-grams)
897    pub fn can_represent(&self, word: &str) -> bool {
898        if self.vocabulary.contains(word) {
899            return true;
900        }
901        // Check if any n-grams match
902        let ngrams = self.extract_ngrams(word);
903        ngrams
904            .iter()
905            .any(|ng| self.ngram_to_bucket.contains_key(ng))
906    }
907
908    /// Get all words in the vocabulary
909    pub fn get_vocabulary_words(&self) -> Vec<String> {
910        let mut words = Vec::with_capacity(self.vocabulary.len());
911        for i in 0..self.vocabulary.len() {
912            if let Some(word) = self.vocabulary.get_token(i) {
913                words.push(word.to_string());
914            }
915        }
916        words
917    }
918}
919
920impl Default for FastText {
921    fn default() -> Self {
922        Self::new()
923    }
924}
925
926/// Calculate cosine similarity between two vectors
927fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
928    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
929    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
930    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
931
932    if norm_a > 0.0 && norm_b > 0.0 {
933        dot_product / (norm_a * norm_b)
934    } else {
935        0.0
936    }
937}
938
939#[cfg(test)]
940mod tests {
941    use super::*;
942
943    #[test]
944    fn test_extract_ngrams() {
945        let config = FastTextConfig {
946            min_n: 3,
947            max_n: 4,
948            ..Default::default()
949        };
950        let model = FastText::with_config(config);
951
952        let ngrams = model.extract_ngrams("test");
953        assert!(!ngrams.is_empty());
954        assert!(ngrams.contains(&"<te".to_string()));
955        assert!(ngrams.contains(&"est".to_string()));
956        assert!(ngrams.contains(&"st>".to_string()));
957        // 4-grams
958        assert!(ngrams.contains(&"<tes".to_string()));
959        assert!(ngrams.contains(&"test".to_string()));
960        assert!(ngrams.contains(&"est>".to_string()));
961    }
962
963    #[test]
964    fn test_extract_ngrams_short_word() {
965        let config = FastTextConfig {
966            min_n: 3,
967            max_n: 6,
968            ..Default::default()
969        };
970        let model = FastText::with_config(config);
971
972        let ngrams = model.extract_ngrams("a");
973        // "<a>" has 3 chars, so only 3-gram possible: "<a>"
974        assert_eq!(ngrams.len(), 1);
975        assert_eq!(ngrams[0], "<a>");
976    }
977
978    #[test]
979    fn test_fasttext_training() {
980        let texts = [
981            "the quick brown fox jumps over the lazy dog",
982            "a quick brown dog outpaces a quick fox",
983        ];
984
985        let config = FastTextConfig {
986            vector_size: 10,
987            window_size: 2,
988            min_count: 1,
989            epochs: 1,
990            min_n: 3,
991            max_n: 4,
992            bucket_size: 1000,
993            ..Default::default()
994        };
995
996        let mut model = FastText::with_config(config);
997        let result = model.train(&texts);
998        assert!(result.is_ok());
999
1000        // Test getting vector for in-vocabulary word
1001        let vec = model.get_word_vector("quick");
1002        assert!(vec.is_ok());
1003        assert_eq!(vec.expect("Failed to get vector").len(), 10);
1004
1005        // Test getting vector for OOV word (should work due to n-grams)
1006        let oov_vec = model.get_word_vector("quickest");
1007        assert!(oov_vec.is_ok());
1008    }
1009
1010    #[test]
1011    fn test_fasttext_oov_handling() {
1012        let texts = ["hello world", "hello there"];
1013
1014        let config = FastTextConfig {
1015            vector_size: 10,
1016            min_count: 1,
1017            epochs: 1,
1018            bucket_size: 1000,
1019            ..Default::default()
1020        };
1021
1022        let mut model = FastText::with_config(config);
1023        model.train(&texts).expect("Training failed");
1024
1025        // Get vector for OOV word that shares n-grams with "hello"
1026        let oov_vec = model.get_word_vector("helloworld");
1027        assert!(oov_vec.is_ok(), "FastText should handle OOV words");
1028    }
1029
1030    #[test]
1031    fn test_fasttext_analogy() {
1032        let texts = [
1033            "the king ruled the kingdom wisely",
1034            "the queen ruled the kingdom wisely",
1035            "the man worked in the field",
1036            "the woman worked in the field",
1037            "the king and the queen were happy",
1038            "the man and the woman were happy",
1039        ];
1040
1041        let config = FastTextConfig {
1042            vector_size: 20,
1043            window_size: 3,
1044            min_count: 1,
1045            epochs: 5,
1046            min_n: 3,
1047            max_n: 5,
1048            bucket_size: 1000,
1049            ..Default::default()
1050        };
1051
1052        let mut model = FastText::with_config(config);
1053        model.train(&texts).expect("Training failed");
1054
1055        // Just verify analogy doesn't crash
1056        let result = model.analogy("king", "man", "woman", 3);
1057        assert!(result.is_ok());
1058        let answers = result.expect("analogy");
1059        assert!(!answers.is_empty());
1060    }
1061
1062    #[test]
1063    fn test_fasttext_word_similarity() {
1064        let texts = [
1065            "the cat sat on the mat",
1066            "the dog sat on the rug",
1067            "the cat and dog played",
1068        ];
1069
1070        let config = FastTextConfig {
1071            vector_size: 10,
1072            min_count: 1,
1073            epochs: 3,
1074            min_n: 3,
1075            max_n: 4,
1076            bucket_size: 1000,
1077            ..Default::default()
1078        };
1079
1080        let mut model = FastText::with_config(config);
1081        model.train(&texts).expect("Training failed");
1082
1083        let sim = model.word_similarity("cat", "dog");
1084        assert!(sim.is_ok());
1085        // Both should have finite similarity
1086        assert!(sim.expect("similarity").is_finite());
1087    }
1088
1089    #[test]
1090    fn test_fasttext_save_load() {
1091        let texts = ["the quick brown fox jumps", "the lazy brown dog sleeps"];
1092
1093        let config = FastTextConfig {
1094            vector_size: 5,
1095            min_count: 1,
1096            epochs: 1,
1097            min_n: 3,
1098            max_n: 4,
1099            bucket_size: 1000,
1100            ..Default::default()
1101        };
1102
1103        let mut model = FastText::with_config(config);
1104        model.train(&texts).expect("Training failed");
1105
1106        let save_path = std::env::temp_dir().join("test_fasttext_save.txt");
1107        model.save(&save_path).expect("Failed to save");
1108
1109        let loaded = FastText::load(&save_path).expect("Failed to load");
1110        assert_eq!(loaded.vocabulary_size(), model.vocabulary_size());
1111        assert_eq!(loaded.vector_size(), model.vector_size());
1112        assert_eq!(loaded.ngram_range(), model.ngram_range());
1113
1114        std::fs::remove_file(save_path).ok();
1115    }
1116
1117    #[test]
1118    fn test_fasttext_can_represent() {
1119        let texts = ["hello world"];
1120
1121        let config = FastTextConfig {
1122            vector_size: 5,
1123            min_count: 1,
1124            epochs: 1,
1125            ..Default::default()
1126        };
1127
1128        let mut model = FastText::with_config(config);
1129        model.train(&texts).expect("Training failed");
1130
1131        assert!(model.contains("hello"));
1132        assert!(model.can_represent("hello"));
1133        assert!(!model.contains("helloworld"));
1134        assert!(model.can_represent("helloworld")); // OOV but has matching n-grams
1135    }
1136
1137    #[test]
1138    fn test_fasttext_most_similar() {
1139        let texts = [
1140            "the dog runs fast",
1141            "the cat runs fast",
1142            "the bird flies high",
1143        ];
1144
1145        let config = FastTextConfig {
1146            vector_size: 10,
1147            min_count: 1,
1148            epochs: 5,
1149            min_n: 3,
1150            max_n: 4,
1151            bucket_size: 1000,
1152            ..Default::default()
1153        };
1154
1155        let mut model = FastText::with_config(config);
1156        model.train(&texts).expect("Training failed");
1157
1158        let similar = model.most_similar("dog", 2).expect("most_similar");
1159        assert!(!similar.is_empty());
1160        assert!(similar.len() <= 2);
1161    }
1162
1163    #[test]
1164    fn test_fasttext_empty_input() {
1165        let texts: Vec<&str> = vec![];
1166        let mut model = FastText::new();
1167        let result = model.train(&texts);
1168        assert!(result.is_err());
1169    }
1170
1171    #[test]
1172    fn test_fasttext_config_defaults() {
1173        let config = FastTextConfig::default();
1174        assert_eq!(config.vector_size, 100);
1175        assert_eq!(config.min_n, 3);
1176        assert_eq!(config.max_n, 6);
1177        assert_eq!(config.window_size, 5);
1178        assert_eq!(config.bucket_size, 2_000_000);
1179    }
1180
1181    #[test]
1182    fn test_hash_ngram_deterministic() {
1183        let model = FastText::new();
1184        let h1 = model.hash_ngram("abc");
1185        let h2 = model.hash_ngram("abc");
1186        assert_eq!(h1, h2);
1187
1188        let h3 = model.hash_ngram("xyz");
1189        // Different strings should (usually) hash differently
1190        // Not guaranteed but very likely
1191        assert_ne!(h1, h3);
1192    }
1193}