rustkernel_graph/
gnn.rs

1//! Graph Neural Network kernels.
2//!
3//! This module provides GPU-accelerated GNN algorithms:
4//! - GNNInference - Message passing neural network inference
5//! - GraphAttention - Graph attention network (GAT) layers
6
7use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11// ============================================================================
12// GNN Inference Kernel
13// ============================================================================
14
15/// Configuration for GNN inference.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GNNConfig {
18    /// Number of message passing layers.
19    pub num_layers: usize,
20    /// Hidden dimension size.
21    pub hidden_dim: usize,
22    /// Output dimension (number of classes or embedding size).
23    pub output_dim: usize,
24    /// Aggregation function for messages.
25    pub aggregation: AggregationType,
26    /// Activation function.
27    pub activation: ActivationType,
28    /// Dropout rate (0-1).
29    pub dropout: f64,
30    /// Whether to add self-loops.
31    pub add_self_loops: bool,
32    /// Whether to use layer normalization.
33    pub layer_norm: bool,
34}
35
36impl Default for GNNConfig {
37    fn default() -> Self {
38        Self {
39            num_layers: 2,
40            hidden_dim: 64,
41            output_dim: 32,
42            aggregation: AggregationType::Mean,
43            activation: ActivationType::ReLU,
44            dropout: 0.0,
45            add_self_loops: true,
46            layer_norm: true,
47        }
48    }
49}
50
51/// Message aggregation type.
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum AggregationType {
54    /// Sum of neighbor messages.
55    Sum,
56    /// Mean of neighbor messages.
57    Mean,
58    /// Max pooling over neighbors.
59    Max,
60    /// GraphSAGE-style sample and aggregate.
61    SAGE,
62}
63
64/// Activation function type.
65#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
66pub enum ActivationType {
67    /// Rectified Linear Unit.
68    ReLU,
69    /// Leaky ReLU with alpha=0.01.
70    LeakyReLU,
71    /// Exponential Linear Unit.
72    ELU,
73    /// Sigmoid function.
74    Sigmoid,
75    /// Hyperbolic tangent.
76    Tanh,
77    /// No activation.
78    None,
79}
80
81/// GNN layer weights (simulated).
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GNNWeights {
84    /// Weight matrices per layer.
85    pub layer_weights: Vec<Vec<Vec<f64>>>,
86    /// Bias vectors per layer.
87    pub layer_biases: Vec<Vec<f64>>,
88}
89
90impl GNNWeights {
91    /// Create random weights for testing.
92    pub fn random(input_dim: usize, config: &GNNConfig) -> Self {
93        use rand::{Rng, rng};
94        let mut r = rng();
95
96        let mut layer_weights = Vec::new();
97        let mut layer_biases = Vec::new();
98
99        let mut prev_dim = input_dim;
100
101        for i in 0..config.num_layers {
102            let out_dim = if i == config.num_layers - 1 {
103                config.output_dim
104            } else {
105                config.hidden_dim
106            };
107
108            // Xavier initialization
109            let scale = (2.0 / (prev_dim + out_dim) as f64).sqrt();
110
111            let weights: Vec<Vec<f64>> = (0..prev_dim)
112                .map(|_| {
113                    (0..out_dim)
114                        .map(|_| r.random_range(-scale..scale))
115                        .collect()
116                })
117                .collect();
118
119            let biases: Vec<f64> = (0..out_dim).map(|_| 0.0).collect();
120
121            layer_weights.push(weights);
122            layer_biases.push(biases);
123            prev_dim = out_dim;
124        }
125
126        Self {
127            layer_weights,
128            layer_biases,
129        }
130    }
131}
132
133/// Result of GNN inference.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct GNNResult {
136    /// Node embeddings after all layers.
137    pub embeddings: Vec<Vec<f64>>,
138    /// Class predictions (if classification).
139    pub predictions: Option<Vec<usize>>,
140    /// Softmax probabilities (if classification).
141    pub probabilities: Option<Vec<Vec<f64>>>,
142}
143
144/// GNN Inference kernel.
145///
146/// Performs message passing neural network inference on graph data.
147/// Supports various aggregation strategies and can be used for
148/// node classification, link prediction, and graph-level tasks.
149#[derive(Debug, Clone)]
150pub struct GNNInference {
151    metadata: KernelMetadata,
152}
153
154impl Default for GNNInference {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl GNNInference {
161    /// Create a new GNN Inference kernel.
162    #[must_use]
163    pub fn new() -> Self {
164        Self {
165            metadata: KernelMetadata::batch("graph/gnn-inference", Domain::GraphAnalytics)
166                .with_description("Message passing neural network inference")
167                .with_throughput(10_000)
168                .with_latency_us(100.0),
169        }
170    }
171
172    /// Run GNN inference on a graph.
173    #[allow(clippy::needless_range_loop)]
174    pub fn compute(
175        graph: &CsrGraph,
176        node_features: &[Vec<f64>],
177        weights: &GNNWeights,
178        config: &GNNConfig,
179    ) -> GNNResult {
180        if graph.num_nodes == 0 || node_features.is_empty() {
181            return GNNResult {
182                embeddings: Vec::new(),
183                predictions: None,
184                probabilities: None,
185            };
186        }
187
188        let n = graph.num_nodes;
189
190        // Build adjacency list from CSR format (with optional self-loops)
191        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
192        for node in 0..n {
193            let start = graph.row_offsets[node] as usize;
194            let end = graph.row_offsets[node + 1] as usize;
195            for &neighbor in &graph.col_indices[start..end] {
196                adj[node].push(neighbor as usize);
197                // Add reverse edge for undirected
198                if !adj[neighbor as usize].contains(&node) {
199                    adj[neighbor as usize].push(node);
200                }
201            }
202        }
203
204        if config.add_self_loops {
205            for i in 0..n {
206                if !adj[i].contains(&i) {
207                    adj[i].push(i);
208                }
209            }
210        }
211
212        // Initialize embeddings from features
213        let mut embeddings: Vec<Vec<f64>> = node_features.to_vec();
214
215        // Run message passing layers
216        for layer_idx in 0..config.num_layers {
217            embeddings = Self::message_passing_layer(
218                &embeddings,
219                &adj,
220                &weights.layer_weights[layer_idx],
221                &weights.layer_biases[layer_idx],
222                config,
223                layer_idx == config.num_layers - 1,
224            );
225        }
226
227        // Compute predictions if output looks like classification
228        let (predictions, probabilities) = if config.output_dim > 1 {
229            let probs: Vec<Vec<f64>> = embeddings.iter().map(|e| Self::softmax(e)).collect();
230            let preds: Vec<usize> = probs
231                .iter()
232                .map(|p| {
233                    p.iter()
234                        .enumerate()
235                        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
236                        .map(|(i, _)| i)
237                        .unwrap_or(0)
238                })
239                .collect();
240            (Some(preds), Some(probs))
241        } else {
242            (None, None)
243        };
244
245        GNNResult {
246            embeddings,
247            predictions,
248            probabilities,
249        }
250    }
251
252    /// Single message passing layer.
253    #[allow(clippy::needless_range_loop)]
254    fn message_passing_layer(
255        embeddings: &[Vec<f64>],
256        adj: &[Vec<usize>],
257        weights: &[Vec<f64>],
258        biases: &[f64],
259        config: &GNNConfig,
260        is_last: bool,
261    ) -> Vec<Vec<f64>> {
262        let n = embeddings.len();
263        let out_dim = biases.len();
264        let mut new_embeddings = vec![vec![0.0; out_dim]; n];
265
266        for i in 0..n {
267            // Aggregate neighbor messages
268            let aggregated = Self::aggregate_neighbors(embeddings, &adj[i], config.aggregation);
269
270            // Transform: out = activation(W * aggregated + b)
271            for j in 0..out_dim {
272                let mut val = biases[j];
273                for (k, &agg_val) in aggregated.iter().enumerate() {
274                    if k < weights.len() && j < weights[k].len() {
275                        val += weights[k][j] * agg_val;
276                    }
277                }
278
279                // Apply activation (except on last layer if doing classification)
280                if !is_last {
281                    val = Self::activate(val, config.activation);
282                }
283
284                new_embeddings[i][j] = val;
285            }
286
287            // Layer normalization
288            if config.layer_norm && !is_last {
289                let mean: f64 = new_embeddings[i].iter().sum::<f64>() / out_dim as f64;
290                let var: f64 = new_embeddings[i]
291                    .iter()
292                    .map(|x| (x - mean).powi(2))
293                    .sum::<f64>()
294                    / out_dim as f64;
295                let std = (var + 1e-5).sqrt();
296
297                for j in 0..out_dim {
298                    new_embeddings[i][j] = (new_embeddings[i][j] - mean) / std;
299                }
300            }
301        }
302
303        new_embeddings
304    }
305
306    /// Aggregate messages from neighbors.
307    fn aggregate_neighbors(
308        embeddings: &[Vec<f64>],
309        neighbors: &[usize],
310        agg_type: AggregationType,
311    ) -> Vec<f64> {
312        if neighbors.is_empty() {
313            return vec![0.0; embeddings.first().map(|e| e.len()).unwrap_or(0)];
314        }
315
316        let dim = embeddings[neighbors[0]].len();
317
318        match agg_type {
319            AggregationType::Sum => {
320                let mut result = vec![0.0; dim];
321                for &n in neighbors {
322                    for (i, &v) in embeddings[n].iter().enumerate() {
323                        result[i] += v;
324                    }
325                }
326                result
327            }
328            AggregationType::Mean => {
329                let mut result = vec![0.0; dim];
330                for &n in neighbors {
331                    for (i, &v) in embeddings[n].iter().enumerate() {
332                        result[i] += v;
333                    }
334                }
335                let count = neighbors.len() as f64;
336                result.iter_mut().for_each(|v| *v /= count);
337                result
338            }
339            AggregationType::Max => {
340                let mut result = vec![f64::NEG_INFINITY; dim];
341                for &n in neighbors {
342                    for (i, &v) in embeddings[n].iter().enumerate() {
343                        result[i] = result[i].max(v);
344                    }
345                }
346                result
347            }
348            AggregationType::SAGE => {
349                // GraphSAGE: concat(self, mean(neighbors))
350                // Simplified: just use mean here
351                let mut result = vec![0.0; dim];
352                for &n in neighbors {
353                    for (i, &v) in embeddings[n].iter().enumerate() {
354                        result[i] += v;
355                    }
356                }
357                let count = neighbors.len() as f64;
358                result.iter_mut().for_each(|v| *v /= count);
359                result
360            }
361        }
362    }
363
364    /// Apply activation function.
365    fn activate(x: f64, activation: ActivationType) -> f64 {
366        match activation {
367            ActivationType::ReLU => x.max(0.0),
368            ActivationType::LeakyReLU => {
369                if x > 0.0 {
370                    x
371                } else {
372                    0.01 * x
373                }
374            }
375            ActivationType::ELU => {
376                if x > 0.0 {
377                    x
378                } else {
379                    x.exp() - 1.0
380                }
381            }
382            ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
383            ActivationType::Tanh => x.tanh(),
384            ActivationType::None => x,
385        }
386    }
387
388    /// Softmax for classification.
389    fn softmax(x: &[f64]) -> Vec<f64> {
390        let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
391        let exp_sum: f64 = x.iter().map(|v| (v - max_val).exp()).sum();
392        x.iter().map(|v| (v - max_val).exp() / exp_sum).collect()
393    }
394}
395
396impl GpuKernel for GNNInference {
397    fn metadata(&self) -> &KernelMetadata {
398        &self.metadata
399    }
400}
401
402// ============================================================================
403// Graph Attention Kernel
404// ============================================================================
405
406/// Configuration for graph attention.
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct GraphAttentionConfig {
409    /// Number of attention heads.
410    pub num_heads: usize,
411    /// Hidden dimension per head.
412    pub head_dim: usize,
413    /// Output dimension.
414    pub output_dim: usize,
415    /// Dropout for attention weights.
416    pub attention_dropout: f64,
417    /// Whether to concatenate heads or average.
418    pub concat_heads: bool,
419    /// Negative slope for LeakyReLU in attention.
420    pub negative_slope: f64,
421}
422
423impl Default for GraphAttentionConfig {
424    fn default() -> Self {
425        Self {
426            num_heads: 4,
427            head_dim: 16,
428            output_dim: 64,
429            attention_dropout: 0.0,
430            concat_heads: true,
431            negative_slope: 0.2,
432        }
433    }
434}
435
436/// Attention weights for GAT layer.
437#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct GATWeights {
439    /// Query transformation weights per head.
440    pub query_weights: Vec<Vec<Vec<f64>>>,
441    /// Key transformation weights per head.
442    pub key_weights: Vec<Vec<Vec<f64>>>,
443    /// Value transformation weights per head.
444    pub value_weights: Vec<Vec<Vec<f64>>>,
445    /// Attention vector per head.
446    pub attention_vectors: Vec<Vec<f64>>,
447    /// Output projection weights.
448    pub output_weights: Vec<Vec<f64>>,
449}
450
451impl GATWeights {
452    /// Create random weights for testing.
453    pub fn random(input_dim: usize, config: &GraphAttentionConfig) -> Self {
454        use rand::{Rng, rng};
455        let mut r = rng();
456
457        let scale = (2.0 / (input_dim + config.head_dim) as f64).sqrt();
458
459        let mut query_weights = Vec::new();
460        let mut key_weights = Vec::new();
461        let mut value_weights = Vec::new();
462        let mut attention_vectors = Vec::new();
463
464        for _ in 0..config.num_heads {
465            let q: Vec<Vec<f64>> = (0..input_dim)
466                .map(|_| {
467                    (0..config.head_dim)
468                        .map(|_| r.random_range(-scale..scale))
469                        .collect()
470                })
471                .collect();
472            let k: Vec<Vec<f64>> = (0..input_dim)
473                .map(|_| {
474                    (0..config.head_dim)
475                        .map(|_| r.random_range(-scale..scale))
476                        .collect()
477                })
478                .collect();
479            let v: Vec<Vec<f64>> = (0..input_dim)
480                .map(|_| {
481                    (0..config.head_dim)
482                        .map(|_| r.random_range(-scale..scale))
483                        .collect()
484                })
485                .collect();
486            let a: Vec<f64> = (0..config.head_dim * 2)
487                .map(|_| r.random_range(-scale..scale))
488                .collect();
489
490            query_weights.push(q);
491            key_weights.push(k);
492            value_weights.push(v);
493            attention_vectors.push(a);
494        }
495
496        let total_dim = if config.concat_heads {
497            config.num_heads * config.head_dim
498        } else {
499            config.head_dim
500        };
501
502        let out_scale = (2.0 / (total_dim + config.output_dim) as f64).sqrt();
503        let output_weights: Vec<Vec<f64>> = (0..total_dim)
504            .map(|_| {
505                (0..config.output_dim)
506                    .map(|_| r.random_range(-out_scale..out_scale))
507                    .collect()
508            })
509            .collect();
510
511        Self {
512            query_weights,
513            key_weights,
514            value_weights,
515            attention_vectors,
516            output_weights,
517        }
518    }
519}
520
521/// Result of graph attention.
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct GATResult {
524    /// Output embeddings.
525    pub embeddings: Vec<Vec<f64>>,
526    /// Attention weights per head (source, target, weight).
527    pub attention_weights: Vec<Vec<(usize, usize, f64)>>,
528}
529
530/// Graph Attention kernel.
531///
532/// Implements Graph Attention Networks (GAT) with multi-head attention.
533/// Learns to weight neighbor contributions based on their relevance
534/// to each node.
535#[derive(Debug, Clone)]
536pub struct GraphAttention {
537    metadata: KernelMetadata,
538}
539
540impl Default for GraphAttention {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546impl GraphAttention {
547    /// Create a new Graph Attention kernel.
548    #[must_use]
549    pub fn new() -> Self {
550        Self {
551            metadata: KernelMetadata::batch("graph/graph-attention", Domain::GraphAnalytics)
552                .with_description("Graph attention networks with multi-head attention")
553                .with_throughput(5_000)
554                .with_latency_us(200.0),
555        }
556    }
557
558    /// Compute graph attention layer.
559    #[allow(clippy::needless_range_loop)]
560    pub fn compute(
561        graph: &CsrGraph,
562        node_features: &[Vec<f64>],
563        weights: &GATWeights,
564        config: &GraphAttentionConfig,
565    ) -> GATResult {
566        if graph.num_nodes == 0 || node_features.is_empty() {
567            return GATResult {
568                embeddings: Vec::new(),
569                attention_weights: Vec::new(),
570            };
571        }
572
573        let n = graph.num_nodes;
574
575        // Build adjacency with self-loops from CSR format
576        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
577        for node in 0..n {
578            let start = graph.row_offsets[node] as usize;
579            let end = graph.row_offsets[node + 1] as usize;
580            for &neighbor in &graph.col_indices[start..end] {
581                adj[node].push(neighbor as usize);
582                if !adj[neighbor as usize].contains(&node) {
583                    adj[neighbor as usize].push(node);
584                }
585            }
586        }
587        for i in 0..n {
588            if !adj[i].contains(&i) {
589                adj[i].push(i);
590            }
591        }
592
593        // Compute attention for each head
594        let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::new();
595        let mut all_attention_weights: Vec<Vec<(usize, usize, f64)>> = Vec::new();
596
597        for head in 0..config.num_heads {
598            let (output, attn_weights) = Self::compute_head(
599                node_features,
600                &adj,
601                &weights.query_weights[head],
602                &weights.key_weights[head],
603                &weights.value_weights[head],
604                &weights.attention_vectors[head],
605                config,
606            );
607            head_outputs.push(output);
608            all_attention_weights.push(attn_weights);
609        }
610
611        // Combine heads
612        let combined: Vec<Vec<f64>> = if config.concat_heads {
613            (0..n)
614                .map(|i| head_outputs.iter().flat_map(|h| h[i].clone()).collect())
615                .collect()
616        } else {
617            // Average heads
618            (0..n)
619                .map(|i| {
620                    let dim = head_outputs[0][i].len();
621                    let mut avg = vec![0.0; dim];
622                    for h in &head_outputs {
623                        for (j, &v) in h[i].iter().enumerate() {
624                            avg[j] += v;
625                        }
626                    }
627                    avg.iter_mut().for_each(|v| *v /= config.num_heads as f64);
628                    avg
629                })
630                .collect()
631        };
632
633        // Output projection
634        let embeddings: Vec<Vec<f64>> = combined
635            .iter()
636            .map(|c| Self::linear_transform(c, &weights.output_weights))
637            .collect();
638
639        GATResult {
640            embeddings,
641            attention_weights: all_attention_weights,
642        }
643    }
644
645    /// Compute single attention head.
646    #[allow(clippy::type_complexity)]
647    fn compute_head(
648        features: &[Vec<f64>],
649        adj: &[Vec<usize>],
650        query_w: &[Vec<f64>],
651        key_w: &[Vec<f64>],
652        value_w: &[Vec<f64>],
653        attn_vec: &[f64],
654        config: &GraphAttentionConfig,
655    ) -> (Vec<Vec<f64>>, Vec<(usize, usize, f64)>) {
656        let n = features.len();
657        let head_dim = config.head_dim;
658
659        // Transform features to Q, K, V
660        let queries: Vec<Vec<f64>> = features
661            .iter()
662            .map(|f| Self::linear_transform(f, query_w))
663            .collect();
664        let keys: Vec<Vec<f64>> = features
665            .iter()
666            .map(|f| Self::linear_transform(f, key_w))
667            .collect();
668        let values: Vec<Vec<f64>> = features
669            .iter()
670            .map(|f| Self::linear_transform(f, value_w))
671            .collect();
672
673        let mut output = vec![vec![0.0; head_dim]; n];
674        let mut attention_list: Vec<(usize, usize, f64)> = Vec::new();
675
676        for i in 0..n {
677            if adj[i].is_empty() {
678                continue;
679            }
680
681            // Compute attention scores for neighbors
682            let mut scores: Vec<f64> = Vec::with_capacity(adj[i].len());
683
684            for &j in &adj[i] {
685                // Concatenate Q_i and K_j, apply attention vector
686                let mut concat = queries[i].clone();
687                concat.extend(keys[j].iter().cloned());
688
689                let score: f64 = concat.iter().zip(attn_vec.iter()).map(|(c, a)| c * a).sum();
690
691                // LeakyReLU
692                let score = if score > 0.0 {
693                    score
694                } else {
695                    config.negative_slope * score
696                };
697
698                scores.push(score);
699            }
700
701            // Softmax over neighbors
702            let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
703            let exp_scores: Vec<f64> = scores.iter().map(|s| (s - max_score).exp()).collect();
704            let sum_exp: f64 = exp_scores.iter().sum();
705            let attention: Vec<f64> = exp_scores.iter().map(|e| e / sum_exp).collect();
706
707            // Aggregate values weighted by attention
708            for (idx, &j) in adj[i].iter().enumerate() {
709                let attn = attention[idx];
710                attention_list.push((i, j, attn));
711
712                for (k, &v) in values[j].iter().enumerate() {
713                    output[i][k] += attn * v;
714                }
715            }
716        }
717
718        (output, attention_list)
719    }
720
721    /// Linear transformation.
722    fn linear_transform(input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
723        if weights.is_empty() {
724            return Vec::new();
725        }
726
727        let out_dim = weights[0].len();
728        let mut output = vec![0.0; out_dim];
729
730        for (i, &x) in input.iter().enumerate() {
731            if i < weights.len() {
732                for (j, &w) in weights[i].iter().enumerate() {
733                    output[j] += x * w;
734                }
735            }
736        }
737
738        output
739    }
740
741    /// Get node importance based on attention received.
742    pub fn node_importance(attention_weights: &[(usize, usize, f64)], n: usize) -> Vec<f64> {
743        let mut importance = vec![0.0; n];
744        let mut counts = vec![0usize; n];
745
746        for &(_, target, weight) in attention_weights {
747            if target < n {
748                importance[target] += weight;
749                counts[target] += 1;
750            }
751        }
752
753        // Normalize by count
754        for i in 0..n {
755            if counts[i] > 0 {
756                importance[i] /= counts[i] as f64;
757            }
758        }
759
760        importance
761    }
762}
763
764impl GpuKernel for GraphAttention {
765    fn metadata(&self) -> &KernelMetadata {
766        &self.metadata
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use std::collections::HashMap;
774
775    fn create_test_graph() -> CsrGraph {
776        // Simple triangle graph: 0 -- 1 -- 2 -- 0
777        CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
778    }
779
780    #[test]
781    fn test_gnn_inference_metadata() {
782        let kernel = GNNInference::new();
783        assert_eq!(kernel.metadata().id, "graph/gnn-inference");
784    }
785
786    #[test]
787    fn test_gnn_inference_basic() {
788        let graph = create_test_graph();
789        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
790
791        let config = GNNConfig {
792            num_layers: 2,
793            hidden_dim: 4,
794            output_dim: 2,
795            ..Default::default()
796        };
797
798        let weights = GNNWeights::random(2, &config);
799        let result = GNNInference::compute(&graph, &features, &weights, &config);
800
801        assert_eq!(result.embeddings.len(), 3);
802        assert_eq!(result.embeddings[0].len(), 2);
803        assert!(result.predictions.is_some());
804    }
805
806    #[test]
807    fn test_gnn_aggregation_types() {
808        let graph = create_test_graph();
809        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
810
811        for agg in [
812            AggregationType::Sum,
813            AggregationType::Mean,
814            AggregationType::Max,
815            AggregationType::SAGE,
816        ] {
817            let config = GNNConfig {
818                aggregation: agg,
819                num_layers: 1,
820                hidden_dim: 4,
821                output_dim: 2,
822                ..Default::default()
823            };
824
825            let weights = GNNWeights::random(2, &config);
826            let result = GNNInference::compute(&graph, &features, &weights, &config);
827
828            assert_eq!(result.embeddings.len(), 3);
829        }
830    }
831
832    #[test]
833    fn test_gnn_empty_graph() {
834        let graph = CsrGraph::empty();
835        let features: Vec<Vec<f64>> = vec![];
836        let config = GNNConfig::default();
837        let weights = GNNWeights::random(2, &config);
838
839        let result = GNNInference::compute(&graph, &features, &weights, &config);
840        assert!(result.embeddings.is_empty());
841    }
842
843    #[test]
844    fn test_graph_attention_metadata() {
845        let kernel = GraphAttention::new();
846        assert_eq!(kernel.metadata().id, "graph/graph-attention");
847    }
848
849    #[test]
850    fn test_graph_attention_basic() {
851        let graph = create_test_graph();
852        let features = vec![
853            vec![1.0, 0.0, 0.0, 0.0],
854            vec![0.0, 1.0, 0.0, 0.0],
855            vec![0.0, 0.0, 1.0, 0.0],
856        ];
857
858        let config = GraphAttentionConfig {
859            num_heads: 2,
860            head_dim: 4,
861            output_dim: 3,
862            ..Default::default()
863        };
864
865        let weights = GATWeights::random(4, &config);
866        let result = GraphAttention::compute(&graph, &features, &weights, &config);
867
868        assert_eq!(result.embeddings.len(), 3);
869        assert_eq!(result.embeddings[0].len(), 3);
870        assert!(!result.attention_weights.is_empty());
871    }
872
873    #[test]
874    fn test_attention_weights_sum_to_one() {
875        let graph = create_test_graph();
876        let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
877
878        let config = GraphAttentionConfig {
879            num_heads: 1,
880            head_dim: 4,
881            output_dim: 2,
882            ..Default::default()
883        };
884
885        let weights = GATWeights::random(2, &config);
886        let result = GraphAttention::compute(&graph, &features, &weights, &config);
887
888        // Group attention weights by source node
889        let mut sums: HashMap<usize, f64> = HashMap::new();
890        for &(src, _, weight) in &result.attention_weights[0] {
891            *sums.entry(src).or_insert(0.0) += weight;
892        }
893
894        // Each source's attention should sum to ~1
895        for (_, sum) in sums {
896            assert!(
897                (sum - 1.0).abs() < 0.01,
898                "Attention should sum to 1, got {}",
899                sum
900            );
901        }
902    }
903
904    #[test]
905    fn test_node_importance() {
906        let attn_weights = vec![
907            (0, 1, 0.5),
908            (0, 2, 0.5),
909            (1, 0, 0.3),
910            (1, 2, 0.7),
911            (2, 0, 0.4),
912            (2, 1, 0.6),
913        ];
914
915        let importance = GraphAttention::node_importance(&attn_weights, 3);
916
917        assert_eq!(importance.len(), 3);
918        // Node 2 receives more attention on average
919        assert!(importance.iter().all(|&i| i >= 0.0));
920    }
921
922    #[test]
923    fn test_activation_functions() {
924        assert_eq!(GNNInference::activate(1.0, ActivationType::ReLU), 1.0);
925        assert_eq!(GNNInference::activate(-1.0, ActivationType::ReLU), 0.0);
926        assert!((GNNInference::activate(0.0, ActivationType::Sigmoid) - 0.5).abs() < 0.001);
927        assert_eq!(GNNInference::activate(1.0, ActivationType::None), 1.0);
928    }
929}