sklears_semi_supervised/
dynamic_graph_learning.rs

1//! Dynamic graph learning for streaming and evolving semi-supervised scenarios
2//!
3//! This module provides advanced dynamic graph learning algorithms that can handle
4//! continuously evolving graph structures, streaming data updates, and online
5//! semi-supervised learning scenarios.
6
7use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use sklears_core::error::SklearsError;
10use std::collections::{HashMap, VecDeque};
11
12/// Dynamic graph learning for streaming and continuously evolving scenarios
13#[derive(Clone)]
14pub struct DynamicGraphLearning {
15    /// Learning rate for online updates
16    pub learning_rate: f64,
17    /// Forgetting factor for old connections
18    pub forgetting_factor: f64,
19    /// Number of neighbors for new node integration
20    pub k_neighbors: usize,
21    /// Buffer size for streaming updates
22    pub buffer_size: usize,
23    /// Threshold for edge creation/removal
24    pub edge_threshold: f64,
25    /// Maximum number of nodes to maintain
26    pub max_nodes: Option<usize>,
27    /// Random state for reproducibility
28    pub random_state: Option<u64>,
29    /// Current adjacency matrix
30    adjacency_matrix: Option<Array2<f64>>,
31    /// Node features buffer
32    node_features: Option<Array2<f64>>,
33    /// Update history buffer
34    update_buffer: VecDeque<GraphUpdate>,
35}
36
37/// Represents a graph update operation
38#[derive(Clone, Debug)]
39pub struct GraphUpdate {
40    /// Type of update: "add_node", "remove_node", "update_edge", "update_features"
41    pub update_type: String,
42    /// Node indices involved
43    pub node_indices: Vec<usize>,
44    /// New feature values (for feature updates)
45    pub features: Option<Array1<f64>>,
46    /// Edge weight (for edge updates)
47    pub edge_weight: Option<f64>,
48    /// Timestamp of update
49    pub timestamp: f64,
50}
51
52impl DynamicGraphLearning {
53    /// Create a new dynamic graph learning instance
54    pub fn new() -> Self {
55        Self {
56            learning_rate: 0.01,
57            forgetting_factor: 0.95,
58            k_neighbors: 5,
59            buffer_size: 1000,
60            edge_threshold: 0.1,
61            max_nodes: None,
62            random_state: None,
63            adjacency_matrix: None,
64            node_features: None,
65            update_buffer: VecDeque::new(),
66        }
67    }
68
69    /// Set the learning rate for online updates
70    pub fn learning_rate(mut self, lr: f64) -> Self {
71        self.learning_rate = lr;
72        self
73    }
74
75    /// Set the forgetting factor for old connections
76    pub fn forgetting_factor(mut self, factor: f64) -> Self {
77        self.forgetting_factor = factor;
78        self
79    }
80
81    /// Set the number of neighbors for new node integration
82    pub fn k_neighbors(mut self, k: usize) -> Self {
83        self.k_neighbors = k;
84        self
85    }
86
87    /// Set the buffer size for streaming updates
88    pub fn buffer_size(mut self, size: usize) -> Self {
89        self.buffer_size = size;
90        self
91    }
92
93    /// Set the edge threshold for creation/removal
94    pub fn edge_threshold(mut self, threshold: f64) -> Self {
95        self.edge_threshold = threshold;
96        self
97    }
98
99    /// Set the maximum number of nodes to maintain
100    pub fn max_nodes(mut self, max_nodes: usize) -> Self {
101        self.max_nodes = Some(max_nodes);
102        self
103    }
104
105    /// Set the random state for reproducibility
106    pub fn random_state(mut self, seed: u64) -> Self {
107        self.random_state = Some(seed);
108        self
109    }
110
111    /// Initialize the dynamic graph with initial data
112    pub fn initialize(&mut self, initial_features: ArrayView2<f64>) -> Result<(), SklearsError> {
113        let n_samples = initial_features.nrows();
114        let n_features = initial_features.ncols();
115
116        if n_samples == 0 {
117            return Err(SklearsError::InvalidInput(
118                "No initial data provided".to_string(),
119            ));
120        }
121
122        // Initialize node features
123        self.node_features = Some(initial_features.to_owned());
124
125        // Initialize adjacency matrix with k-NN graph
126        let mut adjacency = Array2::zeros((n_samples, n_samples));
127
128        for i in 0..n_samples {
129            let mut distances: Vec<(usize, f64)> = Vec::new();
130
131            for j in 0..n_samples {
132                if i != j {
133                    let dist =
134                        self.compute_distance(initial_features.row(i), initial_features.row(j));
135                    distances.push((j, dist));
136                }
137            }
138
139            // Sort by distance and connect to k nearest neighbors
140            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
141            for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
142                let weight = (-dist).exp(); // Gaussian similarity
143                adjacency[[i, neighbor]] = weight;
144                adjacency[[neighbor, i]] = weight; // Symmetric
145            }
146        }
147
148        self.adjacency_matrix = Some(adjacency);
149        Ok(())
150    }
151
152    /// Add new nodes to the dynamic graph
153    pub fn add_nodes(&mut self, new_features: ArrayView2<f64>) -> Result<(), SklearsError> {
154        if self.node_features.is_none() || self.adjacency_matrix.is_none() {
155            return Err(SklearsError::InvalidInput(
156                "Graph not initialized".to_string(),
157            ));
158        }
159
160        let new_n_nodes = new_features.nrows();
161
162        // Check max nodes constraint and prune if necessary
163        if let Some(max_nodes) = self.max_nodes {
164            let current_n_nodes = self.node_features.as_ref().unwrap().nrows();
165            let total_nodes = current_n_nodes + new_n_nodes;
166            if total_nodes > max_nodes {
167                self.prune_old_nodes(max_nodes - new_n_nodes)?;
168            }
169        }
170
171        // Get references after potential pruning
172        let current_features = self.node_features.as_ref().unwrap();
173        let current_adjacency = self.adjacency_matrix.as_ref().unwrap();
174
175        let old_n_nodes = current_features.nrows();
176        let total_nodes = old_n_nodes + new_n_nodes;
177
178        // Extend feature matrix
179        let mut extended_features = Array2::zeros((total_nodes, current_features.ncols()));
180        extended_features
181            .slice_mut(s![..old_n_nodes, ..])
182            .assign(current_features);
183        extended_features
184            .slice_mut(s![old_n_nodes.., ..])
185            .assign(&new_features);
186
187        // Extend adjacency matrix
188        let mut extended_adjacency = Array2::zeros((total_nodes, total_nodes));
189        extended_adjacency
190            .slice_mut(s![..old_n_nodes, ..old_n_nodes])
191            .assign(current_adjacency);
192
193        // Connect new nodes to existing nodes
194        for i in old_n_nodes..total_nodes {
195            let mut distances: Vec<(usize, f64)> = Vec::new();
196
197            for j in 0..old_n_nodes {
198                let dist =
199                    self.compute_distance(extended_features.row(i), extended_features.row(j));
200                distances.push((j, dist));
201            }
202
203            // Connect to k nearest existing neighbors
204            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
205            for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
206                let weight = (-dist).exp();
207                extended_adjacency[[i, neighbor]] = weight;
208                extended_adjacency[[neighbor, i]] = weight;
209            }
210
211            // Connect new nodes to each other
212            for j in (old_n_nodes..total_nodes).filter(|&j| j != i) {
213                let dist =
214                    self.compute_distance(extended_features.row(i), extended_features.row(j));
215                let weight = (-dist).exp();
216                if weight > self.edge_threshold {
217                    extended_adjacency[[i, j]] = weight;
218                    extended_adjacency[[j, i]] = weight;
219                }
220            }
221        }
222
223        self.node_features = Some(extended_features);
224        self.adjacency_matrix = Some(extended_adjacency);
225
226        // Record updates
227        for i in old_n_nodes..total_nodes {
228            self.record_update(GraphUpdate {
229                update_type: "add_node".to_string(),
230                node_indices: vec![i],
231                features: Some(new_features.row(i - old_n_nodes).to_owned()),
232                edge_weight: None,
233                timestamp: self.get_current_time(),
234            });
235        }
236
237        Ok(())
238    }
239
240    /// Update node features dynamically
241    pub fn update_node_features(
242        &mut self,
243        node_idx: usize,
244        new_features: ArrayView1<f64>,
245    ) -> Result<(), SklearsError> {
246        if self.node_features.is_none() {
247            return Err(SklearsError::InvalidInput(
248                "Graph not initialized".to_string(),
249            ));
250        }
251
252        let features = self.node_features.as_mut().unwrap();
253
254        if node_idx >= features.nrows() {
255            return Err(SklearsError::InvalidInput(
256                "Node index out of bounds".to_string(),
257            ));
258        }
259
260        // Apply online learning update
261        let mut current_features = features.row_mut(node_idx);
262        for (i, &new_val) in new_features.iter().enumerate() {
263            current_features[i] =
264                (1.0 - self.learning_rate) * current_features[i] + self.learning_rate * new_val;
265        }
266
267        // Update edges based on new features
268        self.update_edges_for_node(node_idx)?;
269
270        // Record update
271        self.record_update(GraphUpdate {
272            update_type: "update_features".to_string(),
273            node_indices: vec![node_idx],
274            features: Some(new_features.to_owned()),
275            edge_weight: None,
276            timestamp: self.get_current_time(),
277        });
278
279        Ok(())
280    }
281
282    /// Update edges for a specific node after feature change
283    fn update_edges_for_node(&mut self, node_idx: usize) -> Result<(), SklearsError> {
284        if self.node_features.is_none() || self.adjacency_matrix.is_none() {
285            return Ok(());
286        }
287
288        // Create a copy of features to avoid borrowing conflicts
289        let features = self.node_features.as_ref().unwrap().clone();
290        let n_nodes = features.nrows();
291        let forgetting_factor = self.forgetting_factor;
292        let edge_threshold = self.edge_threshold;
293
294        // Get mutable reference to adjacency matrix
295        let adjacency = self.adjacency_matrix.as_mut().unwrap();
296
297        // Recompute edges for this node
298        for other_idx in 0..n_nodes {
299            if node_idx != other_idx {
300                let dist =
301                    Self::compute_distance_static(features.row(node_idx), features.row(other_idx));
302                let new_weight = (-dist).exp();
303
304                // Apply forgetting factor to existing edge and add new weight
305                let current_weight = adjacency[[node_idx, other_idx]];
306                let updated_weight =
307                    forgetting_factor * current_weight + (1.0 - forgetting_factor) * new_weight;
308
309                // Apply threshold for edge maintenance
310                let final_weight = if updated_weight > edge_threshold {
311                    updated_weight
312                } else {
313                    0.0
314                };
315
316                adjacency[[node_idx, other_idx]] = final_weight;
317                adjacency[[other_idx, node_idx]] = final_weight; // Symmetric
318            }
319        }
320
321        Ok(())
322    }
323
324    /// Prune old nodes to maintain memory constraints
325    fn prune_old_nodes(&mut self, target_nodes: usize) -> Result<(), SklearsError> {
326        if self.node_features.is_none() || self.adjacency_matrix.is_none() {
327            return Ok(());
328        }
329
330        let current_nodes = self.node_features.as_ref().unwrap().nrows();
331        if current_nodes <= target_nodes {
332            return Ok(());
333        }
334
335        let nodes_to_remove = current_nodes - target_nodes;
336
337        // Simple strategy: remove oldest nodes (first nodes_to_remove nodes)
338        // In practice, you might want more sophisticated strategies based on
339        // node importance, connectivity, or recency of updates
340
341        let features = self.node_features.as_ref().unwrap();
342        let adjacency = self.adjacency_matrix.as_ref().unwrap();
343
344        // Create new matrices without the pruned nodes
345        let new_features = features.slice(s![nodes_to_remove.., ..]).to_owned();
346        let new_adjacency = adjacency
347            .slice(s![nodes_to_remove.., nodes_to_remove..])
348            .to_owned();
349
350        self.node_features = Some(new_features);
351        self.adjacency_matrix = Some(new_adjacency);
352
353        Ok(())
354    }
355
356    /// Get the current adjacency matrix
357    pub fn get_adjacency_matrix(&self) -> Option<&Array2<f64>> {
358        self.adjacency_matrix.as_ref()
359    }
360
361    /// Get the current node features
362    pub fn get_node_features(&self) -> Option<&Array2<f64>> {
363        self.node_features.as_ref()
364    }
365
366    /// Get recent updates from the buffer
367    pub fn get_recent_updates(&self, n_updates: usize) -> Vec<&GraphUpdate> {
368        self.update_buffer.iter().rev().take(n_updates).collect()
369    }
370
371    /// Compute distance between two feature vectors
372    fn compute_distance(&self, feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
373        Self::compute_distance_static(feat1, feat2)
374    }
375
376    /// Static version of compute_distance to avoid borrowing conflicts
377    fn compute_distance_static(feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
378        feat1
379            .iter()
380            .zip(feat2.iter())
381            .map(|(&a, &b)| (a - b).powi(2))
382            .sum::<f64>()
383            .sqrt()
384    }
385
386    /// Record a graph update in the buffer
387    fn record_update(&mut self, update: GraphUpdate) {
388        self.update_buffer.push_back(update);
389
390        // Maintain buffer size
391        while self.update_buffer.len() > self.buffer_size {
392            self.update_buffer.pop_front();
393        }
394    }
395
396    /// Get current timestamp (simplified)
397    fn get_current_time(&self) -> f64 {
398        std::time::SystemTime::now()
399            .duration_since(std::time::UNIX_EPOCH)
400            .unwrap_or_default()
401            .as_secs_f64()
402    }
403
404    /// Apply decay to all edges to simulate forgetting
405    pub fn apply_temporal_decay(&mut self) -> Result<(), SklearsError> {
406        if let Some(adjacency) = self.adjacency_matrix.as_mut() {
407            *adjacency *= self.forgetting_factor;
408
409            // Remove edges below threshold
410            adjacency.mapv_inplace(|x| if x < self.edge_threshold { 0.0 } else { x });
411        }
412        Ok(())
413    }
414
415    /// Get graph statistics
416    pub fn get_statistics(&self) -> HashMap<String, f64> {
417        let mut stats = HashMap::new();
418
419        if let Some(adjacency) = &self.adjacency_matrix {
420            let n_nodes = adjacency.nrows() as f64;
421            let total_edges = adjacency.iter().filter(|&&x| x > 0.0).count() as f64 / 2.0; // Undirected
422            let density = if n_nodes > 1.0 {
423                total_edges / (n_nodes * (n_nodes - 1.0) / 2.0)
424            } else {
425                0.0
426            };
427
428            stats.insert("n_nodes".to_string(), n_nodes);
429            stats.insert("n_edges".to_string(), total_edges);
430            stats.insert("density".to_string(), density);
431            stats.insert("avg_degree".to_string(), total_edges * 2.0 / n_nodes);
432        }
433
434        stats.insert("buffer_size".to_string(), self.update_buffer.len() as f64);
435        stats
436    }
437}
438
439impl Default for DynamicGraphLearning {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445#[allow(non_snake_case)]
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use approx::assert_abs_diff_eq;
450    use scirs2_core::array;
451
452    #[test]
453    fn test_dynamic_graph_initialization() {
454        let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
455
456        let initial_data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
457
458        let result = dgl.initialize(initial_data.view());
459        assert!(result.is_ok());
460
461        let adjacency = dgl.get_adjacency_matrix().unwrap();
462        assert_eq!(adjacency.dim(), (3, 3));
463
464        // Check that diagonal is zero
465        for i in 0..3 {
466            assert_eq!(adjacency[[i, i]], 0.0);
467        }
468    }
469
470    #[test]
471    fn test_add_nodes() {
472        let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
473
474        let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
475
476        dgl.initialize(initial_data.view()).unwrap();
477
478        let new_data = array![[3.0, 4.0], [4.0, 5.0]];
479
480        let result = dgl.add_nodes(new_data.view());
481        assert!(result.is_ok());
482
483        let adjacency = dgl.get_adjacency_matrix().unwrap();
484        assert_eq!(adjacency.dim(), (4, 4));
485
486        let features = dgl.get_node_features().unwrap();
487        assert_eq!(features.dim(), (4, 2));
488    }
489
490    #[test]
491    fn test_update_node_features() {
492        let mut dgl = DynamicGraphLearning::new()
493            .k_neighbors(2)
494            .learning_rate(0.5);
495
496        let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
497
498        dgl.initialize(initial_data.view()).unwrap();
499
500        let new_features = array![5.0, 6.0];
501        let result = dgl.update_node_features(0, new_features.view());
502        assert!(result.is_ok());
503
504        let features = dgl.get_node_features().unwrap();
505        // Features should be updated with learning rate
506        assert!(features[[0, 0]] > 1.0);
507        assert!(features[[0, 1]] > 2.0);
508    }
509
510    #[test]
511    fn test_temporal_decay() {
512        let mut dgl = DynamicGraphLearning::new()
513            .k_neighbors(2)
514            .forgetting_factor(0.5)
515            .edge_threshold(0.1);
516
517        let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
518
519        dgl.initialize(initial_data.view()).unwrap();
520
521        let original_adjacency = dgl.get_adjacency_matrix().unwrap().clone();
522
523        dgl.apply_temporal_decay().unwrap();
524
525        let decayed_adjacency = dgl.get_adjacency_matrix().unwrap();
526
527        // Check that edges have been decayed
528        for i in 0..2 {
529            for j in 0..2 {
530                if i != j && original_adjacency[[i, j]] > 0.0 {
531                    assert!(decayed_adjacency[[i, j]] < original_adjacency[[i, j]]);
532                }
533            }
534        }
535    }
536
537    #[test]
538    fn test_max_nodes_constraint() {
539        let mut dgl = DynamicGraphLearning::new().k_neighbors(2).max_nodes(3);
540
541        let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
542
543        dgl.initialize(initial_data.view()).unwrap();
544
545        let new_data = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
546
547        let result = dgl.add_nodes(new_data.view());
548        assert!(result.is_ok());
549
550        let adjacency = dgl.get_adjacency_matrix().unwrap();
551        assert_eq!(adjacency.nrows(), 3); // Should be pruned to max_nodes
552    }
553
554    #[test]
555    fn test_graph_statistics() {
556        let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
557
558        let initial_data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
559
560        dgl.initialize(initial_data.view()).unwrap();
561
562        let stats = dgl.get_statistics();
563
564        assert!(stats.contains_key("n_nodes"));
565        assert!(stats.contains_key("n_edges"));
566        assert!(stats.contains_key("density"));
567        assert!(stats.contains_key("avg_degree"));
568
569        assert_eq!(stats["n_nodes"], 3.0);
570        assert!(stats["n_edges"] > 0.0);
571    }
572
573    #[test]
574    fn test_update_buffer() {
575        let mut dgl = DynamicGraphLearning::new().buffer_size(2);
576
577        let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
578
579        dgl.initialize(initial_data.view()).unwrap();
580
581        let new_features = array![5.0, 6.0];
582        dgl.update_node_features(0, new_features.view()).unwrap();
583        dgl.update_node_features(1, new_features.view()).unwrap();
584        dgl.update_node_features(0, new_features.view()).unwrap();
585
586        let recent_updates = dgl.get_recent_updates(5);
587        assert!(recent_updates.len() <= 2); // Buffer size constraint
588    }
589
590    #[test]
591    fn test_error_cases() {
592        let mut dgl = DynamicGraphLearning::new();
593
594        // Test operations before initialization
595        let new_data = array![[1.0, 2.0]];
596        assert!(dgl.add_nodes(new_data.view()).is_err());
597
598        let new_features = array![5.0, 6.0];
599        assert!(dgl.update_node_features(0, new_features.view()).is_err());
600
601        // Test initialization with empty data
602        let empty_data = Array2::<f64>::zeros((0, 2));
603        assert!(dgl.initialize(empty_data.view()).is_err());
604
605        // Test feature update with invalid index
606        let initial_data = array![[1.0, 2.0]];
607        dgl.initialize(initial_data.view()).unwrap();
608        assert!(dgl.update_node_features(10, new_features.view()).is_err());
609    }
610}