Skip to main content

ruvector_gnn/
layer.rs

1//! GNN Layer Implementation for HNSW Topology
2//!
3//! This module implements graph neural network layers that operate on HNSW graph structure,
4//! including attention mechanisms, normalization, and gated recurrent updates.
5
6use crate::error::GnnError;
7use ndarray::{Array1, Array2, ArrayView1};
8use rand::Rng;
9use rand_distr::{Distribution, Normal};
10use serde::{Deserialize, Serialize};
11
12/// Linear transformation layer (weight matrix multiplication)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Linear {
15    weights: Array2<f32>,
16    bias: Array1<f32>,
17}
18
19impl Linear {
20    /// Create a new linear layer with Xavier/Glorot initialization
21    pub fn new(input_dim: usize, output_dim: usize) -> Self {
22        let mut rng = rand::thread_rng();
23
24        // Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim))
25        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
26        let normal = Normal::new(0.0, scale as f64).unwrap();
27
28        let weights =
29            Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
30
31        let bias = Array1::zeros(output_dim);
32
33        Self { weights, bias }
34    }
35
36    /// Forward pass: y = Wx + b
37    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
38        let x = ArrayView1::from(input);
39        let output = self.weights.dot(&x) + &self.bias;
40        output.to_vec()
41    }
42
43    /// Get output dimension
44    pub fn output_dim(&self) -> usize {
45        self.weights.shape()[0]
46    }
47}
48
49/// Layer normalization
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LayerNorm {
52    gamma: Array1<f32>,
53    beta: Array1<f32>,
54    eps: f32,
55}
56
57impl LayerNorm {
58    /// Create a new layer normalization layer
59    pub fn new(dim: usize, eps: f32) -> Self {
60        Self {
61            gamma: Array1::ones(dim),
62            beta: Array1::zeros(dim),
63            eps,
64        }
65    }
66
67    /// Forward pass: normalize and scale
68    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
69        let x = ArrayView1::from(input);
70
71        // Compute mean and variance
72        let mean = x.mean().unwrap_or(0.0);
73        let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
74
75        // Normalize
76        let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
77
78        // Scale and shift
79        let output = &self.gamma * &normalized + &self.beta;
80        output.to_vec()
81    }
82}
83
84/// Multi-head attention mechanism
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct MultiHeadAttention {
87    num_heads: usize,
88    head_dim: usize,
89    q_linear: Linear,
90    k_linear: Linear,
91    v_linear: Linear,
92    out_linear: Linear,
93}
94
95impl MultiHeadAttention {
96    /// Create a new multi-head attention layer
97    ///
98    /// # Errors
99    /// Returns `GnnError::LayerConfig` if `embed_dim` is not divisible by `num_heads`.
100    pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
101        if embed_dim % num_heads != 0 {
102            return Err(GnnError::layer_config(format!(
103                "Embedding dimension ({}) must be divisible by number of heads ({})",
104                embed_dim, num_heads
105            )));
106        }
107
108        let head_dim = embed_dim / num_heads;
109
110        Ok(Self {
111            num_heads,
112            head_dim,
113            q_linear: Linear::new(embed_dim, embed_dim),
114            k_linear: Linear::new(embed_dim, embed_dim),
115            v_linear: Linear::new(embed_dim, embed_dim),
116            out_linear: Linear::new(embed_dim, embed_dim),
117        })
118    }
119
120    /// Forward pass: compute multi-head attention
121    ///
122    /// # Arguments
123    /// * `query` - Query vector
124    /// * `keys` - Key vectors from neighbors
125    /// * `values` - Value vectors from neighbors
126    ///
127    /// # Returns
128    /// Attention-weighted output vector
129    pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
130        if keys.is_empty() || values.is_empty() {
131            return query.to_vec();
132        }
133
134        // Project query, keys, and values
135        let q = self.q_linear.forward(query);
136        let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
137        let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
138
139        // Reshape for multi-head attention
140        let q_heads = self.split_heads(&q);
141        let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
142        let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
143
144        // Compute attention for each head
145        let mut head_outputs = Vec::new();
146        for h in 0..self.num_heads {
147            let q_h = &q_heads[h];
148            let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
149            let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
150
151            let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
152            head_outputs.push(head_output);
153        }
154
155        // Concatenate heads
156        let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
157
158        // Final linear projection
159        self.out_linear.forward(&concat)
160    }
161
162    /// Split vector into multiple heads
163    fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
164        let mut heads = Vec::new();
165        for h in 0..self.num_heads {
166            let start = h * self.head_dim;
167            let end = start + self.head_dim;
168            heads.push(x[start..end].to_vec());
169        }
170        heads
171    }
172
173    /// Scaled dot-product attention
174    fn scaled_dot_product_attention(
175        &self,
176        query: &[f32],
177        keys: &[&Vec<f32>],
178        values: &[&Vec<f32>],
179    ) -> Vec<f32> {
180        if keys.is_empty() {
181            return query.to_vec();
182        }
183
184        let scale = (self.head_dim as f32).sqrt();
185
186        // Compute attention scores
187        let scores: Vec<f32> = keys
188            .iter()
189            .map(|k| {
190                let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
191                dot / scale
192            })
193            .collect();
194
195        // Softmax with epsilon guard against division by zero
196        let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
197        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
198        let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
199        let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
200
201        // Weighted sum of values
202        let mut output = vec![0.0; self.head_dim];
203        for (weight, value) in attention_weights.iter().zip(values.iter()) {
204            for (out, &val) in output.iter_mut().zip(value.iter()) {
205                *out += weight * val;
206            }
207        }
208
209        output
210    }
211}
212
213/// Gated Recurrent Unit (GRU) cell for state updates
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct GRUCell {
216    // Update gate
217    w_z: Linear,
218    u_z: Linear,
219
220    // Reset gate
221    w_r: Linear,
222    u_r: Linear,
223
224    // Candidate hidden state
225    w_h: Linear,
226    u_h: Linear,
227}
228
229impl GRUCell {
230    /// Create a new GRU cell
231    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
232        Self {
233            // Update gate
234            w_z: Linear::new(input_dim, hidden_dim),
235            u_z: Linear::new(hidden_dim, hidden_dim),
236
237            // Reset gate
238            w_r: Linear::new(input_dim, hidden_dim),
239            u_r: Linear::new(hidden_dim, hidden_dim),
240
241            // Candidate hidden state
242            w_h: Linear::new(input_dim, hidden_dim),
243            u_h: Linear::new(hidden_dim, hidden_dim),
244        }
245    }
246
247    /// Forward pass: update hidden state
248    ///
249    /// # Arguments
250    /// * `input` - Current input
251    /// * `hidden` - Previous hidden state
252    ///
253    /// # Returns
254    /// Updated hidden state
255    pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
256        // Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1})
257        let z =
258            self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
259
260        // Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1})
261        let r =
262            self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
263
264        // Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1}))
265        let r_hidden = self.mul_vecs(&r, hidden);
266        let h_tilde =
267            self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
268
269        // Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde
270        let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
271        let term1 = self.mul_vecs(&one_minus_z, hidden);
272        let term2 = self.mul_vecs(&z, &h_tilde);
273
274        self.add_vecs(&term1, &term2)
275    }
276
277    /// Sigmoid activation with numerical stability
278    fn sigmoid(&self, x: f32) -> f32 {
279        if x > 0.0 {
280            1.0 / (1.0 + (-x).exp())
281        } else {
282            let ex = x.exp();
283            ex / (1.0 + ex)
284        }
285    }
286
287    /// Sigmoid for vectors
288    fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
289        v.iter().map(|&x| self.sigmoid(x)).collect()
290    }
291
292    /// Tanh activation
293    fn tanh(&self, x: f32) -> f32 {
294        x.tanh()
295    }
296
297    /// Tanh for vectors
298    fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
299        v.iter().map(|&x| self.tanh(x)).collect()
300    }
301
302    /// Element-wise addition
303    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
304        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
305    }
306
307    /// Element-wise multiplication
308    fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
309        a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
310    }
311}
312
313/// Main GNN layer operating on HNSW topology
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct RuvectorLayer {
316    /// Message weight matrix
317    w_msg: Linear,
318
319    /// Aggregation weight matrix
320    w_agg: Linear,
321
322    /// GRU update cell
323    w_update: GRUCell,
324
325    /// Multi-head attention
326    attention: MultiHeadAttention,
327
328    /// Layer normalization
329    norm: LayerNorm,
330
331    /// Dropout rate
332    dropout: f32,
333}
334
335impl RuvectorLayer {
336    /// Create a new Ruvector GNN layer
337    ///
338    /// # Arguments
339    /// * `input_dim` - Dimension of input node embeddings
340    /// * `hidden_dim` - Dimension of hidden representations
341    /// * `heads` - Number of attention heads
342    /// * `dropout` - Dropout rate (0.0 to 1.0)
343    ///
344    /// # Errors
345    /// Returns `GnnError::LayerConfig` if `dropout` is outside `[0.0, 1.0]` or
346    /// if `hidden_dim` is not divisible by `heads`.
347    pub fn new(
348        input_dim: usize,
349        hidden_dim: usize,
350        heads: usize,
351        dropout: f32,
352    ) -> Result<Self, GnnError> {
353        if !(0.0..=1.0).contains(&dropout) {
354            return Err(GnnError::layer_config(format!(
355                "Dropout must be between 0.0 and 1.0, got {}",
356                dropout
357            )));
358        }
359
360        Ok(Self {
361            w_msg: Linear::new(input_dim, hidden_dim),
362            w_agg: Linear::new(hidden_dim, hidden_dim),
363            w_update: GRUCell::new(hidden_dim, hidden_dim),
364            attention: MultiHeadAttention::new(hidden_dim, heads)?,
365            norm: LayerNorm::new(hidden_dim, 1e-5),
366            dropout,
367        })
368    }
369
370    /// Forward pass through the GNN layer
371    ///
372    /// # Arguments
373    /// * `node_embedding` - Current node's embedding
374    /// * `neighbor_embeddings` - Embeddings of neighbor nodes
375    /// * `edge_weights` - Weights of edges to neighbors (e.g., distances)
376    ///
377    /// # Returns
378    /// Updated node embedding
379    pub fn forward(
380        &self,
381        node_embedding: &[f32],
382        neighbor_embeddings: &[Vec<f32>],
383        edge_weights: &[f32],
384    ) -> Vec<f32> {
385        if neighbor_embeddings.is_empty() {
386            // No neighbors: return normalized projection
387            let projected = self.w_msg.forward(node_embedding);
388            return self.norm.forward(&projected);
389        }
390
391        // Step 1: Message passing - transform node and neighbor embeddings
392        let node_msg = self.w_msg.forward(node_embedding);
393        let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
394            .iter()
395            .map(|n| self.w_msg.forward(n))
396            .collect();
397
398        // Step 2: Attention-based aggregation
399        let attention_output = self
400            .attention
401            .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
402
403        // Step 3: Weighted aggregation using edge weights
404        let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
405
406        // Step 4: Combine attention and weighted aggregation
407        let combined = self.add_vecs(&attention_output, &weighted_msgs);
408        let aggregated = self.w_agg.forward(&combined);
409
410        // Step 5: GRU update
411        let updated = self.w_update.forward(&aggregated, &node_msg);
412
413        // Step 6: Apply dropout (simplified - always apply scaling)
414        let dropped = self.apply_dropout(&updated);
415
416        // Step 7: Layer normalization
417        self.norm.forward(&dropped)
418    }
419
420    /// Aggregate neighbor messages with edge weights
421    fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
422        if messages.is_empty() || weights.is_empty() {
423            return vec![0.0; self.w_msg.output_dim()];
424        }
425
426        // Normalize weights to sum to 1
427        let weight_sum: f32 = weights.iter().sum();
428        let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
429            weights.iter().map(|&w| w / weight_sum).collect()
430        } else {
431            vec![1.0 / weights.len() as f32; weights.len()]
432        };
433
434        // Weighted sum
435        let dim = messages[0].len();
436        let mut aggregated = vec![0.0; dim];
437
438        for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
439            for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
440                *agg += weight * m;
441            }
442        }
443
444        aggregated
445    }
446
447    /// Apply dropout (simplified version - just scales by (1-dropout))
448    fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
449        let scale = 1.0 - self.dropout;
450        input.iter().map(|&x| x * scale).collect()
451    }
452
453    /// Element-wise vector addition
454    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
455        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_linear_layer() {
465        let linear = Linear::new(4, 2);
466        let input = vec![1.0, 2.0, 3.0, 4.0];
467        let output = linear.forward(&input);
468        assert_eq!(output.len(), 2);
469    }
470
471    #[test]
472    fn test_layer_norm() {
473        let norm = LayerNorm::new(4, 1e-5);
474        let input = vec![1.0, 2.0, 3.0, 4.0];
475        let output = norm.forward(&input);
476
477        // Check that output has zero mean (approximately)
478        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
479        assert!((mean).abs() < 1e-5);
480    }
481
482    #[test]
483    fn test_multihead_attention() {
484        let attention = MultiHeadAttention::new(8, 2).unwrap();
485        let query = vec![0.5; 8];
486        let keys = vec![vec![0.3; 8], vec![0.7; 8]];
487        let values = vec![vec![0.2; 8], vec![0.8; 8]];
488
489        let output = attention.forward(&query, &keys, &values);
490        assert_eq!(output.len(), 8);
491    }
492
493    #[test]
494    fn test_multihead_attention_invalid_dims() {
495        let result = MultiHeadAttention::new(10, 3);
496        assert!(result.is_err());
497        let err = result.unwrap_err().to_string();
498        assert!(err.contains("divisible"));
499    }
500
501    #[test]
502    fn test_gru_cell() {
503        let gru = GRUCell::new(4, 8);
504        let input = vec![1.0; 4];
505        let hidden = vec![0.5; 8];
506
507        let new_hidden = gru.forward(&input, &hidden);
508        assert_eq!(new_hidden.len(), 8);
509    }
510
511    #[test]
512    fn test_ruvector_layer() {
513        let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
514
515        let node = vec![1.0, 2.0, 3.0, 4.0];
516        let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
517        let weights = vec![0.3, 0.7];
518
519        let output = layer.forward(&node, &neighbors, &weights);
520        assert_eq!(output.len(), 8);
521    }
522
523    #[test]
524    fn test_ruvector_layer_no_neighbors() {
525        let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
526
527        let node = vec![1.0, 2.0, 3.0, 4.0];
528        let neighbors: Vec<Vec<f32>> = vec![];
529        let weights: Vec<f32> = vec![];
530
531        let output = layer.forward(&node, &neighbors, &weights);
532        assert_eq!(output.len(), 8);
533    }
534
535    #[test]
536    fn test_ruvector_layer_invalid_dropout() {
537        let result = RuvectorLayer::new(4, 8, 2, 1.5);
538        assert!(result.is_err());
539    }
540
541    #[test]
542    fn test_ruvector_layer_invalid_heads() {
543        let result = RuvectorLayer::new(4, 7, 3, 0.1);
544        assert!(result.is_err());
545    }
546}