1use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6 creation::{randn, zeros},
7 Tensor,
8};
9
10#[derive(Debug)]
12pub struct GraphTransformer {
13 in_features: usize,
14 out_features: usize,
15 heads: usize,
16 edge_dim: usize,
17 query_weight: Parameter,
18 key_weight: Parameter,
19 value_weight: Parameter,
20 edge_weight: Parameter,
21 output_weight: Parameter,
22 bias: Option<Parameter>,
23 dropout: f32,
24}
25
26impl GraphTransformer {
27 pub fn new(
29 in_features: usize,
30 out_features: usize,
31 heads: usize,
32 edge_dim: usize,
33 dropout: f32,
34 bias: bool,
35 ) -> Self {
36 let query_weight = Parameter::new(
37 randn(&[in_features, out_features]).expect("failed to create query weight tensor"),
38 );
39 let key_weight = Parameter::new(
40 randn(&[in_features, out_features]).expect("failed to create key weight tensor"),
41 );
42 let value_weight = Parameter::new(
43 randn(&[in_features, out_features]).expect("failed to create value weight tensor"),
44 );
45 let edge_weight =
46 Parameter::new(randn(&[edge_dim, heads]).expect("failed to create edge weight tensor"));
47 let output_weight = Parameter::new(
48 randn(&[out_features, out_features]).expect("failed to create output weight tensor"),
49 );
50
51 let bias = if bias {
52 Some(Parameter::new(
53 zeros(&[out_features]).expect("failed to create bias tensor"),
54 ))
55 } else {
56 None
57 };
58
59 Self {
60 in_features,
61 out_features,
62 heads,
63 edge_dim,
64 query_weight,
65 key_weight,
66 value_weight,
67 edge_weight,
68 output_weight,
69 bias,
70 dropout,
71 }
72 }
73
74 pub fn in_features(&self) -> usize {
76 self.in_features
77 }
78
79 pub fn out_features(&self) -> usize {
81 self.out_features
82 }
83
84 pub fn heads(&self) -> usize {
86 self.heads
87 }
88
89 pub fn edge_dim(&self) -> usize {
91 self.edge_dim
92 }
93
94 pub fn dropout(&self) -> f32 {
96 self.dropout
97 }
98
99 pub fn forward(&self, graph: &GraphData) -> GraphData {
101 let num_nodes = graph.num_nodes;
102 let head_dim = self.out_features / self.heads;
103
104 let queries = graph
106 .x
107 .matmul(&self.query_weight.clone_data())
108 .expect("operation should succeed");
109 let keys = graph
110 .x
111 .matmul(&self.key_weight.clone_data())
112 .expect("operation should succeed");
113 let values = graph
114 .x
115 .matmul(&self.value_weight.clone_data())
116 .expect("operation should succeed");
117
118 let q = queries
120 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
121 .expect("view should succeed");
122 let k = keys
123 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
124 .expect("view should succeed");
125 let v = values
126 .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
127 .expect("view should succeed");
128
129 let mut output_features = zeros(&[num_nodes, self.out_features])
131 .expect("failed to create output features tensor");
132
133 for head in 0..self.heads {
135 let head_dim_start = head * head_dim;
136 let head_dim_end = (head + 1) * head_dim;
137
138 let q_head = q
139 .slice(1, head, head + 1)
140 .expect("failed to slice query head");
141 let k_head = k
142 .slice(1, head, head + 1)
143 .expect("failed to slice key head");
144 let v_head = v
145 .slice(1, head, head + 1)
146 .expect("failed to slice value head");
147
148 let scale = 1.0 / (head_dim as f64).sqrt();
150 let k_head_tensor = k_head
151 .to_tensor()
152 .expect("failed to convert key head to tensor")
153 .squeeze_tensor(1)
154 .expect("failed to squeeze key head");
155 let q_head_tensor = q_head
156 .to_tensor()
157 .expect("failed to convert query head to tensor")
158 .squeeze_tensor(1)
159 .expect("failed to squeeze query head");
160 let v_head_tensor = v_head
161 .to_tensor()
162 .expect("failed to convert value head to tensor")
163 .squeeze_tensor(1)
164 .expect("failed to squeeze value head");
165
166 let k_transposed = k_head_tensor
167 .transpose(0, 1)
168 .expect("transpose should succeed");
169 let attention_scores = q_head_tensor
170 .matmul(&k_transposed)
171 .expect("operation should succeed")
172 .mul_scalar(scale as f32)
173 .expect("failed to scale attention scores");
174 let attention_weights = attention_scores
175 .softmax(-1)
176 .expect("failed to apply softmax to attention scores");
177 let head_output = attention_weights
178 .matmul(&v_head_tensor)
179 .expect("operation should succeed");
180
181 let output_slice = output_features
183 .slice(1, head_dim_start, head_dim_end)
184 .expect("failed to slice output features");
185 let mut output_slice_tensor = output_slice
187 .to_tensor()
188 .expect("failed to convert output slice to tensor");
189 output_slice_tensor
190 .copy_(&head_output)
191 .expect("failed to copy head output to output tensor");
192 }
193
194 output_features = output_features
196 .matmul(&self.output_weight.clone_data())
197 .expect("operation should succeed");
198
199 if let Some(ref bias) = self.bias {
201 output_features = output_features
202 .add(&bias.clone_data())
203 .expect("operation should succeed");
204 }
205
206 GraphData {
207 x: output_features,
208 edge_index: graph.edge_index.clone(),
209 edge_attr: graph.edge_attr.clone(),
210 batch: graph.batch.clone(),
211 num_nodes: graph.num_nodes,
212 num_edges: graph.num_edges,
213 }
214 }
215}
216
217impl GraphLayer for GraphTransformer {
218 fn forward(&self, graph: &GraphData) -> GraphData {
219 self.forward(graph)
220 }
221
222 fn parameters(&self) -> Vec<Tensor> {
223 let mut params = vec![
224 self.query_weight.clone_data(),
225 self.key_weight.clone_data(),
226 self.value_weight.clone_data(),
227 self.edge_weight.clone_data(),
228 self.output_weight.clone_data(),
229 ];
230 if let Some(ref bias) = self.bias {
231 params.push(bias.clone_data());
232 }
233 params
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use torsh_core::device::DeviceType;
241 use torsh_tensor::creation::from_vec;
242
243 #[test]
244 fn test_transformer_creation() {
245 let transformer = GraphTransformer::new(16, 32, 8, 4, 0.1, true);
246 let params = transformer.parameters();
247 assert_eq!(params.len(), 6); assert_eq!(transformer.heads, 8);
249 }
250
251 #[test]
252 fn test_transformer_forward() {
253 let transformer = GraphTransformer::new(6, 12, 3, 2, 0.0, false);
254
255 let x = from_vec(
257 vec![
258 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, ],
262 &[3, 6],
263 DeviceType::Cpu,
264 )
265 .unwrap();
266 let edge_index =
267 from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
268 let graph = GraphData::new(x, edge_index);
269
270 let output = transformer.forward(&graph);
271 assert_eq!(output.x.shape().dims(), &[3, 12]);
272 assert_eq!(output.num_nodes, 3);
273 }
274}