1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use std::collections::{BTreeMap, HashMap};
20use torsh_tensor::{
21 creation::{from_vec, randn, zeros},
22 Tensor,
23};
24
25#[derive(Debug, Clone)]
27pub struct TemporalEvent {
28 pub time: f64,
30 pub event_type: EventType,
32 pub source: Option<usize>,
34 pub target: Option<usize>,
36 pub node: Option<usize>,
38 pub features: Option<Tensor>,
40 pub weight: Option<f32>,
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum EventType {
47 NodeAddition,
48 NodeDeletion,
49 NodeFeatureUpdate,
50 EdgeAddition,
51 EdgeDeletion,
52 EdgeFeatureUpdate,
53 GraphSnapshot,
54}
55
56#[derive(Debug, Clone)]
58pub struct TemporalGraphData {
59 pub current_graph: GraphData,
61 pub events: BTreeMap<u64, Vec<TemporalEvent>>, pub node_features_history: HashMap<usize, BTreeMap<u64, Tensor>>,
65 pub edge_features_history: HashMap<(usize, usize), BTreeMap<u64, Tensor>>,
67 pub current_time: f64,
69 pub time_window: f64,
71 pub max_events: usize,
73}
74
75impl TemporalGraphData {
76 pub fn new(initial_graph: GraphData, time_window: f64, max_events: usize) -> Self {
78 Self {
79 current_graph: initial_graph,
80 events: BTreeMap::new(),
81 node_features_history: HashMap::new(),
82 edge_features_history: HashMap::new(),
83 current_time: 0.0,
84 time_window,
85 max_events,
86 }
87 }
88
89 pub fn add_event(&mut self, event: TemporalEvent) {
91 let timestamp = (event.time * 1000.0) as u64; self.events
93 .entry(timestamp)
94 .or_insert_with(Vec::new)
95 .push(event.clone());
96
97 self.current_time = self.current_time.max(event.time);
99
100 self.apply_event(&event);
102
103 self.cleanup_old_events();
105 }
106
107 fn apply_event(&mut self, event: &TemporalEvent) {
109 match event.event_type {
110 EventType::NodeFeatureUpdate => {
111 if let (Some(node), Some(ref features)) = (event.node, &event.features) {
112 self.update_node_features(node, features.clone());
114
115 let timestamp = (event.time * 1000.0) as u64;
117 self.node_features_history
118 .entry(node)
119 .or_insert_with(BTreeMap::new)
120 .insert(timestamp, features.clone());
121 }
122 }
123 EventType::EdgeFeatureUpdate => {
124 if let (Some(source), Some(target), Some(ref features)) =
125 (event.source, event.target, &event.features)
126 {
127 let timestamp = (event.time * 1000.0) as u64;
128 self.edge_features_history
129 .entry((source, target))
130 .or_insert_with(BTreeMap::new)
131 .insert(timestamp, features.clone());
132 }
133 }
134 _ => {
135 }
138 }
139 }
140
141 fn update_node_features(&mut self, node_id: usize, features: Tensor) {
143 let current_features = self
145 .current_graph
146 .x
147 .to_vec()
148 .expect("conversion should succeed");
149 let feature_dim = self.current_graph.x.shape().dims()[1];
150 let new_features = features.to_vec().expect("conversion should succeed");
151
152 let mut updated_features = current_features;
153 let start_idx = node_id * feature_dim;
154 let _end_idx = start_idx + feature_dim.min(new_features.len());
155
156 for (i, &value) in new_features.iter().take(feature_dim).enumerate() {
157 if start_idx + i < updated_features.len() {
158 updated_features[start_idx + i] = value;
159 }
160 }
161
162 self.current_graph.x = from_vec(
163 updated_features,
164 &[self.current_graph.num_nodes, feature_dim],
165 torsh_core::device::DeviceType::Cpu,
166 )
167 .expect("from_vec updated_features should succeed");
168 }
169
170 fn cleanup_old_events(&mut self) {
172 let cutoff_time = ((self.current_time - self.time_window) * 1000.0) as u64;
173
174 let old_keys: Vec<u64> = self
176 .events
177 .keys()
178 .filter(|&×tamp| timestamp < cutoff_time)
179 .cloned()
180 .collect();
181
182 for key in old_keys {
183 self.events.remove(&key);
184 }
185
186 while self.events.len() > self.max_events {
188 if let Some(first_key) = self.events.keys().next().cloned() {
189 self.events.remove(&first_key);
190 } else {
191 break;
192 }
193 }
194 }
195
196 pub fn get_events_in_range(&self, start_time: f64, end_time: f64) -> Vec<&TemporalEvent> {
198 let start_timestamp = (start_time * 1000.0) as u64;
199 let end_timestamp = (end_time * 1000.0) as u64;
200
201 self.events
202 .range(start_timestamp..=end_timestamp)
203 .flat_map(|(_, events)| events.iter())
204 .collect()
205 }
206
207 pub fn get_node_features_at_time(&self, node_id: usize, time: f64) -> Option<Tensor> {
209 let timestamp = (time * 1000.0) as u64;
210
211 if let Some(history) = self.node_features_history.get(&node_id) {
212 if let Some((_, features)) = history.range(..=timestamp).next_back() {
214 return Some(features.clone());
215 }
216 }
217
218 None
219 }
220
221 pub fn snapshot_at_time(&self, _time: f64) -> GraphData {
223 self.current_graph.clone()
226 }
227}
228
229#[derive(Debug)]
231pub struct TGCNConv {
232 in_features: usize,
233 out_features: usize,
234 temporal_dim: usize,
235 spatial_weight: Parameter,
236 temporal_weight: Parameter,
237 bias: Option<Parameter>,
238 memory_size: usize,
239 time_encoding_dim: usize,
240}
241
242impl TGCNConv {
243 pub fn new(
245 in_features: usize,
246 out_features: usize,
247 temporal_dim: usize,
248 memory_size: usize,
249 bias: bool,
250 ) -> Self {
251 let spatial_weight = Parameter::new(
252 randn(&[in_features, out_features]).expect("randn spatial_weight should succeed"),
253 );
254 let temporal_weight = Parameter::new(
255 randn(&[temporal_dim, out_features]).expect("randn temporal_weight should succeed"),
256 );
257 let bias = if bias {
258 Some(Parameter::new(
259 zeros(&[out_features]).expect("zeros bias should succeed"),
260 ))
261 } else {
262 None
263 };
264
265 Self {
266 in_features,
267 out_features,
268 temporal_dim,
269 spatial_weight,
270 temporal_weight,
271 bias,
272 memory_size,
273 time_encoding_dim: temporal_dim,
274 }
275 }
276
277 pub fn forward(&self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
279 let spatial_features = temporal_graph
281 .current_graph
282 .x
283 .matmul(&self.spatial_weight.clone_data())
284 .expect("matmul spatial_features should succeed");
285
286 let temporal_features = self.encode_temporal_context(temporal_graph);
288
289 let combined_features = spatial_features
291 .add(&temporal_features)
292 .expect("operation should succeed");
293
294 let output_features = if let Some(ref bias) = self.bias {
296 combined_features
297 .add(&bias.clone_data())
298 .expect("operation should succeed")
299 } else {
300 combined_features
301 };
302
303 let mut output_graph = temporal_graph.clone();
305 output_graph.current_graph.x = output_features;
306 output_graph
307 }
308
309 fn encode_temporal_context(&self, temporal_graph: &TemporalGraphData) -> Tensor {
311 let num_nodes = temporal_graph.current_graph.num_nodes;
312 let current_time = temporal_graph.current_time;
313 let lookback_time = current_time - temporal_graph.time_window;
314
315 let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
317
318 let _temporal_encoding = zeros::<f32>(&[num_nodes, self.out_features])
320 .expect("zeros temporal_encoding should succeed");
321
322 let mut node_event_counts = vec![0.0; num_nodes];
324
325 for event in recent_events {
326 if let Some(node_id) = event.node {
327 if node_id < num_nodes {
328 let recency_weight =
330 1.0 - (current_time - event.time) / temporal_graph.time_window;
331 node_event_counts[node_id] += recency_weight;
332 }
333 }
334 }
335
336 let temporal_data: Vec<f32> = node_event_counts
338 .iter()
339 .flat_map(|&count| {
340 (0..self.out_features).map(move |_| count as f32)
342 })
343 .collect();
344
345 from_vec(
346 temporal_data,
347 &[num_nodes, self.out_features],
348 torsh_core::device::DeviceType::Cpu,
349 )
350 .expect("from_vec temporal_data should succeed")
351 }
352}
353
354impl GraphLayer for TGCNConv {
355 fn forward(&self, graph: &GraphData) -> GraphData {
356 let temporal_graph = TemporalGraphData::new(graph.clone(), 1.0, 1000);
358 let output_temporal = self.forward(&temporal_graph);
359 output_temporal.current_graph
360 }
361
362 fn parameters(&self) -> Vec<Tensor> {
363 let mut params = vec![
364 self.spatial_weight.clone_data(),
365 self.temporal_weight.clone_data(),
366 ];
367 if let Some(ref bias) = self.bias {
368 params.push(bias.clone_data());
369 }
370 params
371 }
372}
373
374#[derive(Debug)]
376pub struct TGATConv {
377 in_features: usize,
378 out_features: usize,
379 heads: usize,
380 time_encoding_dim: usize,
381 query_weight: Parameter,
382 key_weight: Parameter,
383 value_weight: Parameter,
384 time_weight: Parameter,
385 output_weight: Parameter,
386 bias: Option<Parameter>,
387 dropout: f32,
388}
389
390impl TGATConv {
391 pub fn new(
393 in_features: usize,
394 out_features: usize,
395 heads: usize,
396 time_encoding_dim: usize,
397 dropout: f32,
398 bias: bool,
399 ) -> Self {
400 let query_weight = Parameter::new(
401 randn(&[in_features, out_features]).expect("randn query_weight should succeed"),
402 );
403 let key_weight = Parameter::new(
404 randn(&[in_features, out_features]).expect("randn key_weight should succeed"),
405 );
406 let value_weight = Parameter::new(
407 randn(&[in_features, out_features]).expect("randn value_weight should succeed"),
408 );
409 let time_weight = Parameter::new(
410 randn(&[time_encoding_dim, out_features]).expect("randn time_weight should succeed"),
411 );
412 let output_weight = Parameter::new(
413 randn(&[out_features, out_features]).expect("randn output_weight should succeed"),
414 );
415
416 let bias = if bias {
417 Some(Parameter::new(
418 zeros(&[out_features]).expect("zeros bias should succeed"),
419 ))
420 } else {
421 None
422 };
423
424 Self {
425 in_features,
426 out_features,
427 heads,
428 time_encoding_dim,
429 query_weight,
430 key_weight,
431 value_weight,
432 time_weight,
433 output_weight,
434 bias,
435 dropout,
436 }
437 }
438
439 pub fn forward(&self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
441 let num_nodes = temporal_graph.current_graph.num_nodes;
442 let head_dim = self.out_features / self.heads;
443
444 let queries = temporal_graph
446 .current_graph
447 .x
448 .matmul(&self.query_weight.clone_data())
449 .expect("matmul queries should succeed");
450 let keys = temporal_graph
451 .current_graph
452 .x
453 .matmul(&self.key_weight.clone_data())
454 .expect("matmul keys should succeed");
455 let values = temporal_graph
456 .current_graph
457 .x
458 .matmul(&self.value_weight.clone_data())
459 .expect("matmul values should succeed");
460
461 let time_encoding = self.compute_time_encoding(temporal_graph);
463 let time_transformed = time_encoding
464 .matmul(&self.time_weight.clone_data())
465 .expect("matmul time_transformed should succeed");
466
467 let q = queries
469 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
470 .expect("view queries should succeed");
471 let k = keys
472 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
473 .expect("view keys should succeed");
474 let v = values
475 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
476 .expect("view values should succeed");
477
478 let attended_features =
480 self.temporal_attention(&q, &k, &v, &time_transformed, temporal_graph);
481
482 let concatenated = attended_features
484 .view(&[num_nodes as i32, self.out_features as i32])
485 .expect("view concatenated should succeed");
486 let mut output = concatenated
487 .matmul(&self.output_weight.clone_data())
488 .expect("matmul output should succeed");
489
490 if let Some(ref bias) = self.bias {
492 output = output
493 .add(&bias.clone_data())
494 .expect("operation should succeed");
495 }
496
497 let mut output_graph = temporal_graph.clone();
499 output_graph.current_graph.x = output;
500 output_graph
501 }
502
503 fn compute_time_encoding(&self, temporal_graph: &TemporalGraphData) -> Tensor {
505 let num_nodes = temporal_graph.current_graph.num_nodes;
506 let current_time = temporal_graph.current_time;
507
508 let mut time_features = vec![current_time as f32; num_nodes * self.time_encoding_dim];
510
511 for (node_id, history) in &temporal_graph.node_features_history {
513 if *node_id < num_nodes {
514 if let Some((timestamp, _)) = history.iter().next_back() {
515 let last_event_time = (*timestamp as f64) / 1000.0;
516 let time_diff = (current_time - last_event_time) as f32;
517
518 for dim in 0..self.time_encoding_dim {
520 let freq = 2.0_f32.powf(dim as f32);
521 let encoded = (time_diff * freq).sin();
522 time_features[*node_id * self.time_encoding_dim + dim] = encoded;
523 }
524 }
525 }
526 }
527
528 from_vec(
529 time_features,
530 &[num_nodes, self.time_encoding_dim],
531 torsh_core::device::DeviceType::Cpu,
532 )
533 .expect("from_vec time_features should succeed")
534 }
535
536 fn temporal_attention(
538 &self,
539 q: &Tensor,
540 k: &Tensor,
541 v: &Tensor,
542 _time_encoding: &Tensor,
543 temporal_graph: &TemporalGraphData,
544 ) -> Tensor {
545 let num_nodes = temporal_graph.current_graph.num_nodes;
546 let head_dim = self.out_features / self.heads;
547
548 let mut output =
550 zeros(&[num_nodes, self.heads, head_dim]).expect("zeros output should succeed");
551
552 for head in 0..self.heads {
554 let _q_head = q
556 .slice_tensor(1, head, head + 1)
557 .expect("slice_tensor q_head should succeed");
558 let _k_head = k
559 .slice_tensor(1, head, head + 1)
560 .expect("slice_tensor k_head should succeed");
561 let v_head = v
562 .slice_tensor(1, head, head + 1)
563 .expect("slice_tensor v_head should succeed");
564
565 for i in 0..num_nodes {
567 let mut attended_value =
568 zeros(&[head_dim]).expect("zeros attended_value should succeed");
569 let mut attention_sum = 0.0;
570
571 for j in 0..num_nodes {
572 let score = 1.0 / (1.0 + (i as f32 - j as f32).abs()); let v_j = v_head
577 .slice_tensor(0, j, j + 1)
578 .expect("slice_tensor v_j should succeed")
579 .squeeze_tensor(0)
580 .expect("squeeze_tensor should succeed")
581 .squeeze_tensor(0)
582 .expect("squeeze_tensor should succeed");
583
584 let weighted_value = v_j.mul_scalar(score).expect("mul_scalar should succeed");
585 attended_value = attended_value
586 .add(&weighted_value)
587 .expect("operation should succeed");
588 attention_sum += score;
589 }
590
591 if attention_sum > 0.0 {
593 attended_value = attended_value
594 .div_scalar(attention_sum)
595 .expect("div_scalar should succeed");
596 }
597
598 let attended_data = attended_value.to_vec().expect("conversion should succeed");
600 for (dim, &val) in attended_data.iter().enumerate() {
601 if dim < head_dim {
602 output
603 .set_item(&[i, head, dim], val)
604 .expect("set_item should succeed");
605 }
606 }
607 }
608 }
609
610 output
611 }
612}
613
614impl GraphLayer for TGATConv {
615 fn forward(&self, graph: &GraphData) -> GraphData {
616 let temporal_graph = TemporalGraphData::new(graph.clone(), 1.0, 1000);
617 let output_temporal = self.forward(&temporal_graph);
618 output_temporal.current_graph
619 }
620
621 fn parameters(&self) -> Vec<Tensor> {
622 let mut params = vec![
623 self.query_weight.clone_data(),
624 self.key_weight.clone_data(),
625 self.value_weight.clone_data(),
626 self.time_weight.clone_data(),
627 self.output_weight.clone_data(),
628 ];
629 if let Some(ref bias) = self.bias {
630 params.push(bias.clone_data());
631 }
632 params
633 }
634}
635
636#[derive(Debug)]
638pub struct TGNConv {
639 in_features: usize,
640 out_features: usize,
641 memory_dim: usize,
642 time_encoding_dim: usize,
643 message_function: Parameter,
644 memory_updater: Parameter,
645 node_embedding: Parameter,
646 bias: Option<Parameter>,
647 node_memories: HashMap<usize, Tensor>,
648 last_update_times: HashMap<usize, f64>,
649}
650
651impl TGNConv {
652 pub fn new(
654 in_features: usize,
655 out_features: usize,
656 memory_dim: usize,
657 time_encoding_dim: usize,
658 bias: bool,
659 ) -> Self {
660 let message_function = Parameter::new(
661 randn(&[in_features + time_encoding_dim, memory_dim])
662 .expect("randn message_function should succeed"),
663 );
664 let memory_updater = Parameter::new(
665 randn(&[memory_dim * 2, memory_dim]).expect("randn memory_updater should succeed"),
666 );
667 let node_embedding = Parameter::new(
668 randn(&[memory_dim, out_features]).expect("randn node_embedding should succeed"),
669 );
670
671 let bias = if bias {
672 Some(Parameter::new(
673 zeros(&[out_features]).expect("zeros bias should succeed"),
674 ))
675 } else {
676 None
677 };
678
679 Self {
680 in_features,
681 out_features,
682 memory_dim,
683 time_encoding_dim,
684 message_function,
685 memory_updater,
686 node_embedding,
687 bias,
688 node_memories: HashMap::new(),
689 last_update_times: HashMap::new(),
690 }
691 }
692
693 pub fn forward(&mut self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
695 self.update_memories(temporal_graph);
697
698 let output_features = self.generate_embeddings(temporal_graph);
700
701 let mut output_graph = temporal_graph.clone();
703 output_graph.current_graph.x = output_features;
704 output_graph
705 }
706
707 fn update_memories(&mut self, temporal_graph: &TemporalGraphData) {
709 let current_time = temporal_graph.current_time;
710 let lookback_time = current_time - temporal_graph.time_window;
711
712 let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
714
715 for event in recent_events {
716 if let Some(node_id) = event.node {
717 let message = self.compute_message(event, current_time);
719
720 self.update_node_memory(node_id, message, event.time);
722 }
723 }
724 }
725
726 fn compute_message(&self, event: &TemporalEvent, current_time: f64) -> Tensor {
728 let time_diff = (current_time - event.time) as f32;
730 let mut time_encoding = Vec::new();
731
732 for i in 0..self.time_encoding_dim {
733 let freq = 2.0_f32.powf(i as f32);
734 time_encoding.push((time_diff * freq).sin());
735 }
736
737 let mut message_input = if let Some(ref features) = event.features {
739 features.to_vec().expect("conversion should succeed")
740 } else {
741 vec![1.0; self.in_features] };
743
744 message_input.extend(time_encoding);
745
746 let input_tensor = from_vec(
747 message_input,
748 &[1, self.in_features + self.time_encoding_dim],
749 torsh_core::device::DeviceType::Cpu,
750 )
751 .expect("from_vec input_tensor should succeed");
752
753 input_tensor
755 .matmul(&self.message_function.clone_data())
756 .expect("matmul message should succeed")
757 }
758
759 fn update_node_memory(&mut self, node_id: usize, message: Tensor, event_time: f64) {
761 let current_memory = self
763 .node_memories
764 .get(&node_id)
765 .cloned()
766 .unwrap_or_else(|| zeros(&[1, self.memory_dim]).expect("zeros memory should succeed"));
767
768 let current_data = current_memory.to_vec().expect("conversion should succeed");
770 let message_data = message.to_vec().expect("conversion should succeed");
771 let mut combined_data = current_data;
772 combined_data.extend(message_data);
773
774 let combined_tensor = from_vec(
775 combined_data,
776 &[1, self.memory_dim * 2],
777 torsh_core::device::DeviceType::Cpu,
778 )
779 .expect("from_vec combined_tensor should succeed");
780
781 let new_memory = combined_tensor
783 .matmul(&self.memory_updater.clone_data())
784 .expect("matmul new_memory should succeed");
785
786 self.node_memories.insert(node_id, new_memory);
787 self.last_update_times.insert(node_id, event_time);
788 }
789
790 fn generate_embeddings(&self, temporal_graph: &TemporalGraphData) -> Tensor {
792 let num_nodes = temporal_graph.current_graph.num_nodes;
793 let mut embeddings = Vec::new();
794
795 for node_id in 0..num_nodes {
796 let memory = self
797 .node_memories
798 .get(&node_id)
799 .cloned()
800 .unwrap_or_else(|| {
801 zeros(&[1, self.memory_dim]).expect("zeros memory should succeed")
802 });
803
804 let embedding = memory
805 .matmul(&self.node_embedding.clone_data())
806 .expect("operation should succeed");
807 let embedding_data = embedding.to_vec().expect("conversion should succeed");
808 embeddings.extend(embedding_data);
809 }
810
811 let mut output = from_vec(
812 embeddings,
813 &[num_nodes, self.out_features],
814 torsh_core::device::DeviceType::Cpu,
815 )
816 .expect("from_vec embeddings should succeed");
817
818 if let Some(ref bias) = self.bias {
820 output = output
821 .add(&bias.clone_data())
822 .expect("operation should succeed");
823 }
824
825 output
826 }
827}
828
829pub mod pooling {
831 use super::*;
832
833 #[derive(Debug, Clone, Copy)]
835 pub enum TemporalPoolingMethod {
836 MostRecent,
837 TimeWeightedMean,
838 ExponentialDecay,
839 AttentionBased,
840 }
841
842 pub fn temporal_pool(
844 temporal_graph: &TemporalGraphData,
845 method: TemporalPoolingMethod,
846 ) -> Tensor {
847 match method {
848 TemporalPoolingMethod::MostRecent => {
849 temporal_graph
851 .current_graph
852 .x
853 .mean(Some(&[0]), false)
854 .expect("mean pooling should succeed")
855 }
856 TemporalPoolingMethod::TimeWeightedMean => time_weighted_pool(temporal_graph),
857 TemporalPoolingMethod::ExponentialDecay => exponential_decay_pool(temporal_graph),
858 TemporalPoolingMethod::AttentionBased => attention_temporal_pool(temporal_graph),
859 }
860 }
861
862 fn time_weighted_pool(temporal_graph: &TemporalGraphData) -> Tensor {
864 let current_time = temporal_graph.current_time;
865 let lookback_time = current_time - temporal_graph.time_window;
866 let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
867
868 if recent_events.is_empty() {
869 return temporal_graph
870 .current_graph
871 .x
872 .mean(Some(&[0]), false)
873 .expect("mean pooling should succeed");
874 }
875
876 let mut weighted_sum = zeros(&[temporal_graph.current_graph.x.shape().dims()[1]])
878 .expect("zeros weighted_sum should succeed");
879 let mut total_weight = 0.0;
880
881 for event in recent_events {
882 if let Some(ref features) = event.features {
883 let weight = 1.0 - (current_time - event.time) / temporal_graph.time_window;
884 let weighted_features = features
885 .mul_scalar(weight as f32)
886 .expect("mul_scalar should succeed");
887
888 let features_data = weighted_features
890 .to_vec()
891 .expect("conversion should succeed");
892 let current_data = weighted_sum.to_vec().expect("conversion should succeed");
893 let mut new_data = Vec::new();
894
895 for (_i, (¤t, &new)) in
896 current_data.iter().zip(features_data.iter()).enumerate()
897 {
898 new_data.push(current + new);
899 }
900
901 weighted_sum = from_vec(
902 new_data,
903 &[weighted_sum.shape().dims()[0]],
904 torsh_core::device::DeviceType::Cpu,
905 )
906 .expect("from_vec weighted_sum should succeed");
907
908 total_weight += weight;
909 }
910 }
911
912 if total_weight > 0.0 {
913 weighted_sum
914 .div_scalar(total_weight as f32)
915 .expect("div_scalar should succeed")
916 } else {
917 temporal_graph
918 .current_graph
919 .x
920 .mean(Some(&[0]), false)
921 .expect("mean pooling should succeed")
922 }
923 }
924
925 fn exponential_decay_pool(temporal_graph: &TemporalGraphData) -> Tensor {
927 let decay_rate = 0.1; let current_time = temporal_graph.current_time;
929
930 let decay_factor = (-decay_rate * current_time).exp() as f32;
932 temporal_graph
933 .current_graph
934 .x
935 .mul_scalar(decay_factor)
936 .expect("mul_scalar should succeed")
937 .mean(Some(&[0]), false)
938 .expect("mean pooling should succeed")
939 }
940
941 fn attention_temporal_pool(temporal_graph: &TemporalGraphData) -> Tensor {
943 let features = &temporal_graph.current_graph.x;
945 let attention_scores = features
946 .sum_dim(&[1], false)
947 .expect("sum_dim should succeed");
948 let attention_weights = attention_scores.softmax(0).expect("softmax should succeed");
949 let attention_expanded = attention_weights
950 .unsqueeze(-1)
951 .expect("unsqueeze should succeed");
952
953 let weighted_features = features
954 .mul(&attention_expanded)
955 .expect("operation should succeed");
956 weighted_features
957 .sum_dim(&[0], false)
958 .expect("sum_dim should succeed")
959 }
960}
961
962pub mod utils {
964 use super::*;
965
966 pub fn generate_random_events(
968 num_events: usize,
969 num_nodes: usize,
970 time_span: f64,
971 feature_dim: usize,
972 ) -> Vec<TemporalEvent> {
973 let mut rng = scirs2_core::random::thread_rng();
974 let mut events = Vec::new();
975
976 for _ in 0..num_events {
977 let time = rng.gen_range(0.0..time_span);
978 let event_type = if rng.gen_range(0.0..1.0) < 0.7 {
979 EventType::NodeFeatureUpdate
980 } else {
981 EventType::EdgeAddition
982 };
983
984 let node = if matches!(event_type, EventType::NodeFeatureUpdate) {
985 Some(rng.gen_range(0..num_nodes))
986 } else {
987 None
988 };
989
990 let (source, target) = if matches!(event_type, EventType::EdgeAddition) {
991 let s = rng.gen_range(0..num_nodes);
992 let t = rng.gen_range(0..num_nodes);
993 (Some(s), Some(t))
994 } else {
995 (None, None)
996 };
997
998 let features = if matches!(event_type, EventType::NodeFeatureUpdate) {
999 Some(randn(&[feature_dim]).expect("randn features should succeed"))
1000 } else {
1001 None
1002 };
1003
1004 events.push(TemporalEvent {
1005 time,
1006 event_type,
1007 source,
1008 target,
1009 node,
1010 features,
1011 weight: Some(rng.gen_range(0.1..1.0)),
1012 });
1013 }
1014
1015 events.sort_by(|a, b| {
1017 a.time
1018 .partial_cmp(&b.time)
1019 .expect("time comparison should succeed")
1020 });
1021 events
1022 }
1023
1024 pub fn create_temporal_graph_from_events(
1026 initial_graph: GraphData,
1027 events: Vec<TemporalEvent>,
1028 time_window: f64,
1029 ) -> TemporalGraphData {
1030 let mut temporal_graph = TemporalGraphData::new(initial_graph, time_window, 10000);
1031
1032 for event in events {
1033 temporal_graph.add_event(event);
1034 }
1035
1036 temporal_graph
1037 }
1038
1039 pub fn temporal_metrics(temporal_graph: &TemporalGraphData) -> TemporalMetrics {
1041 let total_events = temporal_graph.events.values().map(|v| v.len()).sum();
1042 let unique_nodes_with_events = temporal_graph.node_features_history.len();
1043 let time_span = if let (Some(first), Some(last)) = (
1044 temporal_graph.events.keys().next(),
1045 temporal_graph.events.keys().next_back(),
1046 ) {
1047 (*last as f64 - *first as f64) / 1000.0
1048 } else {
1049 0.0
1050 };
1051
1052 let event_rate = if time_span > 0.0 {
1053 total_events as f64 / time_span
1054 } else {
1055 0.0
1056 };
1057
1058 TemporalMetrics {
1059 total_events,
1060 unique_active_nodes: unique_nodes_with_events,
1061 time_span,
1062 event_rate,
1063 current_time: temporal_graph.current_time,
1064 }
1065 }
1066
1067 #[derive(Debug, Clone)]
1069 pub struct TemporalMetrics {
1070 pub total_events: usize,
1071 pub unique_active_nodes: usize,
1072 pub time_span: f64,
1073 pub event_rate: f64,
1074 pub current_time: f64,
1075 }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use super::*;
1081 use torsh_core::device::DeviceType;
1082
1083 #[test]
1084 fn test_temporal_graph_creation() {
1085 let features = randn(&[4, 3]).unwrap();
1086 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0];
1087 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1088 let graph = GraphData::new(features, edge_index);
1089
1090 let temporal_graph = TemporalGraphData::new(graph, 10.0, 1000);
1091
1092 assert_eq!(temporal_graph.current_graph.num_nodes, 4);
1093 assert_eq!(temporal_graph.time_window, 10.0);
1094 assert_eq!(temporal_graph.max_events, 1000);
1095 }
1096
1097 #[test]
1098 fn test_temporal_event_addition() {
1099 let features = randn(&[3, 2]).unwrap();
1100 let edges = vec![0.0, 1.0, 1.0, 2.0];
1101 let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1102 let graph = GraphData::new(features, edge_index);
1103
1104 let mut temporal_graph = TemporalGraphData::new(graph, 5.0, 100);
1105
1106 let event = TemporalEvent {
1107 time: 1.0,
1108 event_type: EventType::NodeFeatureUpdate,
1109 source: None,
1110 target: None,
1111 node: Some(0),
1112 features: Some(randn(&[2]).unwrap()),
1113 weight: None,
1114 };
1115
1116 temporal_graph.add_event(event);
1117
1118 assert_eq!(temporal_graph.current_time, 1.0);
1119 assert!(!temporal_graph.events.is_empty());
1120 }
1121
1122 #[test]
1123 fn test_tgcn_layer() {
1124 let features = randn(&[3, 4]).unwrap();
1125 let edges = vec![0.0, 1.0, 1.0, 2.0];
1126 let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1127 let graph = GraphData::new(features, edge_index);
1128
1129 let temporal_graph = TemporalGraphData::new(graph, 1.0, 100);
1130 let tgcn = TGCNConv::new(4, 8, 16, 64, true);
1131
1132 let output = tgcn.forward(&temporal_graph);
1133 assert_eq!(output.current_graph.x.shape().dims(), &[3, 8]);
1134 }
1135
1136 #[test]
1137 fn test_tgat_layer() {
1138 let features = randn(&[4, 6]).unwrap();
1139 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1140 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1141 let graph = GraphData::new(features, edge_index);
1142
1143 let temporal_graph = TemporalGraphData::new(graph, 2.0, 200);
1144 let tgat = TGATConv::new(6, 12, 3, 8, 0.1, true);
1145
1146 let output = tgat.forward(&temporal_graph);
1147 assert_eq!(output.current_graph.x.shape().dims(), &[4, 12]);
1148 }
1149
1150 #[test]
1151 fn test_temporal_pooling() {
1152 let features = randn(&[5, 4]).unwrap();
1153 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
1154 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1155 let graph = GraphData::new(features, edge_index);
1156
1157 let temporal_graph = TemporalGraphData::new(graph, 3.0, 150);
1158
1159 let pooled =
1160 pooling::temporal_pool(&temporal_graph, pooling::TemporalPoolingMethod::MostRecent);
1161 assert_eq!(pooled.shape().dims(), &[4]);
1162
1163 let weighted_pooled = pooling::temporal_pool(
1164 &temporal_graph,
1165 pooling::TemporalPoolingMethod::TimeWeightedMean,
1166 );
1167 assert_eq!(weighted_pooled.shape().dims(), &[4]);
1168 }
1169
1170 #[test]
1171 fn test_temporal_utils() {
1172 let events = utils::generate_random_events(10, 5, 10.0, 3);
1173 assert_eq!(events.len(), 10);
1174
1175 for i in 1..events.len() {
1177 assert!(events[i].time >= events[i - 1].time);
1178 }
1179
1180 let features = randn(&[5, 3]).unwrap();
1181 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 0.0];
1182 let edge_index = from_vec(edges, &[2, 5], DeviceType::Cpu).unwrap();
1183 let graph = GraphData::new(features, edge_index);
1184
1185 let temporal_graph = utils::create_temporal_graph_from_events(graph, events, 5.0);
1186 let metrics = utils::temporal_metrics(&temporal_graph);
1187
1188 assert!(metrics.total_events > 0);
1189 assert!(metrics.time_span >= 0.0);
1190 }
1191
1192 #[test]
1193 fn test_event_time_range_query() {
1194 let features = randn(&[3, 2]).unwrap();
1195 let edges = vec![0.0, 1.0, 1.0, 2.0];
1196 let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1197 let graph = GraphData::new(features, edge_index);
1198
1199 let mut temporal_graph = TemporalGraphData::new(graph, 10.0, 100);
1200
1201 for i in 0..5 {
1203 let event = TemporalEvent {
1204 time: i as f64,
1205 event_type: EventType::NodeFeatureUpdate,
1206 source: None,
1207 target: None,
1208 node: Some(i % 3),
1209 features: Some(randn(&[2]).unwrap()),
1210 weight: None,
1211 };
1212 temporal_graph.add_event(event);
1213 }
1214
1215 let events_in_range = temporal_graph.get_events_in_range(1.0, 3.0);
1216 assert_eq!(events_in_range.len(), 3); }
1218}