Skip to main content

torsh_graph/conv/
transformer.rs

1//! Graph Transformer Networks layer implementation
2
3use crate::parameter::Parameter;
4use crate::{GraphData, GraphLayer};
5use torsh_tensor::{
6    creation::{randn, zeros},
7    Tensor,
8};
9
10/// Graph Transformer Networks layer
11#[derive(Debug)]
12pub struct GraphTransformer {
13    in_features: usize,
14    out_features: usize,
15    heads: usize,
16    edge_dim: usize,
17    query_weight: Parameter,
18    key_weight: Parameter,
19    value_weight: Parameter,
20    edge_weight: Parameter,
21    output_weight: Parameter,
22    bias: Option<Parameter>,
23    dropout: f32,
24}
25
26impl GraphTransformer {
27    /// Create a new Graph Transformer layer
28    pub fn new(
29        in_features: usize,
30        out_features: usize,
31        heads: usize,
32        edge_dim: usize,
33        dropout: f32,
34        bias: bool,
35    ) -> Self {
36        let query_weight = Parameter::new(
37            randn(&[in_features, out_features]).expect("failed to create query weight tensor"),
38        );
39        let key_weight = Parameter::new(
40            randn(&[in_features, out_features]).expect("failed to create key weight tensor"),
41        );
42        let value_weight = Parameter::new(
43            randn(&[in_features, out_features]).expect("failed to create value weight tensor"),
44        );
45        let edge_weight =
46            Parameter::new(randn(&[edge_dim, heads]).expect("failed to create edge weight tensor"));
47        let output_weight = Parameter::new(
48            randn(&[out_features, out_features]).expect("failed to create output weight tensor"),
49        );
50
51        let bias = if bias {
52            Some(Parameter::new(
53                zeros(&[out_features]).expect("failed to create bias tensor"),
54            ))
55        } else {
56            None
57        };
58
59        Self {
60            in_features,
61            out_features,
62            heads,
63            edge_dim,
64            query_weight,
65            key_weight,
66            value_weight,
67            edge_weight,
68            output_weight,
69            bias,
70            dropout,
71        }
72    }
73
74    /// Get input feature dimension
75    pub fn in_features(&self) -> usize {
76        self.in_features
77    }
78
79    /// Get output feature dimension
80    pub fn out_features(&self) -> usize {
81        self.out_features
82    }
83
84    /// Get number of attention heads
85    pub fn heads(&self) -> usize {
86        self.heads
87    }
88
89    /// Get edge feature dimension
90    pub fn edge_dim(&self) -> usize {
91        self.edge_dim
92    }
93
94    /// Get dropout rate
95    pub fn dropout(&self) -> f32 {
96        self.dropout
97    }
98
99    /// Apply graph transformer convolution
100    pub fn forward(&self, graph: &GraphData) -> GraphData {
101        let num_nodes = graph.num_nodes;
102        let head_dim = self.out_features / self.heads;
103
104        // Linear transformations for Q, K, V
105        let queries = graph
106            .x
107            .matmul(&self.query_weight.clone_data())
108            .expect("operation should succeed");
109        let keys = graph
110            .x
111            .matmul(&self.key_weight.clone_data())
112            .expect("operation should succeed");
113        let values = graph
114            .x
115            .matmul(&self.value_weight.clone_data())
116            .expect("operation should succeed");
117
118        // Reshape for multi-head attention
119        let q = queries
120            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
121            .expect("view should succeed");
122        let k = keys
123            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
124            .expect("view should succeed");
125        let v = values
126            .view(&[num_nodes as i32, self.heads as i32, head_dim as i32])
127            .expect("view should succeed");
128
129        // Initialize output
130        let mut output_features = zeros(&[num_nodes, self.out_features])
131            .expect("failed to create output features tensor");
132
133        // For simplicity, use a basic attention mechanism
134        for head in 0..self.heads {
135            let head_dim_start = head * head_dim;
136            let head_dim_end = (head + 1) * head_dim;
137
138            let q_head = q
139                .slice(1, head, head + 1)
140                .expect("failed to slice query head");
141            let k_head = k
142                .slice(1, head, head + 1)
143                .expect("failed to slice key head");
144            let v_head = v
145                .slice(1, head, head + 1)
146                .expect("failed to slice value head");
147
148            // Basic self-attention computation
149            let scale = 1.0 / (head_dim as f64).sqrt();
150            let k_head_tensor = k_head
151                .to_tensor()
152                .expect("failed to convert key head to tensor")
153                .squeeze_tensor(1)
154                .expect("failed to squeeze key head");
155            let q_head_tensor = q_head
156                .to_tensor()
157                .expect("failed to convert query head to tensor")
158                .squeeze_tensor(1)
159                .expect("failed to squeeze query head");
160            let v_head_tensor = v_head
161                .to_tensor()
162                .expect("failed to convert value head to tensor")
163                .squeeze_tensor(1)
164                .expect("failed to squeeze value head");
165
166            let k_transposed = k_head_tensor
167                .transpose(0, 1)
168                .expect("transpose should succeed");
169            let attention_scores = q_head_tensor
170                .matmul(&k_transposed)
171                .expect("operation should succeed")
172                .mul_scalar(scale as f32)
173                .expect("failed to scale attention scores");
174            let attention_weights = attention_scores
175                .softmax(-1)
176                .expect("failed to apply softmax to attention scores");
177            let head_output = attention_weights
178                .matmul(&v_head_tensor)
179                .expect("operation should succeed");
180
181            // Copy to output
182            let output_slice = output_features
183                .slice(1, head_dim_start, head_dim_end)
184                .expect("failed to slice output features");
185            // head_output is already [num_nodes, head_dim] - no need to squeeze
186            let mut output_slice_tensor = output_slice
187                .to_tensor()
188                .expect("failed to convert output slice to tensor");
189            output_slice_tensor
190                .copy_(&head_output)
191                .expect("failed to copy head output to output tensor");
192        }
193
194        // Apply output projection
195        output_features = output_features
196            .matmul(&self.output_weight.clone_data())
197            .expect("operation should succeed");
198
199        // Add bias if present
200        if let Some(ref bias) = self.bias {
201            output_features = output_features
202                .add(&bias.clone_data())
203                .expect("operation should succeed");
204        }
205
206        GraphData {
207            x: output_features,
208            edge_index: graph.edge_index.clone(),
209            edge_attr: graph.edge_attr.clone(),
210            batch: graph.batch.clone(),
211            num_nodes: graph.num_nodes,
212            num_edges: graph.num_edges,
213        }
214    }
215}
216
217impl GraphLayer for GraphTransformer {
218    fn forward(&self, graph: &GraphData) -> GraphData {
219        self.forward(graph)
220    }
221
222    fn parameters(&self) -> Vec<Tensor> {
223        let mut params = vec![
224            self.query_weight.clone_data(),
225            self.key_weight.clone_data(),
226            self.value_weight.clone_data(),
227            self.edge_weight.clone_data(),
228            self.output_weight.clone_data(),
229        ];
230        if let Some(ref bias) = self.bias {
231            params.push(bias.clone_data());
232        }
233        params
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use torsh_core::device::DeviceType;
241    use torsh_tensor::creation::from_vec;
242
243    #[test]
244    fn test_transformer_creation() {
245        let transformer = GraphTransformer::new(16, 32, 8, 4, 0.1, true);
246        let params = transformer.parameters();
247        assert_eq!(params.len(), 6); // Q, K, V, edge, output weights + bias
248        assert_eq!(transformer.heads, 8);
249    }
250
251    #[test]
252    fn test_transformer_forward() {
253        let transformer = GraphTransformer::new(6, 12, 3, 2, 0.0, false);
254
255        // Create test graph
256        let x = from_vec(
257            vec![
258                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // node 0
259                7.0, 8.0, 9.0, 10.0, 11.0, 12.0, // node 1
260                13.0, 14.0, 15.0, 16.0, 17.0, 18.0, // node 2
261            ],
262            &[3, 6],
263            DeviceType::Cpu,
264        )
265        .unwrap();
266        let edge_index =
267            from_vec(vec![0.0, 1.0, 2.0, 1.0, 2.0, 0.0], &[2, 3], DeviceType::Cpu).unwrap();
268        let graph = GraphData::new(x, edge_index);
269
270        let output = transformer.forward(&graph);
271        assert_eq!(output.x.shape().dims(), &[3, 12]);
272        assert_eq!(output.num_nodes, 3);
273    }
274}