Skip to main content

ruvector_dag/sona/
micro_lora.rs

1//! MicroLoRA: Ultra-fast per-query adaptation
2
3use ndarray::{Array1, Array2};
4
5#[derive(Debug, Clone)]
6pub struct MicroLoRAConfig {
7    pub rank: usize,  // 1-2 for micro
8    pub alpha: f32,   // Scaling factor
9    pub dropout: f32, // Dropout rate
10}
11
12impl Default for MicroLoRAConfig {
13    fn default() -> Self {
14        Self {
15            rank: 2,
16            alpha: 1.0,
17            dropout: 0.0,
18        }
19    }
20}
21
22pub struct MicroLoRA {
23    config: MicroLoRAConfig,
24    a_matrix: Array2<f32>, // (in_dim, rank)
25    b_matrix: Array2<f32>, // (rank, out_dim)
26    #[allow(dead_code)]
27    in_dim: usize,
28    #[allow(dead_code)]
29    out_dim: usize,
30}
31
32impl MicroLoRA {
33    pub fn new(config: MicroLoRAConfig, dim: usize) -> Self {
34        let rank = config.rank;
35        // Initialize A with small random values, B with zeros
36        let a_matrix = Array2::from_shape_fn((dim, rank), |_| (rand::random::<f32>() - 0.5) * 0.01);
37        let b_matrix = Array2::zeros((rank, dim));
38
39        Self {
40            config,
41            a_matrix,
42            b_matrix,
43            in_dim: dim,
44            out_dim: dim,
45        }
46    }
47
48    /// Forward pass: x + alpha * (x @ A @ B)
49    pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
50        let low_rank = x.dot(&self.a_matrix).dot(&self.b_matrix);
51        x + &(low_rank * self.config.alpha)
52    }
53
54    /// Adapt weights based on gradient signal
55    pub fn adapt(&mut self, gradient: &Array1<f32>, learning_rate: f32) {
56        // Update B matrix based on gradient (rank-1 update)
57        // This is the "instant" adaptation - must be <100μs
58        let grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
59        if grad_norm > 1e-8 {
60            let normalized = gradient / grad_norm;
61            // Outer product update to B
62            for i in 0..self.config.rank {
63                for j in 0..self.out_dim {
64                    self.b_matrix[[i, j]] +=
65                        learning_rate * self.a_matrix.column(i).sum() * normalized[j];
66                }
67            }
68        }
69    }
70
71    /// Reset to initial state
72    pub fn reset(&mut self) {
73        self.b_matrix.fill(0.0);
74    }
75
76    /// Get parameter count
77    pub fn param_count(&self) -> usize {
78        self.a_matrix.len() + self.b_matrix.len()
79    }
80}