1use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10pub const OPTIMAL_BATCH_SIZE: usize = 32;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct MicroLoRA {
24 down_proj: Vec<f32>,
26 up_proj: Vec<f32>,
28 rank: usize,
30 hidden_dim: usize,
32 #[serde(skip)]
34 grad_down: Vec<f32>,
35 #[serde(skip)]
37 grad_up: Vec<f32>,
38 #[serde(skip)]
40 update_count: usize,
41 scale: f32,
43}
44
45impl MicroLoRA {
46 pub fn new(hidden_dim: usize, rank: usize) -> Self {
55 assert!(
56 (1..=2).contains(&rank),
57 "MicroLoRA rank must be 1-2, got {}",
58 rank
59 );
60
61 let down_proj: Vec<f32> = (0..hidden_dim * rank)
63 .map(|i| {
64 let x = (i as f32 * 0.618_034) % 1.0;
65 (x - 0.5) * 0.02
66 })
67 .collect();
68
69 let up_proj = vec![0.0f32; rank * hidden_dim];
71
72 Self {
73 down_proj,
74 up_proj,
75 rank,
76 hidden_dim,
77 grad_down: vec![0.0; hidden_dim * rank],
78 grad_up: vec![0.0; rank * hidden_dim],
79 update_count: 0,
80 scale: 1.0 / (rank as f32).sqrt(),
81 }
82 }
83
84 pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
86 assert_eq!(input.len(), self.hidden_dim);
87 assert_eq!(output.len(), self.hidden_dim);
88
89 let mut intermediate = vec![0.0f32; self.rank];
91 for (r, inter) in intermediate.iter_mut().enumerate() {
92 let mut sum = 0.0f32;
93 let offset = r * self.hidden_dim;
94 for (i, &inp) in input.iter().enumerate() {
95 sum += inp * self.down_proj[offset + i];
96 }
97 *inter = sum;
98 }
99
100 for (i, out) in output.iter_mut().enumerate() {
102 let mut sum = 0.0f32;
103 for (r, &inter) in intermediate.iter().enumerate() {
104 sum += inter * self.up_proj[r * self.hidden_dim + i];
105 }
106 *out += sum * self.scale;
107 }
108 }
109
110 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
112 pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
113 use std::arch::x86_64::*;
114
115 assert_eq!(input.len(), self.hidden_dim);
116 assert_eq!(output.len(), self.hidden_dim);
117
118 unsafe {
119 let mut intermediate = vec![0.0f32; self.rank];
121
122 for r in 0..self.rank {
123 let mut sum = _mm256_setzero_ps();
124 let offset = r * self.hidden_dim;
125
126 let mut i = 0;
127 while i + 8 <= self.hidden_dim {
128 let inp = _mm256_loadu_ps(input[i..].as_ptr());
129 let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
130 sum = _mm256_fmadd_ps(inp, weight, sum);
131 i += 8;
132 }
133
134 let mut result = [0.0f32; 8];
136 _mm256_storeu_ps(result.as_mut_ptr(), sum);
137 intermediate[r] = result.iter().sum();
138
139 for j in i..self.hidden_dim {
141 intermediate[r] += input[j] * self.down_proj[offset + j];
142 }
143 }
144
145 let scale_vec = _mm256_set1_ps(self.scale);
147
148 let mut i = 0;
149 while i + 8 <= self.hidden_dim {
150 let mut sum = _mm256_setzero_ps();
151
152 for r in 0..self.rank {
153 let up_offset = r * self.hidden_dim;
154 let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
155 let inter = _mm256_set1_ps(intermediate[r]);
156 sum = _mm256_fmadd_ps(inter, weight, sum);
157 }
158
159 sum = _mm256_mul_ps(sum, scale_vec);
161 let existing = _mm256_loadu_ps(output[i..].as_ptr());
162 let result = _mm256_add_ps(existing, sum);
163 _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
164
165 i += 8;
166 }
167
168 for j in i..self.hidden_dim {
170 let mut val = 0.0;
171 for r in 0..self.rank {
172 val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
173 }
174 output[j] += val * self.scale;
175 }
176 }
177 }
178
179 pub fn forward(&self, input: &[f32], output: &mut [f32]) {
181 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
182 {
183 self.forward_simd(input, output);
184 return;
185 }
186
187 #[allow(unreachable_code)]
188 self.forward_scalar(input, output);
189 }
190
191 pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
193 if signal.gradient_estimate.len() != self.hidden_dim {
194 return;
195 }
196
197 let quality = signal.quality_score;
198
199 for r in 0..self.rank {
202 for i in 0..self.hidden_dim {
203 let grad_idx = r * self.hidden_dim + i;
204 self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
206 }
207 }
208
209 self.update_count += 1;
210 }
211
212 pub fn apply_accumulated(&mut self, learning_rate: f32) {
214 if self.update_count == 0 {
215 return;
216 }
217
218 let scale = learning_rate / self.update_count as f32;
219
220 for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
222 *w += g * scale;
223 }
224
225 self.grad_up.fill(0.0);
227 self.grad_down.fill(0.0);
228 self.update_count = 0;
229 }
230
231 pub fn reset(&mut self) {
233 self.up_proj.fill(0.0);
234 self.grad_up.fill(0.0);
235 self.grad_down.fill(0.0);
236 self.update_count = 0;
237 }
238
239 pub fn rank(&self) -> usize {
241 self.rank
242 }
243
244 pub fn hidden_dim(&self) -> usize {
246 self.hidden_dim
247 }
248
249 pub fn param_count(&self) -> usize {
251 self.down_proj.len() + self.up_proj.len()
252 }
253
254 pub fn scale(&self) -> f32 {
256 self.scale
257 }
258
259 pub fn set_scale(&mut self, scale: f32) {
261 self.scale = scale;
262 }
263
264 pub fn pending_updates(&self) -> usize {
266 self.update_count
267 }
268
269 pub fn get_weights(&self) -> (&Vec<f32>, &Vec<f32>) {
271 (&self.down_proj, &self.up_proj)
272 }
273
274 pub fn set_weights(&mut self, down_proj: Vec<f32>, up_proj: Vec<f32>) -> Result<(), String> {
283 let expected_down = self.hidden_dim * self.rank;
284 if down_proj.len() != expected_down {
285 return Err(format!(
286 "down_proj dimension mismatch: expected {}, got {}",
287 expected_down,
288 down_proj.len()
289 ));
290 }
291
292 let expected_up = self.rank * self.hidden_dim;
293 if up_proj.len() != expected_up {
294 return Err(format!(
295 "up_proj dimension mismatch: expected {}, got {}",
296 expected_up,
297 up_proj.len()
298 ));
299 }
300
301 self.down_proj = down_proj;
302 self.up_proj = up_proj;
303 Ok(())
304 }
305}
306
307#[derive(Clone, Debug, Serialize, Deserialize)]
312pub struct BaseLoRA {
313 pub layers: Vec<LoRALayer>,
315 pub rank: usize,
317 pub hidden_dim: usize,
319 pub alpha: f32,
321}
322
323#[derive(Clone, Debug, Serialize, Deserialize)]
325pub struct LoRALayer {
326 pub down_proj: Vec<f32>,
328 pub up_proj: Vec<f32>,
330 pub layer_idx: usize,
332}
333
334impl BaseLoRA {
335 pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
337 let layers = (0..num_layers)
338 .map(|idx| LoRALayer {
339 down_proj: vec![0.0; hidden_dim * rank],
340 up_proj: vec![0.0; rank * hidden_dim],
341 layer_idx: idx,
342 })
343 .collect();
344
345 Self {
346 layers,
347 rank,
348 hidden_dim,
349 alpha: rank as f32,
350 }
351 }
352
353 pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
355 if layer_idx >= self.layers.len() {
356 return;
357 }
358
359 let layer = &self.layers[layer_idx];
360 let scale = self.alpha / self.rank as f32;
361
362 let mut intermediate = vec![0.0f32; self.rank];
364 for (r, inter) in intermediate.iter_mut().enumerate() {
365 let offset = r * self.hidden_dim;
366 *inter = input
367 .iter()
368 .zip(&layer.down_proj[offset..offset + self.hidden_dim])
369 .map(|(a, b)| a * b)
370 .sum();
371 }
372
373 for (i, out) in output.iter_mut().enumerate() {
375 let mut sum = 0.0f32;
376 for (r, &inter) in intermediate.iter().enumerate() {
377 sum += inter * layer.up_proj[r * self.hidden_dim + i];
378 }
379 *out += sum * scale;
380 }
381 }
382
383 pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
385 if layer_idx >= self.layers.len() {
386 return;
387 }
388
389 let layer = &self.layers[layer_idx];
390 let scale = self.alpha / self.rank as f32;
391
392 for i in 0..self.hidden_dim {
395 for j in 0..self.hidden_dim {
396 let mut delta = 0.0f32;
397 for r in 0..self.rank {
398 delta +=
399 layer.down_proj[i * self.rank + r] * layer.up_proj[r * self.hidden_dim + j];
400 }
401 model_weights[i * self.hidden_dim + j] += delta * scale;
402 }
403 }
404 }
405
406 pub fn num_layers(&self) -> usize {
408 self.layers.len()
409 }
410
411 pub fn param_count(&self) -> usize {
413 self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
414 }
415
416 pub fn get_layer_weights(&self, layer_idx: usize) -> Option<(&Vec<f32>, &Vec<f32>)> {
418 self.layers
419 .get(layer_idx)
420 .map(|layer| (&layer.down_proj, &layer.up_proj))
421 }
422}
423
424#[derive(Clone, Debug)]
426pub struct LoRAEngine {
427 pub micro: MicroLoRA,
429 pub base: BaseLoRA,
431 pub micro_enabled: bool,
433 pub base_enabled: bool,
435}
436
437impl LoRAEngine {
438 pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
440 Self {
441 micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
442 base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
443 micro_enabled: true,
444 base_enabled: true,
445 }
446 }
447
448 pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
450 if self.micro_enabled {
451 self.micro.forward(input, output);
452 }
453 if self.base_enabled && layer_idx < self.base.num_layers() {
454 self.base.forward_layer(layer_idx, input, output);
455 }
456 }
457
458 pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
460 if self.micro_enabled {
461 self.micro.accumulate_gradient(signal);
462 }
463 }
464
465 pub fn apply_micro(&mut self, learning_rate: f32) {
467 if self.micro_enabled {
468 self.micro.apply_accumulated(learning_rate);
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_micro_lora_creation() {
479 let lora = MicroLoRA::new(256, 1);
480 assert_eq!(lora.rank(), 1);
481 assert_eq!(lora.hidden_dim(), 256);
482 assert_eq!(lora.param_count(), 256 + 256);
483 }
484
485 #[test]
486 fn test_micro_lora_forward() {
487 let lora = MicroLoRA::new(64, 1);
488 let input = vec![1.0f32; 64];
489 let mut output = vec![0.0f32; 64];
490
491 lora.forward(&input, &mut output);
492
493 let sum: f32 = output.iter().sum();
496 assert!(
497 sum.abs() < 1e-6,
498 "Expected ~0 with zero up_proj, got {}",
499 sum
500 );
501 }
502
503 #[test]
504 fn test_micro_lora_learning() {
505 let mut lora = MicroLoRA::new(64, 1);
506
507 let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.8);
508
509 lora.accumulate_gradient(&signal);
510 assert_eq!(lora.pending_updates(), 1);
511
512 lora.apply_accumulated(0.01);
513 assert_eq!(lora.pending_updates(), 0);
514
515 let input = vec![1.0f32; 64];
517 let mut output = vec![0.0f32; 64];
518 lora.forward(&input, &mut output);
519
520 let sum: f32 = output.iter().map(|x| x.abs()).sum();
521 assert!(sum > 0.0, "Expected non-zero output after learning");
522 }
523
524 #[test]
525 fn test_base_lora() {
526 let lora = BaseLoRA::new(64, 4, 12);
527 assert_eq!(lora.num_layers(), 12);
528 assert_eq!(lora.rank, 4);
529 }
530
531 #[test]
532 fn test_lora_engine() {
533 let mut engine = LoRAEngine::new(64, 1, 4, 12);
534
535 let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.9);
536
537 engine.accumulate_micro(&signal);
538 engine.apply_micro(0.01);
539
540 let input = vec![1.0f32; 64];
541 let mut output = vec![0.0f32; 64];
542 engine.forward(0, &input, &mut output);
543 }
544
545 #[test]
546 #[should_panic(expected = "MicroLoRA rank must be 1-2")]
547 fn test_invalid_rank() {
548 MicroLoRA::new(64, 5);
549 }
550
551 #[test]
552 fn test_set_weights_valid() {
553 let mut lora = MicroLoRA::new(64, 2);
554 let down = vec![1.0f32; 64 * 2];
555 let up = vec![0.5f32; 2 * 64];
556
557 let result = lora.set_weights(down.clone(), up.clone());
558 assert!(result.is_ok());
559
560 let (got_down, got_up) = lora.get_weights();
561 assert_eq!(got_down, &down);
562 assert_eq!(got_up, &up);
563 }
564
565 #[test]
566 fn test_set_weights_wrong_down_dim() {
567 let mut lora = MicroLoRA::new(64, 2);
568 let wrong_down = vec![1.0f32; 64 * 3];
569 let up = vec![0.5f32; 2 * 64];
570
571 let result = lora.set_weights(wrong_down, up);
572 assert!(result.is_err());
573 assert!(result.unwrap_err().contains("down_proj dimension mismatch"));
574 }
575
576 #[test]
577 fn test_set_weights_wrong_up_dim() {
578 let mut lora = MicroLoRA::new(64, 2);
579 let down = vec![1.0f32; 64 * 2];
580 let wrong_up = vec![0.5f32; 3 * 64];
581
582 let result = lora.set_weights(down, wrong_up);
583 assert!(result.is_err());
584 assert!(result.unwrap_err().contains("up_proj dimension mismatch"));
585 }
586}