Skip to main content

ruvector_sona/
lora.rs

1//! LoRA (Low-Rank Adaptation) implementations for SONA
2//!
3//! Two-tier LoRA system:
4//! - MicroLoRA: Rank 1-2, per-request adaptation (<100μs)
5//! - BaseLoRA: Rank 4-16, background adaptation (hourly)
6
7use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10/// Optimal batch size for processing (benchmark-validated)
11pub const OPTIMAL_BATCH_SIZE: usize = 32;
12
13/// Micro-LoRA for per-request adaptation
14///
15/// Uses rank 1-2 for ultra-low latency updates.
16/// Forward pass: output += scale * (input @ down) @ up
17///
18/// **Performance notes (from benchmarks):**
19/// - Rank-2 is ~5% faster than Rank-1 due to better SIMD vectorization
20/// - Batch size 32 optimal: 0.447ms per-vector, 2,236 ops/sec throughput
21/// - SIMD-enabled: +10% speedup over scalar
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct MicroLoRA {
24    /// Down projection (hidden_dim -> rank)
25    down_proj: Vec<f32>,
26    /// Up projection (rank -> hidden_dim)
27    up_proj: Vec<f32>,
28    /// Rank (1-2 for micro updates)
29    rank: usize,
30    /// Hidden dimension
31    hidden_dim: usize,
32    /// Accumulated gradients for down
33    #[serde(skip)]
34    grad_down: Vec<f32>,
35    /// Accumulated gradients for up
36    #[serde(skip)]
37    grad_up: Vec<f32>,
38    /// Update count for averaging
39    #[serde(skip)]
40    update_count: usize,
41    /// Scaling factor
42    scale: f32,
43}
44
45impl MicroLoRA {
46    /// Create new Micro-LoRA adapter
47    ///
48    /// # Arguments
49    /// * `hidden_dim` - Model hidden dimension
50    /// * `rank` - LoRA rank (must be 1-2)
51    ///
52    /// # Panics
53    /// Panics if rank > 2
54    pub fn new(hidden_dim: usize, rank: usize) -> Self {
55        assert!(
56            (1..=2).contains(&rank),
57            "MicroLoRA rank must be 1-2, got {}",
58            rank
59        );
60
61        // Initialize down with small random-like values (deterministic for reproducibility)
62        let down_proj: Vec<f32> = (0..hidden_dim * rank)
63            .map(|i| {
64                let x = (i as f32 * 0.618_034) % 1.0;
65                (x - 0.5) * 0.02
66            })
67            .collect();
68
69        // Initialize up to zero (standard LoRA init)
70        let up_proj = vec![0.0f32; rank * hidden_dim];
71
72        Self {
73            down_proj,
74            up_proj,
75            rank,
76            hidden_dim,
77            grad_down: vec![0.0; hidden_dim * rank],
78            grad_up: vec![0.0; rank * hidden_dim],
79            update_count: 0,
80            scale: 1.0 / (rank as f32).sqrt(),
81        }
82    }
83
84    /// Scalar forward pass (fallback)
85    pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
86        assert_eq!(input.len(), self.hidden_dim);
87        assert_eq!(output.len(), self.hidden_dim);
88
89        // Down projection: hidden_dim -> rank
90        let mut intermediate = vec![0.0f32; self.rank];
91        for (r, inter) in intermediate.iter_mut().enumerate() {
92            let mut sum = 0.0f32;
93            let offset = r * self.hidden_dim;
94            for (i, &inp) in input.iter().enumerate() {
95                sum += inp * self.down_proj[offset + i];
96            }
97            *inter = sum;
98        }
99
100        // Up projection: rank -> hidden_dim
101        for (i, out) in output.iter_mut().enumerate() {
102            let mut sum = 0.0f32;
103            for (r, &inter) in intermediate.iter().enumerate() {
104                sum += inter * self.up_proj[r * self.hidden_dim + i];
105            }
106            *out += sum * self.scale;
107        }
108    }
109
110    /// SIMD-optimized forward pass (AVX2)
111    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
112    pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
113        use std::arch::x86_64::*;
114
115        assert_eq!(input.len(), self.hidden_dim);
116        assert_eq!(output.len(), self.hidden_dim);
117
118        unsafe {
119            // Down projection: hidden_dim -> rank
120            let mut intermediate = vec![0.0f32; self.rank];
121
122            for r in 0..self.rank {
123                let mut sum = _mm256_setzero_ps();
124                let offset = r * self.hidden_dim;
125
126                let mut i = 0;
127                while i + 8 <= self.hidden_dim {
128                    let inp = _mm256_loadu_ps(input[i..].as_ptr());
129                    let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
130                    sum = _mm256_fmadd_ps(inp, weight, sum);
131                    i += 8;
132                }
133
134                // Horizontal sum
135                let mut result = [0.0f32; 8];
136                _mm256_storeu_ps(result.as_mut_ptr(), sum);
137                intermediate[r] = result.iter().sum();
138
139                // Handle remaining elements
140                for j in i..self.hidden_dim {
141                    intermediate[r] += input[j] * self.down_proj[offset + j];
142                }
143            }
144
145            // Up projection: rank -> hidden_dim
146            let scale_vec = _mm256_set1_ps(self.scale);
147
148            let mut i = 0;
149            while i + 8 <= self.hidden_dim {
150                let mut sum = _mm256_setzero_ps();
151
152                for r in 0..self.rank {
153                    let up_offset = r * self.hidden_dim;
154                    let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
155                    let inter = _mm256_set1_ps(intermediate[r]);
156                    sum = _mm256_fmadd_ps(inter, weight, sum);
157                }
158
159                // Scale and add to output
160                sum = _mm256_mul_ps(sum, scale_vec);
161                let existing = _mm256_loadu_ps(output[i..].as_ptr());
162                let result = _mm256_add_ps(existing, sum);
163                _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
164
165                i += 8;
166            }
167
168            // Handle remaining elements
169            for j in i..self.hidden_dim {
170                let mut val = 0.0;
171                for r in 0..self.rank {
172                    val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
173                }
174                output[j] += val * self.scale;
175            }
176        }
177    }
178
179    /// Forward pass with automatic SIMD detection
180    pub fn forward(&self, input: &[f32], output: &mut [f32]) {
181        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
182        {
183            self.forward_simd(input, output);
184            return;
185        }
186
187        #[allow(unreachable_code)]
188        self.forward_scalar(input, output);
189    }
190
191    /// Accumulate gradient from learning signal
192    pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
193        if signal.gradient_estimate.len() != self.hidden_dim {
194            return;
195        }
196
197        let quality = signal.quality_score;
198
199        // Simplified gradient: outer product scaled by quality
200        // This approximates the true gradient for rank-1 LoRA
201        for r in 0..self.rank {
202            for i in 0..self.hidden_dim {
203                let grad_idx = r * self.hidden_dim + i;
204                // Update up projection gradient (main target)
205                self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
206            }
207        }
208
209        self.update_count += 1;
210    }
211
212    /// Apply accumulated gradients with learning rate
213    pub fn apply_accumulated(&mut self, learning_rate: f32) {
214        if self.update_count == 0 {
215            return;
216        }
217
218        let scale = learning_rate / self.update_count as f32;
219
220        // Update up projection (main adaptation target)
221        for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
222            *w += g * scale;
223        }
224
225        // Reset accumulators
226        self.grad_up.fill(0.0);
227        self.grad_down.fill(0.0);
228        self.update_count = 0;
229    }
230
231    /// Reset adapter to initial state
232    pub fn reset(&mut self) {
233        self.up_proj.fill(0.0);
234        self.grad_up.fill(0.0);
235        self.grad_down.fill(0.0);
236        self.update_count = 0;
237    }
238
239    /// Get rank
240    pub fn rank(&self) -> usize {
241        self.rank
242    }
243
244    /// Get hidden dimension
245    pub fn hidden_dim(&self) -> usize {
246        self.hidden_dim
247    }
248
249    /// Get parameter count
250    pub fn param_count(&self) -> usize {
251        self.down_proj.len() + self.up_proj.len()
252    }
253
254    /// Get scale factor
255    pub fn scale(&self) -> f32 {
256        self.scale
257    }
258
259    /// Set scale factor
260    pub fn set_scale(&mut self, scale: f32) {
261        self.scale = scale;
262    }
263
264    /// Get pending update count
265    pub fn pending_updates(&self) -> usize {
266        self.update_count
267    }
268
269    /// Get LoRA weights for export (lora_a, lora_b)
270    pub fn get_weights(&self) -> (&Vec<f32>, &Vec<f32>) {
271        (&self.down_proj, &self.up_proj)
272    }
273
274    /// Set LoRA weights from external source (disk load, other system)
275    ///
276    /// # Arguments
277    /// * `down_proj` - Down projection weights (hidden_dim * rank)
278    /// * `up_proj` - Up projection weights (rank * hidden_dim)
279    ///
280    /// # Errors
281    /// Returns Err if dimensions don't match current rank/hidden_dim
282    pub fn set_weights(&mut self, down_proj: Vec<f32>, up_proj: Vec<f32>) -> Result<(), String> {
283        let expected_down = self.hidden_dim * self.rank;
284        if down_proj.len() != expected_down {
285            return Err(format!(
286                "down_proj dimension mismatch: expected {}, got {}",
287                expected_down,
288                down_proj.len()
289            ));
290        }
291
292        let expected_up = self.rank * self.hidden_dim;
293        if up_proj.len() != expected_up {
294            return Err(format!(
295                "up_proj dimension mismatch: expected {}, got {}",
296                expected_up,
297                up_proj.len()
298            ));
299        }
300
301        self.down_proj = down_proj;
302        self.up_proj = up_proj;
303        Ok(())
304    }
305}
306
307/// Base LoRA for background adaptation
308///
309/// Higher rank (4-16) for more expressive adaptation.
310/// Applied hourly during background learning cycles.
311#[derive(Clone, Debug, Serialize, Deserialize)]
312pub struct BaseLoRA {
313    /// LoRA layers
314    pub layers: Vec<LoRALayer>,
315    /// Rank
316    pub rank: usize,
317    /// Hidden dimension
318    pub hidden_dim: usize,
319    /// Alpha scaling factor
320    pub alpha: f32,
321}
322
323/// Single LoRA layer
324#[derive(Clone, Debug, Serialize, Deserialize)]
325pub struct LoRALayer {
326    /// Down projection weights
327    pub down_proj: Vec<f32>,
328    /// Up projection weights
329    pub up_proj: Vec<f32>,
330    /// Layer index
331    pub layer_idx: usize,
332}
333
334impl BaseLoRA {
335    /// Create new Base LoRA
336    pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
337        let layers = (0..num_layers)
338            .map(|idx| LoRALayer {
339                down_proj: vec![0.0; hidden_dim * rank],
340                up_proj: vec![0.0; rank * hidden_dim],
341                layer_idx: idx,
342            })
343            .collect();
344
345        Self {
346            layers,
347            rank,
348            hidden_dim,
349            alpha: rank as f32,
350        }
351    }
352
353    /// Forward pass for single layer
354    pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
355        if layer_idx >= self.layers.len() {
356            return;
357        }
358
359        let layer = &self.layers[layer_idx];
360        let scale = self.alpha / self.rank as f32;
361
362        // Down projection
363        let mut intermediate = vec![0.0f32; self.rank];
364        for (r, inter) in intermediate.iter_mut().enumerate() {
365            let offset = r * self.hidden_dim;
366            *inter = input
367                .iter()
368                .zip(&layer.down_proj[offset..offset + self.hidden_dim])
369                .map(|(a, b)| a * b)
370                .sum();
371        }
372
373        // Up projection
374        for (i, out) in output.iter_mut().enumerate() {
375            let mut sum = 0.0f32;
376            for (r, &inter) in intermediate.iter().enumerate() {
377                sum += inter * layer.up_proj[r * self.hidden_dim + i];
378            }
379            *out += sum * scale;
380        }
381    }
382
383    /// Merge LoRA weights into model weights (for inference optimization)
384    pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
385        if layer_idx >= self.layers.len() {
386            return;
387        }
388
389        let layer = &self.layers[layer_idx];
390        let scale = self.alpha / self.rank as f32;
391
392        // W' = W + scale * (down @ up)
393        // Assumes model_weights is [hidden_dim x hidden_dim]
394        for i in 0..self.hidden_dim {
395            for j in 0..self.hidden_dim {
396                let mut delta = 0.0f32;
397                for r in 0..self.rank {
398                    delta +=
399                        layer.down_proj[i * self.rank + r] * layer.up_proj[r * self.hidden_dim + j];
400                }
401                model_weights[i * self.hidden_dim + j] += delta * scale;
402            }
403        }
404    }
405
406    /// Get number of layers
407    pub fn num_layers(&self) -> usize {
408        self.layers.len()
409    }
410
411    /// Get total parameter count
412    pub fn param_count(&self) -> usize {
413        self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
414    }
415
416    /// Get weights for a specific layer for export (lora_a, lora_b)
417    pub fn get_layer_weights(&self, layer_idx: usize) -> Option<(&Vec<f32>, &Vec<f32>)> {
418        self.layers
419            .get(layer_idx)
420            .map(|layer| (&layer.down_proj, &layer.up_proj))
421    }
422}
423
424/// Combined LoRA engine managing both tiers
425#[derive(Clone, Debug)]
426pub struct LoRAEngine {
427    /// Micro-LoRA for instant adaptation
428    pub micro: MicroLoRA,
429    /// Base LoRA for background adaptation
430    pub base: BaseLoRA,
431    /// Whether micro-LoRA is enabled
432    pub micro_enabled: bool,
433    /// Whether base LoRA is enabled
434    pub base_enabled: bool,
435}
436
437impl LoRAEngine {
438    /// Create new LoRA engine
439    pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
440        Self {
441            micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
442            base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
443            micro_enabled: true,
444            base_enabled: true,
445        }
446    }
447
448    /// Apply both LoRA tiers
449    pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
450        if self.micro_enabled {
451            self.micro.forward(input, output);
452        }
453        if self.base_enabled && layer_idx < self.base.num_layers() {
454            self.base.forward_layer(layer_idx, input, output);
455        }
456    }
457
458    /// Accumulate micro-LoRA gradient
459    pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
460        if self.micro_enabled {
461            self.micro.accumulate_gradient(signal);
462        }
463    }
464
465    /// Apply micro-LoRA updates
466    pub fn apply_micro(&mut self, learning_rate: f32) {
467        if self.micro_enabled {
468            self.micro.apply_accumulated(learning_rate);
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_micro_lora_creation() {
479        let lora = MicroLoRA::new(256, 1);
480        assert_eq!(lora.rank(), 1);
481        assert_eq!(lora.hidden_dim(), 256);
482        assert_eq!(lora.param_count(), 256 + 256);
483    }
484
485    #[test]
486    fn test_micro_lora_forward() {
487        let lora = MicroLoRA::new(64, 1);
488        let input = vec![1.0f32; 64];
489        let mut output = vec![0.0f32; 64];
490
491        lora.forward(&input, &mut output);
492
493        // Output should be modified (even if small due to init)
494        // With zero-init up_proj, output should still be zero
495        let sum: f32 = output.iter().sum();
496        assert!(
497            sum.abs() < 1e-6,
498            "Expected ~0 with zero up_proj, got {}",
499            sum
500        );
501    }
502
503    #[test]
504    fn test_micro_lora_learning() {
505        let mut lora = MicroLoRA::new(64, 1);
506
507        let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.8);
508
509        lora.accumulate_gradient(&signal);
510        assert_eq!(lora.pending_updates(), 1);
511
512        lora.apply_accumulated(0.01);
513        assert_eq!(lora.pending_updates(), 0);
514
515        // Now forward should produce non-zero output
516        let input = vec![1.0f32; 64];
517        let mut output = vec![0.0f32; 64];
518        lora.forward(&input, &mut output);
519
520        let sum: f32 = output.iter().map(|x| x.abs()).sum();
521        assert!(sum > 0.0, "Expected non-zero output after learning");
522    }
523
524    #[test]
525    fn test_base_lora() {
526        let lora = BaseLoRA::new(64, 4, 12);
527        assert_eq!(lora.num_layers(), 12);
528        assert_eq!(lora.rank, 4);
529    }
530
531    #[test]
532    fn test_lora_engine() {
533        let mut engine = LoRAEngine::new(64, 1, 4, 12);
534
535        let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.9);
536
537        engine.accumulate_micro(&signal);
538        engine.apply_micro(0.01);
539
540        let input = vec![1.0f32; 64];
541        let mut output = vec![0.0f32; 64];
542        engine.forward(0, &input, &mut output);
543    }
544
545    #[test]
546    #[should_panic(expected = "MicroLoRA rank must be 1-2")]
547    fn test_invalid_rank() {
548        MicroLoRA::new(64, 5);
549    }
550
551    #[test]
552    fn test_set_weights_valid() {
553        let mut lora = MicroLoRA::new(64, 2);
554        let down = vec![1.0f32; 64 * 2];
555        let up = vec![0.5f32; 2 * 64];
556
557        let result = lora.set_weights(down.clone(), up.clone());
558        assert!(result.is_ok());
559
560        let (got_down, got_up) = lora.get_weights();
561        assert_eq!(got_down, &down);
562        assert_eq!(got_up, &up);
563    }
564
565    #[test]
566    fn test_set_weights_wrong_down_dim() {
567        let mut lora = MicroLoRA::new(64, 2);
568        let wrong_down = vec![1.0f32; 64 * 3];
569        let up = vec![0.5f32; 2 * 64];
570
571        let result = lora.set_weights(wrong_down, up);
572        assert!(result.is_err());
573        assert!(result.unwrap_err().contains("down_proj dimension mismatch"));
574    }
575
576    #[test]
577    fn test_set_weights_wrong_up_dim() {
578        let mut lora = MicroLoRA::new(64, 2);
579        let down = vec![1.0f32; 64 * 2];
580        let wrong_up = vec![0.5f32; 3 * 64];
581
582        let result = lora.set_weights(down, wrong_up);
583        assert!(result.is_err());
584        assert!(result.unwrap_err().contains("up_proj dimension mismatch"));
585    }
586}