Skip to main content

scirs2_text/embeddings/
fasttext.rs

1//! FastText embeddings with character n-grams
2//!
3//! This module implements FastText, an extension of Word2Vec that learns
4//! word representations as bags of character n-grams. This approach handles
5//! out-of-vocabulary words and morphologically rich languages better.
6//!
7//! ## Overview
8//!
9//! FastText represents each word as a bag of character n-grams. For example:
10//! - word: "where"
11//! - 3-grams: "<wh", "whe", "her", "ere", "re>"
12//! - The word embedding is the sum of its n-gram embeddings
13//!
14//! ## Quick Start
15//!
16//! ```rust
17//! use scirs2_text::embeddings::fasttext::{FastText, FastTextConfig};
18//!
19//! // Create configuration
20//! let config = FastTextConfig {
21//!     vector_size: 100,
22//!     min_n: 3,
23//!     max_n: 6,
24//!     window_size: 5,
25//!     epochs: 5,
26//!     learning_rate: 0.05,
27//!     min_count: 1,
28//!     negative_samples: 5,
29//!     ..Default::default()
30//! };
31//!
32//! // Train model
33//! let documents = vec![
34//!     "the quick brown fox jumps over the lazy dog",
35//!     "a quick brown dog outpaces a quick fox"
36//! ];
37//!
38//! let mut model = FastText::with_config(config);
39//! model.train(&documents).expect("Training failed");
40//!
41//! // Get word vector (works even for OOV words!)
42//! if let Ok(vector) = model.get_word_vector("quickest") {
43//!     println!("Vector for OOV word 'quickest': {:?}", vector);
44//! }
45//! ```
46
47use crate::error::{Result, TextError};
48use crate::tokenize::{Tokenizer, WordTokenizer};
49use crate::vocabulary::Vocabulary;
50use scirs2_core::ndarray::{Array1, Array2};
51use scirs2_core::random::prelude::*;
52use std::collections::{HashMap, HashSet};
53use std::fmt::Debug;
54use std::fs::File;
55use std::io::{BufRead, BufReader, Write};
56use std::path::Path;
57
58/// FastText configuration
59#[derive(Debug, Clone)]
60pub struct FastTextConfig {
61    /// Size of word vectors
62    pub vector_size: usize,
63    /// Minimum length of character n-grams
64    pub min_n: usize,
65    /// Maximum length of character n-grams
66    pub max_n: usize,
67    /// Size of context window
68    pub window_size: usize,
69    /// Number of training epochs
70    pub epochs: usize,
71    /// Learning rate
72    pub learning_rate: f64,
73    /// Minimum word count threshold
74    pub min_count: usize,
75    /// Number of negative samples
76    pub negative_samples: usize,
77    /// Subsampling threshold for frequent words
78    pub subsample: f64,
79    /// Bucket size for hashing n-grams
80    pub bucket_size: usize,
81}
82
83impl Default for FastTextConfig {
84    fn default() -> Self {
85        Self {
86            vector_size: 100,
87            min_n: 3,
88            max_n: 6,
89            window_size: 5,
90            epochs: 5,
91            learning_rate: 0.05,
92            min_count: 5,
93            negative_samples: 5,
94            subsample: 1e-3,
95            bucket_size: 2_000_000,
96        }
97    }
98}
99
100/// FastText model for learning word representations with character n-grams
101pub struct FastText {
102    /// Configuration
103    config: FastTextConfig,
104    /// Vocabulary of words
105    vocabulary: Vocabulary,
106    /// Word frequencies
107    word_counts: HashMap<String, usize>,
108    /// Word embeddings (for words in vocabulary)
109    word_embeddings: Option<Array2<f64>>,
110    /// N-gram embeddings (subword information)
111    ngram_embeddings: Option<Array2<f64>>,
112    /// N-gram to bucket index mapping
113    ngram_to_bucket: HashMap<String, usize>,
114    /// Tokenizer
115    tokenizer: Box<dyn Tokenizer + Send + Sync>,
116    /// Current learning rate
117    current_learning_rate: f64,
118}
119
120impl Debug for FastText {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("FastText")
123            .field("config", &self.config)
124            .field("vocabulary_size", &self.vocabulary.len())
125            .field("word_embeddings", &self.word_embeddings.is_some())
126            .field("ngram_embeddings", &self.ngram_embeddings.is_some())
127            .finish()
128    }
129}
130
131impl Clone for FastText {
132    fn clone(&self) -> Self {
133        Self {
134            config: self.config.clone(),
135            vocabulary: self.vocabulary.clone(),
136            word_counts: self.word_counts.clone(),
137            word_embeddings: self.word_embeddings.clone(),
138            ngram_embeddings: self.ngram_embeddings.clone(),
139            ngram_to_bucket: self.ngram_to_bucket.clone(),
140            tokenizer: Box::new(WordTokenizer::default()),
141            current_learning_rate: self.current_learning_rate,
142        }
143    }
144}
145
146impl FastText {
147    /// Create a new FastText model with default configuration
148    pub fn new() -> Self {
149        Self {
150            config: FastTextConfig::default(),
151            vocabulary: Vocabulary::new(),
152            word_counts: HashMap::new(),
153            word_embeddings: None,
154            ngram_embeddings: None,
155            ngram_to_bucket: HashMap::new(),
156            tokenizer: Box::new(WordTokenizer::default()),
157            current_learning_rate: 0.05,
158        }
159    }
160
161    /// Create a new FastText model with custom configuration
162    pub fn with_config(config: FastTextConfig) -> Self {
163        let learning_rate = config.learning_rate;
164        Self {
165            config,
166            vocabulary: Vocabulary::new(),
167            word_counts: HashMap::new(),
168            word_embeddings: None,
169            ngram_embeddings: None,
170            ngram_to_bucket: HashMap::new(),
171            tokenizer: Box::new(WordTokenizer::default()),
172            current_learning_rate: learning_rate,
173        }
174    }
175
176    /// Extract character n-grams from a word
177    fn extract_ngrams(&self, word: &str) -> Vec<String> {
178        let word_with_boundaries = format!("<{}>", word);
179        let chars: Vec<char> = word_with_boundaries.chars().collect();
180        let mut ngrams = Vec::new();
181
182        for n in self.config.min_n..=self.config.max_n {
183            if chars.len() < n {
184                continue;
185            }
186
187            for i in 0..=(chars.len() - n) {
188                let ngram: String = chars[i..i + n].iter().collect();
189                ngrams.push(ngram);
190            }
191        }
192
193        ngrams
194    }
195
196    /// Hash an n-gram to a bucket index
197    fn hash_ngram(&self, ngram: &str) -> usize {
198        // Simple hash function (FNV-1a)
199        let mut hash: u64 = 2166136261;
200        for byte in ngram.bytes() {
201            hash ^= u64::from(byte);
202            hash = hash.wrapping_mul(16777619);
203        }
204        (hash % (self.config.bucket_size as u64)) as usize
205    }
206
207    /// Build vocabulary from texts
208    pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
209        if texts.is_empty() {
210            return Err(TextError::InvalidInput(
211                "No texts provided for building vocabulary".into(),
212            ));
213        }
214
215        // Count word frequencies
216        let mut word_counts = HashMap::new();
217
218        for &text in texts {
219            let tokens = self.tokenizer.tokenize(text)?;
220            for token in tokens {
221                *word_counts.entry(token).or_insert(0) += 1;
222            }
223        }
224
225        // Build vocabulary with min_count threshold
226        self.vocabulary = Vocabulary::new();
227        for (word, count) in &word_counts {
228            if *count >= self.config.min_count {
229                self.vocabulary.add_token(word);
230            }
231        }
232
233        if self.vocabulary.is_empty() {
234            return Err(TextError::VocabularyError(
235                "No words meet the minimum count threshold".into(),
236            ));
237        }
238
239        self.word_counts = word_counts;
240
241        // Initialize embeddings
242        let vocab_size = self.vocabulary.len();
243        let vector_size = self.config.vector_size;
244        let bucket_size = self.config.bucket_size;
245
246        let mut rng = scirs2_core::random::rng();
247
248        // Initialize word embeddings
249        let word_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
250            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
251        });
252
253        // Initialize n-gram embeddings
254        let ngram_embeddings = Array2::from_shape_fn((bucket_size, vector_size), |_| {
255            (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
256        });
257
258        self.word_embeddings = Some(word_embeddings);
259        self.ngram_embeddings = Some(ngram_embeddings);
260
261        // Build n-gram to bucket mapping
262        self.ngram_to_bucket.clear();
263        for i in 0..self.vocabulary.len() {
264            if let Some(word) = self.vocabulary.get_token(i) {
265                let ngrams = self.extract_ngrams(word);
266                for ngram in ngrams {
267                    if !self.ngram_to_bucket.contains_key(&ngram) {
268                        let bucket = self.hash_ngram(&ngram);
269                        self.ngram_to_bucket.insert(ngram, bucket);
270                    }
271                }
272            }
273        }
274
275        Ok(())
276    }
277
278    /// Train the FastText model
279    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
280        if texts.is_empty() {
281            return Err(TextError::InvalidInput(
282                "No texts provided for training".into(),
283            ));
284        }
285
286        // Build vocabulary if not already built
287        if self.vocabulary.is_empty() {
288            self.build_vocabulary(texts)?;
289        }
290
291        // Prepare training data
292        let mut sentences = Vec::new();
293        for &text in texts {
294            let tokens = self.tokenizer.tokenize(text)?;
295            let word_indices: Vec<usize> = tokens
296                .iter()
297                .filter_map(|token| self.vocabulary.get_index(token))
298                .collect();
299            if !word_indices.is_empty() {
300                sentences.push(word_indices);
301            }
302        }
303
304        // Training loop
305        for epoch in 0..self.config.epochs {
306            // Update learning rate
307            self.current_learning_rate =
308                self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
309            self.current_learning_rate = self
310                .current_learning_rate
311                .max(self.config.learning_rate * 0.0001);
312
313            // Train on each sentence
314            for sentence in &sentences {
315                self.train_sentence(sentence)?;
316            }
317        }
318
319        Ok(())
320    }
321
322    /// Train on a single sentence
323    fn train_sentence(&mut self, sentence: &[usize]) -> Result<()> {
324        if sentence.len() < 2 {
325            return Ok(());
326        }
327
328        // Extract all ngrams for all words in sentence BEFORE taking mutable borrows
329        let mut sentence_ngrams = Vec::with_capacity(sentence.len());
330        for &target_idx in sentence {
331            let target_word = self
332                .vocabulary
333                .get_token(target_idx)
334                .ok_or_else(|| TextError::VocabularyError("Invalid word index".into()))?;
335            let ngrams = self.extract_ngrams(target_word);
336            let ngram_buckets: Vec<usize> = ngrams
337                .iter()
338                .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
339                .collect();
340            sentence_ngrams.push(ngram_buckets);
341        }
342
343        let word_embeddings = self
344            .word_embeddings
345            .as_mut()
346            .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
347        let ngram_embeddings = self
348            .ngram_embeddings
349            .as_mut()
350            .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
351
352        let mut rng = scirs2_core::random::rng();
353
354        // Skip-gram training for each word in the sentence
355        for (pos, &target_idx) in sentence.iter().enumerate() {
356            // Random window size
357            let window = 1 + rng.random_range(0..self.config.window_size);
358
359            // Get precomputed n-gram buckets for this target word
360            let ngram_buckets = &sentence_ngrams[pos];
361
362            // Average word embedding with n-gram embeddings
363            let mut target_vec = word_embeddings.row(target_idx).to_owned();
364            for &bucket in ngram_buckets {
365                target_vec += &ngram_embeddings.row(bucket);
366            }
367            if !ngram_buckets.is_empty() {
368                target_vec /= 1.0 + ngram_buckets.len() as f64;
369            }
370
371            // For each context word in window
372            for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
373                if i == pos {
374                    continue;
375                }
376
377                let context_idx = sentence[i];
378
379                // Positive example
380                let context_vec = word_embeddings.row(context_idx).to_owned();
381                let dot_product: f64 = target_vec
382                    .iter()
383                    .zip(context_vec.iter())
384                    .map(|(a, b)| a * b)
385                    .sum();
386                let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
387                let gradient = (1.0 - sigmoid) * self.current_learning_rate;
388
389                // Update embeddings
390                let update = &target_vec * gradient;
391                let mut context_row = word_embeddings.row_mut(context_idx);
392                context_row += &update;
393
394                // Update n-gram embeddings (pre-compute scaled update to avoid move in loop)
395                if !ngram_buckets.is_empty() {
396                    let ngram_update = update / (1.0 + ngram_buckets.len() as f64);
397                    for &bucket in ngram_buckets {
398                        let mut ngram_row = ngram_embeddings.row_mut(bucket);
399                        ngram_row += &ngram_update;
400                    }
401                }
402
403                // Negative sampling
404                for _ in 0..self.config.negative_samples {
405                    let neg_idx = rng.random_range(0..self.vocabulary.len());
406                    if neg_idx == context_idx {
407                        continue;
408                    }
409
410                    let neg_vec = word_embeddings.row(neg_idx).to_owned();
411                    let dot_product: f64 = target_vec
412                        .iter()
413                        .zip(neg_vec.iter())
414                        .map(|(a, b)| a * b)
415                        .sum();
416                    let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
417                    let gradient = -sigmoid * self.current_learning_rate;
418
419                    let update = &target_vec * gradient;
420                    let mut neg_row = word_embeddings.row_mut(neg_idx);
421                    neg_row += &update;
422                }
423            }
424        }
425
426        Ok(())
427    }
428
429    /// Get the embedding vector for a word (handles OOV words)
430    pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
431        let word_embeddings = self
432            .word_embeddings
433            .as_ref()
434            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
435        let ngram_embeddings = self
436            .ngram_embeddings
437            .as_ref()
438            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
439
440        let ngrams = self.extract_ngrams(word);
441        let mut vector = Array1::zeros(self.config.vector_size);
442        let mut count = 0.0;
443
444        // Add word embedding if in vocabulary
445        if let Some(idx) = self.vocabulary.get_index(word) {
446            vector += &word_embeddings.row(idx);
447            count += 1.0;
448        }
449
450        // Add n-gram embeddings
451        for ngram in &ngrams {
452            if let Some(&bucket) = self.ngram_to_bucket.get(ngram) {
453                vector += &ngram_embeddings.row(bucket);
454                count += 1.0;
455            }
456        }
457
458        if count > 0.0 {
459            vector /= count;
460            Ok(vector)
461        } else {
462            Err(TextError::VocabularyError(format!(
463                "Cannot compute vector for word '{}': no n-grams found",
464                word
465            )))
466        }
467    }
468
469    /// Find most similar words
470    pub fn most_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
471        let word_vec = self.get_word_vector(word)?;
472        let word_embeddings = self
473            .word_embeddings
474            .as_ref()
475            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
476
477        let mut similarities = Vec::new();
478
479        for i in 0..self.vocabulary.len() {
480            if let Some(candidate) = self.vocabulary.get_token(i) {
481                if candidate == word {
482                    continue;
483                }
484
485                let candidate_vec = word_embeddings.row(i).to_owned();
486                let similarity = cosine_similarity(&word_vec, &candidate_vec);
487                similarities.push((candidate.to_string(), similarity));
488            }
489        }
490
491        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
492        Ok(similarities.into_iter().take(top_n).collect())
493    }
494
495    /// Save the model to a file
496    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
497        let word_embeddings = self
498            .word_embeddings
499            .as_ref()
500            .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
501
502        let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
503
504        // Write header
505        writeln!(
506            &mut file,
507            "{} {}",
508            self.vocabulary.len(),
509            self.config.vector_size
510        )
511        .map_err(|e| TextError::IoError(e.to_string()))?;
512
513        // Write each word and its vector
514        for i in 0..self.vocabulary.len() {
515            if let Some(word) = self.vocabulary.get_token(i) {
516                write!(&mut file, "{} ", word).map_err(|e| TextError::IoError(e.to_string()))?;
517
518                for j in 0..self.config.vector_size {
519                    write!(&mut file, "{:.6} ", word_embeddings[[i, j]])
520                        .map_err(|e| TextError::IoError(e.to_string()))?;
521                }
522
523                writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
524            }
525        }
526
527        Ok(())
528    }
529
530    /// Get the vocabulary size
531    pub fn vocabulary_size(&self) -> usize {
532        self.vocabulary.len()
533    }
534
535    /// Get the vector size
536    pub fn vector_size(&self) -> usize {
537        self.config.vector_size
538    }
539}
540
541impl Default for FastText {
542    fn default() -> Self {
543        Self::new()
544    }
545}
546
547/// Calculate cosine similarity between two vectors
548fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
549    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
550    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
551    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
552
553    if norm_a > 0.0 && norm_b > 0.0 {
554        dot_product / (norm_a * norm_b)
555    } else {
556        0.0
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_extract_ngrams() {
566        let config = FastTextConfig {
567            min_n: 3,
568            max_n: 4,
569            ..Default::default()
570        };
571        let model = FastText::with_config(config);
572
573        let ngrams = model.extract_ngrams("test");
574        // Should include: "<te", "tes", "est", "st>", "<tes", "test", "est>", ...
575        assert!(!ngrams.is_empty());
576        assert!(ngrams.contains(&"<te".to_string()));
577        assert!(ngrams.contains(&"est".to_string()));
578    }
579
580    #[test]
581    fn test_fasttext_training() {
582        let texts = [
583            "the quick brown fox jumps over the lazy dog",
584            "a quick brown dog outpaces a quick fox",
585        ];
586
587        let config = FastTextConfig {
588            vector_size: 10,
589            window_size: 2,
590            min_count: 1,
591            epochs: 1,
592            min_n: 3,
593            max_n: 4,
594            ..Default::default()
595        };
596
597        let mut model = FastText::with_config(config);
598        let result = model.train(&texts);
599        assert!(result.is_ok());
600
601        // Test getting vector for in-vocabulary word
602        let vec = model.get_word_vector("quick");
603        assert!(vec.is_ok());
604        assert_eq!(vec.expect("Failed to get vector").len(), 10);
605
606        // Test getting vector for OOV word (should work due to n-grams)
607        let oov_vec = model.get_word_vector("quickest");
608        assert!(oov_vec.is_ok());
609    }
610
611    #[test]
612    fn test_fasttext_oov_handling() {
613        let texts = ["hello world", "hello there"];
614
615        let config = FastTextConfig {
616            vector_size: 10,
617            min_count: 1,
618            epochs: 1,
619            ..Default::default()
620        };
621
622        let mut model = FastText::with_config(config);
623        model.train(&texts).expect("Training failed");
624
625        // Get vector for OOV word that shares n-grams with "hello"
626        let oov_vec = model.get_word_vector("helloworld");
627        assert!(oov_vec.is_ok(), "FastText should handle OOV words");
628    }
629}