word2vec_example/
word2vec_example.rs1use scirs2_core::ndarray::Array1;
2use scirs2_text::embeddings::{cosine_similarity, Word2Vec, Word2VecAlgorithm};
3use std::time::Instant;
4
5#[allow(dead_code)]
6fn main() {
7 println!("Word2Vec Example");
8 println!("================\n");
9
10 let corpus = [
12 "the quick brown fox jumps over the lazy dog",
13 "a quick brown fox jumps over a lazy dog",
14 "the fox is quick and brown",
15 "the dog is lazy and sleepy",
16 "quick brown foxes jump over lazy dogs",
17 "the quick fox jumped over the lazy sleeping dog",
18 "a brown dog chased the quick fox",
19 "foxes and dogs are natural enemies",
20 "the quick brown cat jumps over the lazy fox",
21 "a quick brown cat jumps over a lazy fox",
22 ];
23
24 println!("Training Word2Vec model on a small corpus...");
25 let start = Instant::now();
26
27 let mut skipgram_model = Word2Vec::new()
29 .with_vector_size(50)
30 .with_window_size(3)
31 .with_min_count(1)
32 .with_epochs(100)
33 .with_algorithm(Word2VecAlgorithm::SkipGram)
34 .with_negative_samples(5);
35
36 skipgram_model
38 .train(&corpus)
39 .expect("Failed to train Skip-gram model");
40 let elapsed = start.elapsed();
41
42 println!(
43 "Training completed in {:.2} seconds\n",
44 elapsed.as_secs_f32()
45 );
46
47 println!("Finding words similar to 'fox':");
49 let similar_to_fox = skipgram_model
50 .most_similar("fox", 5)
51 .expect("Failed to find similar words");
52
53 for (word, similarity) in similar_to_fox {
54 println!("{word}: {similarity:.4}");
55 }
56
57 println!("\nFinding words similar to 'dog':");
58 let similar_to_dog = skipgram_model
59 .most_similar("dog", 5)
60 .expect("Failed to find similar words");
61
62 for (word, similarity) in similar_to_dog {
63 println!("{word}: {similarity:.4}");
64 }
65
66 println!("\nAnalogy: fox is to dog as quick is to ?");
68 let analogy_result = skipgram_model
69 .analogy("fox", "dog", "quick", 3)
70 .expect("Failed to compute analogy");
71
72 for (word, similarity) in analogy_result {
73 println!("{word}: {similarity:.4}");
74 }
75
76 println!("\nComparing word vectors:");
78 let fox_vector = skipgram_model
79 .get_word_vector("fox")
80 .expect("Failed to get vector for 'fox'");
81 let dog_vector = skipgram_model
82 .get_word_vector("dog")
83 .expect("Failed to get vector for 'dog'");
84 let quick_vector = skipgram_model
85 .get_word_vector("quick")
86 .expect("Failed to get vector for 'quick'");
87
88 println!(
89 "Cosine similarity between 'fox' and 'dog': {:.4}",
90 cosine_similarity(&fox_vector, &dog_vector)
91 );
92 println!(
93 "Cosine similarity between 'fox' and 'quick': {:.4}",
94 cosine_similarity(&fox_vector, &quick_vector)
95 );
96 println!(
97 "Cosine similarity between 'quick' and 'dog': {:.4}",
98 cosine_similarity(&quick_vector, &dog_vector)
99 );
100
101 println!("\nTraining CBOW model on the same corpus...");
103 let start = Instant::now();
104
105 let mut cbow_model = Word2Vec::new()
106 .with_vector_size(50)
107 .with_window_size(3)
108 .with_min_count(1)
109 .with_epochs(100)
110 .with_algorithm(Word2VecAlgorithm::CBOW)
111 .with_negative_samples(5);
112
113 cbow_model
114 .train(&corpus)
115 .expect("Failed to train CBOW model");
116 let elapsed = start.elapsed();
117
118 println!(
119 "Training completed in {:.2} seconds\n",
120 elapsed.as_secs_f32()
121 );
122
123 println!("CBOW model - Words similar to 'fox':");
125 let similar_to_fox_cbow = cbow_model
126 .most_similar("fox", 5)
127 .expect("Failed to find similar words");
128
129 for (word, similarity) in similar_to_fox_cbow {
130 println!("{word}: {similarity:.4}");
131 }
132
133 println!("\nVector arithmetic: fox - dog + cat = ?");
135
136 let fox_vec = skipgram_model.get_word_vector("fox").unwrap();
138 let dog_vec = skipgram_model.get_word_vector("dog").unwrap();
139 let cat_vec = skipgram_model.get_word_vector("cat").unwrap();
140
141 let mut result_vec = Array1::zeros(fox_vec.dim());
143 result_vec.assign(&fox_vec);
144 result_vec -= &dog_vec;
145 result_vec += &cat_vec;
146
147 let norm = (result_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
149 result_vec.mapv_inplace(|val| val / norm);
150
151 let similar_to_result = skipgram_model
153 .most_similar_by_vector(&result_vec, 5, &["fox", "dog", "cat"])
154 .expect("Failed to find similar words");
155
156 for (word, similarity) in similar_to_result {
157 println!("{word}: {similarity:.4}");
158 }
159
160 println!("\nSaving and loading the model...");
162 skipgram_model
163 .save("word2vec_model.txt")
164 .expect("Failed to save model");
165 println!("Model saved to 'word2vec_model.txt'");
166
167 let loaded_model = Word2Vec::load("word2vec_model.txt").expect("Failed to load model");
168 println!("Model loaded successfully");
169
170 let similar_words_loaded = loaded_model
172 .most_similar("fox", 3)
173 .expect("Failed to find similar words with loaded model");
174
175 println!("\nWords similar to 'fox' using loaded model:");
176 for (word, similarity) in similar_words_loaded {
177 println!("{word}: {similarity:.4}");
178 }
179}