1use crate::error::{Result, TextError};
48use crate::tokenize::{Tokenizer, WordTokenizer};
49use crate::vocabulary::Vocabulary;
50use scirs2_core::ndarray::{Array1, Array2};
51use scirs2_core::random::prelude::*;
52use std::collections::{HashMap, HashSet};
53use std::fmt::Debug;
54use std::fs::File;
55use std::io::{BufRead, BufReader, Write};
56use std::path::Path;
57
58#[derive(Debug, Clone)]
60pub struct FastTextConfig {
61 pub vector_size: usize,
63 pub min_n: usize,
65 pub max_n: usize,
67 pub window_size: usize,
69 pub epochs: usize,
71 pub learning_rate: f64,
73 pub min_count: usize,
75 pub negative_samples: usize,
77 pub subsample: f64,
79 pub bucket_size: usize,
81}
82
83impl Default for FastTextConfig {
84 fn default() -> Self {
85 Self {
86 vector_size: 100,
87 min_n: 3,
88 max_n: 6,
89 window_size: 5,
90 epochs: 5,
91 learning_rate: 0.05,
92 min_count: 5,
93 negative_samples: 5,
94 subsample: 1e-3,
95 bucket_size: 2_000_000,
96 }
97 }
98}
99
100pub struct FastText {
102 config: FastTextConfig,
104 vocabulary: Vocabulary,
106 word_counts: HashMap<String, usize>,
108 word_embeddings: Option<Array2<f64>>,
110 ngram_embeddings: Option<Array2<f64>>,
112 ngram_to_bucket: HashMap<String, usize>,
114 tokenizer: Box<dyn Tokenizer + Send + Sync>,
116 current_learning_rate: f64,
118}
119
120impl Debug for FastText {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("FastText")
123 .field("config", &self.config)
124 .field("vocabulary_size", &self.vocabulary.len())
125 .field("word_embeddings", &self.word_embeddings.is_some())
126 .field("ngram_embeddings", &self.ngram_embeddings.is_some())
127 .finish()
128 }
129}
130
131impl Clone for FastText {
132 fn clone(&self) -> Self {
133 Self {
134 config: self.config.clone(),
135 vocabulary: self.vocabulary.clone(),
136 word_counts: self.word_counts.clone(),
137 word_embeddings: self.word_embeddings.clone(),
138 ngram_embeddings: self.ngram_embeddings.clone(),
139 ngram_to_bucket: self.ngram_to_bucket.clone(),
140 tokenizer: Box::new(WordTokenizer::default()),
141 current_learning_rate: self.current_learning_rate,
142 }
143 }
144}
145
146impl FastText {
147 pub fn new() -> Self {
149 Self {
150 config: FastTextConfig::default(),
151 vocabulary: Vocabulary::new(),
152 word_counts: HashMap::new(),
153 word_embeddings: None,
154 ngram_embeddings: None,
155 ngram_to_bucket: HashMap::new(),
156 tokenizer: Box::new(WordTokenizer::default()),
157 current_learning_rate: 0.05,
158 }
159 }
160
161 pub fn with_config(config: FastTextConfig) -> Self {
163 let learning_rate = config.learning_rate;
164 Self {
165 config,
166 vocabulary: Vocabulary::new(),
167 word_counts: HashMap::new(),
168 word_embeddings: None,
169 ngram_embeddings: None,
170 ngram_to_bucket: HashMap::new(),
171 tokenizer: Box::new(WordTokenizer::default()),
172 current_learning_rate: learning_rate,
173 }
174 }
175
176 fn extract_ngrams(&self, word: &str) -> Vec<String> {
178 let word_with_boundaries = format!("<{}>", word);
179 let chars: Vec<char> = word_with_boundaries.chars().collect();
180 let mut ngrams = Vec::new();
181
182 for n in self.config.min_n..=self.config.max_n {
183 if chars.len() < n {
184 continue;
185 }
186
187 for i in 0..=(chars.len() - n) {
188 let ngram: String = chars[i..i + n].iter().collect();
189 ngrams.push(ngram);
190 }
191 }
192
193 ngrams
194 }
195
196 fn hash_ngram(&self, ngram: &str) -> usize {
198 let mut hash: u64 = 2166136261;
200 for byte in ngram.bytes() {
201 hash ^= u64::from(byte);
202 hash = hash.wrapping_mul(16777619);
203 }
204 (hash % (self.config.bucket_size as u64)) as usize
205 }
206
207 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
209 if texts.is_empty() {
210 return Err(TextError::InvalidInput(
211 "No texts provided for building vocabulary".into(),
212 ));
213 }
214
215 let mut word_counts = HashMap::new();
217
218 for &text in texts {
219 let tokens = self.tokenizer.tokenize(text)?;
220 for token in tokens {
221 *word_counts.entry(token).or_insert(0) += 1;
222 }
223 }
224
225 self.vocabulary = Vocabulary::new();
227 for (word, count) in &word_counts {
228 if *count >= self.config.min_count {
229 self.vocabulary.add_token(word);
230 }
231 }
232
233 if self.vocabulary.is_empty() {
234 return Err(TextError::VocabularyError(
235 "No words meet the minimum count threshold".into(),
236 ));
237 }
238
239 self.word_counts = word_counts;
240
241 let vocab_size = self.vocabulary.len();
243 let vector_size = self.config.vector_size;
244 let bucket_size = self.config.bucket_size;
245
246 let mut rng = scirs2_core::random::rng();
247
248 let word_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
250 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
251 });
252
253 let ngram_embeddings = Array2::from_shape_fn((bucket_size, vector_size), |_| {
255 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
256 });
257
258 self.word_embeddings = Some(word_embeddings);
259 self.ngram_embeddings = Some(ngram_embeddings);
260
261 self.ngram_to_bucket.clear();
263 for i in 0..self.vocabulary.len() {
264 if let Some(word) = self.vocabulary.get_token(i) {
265 let ngrams = self.extract_ngrams(word);
266 for ngram in ngrams {
267 if !self.ngram_to_bucket.contains_key(&ngram) {
268 let bucket = self.hash_ngram(&ngram);
269 self.ngram_to_bucket.insert(ngram, bucket);
270 }
271 }
272 }
273 }
274
275 Ok(())
276 }
277
278 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
280 if texts.is_empty() {
281 return Err(TextError::InvalidInput(
282 "No texts provided for training".into(),
283 ));
284 }
285
286 if self.vocabulary.is_empty() {
288 self.build_vocabulary(texts)?;
289 }
290
291 let mut sentences = Vec::new();
293 for &text in texts {
294 let tokens = self.tokenizer.tokenize(text)?;
295 let word_indices: Vec<usize> = tokens
296 .iter()
297 .filter_map(|token| self.vocabulary.get_index(token))
298 .collect();
299 if !word_indices.is_empty() {
300 sentences.push(word_indices);
301 }
302 }
303
304 for epoch in 0..self.config.epochs {
306 self.current_learning_rate =
308 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
309 self.current_learning_rate = self
310 .current_learning_rate
311 .max(self.config.learning_rate * 0.0001);
312
313 for sentence in &sentences {
315 self.train_sentence(sentence)?;
316 }
317 }
318
319 Ok(())
320 }
321
322 fn train_sentence(&mut self, sentence: &[usize]) -> Result<()> {
324 if sentence.len() < 2 {
325 return Ok(());
326 }
327
328 let mut sentence_ngrams = Vec::with_capacity(sentence.len());
330 for &target_idx in sentence {
331 let target_word = self
332 .vocabulary
333 .get_token(target_idx)
334 .ok_or_else(|| TextError::VocabularyError("Invalid word index".into()))?;
335 let ngrams = self.extract_ngrams(target_word);
336 let ngram_buckets: Vec<usize> = ngrams
337 .iter()
338 .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
339 .collect();
340 sentence_ngrams.push(ngram_buckets);
341 }
342
343 let word_embeddings = self
344 .word_embeddings
345 .as_mut()
346 .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
347 let ngram_embeddings = self
348 .ngram_embeddings
349 .as_mut()
350 .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
351
352 let mut rng = scirs2_core::random::rng();
353
354 for (pos, &target_idx) in sentence.iter().enumerate() {
356 let window = 1 + rng.random_range(0..self.config.window_size);
358
359 let ngram_buckets = &sentence_ngrams[pos];
361
362 let mut target_vec = word_embeddings.row(target_idx).to_owned();
364 for &bucket in ngram_buckets {
365 target_vec += &ngram_embeddings.row(bucket);
366 }
367 if !ngram_buckets.is_empty() {
368 target_vec /= 1.0 + ngram_buckets.len() as f64;
369 }
370
371 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
373 if i == pos {
374 continue;
375 }
376
377 let context_idx = sentence[i];
378
379 let context_vec = word_embeddings.row(context_idx).to_owned();
381 let dot_product: f64 = target_vec
382 .iter()
383 .zip(context_vec.iter())
384 .map(|(a, b)| a * b)
385 .sum();
386 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
387 let gradient = (1.0 - sigmoid) * self.current_learning_rate;
388
389 let update = &target_vec * gradient;
391 let mut context_row = word_embeddings.row_mut(context_idx);
392 context_row += &update;
393
394 if !ngram_buckets.is_empty() {
396 let ngram_update = update / (1.0 + ngram_buckets.len() as f64);
397 for &bucket in ngram_buckets {
398 let mut ngram_row = ngram_embeddings.row_mut(bucket);
399 ngram_row += &ngram_update;
400 }
401 }
402
403 for _ in 0..self.config.negative_samples {
405 let neg_idx = rng.random_range(0..self.vocabulary.len());
406 if neg_idx == context_idx {
407 continue;
408 }
409
410 let neg_vec = word_embeddings.row(neg_idx).to_owned();
411 let dot_product: f64 = target_vec
412 .iter()
413 .zip(neg_vec.iter())
414 .map(|(a, b)| a * b)
415 .sum();
416 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
417 let gradient = -sigmoid * self.current_learning_rate;
418
419 let update = &target_vec * gradient;
420 let mut neg_row = word_embeddings.row_mut(neg_idx);
421 neg_row += &update;
422 }
423 }
424 }
425
426 Ok(())
427 }
428
429 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
431 let word_embeddings = self
432 .word_embeddings
433 .as_ref()
434 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
435 let ngram_embeddings = self
436 .ngram_embeddings
437 .as_ref()
438 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
439
440 let ngrams = self.extract_ngrams(word);
441 let mut vector = Array1::zeros(self.config.vector_size);
442 let mut count = 0.0;
443
444 if let Some(idx) = self.vocabulary.get_index(word) {
446 vector += &word_embeddings.row(idx);
447 count += 1.0;
448 }
449
450 for ngram in &ngrams {
452 if let Some(&bucket) = self.ngram_to_bucket.get(ngram) {
453 vector += &ngram_embeddings.row(bucket);
454 count += 1.0;
455 }
456 }
457
458 if count > 0.0 {
459 vector /= count;
460 Ok(vector)
461 } else {
462 Err(TextError::VocabularyError(format!(
463 "Cannot compute vector for word '{}': no n-grams found",
464 word
465 )))
466 }
467 }
468
469 pub fn most_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
471 let word_vec = self.get_word_vector(word)?;
472 let word_embeddings = self
473 .word_embeddings
474 .as_ref()
475 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
476
477 let mut similarities = Vec::new();
478
479 for i in 0..self.vocabulary.len() {
480 if let Some(candidate) = self.vocabulary.get_token(i) {
481 if candidate == word {
482 continue;
483 }
484
485 let candidate_vec = word_embeddings.row(i).to_owned();
486 let similarity = cosine_similarity(&word_vec, &candidate_vec);
487 similarities.push((candidate.to_string(), similarity));
488 }
489 }
490
491 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
492 Ok(similarities.into_iter().take(top_n).collect())
493 }
494
495 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
497 let word_embeddings = self
498 .word_embeddings
499 .as_ref()
500 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
501
502 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
503
504 writeln!(
506 &mut file,
507 "{} {}",
508 self.vocabulary.len(),
509 self.config.vector_size
510 )
511 .map_err(|e| TextError::IoError(e.to_string()))?;
512
513 for i in 0..self.vocabulary.len() {
515 if let Some(word) = self.vocabulary.get_token(i) {
516 write!(&mut file, "{} ", word).map_err(|e| TextError::IoError(e.to_string()))?;
517
518 for j in 0..self.config.vector_size {
519 write!(&mut file, "{:.6} ", word_embeddings[[i, j]])
520 .map_err(|e| TextError::IoError(e.to_string()))?;
521 }
522
523 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
524 }
525 }
526
527 Ok(())
528 }
529
530 pub fn vocabulary_size(&self) -> usize {
532 self.vocabulary.len()
533 }
534
535 pub fn vector_size(&self) -> usize {
537 self.config.vector_size
538 }
539}
540
541impl Default for FastText {
542 fn default() -> Self {
543 Self::new()
544 }
545}
546
547fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
549 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
550 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
551 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
552
553 if norm_a > 0.0 && norm_b > 0.0 {
554 dot_product / (norm_a * norm_b)
555 } else {
556 0.0
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_extract_ngrams() {
566 let config = FastTextConfig {
567 min_n: 3,
568 max_n: 4,
569 ..Default::default()
570 };
571 let model = FastText::with_config(config);
572
573 let ngrams = model.extract_ngrams("test");
574 assert!(!ngrams.is_empty());
576 assert!(ngrams.contains(&"<te".to_string()));
577 assert!(ngrams.contains(&"est".to_string()));
578 }
579
580 #[test]
581 fn test_fasttext_training() {
582 let texts = [
583 "the quick brown fox jumps over the lazy dog",
584 "a quick brown dog outpaces a quick fox",
585 ];
586
587 let config = FastTextConfig {
588 vector_size: 10,
589 window_size: 2,
590 min_count: 1,
591 epochs: 1,
592 min_n: 3,
593 max_n: 4,
594 ..Default::default()
595 };
596
597 let mut model = FastText::with_config(config);
598 let result = model.train(&texts);
599 assert!(result.is_ok());
600
601 let vec = model.get_word_vector("quick");
603 assert!(vec.is_ok());
604 assert_eq!(vec.expect("Failed to get vector").len(), 10);
605
606 let oov_vec = model.get_word_vector("quickest");
608 assert!(oov_vec.is_ok());
609 }
610
611 #[test]
612 fn test_fasttext_oov_handling() {
613 let texts = ["hello world", "hello there"];
614
615 let config = FastTextConfig {
616 vector_size: 10,
617 min_count: 1,
618 epochs: 1,
619 ..Default::default()
620 };
621
622 let mut model = FastText::with_config(config);
623 model.train(&texts).expect("Training failed");
624
625 let oov_vec = model.get_word_vector("helloworld");
627 assert!(oov_vec.is_ok(), "FastText should handle OOV words");
628 }
629}