1use crate::error::GnnError;
7use ndarray::{Array1, Array2, ArrayView1};
8use rand::Rng;
9use rand_distr::{Distribution, Normal};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Linear {
15 weights: Array2<f32>,
16 bias: Array1<f32>,
17}
18
19impl Linear {
20 pub fn new(input_dim: usize, output_dim: usize) -> Self {
22 let mut rng = rand::thread_rng();
23
24 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
26 let normal = Normal::new(0.0, scale as f64).unwrap();
27
28 let weights =
29 Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
30
31 let bias = Array1::zeros(output_dim);
32
33 Self { weights, bias }
34 }
35
36 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
38 let x = ArrayView1::from(input);
39 let output = self.weights.dot(&x) + &self.bias;
40 output.to_vec()
41 }
42
43 pub fn output_dim(&self) -> usize {
45 self.weights.shape()[0]
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LayerNorm {
52 gamma: Array1<f32>,
53 beta: Array1<f32>,
54 eps: f32,
55}
56
57impl LayerNorm {
58 pub fn new(dim: usize, eps: f32) -> Self {
60 Self {
61 gamma: Array1::ones(dim),
62 beta: Array1::zeros(dim),
63 eps,
64 }
65 }
66
67 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
69 let x = ArrayView1::from(input);
70
71 let mean = x.mean().unwrap_or(0.0);
73 let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
74
75 let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
77
78 let output = &self.gamma * &normalized + &self.beta;
80 output.to_vec()
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct MultiHeadAttention {
87 num_heads: usize,
88 head_dim: usize,
89 q_linear: Linear,
90 k_linear: Linear,
91 v_linear: Linear,
92 out_linear: Linear,
93}
94
95impl MultiHeadAttention {
96 pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
101 if embed_dim % num_heads != 0 {
102 return Err(GnnError::layer_config(format!(
103 "Embedding dimension ({}) must be divisible by number of heads ({})",
104 embed_dim, num_heads
105 )));
106 }
107
108 let head_dim = embed_dim / num_heads;
109
110 Ok(Self {
111 num_heads,
112 head_dim,
113 q_linear: Linear::new(embed_dim, embed_dim),
114 k_linear: Linear::new(embed_dim, embed_dim),
115 v_linear: Linear::new(embed_dim, embed_dim),
116 out_linear: Linear::new(embed_dim, embed_dim),
117 })
118 }
119
120 pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
130 if keys.is_empty() || values.is_empty() {
131 return query.to_vec();
132 }
133
134 let q = self.q_linear.forward(query);
136 let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
137 let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
138
139 let q_heads = self.split_heads(&q);
141 let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
142 let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
143
144 let mut head_outputs = Vec::new();
146 for h in 0..self.num_heads {
147 let q_h = &q_heads[h];
148 let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
149 let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
150
151 let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
152 head_outputs.push(head_output);
153 }
154
155 let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
157
158 self.out_linear.forward(&concat)
160 }
161
162 fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
164 let mut heads = Vec::new();
165 for h in 0..self.num_heads {
166 let start = h * self.head_dim;
167 let end = start + self.head_dim;
168 heads.push(x[start..end].to_vec());
169 }
170 heads
171 }
172
173 fn scaled_dot_product_attention(
175 &self,
176 query: &[f32],
177 keys: &[&Vec<f32>],
178 values: &[&Vec<f32>],
179 ) -> Vec<f32> {
180 if keys.is_empty() {
181 return query.to_vec();
182 }
183
184 let scale = (self.head_dim as f32).sqrt();
185
186 let scores: Vec<f32> = keys
188 .iter()
189 .map(|k| {
190 let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
191 dot / scale
192 })
193 .collect();
194
195 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
197 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
198 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
199 let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
200
201 let mut output = vec![0.0; self.head_dim];
203 for (weight, value) in attention_weights.iter().zip(values.iter()) {
204 for (out, &val) in output.iter_mut().zip(value.iter()) {
205 *out += weight * val;
206 }
207 }
208
209 output
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct GRUCell {
216 w_z: Linear,
218 u_z: Linear,
219
220 w_r: Linear,
222 u_r: Linear,
223
224 w_h: Linear,
226 u_h: Linear,
227}
228
229impl GRUCell {
230 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
232 Self {
233 w_z: Linear::new(input_dim, hidden_dim),
235 u_z: Linear::new(hidden_dim, hidden_dim),
236
237 w_r: Linear::new(input_dim, hidden_dim),
239 u_r: Linear::new(hidden_dim, hidden_dim),
240
241 w_h: Linear::new(input_dim, hidden_dim),
243 u_h: Linear::new(hidden_dim, hidden_dim),
244 }
245 }
246
247 pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
256 let z =
258 self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
259
260 let r =
262 self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
263
264 let r_hidden = self.mul_vecs(&r, hidden);
266 let h_tilde =
267 self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
268
269 let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
271 let term1 = self.mul_vecs(&one_minus_z, hidden);
272 let term2 = self.mul_vecs(&z, &h_tilde);
273
274 self.add_vecs(&term1, &term2)
275 }
276
277 fn sigmoid(&self, x: f32) -> f32 {
279 if x > 0.0 {
280 1.0 / (1.0 + (-x).exp())
281 } else {
282 let ex = x.exp();
283 ex / (1.0 + ex)
284 }
285 }
286
287 fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
289 v.iter().map(|&x| self.sigmoid(x)).collect()
290 }
291
292 fn tanh(&self, x: f32) -> f32 {
294 x.tanh()
295 }
296
297 fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
299 v.iter().map(|&x| self.tanh(x)).collect()
300 }
301
302 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
304 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
305 }
306
307 fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
309 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
310 }
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct RuvectorLayer {
316 w_msg: Linear,
318
319 w_agg: Linear,
321
322 w_update: GRUCell,
324
325 attention: MultiHeadAttention,
327
328 norm: LayerNorm,
330
331 dropout: f32,
333}
334
335impl RuvectorLayer {
336 pub fn new(
348 input_dim: usize,
349 hidden_dim: usize,
350 heads: usize,
351 dropout: f32,
352 ) -> Result<Self, GnnError> {
353 if !(0.0..=1.0).contains(&dropout) {
354 return Err(GnnError::layer_config(format!(
355 "Dropout must be between 0.0 and 1.0, got {}",
356 dropout
357 )));
358 }
359
360 Ok(Self {
361 w_msg: Linear::new(input_dim, hidden_dim),
362 w_agg: Linear::new(hidden_dim, hidden_dim),
363 w_update: GRUCell::new(hidden_dim, hidden_dim),
364 attention: MultiHeadAttention::new(hidden_dim, heads)?,
365 norm: LayerNorm::new(hidden_dim, 1e-5),
366 dropout,
367 })
368 }
369
370 pub fn forward(
380 &self,
381 node_embedding: &[f32],
382 neighbor_embeddings: &[Vec<f32>],
383 edge_weights: &[f32],
384 ) -> Vec<f32> {
385 if neighbor_embeddings.is_empty() {
386 let projected = self.w_msg.forward(node_embedding);
388 return self.norm.forward(&projected);
389 }
390
391 let node_msg = self.w_msg.forward(node_embedding);
393 let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
394 .iter()
395 .map(|n| self.w_msg.forward(n))
396 .collect();
397
398 let attention_output = self
400 .attention
401 .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
402
403 let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
405
406 let combined = self.add_vecs(&attention_output, &weighted_msgs);
408 let aggregated = self.w_agg.forward(&combined);
409
410 let updated = self.w_update.forward(&aggregated, &node_msg);
412
413 let dropped = self.apply_dropout(&updated);
415
416 self.norm.forward(&dropped)
418 }
419
420 fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
422 if messages.is_empty() || weights.is_empty() {
423 return vec![0.0; self.w_msg.output_dim()];
424 }
425
426 let weight_sum: f32 = weights.iter().sum();
428 let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
429 weights.iter().map(|&w| w / weight_sum).collect()
430 } else {
431 vec![1.0 / weights.len() as f32; weights.len()]
432 };
433
434 let dim = messages[0].len();
436 let mut aggregated = vec![0.0; dim];
437
438 for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
439 for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
440 *agg += weight * m;
441 }
442 }
443
444 aggregated
445 }
446
447 fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
449 let scale = 1.0 - self.dropout;
450 input.iter().map(|&x| x * scale).collect()
451 }
452
453 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
455 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_linear_layer() {
465 let linear = Linear::new(4, 2);
466 let input = vec![1.0, 2.0, 3.0, 4.0];
467 let output = linear.forward(&input);
468 assert_eq!(output.len(), 2);
469 }
470
471 #[test]
472 fn test_layer_norm() {
473 let norm = LayerNorm::new(4, 1e-5);
474 let input = vec![1.0, 2.0, 3.0, 4.0];
475 let output = norm.forward(&input);
476
477 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
479 assert!((mean).abs() < 1e-5);
480 }
481
482 #[test]
483 fn test_multihead_attention() {
484 let attention = MultiHeadAttention::new(8, 2).unwrap();
485 let query = vec![0.5; 8];
486 let keys = vec![vec![0.3; 8], vec![0.7; 8]];
487 let values = vec![vec![0.2; 8], vec![0.8; 8]];
488
489 let output = attention.forward(&query, &keys, &values);
490 assert_eq!(output.len(), 8);
491 }
492
493 #[test]
494 fn test_multihead_attention_invalid_dims() {
495 let result = MultiHeadAttention::new(10, 3);
496 assert!(result.is_err());
497 let err = result.unwrap_err().to_string();
498 assert!(err.contains("divisible"));
499 }
500
501 #[test]
502 fn test_gru_cell() {
503 let gru = GRUCell::new(4, 8);
504 let input = vec![1.0; 4];
505 let hidden = vec![0.5; 8];
506
507 let new_hidden = gru.forward(&input, &hidden);
508 assert_eq!(new_hidden.len(), 8);
509 }
510
511 #[test]
512 fn test_ruvector_layer() {
513 let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
514
515 let node = vec![1.0, 2.0, 3.0, 4.0];
516 let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
517 let weights = vec![0.3, 0.7];
518
519 let output = layer.forward(&node, &neighbors, &weights);
520 assert_eq!(output.len(), 8);
521 }
522
523 #[test]
524 fn test_ruvector_layer_no_neighbors() {
525 let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
526
527 let node = vec![1.0, 2.0, 3.0, 4.0];
528 let neighbors: Vec<Vec<f32>> = vec![];
529 let weights: Vec<f32> = vec![];
530
531 let output = layer.forward(&node, &neighbors, &weights);
532 assert_eq!(output.len(), 8);
533 }
534
535 #[test]
536 fn test_ruvector_layer_invalid_dropout() {
537 let result = RuvectorLayer::new(4, 8, 2, 1.5);
538 assert!(result.is_err());
539 }
540
541 #[test]
542 fn test_ruvector_layer_invalid_heads() {
543 let result = RuvectorLayer::new(4, 7, 3, 0.1);
544 assert!(result.is_err());
545 }
546}