word2vec_example/
word2vec_example.rs

1use scirs2_core::ndarray::Array1;
2use scirs2_text::embeddings::{cosine_similarity, Word2Vec, Word2VecAlgorithm};
3use std::time::Instant;
4
5#[allow(dead_code)]
6fn main() {
7    println!("Word2Vec Example");
8    println!("================\n");
9
10    // Sample corpus for demonstration
11    let corpus = [
12        "the quick brown fox jumps over the lazy dog",
13        "a quick brown fox jumps over a lazy dog",
14        "the fox is quick and brown",
15        "the dog is lazy and sleepy",
16        "quick brown foxes jump over lazy dogs",
17        "the quick fox jumped over the lazy sleeping dog",
18        "a brown dog chased the quick fox",
19        "foxes and dogs are natural enemies",
20        "the quick brown cat jumps over the lazy fox",
21        "a quick brown cat jumps over a lazy fox",
22    ];
23
24    println!("Training Word2Vec model on a small corpus...");
25    let start = Instant::now();
26
27    // Create a Word2Vec model with Skip-gram algorithm
28    let mut skipgram_model = Word2Vec::new()
29        .with_vector_size(50)
30        .with_window_size(3)
31        .with_min_count(1)
32        .with_epochs(100)
33        .with_algorithm(Word2VecAlgorithm::SkipGram)
34        .with_negative_samples(5);
35
36    // Train the model
37    skipgram_model
38        .train(&corpus)
39        .expect("Failed to train Skip-gram model");
40    let elapsed = start.elapsed();
41
42    println!(
43        "Training completed in {:.2} seconds\n",
44        elapsed.as_secs_f32()
45    );
46
47    // Find similar words
48    println!("Finding words similar to 'fox':");
49    let similar_to_fox = skipgram_model
50        .most_similar("fox", 5)
51        .expect("Failed to find similar words");
52
53    for (word, similarity) in similar_to_fox {
54        println!("{word}: {similarity:.4}");
55    }
56
57    println!("\nFinding words similar to 'dog':");
58    let similar_to_dog = skipgram_model
59        .most_similar("dog", 5)
60        .expect("Failed to find similar words");
61
62    for (word, similarity) in similar_to_dog {
63        println!("{word}: {similarity:.4}");
64    }
65
66    // Compute analogies (e.g., fox is to dog as quick is to ?)
67    println!("\nAnalogy: fox is to dog as quick is to ?");
68    let analogy_result = skipgram_model
69        .analogy("fox", "dog", "quick", 3)
70        .expect("Failed to compute analogy");
71
72    for (word, similarity) in analogy_result {
73        println!("{word}: {similarity:.4}");
74    }
75
76    // Get word vectors and calculate cosine similarity manually
77    println!("\nComparing word vectors:");
78    let fox_vector = skipgram_model
79        .get_word_vector("fox")
80        .expect("Failed to get vector for 'fox'");
81    let dog_vector = skipgram_model
82        .get_word_vector("dog")
83        .expect("Failed to get vector for 'dog'");
84    let quick_vector = skipgram_model
85        .get_word_vector("quick")
86        .expect("Failed to get vector for 'quick'");
87
88    println!(
89        "Cosine similarity between 'fox' and 'dog': {:.4}",
90        cosine_similarity(&fox_vector, &dog_vector)
91    );
92    println!(
93        "Cosine similarity between 'fox' and 'quick': {:.4}",
94        cosine_similarity(&fox_vector, &quick_vector)
95    );
96    println!(
97        "Cosine similarity between 'quick' and 'dog': {:.4}",
98        cosine_similarity(&quick_vector, &dog_vector)
99    );
100
101    // Train a CBOW model on the same corpus
102    println!("\nTraining CBOW model on the same corpus...");
103    let start = Instant::now();
104
105    let mut cbow_model = Word2Vec::new()
106        .with_vector_size(50)
107        .with_window_size(3)
108        .with_min_count(1)
109        .with_epochs(100)
110        .with_algorithm(Word2VecAlgorithm::CBOW)
111        .with_negative_samples(5);
112
113    cbow_model
114        .train(&corpus)
115        .expect("Failed to train CBOW model");
116    let elapsed = start.elapsed();
117
118    println!(
119        "Training completed in {:.2} seconds\n",
120        elapsed.as_secs_f32()
121    );
122
123    // Compare results from CBOW model
124    println!("CBOW model - Words similar to 'fox':");
125    let similar_to_fox_cbow = cbow_model
126        .most_similar("fox", 5)
127        .expect("Failed to find similar words");
128
129    for (word, similarity) in similar_to_fox_cbow {
130        println!("{word}: {similarity:.4}");
131    }
132
133    // Vector arithmetic: fox - dog + cat = ?
134    println!("\nVector arithmetic: fox - dog + cat = ?");
135
136    // Manual vector arithmetic
137    let fox_vec = skipgram_model.get_word_vector("fox").unwrap();
138    let dog_vec = skipgram_model.get_word_vector("dog").unwrap();
139    let cat_vec = skipgram_model.get_word_vector("cat").unwrap();
140
141    // Compute the result vector
142    let mut result_vec = Array1::zeros(fox_vec.dim());
143    result_vec.assign(&fox_vec);
144    result_vec -= &dog_vec;
145    result_vec += &cat_vec;
146
147    // Normalize the vector
148    let norm = (result_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
149    result_vec.mapv_inplace(|val| val / norm);
150
151    // Find words similar to the result vector
152    let similar_to_result = skipgram_model
153        .most_similar_by_vector(&result_vec, 5, &["fox", "dog", "cat"])
154        .expect("Failed to find similar words");
155
156    for (word, similarity) in similar_to_result {
157        println!("{word}: {similarity:.4}");
158    }
159
160    // Save and load the model
161    println!("\nSaving and loading the model...");
162    skipgram_model
163        .save("word2vec_model.txt")
164        .expect("Failed to save model");
165    println!("Model saved to 'word2vec_model.txt'");
166
167    let loaded_model = Word2Vec::load("word2vec_model.txt").expect("Failed to load model");
168    println!("Model loaded successfully");
169
170    // Verify the loaded model works
171    let similar_words_loaded = loaded_model
172        .most_similar("fox", 3)
173        .expect("Failed to find similar words with loaded model");
174
175    println!("\nWords similar to 'fox' using loaded model:");
176    for (word, similarity) in similar_words_loaded {
177        println!("{word}: {similarity:.4}");
178    }
179}