1use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6 creation::{randn, zeros},
7 Tensor,
8};
9
10#[derive(Debug)]
12pub struct GCNConv {
13 in_features: usize,
14 out_features: usize,
15 weight: Parameter,
16 bias: Option<Parameter>,
17}
18
19impl GCNConv {
20 pub fn new(in_features: usize, out_features: usize, bias: bool) -> Self {
22 let weight = Parameter::new(
23 randn(&[in_features, out_features]).expect("failed to create weight tensor"),
24 );
25 let bias = if bias {
26 Some(Parameter::new(
27 zeros(&[out_features]).expect("failed to create bias tensor"),
28 ))
29 } else {
30 None
31 };
32
33 Self {
34 in_features,
35 out_features,
36 weight,
37 bias,
38 }
39 }
40
41 pub fn in_features(&self) -> usize {
43 self.in_features
44 }
45
46 pub fn out_features(&self) -> usize {
48 self.out_features
49 }
50
51 pub fn forward(&self, graph: &GraphData) -> GraphData {
53 let laplacian = crate::utils::graph_laplacian(&graph.edge_index, graph.num_nodes, true);
55
56 let x_transformed = graph
58 .x
59 .matmul(&self.weight.clone_data())
60 .expect("operation should succeed");
61 let mut output_features = laplacian
62 .matmul(&x_transformed)
63 .expect("operation should succeed");
64
65 if let Some(ref bias) = self.bias {
67 output_features = output_features
68 .add(&bias.clone_data())
69 .expect("operation should succeed");
70 }
71
72 GraphData {
74 x: output_features,
75 edge_index: graph.edge_index.clone(),
76 edge_attr: graph.edge_attr.clone(),
77 batch: graph.batch.clone(),
78 num_nodes: graph.num_nodes,
79 num_edges: graph.num_edges,
80 }
81 }
82}
83
84impl GraphLayer for GCNConv {
85 fn forward(&self, graph: &GraphData) -> GraphData {
86 self.forward(graph)
87 }
88
89 fn parameters(&self) -> Vec<Tensor> {
90 let mut params = vec![self.weight.clone_data()];
91 if let Some(ref bias) = self.bias {
92 params.push(bias.clone_data());
93 }
94 params
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use torsh_core::device::DeviceType;
102 use torsh_tensor::creation::from_vec;
103
104 #[test]
105 fn test_gcn_creation() {
106 let gcn = GCNConv::new(8, 16, true);
107 let params = gcn.parameters();
108 assert_eq!(params.len(), 2); }
110
111 #[test]
112 fn test_gcn_forward() {
113 let gcn = GCNConv::new(3, 8, false);
114
115 let x = from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], DeviceType::Cpu).unwrap();
117 let edge_index = from_vec(vec![0.0, 1.0, 1.0, 0.0], &[2, 2], DeviceType::Cpu).unwrap();
118 let graph = GraphData::new(x, edge_index);
119
120 let output = gcn.forward(&graph);
121 assert_eq!(output.x.shape().dims(), &[2, 8]);
122 assert_eq!(output.num_nodes, 2);
123 }
124}