1use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::{Rng, RngExt};
16
17use crate::error::{GraphError, Result};
18
19#[derive(Debug, Clone)]
25pub struct TemporalEvent {
26 pub source: usize,
28 pub target: usize,
30 pub timestamp: f64,
32 pub features: Option<Vec<f64>>,
34}
35
36impl TemporalEvent {
37 pub fn new(source: usize, target: usize, timestamp: f64) -> Self {
39 TemporalEvent {
40 source,
41 target,
42 timestamp,
43 features: None,
44 }
45 }
46
47 pub fn with_features(source: usize, target: usize, timestamp: f64, features: Vec<f64>) -> Self {
49 TemporalEvent {
50 source,
51 target,
52 timestamp,
53 features: Some(features),
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum TimeEncodingType {
65 Sinusoidal,
67 Time2Vec,
69}
70
71#[derive(Debug, Clone)]
77pub struct TimeEncoding {
78 pub encoding_type: TimeEncodingType,
80 pub time_dim: usize,
82 pub omega: Array1<f64>,
84 pub phi: Array1<f64>,
86 pub linear_weight: f64,
88 pub linear_bias: f64,
90}
91
92impl TimeEncoding {
93 pub fn new(time_dim: usize, encoding_type: TimeEncodingType) -> Self {
99 let mut rng = scirs2_core::random::rng();
100
101 let omega = match &encoding_type {
102 TimeEncodingType::Sinusoidal => {
103 Array1::from_iter(
105 (0..time_dim)
106 .map(|i| 1.0 / 10000.0_f64.powf(2.0 * (i / 2) as f64 / time_dim as f64)),
107 )
108 }
109 TimeEncodingType::Time2Vec => {
110 Array1::from_iter((0..time_dim).map(|_| rng.random::<f64>() * 2.0))
112 }
113 };
114
115 let phi =
116 Array1::from_iter((0..time_dim).map(|_| rng.random::<f64>() * std::f64::consts::TAU));
117
118 TimeEncoding {
119 encoding_type,
120 time_dim,
121 omega,
122 phi,
123 linear_weight: rng.random::<f64>() * 0.1,
124 linear_bias: 0.0,
125 }
126 }
127
128 pub fn encode(&self, t: f64) -> Array1<f64> {
136 let mut encoding = Array1::zeros(self.time_dim);
137
138 match self.encoding_type {
139 TimeEncodingType::Sinusoidal => {
140 for i in 0..self.time_dim {
141 let angle = t * self.omega[i];
142 if i % 2 == 0 {
143 encoding[i] = angle.sin();
144 } else {
145 encoding[i] = angle.cos();
146 }
147 }
148 }
149 TimeEncodingType::Time2Vec => {
150 if self.time_dim > 0 {
152 encoding[0] = self.linear_weight * t + self.linear_bias;
153 }
154 for i in 1..self.time_dim {
156 encoding[i] = (self.omega[i] * t + self.phi[i]).sin();
157 }
158 }
159 }
160
161 encoding
162 }
163
164 pub fn encode_batch(&self, timestamps: &[f64]) -> Array2<f64> {
172 let n = timestamps.len();
173 let mut result = Array2::zeros((n, self.time_dim));
174 for (i, &t) in timestamps.iter().enumerate() {
175 let enc = self.encode(t);
176 for j in 0..self.time_dim {
177 result[[i, j]] = enc[j];
178 }
179 }
180 result
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
190pub enum MemoryUpdateMethod {
191 Gru,
193 Mlp,
195}
196
197#[derive(Debug, Clone)]
202pub struct MemoryModule {
203 pub memory: Array2<f64>,
205 pub last_update: Vec<f64>,
207 pub memory_dim: usize,
209 pub time_dim: usize,
211 pub n_nodes: usize,
213 pub update_method: MemoryUpdateMethod,
215 pub time_encoding: TimeEncoding,
217 pub message_dim: usize,
219
220 gru_wz: Array2<f64>,
223 gru_uz: Array2<f64>,
225 gru_wr: Array2<f64>,
227 gru_ur: Array2<f64>,
229 gru_wh: Array2<f64>,
231 gru_uh: Array2<f64>,
233 gru_bz: Array1<f64>,
235 gru_br: Array1<f64>,
237 gru_bh: Array1<f64>,
239
240 mlp_w: Array2<f64>,
243 mlp_b: Array1<f64>,
245}
246
247impl MemoryModule {
248 pub fn new(
256 n_nodes: usize,
257 memory_dim: usize,
258 time_dim: usize,
259 update_method: MemoryUpdateMethod,
260 ) -> Self {
261 let mut rng = scirs2_core::random::rng();
262 let time_encoding = TimeEncoding::new(time_dim, TimeEncodingType::Time2Vec);
263
264 let message_dim = memory_dim + memory_dim + time_dim;
266
267 let scale_gru = (6.0_f64 / (message_dim + memory_dim) as f64).sqrt();
268 let scale_u = (6.0_f64 / (2 * memory_dim) as f64).sqrt();
269 let scale_mlp = (6.0_f64 / (message_dim + 2 * memory_dim) as f64).sqrt();
270
271 let mut init = |r: usize, c: usize, s: f64| -> Array2<f64> {
272 Array2::from_shape_fn((r, c), |_| (rng.random::<f64>() * 2.0 - 1.0) * s)
273 };
274
275 MemoryModule {
276 memory: Array2::zeros((n_nodes, memory_dim)),
277 last_update: vec![0.0; n_nodes],
278 memory_dim,
279 time_dim,
280 n_nodes,
281 update_method,
282 time_encoding,
283 message_dim,
284 gru_wz: init(message_dim, memory_dim, scale_gru),
285 gru_uz: init(memory_dim, memory_dim, scale_u),
286 gru_wr: init(message_dim, memory_dim, scale_gru),
287 gru_ur: init(memory_dim, memory_dim, scale_u),
288 gru_wh: init(message_dim, memory_dim, scale_gru),
289 gru_uh: init(memory_dim, memory_dim, scale_u),
290 gru_bz: Array1::zeros(memory_dim),
291 gru_br: Array1::zeros(memory_dim),
292 gru_bh: Array1::zeros(memory_dim),
293 mlp_w: init(message_dim + memory_dim, memory_dim, scale_mlp),
294 mlp_b: Array1::zeros(memory_dim),
295 }
296 }
297
298 fn compute_message(&self, event: &TemporalEvent) -> Vec<f64> {
302 let src = event.source;
303 let tgt = event.target;
304 let delta_t = event.timestamp - self.last_update[src].max(self.last_update[tgt]);
305 let time_enc = self.time_encoding.encode(delta_t);
306
307 let mut msg = Vec::with_capacity(self.message_dim);
308
309 for j in 0..self.memory_dim {
311 msg.push(if src < self.n_nodes {
312 self.memory[[src, j]]
313 } else {
314 0.0
315 });
316 }
317
318 for j in 0..self.memory_dim {
320 msg.push(if tgt < self.n_nodes {
321 self.memory[[tgt, j]]
322 } else {
323 0.0
324 });
325 }
326
327 for j in 0..self.time_dim {
329 msg.push(time_enc[j]);
330 }
331
332 msg
333 }
334
335 fn gru_update(&self, memory: &[f64], message: &[f64]) -> Vec<f64> {
344 let d = self.memory_dim;
345 let m = self.message_dim;
346
347 let mut z = vec![0.0f64; d];
349 for j in 0..d {
350 let mut s = self.gru_bz[j];
351 for k in 0..m {
352 s += message[k] * self.gru_wz[[k, j]];
353 }
354 for k in 0..d {
355 s += memory[k] * self.gru_uz[[k, j]];
356 }
357 z[j] = sigmoid(s);
358 }
359
360 let mut r = vec![0.0f64; d];
362 for j in 0..d {
363 let mut s = self.gru_br[j];
364 for k in 0..m {
365 s += message[k] * self.gru_wr[[k, j]];
366 }
367 for k in 0..d {
368 s += memory[k] * self.gru_ur[[k, j]];
369 }
370 r[j] = sigmoid(s);
371 }
372
373 let mut h_tilde = vec![0.0f64; d];
375 for j in 0..d {
376 let mut s = self.gru_bh[j];
377 for k in 0..m {
378 s += message[k] * self.gru_wh[[k, j]];
379 }
380 for k in 0..d {
381 s += (r[k] * memory[k]) * self.gru_uh[[k, j]];
382 }
383 h_tilde[j] = s.tanh();
384 }
385
386 let mut h_new = vec![0.0f64; d];
388 for j in 0..d {
389 h_new[j] = (1.0 - z[j]) * memory[j] + z[j] * h_tilde[j];
390 }
391
392 h_new
393 }
394
395 fn mlp_update(&self, memory: &[f64], message: &[f64]) -> Vec<f64> {
397 let d = self.memory_dim;
398 let total_in = self.message_dim + d;
399
400 let mut input = Vec::with_capacity(total_in);
402 input.extend_from_slice(message);
403 input.extend_from_slice(memory);
404
405 let mut out = vec![0.0f64; d];
407 for j in 0..d {
408 let mut s = self.mlp_b[j];
409 for k in 0..total_in {
410 s += input[k] * self.mlp_w[[k, j]];
411 }
412 out[j] = s.tanh();
413 }
414
415 out
416 }
417
418 pub fn process_event(&mut self, event: &TemporalEvent) -> Result<()> {
425 if event.source >= self.n_nodes || event.target >= self.n_nodes {
426 return Err(GraphError::InvalidParameter {
427 param: "event".to_string(),
428 value: format!("source={}, target={}", event.source, event.target),
429 expected: format!("indices < {}", self.n_nodes),
430 context: "MemoryModule::process_event".to_string(),
431 });
432 }
433
434 let message = self.compute_message(event);
435
436 let src_memory: Vec<f64> = (0..self.memory_dim)
438 .map(|j| self.memory[[event.source, j]])
439 .collect();
440 let new_src = match self.update_method {
441 MemoryUpdateMethod::Gru => self.gru_update(&src_memory, &message),
442 MemoryUpdateMethod::Mlp => self.mlp_update(&src_memory, &message),
443 };
444
445 let tgt_memory: Vec<f64> = (0..self.memory_dim)
447 .map(|j| self.memory[[event.target, j]])
448 .collect();
449 let new_tgt = match self.update_method {
450 MemoryUpdateMethod::Gru => self.gru_update(&tgt_memory, &message),
451 MemoryUpdateMethod::Mlp => self.mlp_update(&tgt_memory, &message),
452 };
453
454 for j in 0..self.memory_dim {
456 self.memory[[event.source, j]] = new_src[j];
457 self.memory[[event.target, j]] = new_tgt[j];
458 }
459
460 self.last_update[event.source] = event.timestamp;
461 self.last_update[event.target] = event.timestamp;
462
463 Ok(())
464 }
465
466 pub fn process_events(&mut self, events: &[TemporalEvent]) -> Result<()> {
470 for event in events {
471 self.process_event(event)?;
472 }
473 Ok(())
474 }
475
476 pub fn get_memory(&self) -> &Array2<f64> {
478 &self.memory
479 }
480
481 pub fn reset(&mut self) {
483 self.memory.fill(0.0);
484 self.last_update.fill(0.0);
485 }
486}
487
488#[inline]
490fn sigmoid(x: f64) -> f64 {
491 1.0 / (1.0 + (-x).exp())
492}
493
494#[derive(Debug, Clone)]
504pub struct TemporalAttention {
505 pub w_q: Array2<f64>,
507 pub w_k: Array2<f64>,
509 pub w_v: Array2<f64>,
511 pub num_heads: usize,
513 pub hidden_dim: usize,
515 pub head_dim: usize,
517 pub time_encoding: TimeEncoding,
519 pub memory_dim: usize,
521 pub time_dim: usize,
523}
524
525impl TemporalAttention {
526 pub fn new(memory_dim: usize, time_dim: usize, num_heads: usize) -> Result<Self> {
533 let hidden_dim = memory_dim;
534 if !hidden_dim.is_multiple_of(num_heads) {
535 return Err(GraphError::InvalidParameter {
536 param: "memory_dim".to_string(),
537 value: format!("{memory_dim}"),
538 expected: format!("divisible by num_heads={num_heads}"),
539 context: "TemporalAttention::new".to_string(),
540 });
541 }
542
543 let head_dim = hidden_dim / num_heads;
544 let mut rng = scirs2_core::random::rng();
545 let scale_q = (6.0_f64 / (memory_dim + hidden_dim) as f64).sqrt();
546 let scale_kv = (6.0_f64 / (memory_dim + time_dim + hidden_dim) as f64).sqrt();
547
548 let w_q = Array2::from_shape_fn((memory_dim, hidden_dim), |_| {
549 (rng.random::<f64>() * 2.0 - 1.0) * scale_q
550 });
551 let w_k = Array2::from_shape_fn((memory_dim + time_dim, hidden_dim), |_| {
552 (rng.random::<f64>() * 2.0 - 1.0) * scale_kv
553 });
554 let w_v = Array2::from_shape_fn((memory_dim + time_dim, hidden_dim), |_| {
555 (rng.random::<f64>() * 2.0 - 1.0) * scale_kv
556 });
557
558 let time_encoding = TimeEncoding::new(time_dim, TimeEncodingType::Sinusoidal);
559
560 Ok(TemporalAttention {
561 w_q,
562 w_k,
563 w_v,
564 num_heads,
565 hidden_dim,
566 head_dim,
567 time_encoding,
568 memory_dim,
569 time_dim,
570 })
571 }
572
573 pub fn forward(
583 &self,
584 query_memory: &Array1<f64>,
585 neighbor_memories: &Array2<f64>,
586 time_deltas: &[f64],
587 ) -> Result<Array1<f64>> {
588 let num_neighbors = neighbor_memories.dim().0;
589 if num_neighbors == 0 {
590 return Ok(Array1::zeros(self.hidden_dim));
591 }
592 if time_deltas.len() != num_neighbors {
593 return Err(GraphError::InvalidParameter {
594 param: "time_deltas".to_string(),
595 value: format!("len={}", time_deltas.len()),
596 expected: format!("len={num_neighbors}"),
597 context: "TemporalAttention::forward".to_string(),
598 });
599 }
600
601 let h = self.num_heads;
602 let dk = self.head_dim;
603 let scale = 1.0 / (dk as f64).sqrt();
604
605 let mut q = vec![0.0f64; self.hidden_dim];
607 for j in 0..self.hidden_dim {
608 for m in 0..self.memory_dim {
609 q[j] += query_memory[m] * self.w_q[[m, j]];
610 }
611 }
612
613 let kv_in_dim = self.memory_dim + self.time_dim;
615 let mut keys = vec![vec![0.0f64; self.hidden_dim]; num_neighbors];
616 let mut values = vec![vec![0.0f64; self.hidden_dim]; num_neighbors];
617
618 for nb in 0..num_neighbors {
619 let time_enc = self.time_encoding.encode(time_deltas[nb]);
621 let mut kv_input = Vec::with_capacity(kv_in_dim);
622 for m in 0..self.memory_dim {
623 kv_input.push(neighbor_memories[[nb, m]]);
624 }
625 for m in 0..self.time_dim {
626 kv_input.push(time_enc[m]);
627 }
628
629 for j in 0..self.hidden_dim {
630 let mut sk = 0.0;
631 let mut sv = 0.0;
632 for m in 0..kv_in_dim {
633 sk += kv_input[m] * self.w_k[[m, j]];
634 sv += kv_input[m] * self.w_v[[m, j]];
635 }
636 keys[nb][j] = sk;
637 values[nb][j] = sv;
638 }
639 }
640
641 let mut output = vec![0.0f64; self.hidden_dim];
643
644 for head in 0..h {
645 let offset = head * dk;
646
647 let mut scores = vec![0.0f64; num_neighbors];
649 for nb in 0..num_neighbors {
650 let mut dot = 0.0;
651 for m in 0..dk {
652 dot += q[offset + m] * keys[nb][offset + m];
653 }
654 scores[nb] = dot * scale;
655 }
656
657 let alphas = softmax_slice(&scores);
659
660 for nb in 0..num_neighbors {
662 for m in 0..dk {
663 output[offset + m] += alphas[nb] * values[nb][offset + m];
664 }
665 }
666 }
667
668 Ok(Array1::from_vec(output))
669 }
670}
671
672fn softmax_slice(xs: &[f64]) -> Vec<f64> {
674 if xs.is_empty() {
675 return Vec::new();
676 }
677 let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
678 let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
679 let sum = exps.iter().sum::<f64>().max(1e-12);
680 exps.iter().map(|e| e / sum).collect()
681}
682
683#[derive(Debug, Clone)]
689pub struct TemporalGnnConfig {
690 pub n_nodes: usize,
692 pub memory_dim: usize,
694 pub time_dim: usize,
696 pub num_heads: usize,
698 pub update_method: MemoryUpdateMethod,
700 pub time_encoding_type: TimeEncodingType,
702}
703
704impl Default for TemporalGnnConfig {
705 fn default() -> Self {
706 TemporalGnnConfig {
707 n_nodes: 100,
708 memory_dim: 64,
709 time_dim: 16,
710 num_heads: 4,
711 update_method: MemoryUpdateMethod::Gru,
712 time_encoding_type: TimeEncodingType::Time2Vec,
713 }
714 }
715}
716
717#[derive(Debug, Clone)]
726pub struct TemporalGnnModel {
727 pub memory_module: MemoryModule,
729 pub temporal_attention: TemporalAttention,
731 pub config: TemporalGnnConfig,
733 event_history: Vec<TemporalEvent>,
735}
736
737impl TemporalGnnModel {
738 pub fn new(config: TemporalGnnConfig) -> Result<Self> {
740 let memory_module = MemoryModule::new(
741 config.n_nodes,
742 config.memory_dim,
743 config.time_dim,
744 config.update_method.clone(),
745 );
746 let temporal_attention =
747 TemporalAttention::new(config.memory_dim, config.time_dim, config.num_heads)?;
748
749 Ok(TemporalGnnModel {
750 memory_module,
751 temporal_attention,
752 config,
753 event_history: Vec::new(),
754 })
755 }
756
757 pub fn process_events(&mut self, events: &[TemporalEvent]) -> Result<()> {
761 self.memory_module.process_events(events)?;
762 self.event_history.extend(events.iter().cloned());
763 Ok(())
764 }
765
766 pub fn get_node_embedding(
776 &self,
777 node: usize,
778 current_time: f64,
779 max_neighbors: usize,
780 ) -> Result<Array1<f64>> {
781 if node >= self.config.n_nodes {
782 return Err(GraphError::InvalidParameter {
783 param: "node".to_string(),
784 value: format!("{node}"),
785 expected: format!("< {}", self.config.n_nodes),
786 context: "TemporalGnnModel::get_node_embedding".to_string(),
787 });
788 }
789
790 let mut neighbor_events: Vec<(usize, f64)> = Vec::new();
792 for event in self.event_history.iter().rev() {
793 if neighbor_events.len() >= max_neighbors {
794 break;
795 }
796 if event.source == node {
797 neighbor_events.push((event.target, event.timestamp));
798 } else if event.target == node {
799 neighbor_events.push((event.source, event.timestamp));
800 }
801 }
802
803 if neighbor_events.is_empty() {
804 return Ok(Array1::from_iter(
806 (0..self.config.memory_dim).map(|j| self.memory_module.memory[[node, j]]),
807 ));
808 }
809
810 let num_nb = neighbor_events.len();
812 let mut nb_memories = Array2::zeros((num_nb, self.config.memory_dim));
813 let mut time_deltas = vec![0.0f64; num_nb];
814
815 for (idx, &(nb_node, nb_time)) in neighbor_events.iter().enumerate() {
816 for j in 0..self.config.memory_dim {
817 nb_memories[[idx, j]] = self.memory_module.memory[[nb_node, j]];
818 }
819 time_deltas[idx] = current_time - nb_time;
820 }
821
822 let query = Array1::from_iter(
824 (0..self.config.memory_dim).map(|j| self.memory_module.memory[[node, j]]),
825 );
826
827 self.temporal_attention
828 .forward(&query, &nb_memories, &time_deltas)
829 }
830
831 pub fn reset(&mut self) {
833 self.memory_module.reset();
834 self.event_history.clear();
835 }
836
837 pub fn get_memory(&self) -> &Array2<f64> {
839 self.memory_module.get_memory()
840 }
841}
842
843#[cfg(test)]
848mod tests {
849 use super::*;
850
851 #[test]
852 fn test_time_encoding_sinusoidal_varies_with_time() {
853 let te = TimeEncoding::new(8, TimeEncodingType::Sinusoidal);
854 let enc1 = te.encode(0.0);
855 let enc2 = te.encode(1.0);
856 let enc3 = te.encode(10.0);
857
858 assert_eq!(enc1.len(), 8);
859 assert_eq!(enc2.len(), 8);
860
861 let diff_12: f64 = enc1
863 .iter()
864 .zip(enc2.iter())
865 .map(|(a, b)| (a - b).abs())
866 .sum();
867 let diff_13: f64 = enc1
868 .iter()
869 .zip(enc3.iter())
870 .map(|(a, b)| (a - b).abs())
871 .sum();
872
873 assert!(diff_12 > 1e-6, "encodings at t=0 and t=1 should differ");
874 assert!(diff_13 > 1e-6, "encodings at t=0 and t=10 should differ");
875 }
876
877 #[test]
878 fn test_time_encoding_time2vec() {
879 let te = TimeEncoding::new(6, TimeEncodingType::Time2Vec);
880 let enc = te.encode(5.0);
881 assert_eq!(enc.len(), 6);
882 for &v in enc.iter() {
883 assert!(v.is_finite(), "Time2Vec encoding should be finite");
884 }
885
886 let enc0 = te.encode(0.0);
888 let enc10 = te.encode(10.0);
889 let expected_diff = te.linear_weight * 10.0;
891 let actual_diff = enc10[0] - enc0[0];
892 assert!(
893 (actual_diff - expected_diff).abs() < 1e-10,
894 "first component should be linear"
895 );
896 }
897
898 #[test]
899 fn test_memory_update_changes_state() {
900 let mut mem = MemoryModule::new(5, 8, 4, MemoryUpdateMethod::Gru);
901
902 let initial_norm: f64 = mem.memory.iter().map(|x| x * x).sum();
904 assert!(initial_norm < 1e-12, "initial memory should be zero");
905
906 let event = TemporalEvent::new(0, 1, 1.0);
908 mem.process_event(&event).expect("process event");
909
910 let node0_norm: f64 = (0..8).map(|j| mem.memory[[0, j]].powi(2)).sum();
912 let node1_norm: f64 = (0..8).map(|j| mem.memory[[1, j]].powi(2)).sum();
913
914 assert!(
915 node0_norm > 1e-12,
916 "node 0 memory should be updated after event"
917 );
918 assert!(
919 node1_norm > 1e-12,
920 "node 1 memory should be updated after event"
921 );
922
923 let node2_norm: f64 = (0..8).map(|j| mem.memory[[2, j]].powi(2)).sum();
925 assert!(node2_norm < 1e-12, "node 2 memory should remain zero");
926 }
927
928 #[test]
929 fn test_memory_update_mlp() {
930 let mut mem = MemoryModule::new(4, 6, 3, MemoryUpdateMethod::Mlp);
931 let event = TemporalEvent::new(0, 1, 0.5);
932 mem.process_event(&event).expect("process event MLP");
933
934 let node0_norm: f64 = (0..6).map(|j| mem.memory[[0, j]].powi(2)).sum();
935 assert!(node0_norm > 1e-12, "MLP update should modify memory");
936 }
937
938 #[test]
939 fn test_temporal_attention_shape() {
940 let ta = TemporalAttention::new(8, 4, 2).expect("temporal attention");
941 let query = Array1::from_vec(vec![0.1; 8]);
942 let neighbors = Array2::from_shape_fn((3, 8), |(i, j)| (i + j) as f64 * 0.05);
943 let deltas = vec![1.0, 2.0, 3.0];
944
945 let out = ta.forward(&query, &neighbors, &deltas).expect("forward");
946 assert_eq!(out.len(), 8);
947 for &v in out.iter() {
948 assert!(v.is_finite(), "temporal attention output should be finite");
949 }
950 }
951
952 #[test]
953 fn test_temporal_attention_empty_neighbors() {
954 let ta = TemporalAttention::new(8, 4, 2).expect("temporal attention");
955 let query = Array1::from_vec(vec![0.1; 8]);
956 let neighbors = Array2::zeros((0, 8));
957 let deltas: Vec<f64> = Vec::new();
958
959 let out = ta
960 .forward(&query, &neighbors, &deltas)
961 .expect("empty forward");
962 assert_eq!(out.len(), 8);
963 let norm: f64 = out.iter().map(|x| x * x).sum();
965 assert!(norm < 1e-12, "empty neighbor attention should return zeros");
966 }
967
968 #[test]
969 fn test_temporal_gnn_model_full_pipeline() {
970 let config = TemporalGnnConfig {
971 n_nodes: 5,
972 memory_dim: 8,
973 time_dim: 4,
974 num_heads: 2,
975 update_method: MemoryUpdateMethod::Gru,
976 time_encoding_type: TimeEncodingType::Time2Vec,
977 };
978
979 let mut model = TemporalGnnModel::new(config).expect("model");
980
981 let events = vec![
983 TemporalEvent::new(0, 1, 1.0),
984 TemporalEvent::new(1, 2, 2.0),
985 TemporalEvent::new(0, 2, 3.0),
986 TemporalEvent::new(2, 3, 4.0),
987 ];
988 model.process_events(&events).expect("process events");
989
990 let emb0 = model.get_node_embedding(0, 5.0, 3).expect("embedding 0");
992 let emb4 = model.get_node_embedding(4, 5.0, 3).expect("embedding 4");
993
994 assert_eq!(emb0.len(), 8);
995 assert_eq!(emb4.len(), 8);
996
997 let diff: f64 = emb0
999 .iter()
1000 .zip(emb4.iter())
1001 .map(|(a, b)| (a - b).abs())
1002 .sum();
1003 assert!(
1005 emb0.iter().any(|&v| v.abs() > 1e-12),
1006 "active node should have non-zero embedding"
1007 );
1008 }
1009
1010 #[test]
1011 fn test_memory_module_event_out_of_bounds() {
1012 let mut mem = MemoryModule::new(3, 4, 2, MemoryUpdateMethod::Gru);
1013 let event = TemporalEvent::new(0, 5, 1.0); let result = mem.process_event(&event);
1015 assert!(result.is_err());
1016 }
1017
1018 #[test]
1019 fn test_temporal_gnn_reset() {
1020 let config = TemporalGnnConfig {
1021 n_nodes: 3,
1022 memory_dim: 4,
1023 time_dim: 2,
1024 num_heads: 2,
1025 ..Default::default()
1026 };
1027
1028 let mut model = TemporalGnnModel::new(config).expect("model");
1029 let event = TemporalEvent::new(0, 1, 1.0);
1030 model.process_events(&[event]).expect("process");
1031
1032 model.reset();
1034 let mem_norm: f64 = model.get_memory().iter().map(|x| x * x).sum();
1035 assert!(mem_norm < 1e-12, "memory should be zero after reset");
1036 }
1037
1038 #[test]
1039 fn test_time_encoding_batch() {
1040 let te = TimeEncoding::new(4, TimeEncodingType::Sinusoidal);
1041 let timestamps = vec![0.0, 1.0, 5.0, 10.0];
1042 let batch = te.encode_batch(×tamps);
1043 assert_eq!(batch.dim(), (4, 4));
1044
1045 for (i, &t) in timestamps.iter().enumerate() {
1047 let single = te.encode(t);
1048 for j in 0..4 {
1049 assert!(
1050 (batch[[i, j]] - single[j]).abs() < 1e-12,
1051 "batch encoding should match single encoding"
1052 );
1053 }
1054 }
1055 }
1056
1057 #[test]
1058 fn test_memory_timestamps_updated() {
1059 let mut mem = MemoryModule::new(3, 4, 2, MemoryUpdateMethod::Gru);
1060
1061 assert!(mem.last_update[0] < 1e-12);
1062 assert!(mem.last_update[1] < 1e-12);
1063
1064 let event = TemporalEvent::new(0, 1, 5.0);
1065 mem.process_event(&event).expect("process");
1066
1067 assert!(
1068 (mem.last_update[0] - 5.0).abs() < 1e-12,
1069 "source timestamp should be updated"
1070 );
1071 assert!(
1072 (mem.last_update[1] - 5.0).abs() < 1e-12,
1073 "target timestamp should be updated"
1074 );
1075 assert!(
1076 mem.last_update[2] < 1e-12,
1077 "uninvolved node timestamp should remain 0"
1078 );
1079 }
1080}