1use ndarray::{Array1, Array2, ArrayView1};
7use rand::Rng;
8use rand_distr::{Distribution, Normal};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Linear {
14 weights: Array2<f32>,
15 bias: Array1<f32>,
16}
17
18impl Linear {
19 pub fn new(input_dim: usize, output_dim: usize) -> Self {
21 let mut rng = rand::thread_rng();
22
23 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
25 let normal = Normal::new(0.0, scale as f64).unwrap();
26
27 let weights =
28 Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
29
30 let bias = Array1::zeros(output_dim);
31
32 Self { weights, bias }
33 }
34
35 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
37 let x = ArrayView1::from(input);
38 let output = self.weights.dot(&x) + &self.bias;
39 output.to_vec()
40 }
41
42 pub fn output_dim(&self) -> usize {
44 self.weights.shape()[0]
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LayerNorm {
51 gamma: Array1<f32>,
52 beta: Array1<f32>,
53 eps: f32,
54}
55
56impl LayerNorm {
57 pub fn new(dim: usize, eps: f32) -> Self {
59 Self {
60 gamma: Array1::ones(dim),
61 beta: Array1::zeros(dim),
62 eps,
63 }
64 }
65
66 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
68 let x = ArrayView1::from(input);
69
70 let mean = x.mean().unwrap_or(0.0);
72 let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
73
74 let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
76
77 let output = &self.gamma * &normalized + &self.beta;
79 output.to_vec()
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct MultiHeadAttention {
86 num_heads: usize,
87 head_dim: usize,
88 q_linear: Linear,
89 k_linear: Linear,
90 v_linear: Linear,
91 out_linear: Linear,
92}
93
94impl MultiHeadAttention {
95 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
97 assert!(
98 embed_dim % num_heads == 0,
99 "Embedding dimension must be divisible by number of heads"
100 );
101
102 let head_dim = embed_dim / num_heads;
103
104 Self {
105 num_heads,
106 head_dim,
107 q_linear: Linear::new(embed_dim, embed_dim),
108 k_linear: Linear::new(embed_dim, embed_dim),
109 v_linear: Linear::new(embed_dim, embed_dim),
110 out_linear: Linear::new(embed_dim, embed_dim),
111 }
112 }
113
114 pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
124 if keys.is_empty() || values.is_empty() {
125 return query.to_vec();
126 }
127
128 let q = self.q_linear.forward(query);
130 let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
131 let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
132
133 let q_heads = self.split_heads(&q);
135 let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
136 let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
137
138 let mut head_outputs = Vec::new();
140 for h in 0..self.num_heads {
141 let q_h = &q_heads[h];
142 let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
143 let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
144
145 let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
146 head_outputs.push(head_output);
147 }
148
149 let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
151
152 self.out_linear.forward(&concat)
154 }
155
156 fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
158 let mut heads = Vec::new();
159 for h in 0..self.num_heads {
160 let start = h * self.head_dim;
161 let end = start + self.head_dim;
162 heads.push(x[start..end].to_vec());
163 }
164 heads
165 }
166
167 fn scaled_dot_product_attention(
169 &self,
170 query: &[f32],
171 keys: &[&Vec<f32>],
172 values: &[&Vec<f32>],
173 ) -> Vec<f32> {
174 if keys.is_empty() {
175 return query.to_vec();
176 }
177
178 let scale = (self.head_dim as f32).sqrt();
179
180 let scores: Vec<f32> = keys
182 .iter()
183 .map(|k| {
184 let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
185 dot / scale
186 })
187 .collect();
188
189 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
191 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
192 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
193 let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
194
195 let mut output = vec![0.0; self.head_dim];
197 for (weight, value) in attention_weights.iter().zip(values.iter()) {
198 for (out, &val) in output.iter_mut().zip(value.iter()) {
199 *out += weight * val;
200 }
201 }
202
203 output
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GRUCell {
210 w_z: Linear,
212 u_z: Linear,
213
214 w_r: Linear,
216 u_r: Linear,
217
218 w_h: Linear,
220 u_h: Linear,
221}
222
223impl GRUCell {
224 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
226 Self {
227 w_z: Linear::new(input_dim, hidden_dim),
229 u_z: Linear::new(hidden_dim, hidden_dim),
230
231 w_r: Linear::new(input_dim, hidden_dim),
233 u_r: Linear::new(hidden_dim, hidden_dim),
234
235 w_h: Linear::new(input_dim, hidden_dim),
237 u_h: Linear::new(hidden_dim, hidden_dim),
238 }
239 }
240
241 pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
250 let z =
252 self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
253
254 let r =
256 self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
257
258 let r_hidden = self.mul_vecs(&r, hidden);
260 let h_tilde =
261 self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
262
263 let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
265 let term1 = self.mul_vecs(&one_minus_z, hidden);
266 let term2 = self.mul_vecs(&z, &h_tilde);
267
268 self.add_vecs(&term1, &term2)
269 }
270
271 fn sigmoid(&self, x: f32) -> f32 {
273 if x > 0.0 {
274 1.0 / (1.0 + (-x).exp())
275 } else {
276 let ex = x.exp();
277 ex / (1.0 + ex)
278 }
279 }
280
281 fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
283 v.iter().map(|&x| self.sigmoid(x)).collect()
284 }
285
286 fn tanh(&self, x: f32) -> f32 {
288 x.tanh()
289 }
290
291 fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
293 v.iter().map(|&x| self.tanh(x)).collect()
294 }
295
296 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
298 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
299 }
300
301 fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
303 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
304 }
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct RuvectorLayer {
310 w_msg: Linear,
312
313 w_agg: Linear,
315
316 w_update: GRUCell,
318
319 attention: MultiHeadAttention,
321
322 norm: LayerNorm,
324
325 dropout: f32,
327}
328
329impl RuvectorLayer {
330 pub fn new(input_dim: usize, hidden_dim: usize, heads: usize, dropout: f32) -> Self {
338 assert!(
339 dropout >= 0.0 && dropout <= 1.0,
340 "Dropout must be between 0.0 and 1.0"
341 );
342
343 Self {
344 w_msg: Linear::new(input_dim, hidden_dim),
345 w_agg: Linear::new(hidden_dim, hidden_dim),
346 w_update: GRUCell::new(hidden_dim, hidden_dim),
347 attention: MultiHeadAttention::new(hidden_dim, heads),
348 norm: LayerNorm::new(hidden_dim, 1e-5),
349 dropout,
350 }
351 }
352
353 pub fn forward(
363 &self,
364 node_embedding: &[f32],
365 neighbor_embeddings: &[Vec<f32>],
366 edge_weights: &[f32],
367 ) -> Vec<f32> {
368 if neighbor_embeddings.is_empty() {
369 let projected = self.w_msg.forward(node_embedding);
371 return self.norm.forward(&projected);
372 }
373
374 let node_msg = self.w_msg.forward(node_embedding);
376 let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
377 .iter()
378 .map(|n| self.w_msg.forward(n))
379 .collect();
380
381 let attention_output = self
383 .attention
384 .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
385
386 let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
388
389 let combined = self.add_vecs(&attention_output, &weighted_msgs);
391 let aggregated = self.w_agg.forward(&combined);
392
393 let updated = self.w_update.forward(&aggregated, &node_msg);
395
396 let dropped = self.apply_dropout(&updated);
398
399 self.norm.forward(&dropped)
401 }
402
403 fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
405 if messages.is_empty() || weights.is_empty() {
406 return vec![0.0; self.w_msg.output_dim()];
407 }
408
409 let weight_sum: f32 = weights.iter().sum();
411 let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
412 weights.iter().map(|&w| w / weight_sum).collect()
413 } else {
414 vec![1.0 / weights.len() as f32; weights.len()]
415 };
416
417 let dim = messages[0].len();
419 let mut aggregated = vec![0.0; dim];
420
421 for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
422 for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
423 *agg += weight * m;
424 }
425 }
426
427 aggregated
428 }
429
430 fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
432 let scale = 1.0 - self.dropout;
433 input.iter().map(|&x| x * scale).collect()
434 }
435
436 fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
438 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_linear_layer() {
448 let linear = Linear::new(4, 2);
449 let input = vec![1.0, 2.0, 3.0, 4.0];
450 let output = linear.forward(&input);
451 assert_eq!(output.len(), 2);
452 }
453
454 #[test]
455 fn test_layer_norm() {
456 let norm = LayerNorm::new(4, 1e-5);
457 let input = vec![1.0, 2.0, 3.0, 4.0];
458 let output = norm.forward(&input);
459
460 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
462 assert!((mean).abs() < 1e-5);
463 }
464
465 #[test]
466 fn test_multihead_attention() {
467 let attention = MultiHeadAttention::new(8, 2);
468 let query = vec![0.5; 8];
469 let keys = vec![vec![0.3; 8], vec![0.7; 8]];
470 let values = vec![vec![0.2; 8], vec![0.8; 8]];
471
472 let output = attention.forward(&query, &keys, &values);
473 assert_eq!(output.len(), 8);
474 }
475
476 #[test]
477 fn test_gru_cell() {
478 let gru = GRUCell::new(4, 8);
479 let input = vec![1.0; 4];
480 let hidden = vec![0.5; 8];
481
482 let new_hidden = gru.forward(&input, &hidden);
483 assert_eq!(new_hidden.len(), 8);
484 }
485
486 #[test]
487 fn test_ruvector_layer() {
488 let layer = RuvectorLayer::new(4, 8, 2, 0.1);
489
490 let node = vec![1.0, 2.0, 3.0, 4.0];
491 let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
492 let weights = vec![0.3, 0.7];
493
494 let output = layer.forward(&node, &neighbors, &weights);
495 assert_eq!(output.len(), 8);
496 }
497
498 #[test]
499 fn test_ruvector_layer_no_neighbors() {
500 let layer = RuvectorLayer::new(4, 8, 2, 0.1);
501
502 let node = vec![1.0, 2.0, 3.0, 4.0];
503 let neighbors: Vec<Vec<f32>> = vec![];
504 let weights: Vec<f32> = vec![];
505
506 let output = layer.forward(&node, &neighbors, &weights);
507 assert_eq!(output.len(), 8);
508 }
509}