1use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct MicroLoRA {
16 down_proj: Vec<f32>,
18 up_proj: Vec<f32>,
20 rank: usize,
22 hidden_dim: usize,
24 #[serde(skip)]
26 grad_down: Vec<f32>,
27 #[serde(skip)]
29 grad_up: Vec<f32>,
30 #[serde(skip)]
32 update_count: usize,
33 scale: f32,
35}
36
37impl MicroLoRA {
38 pub fn new(hidden_dim: usize, rank: usize) -> Self {
47 assert!(rank >= 1 && rank <= 2, "MicroLoRA rank must be 1-2, got {}", rank);
48
49 let down_proj: Vec<f32> = (0..hidden_dim * rank)
51 .map(|i| {
52 let x = (i as f32 * 0.618033988749895) % 1.0;
53 (x - 0.5) * 0.02
54 })
55 .collect();
56
57 let up_proj = vec![0.0f32; rank * hidden_dim];
59
60 Self {
61 down_proj,
62 up_proj,
63 rank,
64 hidden_dim,
65 grad_down: vec![0.0; hidden_dim * rank],
66 grad_up: vec![0.0; rank * hidden_dim],
67 update_count: 0,
68 scale: 1.0 / (rank as f32).sqrt(),
69 }
70 }
71
72 pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
74 assert_eq!(input.len(), self.hidden_dim);
75 assert_eq!(output.len(), self.hidden_dim);
76
77 let mut intermediate = vec![0.0f32; self.rank];
79 for r in 0..self.rank {
80 let mut sum = 0.0f32;
81 let offset = r * self.hidden_dim;
82 for i in 0..self.hidden_dim {
83 sum += input[i] * self.down_proj[offset + i];
84 }
85 intermediate[r] = sum;
86 }
87
88 for i in 0..self.hidden_dim {
90 let mut sum = 0.0f32;
91 for r in 0..self.rank {
92 sum += intermediate[r] * self.up_proj[r * self.hidden_dim + i];
93 }
94 output[i] += sum * self.scale;
95 }
96 }
97
98 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
100 pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
101 use std::arch::x86_64::*;
102
103 assert_eq!(input.len(), self.hidden_dim);
104 assert_eq!(output.len(), self.hidden_dim);
105
106 unsafe {
107 let mut intermediate = vec![0.0f32; self.rank];
109
110 for r in 0..self.rank {
111 let mut sum = _mm256_setzero_ps();
112 let offset = r * self.hidden_dim;
113
114 let mut i = 0;
115 while i + 8 <= self.hidden_dim {
116 let inp = _mm256_loadu_ps(input[i..].as_ptr());
117 let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
118 sum = _mm256_fmadd_ps(inp, weight, sum);
119 i += 8;
120 }
121
122 let mut result = [0.0f32; 8];
124 _mm256_storeu_ps(result.as_mut_ptr(), sum);
125 intermediate[r] = result.iter().sum();
126
127 for j in i..self.hidden_dim {
129 intermediate[r] += input[j] * self.down_proj[offset + j];
130 }
131 }
132
133 let scale_vec = _mm256_set1_ps(self.scale);
135
136 let mut i = 0;
137 while i + 8 <= self.hidden_dim {
138 let mut sum = _mm256_setzero_ps();
139
140 for r in 0..self.rank {
141 let up_offset = r * self.hidden_dim;
142 let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
143 let inter = _mm256_set1_ps(intermediate[r]);
144 sum = _mm256_fmadd_ps(inter, weight, sum);
145 }
146
147 sum = _mm256_mul_ps(sum, scale_vec);
149 let existing = _mm256_loadu_ps(output[i..].as_ptr());
150 let result = _mm256_add_ps(existing, sum);
151 _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
152
153 i += 8;
154 }
155
156 for j in i..self.hidden_dim {
158 let mut val = 0.0;
159 for r in 0..self.rank {
160 val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
161 }
162 output[j] += val * self.scale;
163 }
164 }
165 }
166
167 pub fn forward(&self, input: &[f32], output: &mut [f32]) {
169 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
170 {
171 self.forward_simd(input, output);
172 return;
173 }
174
175 #[allow(unreachable_code)]
176 self.forward_scalar(input, output);
177 }
178
179 pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
181 if signal.gradient_estimate.len() != self.hidden_dim {
182 return;
183 }
184
185 let quality = signal.quality_score;
186
187 for r in 0..self.rank {
190 for i in 0..self.hidden_dim {
191 let grad_idx = r * self.hidden_dim + i;
192 self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
194 }
195 }
196
197 self.update_count += 1;
198 }
199
200 pub fn apply_accumulated(&mut self, learning_rate: f32) {
202 if self.update_count == 0 {
203 return;
204 }
205
206 let scale = learning_rate / self.update_count as f32;
207
208 for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
210 *w += g * scale;
211 }
212
213 self.grad_up.fill(0.0);
215 self.grad_down.fill(0.0);
216 self.update_count = 0;
217 }
218
219 pub fn reset(&mut self) {
221 self.up_proj.fill(0.0);
222 self.grad_up.fill(0.0);
223 self.grad_down.fill(0.0);
224 self.update_count = 0;
225 }
226
227 pub fn rank(&self) -> usize {
229 self.rank
230 }
231
232 pub fn hidden_dim(&self) -> usize {
234 self.hidden_dim
235 }
236
237 pub fn param_count(&self) -> usize {
239 self.down_proj.len() + self.up_proj.len()
240 }
241
242 pub fn scale(&self) -> f32 {
244 self.scale
245 }
246
247 pub fn set_scale(&mut self, scale: f32) {
249 self.scale = scale;
250 }
251
252 pub fn pending_updates(&self) -> usize {
254 self.update_count
255 }
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
263pub struct BaseLoRA {
264 pub layers: Vec<LoRALayer>,
266 pub rank: usize,
268 pub hidden_dim: usize,
270 pub alpha: f32,
272}
273
274#[derive(Clone, Debug, Serialize, Deserialize)]
276pub struct LoRALayer {
277 pub down_proj: Vec<f32>,
279 pub up_proj: Vec<f32>,
281 pub layer_idx: usize,
283}
284
285impl BaseLoRA {
286 pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
288 let layers = (0..num_layers)
289 .map(|idx| LoRALayer {
290 down_proj: vec![0.0; hidden_dim * rank],
291 up_proj: vec![0.0; rank * hidden_dim],
292 layer_idx: idx,
293 })
294 .collect();
295
296 Self {
297 layers,
298 rank,
299 hidden_dim,
300 alpha: rank as f32,
301 }
302 }
303
304 pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
306 if layer_idx >= self.layers.len() {
307 return;
308 }
309
310 let layer = &self.layers[layer_idx];
311 let scale = self.alpha / self.rank as f32;
312
313 let mut intermediate = vec![0.0f32; self.rank];
315 for r in 0..self.rank {
316 let offset = r * self.hidden_dim;
317 intermediate[r] = input.iter()
318 .zip(&layer.down_proj[offset..offset + self.hidden_dim])
319 .map(|(a, b)| a * b)
320 .sum();
321 }
322
323 for i in 0..self.hidden_dim {
325 let mut sum = 0.0f32;
326 for r in 0..self.rank {
327 sum += intermediate[r] * layer.up_proj[r * self.hidden_dim + i];
328 }
329 output[i] += sum * scale;
330 }
331 }
332
333 pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
335 if layer_idx >= self.layers.len() {
336 return;
337 }
338
339 let layer = &self.layers[layer_idx];
340 let scale = self.alpha / self.rank as f32;
341
342 for i in 0..self.hidden_dim {
345 for j in 0..self.hidden_dim {
346 let mut delta = 0.0f32;
347 for r in 0..self.rank {
348 delta += layer.down_proj[i * self.rank + r]
349 * layer.up_proj[r * self.hidden_dim + j];
350 }
351 model_weights[i * self.hidden_dim + j] += delta * scale;
352 }
353 }
354 }
355
356 pub fn num_layers(&self) -> usize {
358 self.layers.len()
359 }
360
361 pub fn param_count(&self) -> usize {
363 self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
364 }
365}
366
367#[derive(Clone, Debug)]
369pub struct LoRAEngine {
370 pub micro: MicroLoRA,
372 pub base: BaseLoRA,
374 pub micro_enabled: bool,
376 pub base_enabled: bool,
378}
379
380impl LoRAEngine {
381 pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
383 Self {
384 micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
385 base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
386 micro_enabled: true,
387 base_enabled: true,
388 }
389 }
390
391 pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
393 if self.micro_enabled {
394 self.micro.forward(input, output);
395 }
396 if self.base_enabled && layer_idx < self.base.num_layers() {
397 self.base.forward_layer(layer_idx, input, output);
398 }
399 }
400
401 pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
403 if self.micro_enabled {
404 self.micro.accumulate_gradient(signal);
405 }
406 }
407
408 pub fn apply_micro(&mut self, learning_rate: f32) {
410 if self.micro_enabled {
411 self.micro.apply_accumulated(learning_rate);
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_micro_lora_creation() {
422 let lora = MicroLoRA::new(256, 1);
423 assert_eq!(lora.rank(), 1);
424 assert_eq!(lora.hidden_dim(), 256);
425 assert_eq!(lora.param_count(), 256 + 256);
426 }
427
428 #[test]
429 fn test_micro_lora_forward() {
430 let lora = MicroLoRA::new(64, 1);
431 let input = vec![1.0f32; 64];
432 let mut output = vec![0.0f32; 64];
433
434 lora.forward(&input, &mut output);
435
436 let sum: f32 = output.iter().sum();
439 assert!(sum.abs() < 1e-6, "Expected ~0 with zero up_proj, got {}", sum);
440 }
441
442 #[test]
443 fn test_micro_lora_learning() {
444 let mut lora = MicroLoRA::new(64, 1);
445
446 let signal = LearningSignal::with_gradient(
447 vec![0.1; 64],
448 vec![0.5; 64],
449 0.8,
450 );
451
452 lora.accumulate_gradient(&signal);
453 assert_eq!(lora.pending_updates(), 1);
454
455 lora.apply_accumulated(0.01);
456 assert_eq!(lora.pending_updates(), 0);
457
458 let input = vec![1.0f32; 64];
460 let mut output = vec![0.0f32; 64];
461 lora.forward(&input, &mut output);
462
463 let sum: f32 = output.iter().map(|x| x.abs()).sum();
464 assert!(sum > 0.0, "Expected non-zero output after learning");
465 }
466
467 #[test]
468 fn test_base_lora() {
469 let lora = BaseLoRA::new(64, 4, 12);
470 assert_eq!(lora.num_layers(), 12);
471 assert_eq!(lora.rank, 4);
472 }
473
474 #[test]
475 fn test_lora_engine() {
476 let mut engine = LoRAEngine::new(64, 1, 4, 12);
477
478 let signal = LearningSignal::with_gradient(
479 vec![0.1; 64],
480 vec![0.5; 64],
481 0.9,
482 );
483
484 engine.accumulate_micro(&signal);
485 engine.apply_micro(0.01);
486
487 let input = vec![1.0f32; 64];
488 let mut output = vec![0.0f32; 64];
489 engine.forward(0, &input, &mut output);
490 }
491
492 #[test]
493 #[should_panic(expected = "MicroLoRA rank must be 1-2")]
494 fn test_invalid_rank() {
495 MicroLoRA::new(64, 5);
496 }
497}