ruvector_sona/
lora.rs

1//! LoRA (Low-Rank Adaptation) implementations for SONA
2//!
3//! Two-tier LoRA system:
4//! - MicroLoRA: Rank 1-2, per-request adaptation (<100μs)
5//! - BaseLoRA: Rank 4-16, background adaptation (hourly)
6
7use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10/// Micro-LoRA for per-request adaptation
11///
12/// Uses rank 1-2 for ultra-low latency updates.
13/// Forward pass: output += scale * (input @ down) @ up
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct MicroLoRA {
16    /// Down projection (hidden_dim -> rank)
17    down_proj: Vec<f32>,
18    /// Up projection (rank -> hidden_dim)
19    up_proj: Vec<f32>,
20    /// Rank (1-2 for micro updates)
21    rank: usize,
22    /// Hidden dimension
23    hidden_dim: usize,
24    /// Accumulated gradients for down
25    #[serde(skip)]
26    grad_down: Vec<f32>,
27    /// Accumulated gradients for up
28    #[serde(skip)]
29    grad_up: Vec<f32>,
30    /// Update count for averaging
31    #[serde(skip)]
32    update_count: usize,
33    /// Scaling factor
34    scale: f32,
35}
36
37impl MicroLoRA {
38    /// Create new Micro-LoRA adapter
39    ///
40    /// # Arguments
41    /// * `hidden_dim` - Model hidden dimension
42    /// * `rank` - LoRA rank (must be 1-2)
43    ///
44    /// # Panics
45    /// Panics if rank > 2
46    pub fn new(hidden_dim: usize, rank: usize) -> Self {
47        assert!(rank >= 1 && rank <= 2, "MicroLoRA rank must be 1-2, got {}", rank);
48
49        // Initialize down with small random-like values (deterministic for reproducibility)
50        let down_proj: Vec<f32> = (0..hidden_dim * rank)
51            .map(|i| {
52                let x = (i as f32 * 0.618033988749895) % 1.0;
53                (x - 0.5) * 0.02
54            })
55            .collect();
56
57        // Initialize up to zero (standard LoRA init)
58        let up_proj = vec![0.0f32; rank * hidden_dim];
59
60        Self {
61            down_proj,
62            up_proj,
63            rank,
64            hidden_dim,
65            grad_down: vec![0.0; hidden_dim * rank],
66            grad_up: vec![0.0; rank * hidden_dim],
67            update_count: 0,
68            scale: 1.0 / (rank as f32).sqrt(),
69        }
70    }
71
72    /// Scalar forward pass (fallback)
73    pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
74        assert_eq!(input.len(), self.hidden_dim);
75        assert_eq!(output.len(), self.hidden_dim);
76
77        // Down projection: hidden_dim -> rank
78        let mut intermediate = vec![0.0f32; self.rank];
79        for r in 0..self.rank {
80            let mut sum = 0.0f32;
81            let offset = r * self.hidden_dim;
82            for i in 0..self.hidden_dim {
83                sum += input[i] * self.down_proj[offset + i];
84            }
85            intermediate[r] = sum;
86        }
87
88        // Up projection: rank -> hidden_dim
89        for i in 0..self.hidden_dim {
90            let mut sum = 0.0f32;
91            for r in 0..self.rank {
92                sum += intermediate[r] * self.up_proj[r * self.hidden_dim + i];
93            }
94            output[i] += sum * self.scale;
95        }
96    }
97
98    /// SIMD-optimized forward pass (AVX2)
99    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
100    pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
101        use std::arch::x86_64::*;
102
103        assert_eq!(input.len(), self.hidden_dim);
104        assert_eq!(output.len(), self.hidden_dim);
105
106        unsafe {
107            // Down projection: hidden_dim -> rank
108            let mut intermediate = vec![0.0f32; self.rank];
109
110            for r in 0..self.rank {
111                let mut sum = _mm256_setzero_ps();
112                let offset = r * self.hidden_dim;
113
114                let mut i = 0;
115                while i + 8 <= self.hidden_dim {
116                    let inp = _mm256_loadu_ps(input[i..].as_ptr());
117                    let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
118                    sum = _mm256_fmadd_ps(inp, weight, sum);
119                    i += 8;
120                }
121
122                // Horizontal sum
123                let mut result = [0.0f32; 8];
124                _mm256_storeu_ps(result.as_mut_ptr(), sum);
125                intermediate[r] = result.iter().sum();
126
127                // Handle remaining elements
128                for j in i..self.hidden_dim {
129                    intermediate[r] += input[j] * self.down_proj[offset + j];
130                }
131            }
132
133            // Up projection: rank -> hidden_dim
134            let scale_vec = _mm256_set1_ps(self.scale);
135
136            let mut i = 0;
137            while i + 8 <= self.hidden_dim {
138                let mut sum = _mm256_setzero_ps();
139
140                for r in 0..self.rank {
141                    let up_offset = r * self.hidden_dim;
142                    let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
143                    let inter = _mm256_set1_ps(intermediate[r]);
144                    sum = _mm256_fmadd_ps(inter, weight, sum);
145                }
146
147                // Scale and add to output
148                sum = _mm256_mul_ps(sum, scale_vec);
149                let existing = _mm256_loadu_ps(output[i..].as_ptr());
150                let result = _mm256_add_ps(existing, sum);
151                _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
152
153                i += 8;
154            }
155
156            // Handle remaining elements
157            for j in i..self.hidden_dim {
158                let mut val = 0.0;
159                for r in 0..self.rank {
160                    val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
161                }
162                output[j] += val * self.scale;
163            }
164        }
165    }
166
167    /// Forward pass with automatic SIMD detection
168    pub fn forward(&self, input: &[f32], output: &mut [f32]) {
169        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
170        {
171            self.forward_simd(input, output);
172            return;
173        }
174
175        #[allow(unreachable_code)]
176        self.forward_scalar(input, output);
177    }
178
179    /// Accumulate gradient from learning signal
180    pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
181        if signal.gradient_estimate.len() != self.hidden_dim {
182            return;
183        }
184
185        let quality = signal.quality_score;
186
187        // Simplified gradient: outer product scaled by quality
188        // This approximates the true gradient for rank-1 LoRA
189        for r in 0..self.rank {
190            for i in 0..self.hidden_dim {
191                let grad_idx = r * self.hidden_dim + i;
192                // Update up projection gradient (main target)
193                self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
194            }
195        }
196
197        self.update_count += 1;
198    }
199
200    /// Apply accumulated gradients with learning rate
201    pub fn apply_accumulated(&mut self, learning_rate: f32) {
202        if self.update_count == 0 {
203            return;
204        }
205
206        let scale = learning_rate / self.update_count as f32;
207
208        // Update up projection (main adaptation target)
209        for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
210            *w += g * scale;
211        }
212
213        // Reset accumulators
214        self.grad_up.fill(0.0);
215        self.grad_down.fill(0.0);
216        self.update_count = 0;
217    }
218
219    /// Reset adapter to initial state
220    pub fn reset(&mut self) {
221        self.up_proj.fill(0.0);
222        self.grad_up.fill(0.0);
223        self.grad_down.fill(0.0);
224        self.update_count = 0;
225    }
226
227    /// Get rank
228    pub fn rank(&self) -> usize {
229        self.rank
230    }
231
232    /// Get hidden dimension
233    pub fn hidden_dim(&self) -> usize {
234        self.hidden_dim
235    }
236
237    /// Get parameter count
238    pub fn param_count(&self) -> usize {
239        self.down_proj.len() + self.up_proj.len()
240    }
241
242    /// Get scale factor
243    pub fn scale(&self) -> f32 {
244        self.scale
245    }
246
247    /// Set scale factor
248    pub fn set_scale(&mut self, scale: f32) {
249        self.scale = scale;
250    }
251
252    /// Get pending update count
253    pub fn pending_updates(&self) -> usize {
254        self.update_count
255    }
256}
257
258/// Base LoRA for background adaptation
259///
260/// Higher rank (4-16) for more expressive adaptation.
261/// Applied hourly during background learning cycles.
262#[derive(Clone, Debug, Serialize, Deserialize)]
263pub struct BaseLoRA {
264    /// LoRA layers
265    pub layers: Vec<LoRALayer>,
266    /// Rank
267    pub rank: usize,
268    /// Hidden dimension
269    pub hidden_dim: usize,
270    /// Alpha scaling factor
271    pub alpha: f32,
272}
273
274/// Single LoRA layer
275#[derive(Clone, Debug, Serialize, Deserialize)]
276pub struct LoRALayer {
277    /// Down projection weights
278    pub down_proj: Vec<f32>,
279    /// Up projection weights
280    pub up_proj: Vec<f32>,
281    /// Layer index
282    pub layer_idx: usize,
283}
284
285impl BaseLoRA {
286    /// Create new Base LoRA
287    pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
288        let layers = (0..num_layers)
289            .map(|idx| LoRALayer {
290                down_proj: vec![0.0; hidden_dim * rank],
291                up_proj: vec![0.0; rank * hidden_dim],
292                layer_idx: idx,
293            })
294            .collect();
295
296        Self {
297            layers,
298            rank,
299            hidden_dim,
300            alpha: rank as f32,
301        }
302    }
303
304    /// Forward pass for single layer
305    pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
306        if layer_idx >= self.layers.len() {
307            return;
308        }
309
310        let layer = &self.layers[layer_idx];
311        let scale = self.alpha / self.rank as f32;
312
313        // Down projection
314        let mut intermediate = vec![0.0f32; self.rank];
315        for r in 0..self.rank {
316            let offset = r * self.hidden_dim;
317            intermediate[r] = input.iter()
318                .zip(&layer.down_proj[offset..offset + self.hidden_dim])
319                .map(|(a, b)| a * b)
320                .sum();
321        }
322
323        // Up projection
324        for i in 0..self.hidden_dim {
325            let mut sum = 0.0f32;
326            for r in 0..self.rank {
327                sum += intermediate[r] * layer.up_proj[r * self.hidden_dim + i];
328            }
329            output[i] += sum * scale;
330        }
331    }
332
333    /// Merge LoRA weights into model weights (for inference optimization)
334    pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
335        if layer_idx >= self.layers.len() {
336            return;
337        }
338
339        let layer = &self.layers[layer_idx];
340        let scale = self.alpha / self.rank as f32;
341
342        // W' = W + scale * (down @ up)
343        // Assumes model_weights is [hidden_dim x hidden_dim]
344        for i in 0..self.hidden_dim {
345            for j in 0..self.hidden_dim {
346                let mut delta = 0.0f32;
347                for r in 0..self.rank {
348                    delta += layer.down_proj[i * self.rank + r]
349                           * layer.up_proj[r * self.hidden_dim + j];
350                }
351                model_weights[i * self.hidden_dim + j] += delta * scale;
352            }
353        }
354    }
355
356    /// Get number of layers
357    pub fn num_layers(&self) -> usize {
358        self.layers.len()
359    }
360
361    /// Get total parameter count
362    pub fn param_count(&self) -> usize {
363        self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
364    }
365}
366
367/// Combined LoRA engine managing both tiers
368#[derive(Clone, Debug)]
369pub struct LoRAEngine {
370    /// Micro-LoRA for instant adaptation
371    pub micro: MicroLoRA,
372    /// Base LoRA for background adaptation
373    pub base: BaseLoRA,
374    /// Whether micro-LoRA is enabled
375    pub micro_enabled: bool,
376    /// Whether base LoRA is enabled
377    pub base_enabled: bool,
378}
379
380impl LoRAEngine {
381    /// Create new LoRA engine
382    pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
383        Self {
384            micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
385            base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
386            micro_enabled: true,
387            base_enabled: true,
388        }
389    }
390
391    /// Apply both LoRA tiers
392    pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
393        if self.micro_enabled {
394            self.micro.forward(input, output);
395        }
396        if self.base_enabled && layer_idx < self.base.num_layers() {
397            self.base.forward_layer(layer_idx, input, output);
398        }
399    }
400
401    /// Accumulate micro-LoRA gradient
402    pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
403        if self.micro_enabled {
404            self.micro.accumulate_gradient(signal);
405        }
406    }
407
408    /// Apply micro-LoRA updates
409    pub fn apply_micro(&mut self, learning_rate: f32) {
410        if self.micro_enabled {
411            self.micro.apply_accumulated(learning_rate);
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_micro_lora_creation() {
422        let lora = MicroLoRA::new(256, 1);
423        assert_eq!(lora.rank(), 1);
424        assert_eq!(lora.hidden_dim(), 256);
425        assert_eq!(lora.param_count(), 256 + 256);
426    }
427
428    #[test]
429    fn test_micro_lora_forward() {
430        let lora = MicroLoRA::new(64, 1);
431        let input = vec![1.0f32; 64];
432        let mut output = vec![0.0f32; 64];
433
434        lora.forward(&input, &mut output);
435
436        // Output should be modified (even if small due to init)
437        // With zero-init up_proj, output should still be zero
438        let sum: f32 = output.iter().sum();
439        assert!(sum.abs() < 1e-6, "Expected ~0 with zero up_proj, got {}", sum);
440    }
441
442    #[test]
443    fn test_micro_lora_learning() {
444        let mut lora = MicroLoRA::new(64, 1);
445
446        let signal = LearningSignal::with_gradient(
447            vec![0.1; 64],
448            vec![0.5; 64],
449            0.8,
450        );
451
452        lora.accumulate_gradient(&signal);
453        assert_eq!(lora.pending_updates(), 1);
454
455        lora.apply_accumulated(0.01);
456        assert_eq!(lora.pending_updates(), 0);
457
458        // Now forward should produce non-zero output
459        let input = vec![1.0f32; 64];
460        let mut output = vec![0.0f32; 64];
461        lora.forward(&input, &mut output);
462
463        let sum: f32 = output.iter().map(|x| x.abs()).sum();
464        assert!(sum > 0.0, "Expected non-zero output after learning");
465    }
466
467    #[test]
468    fn test_base_lora() {
469        let lora = BaseLoRA::new(64, 4, 12);
470        assert_eq!(lora.num_layers(), 12);
471        assert_eq!(lora.rank, 4);
472    }
473
474    #[test]
475    fn test_lora_engine() {
476        let mut engine = LoRAEngine::new(64, 1, 4, 12);
477
478        let signal = LearningSignal::with_gradient(
479            vec![0.1; 64],
480            vec![0.5; 64],
481            0.9,
482        );
483
484        engine.accumulate_micro(&signal);
485        engine.apply_micro(0.01);
486
487        let input = vec![1.0f32; 64];
488        let mut output = vec![0.0f32; 64];
489        engine.forward(0, &input, &mut output);
490    }
491
492    #[test]
493    #[should_panic(expected = "MicroLoRA rank must be 1-2")]
494    fn test_invalid_rank() {
495        MicroLoRA::new(64, 5);
496    }
497}