Skip to main content

ruvector_gnn/
layer.rs

1//! GNN Layer Implementation for HNSW Topology
2//!
3//! This module implements graph neural network layers that operate on HNSW graph structure,
4//! including attention mechanisms, normalization, and gated recurrent updates.
5
6use crate::error::GnnError;
7use ndarray::{Array1, Array2, ArrayView1};
8use rand::SeedableRng;
9use rand_distr::{Distribution, Normal};
10use serde::{Deserialize, Serialize};
11
12/// Linear transformation layer (weight matrix multiplication)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Linear {
15    weights: Array2<f32>,
16    bias: Array1<f32>,
17}
18
19impl Linear {
20    /// Create a new linear layer with Xavier/Glorot initialization.
21    /// Uses a deterministic seeded RNG (faster than thread_rng on all platforms,
22    /// especially ARM64) while still producing well-distributed weights.
23    pub fn new(input_dim: usize, output_dim: usize) -> Self {
24        // Seed from dims so layers with different shapes get distinct weights
25        let seed = (input_dim as u64).wrapping_mul(6364136223846793005)
26            ^ (output_dim as u64).wrapping_mul(1442695040888963407);
27        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
28
29        // Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim))
30        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
31        let normal = Normal::new(0.0, scale as f64).unwrap();
32
33        let weights =
34            Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
35
36        let bias = Array1::zeros(output_dim);
37
38        Self { weights, bias }
39    }
40
41    /// Forward pass: y = Wx + b
42    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
43        let x = ArrayView1::from(input);
44        let output = self.weights.dot(&x) + &self.bias;
45        output.to_vec()
46    }
47
48    /// Get output dimension
49    pub fn output_dim(&self) -> usize {
50        self.weights.shape()[0]
51    }
52}
53
54/// Layer normalization
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LayerNorm {
57    gamma: Array1<f32>,
58    beta: Array1<f32>,
59    eps: f32,
60}
61
62impl LayerNorm {
63    /// Create a new layer normalization layer
64    pub fn new(dim: usize, eps: f32) -> Self {
65        Self {
66            gamma: Array1::ones(dim),
67            beta: Array1::zeros(dim),
68            eps,
69        }
70    }
71
72    /// Forward pass: normalize and scale
73    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
74        let x = ArrayView1::from(input);
75
76        // Compute mean and variance
77        let mean = x.mean().unwrap_or(0.0);
78        let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
79
80        // Normalize
81        let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
82
83        // Scale and shift
84        let output = &self.gamma * &normalized + &self.beta;
85        output.to_vec()
86    }
87}
88
89/// Multi-head attention mechanism
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MultiHeadAttention {
92    num_heads: usize,
93    head_dim: usize,
94    q_linear: Linear,
95    k_linear: Linear,
96    v_linear: Linear,
97    out_linear: Linear,
98}
99
100impl MultiHeadAttention {
101    /// Create a new multi-head attention layer
102    ///
103    /// # Errors
104    /// Returns `GnnError::LayerConfig` if `embed_dim` is not divisible by `num_heads`.
105    pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
106        if embed_dim % num_heads != 0 {
107            return Err(GnnError::layer_config(format!(
108                "Embedding dimension ({}) must be divisible by number of heads ({})",
109                embed_dim, num_heads
110            )));
111        }
112
113        let head_dim = embed_dim / num_heads;
114
115        Ok(Self {
116            num_heads,
117            head_dim,
118            q_linear: Linear::new(embed_dim, embed_dim),
119            k_linear: Linear::new(embed_dim, embed_dim),
120            v_linear: Linear::new(embed_dim, embed_dim),
121            out_linear: Linear::new(embed_dim, embed_dim),
122        })
123    }
124
125    /// Forward pass: compute multi-head attention
126    ///
127    /// # Arguments
128    /// * `query` - Query vector
129    /// * `keys` - Key vectors from neighbors
130    /// * `values` - Value vectors from neighbors
131    ///
132    /// # Returns
133    /// Attention-weighted output vector
134    pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
135        if keys.is_empty() || values.is_empty() {
136            return query.to_vec();
137        }
138
139        // Project query, keys, and values
140        let q = self.q_linear.forward(query);
141        let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
142        let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
143
144        // Reshape for multi-head attention
145        let q_heads = self.split_heads(&q);
146        let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
147        let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
148
149        // Compute attention for each head
150        let mut head_outputs = Vec::new();
151        for h in 0..self.num_heads {
152            let q_h = &q_heads[h];
153            let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
154            let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
155
156            let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
157            head_outputs.push(head_output);
158        }
159
160        // Concatenate heads
161        let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
162
163        // Final linear projection
164        self.out_linear.forward(&concat)
165    }
166
167    /// Split vector into multiple heads
168    fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
169        let mut heads = Vec::new();
170        for h in 0..self.num_heads {
171            let start = h * self.head_dim;
172            let end = start + self.head_dim;
173            heads.push(x[start..end].to_vec());
174        }
175        heads
176    }
177
178    /// Scaled dot-product attention
179    fn scaled_dot_product_attention(
180        &self,
181        query: &[f32],
182        keys: &[&Vec<f32>],
183        values: &[&Vec<f32>],
184    ) -> Vec<f32> {
185        if keys.is_empty() {
186            return query.to_vec();
187        }
188
189        let scale = (self.head_dim as f32).sqrt();
190
191        // Compute attention scores
192        let scores: Vec<f32> = keys
193            .iter()
194            .map(|k| {
195                let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
196                dot / scale
197            })
198            .collect();
199
200        // Softmax with epsilon guard against division by zero
201        let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
202        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
203        let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
204        let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
205
206        // Weighted sum of values
207        let mut output = vec![0.0; self.head_dim];
208        for (weight, value) in attention_weights.iter().zip(values.iter()) {
209            for (out, &val) in output.iter_mut().zip(value.iter()) {
210                *out += weight * val;
211            }
212        }
213
214        output
215    }
216}
217
218/// Gated Recurrent Unit (GRU) cell for state updates
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct GRUCell {
221    // Update gate
222    w_z: Linear,
223    u_z: Linear,
224
225    // Reset gate
226    w_r: Linear,
227    u_r: Linear,
228
229    // Candidate hidden state
230    w_h: Linear,
231    u_h: Linear,
232}
233
234impl GRUCell {
235    /// Create a new GRU cell
236    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
237        Self {
238            // Update gate
239            w_z: Linear::new(input_dim, hidden_dim),
240            u_z: Linear::new(hidden_dim, hidden_dim),
241
242            // Reset gate
243            w_r: Linear::new(input_dim, hidden_dim),
244            u_r: Linear::new(hidden_dim, hidden_dim),
245
246            // Candidate hidden state
247            w_h: Linear::new(input_dim, hidden_dim),
248            u_h: Linear::new(hidden_dim, hidden_dim),
249        }
250    }
251
252    /// Forward pass: update hidden state
253    ///
254    /// # Arguments
255    /// * `input` - Current input
256    /// * `hidden` - Previous hidden state
257    ///
258    /// # Returns
259    /// Updated hidden state
260    pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
261        // Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1})
262        let z =
263            self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
264
265        // Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1})
266        let r =
267            self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
268
269        // Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1}))
270        let r_hidden = self.mul_vecs(&r, hidden);
271        let h_tilde =
272            self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
273
274        // Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde
275        let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
276        let term1 = self.mul_vecs(&one_minus_z, hidden);
277        let term2 = self.mul_vecs(&z, &h_tilde);
278
279        self.add_vecs(&term1, &term2)
280    }
281
282    /// Sigmoid activation with numerical stability
283    fn sigmoid(&self, x: f32) -> f32 {
284        if x > 0.0 {
285            1.0 / (1.0 + (-x).exp())
286        } else {
287            let ex = x.exp();
288            ex / (1.0 + ex)
289        }
290    }
291
292    /// Sigmoid for vectors
293    fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
294        v.iter().map(|&x| self.sigmoid(x)).collect()
295    }
296
297    /// Tanh activation
298    fn tanh(&self, x: f32) -> f32 {
299        x.tanh()
300    }
301
302    /// Tanh for vectors
303    fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
304        v.iter().map(|&x| self.tanh(x)).collect()
305    }
306
307    /// Element-wise addition
308    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
309        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
310    }
311
312    /// Element-wise multiplication
313    fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
314        a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
315    }
316}
317
318/// Main GNN layer operating on HNSW topology
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct RuvectorLayer {
321    /// Message weight matrix
322    w_msg: Linear,
323
324    /// Aggregation weight matrix
325    w_agg: Linear,
326
327    /// GRU update cell
328    w_update: GRUCell,
329
330    /// Multi-head attention
331    attention: MultiHeadAttention,
332
333    /// Layer normalization
334    norm: LayerNorm,
335
336    /// Dropout rate
337    dropout: f32,
338}
339
340impl RuvectorLayer {
341    /// Create a new Ruvector GNN layer
342    ///
343    /// # Arguments
344    /// * `input_dim` - Dimension of input node embeddings
345    /// * `hidden_dim` - Dimension of hidden representations
346    /// * `heads` - Number of attention heads
347    /// * `dropout` - Dropout rate (0.0 to 1.0)
348    ///
349    /// # Errors
350    /// Returns `GnnError::LayerConfig` if `dropout` is outside `[0.0, 1.0]` or
351    /// if `hidden_dim` is not divisible by `heads`.
352    pub fn new(
353        input_dim: usize,
354        hidden_dim: usize,
355        heads: usize,
356        dropout: f32,
357    ) -> Result<Self, GnnError> {
358        if !(0.0..=1.0).contains(&dropout) {
359            return Err(GnnError::layer_config(format!(
360                "Dropout must be between 0.0 and 1.0, got {}",
361                dropout
362            )));
363        }
364
365        Ok(Self {
366            w_msg: Linear::new(input_dim, hidden_dim),
367            w_agg: Linear::new(hidden_dim, hidden_dim),
368            w_update: GRUCell::new(hidden_dim, hidden_dim),
369            attention: MultiHeadAttention::new(hidden_dim, heads)?,
370            norm: LayerNorm::new(hidden_dim, 1e-5),
371            dropout,
372        })
373    }
374
375    /// Forward pass through the GNN layer
376    ///
377    /// # Arguments
378    /// * `node_embedding` - Current node's embedding
379    /// * `neighbor_embeddings` - Embeddings of neighbor nodes
380    /// * `edge_weights` - Weights of edges to neighbors (e.g., distances)
381    ///
382    /// # Returns
383    /// Updated node embedding
384    pub fn forward(
385        &self,
386        node_embedding: &[f32],
387        neighbor_embeddings: &[Vec<f32>],
388        edge_weights: &[f32],
389    ) -> Vec<f32> {
390        if neighbor_embeddings.is_empty() {
391            // No neighbors: return normalized projection
392            let projected = self.w_msg.forward(node_embedding);
393            return self.norm.forward(&projected);
394        }
395
396        // Step 1: Message passing - transform node and neighbor embeddings
397        let node_msg = self.w_msg.forward(node_embedding);
398        let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
399            .iter()
400            .map(|n| self.w_msg.forward(n))
401            .collect();
402
403        // Step 2: Attention-based aggregation
404        let attention_output = self
405            .attention
406            .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
407
408        // Step 3: Weighted aggregation using edge weights
409        let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
410
411        // Step 4: Combine attention and weighted aggregation
412        let combined = self.add_vecs(&attention_output, &weighted_msgs);
413        let aggregated = self.w_agg.forward(&combined);
414
415        // Step 5: GRU update
416        let updated = self.w_update.forward(&aggregated, &node_msg);
417
418        // Step 6: Apply dropout (simplified - always apply scaling)
419        let dropped = self.apply_dropout(&updated);
420
421        // Step 7: Layer normalization
422        self.norm.forward(&dropped)
423    }
424
425    /// Aggregate neighbor messages with edge weights
426    fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
427        if messages.is_empty() || weights.is_empty() {
428            return vec![0.0; self.w_msg.output_dim()];
429        }
430
431        // Normalize weights to sum to 1
432        let weight_sum: f32 = weights.iter().sum();
433        let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
434            weights.iter().map(|&w| w / weight_sum).collect()
435        } else {
436            vec![1.0 / weights.len() as f32; weights.len()]
437        };
438
439        // Weighted sum
440        let dim = messages[0].len();
441        let mut aggregated = vec![0.0; dim];
442
443        for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
444            for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
445                *agg += weight * m;
446            }
447        }
448
449        aggregated
450    }
451
452    /// Apply dropout (simplified version - just scales by (1-dropout))
453    fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
454        let scale = 1.0 - self.dropout;
455        input.iter().map(|&x| x * scale).collect()
456    }
457
458    /// Element-wise vector addition
459    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
460        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn test_linear_layer() {
470        let linear = Linear::new(4, 2);
471        let input = vec![1.0, 2.0, 3.0, 4.0];
472        let output = linear.forward(&input);
473        assert_eq!(output.len(), 2);
474    }
475
476    #[test]
477    fn test_layer_norm() {
478        let norm = LayerNorm::new(4, 1e-5);
479        let input = vec![1.0, 2.0, 3.0, 4.0];
480        let output = norm.forward(&input);
481
482        // Check that output has zero mean (approximately)
483        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
484        assert!((mean).abs() < 1e-5);
485    }
486
487    #[test]
488    fn test_multihead_attention() {
489        let attention = MultiHeadAttention::new(8, 2).unwrap();
490        let query = vec![0.5; 8];
491        let keys = vec![vec![0.3; 8], vec![0.7; 8]];
492        let values = vec![vec![0.2; 8], vec![0.8; 8]];
493
494        let output = attention.forward(&query, &keys, &values);
495        assert_eq!(output.len(), 8);
496    }
497
498    #[test]
499    fn test_multihead_attention_invalid_dims() {
500        let result = MultiHeadAttention::new(10, 3);
501        assert!(result.is_err());
502        let err = result.unwrap_err().to_string();
503        assert!(err.contains("divisible"));
504    }
505
506    #[test]
507    fn test_gru_cell() {
508        let gru = GRUCell::new(4, 8);
509        let input = vec![1.0; 4];
510        let hidden = vec![0.5; 8];
511
512        let new_hidden = gru.forward(&input, &hidden);
513        assert_eq!(new_hidden.len(), 8);
514    }
515
516    #[test]
517    fn test_ruvector_layer() {
518        let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
519
520        let node = vec![1.0, 2.0, 3.0, 4.0];
521        let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
522        let weights = vec![0.3, 0.7];
523
524        let output = layer.forward(&node, &neighbors, &weights);
525        assert_eq!(output.len(), 8);
526    }
527
528    #[test]
529    fn test_ruvector_layer_no_neighbors() {
530        let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
531
532        let node = vec![1.0, 2.0, 3.0, 4.0];
533        let neighbors: Vec<Vec<f32>> = vec![];
534        let weights: Vec<f32> = vec![];
535
536        let output = layer.forward(&node, &neighbors, &weights);
537        assert_eq!(output.len(), 8);
538    }
539
540    #[test]
541    fn test_ruvector_layer_invalid_dropout() {
542        let result = RuvectorLayer::new(4, 8, 2, 1.5);
543        assert!(result.is_err());
544    }
545
546    #[test]
547    fn test_ruvector_layer_invalid_heads() {
548        let result = RuvectorLayer::new(4, 7, 3, 0.1);
549        assert!(result.is_err());
550    }
551}