1use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use sklears_core::error::SklearsError;
10use std::collections::{HashMap, VecDeque};
11
12#[derive(Clone)]
14pub struct DynamicGraphLearning {
15 pub learning_rate: f64,
17 pub forgetting_factor: f64,
19 pub k_neighbors: usize,
21 pub buffer_size: usize,
23 pub edge_threshold: f64,
25 pub max_nodes: Option<usize>,
27 pub random_state: Option<u64>,
29 adjacency_matrix: Option<Array2<f64>>,
31 node_features: Option<Array2<f64>>,
33 update_buffer: VecDeque<GraphUpdate>,
35}
36
37#[derive(Clone, Debug)]
39pub struct GraphUpdate {
40 pub update_type: String,
42 pub node_indices: Vec<usize>,
44 pub features: Option<Array1<f64>>,
46 pub edge_weight: Option<f64>,
48 pub timestamp: f64,
50}
51
52impl DynamicGraphLearning {
53 pub fn new() -> Self {
55 Self {
56 learning_rate: 0.01,
57 forgetting_factor: 0.95,
58 k_neighbors: 5,
59 buffer_size: 1000,
60 edge_threshold: 0.1,
61 max_nodes: None,
62 random_state: None,
63 adjacency_matrix: None,
64 node_features: None,
65 update_buffer: VecDeque::new(),
66 }
67 }
68
69 pub fn learning_rate(mut self, lr: f64) -> Self {
71 self.learning_rate = lr;
72 self
73 }
74
75 pub fn forgetting_factor(mut self, factor: f64) -> Self {
77 self.forgetting_factor = factor;
78 self
79 }
80
81 pub fn k_neighbors(mut self, k: usize) -> Self {
83 self.k_neighbors = k;
84 self
85 }
86
87 pub fn buffer_size(mut self, size: usize) -> Self {
89 self.buffer_size = size;
90 self
91 }
92
93 pub fn edge_threshold(mut self, threshold: f64) -> Self {
95 self.edge_threshold = threshold;
96 self
97 }
98
99 pub fn max_nodes(mut self, max_nodes: usize) -> Self {
101 self.max_nodes = Some(max_nodes);
102 self
103 }
104
105 pub fn random_state(mut self, seed: u64) -> Self {
107 self.random_state = Some(seed);
108 self
109 }
110
111 pub fn initialize(&mut self, initial_features: ArrayView2<f64>) -> Result<(), SklearsError> {
113 let n_samples = initial_features.nrows();
114 let n_features = initial_features.ncols();
115
116 if n_samples == 0 {
117 return Err(SklearsError::InvalidInput(
118 "No initial data provided".to_string(),
119 ));
120 }
121
122 self.node_features = Some(initial_features.to_owned());
124
125 let mut adjacency = Array2::zeros((n_samples, n_samples));
127
128 for i in 0..n_samples {
129 let mut distances: Vec<(usize, f64)> = Vec::new();
130
131 for j in 0..n_samples {
132 if i != j {
133 let dist =
134 self.compute_distance(initial_features.row(i), initial_features.row(j));
135 distances.push((j, dist));
136 }
137 }
138
139 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
141 for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
142 let weight = (-dist).exp(); adjacency[[i, neighbor]] = weight;
144 adjacency[[neighbor, i]] = weight; }
146 }
147
148 self.adjacency_matrix = Some(adjacency);
149 Ok(())
150 }
151
152 pub fn add_nodes(&mut self, new_features: ArrayView2<f64>) -> Result<(), SklearsError> {
154 if self.node_features.is_none() || self.adjacency_matrix.is_none() {
155 return Err(SklearsError::InvalidInput(
156 "Graph not initialized".to_string(),
157 ));
158 }
159
160 let new_n_nodes = new_features.nrows();
161
162 if let Some(max_nodes) = self.max_nodes {
164 let current_n_nodes = self.node_features.as_ref().unwrap().nrows();
165 let total_nodes = current_n_nodes + new_n_nodes;
166 if total_nodes > max_nodes {
167 self.prune_old_nodes(max_nodes - new_n_nodes)?;
168 }
169 }
170
171 let current_features = self.node_features.as_ref().unwrap();
173 let current_adjacency = self.adjacency_matrix.as_ref().unwrap();
174
175 let old_n_nodes = current_features.nrows();
176 let total_nodes = old_n_nodes + new_n_nodes;
177
178 let mut extended_features = Array2::zeros((total_nodes, current_features.ncols()));
180 extended_features
181 .slice_mut(s![..old_n_nodes, ..])
182 .assign(current_features);
183 extended_features
184 .slice_mut(s![old_n_nodes.., ..])
185 .assign(&new_features);
186
187 let mut extended_adjacency = Array2::zeros((total_nodes, total_nodes));
189 extended_adjacency
190 .slice_mut(s![..old_n_nodes, ..old_n_nodes])
191 .assign(current_adjacency);
192
193 for i in old_n_nodes..total_nodes {
195 let mut distances: Vec<(usize, f64)> = Vec::new();
196
197 for j in 0..old_n_nodes {
198 let dist =
199 self.compute_distance(extended_features.row(i), extended_features.row(j));
200 distances.push((j, dist));
201 }
202
203 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
205 for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
206 let weight = (-dist).exp();
207 extended_adjacency[[i, neighbor]] = weight;
208 extended_adjacency[[neighbor, i]] = weight;
209 }
210
211 for j in (old_n_nodes..total_nodes).filter(|&j| j != i) {
213 let dist =
214 self.compute_distance(extended_features.row(i), extended_features.row(j));
215 let weight = (-dist).exp();
216 if weight > self.edge_threshold {
217 extended_adjacency[[i, j]] = weight;
218 extended_adjacency[[j, i]] = weight;
219 }
220 }
221 }
222
223 self.node_features = Some(extended_features);
224 self.adjacency_matrix = Some(extended_adjacency);
225
226 for i in old_n_nodes..total_nodes {
228 self.record_update(GraphUpdate {
229 update_type: "add_node".to_string(),
230 node_indices: vec![i],
231 features: Some(new_features.row(i - old_n_nodes).to_owned()),
232 edge_weight: None,
233 timestamp: self.get_current_time(),
234 });
235 }
236
237 Ok(())
238 }
239
240 pub fn update_node_features(
242 &mut self,
243 node_idx: usize,
244 new_features: ArrayView1<f64>,
245 ) -> Result<(), SklearsError> {
246 if self.node_features.is_none() {
247 return Err(SklearsError::InvalidInput(
248 "Graph not initialized".to_string(),
249 ));
250 }
251
252 let features = self.node_features.as_mut().unwrap();
253
254 if node_idx >= features.nrows() {
255 return Err(SklearsError::InvalidInput(
256 "Node index out of bounds".to_string(),
257 ));
258 }
259
260 let mut current_features = features.row_mut(node_idx);
262 for (i, &new_val) in new_features.iter().enumerate() {
263 current_features[i] =
264 (1.0 - self.learning_rate) * current_features[i] + self.learning_rate * new_val;
265 }
266
267 self.update_edges_for_node(node_idx)?;
269
270 self.record_update(GraphUpdate {
272 update_type: "update_features".to_string(),
273 node_indices: vec![node_idx],
274 features: Some(new_features.to_owned()),
275 edge_weight: None,
276 timestamp: self.get_current_time(),
277 });
278
279 Ok(())
280 }
281
282 fn update_edges_for_node(&mut self, node_idx: usize) -> Result<(), SklearsError> {
284 if self.node_features.is_none() || self.adjacency_matrix.is_none() {
285 return Ok(());
286 }
287
288 let features = self.node_features.as_ref().unwrap().clone();
290 let n_nodes = features.nrows();
291 let forgetting_factor = self.forgetting_factor;
292 let edge_threshold = self.edge_threshold;
293
294 let adjacency = self.adjacency_matrix.as_mut().unwrap();
296
297 for other_idx in 0..n_nodes {
299 if node_idx != other_idx {
300 let dist =
301 Self::compute_distance_static(features.row(node_idx), features.row(other_idx));
302 let new_weight = (-dist).exp();
303
304 let current_weight = adjacency[[node_idx, other_idx]];
306 let updated_weight =
307 forgetting_factor * current_weight + (1.0 - forgetting_factor) * new_weight;
308
309 let final_weight = if updated_weight > edge_threshold {
311 updated_weight
312 } else {
313 0.0
314 };
315
316 adjacency[[node_idx, other_idx]] = final_weight;
317 adjacency[[other_idx, node_idx]] = final_weight; }
319 }
320
321 Ok(())
322 }
323
324 fn prune_old_nodes(&mut self, target_nodes: usize) -> Result<(), SklearsError> {
326 if self.node_features.is_none() || self.adjacency_matrix.is_none() {
327 return Ok(());
328 }
329
330 let current_nodes = self.node_features.as_ref().unwrap().nrows();
331 if current_nodes <= target_nodes {
332 return Ok(());
333 }
334
335 let nodes_to_remove = current_nodes - target_nodes;
336
337 let features = self.node_features.as_ref().unwrap();
342 let adjacency = self.adjacency_matrix.as_ref().unwrap();
343
344 let new_features = features.slice(s![nodes_to_remove.., ..]).to_owned();
346 let new_adjacency = adjacency
347 .slice(s![nodes_to_remove.., nodes_to_remove..])
348 .to_owned();
349
350 self.node_features = Some(new_features);
351 self.adjacency_matrix = Some(new_adjacency);
352
353 Ok(())
354 }
355
356 pub fn get_adjacency_matrix(&self) -> Option<&Array2<f64>> {
358 self.adjacency_matrix.as_ref()
359 }
360
361 pub fn get_node_features(&self) -> Option<&Array2<f64>> {
363 self.node_features.as_ref()
364 }
365
366 pub fn get_recent_updates(&self, n_updates: usize) -> Vec<&GraphUpdate> {
368 self.update_buffer.iter().rev().take(n_updates).collect()
369 }
370
371 fn compute_distance(&self, feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
373 Self::compute_distance_static(feat1, feat2)
374 }
375
376 fn compute_distance_static(feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
378 feat1
379 .iter()
380 .zip(feat2.iter())
381 .map(|(&a, &b)| (a - b).powi(2))
382 .sum::<f64>()
383 .sqrt()
384 }
385
386 fn record_update(&mut self, update: GraphUpdate) {
388 self.update_buffer.push_back(update);
389
390 while self.update_buffer.len() > self.buffer_size {
392 self.update_buffer.pop_front();
393 }
394 }
395
396 fn get_current_time(&self) -> f64 {
398 std::time::SystemTime::now()
399 .duration_since(std::time::UNIX_EPOCH)
400 .unwrap_or_default()
401 .as_secs_f64()
402 }
403
404 pub fn apply_temporal_decay(&mut self) -> Result<(), SklearsError> {
406 if let Some(adjacency) = self.adjacency_matrix.as_mut() {
407 *adjacency *= self.forgetting_factor;
408
409 adjacency.mapv_inplace(|x| if x < self.edge_threshold { 0.0 } else { x });
411 }
412 Ok(())
413 }
414
415 pub fn get_statistics(&self) -> HashMap<String, f64> {
417 let mut stats = HashMap::new();
418
419 if let Some(adjacency) = &self.adjacency_matrix {
420 let n_nodes = adjacency.nrows() as f64;
421 let total_edges = adjacency.iter().filter(|&&x| x > 0.0).count() as f64 / 2.0; let density = if n_nodes > 1.0 {
423 total_edges / (n_nodes * (n_nodes - 1.0) / 2.0)
424 } else {
425 0.0
426 };
427
428 stats.insert("n_nodes".to_string(), n_nodes);
429 stats.insert("n_edges".to_string(), total_edges);
430 stats.insert("density".to_string(), density);
431 stats.insert("avg_degree".to_string(), total_edges * 2.0 / n_nodes);
432 }
433
434 stats.insert("buffer_size".to_string(), self.update_buffer.len() as f64);
435 stats
436 }
437}
438
439impl Default for DynamicGraphLearning {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445#[allow(non_snake_case)]
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use approx::assert_abs_diff_eq;
450 use scirs2_core::array;
451
452 #[test]
453 fn test_dynamic_graph_initialization() {
454 let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
455
456 let initial_data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
457
458 let result = dgl.initialize(initial_data.view());
459 assert!(result.is_ok());
460
461 let adjacency = dgl.get_adjacency_matrix().unwrap();
462 assert_eq!(adjacency.dim(), (3, 3));
463
464 for i in 0..3 {
466 assert_eq!(adjacency[[i, i]], 0.0);
467 }
468 }
469
470 #[test]
471 fn test_add_nodes() {
472 let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
473
474 let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
475
476 dgl.initialize(initial_data.view()).unwrap();
477
478 let new_data = array![[3.0, 4.0], [4.0, 5.0]];
479
480 let result = dgl.add_nodes(new_data.view());
481 assert!(result.is_ok());
482
483 let adjacency = dgl.get_adjacency_matrix().unwrap();
484 assert_eq!(adjacency.dim(), (4, 4));
485
486 let features = dgl.get_node_features().unwrap();
487 assert_eq!(features.dim(), (4, 2));
488 }
489
490 #[test]
491 fn test_update_node_features() {
492 let mut dgl = DynamicGraphLearning::new()
493 .k_neighbors(2)
494 .learning_rate(0.5);
495
496 let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
497
498 dgl.initialize(initial_data.view()).unwrap();
499
500 let new_features = array![5.0, 6.0];
501 let result = dgl.update_node_features(0, new_features.view());
502 assert!(result.is_ok());
503
504 let features = dgl.get_node_features().unwrap();
505 assert!(features[[0, 0]] > 1.0);
507 assert!(features[[0, 1]] > 2.0);
508 }
509
510 #[test]
511 fn test_temporal_decay() {
512 let mut dgl = DynamicGraphLearning::new()
513 .k_neighbors(2)
514 .forgetting_factor(0.5)
515 .edge_threshold(0.1);
516
517 let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
518
519 dgl.initialize(initial_data.view()).unwrap();
520
521 let original_adjacency = dgl.get_adjacency_matrix().unwrap().clone();
522
523 dgl.apply_temporal_decay().unwrap();
524
525 let decayed_adjacency = dgl.get_adjacency_matrix().unwrap();
526
527 for i in 0..2 {
529 for j in 0..2 {
530 if i != j && original_adjacency[[i, j]] > 0.0 {
531 assert!(decayed_adjacency[[i, j]] < original_adjacency[[i, j]]);
532 }
533 }
534 }
535 }
536
537 #[test]
538 fn test_max_nodes_constraint() {
539 let mut dgl = DynamicGraphLearning::new().k_neighbors(2).max_nodes(3);
540
541 let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
542
543 dgl.initialize(initial_data.view()).unwrap();
544
545 let new_data = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
546
547 let result = dgl.add_nodes(new_data.view());
548 assert!(result.is_ok());
549
550 let adjacency = dgl.get_adjacency_matrix().unwrap();
551 assert_eq!(adjacency.nrows(), 3); }
553
554 #[test]
555 fn test_graph_statistics() {
556 let mut dgl = DynamicGraphLearning::new().k_neighbors(2);
557
558 let initial_data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
559
560 dgl.initialize(initial_data.view()).unwrap();
561
562 let stats = dgl.get_statistics();
563
564 assert!(stats.contains_key("n_nodes"));
565 assert!(stats.contains_key("n_edges"));
566 assert!(stats.contains_key("density"));
567 assert!(stats.contains_key("avg_degree"));
568
569 assert_eq!(stats["n_nodes"], 3.0);
570 assert!(stats["n_edges"] > 0.0);
571 }
572
573 #[test]
574 fn test_update_buffer() {
575 let mut dgl = DynamicGraphLearning::new().buffer_size(2);
576
577 let initial_data = array![[1.0, 2.0], [2.0, 3.0]];
578
579 dgl.initialize(initial_data.view()).unwrap();
580
581 let new_features = array![5.0, 6.0];
582 dgl.update_node_features(0, new_features.view()).unwrap();
583 dgl.update_node_features(1, new_features.view()).unwrap();
584 dgl.update_node_features(0, new_features.view()).unwrap();
585
586 let recent_updates = dgl.get_recent_updates(5);
587 assert!(recent_updates.len() <= 2); }
589
590 #[test]
591 fn test_error_cases() {
592 let mut dgl = DynamicGraphLearning::new();
593
594 let new_data = array![[1.0, 2.0]];
596 assert!(dgl.add_nodes(new_data.view()).is_err());
597
598 let new_features = array![5.0, 6.0];
599 assert!(dgl.update_node_features(0, new_features.view()).is_err());
600
601 let empty_data = Array2::<f64>::zeros((0, 2));
603 assert!(dgl.initialize(empty_data.view()).is_err());
604
605 let initial_data = array![[1.0, 2.0]];
607 dgl.initialize(initial_data.view()).unwrap();
608 assert!(dgl.update_node_features(10, new_features.view()).is_err());
609 }
610}