1use crate::error::GnnError;
7use ndarray::{Array1, Array2, ArrayView1};
8use rand::SeedableRng;
9use rand_distr::{Distribution, Normal};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Linear {
15 weights: Array2<f32>,
16 bias: Array1<f32>,
17}
18
19impl Linear {
20 pub fn new(input_dim: usize, output_dim: usize) -> Self {
24 let seed = (input_dim as u64).wrapping_mul(6364136223846793005)
26 ^ (output_dim as u64).wrapping_mul(1442695040888963407);
27 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
28
29 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
31 let normal = Normal::new(0.0, scale as f64).unwrap();
32
33 let weights =
34 Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
35
36 let bias = Array1::zeros(output_dim);
37
38 Self { weights, bias }
39 }
40
41 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
43 let x = ArrayView1::from(input);
44 let output = self.weights.dot(&x) + &self.bias;
45 output.to_vec()
46 }
47
48 pub fn output_dim(&self) -> usize {
50 self.weights.shape()[0]
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LayerNorm {
57 gamma: Array1<f32>,
58 beta: Array1<f32>,
59 eps: f32,
60}
61
62impl LayerNorm {
63 pub fn new(dim: usize, eps: f32) -> Self {
65 Self {
66 gamma: Array1::ones(dim),
67 beta: Array1::zeros(dim),
68 eps,
69 }
70 }
71
72 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
74 let x = ArrayView1::from(input);
75
76 let mean = x.mean().unwrap_or(0.0);
78 let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
79
80 let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
82
83 let output = &self.gamma * &normalized + &self.beta;
85 output.to_vec()
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MultiHeadAttention {
92 num_heads: usize,
93 head_dim: usize,
94 q_linear: Linear,
95 k_linear: Linear,
96 v_linear: Linear,
97 out_linear: Linear,
98}
99
100impl MultiHeadAttention {
101 pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
106 if embed_dim % num_heads != 0 {
107 return Err(GnnError::layer_config(format!(
108 "Embedding dimension ({}) must be divisible by number of heads ({})",
109 embed_dim, num_heads
110 )));
111 }
112
113 let head_dim = embed_dim / num_heads;
114
115 Ok(Self {
116 num_heads,
117 head_dim,
118 q_linear: Linear::new(embed_dim, embed_dim),
119 k_linear: Linear::new(embed_dim, embed_dim),
120 v_linear: Linear::new(embed_dim, embed_dim),
121 out_linear: Linear::new(embed_dim, embed_dim),
122 })
123 }
124
125 pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
135 if keys.is_empty() || values.is_empty() {
136 return query.to_vec();
137 }
138
139 let q = self.q_linear.forward(query);
141 let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
142 let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
143
144 let q_heads = self.split_heads(&q);
146 let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
147 let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
148
149 let mut head_outputs = Vec::new();
151 for h in 0..self.num_heads {
152 let q_h = &q_heads[h];
153 let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
154 let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
155
156 let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
157 head_outputs.push(head_output);
158 }
159
160 let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
162
163 self.out_linear.forward(&concat)
165 }
166
167 fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
169 let mut heads = Vec::new();
170 for h in 0..self.num_heads {
171 let start = h * self.head_dim;
172 let end = start + self.head_dim;
173 heads.push(x[start..end].to_vec());
174 }
175 heads
176 }
177
178 fn scaled_dot_product_attention(
180 &self,
181 query: &[f32],
182 keys: &[&Vec<f32>],
183 values: &[&Vec<f32>],
184 ) -> Vec<f32> {
185 if keys.is_empty() {
186 return query.to_vec();
187 }
188
189 let scale = (self.head_dim as f32).sqrt();
190
191 let scores: Vec<f32> = keys
193 .iter()
194 .map(|k| {
195 let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
196 dot / scale
197 })
198 .collect();
199
200 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
202 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
203 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
204 let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
205
206 let mut output = vec![0.0; self.head_dim];
208 for (weight, value) in attention_weights.iter().zip(values.iter()) {
209 for (out, &val) in output.iter_mut().zip(value.iter()) {
210 *out += weight * val;
211 }
212 }
213
214 output
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct GRUCell {
221 w_z: Linear,
223 u_z: Linear,
224
225 w_r: Linear,
227 u_r: Linear,
228
229 w_h: Linear,
231 u_h: Linear,
232}
233
234impl GRUCell {
235 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
237 Self {
238 w_z: Linear::new(input_dim, hidden_dim),
240 u_z: Linear::new(hidden_dim, hidden_dim),
241
242 w_r: Linear::new(input_dim, hidden_dim),
244 u_r: Linear::new(hidden_dim, hidden_dim),
245
246 w_h: Linear::new(input_dim, hidden_dim),
248 u_h: Linear::new(hidden_dim, hidden_dim),
249 }
250 }
251
252 pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
261 let z =
263 self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
264
265 let r =
267 self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
268
269 let r_hidden = self.mul_vecs(&r, hidden);
271 let h_tilde =
272 self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
273
274 let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
276 let term1 = self.mul_vecs(&one_minus_z, hidden);
277 let term2 = self.mul_vecs(&z, &h_tilde);
278
279 self.add_vecs(&term1, &term2)
280 }
281
282 fn sigmoid(&self, x: f32) -> f32 {
284 if x > 0.0 {
285 1.0 / (1.0 + (-x).exp())
286 } else {
287 let ex = x.exp();
288 ex / (1.0 + ex)
289 }
290 }
291
292 fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
294 v.iter().map(|&x| self.sigmoid(x)).collect()
295 }
296
297 fn tanh(&self, x: f32) -> f32 {
299 x.tanh()
300 }
301
302 fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
304 v.iter().map(|&x| self.tanh(x)).collect()
305 }
306
307 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
309 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
310 }
311
312 fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
314 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
315 }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct RuvectorLayer {
321 w_msg: Linear,
323
324 w_agg: Linear,
326
327 w_update: GRUCell,
329
330 attention: MultiHeadAttention,
332
333 norm: LayerNorm,
335
336 dropout: f32,
338}
339
340impl RuvectorLayer {
341 pub fn new(
353 input_dim: usize,
354 hidden_dim: usize,
355 heads: usize,
356 dropout: f32,
357 ) -> Result<Self, GnnError> {
358 if !(0.0..=1.0).contains(&dropout) {
359 return Err(GnnError::layer_config(format!(
360 "Dropout must be between 0.0 and 1.0, got {}",
361 dropout
362 )));
363 }
364
365 Ok(Self {
366 w_msg: Linear::new(input_dim, hidden_dim),
367 w_agg: Linear::new(hidden_dim, hidden_dim),
368 w_update: GRUCell::new(hidden_dim, hidden_dim),
369 attention: MultiHeadAttention::new(hidden_dim, heads)?,
370 norm: LayerNorm::new(hidden_dim, 1e-5),
371 dropout,
372 })
373 }
374
375 pub fn forward(
385 &self,
386 node_embedding: &[f32],
387 neighbor_embeddings: &[Vec<f32>],
388 edge_weights: &[f32],
389 ) -> Vec<f32> {
390 if neighbor_embeddings.is_empty() {
391 let projected = self.w_msg.forward(node_embedding);
393 return self.norm.forward(&projected);
394 }
395
396 let node_msg = self.w_msg.forward(node_embedding);
398 let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
399 .iter()
400 .map(|n| self.w_msg.forward(n))
401 .collect();
402
403 let attention_output = self
405 .attention
406 .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
407
408 let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
410
411 let combined = self.add_vecs(&attention_output, &weighted_msgs);
413 let aggregated = self.w_agg.forward(&combined);
414
415 let updated = self.w_update.forward(&aggregated, &node_msg);
417
418 let dropped = self.apply_dropout(&updated);
420
421 self.norm.forward(&dropped)
423 }
424
425 fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
427 if messages.is_empty() || weights.is_empty() {
428 return vec![0.0; self.w_msg.output_dim()];
429 }
430
431 let weight_sum: f32 = weights.iter().sum();
433 let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
434 weights.iter().map(|&w| w / weight_sum).collect()
435 } else {
436 vec![1.0 / weights.len() as f32; weights.len()]
437 };
438
439 let dim = messages[0].len();
441 let mut aggregated = vec![0.0; dim];
442
443 for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
444 for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
445 *agg += weight * m;
446 }
447 }
448
449 aggregated
450 }
451
452 fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
454 let scale = 1.0 - self.dropout;
455 input.iter().map(|&x| x * scale).collect()
456 }
457
458 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
460 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn test_linear_layer() {
470 let linear = Linear::new(4, 2);
471 let input = vec![1.0, 2.0, 3.0, 4.0];
472 let output = linear.forward(&input);
473 assert_eq!(output.len(), 2);
474 }
475
476 #[test]
477 fn test_layer_norm() {
478 let norm = LayerNorm::new(4, 1e-5);
479 let input = vec![1.0, 2.0, 3.0, 4.0];
480 let output = norm.forward(&input);
481
482 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
484 assert!((mean).abs() < 1e-5);
485 }
486
487 #[test]
488 fn test_multihead_attention() {
489 let attention = MultiHeadAttention::new(8, 2).unwrap();
490 let query = vec![0.5; 8];
491 let keys = vec![vec![0.3; 8], vec![0.7; 8]];
492 let values = vec![vec![0.2; 8], vec![0.8; 8]];
493
494 let output = attention.forward(&query, &keys, &values);
495 assert_eq!(output.len(), 8);
496 }
497
498 #[test]
499 fn test_multihead_attention_invalid_dims() {
500 let result = MultiHeadAttention::new(10, 3);
501 assert!(result.is_err());
502 let err = result.unwrap_err().to_string();
503 assert!(err.contains("divisible"));
504 }
505
506 #[test]
507 fn test_gru_cell() {
508 let gru = GRUCell::new(4, 8);
509 let input = vec![1.0; 4];
510 let hidden = vec![0.5; 8];
511
512 let new_hidden = gru.forward(&input, &hidden);
513 assert_eq!(new_hidden.len(), 8);
514 }
515
516 #[test]
517 fn test_ruvector_layer() {
518 let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
519
520 let node = vec![1.0, 2.0, 3.0, 4.0];
521 let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
522 let weights = vec![0.3, 0.7];
523
524 let output = layer.forward(&node, &neighbors, &weights);
525 assert_eq!(output.len(), 8);
526 }
527
528 #[test]
529 fn test_ruvector_layer_no_neighbors() {
530 let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
531
532 let node = vec![1.0, 2.0, 3.0, 4.0];
533 let neighbors: Vec<Vec<f32>> = vec![];
534 let weights: Vec<f32> = vec![];
535
536 let output = layer.forward(&node, &neighbors, &weights);
537 assert_eq!(output.len(), 8);
538 }
539
540 #[test]
541 fn test_ruvector_layer_invalid_dropout() {
542 let result = RuvectorLayer::new(4, 8, 2, 1.5);
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_ruvector_layer_invalid_heads() {
548 let result = RuvectorLayer::new(4, 7, 3, 0.1);
549 assert!(result.is_err());
550 }
551}