Skip to main content

torsh_graph/
temporal.rs

1//! Temporal Graph Neural Networks
2//!
3//! Advanced implementation of temporal graph neural networks for continuous-time dynamic graphs.
4//! Handles evolving graph structures and node/edge features over time with sophisticated
5//! temporal modeling capabilities.
6//!
7//! # Features:
8//! - Continuous-time temporal graphs with event-based modeling
9//! - Temporal graph convolution layers (TGN, DyRep, TGAT)
10//! - Memory-augmented temporal networks
11//! - Time-aware graph attention mechanisms
12//! - Temporal pooling and aggregation operations
13//! - Causal temporal modeling with proper time ordering
14
15// Framework infrastructure - components designed for future use
16#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use std::collections::{BTreeMap, HashMap};
20use torsh_tensor::{
21    creation::{from_vec, randn, zeros},
22    Tensor,
23};
24
25/// Temporal event representing a graph change at a specific time
26#[derive(Debug, Clone)]
27pub struct TemporalEvent {
28    /// Time of the event (continuous time)
29    pub time: f64,
30    /// Type of event (node addition, edge addition, feature update, etc.)
31    pub event_type: EventType,
32    /// Source node ID (for edge events)
33    pub source: Option<usize>,
34    /// Target node ID (for edge events)
35    pub target: Option<usize>,
36    /// Node ID (for node events)
37    pub node: Option<usize>,
38    /// Feature vector associated with the event
39    pub features: Option<Tensor>,
40    /// Edge weight (for edge events)
41    pub weight: Option<f32>,
42}
43
44/// Types of temporal events
45#[derive(Debug, Clone, PartialEq)]
46pub enum EventType {
47    NodeAddition,
48    NodeDeletion,
49    NodeFeatureUpdate,
50    EdgeAddition,
51    EdgeDeletion,
52    EdgeFeatureUpdate,
53    GraphSnapshot,
54}
55
56/// Temporal graph data structure for continuous-time dynamic graphs
57#[derive(Debug, Clone)]
58pub struct TemporalGraphData {
59    /// Static graph structure at current time
60    pub current_graph: GraphData,
61    /// Sequence of temporal events ordered by time
62    pub events: BTreeMap<u64, Vec<TemporalEvent>>, // Using u64 timestamp for ordering
63    /// Time-indexed node features
64    pub node_features_history: HashMap<usize, BTreeMap<u64, Tensor>>,
65    /// Time-indexed edge features
66    pub edge_features_history: HashMap<(usize, usize), BTreeMap<u64, Tensor>>,
67    /// Current timestamp
68    pub current_time: f64,
69    /// Time window for temporal aggregation
70    pub time_window: f64,
71    /// Maximum number of events to keep in memory
72    pub max_events: usize,
73}
74
75impl TemporalGraphData {
76    /// Create a new temporal graph
77    pub fn new(initial_graph: GraphData, time_window: f64, max_events: usize) -> Self {
78        Self {
79            current_graph: initial_graph,
80            events: BTreeMap::new(),
81            node_features_history: HashMap::new(),
82            edge_features_history: HashMap::new(),
83            current_time: 0.0,
84            time_window,
85            max_events,
86        }
87    }
88
89    /// Add a temporal event to the graph
90    pub fn add_event(&mut self, event: TemporalEvent) {
91        let timestamp = (event.time * 1000.0) as u64; // Convert to milliseconds for ordering
92        self.events
93            .entry(timestamp)
94            .or_insert_with(Vec::new)
95            .push(event.clone());
96
97        // Update current time
98        self.current_time = self.current_time.max(event.time);
99
100        // Apply event to current graph structure
101        self.apply_event(&event);
102
103        // Clean up old events outside time window
104        self.cleanup_old_events();
105    }
106
107    /// Apply an event to the current graph structure
108    fn apply_event(&mut self, event: &TemporalEvent) {
109        match event.event_type {
110            EventType::NodeFeatureUpdate => {
111                if let (Some(node), Some(ref features)) = (event.node, &event.features) {
112                    // Update node features in current graph
113                    self.update_node_features(node, features.clone());
114
115                    // Store in history
116                    let timestamp = (event.time * 1000.0) as u64;
117                    self.node_features_history
118                        .entry(node)
119                        .or_insert_with(BTreeMap::new)
120                        .insert(timestamp, features.clone());
121                }
122            }
123            EventType::EdgeFeatureUpdate => {
124                if let (Some(source), Some(target), Some(ref features)) =
125                    (event.source, event.target, &event.features)
126                {
127                    let timestamp = (event.time * 1000.0) as u64;
128                    self.edge_features_history
129                        .entry((source, target))
130                        .or_insert_with(BTreeMap::new)
131                        .insert(timestamp, features.clone());
132                }
133            }
134            _ => {
135                // For simplicity, other event types are not fully implemented
136                // In a complete implementation, these would modify the graph structure
137            }
138        }
139    }
140
141    /// Update node features in the current graph
142    fn update_node_features(&mut self, node_id: usize, features: Tensor) {
143        // Simplified implementation - would need proper tensor slicing in practice
144        let current_features = self
145            .current_graph
146            .x
147            .to_vec()
148            .expect("conversion should succeed");
149        let feature_dim = self.current_graph.x.shape().dims()[1];
150        let new_features = features.to_vec().expect("conversion should succeed");
151
152        let mut updated_features = current_features;
153        let start_idx = node_id * feature_dim;
154        let _end_idx = start_idx + feature_dim.min(new_features.len());
155
156        for (i, &value) in new_features.iter().take(feature_dim).enumerate() {
157            if start_idx + i < updated_features.len() {
158                updated_features[start_idx + i] = value;
159            }
160        }
161
162        self.current_graph.x = from_vec(
163            updated_features,
164            &[self.current_graph.num_nodes, feature_dim],
165            torsh_core::device::DeviceType::Cpu,
166        )
167        .expect("from_vec updated_features should succeed");
168    }
169
170    /// Clean up old events outside the time window
171    fn cleanup_old_events(&mut self) {
172        let cutoff_time = ((self.current_time - self.time_window) * 1000.0) as u64;
173
174        // Remove events older than time window
175        let old_keys: Vec<u64> = self
176            .events
177            .keys()
178            .filter(|&&timestamp| timestamp < cutoff_time)
179            .cloned()
180            .collect();
181
182        for key in old_keys {
183            self.events.remove(&key);
184        }
185
186        // Also limit total number of events
187        while self.events.len() > self.max_events {
188            if let Some(first_key) = self.events.keys().next().cloned() {
189                self.events.remove(&first_key);
190            } else {
191                break;
192            }
193        }
194    }
195
196    /// Get events within a specific time range
197    pub fn get_events_in_range(&self, start_time: f64, end_time: f64) -> Vec<&TemporalEvent> {
198        let start_timestamp = (start_time * 1000.0) as u64;
199        let end_timestamp = (end_time * 1000.0) as u64;
200
201        self.events
202            .range(start_timestamp..=end_timestamp)
203            .flat_map(|(_, events)| events.iter())
204            .collect()
205    }
206
207    /// Get node features at a specific time (with interpolation)
208    pub fn get_node_features_at_time(&self, node_id: usize, time: f64) -> Option<Tensor> {
209        let timestamp = (time * 1000.0) as u64;
210
211        if let Some(history) = self.node_features_history.get(&node_id) {
212            // Find the most recent features before or at the requested time
213            if let Some((_, features)) = history.range(..=timestamp).next_back() {
214                return Some(features.clone());
215            }
216        }
217
218        None
219    }
220
221    /// Create a snapshot of the graph at a specific time
222    pub fn snapshot_at_time(&self, _time: f64) -> GraphData {
223        // Simplified implementation - returns current graph
224        // In practice, this would reconstruct the graph state at the specified time
225        self.current_graph.clone()
226    }
227}
228
229/// Temporal Graph Convolutional Network (TGCN) layer
230#[derive(Debug)]
231pub struct TGCNConv {
232    in_features: usize,
233    out_features: usize,
234    temporal_dim: usize,
235    spatial_weight: Parameter,
236    temporal_weight: Parameter,
237    bias: Option<Parameter>,
238    memory_size: usize,
239    time_encoding_dim: usize,
240}
241
242impl TGCNConv {
243    /// Create a new TGCN layer
244    pub fn new(
245        in_features: usize,
246        out_features: usize,
247        temporal_dim: usize,
248        memory_size: usize,
249        bias: bool,
250    ) -> Self {
251        let spatial_weight = Parameter::new(
252            randn(&[in_features, out_features]).expect("randn spatial_weight should succeed"),
253        );
254        let temporal_weight = Parameter::new(
255            randn(&[temporal_dim, out_features]).expect("randn temporal_weight should succeed"),
256        );
257        let bias = if bias {
258            Some(Parameter::new(
259                zeros(&[out_features]).expect("zeros bias should succeed"),
260            ))
261        } else {
262            None
263        };
264
265        Self {
266            in_features,
267            out_features,
268            temporal_dim,
269            spatial_weight,
270            temporal_weight,
271            bias,
272            memory_size,
273            time_encoding_dim: temporal_dim,
274        }
275    }
276
277    /// Forward pass through TGCN layer
278    pub fn forward(&self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
279        // Step 1: Spatial convolution on current graph
280        let spatial_features = temporal_graph
281            .current_graph
282            .x
283            .matmul(&self.spatial_weight.clone_data())
284            .expect("matmul spatial_features should succeed");
285
286        // Step 2: Temporal encoding based on recent events
287        let temporal_features = self.encode_temporal_context(temporal_graph);
288
289        // Step 3: Combine spatial and temporal features
290        let combined_features = spatial_features
291            .add(&temporal_features)
292            .expect("operation should succeed");
293
294        // Step 4: Add bias if present
295        let output_features = if let Some(ref bias) = self.bias {
296            combined_features
297                .add(&bias.clone_data())
298                .expect("operation should succeed")
299        } else {
300            combined_features
301        };
302
303        // Create output temporal graph
304        let mut output_graph = temporal_graph.clone();
305        output_graph.current_graph.x = output_features;
306        output_graph
307    }
308
309    /// Encode temporal context from recent events
310    fn encode_temporal_context(&self, temporal_graph: &TemporalGraphData) -> Tensor {
311        let num_nodes = temporal_graph.current_graph.num_nodes;
312        let current_time = temporal_graph.current_time;
313        let lookback_time = current_time - temporal_graph.time_window;
314
315        // Get recent events
316        let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
317
318        // Initialize temporal encoding
319        let _temporal_encoding = zeros::<f32>(&[num_nodes, self.out_features])
320            .expect("zeros temporal_encoding should succeed");
321
322        // Simple temporal encoding based on event recency and frequency
323        let mut node_event_counts = vec![0.0; num_nodes];
324
325        for event in recent_events {
326            if let Some(node_id) = event.node {
327                if node_id < num_nodes {
328                    // Weight by recency (more recent events have higher weight)
329                    let recency_weight =
330                        1.0 - (current_time - event.time) / temporal_graph.time_window;
331                    node_event_counts[node_id] += recency_weight;
332                }
333            }
334        }
335
336        // Convert counts to temporal features
337        let temporal_data: Vec<f32> = node_event_counts
338            .iter()
339            .flat_map(|&count| {
340                // Simple encoding: repeat the count for each output feature
341                (0..self.out_features).map(move |_| count as f32)
342            })
343            .collect();
344
345        from_vec(
346            temporal_data,
347            &[num_nodes, self.out_features],
348            torsh_core::device::DeviceType::Cpu,
349        )
350        .expect("from_vec temporal_data should succeed")
351    }
352}
353
354impl GraphLayer for TGCNConv {
355    fn forward(&self, graph: &GraphData) -> GraphData {
356        // Convert to temporal graph for processing
357        let temporal_graph = TemporalGraphData::new(graph.clone(), 1.0, 1000);
358        let output_temporal = self.forward(&temporal_graph);
359        output_temporal.current_graph
360    }
361
362    fn parameters(&self) -> Vec<Tensor> {
363        let mut params = vec![
364            self.spatial_weight.clone_data(),
365            self.temporal_weight.clone_data(),
366        ];
367        if let Some(ref bias) = self.bias {
368            params.push(bias.clone_data());
369        }
370        params
371    }
372}
373
374/// Temporal Graph Attention Network (TGAT) layer
375#[derive(Debug)]
376pub struct TGATConv {
377    in_features: usize,
378    out_features: usize,
379    heads: usize,
380    time_encoding_dim: usize,
381    query_weight: Parameter,
382    key_weight: Parameter,
383    value_weight: Parameter,
384    time_weight: Parameter,
385    output_weight: Parameter,
386    bias: Option<Parameter>,
387    dropout: f32,
388}
389
390impl TGATConv {
391    /// Create a new TGAT layer
392    pub fn new(
393        in_features: usize,
394        out_features: usize,
395        heads: usize,
396        time_encoding_dim: usize,
397        dropout: f32,
398        bias: bool,
399    ) -> Self {
400        let query_weight = Parameter::new(
401            randn(&[in_features, out_features]).expect("randn query_weight should succeed"),
402        );
403        let key_weight = Parameter::new(
404            randn(&[in_features, out_features]).expect("randn key_weight should succeed"),
405        );
406        let value_weight = Parameter::new(
407            randn(&[in_features, out_features]).expect("randn value_weight should succeed"),
408        );
409        let time_weight = Parameter::new(
410            randn(&[time_encoding_dim, out_features]).expect("randn time_weight should succeed"),
411        );
412        let output_weight = Parameter::new(
413            randn(&[out_features, out_features]).expect("randn output_weight should succeed"),
414        );
415
416        let bias = if bias {
417            Some(Parameter::new(
418                zeros(&[out_features]).expect("zeros bias should succeed"),
419            ))
420        } else {
421            None
422        };
423
424        Self {
425            in_features,
426            out_features,
427            heads,
428            time_encoding_dim,
429            query_weight,
430            key_weight,
431            value_weight,
432            time_weight,
433            output_weight,
434            bias,
435            dropout,
436        }
437    }
438
439    /// Forward pass through TGAT layer
440    pub fn forward(&self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
441        let num_nodes = temporal_graph.current_graph.num_nodes;
442        let head_dim = self.out_features / self.heads;
443
444        // Compute Q, K, V transformations
445        let queries = temporal_graph
446            .current_graph
447            .x
448            .matmul(&self.query_weight.clone_data())
449            .expect("matmul queries should succeed");
450        let keys = temporal_graph
451            .current_graph
452            .x
453            .matmul(&self.key_weight.clone_data())
454            .expect("matmul keys should succeed");
455        let values = temporal_graph
456            .current_graph
457            .x
458            .matmul(&self.value_weight.clone_data())
459            .expect("matmul values should succeed");
460
461        // Compute time encoding for each node based on recent activity
462        let time_encoding = self.compute_time_encoding(temporal_graph);
463        let time_transformed = time_encoding
464            .matmul(&self.time_weight.clone_data())
465            .expect("matmul time_transformed should succeed");
466
467        // Reshape for multi-head attention
468        let q = queries
469            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
470            .expect("view queries should succeed");
471        let k = keys
472            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
473            .expect("view keys should succeed");
474        let v = values
475            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
476            .expect("view values should succeed");
477
478        // Perform temporal attention
479        let attended_features =
480            self.temporal_attention(&q, &k, &v, &time_transformed, temporal_graph);
481
482        // Reshape and apply output transformation
483        let concatenated = attended_features
484            .view(&[num_nodes as i32, self.out_features as i32])
485            .expect("view concatenated should succeed");
486        let mut output = concatenated
487            .matmul(&self.output_weight.clone_data())
488            .expect("matmul output should succeed");
489
490        // Add bias if present
491        if let Some(ref bias) = self.bias {
492            output = output
493                .add(&bias.clone_data())
494                .expect("operation should succeed");
495        }
496
497        // Create output temporal graph
498        let mut output_graph = temporal_graph.clone();
499        output_graph.current_graph.x = output;
500        output_graph
501    }
502
503    /// Compute time encoding for nodes based on recent events
504    fn compute_time_encoding(&self, temporal_graph: &TemporalGraphData) -> Tensor {
505        let num_nodes = temporal_graph.current_graph.num_nodes;
506        let current_time = temporal_graph.current_time;
507
508        // Simple time encoding: time since last event for each node
509        let mut time_features = vec![current_time as f32; num_nodes * self.time_encoding_dim];
510
511        // Update with actual last event times
512        for (node_id, history) in &temporal_graph.node_features_history {
513            if *node_id < num_nodes {
514                if let Some((timestamp, _)) = history.iter().next_back() {
515                    let last_event_time = (*timestamp as f64) / 1000.0;
516                    let time_diff = (current_time - last_event_time) as f32;
517
518                    // Encode time difference in multiple dimensions
519                    for dim in 0..self.time_encoding_dim {
520                        let freq = 2.0_f32.powf(dim as f32);
521                        let encoded = (time_diff * freq).sin();
522                        time_features[*node_id * self.time_encoding_dim + dim] = encoded;
523                    }
524                }
525            }
526        }
527
528        from_vec(
529            time_features,
530            &[num_nodes, self.time_encoding_dim],
531            torsh_core::device::DeviceType::Cpu,
532        )
533        .expect("from_vec time_features should succeed")
534    }
535
536    /// Temporal attention mechanism
537    fn temporal_attention(
538        &self,
539        q: &Tensor,
540        k: &Tensor,
541        v: &Tensor,
542        _time_encoding: &Tensor,
543        temporal_graph: &TemporalGraphData,
544    ) -> Tensor {
545        let num_nodes = temporal_graph.current_graph.num_nodes;
546        let head_dim = self.out_features / self.heads;
547
548        // Simplified temporal attention
549        let mut output =
550            zeros(&[num_nodes, self.heads, head_dim]).expect("zeros output should succeed");
551
552        // For each head, compute attention with temporal bias
553        for head in 0..self.heads {
554            // Extract head-specific features
555            let _q_head = q
556                .slice_tensor(1, head, head + 1)
557                .expect("slice_tensor q_head should succeed");
558            let _k_head = k
559                .slice_tensor(1, head, head + 1)
560                .expect("slice_tensor k_head should succeed");
561            let v_head = v
562                .slice_tensor(1, head, head + 1)
563                .expect("slice_tensor v_head should succeed");
564
565            // Simplified attention computation (using dot product)
566            for i in 0..num_nodes {
567                let mut attended_value =
568                    zeros(&[head_dim]).expect("zeros attended_value should succeed");
569                let mut attention_sum = 0.0;
570
571                for j in 0..num_nodes {
572                    // Basic attention score computation
573                    let score = 1.0 / (1.0 + (i as f32 - j as f32).abs()); // Distance-based attention
574
575                    // Get value for node j
576                    let v_j = v_head
577                        .slice_tensor(0, j, j + 1)
578                        .expect("slice_tensor v_j should succeed")
579                        .squeeze_tensor(0)
580                        .expect("squeeze_tensor should succeed")
581                        .squeeze_tensor(0)
582                        .expect("squeeze_tensor should succeed");
583
584                    let weighted_value = v_j.mul_scalar(score).expect("mul_scalar should succeed");
585                    attended_value = attended_value
586                        .add(&weighted_value)
587                        .expect("operation should succeed");
588                    attention_sum += score;
589                }
590
591                // Normalize
592                if attention_sum > 0.0 {
593                    attended_value = attended_value
594                        .div_scalar(attention_sum)
595                        .expect("div_scalar should succeed");
596                }
597
598                // Store in output (simplified assignment)
599                let attended_data = attended_value.to_vec().expect("conversion should succeed");
600                for (dim, &val) in attended_data.iter().enumerate() {
601                    if dim < head_dim {
602                        output
603                            .set_item(&[i, head, dim], val)
604                            .expect("set_item should succeed");
605                    }
606                }
607            }
608        }
609
610        output
611    }
612}
613
614impl GraphLayer for TGATConv {
615    fn forward(&self, graph: &GraphData) -> GraphData {
616        let temporal_graph = TemporalGraphData::new(graph.clone(), 1.0, 1000);
617        let output_temporal = self.forward(&temporal_graph);
618        output_temporal.current_graph
619    }
620
621    fn parameters(&self) -> Vec<Tensor> {
622        let mut params = vec![
623            self.query_weight.clone_data(),
624            self.key_weight.clone_data(),
625            self.value_weight.clone_data(),
626            self.time_weight.clone_data(),
627            self.output_weight.clone_data(),
628        ];
629        if let Some(ref bias) = self.bias {
630            params.push(bias.clone_data());
631        }
632        params
633    }
634}
635
636/// Memory-augmented Temporal Graph Network (TGN) layer
637#[derive(Debug)]
638pub struct TGNConv {
639    in_features: usize,
640    out_features: usize,
641    memory_dim: usize,
642    time_encoding_dim: usize,
643    message_function: Parameter,
644    memory_updater: Parameter,
645    node_embedding: Parameter,
646    bias: Option<Parameter>,
647    node_memories: HashMap<usize, Tensor>,
648    last_update_times: HashMap<usize, f64>,
649}
650
651impl TGNConv {
652    /// Create a new TGN layer
653    pub fn new(
654        in_features: usize,
655        out_features: usize,
656        memory_dim: usize,
657        time_encoding_dim: usize,
658        bias: bool,
659    ) -> Self {
660        let message_function = Parameter::new(
661            randn(&[in_features + time_encoding_dim, memory_dim])
662                .expect("randn message_function should succeed"),
663        );
664        let memory_updater = Parameter::new(
665            randn(&[memory_dim * 2, memory_dim]).expect("randn memory_updater should succeed"),
666        );
667        let node_embedding = Parameter::new(
668            randn(&[memory_dim, out_features]).expect("randn node_embedding should succeed"),
669        );
670
671        let bias = if bias {
672            Some(Parameter::new(
673                zeros(&[out_features]).expect("zeros bias should succeed"),
674            ))
675        } else {
676            None
677        };
678
679        Self {
680            in_features,
681            out_features,
682            memory_dim,
683            time_encoding_dim,
684            message_function,
685            memory_updater,
686            node_embedding,
687            bias,
688            node_memories: HashMap::new(),
689            last_update_times: HashMap::new(),
690        }
691    }
692
693    /// Forward pass through TGN layer
694    pub fn forward(&mut self, temporal_graph: &TemporalGraphData) -> TemporalGraphData {
695        // Update node memories based on recent events
696        self.update_memories(temporal_graph);
697
698        // Generate node embeddings from memories
699        let output_features = self.generate_embeddings(temporal_graph);
700
701        // Create output temporal graph
702        let mut output_graph = temporal_graph.clone();
703        output_graph.current_graph.x = output_features;
704        output_graph
705    }
706
707    /// Update node memories based on temporal events
708    fn update_memories(&mut self, temporal_graph: &TemporalGraphData) {
709        let current_time = temporal_graph.current_time;
710        let lookback_time = current_time - temporal_graph.time_window;
711
712        // Get recent events
713        let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
714
715        for event in recent_events {
716            if let Some(node_id) = event.node {
717                // Generate message from event
718                let message = self.compute_message(event, current_time);
719
720                // Update node memory
721                self.update_node_memory(node_id, message, event.time);
722            }
723        }
724    }
725
726    /// Compute message from temporal event
727    fn compute_message(&self, event: &TemporalEvent, current_time: f64) -> Tensor {
728        // Time encoding
729        let time_diff = (current_time - event.time) as f32;
730        let mut time_encoding = Vec::new();
731
732        for i in 0..self.time_encoding_dim {
733            let freq = 2.0_f32.powf(i as f32);
734            time_encoding.push((time_diff * freq).sin());
735        }
736
737        // Combine event features with time encoding
738        let mut message_input = if let Some(ref features) = event.features {
739            features.to_vec().expect("conversion should succeed")
740        } else {
741            vec![1.0; self.in_features] // Default features
742        };
743
744        message_input.extend(time_encoding);
745
746        let input_tensor = from_vec(
747            message_input,
748            &[1, self.in_features + self.time_encoding_dim],
749            torsh_core::device::DeviceType::Cpu,
750        )
751        .expect("from_vec input_tensor should succeed");
752
753        // Apply message function
754        input_tensor
755            .matmul(&self.message_function.clone_data())
756            .expect("matmul message should succeed")
757    }
758
759    /// Update memory for a specific node
760    fn update_node_memory(&mut self, node_id: usize, message: Tensor, event_time: f64) {
761        // Get current memory or initialize
762        let current_memory = self
763            .node_memories
764            .get(&node_id)
765            .cloned()
766            .unwrap_or_else(|| zeros(&[1, self.memory_dim]).expect("zeros memory should succeed"));
767
768        // Concatenate current memory and message
769        let current_data = current_memory.to_vec().expect("conversion should succeed");
770        let message_data = message.to_vec().expect("conversion should succeed");
771        let mut combined_data = current_data;
772        combined_data.extend(message_data);
773
774        let combined_tensor = from_vec(
775            combined_data,
776            &[1, self.memory_dim * 2],
777            torsh_core::device::DeviceType::Cpu,
778        )
779        .expect("from_vec combined_tensor should succeed");
780
781        // Update memory using memory updater
782        let new_memory = combined_tensor
783            .matmul(&self.memory_updater.clone_data())
784            .expect("matmul new_memory should succeed");
785
786        self.node_memories.insert(node_id, new_memory);
787        self.last_update_times.insert(node_id, event_time);
788    }
789
790    /// Generate node embeddings from memories
791    fn generate_embeddings(&self, temporal_graph: &TemporalGraphData) -> Tensor {
792        let num_nodes = temporal_graph.current_graph.num_nodes;
793        let mut embeddings = Vec::new();
794
795        for node_id in 0..num_nodes {
796            let memory = self
797                .node_memories
798                .get(&node_id)
799                .cloned()
800                .unwrap_or_else(|| {
801                    zeros(&[1, self.memory_dim]).expect("zeros memory should succeed")
802                });
803
804            let embedding = memory
805                .matmul(&self.node_embedding.clone_data())
806                .expect("operation should succeed");
807            let embedding_data = embedding.to_vec().expect("conversion should succeed");
808            embeddings.extend(embedding_data);
809        }
810
811        let mut output = from_vec(
812            embeddings,
813            &[num_nodes, self.out_features],
814            torsh_core::device::DeviceType::Cpu,
815        )
816        .expect("from_vec embeddings should succeed");
817
818        // Add bias if present
819        if let Some(ref bias) = self.bias {
820            output = output
821                .add(&bias.clone_data())
822                .expect("operation should succeed");
823        }
824
825        output
826    }
827}
828
829/// Temporal graph pooling operations
830pub mod pooling {
831    use super::*;
832
833    /// Temporal pooling methods
834    #[derive(Debug, Clone, Copy)]
835    pub enum TemporalPoolingMethod {
836        MostRecent,
837        TimeWeightedMean,
838        ExponentialDecay,
839        AttentionBased,
840    }
841
842    /// Global temporal pooling
843    pub fn temporal_pool(
844        temporal_graph: &TemporalGraphData,
845        method: TemporalPoolingMethod,
846    ) -> Tensor {
847        match method {
848            TemporalPoolingMethod::MostRecent => {
849                // Use current graph features
850                temporal_graph
851                    .current_graph
852                    .x
853                    .mean(Some(&[0]), false)
854                    .expect("mean pooling should succeed")
855            }
856            TemporalPoolingMethod::TimeWeightedMean => time_weighted_pool(temporal_graph),
857            TemporalPoolingMethod::ExponentialDecay => exponential_decay_pool(temporal_graph),
858            TemporalPoolingMethod::AttentionBased => attention_temporal_pool(temporal_graph),
859        }
860    }
861
862    /// Time-weighted pooling based on event recency
863    fn time_weighted_pool(temporal_graph: &TemporalGraphData) -> Tensor {
864        let current_time = temporal_graph.current_time;
865        let lookback_time = current_time - temporal_graph.time_window;
866        let recent_events = temporal_graph.get_events_in_range(lookback_time, current_time);
867
868        if recent_events.is_empty() {
869            return temporal_graph
870                .current_graph
871                .x
872                .mean(Some(&[0]), false)
873                .expect("mean pooling should succeed");
874        }
875
876        // Weight events by recency
877        let mut weighted_sum = zeros(&[temporal_graph.current_graph.x.shape().dims()[1]])
878            .expect("zeros weighted_sum should succeed");
879        let mut total_weight = 0.0;
880
881        for event in recent_events {
882            if let Some(ref features) = event.features {
883                let weight = 1.0 - (current_time - event.time) / temporal_graph.time_window;
884                let weighted_features = features
885                    .mul_scalar(weight as f32)
886                    .expect("mul_scalar should succeed");
887
888                // Sum the features (simplified)
889                let features_data = weighted_features
890                    .to_vec()
891                    .expect("conversion should succeed");
892                let current_data = weighted_sum.to_vec().expect("conversion should succeed");
893                let mut new_data = Vec::new();
894
895                for (_i, (&current, &new)) in
896                    current_data.iter().zip(features_data.iter()).enumerate()
897                {
898                    new_data.push(current + new);
899                }
900
901                weighted_sum = from_vec(
902                    new_data,
903                    &[weighted_sum.shape().dims()[0]],
904                    torsh_core::device::DeviceType::Cpu,
905                )
906                .expect("from_vec weighted_sum should succeed");
907
908                total_weight += weight;
909            }
910        }
911
912        if total_weight > 0.0 {
913            weighted_sum
914                .div_scalar(total_weight as f32)
915                .expect("div_scalar should succeed")
916        } else {
917            temporal_graph
918                .current_graph
919                .x
920                .mean(Some(&[0]), false)
921                .expect("mean pooling should succeed")
922        }
923    }
924
925    /// Exponential decay pooling
926    fn exponential_decay_pool(temporal_graph: &TemporalGraphData) -> Tensor {
927        let decay_rate = 0.1; // Decay parameter
928        let current_time = temporal_graph.current_time;
929
930        // Simple exponential decay - use current features
931        let decay_factor = (-decay_rate * current_time).exp() as f32;
932        temporal_graph
933            .current_graph
934            .x
935            .mul_scalar(decay_factor)
936            .expect("mul_scalar should succeed")
937            .mean(Some(&[0]), false)
938            .expect("mean pooling should succeed")
939    }
940
941    /// Attention-based temporal pooling
942    fn attention_temporal_pool(temporal_graph: &TemporalGraphData) -> Tensor {
943        // Simplified attention pooling
944        let features = &temporal_graph.current_graph.x;
945        let attention_scores = features
946            .sum_dim(&[1], false)
947            .expect("sum_dim should succeed");
948        let attention_weights = attention_scores.softmax(0).expect("softmax should succeed");
949        let attention_expanded = attention_weights
950            .unsqueeze(-1)
951            .expect("unsqueeze should succeed");
952
953        let weighted_features = features
954            .mul(&attention_expanded)
955            .expect("operation should succeed");
956        weighted_features
957            .sum_dim(&[0], false)
958            .expect("sum_dim should succeed")
959    }
960}
961
962/// Temporal graph utilities
963pub mod utils {
964    use super::*;
965
966    /// Generate random temporal events
967    pub fn generate_random_events(
968        num_events: usize,
969        num_nodes: usize,
970        time_span: f64,
971        feature_dim: usize,
972    ) -> Vec<TemporalEvent> {
973        let mut rng = scirs2_core::random::thread_rng();
974        let mut events = Vec::new();
975
976        for _ in 0..num_events {
977            let time = rng.gen_range(0.0..time_span);
978            let event_type = if rng.gen_range(0.0..1.0) < 0.7 {
979                EventType::NodeFeatureUpdate
980            } else {
981                EventType::EdgeAddition
982            };
983
984            let node = if matches!(event_type, EventType::NodeFeatureUpdate) {
985                Some(rng.gen_range(0..num_nodes))
986            } else {
987                None
988            };
989
990            let (source, target) = if matches!(event_type, EventType::EdgeAddition) {
991                let s = rng.gen_range(0..num_nodes);
992                let t = rng.gen_range(0..num_nodes);
993                (Some(s), Some(t))
994            } else {
995                (None, None)
996            };
997
998            let features = if matches!(event_type, EventType::NodeFeatureUpdate) {
999                Some(randn(&[feature_dim]).expect("randn features should succeed"))
1000            } else {
1001                None
1002            };
1003
1004            events.push(TemporalEvent {
1005                time,
1006                event_type,
1007                source,
1008                target,
1009                node,
1010                features,
1011                weight: Some(rng.gen_range(0.1..1.0)),
1012            });
1013        }
1014
1015        // Sort events by time
1016        events.sort_by(|a, b| {
1017            a.time
1018                .partial_cmp(&b.time)
1019                .expect("time comparison should succeed")
1020        });
1021        events
1022    }
1023
1024    /// Create temporal graph from event sequence
1025    pub fn create_temporal_graph_from_events(
1026        initial_graph: GraphData,
1027        events: Vec<TemporalEvent>,
1028        time_window: f64,
1029    ) -> TemporalGraphData {
1030        let mut temporal_graph = TemporalGraphData::new(initial_graph, time_window, 10000);
1031
1032        for event in events {
1033            temporal_graph.add_event(event);
1034        }
1035
1036        temporal_graph
1037    }
1038
1039    /// Compute temporal graph metrics
1040    pub fn temporal_metrics(temporal_graph: &TemporalGraphData) -> TemporalMetrics {
1041        let total_events = temporal_graph.events.values().map(|v| v.len()).sum();
1042        let unique_nodes_with_events = temporal_graph.node_features_history.len();
1043        let time_span = if let (Some(first), Some(last)) = (
1044            temporal_graph.events.keys().next(),
1045            temporal_graph.events.keys().next_back(),
1046        ) {
1047            (*last as f64 - *first as f64) / 1000.0
1048        } else {
1049            0.0
1050        };
1051
1052        let event_rate = if time_span > 0.0 {
1053            total_events as f64 / time_span
1054        } else {
1055            0.0
1056        };
1057
1058        TemporalMetrics {
1059            total_events,
1060            unique_active_nodes: unique_nodes_with_events,
1061            time_span,
1062            event_rate,
1063            current_time: temporal_graph.current_time,
1064        }
1065    }
1066
1067    /// Temporal graph metrics
1068    #[derive(Debug, Clone)]
1069    pub struct TemporalMetrics {
1070        pub total_events: usize,
1071        pub unique_active_nodes: usize,
1072        pub time_span: f64,
1073        pub event_rate: f64,
1074        pub current_time: f64,
1075    }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use super::*;
1081    use torsh_core::device::DeviceType;
1082
1083    #[test]
1084    fn test_temporal_graph_creation() {
1085        let features = randn(&[4, 3]).unwrap();
1086        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0];
1087        let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1088        let graph = GraphData::new(features, edge_index);
1089
1090        let temporal_graph = TemporalGraphData::new(graph, 10.0, 1000);
1091
1092        assert_eq!(temporal_graph.current_graph.num_nodes, 4);
1093        assert_eq!(temporal_graph.time_window, 10.0);
1094        assert_eq!(temporal_graph.max_events, 1000);
1095    }
1096
1097    #[test]
1098    fn test_temporal_event_addition() {
1099        let features = randn(&[3, 2]).unwrap();
1100        let edges = vec![0.0, 1.0, 1.0, 2.0];
1101        let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1102        let graph = GraphData::new(features, edge_index);
1103
1104        let mut temporal_graph = TemporalGraphData::new(graph, 5.0, 100);
1105
1106        let event = TemporalEvent {
1107            time: 1.0,
1108            event_type: EventType::NodeFeatureUpdate,
1109            source: None,
1110            target: None,
1111            node: Some(0),
1112            features: Some(randn(&[2]).unwrap()),
1113            weight: None,
1114        };
1115
1116        temporal_graph.add_event(event);
1117
1118        assert_eq!(temporal_graph.current_time, 1.0);
1119        assert!(!temporal_graph.events.is_empty());
1120    }
1121
1122    #[test]
1123    fn test_tgcn_layer() {
1124        let features = randn(&[3, 4]).unwrap();
1125        let edges = vec![0.0, 1.0, 1.0, 2.0];
1126        let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1127        let graph = GraphData::new(features, edge_index);
1128
1129        let temporal_graph = TemporalGraphData::new(graph, 1.0, 100);
1130        let tgcn = TGCNConv::new(4, 8, 16, 64, true);
1131
1132        let output = tgcn.forward(&temporal_graph);
1133        assert_eq!(output.current_graph.x.shape().dims(), &[3, 8]);
1134    }
1135
1136    #[test]
1137    fn test_tgat_layer() {
1138        let features = randn(&[4, 6]).unwrap();
1139        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1140        let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1141        let graph = GraphData::new(features, edge_index);
1142
1143        let temporal_graph = TemporalGraphData::new(graph, 2.0, 200);
1144        let tgat = TGATConv::new(6, 12, 3, 8, 0.1, true);
1145
1146        let output = tgat.forward(&temporal_graph);
1147        assert_eq!(output.current_graph.x.shape().dims(), &[4, 12]);
1148    }
1149
1150    #[test]
1151    fn test_temporal_pooling() {
1152        let features = randn(&[5, 4]).unwrap();
1153        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
1154        let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
1155        let graph = GraphData::new(features, edge_index);
1156
1157        let temporal_graph = TemporalGraphData::new(graph, 3.0, 150);
1158
1159        let pooled =
1160            pooling::temporal_pool(&temporal_graph, pooling::TemporalPoolingMethod::MostRecent);
1161        assert_eq!(pooled.shape().dims(), &[4]);
1162
1163        let weighted_pooled = pooling::temporal_pool(
1164            &temporal_graph,
1165            pooling::TemporalPoolingMethod::TimeWeightedMean,
1166        );
1167        assert_eq!(weighted_pooled.shape().dims(), &[4]);
1168    }
1169
1170    #[test]
1171    fn test_temporal_utils() {
1172        let events = utils::generate_random_events(10, 5, 10.0, 3);
1173        assert_eq!(events.len(), 10);
1174
1175        // Check that events are sorted by time
1176        for i in 1..events.len() {
1177            assert!(events[i].time >= events[i - 1].time);
1178        }
1179
1180        let features = randn(&[5, 3]).unwrap();
1181        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 0.0];
1182        let edge_index = from_vec(edges, &[2, 5], DeviceType::Cpu).unwrap();
1183        let graph = GraphData::new(features, edge_index);
1184
1185        let temporal_graph = utils::create_temporal_graph_from_events(graph, events, 5.0);
1186        let metrics = utils::temporal_metrics(&temporal_graph);
1187
1188        assert!(metrics.total_events > 0);
1189        assert!(metrics.time_span >= 0.0);
1190    }
1191
1192    #[test]
1193    fn test_event_time_range_query() {
1194        let features = randn(&[3, 2]).unwrap();
1195        let edges = vec![0.0, 1.0, 1.0, 2.0];
1196        let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1197        let graph = GraphData::new(features, edge_index);
1198
1199        let mut temporal_graph = TemporalGraphData::new(graph, 10.0, 100);
1200
1201        // Add events at different times
1202        for i in 0..5 {
1203            let event = TemporalEvent {
1204                time: i as f64,
1205                event_type: EventType::NodeFeatureUpdate,
1206                source: None,
1207                target: None,
1208                node: Some(i % 3),
1209                features: Some(randn(&[2]).unwrap()),
1210                weight: None,
1211            };
1212            temporal_graph.add_event(event);
1213        }
1214
1215        let events_in_range = temporal_graph.get_events_in_range(1.0, 3.0);
1216        assert_eq!(events_in_range.len(), 3); // Events at times 1, 2, 3
1217    }
1218}