1use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6 creation::{randn, zeros},
7 Tensor,
8};
9
10#[derive(Debug)]
12pub struct SAGEConv {
13 in_features: usize,
14 out_features: usize,
15 weight_neighbor: Parameter,
16 weight_self: Parameter,
17 bias: Option<Parameter>,
18}
19
20impl SAGEConv {
21 pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
23 let weight_neighbor = Parameter::new(
24 randn(&[in_features, out_features]).expect("failed to create neighbor weight tensor"),
25 );
26 let weight_self = Parameter::new(
27 randn(&[in_features, out_features]).expect("failed to create self weight tensor"),
28 );
29 let bias = if bias {
30 Some(Parameter::new(
31 zeros(&[out_features]).expect("failed to create bias tensor"),
32 ))
33 } else {
34 None
35 };
36
37 Self {
38 in_features,
39 out_features,
40 weight_neighbor,
41 weight_self,
42 bias,
43 }
44 }
45
46 pub fn in_features(&self) -> usize {
48 self.in_features
49 }
50
51 pub fn out_features(&self) -> usize {
53 self.out_features
54 }
55
56 pub fn forward(&self, graph: &GraphData) -> GraphData {
58 let num_nodes = graph.num_nodes;
59 let edge_data = crate::utils::tensor_to_vec2::<f32>(&graph.edge_index)
60 .expect("failed to extract edge index data");
61
62 let mut adjacency_list: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
64 for j in 0..edge_data[0].len() {
65 let src = edge_data[0][j] as usize;
66 let dst = edge_data[1][j] as usize;
67 adjacency_list[dst].push(src);
68 }
69
70 let mut neighbor_features = zeros(&[num_nodes, self.in_features])
72 .expect("failed to create neighbor features tensor");
73
74 for node in 0..num_nodes {
75 if !adjacency_list[node].is_empty() {
76 let mut aggregated = zeros(&[self.in_features])
77 .expect("failed to create aggregated features tensor");
78
79 for &neighbor in &adjacency_list[node] {
80 let neighbor_slice = graph
81 .x
82 .slice(0, neighbor, neighbor + 1)
83 .expect("failed to slice neighbor features")
84 .to_tensor()
85 .expect("failed to convert slice to tensor");
86 let neighbor_feat = neighbor_slice.squeeze(0).expect("squeeze should succeed");
87 aggregated = aggregated
88 .add(&neighbor_feat)
89 .expect("operation should succeed");
90 }
91
92 aggregated = aggregated
94 .div_scalar(adjacency_list[node].len() as f32)
95 .expect("failed to compute mean aggregation");
96 let aggregated_data = aggregated.to_vec().expect("conversion should succeed");
98 for (i, &value) in aggregated_data.iter().enumerate() {
99 neighbor_features
100 .set_item(&[node, i], value)
101 .expect("failed to set neighbor feature value");
102 }
103 }
104 }
105
106 let neighbor_transformed = neighbor_features
108 .matmul(&self.weight_neighbor.clone_data())
109 .expect("operation should succeed");
110 let self_transformed = graph
111 .x
112 .matmul(&self.weight_self.clone_data())
113 .expect("operation should succeed");
114
115 let mut output_features = neighbor_transformed
117 .add(&self_transformed)
118 .expect("operation should succeed");
119
120 if let Some(ref bias) = self.bias {
122 output_features = output_features
123 .add(&bias.clone_data())
124 .expect("operation should succeed");
125 }
126
127 let norm_val = output_features
130 .norm()
131 .expect("failed to compute feature norm");
132 let epsilon = 1e-8_f32;
133 let norm_scalar = norm_val
134 .item()
135 .expect("tensor should have single item")
136 .max(epsilon);
137 output_features = output_features
138 .div_scalar(norm_scalar)
139 .expect("failed to normalize output features");
140
141 GraphData {
143 x: output_features,
144 edge_index: graph.edge_index.clone(),
145 edge_attr: graph.edge_attr.clone(),
146 batch: graph.batch.clone(),
147 num_nodes: graph.num_nodes,
148 num_edges: graph.num_edges,
149 }
150 }
151}
152
153impl GraphLayer for SAGEConv {
154 fn forward(&self, graph: &GraphData) -> GraphData {
155 self.forward(graph)
156 }
157
158 fn parameters(&self) -> Vec<Tensor> {
159 let mut params = vec![
160 self.weight_neighbor.clone_data(),
161 self.weight_self.clone_data(),
162 ];
163 if let Some(ref bias) = self.bias {
164 params.push(bias.clone_data());
165 }
166 params
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use torsh_core::device::DeviceType;
174 use torsh_tensor::creation::from_vec;
175
176 #[test]
177 fn test_sage_creation() {
178 let sage = SAGEConv::new(10, 20, true);
179 let params = sage.parameters();
180 assert_eq!(params.len(), 3); }
182
183 #[test]
184 fn test_sage_forward() {
185 let sage = SAGEConv::new(4, 8, false);
186
187 let x = from_vec(
189 vec![
190 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
194 &[3, 4],
195 DeviceType::Cpu,
196 )
197 .unwrap();
198 let edge_index =
199 from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
200 let graph = GraphData::new(x, edge_index);
201
202 let output = sage.forward(&graph);
203 assert_eq!(output.x.shape().dims(), &[3, 8]);
204 assert_eq!(output.num_nodes, 3);
205
206 let output_values = output.x.to_vec().unwrap();
208 for &val in &output_values {
209 assert!(val.is_finite(), "Output should be finite");
210 }
211 }
212}