rustkernel_graph/
gnn.rs

1//! Graph Neural Network kernels.
2//!
3//! This module provides GPU-accelerated GNN algorithms:
4//! - GNNInference - Message passing neural network inference
5//! - GraphAttention - Graph attention network (GAT) layers
6
7use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11// ============================================================================
12// GNN Inference Kernel
13// ============================================================================
14
15/// Configuration for GNN inference.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GNNConfig {
18    /// Number of message passing layers.
19    pub num_layers: usize,
20    /// Hidden dimension size.
21    pub hidden_dim: usize,
22    /// Output dimension (number of classes or embedding size).
23    pub output_dim: usize,
24    /// Aggregation function for messages.
25    pub aggregation: AggregationType,
26    /// Activation function.
27    pub activation: ActivationType,
28    /// Dropout rate (0-1).
29    pub dropout: f64,
30    /// Whether to add self-loops.
31    pub add_self_loops: bool,
32    /// Whether to use layer normalization.
33    pub layer_norm: bool,
34}
35
36impl Default for GNNConfig {
37    fn default() -> Self {
38        Self {
39            num_layers: 2,
40            hidden_dim: 64,
41            output_dim: 32,
42            aggregation: AggregationType::Mean,
43            activation: ActivationType::ReLU,
44            dropout: 0.0,
45            add_self_loops: true,
46            layer_norm: true,
47        }
48    }
49}
50
51/// Message aggregation type.
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum AggregationType {
54    /// Sum of neighbor messages.
55    Sum,
56    /// Mean of neighbor messages.
57    Mean,
58    /// Max pooling over neighbors.
59    Max,
60    /// GraphSAGE-style sample and aggregate.
61    SAGE,
62}
63
64/// Activation function type.
65#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
66pub enum ActivationType {
67    /// Rectified Linear Unit.
68    ReLU,
69    /// Leaky ReLU with alpha=0.01.
70    LeakyReLU,
71    /// Exponential Linear Unit.
72    ELU,
73    /// Sigmoid function.
74    Sigmoid,
75    /// Hyperbolic tangent.
76    Tanh,
77    /// No activation.
78    None,
79}
80
81/// GNN layer weights (simulated).
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GNNWeights {
84    /// Weight matrices per layer.
85    pub layer_weights: Vec<Vec<Vec<f64>>>,
86    /// Bias vectors per layer.
87    pub layer_biases: Vec<Vec<f64>>,
88}
89
90impl GNNWeights {
91    /// Create random weights for testing.
92    pub fn random(input_dim: usize, config: &GNNConfig) -> Self {
93        use rand::{Rng, rng};
94        let mut r = rng();
95
96        let mut layer_weights = Vec::new();
97        let mut layer_biases = Vec::new();
98
99        let mut prev_dim = input_dim;
100
101        for i in 0..config.num_layers {
102            let out_dim = if i == config.num_layers - 1 {
103                config.output_dim
104            } else {
105                config.hidden_dim
106            };
107
108            // Xavier initialization
109            let scale = (2.0 / (prev_dim + out_dim) as f64).sqrt();
110
111            let weights: Vec<Vec<f64>> = (0..prev_dim)
112                .map(|_| {
113                    (0..out_dim)
114                        .map(|_| r.random_range(-scale..scale))
115                        .collect()
116                })
117                .collect();
118
119            let biases: Vec<f64> = (0..out_dim).map(|_| 0.0).collect();
120
121            layer_weights.push(weights);
122            layer_biases.push(biases);
123            prev_dim = out_dim;
124        }
125
126        Self {
127            layer_weights,
128            layer_biases,
129        }
130    }
131}
132
133/// Result of GNN inference.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct GNNResult {
136    /// Node embeddings after all layers.
137    pub embeddings: Vec<Vec<f64>>,
138    /// Class predictions (if classification).
139    pub predictions: Option<Vec<usize>>,
140    /// Softmax probabilities (if classification).
141    pub probabilities: Option<Vec<Vec<f64>>>,
142}
143
144/// GNN Inference kernel.
145///
146/// Performs message passing neural network inference on graph data.
147/// Supports various aggregation strategies and can be used for
148/// node classification, link prediction, and graph-level tasks.
149#[derive(Debug, Clone)]
150pub struct GNNInference {
151    metadata: KernelMetadata,
152}
153
154impl Default for GNNInference {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl GNNInference {
161    /// Create a new GNN Inference kernel.
162    #[must_use]
163    pub fn new() -> Self {
164        Self {
165            metadata: KernelMetadata::batch("graph/gnn-inference", Domain::GraphAnalytics)
166                .with_description("Message passing neural network inference")
167                .with_throughput(10_000)
168                .with_latency_us(100.0),
169        }
170    }
171
172    /// Run GNN inference on a graph.
173    pub fn compute(
174        graph: &CsrGraph,
175        node_features: &[Vec<f64>],
176        weights: &GNNWeights,
177        config: &GNNConfig,
178    ) -> GNNResult {
179        if graph.num_nodes == 0 || node_features.is_empty() {
180            return GNNResult {
181                embeddings: Vec::new(),
182                predictions: None,
183                probabilities: None,
184            };
185        }
186
187        let n = graph.num_nodes;
188
189        // Build adjacency list from CSR format (with optional self-loops)
190        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
191        for node in 0..n {
192            let start = graph.row_offsets[node] as usize;
193            let end = graph.row_offsets[node + 1] as usize;
194            for &neighbor in &graph.col_indices[start..end] {
195                adj[node].push(neighbor as usize);
196                // Add reverse edge for undirected
197                if !adj[neighbor as usize].contains(&node) {
198                    adj[neighbor as usize].push(node);
199                }
200            }
201        }
202
203        if config.add_self_loops {
204            for i in 0..n {
205                if !adj[i].contains(&i) {
206                    adj[i].push(i);
207                }
208            }
209        }
210
211        // Initialize embeddings from features
212        let mut embeddings: Vec<Vec<f64>> = node_features.to_vec();
213
214        // Run message passing layers
215        for layer_idx in 0..config.num_layers {
216            embeddings = Self::message_passing_layer(
217                &embeddings,
218                &adj,
219                &weights.layer_weights[layer_idx],
220                &weights.layer_biases[layer_idx],
221                config,
222                layer_idx == config.num_layers - 1,
223            );
224        }
225
226        // Compute predictions if output looks like classification
227        let (predictions, probabilities) = if config.output_dim > 1 {
228            let probs: Vec<Vec<f64>> = embeddings.iter().map(|e| Self::softmax(e)).collect();
229            let preds: Vec<usize> = probs
230                .iter()
231                .map(|p| {
232                    p.iter()
233                        .enumerate()
234                        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
235                        .map(|(i, _)| i)
236                        .unwrap_or(0)
237                })
238                .collect();
239            (Some(preds), Some(probs))
240        } else {
241            (None, None)
242        };
243
244        GNNResult {
245            embeddings,
246            predictions,
247            probabilities,
248        }
249    }
250
251    /// Single message passing layer.
252    fn message_passing_layer(
253        embeddings: &[Vec<f64>],
254        adj: &[Vec<usize>],
255        weights: &[Vec<f64>],
256        biases: &[f64],
257        config: &GNNConfig,
258        is_last: bool,
259    ) -> Vec<Vec<f64>> {
260        let n = embeddings.len();
261        let out_dim = biases.len();
262        let mut new_embeddings = vec![vec![0.0; out_dim]; n];
263
264        for i in 0..n {
265            // Aggregate neighbor messages
266            let aggregated = Self::aggregate_neighbors(embeddings, &adj[i], config.aggregation);
267
268            // Transform: out = activation(W * aggregated + b)
269            for j in 0..out_dim {
270                let mut val = biases[j];
271                for (k, &agg_val) in aggregated.iter().enumerate() {
272                    if k < weights.len() && j < weights[k].len() {
273                        val += weights[k][j] * agg_val;
274                    }
275                }
276
277                // Apply activation (except on last layer if doing classification)
278                if !is_last {
279                    val = Self::activate(val, config.activation);
280                }
281
282                new_embeddings[i][j] = val;
283            }
284
285            // Layer normalization
286            if config.layer_norm && !is_last {
287                let mean: f64 = new_embeddings[i].iter().sum::<f64>() / out_dim as f64;
288                let var: f64 = new_embeddings[i]
289                    .iter()
290                    .map(|x| (x - mean).powi(2))
291                    .sum::<f64>()
292                    / out_dim as f64;
293                let std = (var + 1e-5).sqrt();
294
295                for j in 0..out_dim {
296                    new_embeddings[i][j] = (new_embeddings[i][j] - mean) / std;
297                }
298            }
299        }
300
301        new_embeddings
302    }
303
304    /// Aggregate messages from neighbors.
305    fn aggregate_neighbors(
306        embeddings: &[Vec<f64>],
307        neighbors: &[usize],
308        agg_type: AggregationType,
309    ) -> Vec<f64> {
310        if neighbors.is_empty() {
311            return vec![0.0; embeddings.get(0).map(|e| e.len()).unwrap_or(0)];
312        }
313
314        let dim = embeddings[neighbors[0]].len();
315
316        match agg_type {
317            AggregationType::Sum => {
318                let mut result = vec![0.0; dim];
319                for &n in neighbors {
320                    for (i, &v) in embeddings[n].iter().enumerate() {
321                        result[i] += v;
322                    }
323                }
324                result
325            }
326            AggregationType::Mean => {
327                let mut result = vec![0.0; dim];
328                for &n in neighbors {
329                    for (i, &v) in embeddings[n].iter().enumerate() {
330                        result[i] += v;
331                    }
332                }
333                let count = neighbors.len() as f64;
334                result.iter_mut().for_each(|v| *v /= count);
335                result
336            }
337            AggregationType::Max => {
338                let mut result = vec![f64::NEG_INFINITY; dim];
339                for &n in neighbors {
340                    for (i, &v) in embeddings[n].iter().enumerate() {
341                        result[i] = result[i].max(v);
342                    }
343                }
344                result
345            }
346            AggregationType::SAGE => {
347                // GraphSAGE: concat(self, mean(neighbors))
348                // Simplified: just use mean here
349                let mut result = vec![0.0; dim];
350                for &n in neighbors {
351                    for (i, &v) in embeddings[n].iter().enumerate() {
352                        result[i] += v;
353                    }
354                }
355                let count = neighbors.len() as f64;
356                result.iter_mut().for_each(|v| *v /= count);
357                result
358            }
359        }
360    }
361
362    /// Apply activation function.
363    fn activate(x: f64, activation: ActivationType) -> f64 {
364        match activation {
365            ActivationType::ReLU => x.max(0.0),
366            ActivationType::LeakyReLU => {
367                if x > 0.0 {
368                    x
369                } else {
370                    0.01 * x
371                }
372            }
373            ActivationType::ELU => {
374                if x > 0.0 {
375                    x
376                } else {
377                    x.exp() - 1.0
378                }
379            }
380            ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
381            ActivationType::Tanh => x.tanh(),
382            ActivationType::None => x,
383        }
384    }
385
386    /// Softmax for classification.
387    fn softmax(x: &[f64]) -> Vec<f64> {
388        let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
389        let exp_sum: f64 = x.iter().map(|v| (v - max_val).exp()).sum();
390        x.iter().map(|v| (v - max_val).exp() / exp_sum).collect()
391    }
392}
393
394impl GpuKernel for GNNInference {
395    fn metadata(&self) -> &KernelMetadata {
396        &self.metadata
397    }
398}
399
400// ============================================================================
401// Graph Attention Kernel
402// ============================================================================
403
404/// Configuration for graph attention.
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct GraphAttentionConfig {
407    /// Number of attention heads.
408    pub num_heads: usize,
409    /// Hidden dimension per head.
410    pub head_dim: usize,
411    /// Output dimension.
412    pub output_dim: usize,
413    /// Dropout for attention weights.
414    pub attention_dropout: f64,
415    /// Whether to concatenate heads or average.
416    pub concat_heads: bool,
417    /// Negative slope for LeakyReLU in attention.
418    pub negative_slope: f64,
419}
420
421impl Default for GraphAttentionConfig {
422    fn default() -> Self {
423        Self {
424            num_heads: 4,
425            head_dim: 16,
426            output_dim: 64,
427            attention_dropout: 0.0,
428            concat_heads: true,
429            negative_slope: 0.2,
430        }
431    }
432}
433
434/// Attention weights for GAT layer.
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct GATWeights {
437    /// Query transformation weights per head.
438    pub query_weights: Vec<Vec<Vec<f64>>>,
439    /// Key transformation weights per head.
440    pub key_weights: Vec<Vec<Vec<f64>>>,
441    /// Value transformation weights per head.
442    pub value_weights: Vec<Vec<Vec<f64>>>,
443    /// Attention vector per head.
444    pub attention_vectors: Vec<Vec<f64>>,
445    /// Output projection weights.
446    pub output_weights: Vec<Vec<f64>>,
447}
448
449impl GATWeights {
450    /// Create random weights for testing.
451    pub fn random(input_dim: usize, config: &GraphAttentionConfig) -> Self {
452        use rand::{Rng, rng};
453        let mut r = rng();
454
455        let scale = (2.0 / (input_dim + config.head_dim) as f64).sqrt();
456
457        let mut query_weights = Vec::new();
458        let mut key_weights = Vec::new();
459        let mut value_weights = Vec::new();
460        let mut attention_vectors = Vec::new();
461
462        for _ in 0..config.num_heads {
463            let q: Vec<Vec<f64>> = (0..input_dim)
464                .map(|_| {
465                    (0..config.head_dim)
466                        .map(|_| r.random_range(-scale..scale))
467                        .collect()
468                })
469                .collect();
470            let k: Vec<Vec<f64>> = (0..input_dim)
471                .map(|_| {
472                    (0..config.head_dim)
473                        .map(|_| r.random_range(-scale..scale))
474                        .collect()
475                })
476                .collect();
477            let v: Vec<Vec<f64>> = (0..input_dim)
478                .map(|_| {
479                    (0..config.head_dim)
480                        .map(|_| r.random_range(-scale..scale))
481                        .collect()
482                })
483                .collect();
484            let a: Vec<f64> = (0..config.head_dim * 2)
485                .map(|_| r.random_range(-scale..scale))
486                .collect();
487
488            query_weights.push(q);
489            key_weights.push(k);
490            value_weights.push(v);
491            attention_vectors.push(a);
492        }
493
494        let total_dim = if config.concat_heads {
495            config.num_heads * config.head_dim
496        } else {
497            config.head_dim
498        };
499
500        let out_scale = (2.0 / (total_dim + config.output_dim) as f64).sqrt();
501        let output_weights: Vec<Vec<f64>> = (0..total_dim)
502            .map(|_| {
503                (0..config.output_dim)
504                    .map(|_| r.random_range(-out_scale..out_scale))
505                    .collect()
506            })
507            .collect();
508
509        Self {
510            query_weights,
511            key_weights,
512            value_weights,
513            attention_vectors,
514            output_weights,
515        }
516    }
517}
518
519/// Result of graph attention.
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct GATResult {
522    /// Output embeddings.
523    pub embeddings: Vec<Vec<f64>>,
524    /// Attention weights per head (source, target, weight).
525    pub attention_weights: Vec<Vec<(usize, usize, f64)>>,
526}
527
528/// Graph Attention kernel.
529///
530/// Implements Graph Attention Networks (GAT) with multi-head attention.
531/// Learns to weight neighbor contributions based on their relevance
532/// to each node.
533#[derive(Debug, Clone)]
534pub struct GraphAttention {
535    metadata: KernelMetadata,
536}
537
538impl Default for GraphAttention {
539    fn default() -> Self {
540        Self::new()
541    }
542}
543
544impl GraphAttention {
545    /// Create a new Graph Attention kernel.
546    #[must_use]
547    pub fn new() -> Self {
548        Self {
549            metadata: KernelMetadata::batch("graph/graph-attention", Domain::GraphAnalytics)
550                .with_description("Graph attention networks with multi-head attention")
551                .with_throughput(5_000)
552                .with_latency_us(200.0),
553        }
554    }
555
556    /// Compute graph attention layer.
557    pub fn compute(
558        graph: &CsrGraph,
559        node_features: &[Vec<f64>],
560        weights: &GATWeights,
561        config: &GraphAttentionConfig,
562    ) -> GATResult {
563        if graph.num_nodes == 0 || node_features.is_empty() {
564            return GATResult {
565                embeddings: Vec::new(),
566                attention_weights: Vec::new(),
567            };
568        }
569
570        let n = graph.num_nodes;
571
572        // Build adjacency with self-loops from CSR format
573        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
574        for node in 0..n {
575            let start = graph.row_offsets[node] as usize;
576            let end = graph.row_offsets[node + 1] as usize;
577            for &neighbor in &graph.col_indices[start..end] {
578                adj[node].push(neighbor as usize);
579                if !adj[neighbor as usize].contains(&node) {
580                    adj[neighbor as usize].push(node);
581                }
582            }
583        }
584        for i in 0..n {
585            if !adj[i].contains(&i) {
586                adj[i].push(i);
587            }
588        }
589
590        // Compute attention for each head
591        let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::new();
592        let mut all_attention_weights: Vec<Vec<(usize, usize, f64)>> = Vec::new();
593
594        for head in 0..config.num_heads {
595            let (output, attn_weights) = Self::compute_head(
596                node_features,
597                &adj,
598                &weights.query_weights[head],
599                &weights.key_weights[head],
600                &weights.value_weights[head],
601                &weights.attention_vectors[head],
602                config,
603            );
604            head_outputs.push(output);
605            all_attention_weights.push(attn_weights);
606        }
607
608        // Combine heads
609        let combined: Vec<Vec<f64>> = if config.concat_heads {
610            (0..n)
611                .map(|i| head_outputs.iter().flat_map(|h| h[i].clone()).collect())
612                .collect()
613        } else {
614            // Average heads
615            (0..n)
616                .map(|i| {
617                    let dim = head_outputs[0][i].len();
618                    let mut avg = vec![0.0; dim];
619                    for h in &head_outputs {
620                        for (j, &v) in h[i].iter().enumerate() {
621                            avg[j] += v;
622                        }
623                    }
624                    avg.iter_mut().for_each(|v| *v /= config.num_heads as f64);
625                    avg
626                })
627                .collect()
628        };
629
630        // Output projection
631        let embeddings: Vec<Vec<f64>> = combined
632            .iter()
633            .map(|c| Self::linear_transform(c, &weights.output_weights))
634            .collect();
635
636        GATResult {
637            embeddings,
638            attention_weights: all_attention_weights,
639        }
640    }
641
642    /// Compute single attention head.
643    fn compute_head(
644        features: &[Vec<f64>],
645        adj: &[Vec<usize>],
646        query_w: &[Vec<f64>],
647        key_w: &[Vec<f64>],
648        value_w: &[Vec<f64>],
649        attn_vec: &[f64],
650        config: &GraphAttentionConfig,
651    ) -> (Vec<Vec<f64>>, Vec<(usize, usize, f64)>) {
652        let n = features.len();
653        let head_dim = config.head_dim;
654
655        // Transform features to Q, K, V
656        let queries: Vec<Vec<f64>> = features
657            .iter()
658            .map(|f| Self::linear_transform(f, query_w))
659            .collect();
660        let keys: Vec<Vec<f64>> = features
661            .iter()
662            .map(|f| Self::linear_transform(f, key_w))
663            .collect();
664        let values: Vec<Vec<f64>> = features
665            .iter()
666            .map(|f| Self::linear_transform(f, value_w))
667            .collect();
668
669        let mut output = vec![vec![0.0; head_dim]; n];
670        let mut attention_list: Vec<(usize, usize, f64)> = Vec::new();
671
672        for i in 0..n {
673            if adj[i].is_empty() {
674                continue;
675            }
676
677            // Compute attention scores for neighbors
678            let mut scores: Vec<f64> = Vec::with_capacity(adj[i].len());
679
680            for &j in &adj[i] {
681                // Concatenate Q_i and K_j, apply attention vector
682                let mut concat = queries[i].clone();
683                concat.extend(keys[j].iter().cloned());
684
685                let score: f64 = concat.iter().zip(attn_vec.iter()).map(|(c, a)| c * a).sum();
686
687                // LeakyReLU
688                let score = if score > 0.0 {
689                    score
690                } else {
691                    config.negative_slope * score
692                };
693
694                scores.push(score);
695            }
696
697            // Softmax over neighbors
698            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
699            let exp_scores: Vec<f64> = scores.iter().map(|s| (s - max_score).exp()).collect();
700            let sum_exp: f64 = exp_scores.iter().sum();
701            let attention: Vec<f64> = exp_scores.iter().map(|e| e / sum_exp).collect();
702
703            // Aggregate values weighted by attention
704            for (idx, &j) in adj[i].iter().enumerate() {
705                let attn = attention[idx];
706                attention_list.push((i, j, attn));
707
708                for (k, &v) in values[j].iter().enumerate() {
709                    output[i][k] += attn * v;
710                }
711            }
712        }
713
714        (output, attention_list)
715    }
716
717    /// Linear transformation.
718    fn linear_transform(input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
719        if weights.is_empty() {
720            return Vec::new();
721        }
722
723        let out_dim = weights[0].len();
724        let mut output = vec![0.0; out_dim];
725
726        for (i, &x) in input.iter().enumerate() {
727            if i < weights.len() {
728                for (j, &w) in weights[i].iter().enumerate() {
729                    output[j] += x * w;
730                }
731            }
732        }
733
734        output
735    }
736
737    /// Get node importance based on attention received.
738    pub fn node_importance(attention_weights: &[(usize, usize, f64)], n: usize) -> Vec<f64> {
739        let mut importance = vec![0.0; n];
740        let mut counts = vec![0usize; n];
741
742        for &(_, target, weight) in attention_weights {
743            if target < n {
744                importance[target] += weight;
745                counts[target] += 1;
746            }
747        }
748
749        // Normalize by count
750        for i in 0..n {
751            if counts[i] > 0 {
752                importance[i] /= counts[i] as f64;
753            }
754        }
755
756        importance
757    }
758}
759
760impl GpuKernel for GraphAttention {
761    fn metadata(&self) -> &KernelMetadata {
762        &self.metadata
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use std::collections::HashMap;
770
771    fn create_test_graph() -> CsrGraph {
772        // Simple triangle graph: 0 -- 1 -- 2 -- 0
773        CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
774    }
775
776    #[test]
777    fn test_gnn_inference_metadata() {
778        let kernel = GNNInference::new();
779        assert_eq!(kernel.metadata().id, "graph/gnn-inference");
780    }
781
782    #[test]
783    fn test_gnn_inference_basic() {
784        let graph = create_test_graph();
785        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
786
787        let config = GNNConfig {
788            num_layers: 2,
789            hidden_dim: 4,
790            output_dim: 2,
791            ..Default::default()
792        };
793
794        let weights = GNNWeights::random(2, &config);
795        let result = GNNInference::compute(&graph, &features, &weights, &config);
796
797        assert_eq!(result.embeddings.len(), 3);
798        assert_eq!(result.embeddings[0].len(), 2);
799        assert!(result.predictions.is_some());
800    }
801
802    #[test]
803    fn test_gnn_aggregation_types() {
804        let graph = create_test_graph();
805        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
806
807        for agg in [
808            AggregationType::Sum,
809            AggregationType::Mean,
810            AggregationType::Max,
811            AggregationType::SAGE,
812        ] {
813            let config = GNNConfig {
814                aggregation: agg,
815                num_layers: 1,
816                hidden_dim: 4,
817                output_dim: 2,
818                ..Default::default()
819            };
820
821            let weights = GNNWeights::random(2, &config);
822            let result = GNNInference::compute(&graph, &features, &weights, &config);
823
824            assert_eq!(result.embeddings.len(), 3);
825        }
826    }
827
828    #[test]
829    fn test_gnn_empty_graph() {
830        let graph = CsrGraph::empty();
831        let features: Vec<Vec<f64>> = vec![];
832        let config = GNNConfig::default();
833        let weights = GNNWeights::random(2, &config);
834
835        let result = GNNInference::compute(&graph, &features, &weights, &config);
836        assert!(result.embeddings.is_empty());
837    }
838
839    #[test]
840    fn test_graph_attention_metadata() {
841        let kernel = GraphAttention::new();
842        assert_eq!(kernel.metadata().id, "graph/graph-attention");
843    }
844
845    #[test]
846    fn test_graph_attention_basic() {
847        let graph = create_test_graph();
848        let features = vec![
849            vec![1.0, 0.0, 0.0, 0.0],
850            vec![0.0, 1.0, 0.0, 0.0],
851            vec![0.0, 0.0, 1.0, 0.0],
852        ];
853
854        let config = GraphAttentionConfig {
855            num_heads: 2,
856            head_dim: 4,
857            output_dim: 3,
858            ..Default::default()
859        };
860
861        let weights = GATWeights::random(4, &config);
862        let result = GraphAttention::compute(&graph, &features, &weights, &config);
863
864        assert_eq!(result.embeddings.len(), 3);
865        assert_eq!(result.embeddings[0].len(), 3);
866        assert!(!result.attention_weights.is_empty());
867    }
868
869    #[test]
870    fn test_attention_weights_sum_to_one() {
871        let graph = create_test_graph();
872        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
873
874        let config = GraphAttentionConfig {
875            num_heads: 1,
876            head_dim: 4,
877            output_dim: 2,
878            ..Default::default()
879        };
880
881        let weights = GATWeights::random(2, &config);
882        let result = GraphAttention::compute(&graph, &features, &weights, &config);
883
884        // Group attention weights by source node
885        let mut sums: HashMap<usize, f64> = HashMap::new();
886        for &(src, _, weight) in &result.attention_weights[0] {
887            *sums.entry(src).or_insert(0.0) += weight;
888        }
889
890        // Each source's attention should sum to ~1
891        for (_, sum) in sums {
892            assert!(
893                (sum - 1.0).abs() < 0.01,
894                "Attention should sum to 1, got {}",
895                sum
896            );
897        }
898    }
899
900    #[test]
901    fn test_node_importance() {
902        let attn_weights = vec![
903            (0, 1, 0.5),
904            (0, 2, 0.5),
905            (1, 0, 0.3),
906            (1, 2, 0.7),
907            (2, 0, 0.4),
908            (2, 1, 0.6),
909        ];
910
911        let importance = GraphAttention::node_importance(&attn_weights, 3);
912
913        assert_eq!(importance.len(), 3);
914        // Node 2 receives more attention on average
915        assert!(importance.iter().all(|&i| i >= 0.0));
916    }
917
918    #[test]
919    fn test_activation_functions() {
920        assert_eq!(GNNInference::activate(1.0, ActivationType::ReLU), 1.0);
921        assert_eq!(GNNInference::activate(-1.0, ActivationType::ReLU), 0.0);
922        assert!((GNNInference::activate(0.0, ActivationType::Sigmoid) - 0.5).abs() < 0.001);
923        assert_eq!(GNNInference::activate(1.0, ActivationType::None), 1.0);
924    }
925}