Skip to main content

scirs2_graph/gnn/transformers/
temporal_gnn.rs

1//! Temporal Graph Neural Networks (TGN)
2//!
3//! Implements temporal graph neural network components inspired by
4//! Rossi et al. (2020), "Temporal Graph Networks for Deep Learning on
5//! Dynamic Graphs".
6//!
7//! Key components:
8//! - **Temporal events**: continuous-time interactions between nodes
9//! - **Time encoding**: learnable temporal representations (Time2Vec / sinusoidal)
10//! - **Memory module**: GRU-based node memory updated on new events
11//! - **Temporal attention**: attend to recent neighbors with time decay
12//! - **Message function**: compute messages from temporal events
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::{Rng, RngExt};
16
17use crate::error::{GraphError, Result};
18
19// ============================================================================
20// Temporal Event
21// ============================================================================
22
23/// A temporal interaction event in a continuous-time dynamic graph.
24#[derive(Debug, Clone)]
25pub struct TemporalEvent {
26    /// Source node index
27    pub source: usize,
28    /// Target node index
29    pub target: usize,
30    /// Timestamp of the event
31    pub timestamp: f64,
32    /// Optional edge features associated with the event
33    pub features: Option<Vec<f64>>,
34}
35
36impl TemporalEvent {
37    /// Create a new temporal event.
38    pub fn new(source: usize, target: usize, timestamp: f64) -> Self {
39        TemporalEvent {
40            source,
41            target,
42            timestamp,
43            features: None,
44        }
45    }
46
47    /// Create a temporal event with features.
48    pub fn with_features(source: usize, target: usize, timestamp: f64, features: Vec<f64>) -> Self {
49        TemporalEvent {
50            source,
51            target,
52            timestamp,
53            features: Some(features),
54        }
55    }
56}
57
58// ============================================================================
59// Time Encoding
60// ============================================================================
61
62/// Type of time encoding to use.
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum TimeEncodingType {
65    /// Sinusoidal encoding similar to positional encoding in Transformers
66    Sinusoidal,
67    /// Learnable Time2Vec encoding
68    Time2Vec,
69}
70
71/// Temporal encoding module that maps scalar timestamps to vector representations.
72///
73/// Supports two encoding types:
74/// - **Sinusoidal**: fixed frequencies at different scales
75/// - **Time2Vec**: learnable linear + periodic components
76#[derive(Debug, Clone)]
77pub struct TimeEncoding {
78    /// Encoding type
79    pub encoding_type: TimeEncodingType,
80    /// Output dimension for time encoding
81    pub time_dim: usize,
82    /// Learnable frequency parameters for Time2Vec: `[time_dim]`
83    pub omega: Array1<f64>,
84    /// Learnable phase parameters for Time2Vec: `[time_dim]`
85    pub phi: Array1<f64>,
86    /// Linear component weight for Time2Vec
87    pub linear_weight: f64,
88    /// Linear component bias for Time2Vec
89    pub linear_bias: f64,
90}
91
92impl TimeEncoding {
93    /// Create a new time encoding module.
94    ///
95    /// # Arguments
96    /// * `time_dim` - Output dimension for temporal encoding
97    /// * `encoding_type` - Type of encoding (sinusoidal or Time2Vec)
98    pub fn new(time_dim: usize, encoding_type: TimeEncodingType) -> Self {
99        let mut rng = scirs2_core::random::rng();
100
101        let omega = match &encoding_type {
102            TimeEncodingType::Sinusoidal => {
103                // Fixed geometric frequencies
104                Array1::from_iter(
105                    (0..time_dim)
106                        .map(|i| 1.0 / 10000.0_f64.powf(2.0 * (i / 2) as f64 / time_dim as f64)),
107                )
108            }
109            TimeEncodingType::Time2Vec => {
110                // Learnable frequencies initialized around 1.0
111                Array1::from_iter((0..time_dim).map(|_| rng.random::<f64>() * 2.0))
112            }
113        };
114
115        let phi =
116            Array1::from_iter((0..time_dim).map(|_| rng.random::<f64>() * std::f64::consts::TAU));
117
118        TimeEncoding {
119            encoding_type,
120            time_dim,
121            omega,
122            phi,
123            linear_weight: rng.random::<f64>() * 0.1,
124            linear_bias: 0.0,
125        }
126    }
127
128    /// Encode a single timestamp.
129    ///
130    /// # Arguments
131    /// * `t` - Scalar timestamp
132    ///
133    /// # Returns
134    /// Time encoding vector of length `time_dim`
135    pub fn encode(&self, t: f64) -> Array1<f64> {
136        let mut encoding = Array1::zeros(self.time_dim);
137
138        match self.encoding_type {
139            TimeEncodingType::Sinusoidal => {
140                for i in 0..self.time_dim {
141                    let angle = t * self.omega[i];
142                    if i % 2 == 0 {
143                        encoding[i] = angle.sin();
144                    } else {
145                        encoding[i] = angle.cos();
146                    }
147                }
148            }
149            TimeEncodingType::Time2Vec => {
150                // First component is linear
151                if self.time_dim > 0 {
152                    encoding[0] = self.linear_weight * t + self.linear_bias;
153                }
154                // Remaining components are periodic
155                for i in 1..self.time_dim {
156                    encoding[i] = (self.omega[i] * t + self.phi[i]).sin();
157                }
158            }
159        }
160
161        encoding
162    }
163
164    /// Encode multiple timestamps.
165    ///
166    /// # Arguments
167    /// * `timestamps` - Slice of timestamps
168    ///
169    /// # Returns
170    /// Matrix `[len, time_dim]` of time encodings
171    pub fn encode_batch(&self, timestamps: &[f64]) -> Array2<f64> {
172        let n = timestamps.len();
173        let mut result = Array2::zeros((n, self.time_dim));
174        for (i, &t) in timestamps.iter().enumerate() {
175            let enc = self.encode(t);
176            for j in 0..self.time_dim {
177                result[[i, j]] = enc[j];
178            }
179        }
180        result
181    }
182}
183
184// ============================================================================
185// Memory Module
186// ============================================================================
187
188/// Method for updating node memory.
189#[derive(Debug, Clone, PartialEq, Eq)]
190pub enum MemoryUpdateMethod {
191    /// GRU-based memory update
192    Gru,
193    /// Simple concatenation + projection
194    Mlp,
195}
196
197/// TGN-style memory module that maintains compressed node interaction history.
198///
199/// Each node has a memory vector that is updated when new events arrive.
200/// The memory captures temporal patterns and interaction history.
201#[derive(Debug, Clone)]
202pub struct MemoryModule {
203    /// Node memory states: `[n_nodes, memory_dim]`
204    pub memory: Array2<f64>,
205    /// Last update timestamp per node
206    pub last_update: Vec<f64>,
207    /// Memory dimension
208    pub memory_dim: usize,
209    /// Time encoding dimension
210    pub time_dim: usize,
211    /// Number of nodes
212    pub n_nodes: usize,
213    /// Memory update method
214    pub update_method: MemoryUpdateMethod,
215    /// Time encoding module
216    pub time_encoding: TimeEncoding,
217    /// Message dimension (memory_dim + memory_dim + time_dim + optional features)
218    pub message_dim: usize,
219
220    // GRU parameters (when update_method == Gru)
221    /// GRU update gate: W_z `[message_dim, memory_dim]`
222    gru_wz: Array2<f64>,
223    /// GRU update gate: U_z `[memory_dim, memory_dim]`
224    gru_uz: Array2<f64>,
225    /// GRU reset gate: W_r `[message_dim, memory_dim]`
226    gru_wr: Array2<f64>,
227    /// GRU reset gate: U_r `[memory_dim, memory_dim]`
228    gru_ur: Array2<f64>,
229    /// GRU candidate: W_h `[message_dim, memory_dim]`
230    gru_wh: Array2<f64>,
231    /// GRU candidate: U_h `[memory_dim, memory_dim]`
232    gru_uh: Array2<f64>,
233    /// GRU biases
234    gru_bz: Array1<f64>,
235    /// GRU reset bias
236    gru_br: Array1<f64>,
237    /// GRU candidate bias
238    gru_bh: Array1<f64>,
239
240    // MLP parameters (when update_method == Mlp)
241    /// MLP projection: `[message_dim + memory_dim, memory_dim]`
242    mlp_w: Array2<f64>,
243    /// MLP bias
244    mlp_b: Array1<f64>,
245}
246
247impl MemoryModule {
248    /// Create a new memory module.
249    ///
250    /// # Arguments
251    /// * `n_nodes` - Number of nodes in the graph
252    /// * `memory_dim` - Dimension of node memory vectors
253    /// * `time_dim` - Dimension for time encoding
254    /// * `update_method` - Memory update method (GRU or MLP)
255    pub fn new(
256        n_nodes: usize,
257        memory_dim: usize,
258        time_dim: usize,
259        update_method: MemoryUpdateMethod,
260    ) -> Self {
261        let mut rng = scirs2_core::random::rng();
262        let time_encoding = TimeEncoding::new(time_dim, TimeEncodingType::Time2Vec);
263
264        // message = concat(source_memory, target_memory, time_encoding)
265        let message_dim = memory_dim + memory_dim + time_dim;
266
267        let scale_gru = (6.0_f64 / (message_dim + memory_dim) as f64).sqrt();
268        let scale_u = (6.0_f64 / (2 * memory_dim) as f64).sqrt();
269        let scale_mlp = (6.0_f64 / (message_dim + 2 * memory_dim) as f64).sqrt();
270
271        let mut init = |r: usize, c: usize, s: f64| -> Array2<f64> {
272            Array2::from_shape_fn((r, c), |_| (rng.random::<f64>() * 2.0 - 1.0) * s)
273        };
274
275        MemoryModule {
276            memory: Array2::zeros((n_nodes, memory_dim)),
277            last_update: vec![0.0; n_nodes],
278            memory_dim,
279            time_dim,
280            n_nodes,
281            update_method,
282            time_encoding,
283            message_dim,
284            gru_wz: init(message_dim, memory_dim, scale_gru),
285            gru_uz: init(memory_dim, memory_dim, scale_u),
286            gru_wr: init(message_dim, memory_dim, scale_gru),
287            gru_ur: init(memory_dim, memory_dim, scale_u),
288            gru_wh: init(message_dim, memory_dim, scale_gru),
289            gru_uh: init(memory_dim, memory_dim, scale_u),
290            gru_bz: Array1::zeros(memory_dim),
291            gru_br: Array1::zeros(memory_dim),
292            gru_bh: Array1::zeros(memory_dim),
293            mlp_w: init(message_dim + memory_dim, memory_dim, scale_mlp),
294            mlp_b: Array1::zeros(memory_dim),
295        }
296    }
297
298    /// Compute message from a temporal event.
299    ///
300    /// Message = concat(source_memory, target_memory, time_encoding(delta_t))
301    fn compute_message(&self, event: &TemporalEvent) -> Vec<f64> {
302        let src = event.source;
303        let tgt = event.target;
304        let delta_t = event.timestamp - self.last_update[src].max(self.last_update[tgt]);
305        let time_enc = self.time_encoding.encode(delta_t);
306
307        let mut msg = Vec::with_capacity(self.message_dim);
308
309        // Source memory
310        for j in 0..self.memory_dim {
311            msg.push(if src < self.n_nodes {
312                self.memory[[src, j]]
313            } else {
314                0.0
315            });
316        }
317
318        // Target memory
319        for j in 0..self.memory_dim {
320            msg.push(if tgt < self.n_nodes {
321                self.memory[[tgt, j]]
322            } else {
323                0.0
324            });
325        }
326
327        // Time encoding
328        for j in 0..self.time_dim {
329            msg.push(time_enc[j]);
330        }
331
332        msg
333    }
334
335    /// GRU update step.
336    ///
337    /// ```text
338    /// z = sigmoid(W_z * msg + U_z * h + b_z)
339    /// r = sigmoid(W_r * msg + U_r * h + b_r)
340    /// h_tilde = tanh(W_h * msg + U_h * (r * h) + b_h)
341    /// h_new = (1 - z) * h + z * h_tilde
342    /// ```
343    fn gru_update(&self, memory: &[f64], message: &[f64]) -> Vec<f64> {
344        let d = self.memory_dim;
345        let m = self.message_dim;
346
347        // Compute z (update gate)
348        let mut z = vec![0.0f64; d];
349        for j in 0..d {
350            let mut s = self.gru_bz[j];
351            for k in 0..m {
352                s += message[k] * self.gru_wz[[k, j]];
353            }
354            for k in 0..d {
355                s += memory[k] * self.gru_uz[[k, j]];
356            }
357            z[j] = sigmoid(s);
358        }
359
360        // Compute r (reset gate)
361        let mut r = vec![0.0f64; d];
362        for j in 0..d {
363            let mut s = self.gru_br[j];
364            for k in 0..m {
365                s += message[k] * self.gru_wr[[k, j]];
366            }
367            for k in 0..d {
368                s += memory[k] * self.gru_ur[[k, j]];
369            }
370            r[j] = sigmoid(s);
371        }
372
373        // Compute h_tilde (candidate)
374        let mut h_tilde = vec![0.0f64; d];
375        for j in 0..d {
376            let mut s = self.gru_bh[j];
377            for k in 0..m {
378                s += message[k] * self.gru_wh[[k, j]];
379            }
380            for k in 0..d {
381                s += (r[k] * memory[k]) * self.gru_uh[[k, j]];
382            }
383            h_tilde[j] = s.tanh();
384        }
385
386        // h_new = (1 - z) * h + z * h_tilde
387        let mut h_new = vec![0.0f64; d];
388        for j in 0..d {
389            h_new[j] = (1.0 - z[j]) * memory[j] + z[j] * h_tilde[j];
390        }
391
392        h_new
393    }
394
395    /// MLP-based memory update.
396    fn mlp_update(&self, memory: &[f64], message: &[f64]) -> Vec<f64> {
397        let d = self.memory_dim;
398        let total_in = self.message_dim + d;
399
400        // Concatenate message + memory
401        let mut input = Vec::with_capacity(total_in);
402        input.extend_from_slice(message);
403        input.extend_from_slice(memory);
404
405        // Linear + tanh
406        let mut out = vec![0.0f64; d];
407        for j in 0..d {
408            let mut s = self.mlp_b[j];
409            for k in 0..total_in {
410                s += input[k] * self.mlp_w[[k, j]];
411            }
412            out[j] = s.tanh();
413        }
414
415        out
416    }
417
418    /// Process a single temporal event and update node memories.
419    ///
420    /// Updates the memory of both source and target nodes involved in the event.
421    ///
422    /// # Arguments
423    /// * `event` - The temporal interaction event
424    pub fn process_event(&mut self, event: &TemporalEvent) -> Result<()> {
425        if event.source >= self.n_nodes || event.target >= self.n_nodes {
426            return Err(GraphError::InvalidParameter {
427                param: "event".to_string(),
428                value: format!("source={}, target={}", event.source, event.target),
429                expected: format!("indices < {}", self.n_nodes),
430                context: "MemoryModule::process_event".to_string(),
431            });
432        }
433
434        let message = self.compute_message(event);
435
436        // Update source memory
437        let src_memory: Vec<f64> = (0..self.memory_dim)
438            .map(|j| self.memory[[event.source, j]])
439            .collect();
440        let new_src = match self.update_method {
441            MemoryUpdateMethod::Gru => self.gru_update(&src_memory, &message),
442            MemoryUpdateMethod::Mlp => self.mlp_update(&src_memory, &message),
443        };
444
445        // Update target memory
446        let tgt_memory: Vec<f64> = (0..self.memory_dim)
447            .map(|j| self.memory[[event.target, j]])
448            .collect();
449        let new_tgt = match self.update_method {
450            MemoryUpdateMethod::Gru => self.gru_update(&tgt_memory, &message),
451            MemoryUpdateMethod::Mlp => self.mlp_update(&tgt_memory, &message),
452        };
453
454        // Write back
455        for j in 0..self.memory_dim {
456            self.memory[[event.source, j]] = new_src[j];
457            self.memory[[event.target, j]] = new_tgt[j];
458        }
459
460        self.last_update[event.source] = event.timestamp;
461        self.last_update[event.target] = event.timestamp;
462
463        Ok(())
464    }
465
466    /// Process a batch of temporal events in chronological order.
467    ///
468    /// Events should be sorted by timestamp (ascending).
469    pub fn process_events(&mut self, events: &[TemporalEvent]) -> Result<()> {
470        for event in events {
471            self.process_event(event)?;
472        }
473        Ok(())
474    }
475
476    /// Get the current memory state for all nodes.
477    pub fn get_memory(&self) -> &Array2<f64> {
478        &self.memory
479    }
480
481    /// Reset all node memories to zero.
482    pub fn reset(&mut self) {
483        self.memory.fill(0.0);
484        self.last_update.fill(0.0);
485    }
486}
487
488/// Sigmoid activation function.
489#[inline]
490fn sigmoid(x: f64) -> f64 {
491    1.0 / (1.0 + (-x).exp())
492}
493
494// ============================================================================
495// Temporal Attention
496// ============================================================================
497
498/// Temporal attention mechanism that attends to recent neighbors
499/// with time-decay weighting.
500///
501/// For each node, computes attention over its recent interaction partners,
502/// incorporating temporal information through time encodings.
503#[derive(Debug, Clone)]
504pub struct TemporalAttention {
505    /// Query projection: `[memory_dim, hidden_dim]`
506    pub w_q: Array2<f64>,
507    /// Key projection: `[memory_dim + time_dim, hidden_dim]`
508    pub w_k: Array2<f64>,
509    /// Value projection: `[memory_dim + time_dim, hidden_dim]`
510    pub w_v: Array2<f64>,
511    /// Number of attention heads
512    pub num_heads: usize,
513    /// Hidden dimension
514    pub hidden_dim: usize,
515    /// Dimension per head
516    pub head_dim: usize,
517    /// Time encoding module
518    pub time_encoding: TimeEncoding,
519    /// Memory dimension
520    pub memory_dim: usize,
521    /// Time encoding dimension
522    pub time_dim: usize,
523}
524
525impl TemporalAttention {
526    /// Create a new temporal attention module.
527    ///
528    /// # Arguments
529    /// * `memory_dim` - Dimension of node memory
530    /// * `time_dim` - Dimension of time encoding
531    /// * `num_heads` - Number of attention heads
532    pub fn new(memory_dim: usize, time_dim: usize, num_heads: usize) -> Result<Self> {
533        let hidden_dim = memory_dim;
534        if !hidden_dim.is_multiple_of(num_heads) {
535            return Err(GraphError::InvalidParameter {
536                param: "memory_dim".to_string(),
537                value: format!("{memory_dim}"),
538                expected: format!("divisible by num_heads={num_heads}"),
539                context: "TemporalAttention::new".to_string(),
540            });
541        }
542
543        let head_dim = hidden_dim / num_heads;
544        let mut rng = scirs2_core::random::rng();
545        let scale_q = (6.0_f64 / (memory_dim + hidden_dim) as f64).sqrt();
546        let scale_kv = (6.0_f64 / (memory_dim + time_dim + hidden_dim) as f64).sqrt();
547
548        let w_q = Array2::from_shape_fn((memory_dim, hidden_dim), |_| {
549            (rng.random::<f64>() * 2.0 - 1.0) * scale_q
550        });
551        let w_k = Array2::from_shape_fn((memory_dim + time_dim, hidden_dim), |_| {
552            (rng.random::<f64>() * 2.0 - 1.0) * scale_kv
553        });
554        let w_v = Array2::from_shape_fn((memory_dim + time_dim, hidden_dim), |_| {
555            (rng.random::<f64>() * 2.0 - 1.0) * scale_kv
556        });
557
558        let time_encoding = TimeEncoding::new(time_dim, TimeEncodingType::Sinusoidal);
559
560        Ok(TemporalAttention {
561            w_q,
562            w_k,
563            w_v,
564            num_heads,
565            hidden_dim,
566            head_dim,
567            time_encoding,
568            memory_dim,
569            time_dim,
570        })
571    }
572
573    /// Compute temporal attention for a single node given its neighbors and event times.
574    ///
575    /// # Arguments
576    /// * `query_memory` - Memory of the query node `[memory_dim]`
577    /// * `neighbor_memories` - Memory vectors of neighbor nodes `[num_neighbors, memory_dim]`
578    /// * `time_deltas` - Time differences for each neighbor interaction `[num_neighbors]`
579    ///
580    /// # Returns
581    /// Aggregated representation `[hidden_dim]`
582    pub fn forward(
583        &self,
584        query_memory: &Array1<f64>,
585        neighbor_memories: &Array2<f64>,
586        time_deltas: &[f64],
587    ) -> Result<Array1<f64>> {
588        let num_neighbors = neighbor_memories.dim().0;
589        if num_neighbors == 0 {
590            return Ok(Array1::zeros(self.hidden_dim));
591        }
592        if time_deltas.len() != num_neighbors {
593            return Err(GraphError::InvalidParameter {
594                param: "time_deltas".to_string(),
595                value: format!("len={}", time_deltas.len()),
596                expected: format!("len={num_neighbors}"),
597                context: "TemporalAttention::forward".to_string(),
598            });
599        }
600
601        let h = self.num_heads;
602        let dk = self.head_dim;
603        let scale = 1.0 / (dk as f64).sqrt();
604
605        // Query: W_q * query_memory -> [hidden_dim]
606        let mut q = vec![0.0f64; self.hidden_dim];
607        for j in 0..self.hidden_dim {
608            for m in 0..self.memory_dim {
609                q[j] += query_memory[m] * self.w_q[[m, j]];
610            }
611        }
612
613        // For each neighbor, compute key and value
614        let kv_in_dim = self.memory_dim + self.time_dim;
615        let mut keys = vec![vec![0.0f64; self.hidden_dim]; num_neighbors];
616        let mut values = vec![vec![0.0f64; self.hidden_dim]; num_neighbors];
617
618        for nb in 0..num_neighbors {
619            // Concatenate neighbor memory with time encoding
620            let time_enc = self.time_encoding.encode(time_deltas[nb]);
621            let mut kv_input = Vec::with_capacity(kv_in_dim);
622            for m in 0..self.memory_dim {
623                kv_input.push(neighbor_memories[[nb, m]]);
624            }
625            for m in 0..self.time_dim {
626                kv_input.push(time_enc[m]);
627            }
628
629            for j in 0..self.hidden_dim {
630                let mut sk = 0.0;
631                let mut sv = 0.0;
632                for m in 0..kv_in_dim {
633                    sk += kv_input[m] * self.w_k[[m, j]];
634                    sv += kv_input[m] * self.w_v[[m, j]];
635                }
636                keys[nb][j] = sk;
637                values[nb][j] = sv;
638            }
639        }
640
641        // Multi-head attention
642        let mut output = vec![0.0f64; self.hidden_dim];
643
644        for head in 0..h {
645            let offset = head * dk;
646
647            // Scores
648            let mut scores = vec![0.0f64; num_neighbors];
649            for nb in 0..num_neighbors {
650                let mut dot = 0.0;
651                for m in 0..dk {
652                    dot += q[offset + m] * keys[nb][offset + m];
653                }
654                scores[nb] = dot * scale;
655            }
656
657            // Softmax
658            let alphas = softmax_slice(&scores);
659
660            // Aggregate
661            for nb in 0..num_neighbors {
662                for m in 0..dk {
663                    output[offset + m] += alphas[nb] * values[nb][offset + m];
664                }
665            }
666        }
667
668        Ok(Array1::from_vec(output))
669    }
670}
671
672/// Numerically-stable softmax.
673fn softmax_slice(xs: &[f64]) -> Vec<f64> {
674    if xs.is_empty() {
675        return Vec::new();
676    }
677    let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
678    let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
679    let sum = exps.iter().sum::<f64>().max(1e-12);
680    exps.iter().map(|e| e / sum).collect()
681}
682
683// ============================================================================
684// TGN Configuration
685// ============================================================================
686
687/// Configuration for the Temporal GNN model.
688#[derive(Debug, Clone)]
689pub struct TemporalGnnConfig {
690    /// Number of nodes
691    pub n_nodes: usize,
692    /// Memory dimension
693    pub memory_dim: usize,
694    /// Time encoding dimension
695    pub time_dim: usize,
696    /// Number of attention heads
697    pub num_heads: usize,
698    /// Memory update method
699    pub update_method: MemoryUpdateMethod,
700    /// Time encoding type
701    pub time_encoding_type: TimeEncodingType,
702}
703
704impl Default for TemporalGnnConfig {
705    fn default() -> Self {
706        TemporalGnnConfig {
707            n_nodes: 100,
708            memory_dim: 64,
709            time_dim: 16,
710            num_heads: 4,
711            update_method: MemoryUpdateMethod::Gru,
712            time_encoding_type: TimeEncodingType::Time2Vec,
713        }
714    }
715}
716
717// ============================================================================
718// Temporal GNN Model
719// ============================================================================
720
721/// Full Temporal Graph Neural Network model.
722///
723/// Combines a TGN-style memory module with temporal attention for
724/// processing continuous-time dynamic graphs.
725#[derive(Debug, Clone)]
726pub struct TemporalGnnModel {
727    /// Memory module
728    pub memory_module: MemoryModule,
729    /// Temporal attention
730    pub temporal_attention: TemporalAttention,
731    /// Configuration
732    pub config: TemporalGnnConfig,
733    /// Event history for neighbor lookup
734    event_history: Vec<TemporalEvent>,
735}
736
737impl TemporalGnnModel {
738    /// Create a new Temporal GNN model.
739    pub fn new(config: TemporalGnnConfig) -> Result<Self> {
740        let memory_module = MemoryModule::new(
741            config.n_nodes,
742            config.memory_dim,
743            config.time_dim,
744            config.update_method.clone(),
745        );
746        let temporal_attention =
747            TemporalAttention::new(config.memory_dim, config.time_dim, config.num_heads)?;
748
749        Ok(TemporalGnnModel {
750            memory_module,
751            temporal_attention,
752            config,
753            event_history: Vec::new(),
754        })
755    }
756
757    /// Process a batch of events and update model state.
758    ///
759    /// Events should be in chronological order.
760    pub fn process_events(&mut self, events: &[TemporalEvent]) -> Result<()> {
761        self.memory_module.process_events(events)?;
762        self.event_history.extend(events.iter().cloned());
763        Ok(())
764    }
765
766    /// Get node embedding at a given time by aggregating over recent neighbors.
767    ///
768    /// # Arguments
769    /// * `node` - Node index
770    /// * `current_time` - Current timestamp for computing time deltas
771    /// * `max_neighbors` - Maximum number of recent neighbors to attend to
772    ///
773    /// # Returns
774    /// Node embedding `[memory_dim]`
775    pub fn get_node_embedding(
776        &self,
777        node: usize,
778        current_time: f64,
779        max_neighbors: usize,
780    ) -> Result<Array1<f64>> {
781        if node >= self.config.n_nodes {
782            return Err(GraphError::InvalidParameter {
783                param: "node".to_string(),
784                value: format!("{node}"),
785                expected: format!("< {}", self.config.n_nodes),
786                context: "TemporalGnnModel::get_node_embedding".to_string(),
787            });
788        }
789
790        // Find recent neighbors from event history
791        let mut neighbor_events: Vec<(usize, f64)> = Vec::new();
792        for event in self.event_history.iter().rev() {
793            if neighbor_events.len() >= max_neighbors {
794                break;
795            }
796            if event.source == node {
797                neighbor_events.push((event.target, event.timestamp));
798            } else if event.target == node {
799                neighbor_events.push((event.source, event.timestamp));
800            }
801        }
802
803        if neighbor_events.is_empty() {
804            // Return raw memory if no neighbors found
805            return Ok(Array1::from_iter(
806                (0..self.config.memory_dim).map(|j| self.memory_module.memory[[node, j]]),
807            ));
808        }
809
810        // Build neighbor memory matrix and time deltas
811        let num_nb = neighbor_events.len();
812        let mut nb_memories = Array2::zeros((num_nb, self.config.memory_dim));
813        let mut time_deltas = vec![0.0f64; num_nb];
814
815        for (idx, &(nb_node, nb_time)) in neighbor_events.iter().enumerate() {
816            for j in 0..self.config.memory_dim {
817                nb_memories[[idx, j]] = self.memory_module.memory[[nb_node, j]];
818            }
819            time_deltas[idx] = current_time - nb_time;
820        }
821
822        // Query memory
823        let query = Array1::from_iter(
824            (0..self.config.memory_dim).map(|j| self.memory_module.memory[[node, j]]),
825        );
826
827        self.temporal_attention
828            .forward(&query, &nb_memories, &time_deltas)
829    }
830
831    /// Reset the model state (memory and event history).
832    pub fn reset(&mut self) {
833        self.memory_module.reset();
834        self.event_history.clear();
835    }
836
837    /// Get the current memory state.
838    pub fn get_memory(&self) -> &Array2<f64> {
839        self.memory_module.get_memory()
840    }
841}
842
843// ============================================================================
844// Tests
845// ============================================================================
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850
851    #[test]
852    fn test_time_encoding_sinusoidal_varies_with_time() {
853        let te = TimeEncoding::new(8, TimeEncodingType::Sinusoidal);
854        let enc1 = te.encode(0.0);
855        let enc2 = te.encode(1.0);
856        let enc3 = te.encode(10.0);
857
858        assert_eq!(enc1.len(), 8);
859        assert_eq!(enc2.len(), 8);
860
861        // Different timestamps should produce different encodings
862        let diff_12: f64 = enc1
863            .iter()
864            .zip(enc2.iter())
865            .map(|(a, b)| (a - b).abs())
866            .sum();
867        let diff_13: f64 = enc1
868            .iter()
869            .zip(enc3.iter())
870            .map(|(a, b)| (a - b).abs())
871            .sum();
872
873        assert!(diff_12 > 1e-6, "encodings at t=0 and t=1 should differ");
874        assert!(diff_13 > 1e-6, "encodings at t=0 and t=10 should differ");
875    }
876
877    #[test]
878    fn test_time_encoding_time2vec() {
879        let te = TimeEncoding::new(6, TimeEncodingType::Time2Vec);
880        let enc = te.encode(5.0);
881        assert_eq!(enc.len(), 6);
882        for &v in enc.iter() {
883            assert!(v.is_finite(), "Time2Vec encoding should be finite");
884        }
885
886        // First component should be linear
887        let enc0 = te.encode(0.0);
888        let enc10 = te.encode(10.0);
889        // Linear component: w * t + b
890        let expected_diff = te.linear_weight * 10.0;
891        let actual_diff = enc10[0] - enc0[0];
892        assert!(
893            (actual_diff - expected_diff).abs() < 1e-10,
894            "first component should be linear"
895        );
896    }
897
898    #[test]
899    fn test_memory_update_changes_state() {
900        let mut mem = MemoryModule::new(5, 8, 4, MemoryUpdateMethod::Gru);
901
902        // Initially all zeros
903        let initial_norm: f64 = mem.memory.iter().map(|x| x * x).sum();
904        assert!(initial_norm < 1e-12, "initial memory should be zero");
905
906        // Process an event
907        let event = TemporalEvent::new(0, 1, 1.0);
908        mem.process_event(&event).expect("process event");
909
910        // Memory should have changed for nodes 0 and 1
911        let node0_norm: f64 = (0..8).map(|j| mem.memory[[0, j]].powi(2)).sum();
912        let node1_norm: f64 = (0..8).map(|j| mem.memory[[1, j]].powi(2)).sum();
913
914        assert!(
915            node0_norm > 1e-12,
916            "node 0 memory should be updated after event"
917        );
918        assert!(
919            node1_norm > 1e-12,
920            "node 1 memory should be updated after event"
921        );
922
923        // Node 2 should still be zero (not involved in event)
924        let node2_norm: f64 = (0..8).map(|j| mem.memory[[2, j]].powi(2)).sum();
925        assert!(node2_norm < 1e-12, "node 2 memory should remain zero");
926    }
927
928    #[test]
929    fn test_memory_update_mlp() {
930        let mut mem = MemoryModule::new(4, 6, 3, MemoryUpdateMethod::Mlp);
931        let event = TemporalEvent::new(0, 1, 0.5);
932        mem.process_event(&event).expect("process event MLP");
933
934        let node0_norm: f64 = (0..6).map(|j| mem.memory[[0, j]].powi(2)).sum();
935        assert!(node0_norm > 1e-12, "MLP update should modify memory");
936    }
937
938    #[test]
939    fn test_temporal_attention_shape() {
940        let ta = TemporalAttention::new(8, 4, 2).expect("temporal attention");
941        let query = Array1::from_vec(vec![0.1; 8]);
942        let neighbors = Array2::from_shape_fn((3, 8), |(i, j)| (i + j) as f64 * 0.05);
943        let deltas = vec![1.0, 2.0, 3.0];
944
945        let out = ta.forward(&query, &neighbors, &deltas).expect("forward");
946        assert_eq!(out.len(), 8);
947        for &v in out.iter() {
948            assert!(v.is_finite(), "temporal attention output should be finite");
949        }
950    }
951
952    #[test]
953    fn test_temporal_attention_empty_neighbors() {
954        let ta = TemporalAttention::new(8, 4, 2).expect("temporal attention");
955        let query = Array1::from_vec(vec![0.1; 8]);
956        let neighbors = Array2::zeros((0, 8));
957        let deltas: Vec<f64> = Vec::new();
958
959        let out = ta
960            .forward(&query, &neighbors, &deltas)
961            .expect("empty forward");
962        assert_eq!(out.len(), 8);
963        // Should be all zeros for empty neighbors
964        let norm: f64 = out.iter().map(|x| x * x).sum();
965        assert!(norm < 1e-12, "empty neighbor attention should return zeros");
966    }
967
968    #[test]
969    fn test_temporal_gnn_model_full_pipeline() {
970        let config = TemporalGnnConfig {
971            n_nodes: 5,
972            memory_dim: 8,
973            time_dim: 4,
974            num_heads: 2,
975            update_method: MemoryUpdateMethod::Gru,
976            time_encoding_type: TimeEncodingType::Time2Vec,
977        };
978
979        let mut model = TemporalGnnModel::new(config).expect("model");
980
981        // Process events
982        let events = vec![
983            TemporalEvent::new(0, 1, 1.0),
984            TemporalEvent::new(1, 2, 2.0),
985            TemporalEvent::new(0, 2, 3.0),
986            TemporalEvent::new(2, 3, 4.0),
987        ];
988        model.process_events(&events).expect("process events");
989
990        // Get embeddings
991        let emb0 = model.get_node_embedding(0, 5.0, 3).expect("embedding 0");
992        let emb4 = model.get_node_embedding(4, 5.0, 3).expect("embedding 4");
993
994        assert_eq!(emb0.len(), 8);
995        assert_eq!(emb4.len(), 8);
996
997        // Node 0 has interactions, node 4 does not - embeddings should differ
998        let diff: f64 = emb0
999            .iter()
1000            .zip(emb4.iter())
1001            .map(|(a, b)| (a - b).abs())
1002            .sum();
1003        // Node 4 returns raw zero memory, node 0 has aggregated info
1004        assert!(
1005            emb0.iter().any(|&v| v.abs() > 1e-12),
1006            "active node should have non-zero embedding"
1007        );
1008    }
1009
1010    #[test]
1011    fn test_memory_module_event_out_of_bounds() {
1012        let mut mem = MemoryModule::new(3, 4, 2, MemoryUpdateMethod::Gru);
1013        let event = TemporalEvent::new(0, 5, 1.0); // node 5 > n_nodes=3
1014        let result = mem.process_event(&event);
1015        assert!(result.is_err());
1016    }
1017
1018    #[test]
1019    fn test_temporal_gnn_reset() {
1020        let config = TemporalGnnConfig {
1021            n_nodes: 3,
1022            memory_dim: 4,
1023            time_dim: 2,
1024            num_heads: 2,
1025            ..Default::default()
1026        };
1027
1028        let mut model = TemporalGnnModel::new(config).expect("model");
1029        let event = TemporalEvent::new(0, 1, 1.0);
1030        model.process_events(&[event]).expect("process");
1031
1032        // After reset, memory should be zero again
1033        model.reset();
1034        let mem_norm: f64 = model.get_memory().iter().map(|x| x * x).sum();
1035        assert!(mem_norm < 1e-12, "memory should be zero after reset");
1036    }
1037
1038    #[test]
1039    fn test_time_encoding_batch() {
1040        let te = TimeEncoding::new(4, TimeEncodingType::Sinusoidal);
1041        let timestamps = vec![0.0, 1.0, 5.0, 10.0];
1042        let batch = te.encode_batch(&timestamps);
1043        assert_eq!(batch.dim(), (4, 4));
1044
1045        // Each row should match individual encoding
1046        for (i, &t) in timestamps.iter().enumerate() {
1047            let single = te.encode(t);
1048            for j in 0..4 {
1049                assert!(
1050                    (batch[[i, j]] - single[j]).abs() < 1e-12,
1051                    "batch encoding should match single encoding"
1052                );
1053            }
1054        }
1055    }
1056
1057    #[test]
1058    fn test_memory_timestamps_updated() {
1059        let mut mem = MemoryModule::new(3, 4, 2, MemoryUpdateMethod::Gru);
1060
1061        assert!(mem.last_update[0] < 1e-12);
1062        assert!(mem.last_update[1] < 1e-12);
1063
1064        let event = TemporalEvent::new(0, 1, 5.0);
1065        mem.process_event(&event).expect("process");
1066
1067        assert!(
1068            (mem.last_update[0] - 5.0).abs() < 1e-12,
1069            "source timestamp should be updated"
1070        );
1071        assert!(
1072            (mem.last_update[1] - 5.0).abs() < 1e-12,
1073            "target timestamp should be updated"
1074        );
1075        assert!(
1076            mem.last_update[2] < 1e-12,
1077            "uninvolved node timestamp should remain 0"
1078        );
1079    }
1080}