1pub mod fasttext;
135pub mod glove;
136
137pub use fasttext::{FastText, FastTextConfig};
139pub use glove::{
140 cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
141 GloVeTrainerConfig,
142};
143
144use crate::error::{Result, TextError};
145use crate::tokenize::{Tokenizer, WordTokenizer};
146use crate::vocabulary::Vocabulary;
147use scirs2_core::ndarray::{Array1, Array2};
148use scirs2_core::random::prelude::*;
149use std::collections::HashMap;
150use std::fmt::Debug;
151use std::fs::File;
152use std::io::{BufRead, BufReader, Write};
153use std::path::Path;
154
155pub trait WordEmbedding {
163 fn embedding(&self, word: &str) -> Result<Array1<f64>>;
165
166 fn dimension(&self) -> usize;
168
169 fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
171 let v1 = self.embedding(word1)?;
172 let v2 = self.embedding(word2)?;
173 Ok(embedding_cosine_similarity(&v1, &v2))
174 }
175
176 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
178
179 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
181
182 fn vocab_size(&self) -> usize;
184}
185
186pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
188 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
189 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
190 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
191
192 if norm_a > 0.0 && norm_b > 0.0 {
193 dot_product / (norm_a * norm_b)
194 } else {
195 0.0
196 }
197}
198
199pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
201 let vectors: Vec<Array1<f64>> = words
202 .iter()
203 .map(|&w| model.embedding(w))
204 .collect::<Result<Vec<_>>>()?;
205
206 let n = vectors.len();
207 let mut matrix = vec![vec![0.0; n]; n];
208
209 for i in 0..n {
210 for j in i..n {
211 let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
212 matrix[i][j] = sim;
213 matrix[j][i] = sim;
214 }
215 }
216
217 Ok(matrix)
218}
219
220impl WordEmbedding for GloVe {
223 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
224 self.get_word_vector(word)
225 }
226
227 fn dimension(&self) -> usize {
228 self.vector_size()
229 }
230
231 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
232 self.most_similar(word, top_n)
233 }
234
235 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
236 self.analogy(a, b, c, top_n)
237 }
238
239 fn vocab_size(&self) -> usize {
240 self.vocabulary_size()
241 }
242}
243
244impl WordEmbedding for FastText {
245 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
246 self.get_word_vector(word)
247 }
248
249 fn dimension(&self) -> usize {
250 self.vector_size()
251 }
252
253 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
254 self.most_similar(word, top_n)
255 }
256
257 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
258 self.analogy(a, b, c, top_n)
259 }
260
261 fn vocab_size(&self) -> usize {
262 self.vocabulary_size()
263 }
264}
265
266#[derive(Debug, Clone)]
270struct HuffmanNode {
271 id: usize,
273 frequency: usize,
275 left: Option<usize>,
277 right: Option<usize>,
279 is_leaf: bool,
281}
282
283#[derive(Debug, Clone)]
285struct HuffmanTree {
286 codes: Vec<Vec<u8>>,
288 paths: Vec<Vec<usize>>,
290 num_internal: usize,
292}
293
294impl HuffmanTree {
295 fn build(frequencies: &[usize]) -> Result<Self> {
299 let vocab_size = frequencies.len();
300 if vocab_size == 0 {
301 return Err(TextError::EmbeddingError(
302 "Cannot build Huffman tree with empty vocabulary".into(),
303 ));
304 }
305 if vocab_size == 1 {
306 return Ok(Self {
308 codes: vec![vec![0]],
309 paths: vec![vec![0]],
310 num_internal: 1,
311 });
312 }
313
314 let mut nodes: Vec<HuffmanNode> = frequencies
316 .iter()
317 .enumerate()
318 .map(|(id, &freq)| HuffmanNode {
319 id,
320 frequency: freq.max(1), left: None,
322 right: None,
323 is_leaf: true,
324 })
325 .collect();
326
327 let mut queue: Vec<(usize, usize)> = nodes
330 .iter()
331 .enumerate()
332 .map(|(i, n)| (i, n.frequency))
333 .collect();
334 queue.sort_by(|a, b| b.1.cmp(&a.1)); while queue.len() > 1 {
338 let (idx1, freq1) = queue
340 .pop()
341 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
342 let (idx2, freq2) = queue
343 .pop()
344 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
345
346 let new_id = nodes.len();
347 let new_node = HuffmanNode {
348 id: new_id,
349 frequency: freq1 + freq2,
350 left: Some(idx1),
351 right: Some(idx2),
352 is_leaf: false,
353 };
354 nodes.push(new_node);
355
356 let new_freq = freq1 + freq2;
358 let insert_pos = queue
359 .binary_search_by(|(_, f)| new_freq.cmp(f))
360 .unwrap_or_else(|pos| pos);
361 queue.insert(insert_pos, (new_id, new_freq));
362 }
363
364 let num_internal = nodes.len() - vocab_size;
366 let mut codes = vec![Vec::new(); vocab_size];
367 let mut paths = vec![Vec::new(); vocab_size];
368
369 let root_idx = nodes.len() - 1;
371 let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
372
373 while let Some((node_idx, code, path)) = stack.pop() {
374 let node = &nodes[node_idx];
375
376 if node.is_leaf {
377 codes[node.id] = code;
378 paths[node.id] = path;
379 } else {
380 let internal_idx = node.id - vocab_size;
382
383 if let Some(left_idx) = node.left {
384 let mut left_code = code.clone();
385 left_code.push(0);
386 let mut left_path = path.clone();
387 left_path.push(internal_idx);
388 stack.push((left_idx, left_code, left_path));
389 }
390
391 if let Some(right_idx) = node.right {
392 let mut right_code = code.clone();
393 right_code.push(1);
394 let mut right_path = path.clone();
395 right_path.push(internal_idx);
396 stack.push((right_idx, right_code, right_path));
397 }
398 }
399 }
400
401 Ok(Self {
402 codes,
403 paths,
404 num_internal,
405 })
406 }
407}
408
409#[derive(Debug, Clone)]
411struct SamplingTable {
412 cdf: Vec<f64>,
414 weights: Vec<f64>,
416}
417
418impl SamplingTable {
419 fn new(weights: &[f64]) -> Result<Self> {
421 if weights.is_empty() {
422 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
423 }
424
425 if weights.iter().any(|&w| w < 0.0) {
427 return Err(TextError::EmbeddingError("Weights must be positive".into()));
428 }
429
430 let sum: f64 = weights.iter().sum();
432 if sum <= 0.0 {
433 return Err(TextError::EmbeddingError(
434 "Sum of _weights must be positive".into(),
435 ));
436 }
437
438 let mut cdf = Vec::with_capacity(weights.len());
439 let mut total = 0.0;
440
441 for &w in weights {
442 total += w;
443 cdf.push(total / sum);
444 }
445
446 Ok(Self {
447 cdf,
448 weights: weights.to_vec(),
449 })
450 }
451
452 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
454 let r = rng.random::<f64>();
455
456 match self.cdf.binary_search_by(|&cdf_val| {
458 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
459 }) {
460 Ok(idx) => idx,
461 Err(idx) => idx,
462 }
463 }
464
465 fn weights(&self) -> &[f64] {
467 &self.weights
468 }
469}
470
471#[derive(Debug, Clone, Copy, PartialEq, Eq)]
473pub enum Word2VecAlgorithm {
474 CBOW,
476 SkipGram,
478}
479
480#[derive(Debug, Clone)]
482pub struct Word2VecConfig {
483 pub vector_size: usize,
485 pub window_size: usize,
487 pub min_count: usize,
489 pub epochs: usize,
491 pub learning_rate: f64,
493 pub algorithm: Word2VecAlgorithm,
495 pub negative_samples: usize,
497 pub subsample: f64,
499 pub batch_size: usize,
501 pub hierarchical_softmax: bool,
503}
504
505impl Default for Word2VecConfig {
506 fn default() -> Self {
507 Self {
508 vector_size: 100,
509 window_size: 5,
510 min_count: 5,
511 epochs: 5,
512 learning_rate: 0.025,
513 algorithm: Word2VecAlgorithm::SkipGram,
514 negative_samples: 5,
515 subsample: 1e-3,
516 batch_size: 128,
517 hierarchical_softmax: false,
518 }
519 }
520}
521
522pub struct Word2Vec {
532 config: Word2VecConfig,
534 vocabulary: Vocabulary,
536 input_embeddings: Option<Array2<f64>>,
538 output_embeddings: Option<Array2<f64>>,
540 tokenizer: Box<dyn Tokenizer + Send + Sync>,
542 sampling_table: Option<SamplingTable>,
544 huffman_tree: Option<HuffmanTree>,
546 hs_params: Option<Array2<f64>>,
548 current_learning_rate: f64,
550}
551
552impl Debug for Word2Vec {
553 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
554 f.debug_struct("Word2Vec")
555 .field("config", &self.config)
556 .field("vocabulary", &self.vocabulary)
557 .field("input_embeddings", &self.input_embeddings)
558 .field("output_embeddings", &self.output_embeddings)
559 .field("sampling_table", &self.sampling_table)
560 .field("huffman_tree", &self.huffman_tree)
561 .field("current_learning_rate", &self.current_learning_rate)
562 .finish()
563 }
564}
565
566impl Default for Word2Vec {
568 fn default() -> Self {
569 Self::new()
570 }
571}
572
573impl Clone for Word2Vec {
574 fn clone(&self) -> Self {
575 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
576
577 Self {
578 config: self.config.clone(),
579 vocabulary: self.vocabulary.clone(),
580 input_embeddings: self.input_embeddings.clone(),
581 output_embeddings: self.output_embeddings.clone(),
582 tokenizer,
583 sampling_table: self.sampling_table.clone(),
584 huffman_tree: self.huffman_tree.clone(),
585 hs_params: self.hs_params.clone(),
586 current_learning_rate: self.current_learning_rate,
587 }
588 }
589}
590
591impl Word2Vec {
592 pub fn new() -> Self {
594 Self {
595 config: Word2VecConfig::default(),
596 vocabulary: Vocabulary::new(),
597 input_embeddings: None,
598 output_embeddings: None,
599 tokenizer: Box::new(WordTokenizer::default()),
600 sampling_table: None,
601 huffman_tree: None,
602 hs_params: None,
603 current_learning_rate: 0.025,
604 }
605 }
606
607 pub fn with_config(config: Word2VecConfig) -> Self {
609 let learning_rate = config.learning_rate;
610 Self {
611 config,
612 vocabulary: Vocabulary::new(),
613 input_embeddings: None,
614 output_embeddings: None,
615 tokenizer: Box::new(WordTokenizer::default()),
616 sampling_table: None,
617 huffman_tree: None,
618 hs_params: None,
619 current_learning_rate: learning_rate,
620 }
621 }
622
623 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
625 self.tokenizer = tokenizer;
626 self
627 }
628
629 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
631 self.config.vector_size = vectorsize;
632 self
633 }
634
635 pub fn with_window_size(mut self, windowsize: usize) -> Self {
637 self.config.window_size = windowsize;
638 self
639 }
640
641 pub fn with_min_count(mut self, mincount: usize) -> Self {
643 self.config.min_count = mincount;
644 self
645 }
646
647 pub fn with_epochs(mut self, epochs: usize) -> Self {
649 self.config.epochs = epochs;
650 self
651 }
652
653 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
655 self.config.learning_rate = learningrate;
656 self.current_learning_rate = learningrate;
657 self
658 }
659
660 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
662 self.config.algorithm = algorithm;
663 self
664 }
665
666 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
668 self.config.negative_samples = negativesamples;
669 self
670 }
671
672 pub fn with_subsample(mut self, subsample: f64) -> Self {
674 self.config.subsample = subsample;
675 self
676 }
677
678 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
680 self.config.batch_size = batchsize;
681 self
682 }
683
684 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
686 if texts.is_empty() {
687 return Err(TextError::InvalidInput(
688 "No texts provided for building vocabulary".into(),
689 ));
690 }
691
692 let mut word_counts = HashMap::new();
694 let mut _total_words = 0;
695
696 for &text in texts {
697 let tokens = self.tokenizer.tokenize(text)?;
698 for token in tokens {
699 *word_counts.entry(token).or_insert(0) += 1;
700 _total_words += 1;
701 }
702 }
703
704 self.vocabulary = Vocabulary::new();
706 for (word, count) in &word_counts {
707 if *count >= self.config.min_count {
708 self.vocabulary.add_token(word);
709 }
710 }
711
712 if self.vocabulary.is_empty() {
713 return Err(TextError::VocabularyError(
714 "No words meet the minimum count threshold".into(),
715 ));
716 }
717
718 let vocab_size = self.vocabulary.len();
720 let vector_size = self.config.vector_size;
721
722 let mut rng = scirs2_core::random::rng();
724 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
725 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
726 });
727 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
728 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
729 });
730
731 self.input_embeddings = Some(input_embeddings);
732 self.output_embeddings = Some(output_embeddings);
733
734 self.create_sampling_table(&word_counts)?;
736
737 if self.config.hierarchical_softmax {
739 let frequencies: Vec<usize> = (0..vocab_size)
740 .map(|i| {
741 self.vocabulary
742 .get_token(i)
743 .and_then(|word| word_counts.get(word).copied())
744 .unwrap_or(1)
745 })
746 .collect();
747
748 let tree = HuffmanTree::build(&frequencies)?;
749 let num_internal = tree.num_internal;
750
751 let hs_params = Array2::zeros((num_internal, vector_size));
753 self.hs_params = Some(hs_params);
754 self.huffman_tree = Some(tree);
755 }
756
757 Ok(())
758 }
759
760 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
762 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
764
765 for (word, &count) in wordcounts.iter() {
766 if let Some(idx) = self.vocabulary.get_index(word) {
767 sampling_weights[idx] = (count as f64).powf(0.75);
769 }
770 }
771
772 match SamplingTable::new(&sampling_weights) {
773 Ok(table) => {
774 self.sampling_table = Some(table);
775 Ok(())
776 }
777 Err(e) => Err(e),
778 }
779 }
780
781 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
783 if texts.is_empty() {
784 return Err(TextError::InvalidInput(
785 "No texts provided for training".into(),
786 ));
787 }
788
789 if self.vocabulary.is_empty() {
791 self.build_vocabulary(texts)?;
792 }
793
794 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
795 return Err(TextError::EmbeddingError(
796 "Embeddings not initialized. Call build_vocabulary() first".into(),
797 ));
798 }
799
800 let mut _total_tokens = 0;
802 let mut sentences = Vec::new();
803 for &text in texts {
804 let tokens = self.tokenizer.tokenize(text)?;
805 let filtered_tokens: Vec<usize> = tokens
806 .iter()
807 .filter_map(|token| self.vocabulary.get_index(token))
808 .collect();
809 if !filtered_tokens.is_empty() {
810 _total_tokens += filtered_tokens.len();
811 sentences.push(filtered_tokens);
812 }
813 }
814
815 for epoch in 0..self.config.epochs {
817 self.current_learning_rate =
819 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
820 self.current_learning_rate = self
821 .current_learning_rate
822 .max(self.config.learning_rate * 0.0001);
823
824 for sentence in &sentences {
826 let subsampled_sentence = if self.config.subsample > 0.0 {
828 self.subsample_sentence(sentence)?
829 } else {
830 sentence.clone()
831 };
832
833 if subsampled_sentence.is_empty() {
835 continue;
836 }
837
838 if self.config.hierarchical_softmax {
840 match self.config.algorithm {
842 Word2VecAlgorithm::SkipGram => {
843 self.train_skipgram_hs_sentence(&subsampled_sentence)?;
844 }
845 Word2VecAlgorithm::CBOW => {
846 self.train_cbow_hs_sentence(&subsampled_sentence)?;
847 }
848 }
849 } else {
850 match self.config.algorithm {
852 Word2VecAlgorithm::CBOW => {
853 self.train_cbow_sentence(&subsampled_sentence)?;
854 }
855 Word2VecAlgorithm::SkipGram => {
856 self.train_skipgram_sentence(&subsampled_sentence)?;
857 }
858 }
859 }
860 }
861 }
862
863 Ok(())
864 }
865
866 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
868 let mut rng = scirs2_core::random::rng();
869 let total_words: f64 = self.vocabulary.len() as f64;
870 let threshold = self.config.subsample * total_words;
871
872 let subsampled: Vec<usize> = sentence
874 .iter()
875 .filter(|&&word_idx| {
876 let word_freq = self.get_word_frequency(word_idx);
877 if word_freq == 0.0 {
878 return true; }
880 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
882 rng.random::<f64>() < keep_prob
883 })
884 .copied()
885 .collect();
886
887 Ok(subsampled)
888 }
889
890 fn get_word_frequency(&self, wordidx: usize) -> f64 {
892 if let Some(table) = &self.sampling_table {
895 table.weights()[wordidx]
896 } else {
897 1.0 }
899 }
900
901 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
903 if sentence.len() < 2 {
904 return Ok(()); }
906
907 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
908 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
909 let vector_size = self.config.vector_size;
910 let window_size = self.config.window_size;
911 let negative_samples = self.config.negative_samples;
912
913 for pos in 0..sentence.len() {
915 let mut rng = scirs2_core::random::rng();
917 let window = 1 + rng.random_range(0..window_size);
918 let target_word = sentence[pos];
919
920 let mut context_words = Vec::new();
922 #[allow(clippy::needless_range_loop)]
923 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
924 if i != pos {
925 context_words.push(sentence[i]);
926 }
927 }
928
929 if context_words.is_empty() {
930 continue; }
932
933 let mut context_sum = Array1::zeros(vector_size);
935 for &context_idx in &context_words {
936 context_sum += &input_embeddings.row(context_idx);
937 }
938 let context_avg = &context_sum / context_words.len() as f64;
939
940 let mut target_output = output_embeddings.row_mut(target_word);
942 let dot_product = (&context_avg * &target_output).sum();
943 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
944 let error = (1.0 - sigmoid) * self.current_learning_rate;
945
946 let mut target_update = target_output.to_owned();
948 target_update.scaled_add(error, &context_avg);
949 target_output.assign(&target_update);
950
951 if let Some(sampler) = &self.sampling_table {
953 for _ in 0..negative_samples {
954 let negative_idx = sampler.sample(&mut rng);
955 if negative_idx == target_word {
956 continue; }
958
959 let mut negative_output = output_embeddings.row_mut(negative_idx);
960 let dot_product = (&context_avg * &negative_output).sum();
961 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
962 let error = -sigmoid * self.current_learning_rate;
963
964 let mut negative_update = negative_output.to_owned();
966 negative_update.scaled_add(error, &context_avg);
967 negative_output.assign(&negative_update);
968 }
969 }
970
971 for &context_idx in &context_words {
973 let mut input_vec = input_embeddings.row_mut(context_idx);
974
975 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
977 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
978 let error =
979 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
980
981 let mut input_update = input_vec.to_owned();
983 input_update.scaled_add(error, &output_embeddings.row(target_word));
984
985 if let Some(sampler) = &self.sampling_table {
987 for _ in 0..negative_samples {
988 let negative_idx = sampler.sample(&mut rng);
989 if negative_idx == target_word {
990 continue;
991 }
992
993 let dot_product =
994 (&context_avg * &output_embeddings.row(negative_idx)).sum();
995 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
996 let error =
997 -sigmoid * self.current_learning_rate / context_words.len() as f64;
998
999 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1000 }
1001 }
1002
1003 input_vec.assign(&input_update);
1004 }
1005 }
1006
1007 Ok(())
1008 }
1009
1010 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1012 if sentence.len() < 2 {
1013 return Ok(()); }
1015
1016 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1017 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1018 let vector_size = self.config.vector_size;
1019 let window_size = self.config.window_size;
1020 let negative_samples = self.config.negative_samples;
1021
1022 for pos in 0..sentence.len() {
1024 let mut rng = scirs2_core::random::rng();
1026 let window = 1 + rng.random_range(0..window_size);
1027 let target_word = sentence[pos];
1028
1029 #[allow(clippy::needless_range_loop)]
1031 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1032 if i == pos {
1033 continue; }
1035
1036 let context_word = sentence[i];
1037 let target_input = input_embeddings.row(target_word);
1038 let mut context_output = output_embeddings.row_mut(context_word);
1039
1040 let dot_product = (&target_input * &context_output).sum();
1042 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1043 let error = (1.0 - sigmoid) * self.current_learning_rate;
1044
1045 let mut context_update = context_output.to_owned();
1047 context_update.scaled_add(error, &target_input);
1048 context_output.assign(&context_update);
1049
1050 let mut input_update = Array1::zeros(vector_size);
1052 input_update.scaled_add(error, &context_output);
1053
1054 if let Some(sampler) = &self.sampling_table {
1056 for _ in 0..negative_samples {
1057 let negative_idx = sampler.sample(&mut rng);
1058 if negative_idx == context_word {
1059 continue; }
1061
1062 let mut negative_output = output_embeddings.row_mut(negative_idx);
1063 let dot_product = (&target_input * &negative_output).sum();
1064 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1065 let error = -sigmoid * self.current_learning_rate;
1066
1067 let mut negative_update = negative_output.to_owned();
1069 negative_update.scaled_add(error, &target_input);
1070 negative_output.assign(&negative_update);
1071
1072 input_update.scaled_add(error, &negative_output);
1074 }
1075 }
1076
1077 let mut target_input_mut = input_embeddings.row_mut(target_word);
1079 target_input_mut += &input_update;
1080 }
1081 }
1082
1083 Ok(())
1084 }
1085
1086 pub fn vector_size(&self) -> usize {
1088 self.config.vector_size
1089 }
1090
1091 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1093 if self.input_embeddings.is_none() {
1094 return Err(TextError::EmbeddingError(
1095 "Model not trained. Call train() first".into(),
1096 ));
1097 }
1098
1099 match self.vocabulary.get_index(word) {
1100 Some(idx) => Ok(self
1101 .input_embeddings
1102 .as_ref()
1103 .expect("Operation failed")
1104 .row(idx)
1105 .to_owned()),
1106 None => Err(TextError::VocabularyError(format!(
1107 "Word '{word}' not in vocabulary"
1108 ))),
1109 }
1110 }
1111
1112 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1114 let word_vec = self.get_word_vector(word)?;
1115 self.most_similar_by_vector(&word_vec, topn, &[word])
1116 }
1117
1118 pub fn most_similar_by_vector(
1120 &self,
1121 vector: &Array1<f64>,
1122 top_n: usize,
1123 exclude_words: &[&str],
1124 ) -> Result<Vec<(String, f64)>> {
1125 if self.input_embeddings.is_none() {
1126 return Err(TextError::EmbeddingError(
1127 "Model not trained. Call train() first".into(),
1128 ));
1129 }
1130
1131 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1132 let vocab_size = self.vocabulary.len();
1133
1134 let exclude_indices: Vec<usize> = exclude_words
1136 .iter()
1137 .filter_map(|&word| self.vocabulary.get_index(word))
1138 .collect();
1139
1140 let mut similarities = Vec::with_capacity(vocab_size);
1142
1143 for i in 0..vocab_size {
1144 if exclude_indices.contains(&i) {
1145 continue;
1146 }
1147
1148 let word_vec = input_embeddings.row(i);
1149 let similarity = cosine_similarity(vector, &word_vec.to_owned());
1150
1151 if let Some(word) = self.vocabulary.get_token(i) {
1152 similarities.push((word.to_string(), similarity));
1153 }
1154 }
1155
1156 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1158
1159 let result = similarities.into_iter().take(top_n).collect();
1161 Ok(result)
1162 }
1163
1164 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1166 if self.input_embeddings.is_none() {
1167 return Err(TextError::EmbeddingError(
1168 "Model not trained. Call train() first".into(),
1169 ));
1170 }
1171
1172 let a_vec = self.get_word_vector(a)?;
1174 let b_vec = self.get_word_vector(b)?;
1175 let c_vec = self.get_word_vector(c)?;
1176
1177 let mut d_vec = b_vec.clone();
1179 d_vec -= &a_vec;
1180 d_vec += &c_vec;
1181
1182 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1184 d_vec.mapv_inplace(|val| val / norm);
1185
1186 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1188 }
1189
1190 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1192 if self.input_embeddings.is_none() {
1193 return Err(TextError::EmbeddingError(
1194 "Model not trained. Call train() first".into(),
1195 ));
1196 }
1197
1198 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1199
1200 writeln!(
1202 &mut file,
1203 "{} {}",
1204 self.vocabulary.len(),
1205 self.config.vector_size
1206 )
1207 .map_err(|e| TextError::IoError(e.to_string()))?;
1208
1209 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1211
1212 for i in 0..self.vocabulary.len() {
1213 if let Some(word) = self.vocabulary.get_token(i) {
1214 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1216
1217 let vector = input_embeddings.row(i);
1219 for j in 0..self.config.vector_size {
1220 write!(&mut file, "{:.6} ", vector[j])
1221 .map_err(|e| TextError::IoError(e.to_string()))?;
1222 }
1223
1224 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1225 }
1226 }
1227
1228 Ok(())
1229 }
1230
1231 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1233 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1234 let mut reader = BufReader::new(file);
1235
1236 let mut header = String::new();
1238 reader
1239 .read_line(&mut header)
1240 .map_err(|e| TextError::IoError(e.to_string()))?;
1241
1242 let parts: Vec<&str> = header.split_whitespace().collect();
1243 if parts.len() != 2 {
1244 return Err(TextError::EmbeddingError(
1245 "Invalid model file format".into(),
1246 ));
1247 }
1248
1249 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1250 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1251 })?;
1252
1253 let vector_size = parts[1]
1254 .parse::<usize>()
1255 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1256
1257 let mut model = Self::new().with_vector_size(vector_size);
1259 let mut vocabulary = Vocabulary::new();
1260 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1261
1262 let mut i = 0;
1264 for line in reader.lines() {
1265 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1266 let parts: Vec<&str> = line.split_whitespace().collect();
1267
1268 if parts.len() != vector_size + 1 {
1269 let line_num = i + 2;
1270 return Err(TextError::EmbeddingError(format!(
1271 "Invalid vector format at line {line_num}"
1272 )));
1273 }
1274
1275 let word = parts[0];
1276 vocabulary.add_token(word);
1277
1278 for j in 0..vector_size {
1279 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1280 TextError::EmbeddingError(format!(
1281 "Invalid vector component at line {}, position {}",
1282 i + 2,
1283 j + 1
1284 ))
1285 })?;
1286 }
1287
1288 i += 1;
1289 }
1290
1291 if i != vocab_size {
1292 return Err(TextError::EmbeddingError(format!(
1293 "Expected {vocab_size} words but found {i}"
1294 )));
1295 }
1296
1297 model.vocabulary = vocabulary;
1298 model.input_embeddings = Some(input_embeddings);
1299 model.output_embeddings = None; Ok(model)
1302 }
1303
1304 pub fn get_vocabulary(&self) -> Vec<String> {
1308 let mut vocab = Vec::new();
1309 for i in 0..self.vocabulary.len() {
1310 if let Some(token) = self.vocabulary.get_token(i) {
1311 vocab.push(token.to_string());
1312 }
1313 }
1314 vocab
1315 }
1316
1317 pub fn get_vector_size(&self) -> usize {
1319 self.config.vector_size
1320 }
1321
1322 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1324 self.config.algorithm
1325 }
1326
1327 pub fn get_window_size(&self) -> usize {
1329 self.config.window_size
1330 }
1331
1332 pub fn get_min_count(&self) -> usize {
1334 self.config.min_count
1335 }
1336
1337 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1339 self.input_embeddings.clone()
1340 }
1341
1342 pub fn get_negative_samples(&self) -> usize {
1344 self.config.negative_samples
1345 }
1346
1347 pub fn get_learning_rate(&self) -> f64 {
1349 self.config.learning_rate
1350 }
1351
1352 pub fn get_epochs(&self) -> usize {
1354 self.config.epochs
1355 }
1356
1357 pub fn get_subsampling_threshold(&self) -> f64 {
1359 self.config.subsample
1360 }
1361
1362 pub fn uses_hierarchical_softmax(&self) -> bool {
1364 self.config.hierarchical_softmax
1365 }
1366
1367 fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1371 if sentence.len() < 2 {
1372 return Ok(());
1373 }
1374
1375 let input_embeddings = self
1376 .input_embeddings
1377 .as_mut()
1378 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1379 let hs_params = self
1380 .hs_params
1381 .as_mut()
1382 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1383 let tree = self
1384 .huffman_tree
1385 .as_ref()
1386 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1387
1388 let vector_size = self.config.vector_size;
1389 let window_size = self.config.window_size;
1390 let lr = self.current_learning_rate;
1391
1392 let codes = tree.codes.clone();
1393 let paths = tree.paths.clone();
1394
1395 let mut rng = scirs2_core::random::rng();
1396
1397 for pos in 0..sentence.len() {
1398 let window = 1 + rng.random_range(0..window_size);
1399 let target_word = sentence[pos];
1400
1401 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1402 if i == pos {
1403 continue;
1404 }
1405
1406 let context_word = sentence[i];
1407 let code = &codes[context_word];
1408 let path = &paths[context_word];
1409
1410 let mut grad_input = Array1::zeros(vector_size);
1411
1412 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1414 if node_idx >= hs_params.nrows() {
1415 continue;
1416 }
1417
1418 let input_vec = input_embeddings.row(target_word);
1420 let param_vec = hs_params.row(node_idx);
1421
1422 let dot: f64 = input_vec
1423 .iter()
1424 .zip(param_vec.iter())
1425 .map(|(a, b)| a * b)
1426 .sum();
1427 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1428
1429 let target = if label == 0 { 1.0 } else { 0.0 };
1431 let gradient = (target - sigmoid) * lr;
1432
1433 grad_input.scaled_add(gradient, ¶m_vec.to_owned());
1435
1436 let input_owned = input_vec.to_owned();
1438 let mut param_mut = hs_params.row_mut(node_idx);
1439 param_mut.scaled_add(gradient, &input_owned);
1440 }
1441
1442 let mut input_mut = input_embeddings.row_mut(target_word);
1444 input_mut += &grad_input;
1445 }
1446 }
1447
1448 Ok(())
1449 }
1450
1451 fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1453 if sentence.len() < 2 {
1454 return Ok(());
1455 }
1456
1457 let input_embeddings = self
1458 .input_embeddings
1459 .as_mut()
1460 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1461 let hs_params = self
1462 .hs_params
1463 .as_mut()
1464 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1465 let tree = self
1466 .huffman_tree
1467 .as_ref()
1468 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1469
1470 let vector_size = self.config.vector_size;
1471 let window_size = self.config.window_size;
1472 let lr = self.current_learning_rate;
1473
1474 let codes = tree.codes.clone();
1475 let paths = tree.paths.clone();
1476
1477 let mut rng = scirs2_core::random::rng();
1478
1479 for pos in 0..sentence.len() {
1480 let window = 1 + rng.random_range(0..window_size);
1481 let target_word = sentence[pos];
1482
1483 let mut context_words = Vec::new();
1485 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1486 if i != pos {
1487 context_words.push(sentence[i]);
1488 }
1489 }
1490
1491 if context_words.is_empty() {
1492 continue;
1493 }
1494
1495 let mut context_avg = Array1::zeros(vector_size);
1497 for &ctx_idx in &context_words {
1498 context_avg += &input_embeddings.row(ctx_idx);
1499 }
1500 context_avg /= context_words.len() as f64;
1501
1502 let code = &codes[target_word];
1504 let path = &paths[target_word];
1505
1506 let mut grad_context = Array1::zeros(vector_size);
1507
1508 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1509 if node_idx >= hs_params.nrows() {
1510 continue;
1511 }
1512
1513 let param_vec = hs_params.row(node_idx);
1514
1515 let dot: f64 = context_avg
1516 .iter()
1517 .zip(param_vec.iter())
1518 .map(|(a, b)| a * b)
1519 .sum();
1520 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1521
1522 let target = if label == 0 { 1.0 } else { 0.0 };
1523 let gradient = (target - sigmoid) * lr;
1524
1525 grad_context.scaled_add(gradient, ¶m_vec.to_owned());
1526
1527 let ctx_owned = context_avg.clone();
1529 let mut param_mut = hs_params.row_mut(node_idx);
1530 param_mut.scaled_add(gradient, &ctx_owned);
1531 }
1532
1533 let grad_per_word = &grad_context / context_words.len() as f64;
1535 for &ctx_idx in &context_words {
1536 let mut input_mut = input_embeddings.row_mut(ctx_idx);
1537 input_mut += &grad_per_word;
1538 }
1539 }
1540
1541 Ok(())
1542 }
1543}
1544
1545impl WordEmbedding for Word2Vec {
1548 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1549 self.get_word_vector(word)
1550 }
1551
1552 fn dimension(&self) -> usize {
1553 self.vector_size()
1554 }
1555
1556 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1557 self.most_similar(word, top_n)
1558 }
1559
1560 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1561 self.analogy(a, b, c, top_n)
1562 }
1563
1564 fn vocab_size(&self) -> usize {
1565 self.vocabulary.len()
1566 }
1567}
1568
1569#[allow(dead_code)]
1571pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1572 let dot_product = (a * b).sum();
1573 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1574 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1575
1576 if norm_a > 0.0 && norm_b > 0.0 {
1577 dot_product / (norm_a * norm_b)
1578 } else {
1579 0.0
1580 }
1581}
1582
1583#[cfg(test)]
1584mod tests {
1585 use super::*;
1586 use approx::assert_relative_eq;
1587
1588 #[test]
1589 fn test_cosine_similarity() {
1590 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1591 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1592
1593 let similarity = cosine_similarity(&a, &b);
1594 let expected = 0.9746318461970762;
1595 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1596 }
1597
1598 #[test]
1599 fn test_word2vec_config() {
1600 let config = Word2VecConfig::default();
1601 assert_eq!(config.vector_size, 100);
1602 assert_eq!(config.window_size, 5);
1603 assert_eq!(config.min_count, 5);
1604 assert_eq!(config.epochs, 5);
1605 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1606 }
1607
1608 #[test]
1609 fn test_word2vec_builder() {
1610 let model = Word2Vec::new()
1611 .with_vector_size(200)
1612 .with_window_size(10)
1613 .with_learning_rate(0.05)
1614 .with_algorithm(Word2VecAlgorithm::CBOW);
1615
1616 assert_eq!(model.config.vector_size, 200);
1617 assert_eq!(model.config.window_size, 10);
1618 assert_eq!(model.config.learning_rate, 0.05);
1619 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1620 }
1621
1622 #[test]
1623 fn test_build_vocabulary() {
1624 let texts = [
1625 "the quick brown fox jumps over the lazy dog",
1626 "a quick brown fox jumps over a lazy dog",
1627 ];
1628
1629 let mut model = Word2Vec::new().with_min_count(1);
1630 let result = model.build_vocabulary(&texts);
1631 assert!(result.is_ok());
1632
1633 assert_eq!(model.vocabulary.len(), 9);
1635
1636 assert!(model.input_embeddings.is_some());
1638 assert!(model.output_embeddings.is_some());
1639 assert_eq!(
1640 model
1641 .input_embeddings
1642 .as_ref()
1643 .expect("Operation failed")
1644 .shape(),
1645 &[9, 100]
1646 );
1647 }
1648
1649 #[test]
1650 fn test_skipgram_training_small() {
1651 let texts = [
1652 "the quick brown fox jumps over the lazy dog",
1653 "a quick brown fox jumps over a lazy dog",
1654 ];
1655
1656 let mut model = Word2Vec::new()
1657 .with_vector_size(10)
1658 .with_window_size(2)
1659 .with_min_count(1)
1660 .with_epochs(1)
1661 .with_algorithm(Word2VecAlgorithm::SkipGram);
1662
1663 let result = model.train(&texts);
1664 assert!(result.is_ok());
1665
1666 let result = model.get_word_vector("fox");
1668 assert!(result.is_ok());
1669 let vec = result.expect("Operation failed");
1670 assert_eq!(vec.len(), 10);
1671 }
1672
1673 #[test]
1676 fn test_huffman_tree_build() {
1677 let frequencies = vec![5, 3, 8, 1, 2];
1678 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1679
1680 assert_eq!(tree.codes.len(), 5);
1682 assert_eq!(tree.paths.len(), 5);
1683
1684 for code in &tree.codes {
1686 assert!(!code.is_empty());
1687 }
1688
1689 assert_eq!(tree.num_internal, 4);
1691 }
1692
1693 #[test]
1694 fn test_huffman_tree_single_word() {
1695 let frequencies = vec![10];
1696 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1697 assert_eq!(tree.codes.len(), 1);
1698 assert_eq!(tree.paths.len(), 1);
1699 }
1700
1701 #[test]
1702 fn test_skipgram_hierarchical_softmax() {
1703 let texts = [
1704 "the quick brown fox jumps over the lazy dog",
1705 "a quick brown fox jumps over a lazy dog",
1706 ];
1707
1708 let config = Word2VecConfig {
1709 vector_size: 10,
1710 window_size: 2,
1711 min_count: 1,
1712 epochs: 3,
1713 learning_rate: 0.025,
1714 algorithm: Word2VecAlgorithm::SkipGram,
1715 hierarchical_softmax: true,
1716 ..Default::default()
1717 };
1718
1719 let mut model = Word2Vec::with_config(config);
1720 let result = model.train(&texts);
1721 assert!(
1722 result.is_ok(),
1723 "HS skipgram training failed: {:?}",
1724 result.err()
1725 );
1726
1727 assert!(model.uses_hierarchical_softmax());
1728
1729 let vec = model.get_word_vector("fox");
1731 assert!(vec.is_ok());
1732 assert_eq!(vec.expect("get vec").len(), 10);
1733 }
1734
1735 #[test]
1736 fn test_cbow_hierarchical_softmax() {
1737 let texts = [
1738 "the quick brown fox jumps over the lazy dog",
1739 "a quick brown fox jumps over a lazy dog",
1740 ];
1741
1742 let config = Word2VecConfig {
1743 vector_size: 10,
1744 window_size: 2,
1745 min_count: 1,
1746 epochs: 3,
1747 learning_rate: 0.025,
1748 algorithm: Word2VecAlgorithm::CBOW,
1749 hierarchical_softmax: true,
1750 ..Default::default()
1751 };
1752
1753 let mut model = Word2Vec::with_config(config);
1754 let result = model.train(&texts);
1755 assert!(
1756 result.is_ok(),
1757 "HS CBOW training failed: {:?}",
1758 result.err()
1759 );
1760
1761 let vec = model.get_word_vector("dog");
1762 assert!(vec.is_ok());
1763 }
1764
1765 #[test]
1768 fn test_word_embedding_trait_word2vec() {
1769 let texts = [
1770 "the quick brown fox jumps over the lazy dog",
1771 "a quick brown fox jumps over a lazy dog",
1772 ];
1773
1774 let mut model = Word2Vec::new()
1775 .with_vector_size(10)
1776 .with_min_count(1)
1777 .with_epochs(1);
1778
1779 model.train(&texts).expect("Training failed");
1780
1781 let emb: &dyn WordEmbedding = &model;
1783 assert_eq!(emb.dimension(), 10);
1784 assert!(emb.vocab_size() > 0);
1785
1786 let vec = emb.embedding("fox");
1787 assert!(vec.is_ok());
1788
1789 let sim = emb.similarity("fox", "dog");
1790 assert!(sim.is_ok());
1791 assert!(sim.expect("sim").is_finite());
1792
1793 let similar = emb.find_similar("fox", 2);
1794 assert!(similar.is_ok());
1795
1796 let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1797 assert!(analogy.is_ok());
1798 }
1799
1800 #[test]
1801 fn test_embedding_cosine_similarity_fn() {
1802 let a = Array1::from_vec(vec![1.0, 0.0]);
1803 let b = Array1::from_vec(vec![0.0, 1.0]);
1804 assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1805
1806 let c = Array1::from_vec(vec![1.0, 1.0]);
1807 let d = Array1::from_vec(vec![1.0, 1.0]);
1808 assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1809 }
1810
1811 #[test]
1812 fn test_pairwise_similarity_fn() {
1813 let texts = ["the quick brown fox", "the lazy brown dog"];
1814
1815 let mut model = Word2Vec::new()
1816 .with_vector_size(10)
1817 .with_min_count(1)
1818 .with_epochs(1);
1819 model.train(&texts).expect("Training failed");
1820
1821 let words = vec!["the", "fox", "dog"];
1822 let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1823
1824 assert_eq!(matrix.len(), 3);
1825 assert_eq!(matrix[0].len(), 3);
1826
1827 for i in 0..3 {
1829 assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1830 }
1831
1832 for i in 0..3 {
1834 for j in 0..3 {
1835 assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1836 }
1837 }
1838 }
1839}