1use crate::error::{Result, TextError};
53use crate::tokenize::{Tokenizer, WordTokenizer};
54use crate::vocabulary::Vocabulary;
55use scirs2_core::ndarray::{Array1, Array2};
56use scirs2_core::random::prelude::*;
57use std::collections::HashMap;
58use std::fmt::Debug;
59use std::fs::File;
60use std::io::{BufRead, BufReader, Write};
61use std::path::Path;
62
63#[derive(Debug, Clone)]
65pub struct FastTextConfig {
66 pub vector_size: usize,
68 pub min_n: usize,
70 pub max_n: usize,
72 pub window_size: usize,
74 pub epochs: usize,
76 pub learning_rate: f64,
78 pub min_count: usize,
80 pub negative_samples: usize,
82 pub subsample: f64,
84 pub bucket_size: usize,
86}
87
88impl Default for FastTextConfig {
89 fn default() -> Self {
90 Self {
91 vector_size: 100,
92 min_n: 3,
93 max_n: 6,
94 window_size: 5,
95 epochs: 5,
96 learning_rate: 0.05,
97 min_count: 5,
98 negative_samples: 5,
99 subsample: 1e-3,
100 bucket_size: 2_000_000,
101 }
102 }
103}
104
105pub struct FastText {
114 config: FastTextConfig,
116 vocabulary: Vocabulary,
118 word_counts: HashMap<String, usize>,
120 word_embeddings: Option<Array2<f64>>,
122 output_embeddings: Option<Array2<f64>>,
124 ngram_embeddings: Option<Array2<f64>>,
126 ngram_to_bucket: HashMap<String, usize>,
128 tokenizer: Box<dyn Tokenizer + Send + Sync>,
130 current_learning_rate: f64,
132 sampling_weights: Vec<f64>,
134}
135
136impl Debug for FastText {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("FastText")
139 .field("config", &self.config)
140 .field("vocabulary_size", &self.vocabulary.len())
141 .field("word_embeddings", &self.word_embeddings.is_some())
142 .field("ngram_embeddings", &self.ngram_embeddings.is_some())
143 .field("ngram_count", &self.ngram_to_bucket.len())
144 .finish()
145 }
146}
147
148impl Clone for FastText {
149 fn clone(&self) -> Self {
150 Self {
151 config: self.config.clone(),
152 vocabulary: self.vocabulary.clone(),
153 word_counts: self.word_counts.clone(),
154 word_embeddings: self.word_embeddings.clone(),
155 output_embeddings: self.output_embeddings.clone(),
156 ngram_embeddings: self.ngram_embeddings.clone(),
157 ngram_to_bucket: self.ngram_to_bucket.clone(),
158 tokenizer: Box::new(WordTokenizer::default()),
159 current_learning_rate: self.current_learning_rate,
160 sampling_weights: self.sampling_weights.clone(),
161 }
162 }
163}
164
165impl FastText {
166 pub fn new() -> Self {
168 Self {
169 config: FastTextConfig::default(),
170 vocabulary: Vocabulary::new(),
171 word_counts: HashMap::new(),
172 word_embeddings: None,
173 output_embeddings: None,
174 ngram_embeddings: None,
175 ngram_to_bucket: HashMap::new(),
176 tokenizer: Box::new(WordTokenizer::default()),
177 current_learning_rate: 0.05,
178 sampling_weights: Vec::new(),
179 }
180 }
181
182 pub fn with_config(config: FastTextConfig) -> Self {
184 let learning_rate = config.learning_rate;
185 Self {
186 config,
187 vocabulary: Vocabulary::new(),
188 word_counts: HashMap::new(),
189 word_embeddings: None,
190 output_embeddings: None,
191 ngram_embeddings: None,
192 ngram_to_bucket: HashMap::new(),
193 tokenizer: Box::new(WordTokenizer::default()),
194 current_learning_rate: learning_rate,
195 sampling_weights: Vec::new(),
196 }
197 }
198
199 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
201 self.tokenizer = tokenizer;
202 self
203 }
204
205 pub fn extract_ngrams(&self, word: &str) -> Vec<String> {
212 let word_with_boundaries = format!("<{}>", word);
213 let chars: Vec<char> = word_with_boundaries.chars().collect();
214 let mut ngrams = Vec::new();
215
216 for n in self.config.min_n..=self.config.max_n {
217 if chars.len() < n {
218 continue;
219 }
220
221 for i in 0..=(chars.len() - n) {
222 let ngram: String = chars[i..i + n].iter().collect();
223 ngrams.push(ngram);
224 }
225 }
226
227 ngrams
228 }
229
230 fn hash_ngram(&self, ngram: &str) -> usize {
232 let mut hash: u64 = 2166136261;
233 for byte in ngram.bytes() {
234 hash ^= u64::from(byte);
235 hash = hash.wrapping_mul(16777619);
236 }
237 (hash % (self.config.bucket_size as u64)) as usize
238 }
239
240 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
242 if texts.is_empty() {
243 return Err(TextError::InvalidInput(
244 "No texts provided for building vocabulary".into(),
245 ));
246 }
247
248 let mut word_counts = HashMap::new();
250
251 for &text in texts {
252 let tokens = self.tokenizer.tokenize(text)?;
253 for token in tokens {
254 *word_counts.entry(token).or_insert(0) += 1;
255 }
256 }
257
258 self.vocabulary = Vocabulary::new();
260 for (word, count) in &word_counts {
261 if *count >= self.config.min_count {
262 self.vocabulary.add_token(word);
263 }
264 }
265
266 if self.vocabulary.is_empty() {
267 return Err(TextError::VocabularyError(
268 "No words meet the minimum count threshold".into(),
269 ));
270 }
271
272 self.word_counts = word_counts;
273
274 let vocab_size = self.vocabulary.len();
276 let vector_size = self.config.vector_size;
277 let bucket_size = self.config.bucket_size;
278
279 let mut rng = scirs2_core::random::rng();
280
281 let word_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
283 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
284 });
285
286 let output_embeddings = Array2::zeros((vocab_size, vector_size));
288
289 let ngram_embeddings = Array2::from_shape_fn((bucket_size, vector_size), |_| {
291 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
292 });
293
294 self.word_embeddings = Some(word_embeddings);
295 self.output_embeddings = Some(output_embeddings);
296 self.ngram_embeddings = Some(ngram_embeddings);
297
298 self.ngram_to_bucket.clear();
300 for i in 0..self.vocabulary.len() {
301 if let Some(word) = self.vocabulary.get_token(i) {
302 let ngrams = self.extract_ngrams(word);
303 for ngram in ngrams {
304 if !self.ngram_to_bucket.contains_key(&ngram) {
305 let bucket = self.hash_ngram(&ngram);
306 self.ngram_to_bucket.insert(ngram, bucket);
307 }
308 }
309 }
310 }
311
312 self.sampling_weights = vec![0.0; vocab_size];
314 for i in 0..vocab_size {
315 if let Some(word) = self.vocabulary.get_token(i) {
316 let count = self.word_counts.get(word).copied().unwrap_or(1);
317 self.sampling_weights[i] = (count as f64).powf(0.75);
318 }
319 }
320
321 Ok(())
322 }
323
324 fn sample_negative(&self, rng: &mut impl Rng) -> usize {
326 if self.sampling_weights.is_empty() {
327 return 0;
328 }
329 let total: f64 = self.sampling_weights.iter().sum();
330 if total <= 0.0 {
331 return rng.random_range(0..self.vocabulary.len().max(1));
332 }
333 let r = rng.random::<f64>() * total;
334 let mut cumulative = 0.0;
335 for (i, &w) in self.sampling_weights.iter().enumerate() {
336 cumulative += w;
337 if r <= cumulative {
338 return i;
339 }
340 }
341 self.sampling_weights.len() - 1
342 }
343
344 fn compute_word_representation(&self, word_idx: usize) -> Result<(Array1<f64>, Vec<usize>)> {
348 let word_embeddings = self
349 .word_embeddings
350 .as_ref()
351 .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
352 let ngram_embeddings = self
353 .ngram_embeddings
354 .as_ref()
355 .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
356
357 let word = self
358 .vocabulary
359 .get_token(word_idx)
360 .ok_or_else(|| TextError::VocabularyError("Invalid word index".into()))?;
361
362 let ngrams = self.extract_ngrams(word);
363 let ngram_buckets: Vec<usize> = ngrams
364 .iter()
365 .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
366 .collect();
367
368 let mut vec = word_embeddings.row(word_idx).to_owned();
369 for &bucket in &ngram_buckets {
370 vec += &ngram_embeddings.row(bucket);
371 }
372 let divisor = 1.0 + ngram_buckets.len() as f64;
373 vec /= divisor;
374
375 Ok((vec, ngram_buckets))
376 }
377
378 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
380 if texts.is_empty() {
381 return Err(TextError::InvalidInput(
382 "No texts provided for training".into(),
383 ));
384 }
385
386 if self.vocabulary.is_empty() {
388 self.build_vocabulary(texts)?;
389 }
390
391 let mut sentences = Vec::new();
393 for &text in texts {
394 let tokens = self.tokenizer.tokenize(text)?;
395 let word_indices: Vec<usize> = tokens
396 .iter()
397 .filter_map(|token| self.vocabulary.get_index(token))
398 .collect();
399 if !word_indices.is_empty() {
400 sentences.push(word_indices);
401 }
402 }
403
404 let mut word_ngram_buckets: Vec<Vec<usize>> = Vec::with_capacity(self.vocabulary.len());
406 for i in 0..self.vocabulary.len() {
407 if let Some(word) = self.vocabulary.get_token(i) {
408 let ngrams = self.extract_ngrams(word);
409 let buckets: Vec<usize> = ngrams
410 .iter()
411 .filter_map(|ng| self.ngram_to_bucket.get(ng).copied())
412 .collect();
413 word_ngram_buckets.push(buckets);
414 } else {
415 word_ngram_buckets.push(Vec::new());
416 }
417 }
418
419 for epoch in 0..self.config.epochs {
421 self.current_learning_rate =
423 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
424 self.current_learning_rate = self
425 .current_learning_rate
426 .max(self.config.learning_rate * 0.0001);
427
428 for sentence in &sentences {
430 self.train_sentence(sentence, &word_ngram_buckets)?;
431 }
432 }
433
434 Ok(())
435 }
436
437 fn train_sentence(
439 &mut self,
440 sentence: &[usize],
441 word_ngram_buckets: &[Vec<usize>],
442 ) -> Result<()> {
443 if sentence.len() < 2 {
444 return Ok(());
445 }
446
447 let sampling_weights = self.sampling_weights.clone();
449 let vocab_len = self.vocabulary.len().max(1);
450 let negative_samples = self.config.negative_samples;
451 let current_lr = self.current_learning_rate;
452
453 let word_embeddings = self
454 .word_embeddings
455 .as_mut()
456 .ok_or_else(|| TextError::EmbeddingError("Word embeddings not initialized".into()))?;
457 let output_embeddings = self
458 .output_embeddings
459 .as_mut()
460 .ok_or_else(|| TextError::EmbeddingError("Output embeddings not initialized".into()))?;
461 let ngram_embeddings = self
462 .ngram_embeddings
463 .as_mut()
464 .ok_or_else(|| TextError::EmbeddingError("N-gram embeddings not initialized".into()))?;
465
466 let vector_size = self.config.vector_size;
467 let mut rng = scirs2_core::random::rng();
468
469 let total_weight: f64 = sampling_weights.iter().sum();
471 let cumulative_weights: Vec<f64> = if total_weight > 0.0 {
472 let mut cum = Vec::with_capacity(sampling_weights.len());
473 let mut acc = 0.0;
474 for &w in &sampling_weights {
475 acc += w;
476 cum.push(acc);
477 }
478 cum
479 } else {
480 Vec::new()
481 };
482
483 for (pos, &target_idx) in sentence.iter().enumerate() {
485 let window = 1 + rng.random_range(0..self.config.window_size);
487
488 let ngram_buckets = &word_ngram_buckets[target_idx];
490
491 let mut input_vec = word_embeddings.row(target_idx).to_owned();
493 for &bucket in ngram_buckets {
494 input_vec += &ngram_embeddings.row(bucket);
495 }
496 let divisor = 1.0 + ngram_buckets.len() as f64;
497 input_vec /= divisor;
498
499 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
501 if i == pos {
502 continue;
503 }
504
505 let context_idx = sentence[i];
506
507 let mut grad_input = Array1::zeros(vector_size);
509
510 let output_vec = output_embeddings.row(context_idx).to_owned();
512 let dot_product: f64 = input_vec
513 .iter()
514 .zip(output_vec.iter())
515 .map(|(a, b)| a * b)
516 .sum();
517 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
518 let gradient = (1.0 - sigmoid) * current_lr;
519
520 grad_input.scaled_add(gradient, &output_vec);
522
523 let mut out_row = output_embeddings.row_mut(context_idx);
525 let update = &input_vec * gradient;
526 out_row += &update;
527
528 for _ in 0..negative_samples {
530 let neg_idx = if cumulative_weights.is_empty() {
531 if vocab_len > 0 {
532 rng.random_range(0..vocab_len)
533 } else {
534 0
535 }
536 } else {
537 let r = rng.random::<f64>() * total_weight;
538 match cumulative_weights.binary_search_by(|w| {
539 w.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
540 }) {
541 Ok(i) => i,
542 Err(i) => i.min(cumulative_weights.len() - 1),
543 }
544 };
545 if neg_idx == context_idx {
546 continue;
547 }
548
549 let neg_vec = output_embeddings.row(neg_idx).to_owned();
550 let dot_product: f64 = input_vec
551 .iter()
552 .zip(neg_vec.iter())
553 .map(|(a, b)| a * b)
554 .sum();
555 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
556 let gradient = -sigmoid * current_lr;
557
558 grad_input.scaled_add(gradient, &neg_vec);
560
561 let mut neg_row = output_embeddings.row_mut(neg_idx);
563 let neg_update = &input_vec * gradient;
564 neg_row += &neg_update;
565 }
566
567 let scaled_grad = &grad_input / divisor;
569
570 let mut word_row = word_embeddings.row_mut(target_idx);
571 word_row += &scaled_grad;
572
573 for &bucket in ngram_buckets {
574 let mut ngram_row = ngram_embeddings.row_mut(bucket);
575 ngram_row += &scaled_grad;
576 }
577 }
578 }
579
580 Ok(())
581 }
582
583 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
588 let word_embeddings = self
589 .word_embeddings
590 .as_ref()
591 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
592 let ngram_embeddings = self
593 .ngram_embeddings
594 .as_ref()
595 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
596
597 let ngrams = self.extract_ngrams(word);
598 let mut vector = Array1::zeros(self.config.vector_size);
599 let mut count = 0.0;
600
601 if let Some(idx) = self.vocabulary.get_index(word) {
603 vector += &word_embeddings.row(idx);
604 count += 1.0;
605 }
606
607 for ngram in &ngrams {
609 if let Some(&bucket) = self.ngram_to_bucket.get(ngram) {
610 vector += &ngram_embeddings.row(bucket);
611 count += 1.0;
612 } else {
613 let bucket = self.hash_ngram(ngram);
615 if bucket < self.config.bucket_size {
616 vector += &ngram_embeddings.row(bucket);
617 count += 1.0;
618 }
619 }
620 }
621
622 if count > 0.0 {
623 vector /= count;
624 Ok(vector)
625 } else {
626 Err(TextError::VocabularyError(format!(
627 "Cannot compute vector for word '{}': no n-grams found",
628 word
629 )))
630 }
631 }
632
633 pub fn most_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
635 let word_vec = self.get_word_vector(word)?;
636 self.most_similar_by_vector(&word_vec, top_n, &[word])
637 }
638
639 pub fn most_similar_by_vector(
641 &self,
642 vector: &Array1<f64>,
643 top_n: usize,
644 exclude_words: &[&str],
645 ) -> Result<Vec<(String, f64)>> {
646 let mut similarities = Vec::new();
647
648 for i in 0..self.vocabulary.len() {
649 if let Some(candidate) = self.vocabulary.get_token(i) {
650 if exclude_words.contains(&candidate) {
651 continue;
652 }
653
654 if let Ok(candidate_vec) = self.get_word_vector(candidate) {
655 let similarity = cosine_similarity(vector, &candidate_vec);
656 similarities.push((candidate.to_string(), similarity));
657 }
658 }
659 }
660
661 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
662 Ok(similarities.into_iter().take(top_n).collect())
663 }
664
665 pub fn analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
670 let a_vec = self.get_word_vector(a)?;
671 let b_vec = self.get_word_vector(b)?;
672 let c_vec = self.get_word_vector(c)?;
673
674 let mut d_vec = b_vec.clone();
676 d_vec -= &a_vec;
677 d_vec += &c_vec;
678
679 let norm = d_vec.iter().fold(0.0, |sum, &val| sum + val * val).sqrt();
681 if norm > 0.0 {
682 d_vec.mapv_inplace(|val| val / norm);
683 }
684
685 self.most_similar_by_vector(&d_vec, top_n, &[a, b, c])
686 }
687
688 pub fn word_similarity(&self, word1: &str, word2: &str) -> Result<f64> {
692 let vec1 = self.get_word_vector(word1)?;
693 let vec2 = self.get_word_vector(word2)?;
694 Ok(cosine_similarity(&vec1, &vec2))
695 }
696
697 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
704 let word_embeddings = self
705 .word_embeddings
706 .as_ref()
707 .ok_or_else(|| TextError::EmbeddingError("Model not trained".into()))?;
708
709 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
710
711 writeln!(
713 &mut file,
714 "FASTTEXT {} {} {} {} {}",
715 self.vocabulary.len(),
716 self.config.vector_size,
717 self.config.min_n,
718 self.config.max_n,
719 self.config.bucket_size,
720 )
721 .map_err(|e| TextError::IoError(e.to_string()))?;
722
723 for i in 0..self.vocabulary.len() {
725 if let Some(word) = self.vocabulary.get_token(i) {
726 write!(&mut file, "{} ", word).map_err(|e| TextError::IoError(e.to_string()))?;
727
728 for j in 0..self.config.vector_size {
730 write!(&mut file, "{:.6} ", word_embeddings[[i, j]])
731 .map_err(|e| TextError::IoError(e.to_string()))?;
732 }
733
734 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
735 }
736 }
737
738 writeln!(&mut file, "NGRAMS {}", self.ngram_to_bucket.len())
740 .map_err(|e| TextError::IoError(e.to_string()))?;
741
742 for (ngram, &bucket) in &self.ngram_to_bucket {
743 writeln!(&mut file, "{} {}", ngram, bucket)
744 .map_err(|e| TextError::IoError(e.to_string()))?;
745 }
746
747 Ok(())
748 }
749
750 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
752 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
753 let mut reader = BufReader::new(file);
754
755 let mut header = String::new();
757 reader
758 .read_line(&mut header)
759 .map_err(|e| TextError::IoError(e.to_string()))?;
760
761 let parts: Vec<&str> = header.split_whitespace().collect();
762 if parts.len() < 6 || parts[0] != "FASTTEXT" {
763 return Err(TextError::EmbeddingError(
764 "Invalid FastText file format (expected FASTTEXT header)".into(),
765 ));
766 }
767
768 let vocab_size = parts[1]
769 .parse::<usize>()
770 .map_err(|_| TextError::EmbeddingError("Invalid vocab size".into()))?;
771 let vector_size = parts[2]
772 .parse::<usize>()
773 .map_err(|_| TextError::EmbeddingError("Invalid vector size".into()))?;
774 let min_n = parts[3]
775 .parse::<usize>()
776 .map_err(|_| TextError::EmbeddingError("Invalid min_n".into()))?;
777 let max_n = parts[4]
778 .parse::<usize>()
779 .map_err(|_| TextError::EmbeddingError("Invalid max_n".into()))?;
780 let bucket_size = parts[5]
781 .parse::<usize>()
782 .map_err(|_| TextError::EmbeddingError("Invalid bucket_size".into()))?;
783
784 let config = FastTextConfig {
785 vector_size,
786 min_n,
787 max_n,
788 bucket_size,
789 ..Default::default()
790 };
791
792 let mut vocabulary = Vocabulary::new();
793 let mut word_embeddings = Array2::zeros((vocab_size, vector_size));
794
795 for i in 0..vocab_size {
797 let mut line = String::new();
798 reader
799 .read_line(&mut line)
800 .map_err(|e| TextError::IoError(e.to_string()))?;
801
802 let parts: Vec<&str> = line.split_whitespace().collect();
803 if parts.len() < vector_size + 1 {
804 return Err(TextError::EmbeddingError(format!(
805 "Invalid vector at line {}",
806 i + 2
807 )));
808 }
809
810 vocabulary.add_token(parts[0]);
811
812 for j in 0..vector_size {
813 word_embeddings[[i, j]] = parts[j + 1].parse::<f64>().map_err(|_| {
814 TextError::EmbeddingError(format!(
815 "Invalid float at line {}, position {}",
816 i + 2,
817 j + 1
818 ))
819 })?;
820 }
821 }
822
823 let mut ngram_to_bucket = HashMap::new();
825 let mut ngram_header = String::new();
826 if reader
827 .read_line(&mut ngram_header)
828 .map_err(|e| TextError::IoError(e.to_string()))?
829 > 0
830 {
831 let ngram_parts: Vec<&str> = ngram_header.split_whitespace().collect();
832 if ngram_parts.len() >= 2 && ngram_parts[0] == "NGRAMS" {
833 let ngram_count = ngram_parts[1]
834 .parse::<usize>()
835 .map_err(|_| TextError::EmbeddingError("Invalid ngram count".into()))?;
836
837 for _ in 0..ngram_count {
838 let mut ngram_line = String::new();
839 reader
840 .read_line(&mut ngram_line)
841 .map_err(|e| TextError::IoError(e.to_string()))?;
842
843 let np: Vec<&str> = ngram_line.split_whitespace().collect();
844 if np.len() >= 2 {
845 let bucket = np[1]
846 .parse::<usize>()
847 .map_err(|_| TextError::EmbeddingError("Invalid bucket".into()))?;
848 ngram_to_bucket.insert(np[0].to_string(), bucket);
849 }
850 }
851 }
852 }
853
854 let ngram_embeddings = Array2::zeros((bucket_size, vector_size));
856
857 Ok(Self {
858 config,
859 vocabulary,
860 word_counts: HashMap::new(),
861 word_embeddings: Some(word_embeddings),
862 output_embeddings: None,
863 ngram_embeddings: Some(ngram_embeddings),
864 ngram_to_bucket,
865 tokenizer: Box::new(WordTokenizer::default()),
866 current_learning_rate: 0.05,
867 sampling_weights: Vec::new(),
868 })
869 }
870
871 pub fn vocabulary_size(&self) -> usize {
873 self.vocabulary.len()
874 }
875
876 pub fn vector_size(&self) -> usize {
878 self.config.vector_size
879 }
880
881 pub fn ngram_range(&self) -> (usize, usize) {
883 (self.config.min_n, self.config.max_n)
884 }
885
886 pub fn ngram_count(&self) -> usize {
888 self.ngram_to_bucket.len()
889 }
890
891 pub fn contains(&self, word: &str) -> bool {
893 self.vocabulary.contains(word)
894 }
895
896 pub fn can_represent(&self, word: &str) -> bool {
898 if self.vocabulary.contains(word) {
899 return true;
900 }
901 let ngrams = self.extract_ngrams(word);
903 ngrams
904 .iter()
905 .any(|ng| self.ngram_to_bucket.contains_key(ng))
906 }
907
908 pub fn get_vocabulary_words(&self) -> Vec<String> {
910 let mut words = Vec::with_capacity(self.vocabulary.len());
911 for i in 0..self.vocabulary.len() {
912 if let Some(word) = self.vocabulary.get_token(i) {
913 words.push(word.to_string());
914 }
915 }
916 words
917 }
918}
919
920impl Default for FastText {
921 fn default() -> Self {
922 Self::new()
923 }
924}
925
926fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
928 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
929 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
930 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
931
932 if norm_a > 0.0 && norm_b > 0.0 {
933 dot_product / (norm_a * norm_b)
934 } else {
935 0.0
936 }
937}
938
939#[cfg(test)]
940mod tests {
941 use super::*;
942
943 #[test]
944 fn test_extract_ngrams() {
945 let config = FastTextConfig {
946 min_n: 3,
947 max_n: 4,
948 ..Default::default()
949 };
950 let model = FastText::with_config(config);
951
952 let ngrams = model.extract_ngrams("test");
953 assert!(!ngrams.is_empty());
954 assert!(ngrams.contains(&"<te".to_string()));
955 assert!(ngrams.contains(&"est".to_string()));
956 assert!(ngrams.contains(&"st>".to_string()));
957 assert!(ngrams.contains(&"<tes".to_string()));
959 assert!(ngrams.contains(&"test".to_string()));
960 assert!(ngrams.contains(&"est>".to_string()));
961 }
962
963 #[test]
964 fn test_extract_ngrams_short_word() {
965 let config = FastTextConfig {
966 min_n: 3,
967 max_n: 6,
968 ..Default::default()
969 };
970 let model = FastText::with_config(config);
971
972 let ngrams = model.extract_ngrams("a");
973 assert_eq!(ngrams.len(), 1);
975 assert_eq!(ngrams[0], "<a>");
976 }
977
978 #[test]
979 fn test_fasttext_training() {
980 let texts = [
981 "the quick brown fox jumps over the lazy dog",
982 "a quick brown dog outpaces a quick fox",
983 ];
984
985 let config = FastTextConfig {
986 vector_size: 10,
987 window_size: 2,
988 min_count: 1,
989 epochs: 1,
990 min_n: 3,
991 max_n: 4,
992 bucket_size: 1000,
993 ..Default::default()
994 };
995
996 let mut model = FastText::with_config(config);
997 let result = model.train(&texts);
998 assert!(result.is_ok());
999
1000 let vec = model.get_word_vector("quick");
1002 assert!(vec.is_ok());
1003 assert_eq!(vec.expect("Failed to get vector").len(), 10);
1004
1005 let oov_vec = model.get_word_vector("quickest");
1007 assert!(oov_vec.is_ok());
1008 }
1009
1010 #[test]
1011 fn test_fasttext_oov_handling() {
1012 let texts = ["hello world", "hello there"];
1013
1014 let config = FastTextConfig {
1015 vector_size: 10,
1016 min_count: 1,
1017 epochs: 1,
1018 bucket_size: 1000,
1019 ..Default::default()
1020 };
1021
1022 let mut model = FastText::with_config(config);
1023 model.train(&texts).expect("Training failed");
1024
1025 let oov_vec = model.get_word_vector("helloworld");
1027 assert!(oov_vec.is_ok(), "FastText should handle OOV words");
1028 }
1029
1030 #[test]
1031 fn test_fasttext_analogy() {
1032 let texts = [
1033 "the king ruled the kingdom wisely",
1034 "the queen ruled the kingdom wisely",
1035 "the man worked in the field",
1036 "the woman worked in the field",
1037 "the king and the queen were happy",
1038 "the man and the woman were happy",
1039 ];
1040
1041 let config = FastTextConfig {
1042 vector_size: 20,
1043 window_size: 3,
1044 min_count: 1,
1045 epochs: 5,
1046 min_n: 3,
1047 max_n: 5,
1048 bucket_size: 1000,
1049 ..Default::default()
1050 };
1051
1052 let mut model = FastText::with_config(config);
1053 model.train(&texts).expect("Training failed");
1054
1055 let result = model.analogy("king", "man", "woman", 3);
1057 assert!(result.is_ok());
1058 let answers = result.expect("analogy");
1059 assert!(!answers.is_empty());
1060 }
1061
1062 #[test]
1063 fn test_fasttext_word_similarity() {
1064 let texts = [
1065 "the cat sat on the mat",
1066 "the dog sat on the rug",
1067 "the cat and dog played",
1068 ];
1069
1070 let config = FastTextConfig {
1071 vector_size: 10,
1072 min_count: 1,
1073 epochs: 3,
1074 min_n: 3,
1075 max_n: 4,
1076 bucket_size: 1000,
1077 ..Default::default()
1078 };
1079
1080 let mut model = FastText::with_config(config);
1081 model.train(&texts).expect("Training failed");
1082
1083 let sim = model.word_similarity("cat", "dog");
1084 assert!(sim.is_ok());
1085 assert!(sim.expect("similarity").is_finite());
1087 }
1088
1089 #[test]
1090 fn test_fasttext_save_load() {
1091 let texts = ["the quick brown fox jumps", "the lazy brown dog sleeps"];
1092
1093 let config = FastTextConfig {
1094 vector_size: 5,
1095 min_count: 1,
1096 epochs: 1,
1097 min_n: 3,
1098 max_n: 4,
1099 bucket_size: 1000,
1100 ..Default::default()
1101 };
1102
1103 let mut model = FastText::with_config(config);
1104 model.train(&texts).expect("Training failed");
1105
1106 let save_path = std::env::temp_dir().join("test_fasttext_save.txt");
1107 model.save(&save_path).expect("Failed to save");
1108
1109 let loaded = FastText::load(&save_path).expect("Failed to load");
1110 assert_eq!(loaded.vocabulary_size(), model.vocabulary_size());
1111 assert_eq!(loaded.vector_size(), model.vector_size());
1112 assert_eq!(loaded.ngram_range(), model.ngram_range());
1113
1114 std::fs::remove_file(save_path).ok();
1115 }
1116
1117 #[test]
1118 fn test_fasttext_can_represent() {
1119 let texts = ["hello world"];
1120
1121 let config = FastTextConfig {
1122 vector_size: 5,
1123 min_count: 1,
1124 epochs: 1,
1125 ..Default::default()
1126 };
1127
1128 let mut model = FastText::with_config(config);
1129 model.train(&texts).expect("Training failed");
1130
1131 assert!(model.contains("hello"));
1132 assert!(model.can_represent("hello"));
1133 assert!(!model.contains("helloworld"));
1134 assert!(model.can_represent("helloworld")); }
1136
1137 #[test]
1138 fn test_fasttext_most_similar() {
1139 let texts = [
1140 "the dog runs fast",
1141 "the cat runs fast",
1142 "the bird flies high",
1143 ];
1144
1145 let config = FastTextConfig {
1146 vector_size: 10,
1147 min_count: 1,
1148 epochs: 5,
1149 min_n: 3,
1150 max_n: 4,
1151 bucket_size: 1000,
1152 ..Default::default()
1153 };
1154
1155 let mut model = FastText::with_config(config);
1156 model.train(&texts).expect("Training failed");
1157
1158 let similar = model.most_similar("dog", 2).expect("most_similar");
1159 assert!(!similar.is_empty());
1160 assert!(similar.len() <= 2);
1161 }
1162
1163 #[test]
1164 fn test_fasttext_empty_input() {
1165 let texts: Vec<&str> = vec![];
1166 let mut model = FastText::new();
1167 let result = model.train(&texts);
1168 assert!(result.is_err());
1169 }
1170
1171 #[test]
1172 fn test_fasttext_config_defaults() {
1173 let config = FastTextConfig::default();
1174 assert_eq!(config.vector_size, 100);
1175 assert_eq!(config.min_n, 3);
1176 assert_eq!(config.max_n, 6);
1177 assert_eq!(config.window_size, 5);
1178 assert_eq!(config.bucket_size, 2_000_000);
1179 }
1180
1181 #[test]
1182 fn test_hash_ngram_deterministic() {
1183 let model = FastText::new();
1184 let h1 = model.hash_ngram("abc");
1185 let h2 = model.hash_ngram("abc");
1186 assert_eq!(h1, h2);
1187
1188 let h3 = model.hash_ngram("xyz");
1189 assert_ne!(h1, h3);
1192 }
1193}