Skip to main content

torsh_graph/
lib.rs

1//! Graph Neural Network components for ToRSh
2//!
3//! This module provides PyTorch-compatible graph neural network layers and operations,
4//! built on top of SciRS2's graph algorithms and spectral methods.
5//!
6//! # Enhanced Features:
7//! - GPU acceleration for graph operations
8//! - Memory-efficient sparse representations
9//! - Graph attention visualization
10//! - Batch processing capabilities
11
12pub mod classification;
13// pub mod continuous_time; // Continuous-time graph networks (TGN, Neural ODE) - TODO: Fix API compatibility
14pub mod conv;
15pub mod data;
16pub mod datasets;
17// pub mod diffusion; // Graph diffusion models (DDPM, DDIM, discrete diffusion) - TODO: Fix API compatibility
18pub mod distributed; // Distributed graph neural networks
19pub mod enhanced_scirs2_integration; // Full SciRS2 algorithm suite
20                                     // pub mod equivariant; // Equivariant graph neural networks (EGNN, SchNet) - TODO: Fix API compatibility
21pub mod explainability;
22pub mod foundation; // Graph foundation models and self-supervised learning
23pub mod functional;
24pub mod generative; // Graph generation models (VAE, GAN)
25pub mod geometric; // Geometric graph neural networks
26pub mod hypergraph;
27pub mod jit;
28pub mod lottery_ticket; // Graph lottery ticket hypothesis and pruning
29pub mod matching; // Graph matching and similarity learning
30pub mod multimodal;
31pub mod neural_operators;
32pub mod neuromorphic;
33pub mod optimal_transport; // Graph optimal transport (Gromov-Wasserstein, Sinkhorn)
34pub mod parameter;
35pub mod pool;
36pub mod quantum;
37pub mod scirs2_integration;
38pub mod spectral; // Spectral graph methods
39pub mod temporal;
40pub mod utils;
41
42use torsh_tensor::Tensor;
43// Enhanced SciRS2 integration for performance optimization
44// use scirs2_core::gpu::{GpuContext, GpuBuffer}; // Will be used when available
45// use scirs2_core::memory_efficient::MemoryMappedArray; // Will be used when available
46
47/// Graph data structure for GNNs
48#[derive(Debug, Clone)]
49pub struct GraphData {
50    /// Node feature matrix (num_nodes x num_features)
51    pub x: Tensor,
52    /// Edge index matrix (2 x num_edges)
53    pub edge_index: Tensor,
54    /// Edge features (optional)
55    pub edge_attr: Option<Tensor>,
56    /// Batch assignment vector (optional, for batched graphs)
57    pub batch: Option<Tensor>,
58    /// Number of nodes
59    pub num_nodes: usize,
60    /// Number of edges
61    pub num_edges: usize,
62}
63
64impl GraphData {
65    /// Create a new graph data structure
66    ///
67    /// # Arguments
68    /// * `x` - Node feature matrix with shape `[num_nodes, num_features]`
69    /// * `edge_index` - Edge connectivity with shape `[2, num_edges]`
70    ///
71    /// # Returns
72    /// A new `GraphData` instance
73    ///
74    /// # Example
75    /// ```
76    /// use torsh_graph::GraphData;
77    /// use torsh_tensor::creation::from_vec;
78    /// use torsh_core::device::DeviceType;
79    ///
80    /// // Create a simple triangle graph
81    /// let x = from_vec(vec![1.0, 2.0, 3.0], &[3, 1], DeviceType::Cpu).unwrap();
82    /// let edge_index = from_vec(
83    ///     vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], // src, dst
84    ///     &[2, 3],
85    ///     DeviceType::Cpu
86    /// ).unwrap();
87    ///
88    /// let graph = GraphData::new(x, edge_index);
89    /// assert_eq!(graph.num_nodes, 3);
90    /// assert_eq!(graph.num_edges, 3);
91    /// ```
92    pub fn new(x: Tensor, edge_index: Tensor) -> Self {
93        let num_nodes = x.shape().dims()[0];
94        let num_edges = edge_index.shape().dims()[1];
95
96        Self {
97            x,
98            edge_index,
99            edge_attr: None,
100            batch: None,
101            num_nodes,
102            num_edges,
103        }
104    }
105
106    /// Add edge attributes
107    pub fn with_edge_attr(mut self, edge_attr: Tensor) -> Self {
108        self.edge_attr = Some(edge_attr);
109        self
110    }
111
112    /// Add batch assignment
113    pub fn with_batch(mut self, batch: Tensor) -> Self {
114        self.batch = Some(batch);
115        self
116    }
117
118    /// Add edge attributes (optional chaining)
119    pub fn with_edge_attr_opt(mut self, edge_attr: Option<Tensor>) -> Self {
120        self.edge_attr = edge_attr;
121        self
122    }
123
124    /// Get memory usage statistics
125    pub fn memory_stats(&self) -> GraphMemoryStats {
126        let x_bytes = self.x.numel() * std::mem::size_of::<f32>(); // Assuming f32
127        let edge_index_bytes = self.edge_index.numel() * std::mem::size_of::<f32>(); // Changed to f32
128        let edge_attr_bytes = self
129            .edge_attr
130            .as_ref()
131            .map(|t| t.numel() * std::mem::size_of::<f32>())
132            .unwrap_or(0);
133        let batch_bytes = self
134            .batch
135            .as_ref()
136            .map(|t| t.numel() * std::mem::size_of::<f32>())
137            .unwrap_or(0);
138
139        GraphMemoryStats {
140            total_bytes: x_bytes + edge_index_bytes + edge_attr_bytes + batch_bytes,
141            node_features_bytes: x_bytes,
142            edge_index_bytes,
143            edge_attr_bytes,
144            batch_bytes,
145        }
146    }
147
148    /// Validate graph structure
149    ///
150    /// Checks that:
151    /// - All edge indices refer to valid nodes
152    /// - Edge attributes (if present) match the number of edges
153    ///
154    /// # Returns
155    /// * `Ok(())` - Graph structure is valid
156    /// * `Err(GraphValidationError)` - Validation failed
157    ///
158    /// # Example
159    /// ```
160    /// use torsh_graph::GraphData;
161    /// use torsh_tensor::creation::from_vec;
162    /// use torsh_core::device::DeviceType;
163    ///
164    /// let x = from_vec(vec![1.0, 2.0], &[2, 1], DeviceType::Cpu).unwrap();
165    /// let edge_index = from_vec(vec![0.0, 1.0], &[2, 1], DeviceType::Cpu).unwrap();
166    /// let graph = GraphData::new(x, edge_index);
167    ///
168    /// assert!(graph.validate().is_ok());
169    /// ```
170    pub fn validate(&self) -> Result<(), GraphValidationError> {
171        // Check edge indices are within node range
172        if let Ok(edge_data) = self.edge_index.to_vec() {
173            let max_node_id = edge_data.iter().fold(0.0f32, |a, &b| a.max(b));
174            if max_node_id >= self.num_nodes as f32 {
175                return Err(GraphValidationError::InvalidNodeIndex {
176                    node_id: max_node_id as i64,
177                    num_nodes: self.num_nodes,
178                });
179            }
180        }
181
182        // Validate edge attributes shape if present
183        if let Some(ref edge_attr) = self.edge_attr {
184            if edge_attr.shape().dims()[0] != self.num_edges {
185                return Err(GraphValidationError::EdgeAttrSizeMismatch {
186                    expected: self.num_edges,
187                    actual: edge_attr.shape().dims()[0],
188                });
189            }
190        }
191
192        Ok(())
193    }
194}
195
196/// Memory usage statistics for a graph
197#[derive(Debug, Clone)]
198pub struct GraphMemoryStats {
199    pub total_bytes: usize,
200    pub node_features_bytes: usize,
201    pub edge_index_bytes: usize,
202    pub edge_attr_bytes: usize,
203    pub batch_bytes: usize,
204}
205
206/// Graph validation errors
207#[derive(Debug, thiserror::Error)]
208pub enum GraphValidationError {
209    #[error("Invalid node index {node_id}, graph has only {num_nodes} nodes")]
210    InvalidNodeIndex { node_id: i64, num_nodes: usize },
211
212    #[error("Edge attribute size mismatch: expected {expected}, got {actual}")]
213    EdgeAttrSizeMismatch { expected: usize, actual: usize },
214
215    #[error("Tensor operation error: {0}")]
216    TensorError(String),
217}
218
219/// Trait for graph neural network layers
220pub trait GraphLayer: std::fmt::Debug {
221    /// Forward pass through the layer
222    fn forward(&self, graph: &GraphData) -> GraphData;
223
224    /// Get layer parameters
225    fn parameters(&self) -> Vec<Tensor>;
226}
227
228/// Graph attention visualization utilities
229pub mod attention_viz {
230
231    use torsh_tensor::Tensor;
232
233    /// Attention weights for visualization
234    #[derive(Debug, Clone)]
235    pub struct AttentionWeights {
236        pub edge_weights: Tensor,         // [num_edges]
237        pub node_weights: Option<Tensor>, // [num_nodes]
238        pub layer_name: String,
239        pub head_index: Option<usize>,
240    }
241
242    impl AttentionWeights {
243        pub fn new(edge_weights: Tensor, layer_name: String) -> Self {
244            Self {
245                edge_weights,
246                node_weights: None,
247                layer_name,
248                head_index: None,
249            }
250        }
251
252        pub fn with_node_weights(mut self, node_weights: Tensor) -> Self {
253            self.node_weights = Some(node_weights);
254            self
255        }
256
257        pub fn with_head_index(mut self, head_index: usize) -> Self {
258            self.head_index = Some(head_index);
259            self
260        }
261
262        /// Normalize attention weights for visualization
263        pub fn normalize(&self) -> Self {
264            // Simplified normalization - just return clone for now due to tensor API limitations
265            Self {
266                edge_weights: self.edge_weights.clone(),
267                node_weights: self.node_weights.clone(),
268                layer_name: self.layer_name.clone(),
269                head_index: self.head_index,
270            }
271        }
272    }
273}
274
275/// Node importance analysis utilities
276pub mod importance_analysis {
277
278    use torsh_tensor::Tensor;
279
280    /// Node importance metrics
281    #[derive(Debug, Clone)]
282    pub struct NodeImportance {
283        pub centrality_scores: Tensor,           // [num_nodes]
284        pub gradient_norm: Option<Tensor>,       // [num_nodes]
285        pub attention_sum: Option<Tensor>,       // [num_nodes]
286        pub feature_attribution: Option<Tensor>, // [num_nodes, num_features]
287    }
288
289    impl NodeImportance {
290        pub fn new(centrality_scores: Tensor) -> Self {
291            Self {
292                centrality_scores,
293                gradient_norm: None,
294                attention_sum: None,
295                feature_attribution: None,
296            }
297        }
298
299        /// Combine multiple importance metrics
300        pub fn combined_importance(
301            &self,
302            weights: &[f32],
303        ) -> Result<Tensor, Box<dyn std::error::Error>> {
304            // Simplified implementation - just return weighted centrality for now
305            Ok(self.centrality_scores.mul_scalar(weights[0])?)
306        }
307    }
308}