Skip to main content

scirs2_graph/embeddings/
deepwalk.rs

1//! DeepWalk graph embedding algorithm
2//!
3//! Implements the DeepWalk algorithm from Perozzi et al. (2014) for learning
4//! latent representations of nodes using short uniform random walks.
5//! Supports both negative sampling and hierarchical softmax approximation.
6//!
7//! # References
8//! - Perozzi, B., Al-Rfou, R., & Skiena, S. (2014). DeepWalk: Online Learning
9//!   of Social Representations. KDD 2014.
10
11use super::core::{Embedding, EmbeddingModel};
12use super::negative_sampling::NegativeSampler;
13use super::random_walk::RandomWalkGenerator;
14use super::types::{DeepWalkConfig, RandomWalk};
15use crate::base::{DiGraph, EdgeWeight, Graph, Node};
16use crate::error::{GraphError, Result};
17use scirs2_core::random::seq::SliceRandom;
18use scirs2_core::random::Rng;
19use std::collections::HashMap;
20
21/// Training mode for DeepWalk
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DeepWalkMode {
24    /// Standard negative sampling (default, simpler and often faster)
25    NegativeSampling,
26    /// Hierarchical softmax using a Huffman tree built from node frequencies
27    HierarchicalSoftmax,
28}
29
30/// A node in the Huffman tree used for hierarchical softmax
31#[derive(Debug, Clone)]
32struct HuffmanNode {
33    /// Path from root (true = right, false = left)
34    code: Vec<bool>,
35    /// Indices of internal nodes along the path
36    point: Vec<usize>,
37}
38
39/// Huffman tree for hierarchical softmax
40#[derive(Debug)]
41struct HuffmanTree {
42    /// Huffman encoding for each leaf node (indexed by node index)
43    codes: Vec<HuffmanNode>,
44    /// Number of internal nodes
45    num_internal: usize,
46}
47
48impl HuffmanTree {
49    /// Build a Huffman tree from node frequencies
50    fn build(frequencies: &[f64]) -> Result<Self> {
51        let n = frequencies.len();
52        if n == 0 {
53            return Err(GraphError::InvalidGraph(
54                "Cannot build Huffman tree from empty frequency list".to_string(),
55            ));
56        }
57
58        if n == 1 {
59            // Single node: trivial encoding
60            let codes = vec![HuffmanNode {
61                code: vec![false],
62                point: vec![0],
63            }];
64            return Ok(HuffmanTree {
65                codes,
66                num_internal: 1,
67            });
68        }
69
70        // Build Huffman tree using a priority queue simulation
71        // Total nodes = n leaves + (n-1) internal nodes
72        let total = 2 * n - 1;
73        let mut count = vec![0.0f64; total];
74        let mut parent = vec![0usize; total];
75        let mut binary = vec![false; total]; // true = right child
76
77        // Initialize leaf frequencies
78        for (i, &freq) in frequencies.iter().enumerate() {
79            count[i] = freq.max(1e-10); // Avoid zero frequencies
80        }
81
82        // Initialize internal node counts to large value
83        for i in n..total {
84            count[i] = f64::MAX;
85        }
86
87        // Build tree bottom-up
88        let mut pos1 = n - 1; // Position scanning leaves (right to left, sorted by freq)
89        let mut pos2 = n; // Position scanning internal nodes
90
91        // Sort leaf indices by frequency (ascending)
92        let mut sorted_indices: Vec<usize> = (0..n).collect();
93        sorted_indices.sort_by(|&a, &b| {
94            count[a]
95                .partial_cmp(&count[b])
96                .unwrap_or(std::cmp::Ordering::Equal)
97        });
98
99        // Reorder counts by sorted order
100        let mut sorted_counts = vec![0.0; n];
101        let mut reverse_map = vec![0usize; n]; // original_index -> sorted_position
102        for (sorted_pos, &orig_idx) in sorted_indices.iter().enumerate() {
103            sorted_counts[sorted_pos] = count[orig_idx];
104            reverse_map[orig_idx] = sorted_pos;
105        }
106        count[..n].copy_from_slice(&sorted_counts[..n]);
107
108        // Build internal nodes
109        for internal_idx in n..total {
110            // Find two nodes with smallest counts
111            let min1;
112            let min2;
113
114            // First minimum
115            if pos1 < n && (pos2 >= internal_idx || count[pos1] < count[pos2]) {
116                min1 = pos1;
117                pos1 = pos1.wrapping_sub(1); // will wrap to usize::MAX when 0
118                if pos1 == usize::MAX {
119                    pos1 = n; // sentinel: no more leaves
120                }
121            } else {
122                min1 = pos2;
123                pos2 += 1;
124            }
125
126            // Second minimum
127            if pos1 < n && (pos2 >= internal_idx || count[pos1] < count[pos2]) {
128                min2 = pos1;
129                pos1 = pos1.wrapping_sub(1);
130                if pos1 == usize::MAX {
131                    pos1 = n;
132                }
133            } else if pos2 < internal_idx {
134                min2 = pos2;
135                pos2 += 1;
136            } else {
137                min2 = min1; // Fallback (shouldn't happen with valid input)
138            }
139
140            count[internal_idx] = count[min1] + count[min2];
141            parent[min1] = internal_idx;
142            parent[min2] = internal_idx;
143            binary[min2] = true; // Right child
144        }
145
146        // Generate codes by traversing from each leaf to root
147        let mut codes = vec![
148            HuffmanNode {
149                code: Vec::new(),
150                point: Vec::new(),
151            };
152            n
153        ];
154
155        for sorted_pos in 0..n {
156            let mut code = Vec::new();
157            let mut point = Vec::new();
158
159            let mut current = sorted_pos;
160            while current < total - 1 {
161                // Not root
162                code.push(binary[current]);
163                let par = parent[current];
164                // Internal node index = par - n (0-indexed)
165                if par >= n {
166                    point.push(par - n);
167                }
168                current = par;
169            }
170
171            // Reverse to get root-to-leaf order
172            code.reverse();
173            point.reverse();
174
175            // Map back from sorted position to original index
176            let orig_idx = sorted_indices[sorted_pos];
177            codes[orig_idx] = HuffmanNode { code, point };
178        }
179
180        Ok(HuffmanTree {
181            codes,
182            num_internal: n - 1,
183        })
184    }
185}
186
187/// DeepWalk embedding algorithm
188///
189/// Learns node embeddings using uniform random walks followed by
190/// skip-gram optimization. Supports both negative sampling and
191/// hierarchical softmax.
192pub struct DeepWalk<N: Node> {
193    config: DeepWalkConfig,
194    model: EmbeddingModel<N>,
195    walk_generator: RandomWalkGenerator<N>,
196    /// Training mode
197    mode: DeepWalkMode,
198    /// Internal node vectors for hierarchical softmax
199    internal_vectors: Vec<Vec<f64>>,
200}
201
202impl<N: Node> DeepWalk<N> {
203    /// Create a new DeepWalk instance with negative sampling (default)
204    pub fn new(config: DeepWalkConfig) -> Self {
205        DeepWalk {
206            model: EmbeddingModel::new(config.dimensions),
207            config,
208            walk_generator: RandomWalkGenerator::new(),
209            mode: DeepWalkMode::NegativeSampling,
210            internal_vectors: Vec::new(),
211        }
212    }
213
214    /// Create a new DeepWalk instance with hierarchical softmax
215    pub fn with_hierarchical_softmax(config: DeepWalkConfig) -> Self {
216        DeepWalk {
217            model: EmbeddingModel::new(config.dimensions),
218            config,
219            walk_generator: RandomWalkGenerator::new(),
220            mode: DeepWalkMode::HierarchicalSoftmax,
221            internal_vectors: Vec::new(),
222        }
223    }
224
225    /// Set the training mode
226    pub fn set_mode(&mut self, mode: DeepWalkMode) {
227        self.mode = mode;
228    }
229
230    /// Get the current training mode
231    pub fn mode(&self) -> DeepWalkMode {
232        self.mode
233    }
234
235    /// Generate training data (uniform random walks) on undirected graph
236    pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
237    where
238        N: Clone + std::fmt::Debug,
239        E: EdgeWeight,
240        Ix: petgraph::graph::IndexType,
241    {
242        let mut all_walks = Vec::new();
243
244        for node in graph.nodes() {
245            for _ in 0..self.config.num_walks {
246                let walk =
247                    self.walk_generator
248                        .simple_random_walk(graph, node, self.config.walk_length)?;
249                all_walks.push(walk);
250            }
251        }
252
253        Ok(all_walks)
254    }
255
256    /// Generate training data (uniform random walks) on directed graph
257    pub fn generate_walks_digraph<E, Ix>(
258        &mut self,
259        graph: &DiGraph<N, E, Ix>,
260    ) -> Result<Vec<RandomWalk<N>>>
261    where
262        N: Clone + std::fmt::Debug,
263        E: EdgeWeight,
264        Ix: petgraph::graph::IndexType,
265    {
266        let mut all_walks = Vec::new();
267
268        for node in graph.nodes() {
269            for _ in 0..self.config.num_walks {
270                let walk = self.walk_generator.simple_random_walk_digraph(
271                    graph,
272                    node,
273                    self.config.walk_length,
274                )?;
275                all_walks.push(walk);
276            }
277        }
278
279        Ok(all_walks)
280    }
281
282    /// Train the DeepWalk model on an undirected graph
283    pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
284    where
285        N: Clone + std::fmt::Debug,
286        E: EdgeWeight,
287        Ix: petgraph::graph::IndexType,
288    {
289        // Initialize random embeddings
290        let mut rng = scirs2_core::random::rng();
291        self.model.initialize_random(graph, &mut rng);
292
293        match self.mode {
294            DeepWalkMode::NegativeSampling => {
295                self.train_negative_sampling(graph, &mut rng)?;
296            }
297            DeepWalkMode::HierarchicalSoftmax => {
298                self.train_hierarchical_softmax(graph, &mut rng)?;
299            }
300        }
301
302        Ok(())
303    }
304
305    /// Train using negative sampling
306    fn train_negative_sampling<E, Ix>(
307        &mut self,
308        graph: &Graph<N, E, Ix>,
309        rng: &mut impl Rng,
310    ) -> Result<()>
311    where
312        N: Clone + std::fmt::Debug,
313        E: EdgeWeight,
314        Ix: petgraph::graph::IndexType,
315    {
316        let negative_sampler = NegativeSampler::new(graph);
317
318        for epoch in 0..self.config.epochs {
319            let walks = self.generate_walks(graph)?;
320            let context_pairs =
321                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
322
323            let mut shuffled_pairs = context_pairs;
324            shuffled_pairs.shuffle(rng);
325
326            let current_lr = self.config.learning_rate
327                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
328
329            self.model.train_skip_gram(
330                &shuffled_pairs,
331                &negative_sampler,
332                current_lr,
333                self.config.negative_samples,
334                rng,
335            )?;
336        }
337
338        Ok(())
339    }
340
341    /// Train using hierarchical softmax approximation
342    fn train_hierarchical_softmax<E, Ix>(
343        &mut self,
344        graph: &Graph<N, E, Ix>,
345        rng: &mut impl Rng,
346    ) -> Result<()>
347    where
348        N: Clone + std::fmt::Debug,
349        E: EdgeWeight,
350        Ix: petgraph::graph::IndexType,
351    {
352        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
353        let n = nodes.len();
354
355        if n == 0 {
356            return Err(GraphError::InvalidGraph(
357                "Cannot train on empty graph".to_string(),
358            ));
359        }
360
361        // Build node-to-index mapping
362        let node_to_idx: HashMap<N, usize> = nodes
363            .iter()
364            .enumerate()
365            .map(|(i, n)| (n.clone(), i))
366            .collect();
367
368        // Compute node frequencies (degree-based)
369        let frequencies: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64 + 1.0).collect();
370
371        // Build Huffman tree
372        let huffman = HuffmanTree::build(&frequencies)?;
373
374        // Initialize internal node vectors (one per internal node)
375        let dim = self.config.dimensions;
376        self.internal_vectors = (0..huffman.num_internal).map(|_| vec![0.0; dim]).collect();
377
378        // Training loop
379        for epoch in 0..self.config.epochs {
380            let walks = self.generate_walks(graph)?;
381
382            let current_lr = self.config.learning_rate
383                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
384
385            // Process each walk
386            for walk in &walks {
387                let walk_indices: Vec<usize> = walk
388                    .nodes
389                    .iter()
390                    .filter_map(|n| node_to_idx.get(n).copied())
391                    .collect();
392
393                // Generate (target, context) pairs from walk
394                for (i, &target_idx) in walk_indices.iter().enumerate() {
395                    let start = i.saturating_sub(self.config.window_size);
396                    let end = (i + self.config.window_size + 1).min(walk_indices.len());
397
398                    for j in start..end {
399                        if i == j {
400                            continue;
401                        }
402
403                        let context_idx = walk_indices[j];
404                        self.hierarchical_softmax_update(
405                            &nodes[target_idx],
406                            context_idx,
407                            &huffman,
408                            current_lr,
409                        );
410                    }
411                }
412            }
413
414            // Shuffle walks for next epoch
415            let _ = rng; // Use rng to avoid unused warning
416        }
417
418        Ok(())
419    }
420
421    /// Update embeddings using hierarchical softmax for one (target, context) pair
422    fn hierarchical_softmax_update(
423        &mut self,
424        target_node: &N,
425        context_idx: usize,
426        huffman: &HuffmanTree,
427        learning_rate: f64,
428    ) where
429        N: Clone,
430    {
431        let dim = self.config.dimensions;
432
433        if context_idx >= huffman.codes.len() {
434            return;
435        }
436
437        let huffman_node = &huffman.codes[context_idx];
438
439        // Get target embedding
440        let target_emb = match self.model.embeddings.get(target_node) {
441            Some(e) => e.vector.clone(),
442            None => return,
443        };
444
445        let mut grad = vec![0.0; dim];
446
447        // Walk along the Huffman tree path
448        for (step, (&is_right, &internal_idx)) in huffman_node
449            .code
450            .iter()
451            .zip(huffman_node.point.iter())
452            .enumerate()
453        {
454            if internal_idx >= self.internal_vectors.len() {
455                continue;
456            }
457
458            // Compute dot product: target . internal_node
459            let dot: f64 = target_emb
460                .iter()
461                .zip(self.internal_vectors[internal_idx].iter())
462                .map(|(a, b)| a * b)
463                .sum();
464
465            let sig = 1.0 / (1.0 + (-dot).exp());
466
467            // Label: 1 for left child (code=false), 0 for right child (code=true)
468            let label = if is_right { 0.0 } else { 1.0 };
469            let g = learning_rate * (label - sig);
470
471            // Accumulate gradient for target embedding
472            for d in 0..dim {
473                grad[d] += g * self.internal_vectors[internal_idx][d];
474            }
475
476            // Update internal node vector
477            for d in 0..dim {
478                self.internal_vectors[internal_idx][d] += g * target_emb[d];
479            }
480
481            let _ = step; // Consume step variable
482        }
483
484        // Apply accumulated gradient to target embedding
485        if let Some(emb) = self.model.embeddings.get_mut(target_node) {
486            for d in 0..dim {
487                emb.vector[d] += grad[d];
488            }
489        }
490    }
491
492    /// Train the DeepWalk model on a directed graph
493    pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
494    where
495        N: Clone + std::fmt::Debug,
496        E: EdgeWeight,
497        Ix: petgraph::graph::IndexType,
498    {
499        let mut rng = scirs2_core::random::rng();
500        self.model.initialize_random_digraph(graph, &mut rng);
501
502        // For directed graphs, we only support negative sampling for now
503        // Build a manual negative sampler from DiGraph
504        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
505        let degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64 + 1.0).collect();
506        let total: f64 = degrees.iter().sum();
507        let powered: Vec<f64> = degrees.iter().map(|d| (d / total).powf(0.75)).collect();
508        let total_powered: f64 = powered.iter().sum();
509        let probs: Vec<f64> = powered.iter().map(|p| p / total_powered).collect();
510
511        let mut cumulative = vec![0.0; probs.len()];
512        if !cumulative.is_empty() {
513            cumulative[0] = probs[0];
514            for i in 1..probs.len() {
515                cumulative[i] = cumulative[i - 1] + probs[i];
516            }
517        }
518
519        for epoch in 0..self.config.epochs {
520            let walks = self.generate_walks_digraph(graph)?;
521            let context_pairs =
522                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
523
524            let mut shuffled_pairs = context_pairs;
525            shuffled_pairs.shuffle(&mut rng);
526
527            let current_lr = self.config.learning_rate
528                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
529
530            let dim = self.config.dimensions;
531            let num_neg = self.config.negative_samples;
532
533            // Manual skip-gram with negative sampling
534            for pair in &shuffled_pairs {
535                let target_emb = match self.model.embeddings.get(&pair.target) {
536                    Some(e) => e.clone(),
537                    None => continue,
538                };
539                let context_emb = match self.model.context_embeddings.get(&pair.context) {
540                    Some(e) => e.clone(),
541                    None => continue,
542                };
543
544                let dot: f64 = target_emb
545                    .vector
546                    .iter()
547                    .zip(context_emb.vector.iter())
548                    .map(|(a, b)| a * b)
549                    .sum();
550                let sig = 1.0 / (1.0 + (-dot).exp());
551                let g = current_lr * (1.0 - sig);
552
553                let mut target_grad = vec![0.0; dim];
554                for d in 0..dim {
555                    target_grad[d] = g * context_emb.vector[d];
556                }
557
558                if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
559                    for d in 0..dim {
560                        ctx.vector[d] += g * target_emb.vector[d];
561                    }
562                }
563
564                // Negative samples
565                for _ in 0..num_neg {
566                    let r = rng.random::<f64>();
567                    let neg_idx = cumulative
568                        .iter()
569                        .position(|&c| r <= c)
570                        .unwrap_or(cumulative.len().saturating_sub(1));
571
572                    if neg_idx >= nodes.len() {
573                        continue;
574                    }
575                    let neg_node = &nodes[neg_idx];
576                    if neg_node == &pair.target || neg_node == &pair.context {
577                        continue;
578                    }
579
580                    if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
581                        let neg_dot: f64 = target_emb
582                            .vector
583                            .iter()
584                            .zip(neg_emb.vector.iter())
585                            .map(|(a, b)| a * b)
586                            .sum();
587                        let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
588                        let neg_g = current_lr * (-neg_sig);
589
590                        for d in 0..dim {
591                            target_grad[d] += neg_g * neg_emb.vector[d];
592                        }
593
594                        if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
595                            for d in 0..dim {
596                                neg_ctx.vector[d] += neg_g * target_emb.vector[d];
597                            }
598                        }
599                    }
600                }
601
602                if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
603                    for d in 0..dim {
604                        target.vector[d] += target_grad[d];
605                    }
606                }
607            }
608        }
609
610        Ok(())
611    }
612
613    /// Get the trained model
614    pub fn model(&self) -> &EmbeddingModel<N> {
615        &self.model
616    }
617
618    /// Get mutable reference to the model
619    pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
620        &mut self.model
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    fn make_triangle() -> Graph<i32, f64> {
629        let mut g = Graph::new();
630        for i in 0..3 {
631            g.add_node(i);
632        }
633        let _ = g.add_edge(0, 1, 1.0);
634        let _ = g.add_edge(1, 2, 1.0);
635        let _ = g.add_edge(0, 2, 1.0);
636        g
637    }
638
639    fn make_path_graph() -> Graph<i32, f64> {
640        let mut g = Graph::new();
641        for i in 0..5 {
642            g.add_node(i);
643        }
644        let _ = g.add_edge(0, 1, 1.0);
645        let _ = g.add_edge(1, 2, 1.0);
646        let _ = g.add_edge(2, 3, 1.0);
647        let _ = g.add_edge(3, 4, 1.0);
648        g
649    }
650
651    fn make_directed_cycle() -> DiGraph<i32, f64> {
652        let mut g = DiGraph::new();
653        for i in 0..4 {
654            g.add_node(i);
655        }
656        let _ = g.add_edge(0, 1, 1.0);
657        let _ = g.add_edge(1, 2, 1.0);
658        let _ = g.add_edge(2, 3, 1.0);
659        let _ = g.add_edge(3, 0, 1.0);
660        g
661    }
662
663    #[test]
664    fn test_deepwalk_negative_sampling() {
665        let g = make_triangle();
666        let config = DeepWalkConfig {
667            dimensions: 8,
668            walk_length: 5,
669            num_walks: 3,
670            window_size: 2,
671            epochs: 2,
672            learning_rate: 0.025,
673            negative_samples: 2,
674        };
675
676        let mut dw = DeepWalk::new(config);
677        assert_eq!(dw.mode(), DeepWalkMode::NegativeSampling);
678
679        let result = dw.train(&g);
680        assert!(
681            result.is_ok(),
682            "DeepWalk negative sampling training should succeed"
683        );
684
685        for node in [0, 1, 2] {
686            assert!(
687                dw.model().get_embedding(&node).is_some(),
688                "Node {node} should have embedding"
689            );
690        }
691    }
692
693    #[test]
694    fn test_deepwalk_hierarchical_softmax() {
695        let g = make_triangle();
696        let config = DeepWalkConfig {
697            dimensions: 8,
698            walk_length: 5,
699            num_walks: 3,
700            window_size: 2,
701            epochs: 2,
702            learning_rate: 0.025,
703            negative_samples: 2,
704        };
705
706        let mut dw = DeepWalk::with_hierarchical_softmax(config);
707        assert_eq!(dw.mode(), DeepWalkMode::HierarchicalSoftmax);
708
709        let result = dw.train(&g);
710        assert!(
711            result.is_ok(),
712            "DeepWalk hierarchical softmax training should succeed"
713        );
714
715        for node in [0, 1, 2] {
716            assert!(
717                dw.model().get_embedding(&node).is_some(),
718                "Node {node} should have embedding"
719            );
720        }
721    }
722
723    #[test]
724    fn test_deepwalk_walk_generation() {
725        let g = make_path_graph();
726        let config = DeepWalkConfig {
727            dimensions: 8,
728            walk_length: 4,
729            num_walks: 2,
730            ..Default::default()
731        };
732
733        let mut dw = DeepWalk::new(config);
734        let walks = dw.generate_walks(&g);
735        assert!(walks.is_ok());
736
737        let walks = walks.expect("walks should be valid");
738        // 5 nodes * 2 walks = 10 walks
739        assert_eq!(walks.len(), 10);
740
741        for walk in &walks {
742            assert!(!walk.nodes.is_empty());
743            assert!(walk.nodes.len() <= 4);
744            // All nodes should be valid
745            for node in &walk.nodes {
746                assert!((0..5).contains(node));
747            }
748        }
749    }
750
751    #[test]
752    fn test_deepwalk_digraph() {
753        let g = make_directed_cycle();
754        let config = DeepWalkConfig {
755            dimensions: 8,
756            walk_length: 6,
757            num_walks: 3,
758            window_size: 2,
759            epochs: 2,
760            learning_rate: 0.025,
761            negative_samples: 2,
762        };
763
764        let mut dw = DeepWalk::new(config);
765        let result = dw.train_digraph(&g);
766        assert!(result.is_ok(), "DiGraph DeepWalk training should succeed");
767
768        for node in 0..4 {
769            assert!(
770                dw.model().get_embedding(&node).is_some(),
771                "Node {node} should have embedding in directed graph"
772            );
773        }
774    }
775
776    #[test]
777    fn test_deepwalk_mode_switching() {
778        let g = make_triangle();
779        let config = DeepWalkConfig {
780            dimensions: 8,
781            walk_length: 5,
782            num_walks: 2,
783            epochs: 1,
784            ..Default::default()
785        };
786
787        let mut dw = DeepWalk::new(config);
788        assert_eq!(dw.mode(), DeepWalkMode::NegativeSampling);
789
790        dw.set_mode(DeepWalkMode::HierarchicalSoftmax);
791        assert_eq!(dw.mode(), DeepWalkMode::HierarchicalSoftmax);
792
793        let result = dw.train(&g);
794        assert!(result.is_ok());
795    }
796
797    #[test]
798    fn test_deepwalk_embedding_dimensions() {
799        let g = make_triangle();
800        let config = DeepWalkConfig {
801            dimensions: 32,
802            walk_length: 5,
803            num_walks: 2,
804            epochs: 1,
805            ..Default::default()
806        };
807
808        let mut dw = DeepWalk::new(config);
809        let _ = dw.train(&g);
810
811        for node in [0, 1, 2] {
812            let emb = dw.model().get_embedding(&node);
813            assert!(emb.is_some());
814            assert_eq!(emb.map(|e| e.dimensions()).unwrap_or(0), 32);
815        }
816    }
817
818    #[test]
819    fn test_huffman_tree_basic() {
820        let freqs = vec![5.0, 2.0, 1.0, 3.0];
821        let tree = HuffmanTree::build(&freqs);
822        assert!(tree.is_ok());
823
824        let tree = tree.expect("tree should be valid");
825        assert_eq!(tree.codes.len(), 4);
826        assert_eq!(tree.num_internal, 3);
827
828        // Each code should be non-empty
829        for (i, code) in tree.codes.iter().enumerate() {
830            assert!(
831                !code.code.is_empty(),
832                "Node {i} should have non-empty Huffman code"
833            );
834            assert!(
835                !code.point.is_empty(),
836                "Node {i} should have non-empty path"
837            );
838        }
839    }
840
841    #[test]
842    fn test_huffman_tree_single_node() {
843        let freqs = vec![1.0];
844        let tree = HuffmanTree::build(&freqs);
845        assert!(tree.is_ok());
846
847        let tree = tree.expect("tree should be valid");
848        assert_eq!(tree.codes.len(), 1);
849    }
850}