ruvector_gnn/
layer.rs

1//! GNN Layer Implementation for HNSW Topology
2//!
3//! This module implements graph neural network layers that operate on HNSW graph structure,
4//! including attention mechanisms, normalization, and gated recurrent updates.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use rand::Rng;
8use rand_distr::{Distribution, Normal};
9use serde::{Deserialize, Serialize};
10
11/// Linear transformation layer (weight matrix multiplication)
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Linear {
14    weights: Array2<f32>,
15    bias: Array1<f32>,
16}
17
18impl Linear {
19    /// Create a new linear layer with Xavier/Glorot initialization
20    pub fn new(input_dim: usize, output_dim: usize) -> Self {
21        let mut rng = rand::thread_rng();
22
23        // Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim))
24        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
25        let normal = Normal::new(0.0, scale as f64).unwrap();
26
27        let weights =
28            Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
29
30        let bias = Array1::zeros(output_dim);
31
32        Self { weights, bias }
33    }
34
35    /// Forward pass: y = Wx + b
36    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
37        let x = ArrayView1::from(input);
38        let output = self.weights.dot(&x) + &self.bias;
39        output.to_vec()
40    }
41
42    /// Get output dimension
43    pub fn output_dim(&self) -> usize {
44        self.weights.shape()[0]
45    }
46}
47
48/// Layer normalization
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LayerNorm {
51    gamma: Array1<f32>,
52    beta: Array1<f32>,
53    eps: f32,
54}
55
56impl LayerNorm {
57    /// Create a new layer normalization layer
58    pub fn new(dim: usize, eps: f32) -> Self {
59        Self {
60            gamma: Array1::ones(dim),
61            beta: Array1::zeros(dim),
62            eps,
63        }
64    }
65
66    /// Forward pass: normalize and scale
67    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
68        let x = ArrayView1::from(input);
69
70        // Compute mean and variance
71        let mean = x.mean().unwrap_or(0.0);
72        let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
73
74        // Normalize
75        let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
76
77        // Scale and shift
78        let output = &self.gamma * &normalized + &self.beta;
79        output.to_vec()
80    }
81}
82
83/// Multi-head attention mechanism
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct MultiHeadAttention {
86    num_heads: usize,
87    head_dim: usize,
88    q_linear: Linear,
89    k_linear: Linear,
90    v_linear: Linear,
91    out_linear: Linear,
92}
93
94impl MultiHeadAttention {
95    /// Create a new multi-head attention layer
96    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
97        assert!(
98            embed_dim % num_heads == 0,
99            "Embedding dimension must be divisible by number of heads"
100        );
101
102        let head_dim = embed_dim / num_heads;
103
104        Self {
105            num_heads,
106            head_dim,
107            q_linear: Linear::new(embed_dim, embed_dim),
108            k_linear: Linear::new(embed_dim, embed_dim),
109            v_linear: Linear::new(embed_dim, embed_dim),
110            out_linear: Linear::new(embed_dim, embed_dim),
111        }
112    }
113
114    /// Forward pass: compute multi-head attention
115    ///
116    /// # Arguments
117    /// * `query` - Query vector
118    /// * `keys` - Key vectors from neighbors
119    /// * `values` - Value vectors from neighbors
120    ///
121    /// # Returns
122    /// Attention-weighted output vector
123    pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
124        if keys.is_empty() || values.is_empty() {
125            return query.to_vec();
126        }
127
128        // Project query, keys, and values
129        let q = self.q_linear.forward(query);
130        let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
131        let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
132
133        // Reshape for multi-head attention
134        let q_heads = self.split_heads(&q);
135        let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
136        let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
137
138        // Compute attention for each head
139        let mut head_outputs = Vec::new();
140        for h in 0..self.num_heads {
141            let q_h = &q_heads[h];
142            let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
143            let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
144
145            let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
146            head_outputs.push(head_output);
147        }
148
149        // Concatenate heads
150        let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
151
152        // Final linear projection
153        self.out_linear.forward(&concat)
154    }
155
156    /// Split vector into multiple heads
157    fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
158        let mut heads = Vec::new();
159        for h in 0..self.num_heads {
160            let start = h * self.head_dim;
161            let end = start + self.head_dim;
162            heads.push(x[start..end].to_vec());
163        }
164        heads
165    }
166
167    /// Scaled dot-product attention
168    fn scaled_dot_product_attention(
169        &self,
170        query: &[f32],
171        keys: &[&Vec<f32>],
172        values: &[&Vec<f32>],
173    ) -> Vec<f32> {
174        if keys.is_empty() {
175            return query.to_vec();
176        }
177
178        let scale = (self.head_dim as f32).sqrt();
179
180        // Compute attention scores
181        let scores: Vec<f32> = keys
182            .iter()
183            .map(|k| {
184                let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
185                dot / scale
186            })
187            .collect();
188
189        // Softmax with epsilon guard against division by zero
190        let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
191        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
192        let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
193        let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
194
195        // Weighted sum of values
196        let mut output = vec![0.0; self.head_dim];
197        for (weight, value) in attention_weights.iter().zip(values.iter()) {
198            for (out, &val) in output.iter_mut().zip(value.iter()) {
199                *out += weight * val;
200            }
201        }
202
203        output
204    }
205}
206
207/// Gated Recurrent Unit (GRU) cell for state updates
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GRUCell {
210    // Update gate
211    w_z: Linear,
212    u_z: Linear,
213
214    // Reset gate
215    w_r: Linear,
216    u_r: Linear,
217
218    // Candidate hidden state
219    w_h: Linear,
220    u_h: Linear,
221}
222
223impl GRUCell {
224    /// Create a new GRU cell
225    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
226        Self {
227            // Update gate
228            w_z: Linear::new(input_dim, hidden_dim),
229            u_z: Linear::new(hidden_dim, hidden_dim),
230
231            // Reset gate
232            w_r: Linear::new(input_dim, hidden_dim),
233            u_r: Linear::new(hidden_dim, hidden_dim),
234
235            // Candidate hidden state
236            w_h: Linear::new(input_dim, hidden_dim),
237            u_h: Linear::new(hidden_dim, hidden_dim),
238        }
239    }
240
241    /// Forward pass: update hidden state
242    ///
243    /// # Arguments
244    /// * `input` - Current input
245    /// * `hidden` - Previous hidden state
246    ///
247    /// # Returns
248    /// Updated hidden state
249    pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
250        // Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1})
251        let z =
252            self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
253
254        // Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1})
255        let r =
256            self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
257
258        // Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1}))
259        let r_hidden = self.mul_vecs(&r, hidden);
260        let h_tilde =
261            self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
262
263        // Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde
264        let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
265        let term1 = self.mul_vecs(&one_minus_z, hidden);
266        let term2 = self.mul_vecs(&z, &h_tilde);
267
268        self.add_vecs(&term1, &term2)
269    }
270
271    /// Sigmoid activation with numerical stability
272    fn sigmoid(&self, x: f32) -> f32 {
273        if x > 0.0 {
274            1.0 / (1.0 + (-x).exp())
275        } else {
276            let ex = x.exp();
277            ex / (1.0 + ex)
278        }
279    }
280
281    /// Sigmoid for vectors
282    fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
283        v.iter().map(|&x| self.sigmoid(x)).collect()
284    }
285
286    /// Tanh activation
287    fn tanh(&self, x: f32) -> f32 {
288        x.tanh()
289    }
290
291    /// Tanh for vectors
292    fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
293        v.iter().map(|&x| self.tanh(x)).collect()
294    }
295
296    /// Element-wise addition
297    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
298        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
299    }
300
301    /// Element-wise multiplication
302    fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
303        a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
304    }
305}
306
307/// Main GNN layer operating on HNSW topology
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct RuvectorLayer {
310    /// Message weight matrix
311    w_msg: Linear,
312
313    /// Aggregation weight matrix
314    w_agg: Linear,
315
316    /// GRU update cell
317    w_update: GRUCell,
318
319    /// Multi-head attention
320    attention: MultiHeadAttention,
321
322    /// Layer normalization
323    norm: LayerNorm,
324
325    /// Dropout rate
326    dropout: f32,
327}
328
329impl RuvectorLayer {
330    /// Create a new Ruvector GNN layer
331    ///
332    /// # Arguments
333    /// * `input_dim` - Dimension of input node embeddings
334    /// * `hidden_dim` - Dimension of hidden representations
335    /// * `heads` - Number of attention heads
336    /// * `dropout` - Dropout rate (0.0 to 1.0)
337    pub fn new(input_dim: usize, hidden_dim: usize, heads: usize, dropout: f32) -> Self {
338        assert!(
339            dropout >= 0.0 && dropout <= 1.0,
340            "Dropout must be between 0.0 and 1.0"
341        );
342
343        Self {
344            w_msg: Linear::new(input_dim, hidden_dim),
345            w_agg: Linear::new(hidden_dim, hidden_dim),
346            w_update: GRUCell::new(hidden_dim, hidden_dim),
347            attention: MultiHeadAttention::new(hidden_dim, heads),
348            norm: LayerNorm::new(hidden_dim, 1e-5),
349            dropout,
350        }
351    }
352
353    /// Forward pass through the GNN layer
354    ///
355    /// # Arguments
356    /// * `node_embedding` - Current node's embedding
357    /// * `neighbor_embeddings` - Embeddings of neighbor nodes
358    /// * `edge_weights` - Weights of edges to neighbors (e.g., distances)
359    ///
360    /// # Returns
361    /// Updated node embedding
362    pub fn forward(
363        &self,
364        node_embedding: &[f32],
365        neighbor_embeddings: &[Vec<f32>],
366        edge_weights: &[f32],
367    ) -> Vec<f32> {
368        if neighbor_embeddings.is_empty() {
369            // No neighbors: return normalized projection
370            let projected = self.w_msg.forward(node_embedding);
371            return self.norm.forward(&projected);
372        }
373
374        // Step 1: Message passing - transform node and neighbor embeddings
375        let node_msg = self.w_msg.forward(node_embedding);
376        let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
377            .iter()
378            .map(|n| self.w_msg.forward(n))
379            .collect();
380
381        // Step 2: Attention-based aggregation
382        let attention_output = self
383            .attention
384            .forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
385
386        // Step 3: Weighted aggregation using edge weights
387        let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
388
389        // Step 4: Combine attention and weighted aggregation
390        let combined = self.add_vecs(&attention_output, &weighted_msgs);
391        let aggregated = self.w_agg.forward(&combined);
392
393        // Step 5: GRU update
394        let updated = self.w_update.forward(&aggregated, &node_msg);
395
396        // Step 6: Apply dropout (simplified - always apply scaling)
397        let dropped = self.apply_dropout(&updated);
398
399        // Step 7: Layer normalization
400        self.norm.forward(&dropped)
401    }
402
403    /// Aggregate neighbor messages with edge weights
404    fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
405        if messages.is_empty() || weights.is_empty() {
406            return vec![0.0; self.w_msg.output_dim()];
407        }
408
409        // Normalize weights to sum to 1
410        let weight_sum: f32 = weights.iter().sum();
411        let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
412            weights.iter().map(|&w| w / weight_sum).collect()
413        } else {
414            vec![1.0 / weights.len() as f32; weights.len()]
415        };
416
417        // Weighted sum
418        let dim = messages[0].len();
419        let mut aggregated = vec![0.0; dim];
420
421        for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
422            for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
423                *agg += weight * m;
424            }
425        }
426
427        aggregated
428    }
429
430    /// Apply dropout (simplified version - just scales by (1-dropout))
431    fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
432        let scale = 1.0 - self.dropout;
433        input.iter().map(|&x| x * scale).collect()
434    }
435
436    /// Element-wise vector addition
437    fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
438        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_linear_layer() {
448        let linear = Linear::new(4, 2);
449        let input = vec![1.0, 2.0, 3.0, 4.0];
450        let output = linear.forward(&input);
451        assert_eq!(output.len(), 2);
452    }
453
454    #[test]
455    fn test_layer_norm() {
456        let norm = LayerNorm::new(4, 1e-5);
457        let input = vec![1.0, 2.0, 3.0, 4.0];
458        let output = norm.forward(&input);
459
460        // Check that output has zero mean (approximately)
461        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
462        assert!((mean).abs() < 1e-5);
463    }
464
465    #[test]
466    fn test_multihead_attention() {
467        let attention = MultiHeadAttention::new(8, 2);
468        let query = vec![0.5; 8];
469        let keys = vec![vec![0.3; 8], vec![0.7; 8]];
470        let values = vec![vec![0.2; 8], vec![0.8; 8]];
471
472        let output = attention.forward(&query, &keys, &values);
473        assert_eq!(output.len(), 8);
474    }
475
476    #[test]
477    fn test_gru_cell() {
478        let gru = GRUCell::new(4, 8);
479        let input = vec![1.0; 4];
480        let hidden = vec![0.5; 8];
481
482        let new_hidden = gru.forward(&input, &hidden);
483        assert_eq!(new_hidden.len(), 8);
484    }
485
486    #[test]
487    fn test_ruvector_layer() {
488        let layer = RuvectorLayer::new(4, 8, 2, 0.1);
489
490        let node = vec![1.0, 2.0, 3.0, 4.0];
491        let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
492        let weights = vec![0.3, 0.7];
493
494        let output = layer.forward(&node, &neighbors, &weights);
495        assert_eq!(output.len(), 8);
496    }
497
498    #[test]
499    fn test_ruvector_layer_no_neighbors() {
500        let layer = RuvectorLayer::new(4, 8, 2, 0.1);
501
502        let node = vec![1.0, 2.0, 3.0, 4.0];
503        let neighbors: Vec<Vec<f32>> = vec![];
504        let weights: Vec<f32> = vec![];
505
506        let output = layer.forward(&node, &neighbors, &weights);
507        assert_eq!(output.len(), 8);
508    }
509}