torsh_graph/lib.rs
1//! Graph Neural Network components for ToRSh
2//!
3//! This module provides PyTorch-compatible graph neural network layers and operations,
4//! built on top of SciRS2's graph algorithms and spectral methods.
5//!
6//! # Enhanced Features:
7//! - GPU acceleration for graph operations
8//! - Memory-efficient sparse representations
9//! - Graph attention visualization
10//! - Batch processing capabilities
11
12pub mod classification;
13// pub mod continuous_time; // Continuous-time graph networks (TGN, Neural ODE) - TODO: Fix API compatibility
14pub mod conv;
15pub mod data;
16pub mod datasets;
17// pub mod diffusion; // Graph diffusion models (DDPM, DDIM, discrete diffusion) - TODO: Fix API compatibility
18pub mod distributed; // Distributed graph neural networks
19pub mod enhanced_scirs2_integration; // Full SciRS2 algorithm suite
20 // pub mod equivariant; // Equivariant graph neural networks (EGNN, SchNet) - TODO: Fix API compatibility
21pub mod explainability;
22pub mod foundation; // Graph foundation models and self-supervised learning
23pub mod functional;
24pub mod generative; // Graph generation models (VAE, GAN)
25pub mod geometric; // Geometric graph neural networks
26pub mod hypergraph;
27pub mod jit;
28pub mod lottery_ticket; // Graph lottery ticket hypothesis and pruning
29pub mod matching; // Graph matching and similarity learning
30pub mod multimodal;
31pub mod neural_operators;
32pub mod neuromorphic;
33pub mod optimal_transport; // Graph optimal transport (Gromov-Wasserstein, Sinkhorn)
34pub mod parameter;
35pub mod pool;
36pub mod quantum;
37pub mod scirs2_integration;
38pub mod spectral; // Spectral graph methods
39pub mod temporal;
40pub mod utils;
41
42use torsh_tensor::Tensor;
43// Enhanced SciRS2 integration for performance optimization
44// use scirs2_core::gpu::{GpuContext, GpuBuffer}; // Will be used when available
45// use scirs2_core::memory_efficient::MemoryMappedArray; // Will be used when available
46
47/// Graph data structure for GNNs
48#[derive(Debug, Clone)]
49pub struct GraphData {
50 /// Node feature matrix (num_nodes x num_features)
51 pub x: Tensor,
52 /// Edge index matrix (2 x num_edges)
53 pub edge_index: Tensor,
54 /// Edge features (optional)
55 pub edge_attr: Option<Tensor>,
56 /// Batch assignment vector (optional, for batched graphs)
57 pub batch: Option<Tensor>,
58 /// Number of nodes
59 pub num_nodes: usize,
60 /// Number of edges
61 pub num_edges: usize,
62}
63
64impl GraphData {
65 /// Create a new graph data structure
66 ///
67 /// # Arguments
68 /// * `x` - Node feature matrix with shape `[num_nodes, num_features]`
69 /// * `edge_index` - Edge connectivity with shape `[2, num_edges]`
70 ///
71 /// # Returns
72 /// A new `GraphData` instance
73 ///
74 /// # Example
75 /// ```
76 /// use torsh_graph::GraphData;
77 /// use torsh_tensor::creation::from_vec;
78 /// use torsh_core::device::DeviceType;
79 ///
80 /// // Create a simple triangle graph
81 /// let x = from_vec(vec![1.0, 2.0, 3.0], &[3, 1], DeviceType::Cpu).unwrap();
82 /// let edge_index = from_vec(
83 /// vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], // src, dst
84 /// &[2, 3],
85 /// DeviceType::Cpu
86 /// ).unwrap();
87 ///
88 /// let graph = GraphData::new(x, edge_index);
89 /// assert_eq!(graph.num_nodes, 3);
90 /// assert_eq!(graph.num_edges, 3);
91 /// ```
92 pub fn new(x: Tensor, edge_index: Tensor) -> Self {
93 let num_nodes = x.shape().dims()[0];
94 let num_edges = edge_index.shape().dims()[1];
95
96 Self {
97 x,
98 edge_index,
99 edge_attr: None,
100 batch: None,
101 num_nodes,
102 num_edges,
103 }
104 }
105
106 /// Add edge attributes
107 pub fn with_edge_attr(mut self, edge_attr: Tensor) -> Self {
108 self.edge_attr = Some(edge_attr);
109 self
110 }
111
112 /// Add batch assignment
113 pub fn with_batch(mut self, batch: Tensor) -> Self {
114 self.batch = Some(batch);
115 self
116 }
117
118 /// Add edge attributes (optional chaining)
119 pub fn with_edge_attr_opt(mut self, edge_attr: Option<Tensor>) -> Self {
120 self.edge_attr = edge_attr;
121 self
122 }
123
124 /// Get memory usage statistics
125 pub fn memory_stats(&self) -> GraphMemoryStats {
126 let x_bytes = self.x.numel() * std::mem::size_of::<f32>(); // Assuming f32
127 let edge_index_bytes = self.edge_index.numel() * std::mem::size_of::<f32>(); // Changed to f32
128 let edge_attr_bytes = self
129 .edge_attr
130 .as_ref()
131 .map(|t| t.numel() * std::mem::size_of::<f32>())
132 .unwrap_or(0);
133 let batch_bytes = self
134 .batch
135 .as_ref()
136 .map(|t| t.numel() * std::mem::size_of::<f32>())
137 .unwrap_or(0);
138
139 GraphMemoryStats {
140 total_bytes: x_bytes + edge_index_bytes + edge_attr_bytes + batch_bytes,
141 node_features_bytes: x_bytes,
142 edge_index_bytes,
143 edge_attr_bytes,
144 batch_bytes,
145 }
146 }
147
148 /// Validate graph structure
149 ///
150 /// Checks that:
151 /// - All edge indices refer to valid nodes
152 /// - Edge attributes (if present) match the number of edges
153 ///
154 /// # Returns
155 /// * `Ok(())` - Graph structure is valid
156 /// * `Err(GraphValidationError)` - Validation failed
157 ///
158 /// # Example
159 /// ```
160 /// use torsh_graph::GraphData;
161 /// use torsh_tensor::creation::from_vec;
162 /// use torsh_core::device::DeviceType;
163 ///
164 /// let x = from_vec(vec![1.0, 2.0], &[2, 1], DeviceType::Cpu).unwrap();
165 /// let edge_index = from_vec(vec![0.0, 1.0], &[2, 1], DeviceType::Cpu).unwrap();
166 /// let graph = GraphData::new(x, edge_index);
167 ///
168 /// assert!(graph.validate().is_ok());
169 /// ```
170 pub fn validate(&self) -> Result<(), GraphValidationError> {
171 // Check edge indices are within node range
172 if let Ok(edge_data) = self.edge_index.to_vec() {
173 let max_node_id = edge_data.iter().fold(0.0f32, |a, &b| a.max(b));
174 if max_node_id >= self.num_nodes as f32 {
175 return Err(GraphValidationError::InvalidNodeIndex {
176 node_id: max_node_id as i64,
177 num_nodes: self.num_nodes,
178 });
179 }
180 }
181
182 // Validate edge attributes shape if present
183 if let Some(ref edge_attr) = self.edge_attr {
184 if edge_attr.shape().dims()[0] != self.num_edges {
185 return Err(GraphValidationError::EdgeAttrSizeMismatch {
186 expected: self.num_edges,
187 actual: edge_attr.shape().dims()[0],
188 });
189 }
190 }
191
192 Ok(())
193 }
194}
195
196/// Memory usage statistics for a graph
197#[derive(Debug, Clone)]
198pub struct GraphMemoryStats {
199 pub total_bytes: usize,
200 pub node_features_bytes: usize,
201 pub edge_index_bytes: usize,
202 pub edge_attr_bytes: usize,
203 pub batch_bytes: usize,
204}
205
206/// Graph validation errors
207#[derive(Debug, thiserror::Error)]
208pub enum GraphValidationError {
209 #[error("Invalid node index {node_id}, graph has only {num_nodes} nodes")]
210 InvalidNodeIndex { node_id: i64, num_nodes: usize },
211
212 #[error("Edge attribute size mismatch: expected {expected}, got {actual}")]
213 EdgeAttrSizeMismatch { expected: usize, actual: usize },
214
215 #[error("Tensor operation error: {0}")]
216 TensorError(String),
217}
218
219/// Trait for graph neural network layers
220pub trait GraphLayer: std::fmt::Debug {
221 /// Forward pass through the layer
222 fn forward(&self, graph: &GraphData) -> GraphData;
223
224 /// Get layer parameters
225 fn parameters(&self) -> Vec<Tensor>;
226}
227
228/// Graph attention visualization utilities
229pub mod attention_viz {
230
231 use torsh_tensor::Tensor;
232
233 /// Attention weights for visualization
234 #[derive(Debug, Clone)]
235 pub struct AttentionWeights {
236 pub edge_weights: Tensor, // [num_edges]
237 pub node_weights: Option<Tensor>, // [num_nodes]
238 pub layer_name: String,
239 pub head_index: Option<usize>,
240 }
241
242 impl AttentionWeights {
243 pub fn new(edge_weights: Tensor, layer_name: String) -> Self {
244 Self {
245 edge_weights,
246 node_weights: None,
247 layer_name,
248 head_index: None,
249 }
250 }
251
252 pub fn with_node_weights(mut self, node_weights: Tensor) -> Self {
253 self.node_weights = Some(node_weights);
254 self
255 }
256
257 pub fn with_head_index(mut self, head_index: usize) -> Self {
258 self.head_index = Some(head_index);
259 self
260 }
261
262 /// Normalize attention weights for visualization
263 pub fn normalize(&self) -> Self {
264 // Simplified normalization - just return clone for now due to tensor API limitations
265 Self {
266 edge_weights: self.edge_weights.clone(),
267 node_weights: self.node_weights.clone(),
268 layer_name: self.layer_name.clone(),
269 head_index: self.head_index,
270 }
271 }
272 }
273}
274
275/// Node importance analysis utilities
276pub mod importance_analysis {
277
278 use torsh_tensor::Tensor;
279
280 /// Node importance metrics
281 #[derive(Debug, Clone)]
282 pub struct NodeImportance {
283 pub centrality_scores: Tensor, // [num_nodes]
284 pub gradient_norm: Option<Tensor>, // [num_nodes]
285 pub attention_sum: Option<Tensor>, // [num_nodes]
286 pub feature_attribution: Option<Tensor>, // [num_nodes, num_features]
287 }
288
289 impl NodeImportance {
290 pub fn new(centrality_scores: Tensor) -> Self {
291 Self {
292 centrality_scores,
293 gradient_norm: None,
294 attention_sum: None,
295 feature_attribution: None,
296 }
297 }
298
299 /// Combine multiple importance metrics
300 pub fn combined_importance(
301 &self,
302 weights: &[f32],
303 ) -> Result<Tensor, Box<dyn std::error::Error>> {
304 // Simplified implementation - just return weighted centrality for now
305 Ok(self.centrality_scores.mul_scalar(weights[0])?)
306 }
307 }
308}