Skip to main content

torsh_graph/conv/
sage.rs

1//! GraphSAGE (Sample and Aggregate) layer implementation
2
3use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6    creation::{randn, zeros},
7    Tensor,
8};
9
10/// GraphSAGE convolution layer
11#[derive(Debug)]
12pub struct SAGEConv {
13    in_features: usize,
14    out_features: usize,
15    weight_neighbor: Parameter,
16    weight_self: Parameter,
17    bias: Option<Parameter>,
18}
19
20impl SAGEConv {
21    /// Create a new GraphSAGE convolution layer
22    pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
23        let weight_neighbor = Parameter::new(
24            randn(&[in_features, out_features]).expect("failed to create neighbor weight tensor"),
25        );
26        let weight_self = Parameter::new(
27            randn(&[in_features, out_features]).expect("failed to create self weight tensor"),
28        );
29        let bias = if bias {
30            Some(Parameter::new(
31                zeros(&[out_features]).expect("failed to create bias tensor"),
32            ))
33        } else {
34            None
35        };
36
37        Self {
38            in_features,
39            out_features,
40            weight_neighbor,
41            weight_self,
42            bias,
43        }
44    }
45
46    /// Get input feature dimension
47    pub fn in_features(&self) -> usize {
48        self.in_features
49    }
50
51    /// Get output feature dimension
52    pub fn out_features(&self) -> usize {
53        self.out_features
54    }
55
56    /// Apply GraphSAGE convolution
57    pub fn forward(&self, graph: &GraphData) -> GraphData {
58        let num_nodes = graph.num_nodes;
59        let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
60            .expect("failed to extract edge index data");
61
62        // Build adjacency list for efficient neighbor aggregation
63        let mut adjacency_list: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
64        for j in 0..edge_data[0].len() {
65            let src = edge_data[0][j] as usize;
66            let dst = edge_data[1][j] as usize;
67            adjacency_list[dst].push(src);
68        }
69
70        // Aggregate neighbor features (mean aggregation)
71        let mut neighbor_features = zeros(&[num_nodes, self.in_features])
72            .expect("failed to create neighbor features tensor");
73
74        for node in 0..num_nodes {
75            if !adjacency_list[node].is_empty() {
76                let mut aggregated = zeros(&[self.in_features])
77                    .expect("failed to create aggregated features tensor");
78
79                for &neighbor in &adjacency_list[node] {
80                    let neighbor_slice = graph
81                        .x
82                        .slice(0, neighbor, neighbor + 1)
83                        .expect("failed to slice neighbor features")
84                        .to_tensor()
85                        .expect("failed to convert slice to tensor");
86                    let neighbor_feat = neighbor_slice.squeeze(0).expect("squeeze should succeed");
87                    aggregated = aggregated
88                        .add(&neighbor_feat)
89                        .expect("operation should succeed");
90                }
91
92                // Mean aggregation
93                aggregated = aggregated
94                    .div_scalar(adjacency_list[node].len() as f32)
95                    .expect("failed to compute mean aggregation");
96                // Store aggregated features for this node
97                let aggregated_data = aggregated.to_vec().expect("conversion should succeed");
98                for (i, &value) in aggregated_data.iter().enumerate() {
99                    neighbor_features
100                        .set_item(&[node, i], value)
101                        .expect("failed to set neighbor feature value");
102                }
103            }
104        }
105
106        // Transform neighbor features and self features
107        let neighbor_transformed = neighbor_features
108            .matmul(&self.weight_neighbor.clone_data())
109            .expect("operation should succeed");
110        let self_transformed = graph
111            .x
112            .matmul(&self.weight_self.clone_data())
113            .expect("operation should succeed");
114
115        // Combine neighbor and self representations
116        let mut output_features = neighbor_transformed
117            .add(&self_transformed)
118            .expect("operation should succeed");
119
120        // Add bias if present
121        if let Some(ref bias) = self.bias {
122            output_features = output_features
123                .add(&bias.clone_data())
124                .expect("operation should succeed");
125        }
126
127        // L2 normalize the output features (common in GraphSAGE)
128        // For simplicity, using standard normalization instead of row-wise normalization
129        let norm_val = output_features
130            .norm()
131            .expect("failed to compute feature norm");
132        let epsilon = 1e-8_f32;
133        let norm_scalar = norm_val
134            .item()
135            .expect("tensor should have single item")
136            .max(epsilon);
137        output_features = output_features
138            .div_scalar(norm_scalar)
139            .expect("failed to normalize output features");
140
141        // Create output graph
142        GraphData {
143            x: output_features,
144            edge_index: graph.edge_index.clone(),
145            edge_attr: graph.edge_attr.clone(),
146            batch: graph.batch.clone(),
147            num_nodes: graph.num_nodes,
148            num_edges: graph.num_edges,
149        }
150    }
151}
152
153impl GraphLayer for SAGEConv {
154    fn forward(&self, graph: &GraphData) -> GraphData {
155        self.forward(graph)
156    }
157
158    fn parameters(&self) -> Vec<Tensor> {
159        let mut params = vec![
160            self.weight_neighbor.clone_data(),
161            self.weight_self.clone_data(),
162        ];
163        if let Some(ref bias) = self.bias {
164            params.push(bias.clone_data());
165        }
166        params
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use torsh_core::device::DeviceType;
174    use torsh_tensor::creation::from_vec;
175
176    #[test]
177    fn test_sage_creation() {
178        let sage = SAGEConv::new(10, 20, true);
179        let params = sage.parameters();
180        assert_eq!(params.len(), 3); // weight_neighbor + weight_self + bias
181    }
182
183    #[test]
184    fn test_sage_forward() {
185        let sage = SAGEConv::new(4, 8, false);
186
187        // Create test graph
188        let x = from_vec(
189            vec![
190                1.0, 2.0, 3.0, 4.0, // node 0
191                5.0, 6.0, 7.0, 8.0, // node 1
192                9.0, 10.0, 11.0, 12.0, // node 2
193            ],
194            &[3, 4],
195            DeviceType::Cpu,
196        )
197        .unwrap();
198        let edge_index =
199            from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
200        let graph = GraphData::new(x, edge_index);
201
202        let output = sage.forward(&graph);
203        assert_eq!(output.x.shape().dims(), &[3, 8]);
204        assert_eq!(output.num_nodes, 3);
205
206        // Check that output is finite (simplified test since norm_dim doesn't exist)
207        let output_values = output.x.to_vec().unwrap();
208        for &val in &output_values {
209            assert!(val.is_finite(), "Output should be finite");
210        }
211    }
212}