1use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10pub const OPTIMAL_BATCH_SIZE: usize = 32;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct MicroLoRA {
24 down_proj: Vec<f32>,
26 up_proj: Vec<f32>,
28 rank: usize,
30 hidden_dim: usize,
32 #[serde(skip)]
34 grad_down: Vec<f32>,
35 #[serde(skip)]
37 grad_up: Vec<f32>,
38 #[serde(skip)]
40 update_count: usize,
41 scale: f32,
43}
44
45impl MicroLoRA {
46 pub fn new(hidden_dim: usize, rank: usize) -> Self {
55 assert!(rank >= 1 && rank <= 2, "MicroLoRA rank must be 1-2, got {}", rank);
56
57 let down_proj: Vec<f32> = (0..hidden_dim * rank)
59 .map(|i| {
60 let x = (i as f32 * 0.618033988749895) % 1.0;
61 (x - 0.5) * 0.02
62 })
63 .collect();
64
65 let up_proj = vec![0.0f32; rank * hidden_dim];
67
68 Self {
69 down_proj,
70 up_proj,
71 rank,
72 hidden_dim,
73 grad_down: vec![0.0; hidden_dim * rank],
74 grad_up: vec![0.0; rank * hidden_dim],
75 update_count: 0,
76 scale: 1.0 / (rank as f32).sqrt(),
77 }
78 }
79
80 pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
82 assert_eq!(input.len(), self.hidden_dim);
83 assert_eq!(output.len(), self.hidden_dim);
84
85 let mut intermediate = vec![0.0f32; self.rank];
87 for r in 0..self.rank {
88 let mut sum = 0.0f32;
89 let offset = r * self.hidden_dim;
90 for i in 0..self.hidden_dim {
91 sum += input[i] * self.down_proj[offset + i];
92 }
93 intermediate[r] = sum;
94 }
95
96 for i in 0..self.hidden_dim {
98 let mut sum = 0.0f32;
99 for r in 0..self.rank {
100 sum += intermediate[r] * self.up_proj[r * self.hidden_dim + i];
101 }
102 output[i] += sum * self.scale;
103 }
104 }
105
106 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
108 pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
109 use std::arch::x86_64::*;
110
111 assert_eq!(input.len(), self.hidden_dim);
112 assert_eq!(output.len(), self.hidden_dim);
113
114 unsafe {
115 let mut intermediate = vec![0.0f32; self.rank];
117
118 for r in 0..self.rank {
119 let mut sum = _mm256_setzero_ps();
120 let offset = r * self.hidden_dim;
121
122 let mut i = 0;
123 while i + 8 <= self.hidden_dim {
124 let inp = _mm256_loadu_ps(input[i..].as_ptr());
125 let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
126 sum = _mm256_fmadd_ps(inp, weight, sum);
127 i += 8;
128 }
129
130 let mut result = [0.0f32; 8];
132 _mm256_storeu_ps(result.as_mut_ptr(), sum);
133 intermediate[r] = result.iter().sum();
134
135 for j in i..self.hidden_dim {
137 intermediate[r] += input[j] * self.down_proj[offset + j];
138 }
139 }
140
141 let scale_vec = _mm256_set1_ps(self.scale);
143
144 let mut i = 0;
145 while i + 8 <= self.hidden_dim {
146 let mut sum = _mm256_setzero_ps();
147
148 for r in 0..self.rank {
149 let up_offset = r * self.hidden_dim;
150 let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
151 let inter = _mm256_set1_ps(intermediate[r]);
152 sum = _mm256_fmadd_ps(inter, weight, sum);
153 }
154
155 sum = _mm256_mul_ps(sum, scale_vec);
157 let existing = _mm256_loadu_ps(output[i..].as_ptr());
158 let result = _mm256_add_ps(existing, sum);
159 _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
160
161 i += 8;
162 }
163
164 for j in i..self.hidden_dim {
166 let mut val = 0.0;
167 for r in 0..self.rank {
168 val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
169 }
170 output[j] += val * self.scale;
171 }
172 }
173 }
174
175 pub fn forward(&self, input: &[f32], output: &mut [f32]) {
177 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
178 {
179 self.forward_simd(input, output);
180 return;
181 }
182
183 #[allow(unreachable_code)]
184 self.forward_scalar(input, output);
185 }
186
187 pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
189 if signal.gradient_estimate.len() != self.hidden_dim {
190 return;
191 }
192
193 let quality = signal.quality_score;
194
195 for r in 0..self.rank {
198 for i in 0..self.hidden_dim {
199 let grad_idx = r * self.hidden_dim + i;
200 self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
202 }
203 }
204
205 self.update_count += 1;
206 }
207
208 pub fn apply_accumulated(&mut self, learning_rate: f32) {
210 if self.update_count == 0 {
211 return;
212 }
213
214 let scale = learning_rate / self.update_count as f32;
215
216 for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
218 *w += g * scale;
219 }
220
221 self.grad_up.fill(0.0);
223 self.grad_down.fill(0.0);
224 self.update_count = 0;
225 }
226
227 pub fn reset(&mut self) {
229 self.up_proj.fill(0.0);
230 self.grad_up.fill(0.0);
231 self.grad_down.fill(0.0);
232 self.update_count = 0;
233 }
234
235 pub fn rank(&self) -> usize {
237 self.rank
238 }
239
240 pub fn hidden_dim(&self) -> usize {
242 self.hidden_dim
243 }
244
245 pub fn param_count(&self) -> usize {
247 self.down_proj.len() + self.up_proj.len()
248 }
249
250 pub fn scale(&self) -> f32 {
252 self.scale
253 }
254
255 pub fn set_scale(&mut self, scale: f32) {
257 self.scale = scale;
258 }
259
260 pub fn pending_updates(&self) -> usize {
262 self.update_count
263 }
264
265 pub fn get_weights(&self) -> (&Vec<f32>, &Vec<f32>) {
267 (&self.down_proj, &self.up_proj)
268 }
269}
270
271#[derive(Clone, Debug, Serialize, Deserialize)]
276pub struct BaseLoRA {
277 pub layers: Vec<LoRALayer>,
279 pub rank: usize,
281 pub hidden_dim: usize,
283 pub alpha: f32,
285}
286
287#[derive(Clone, Debug, Serialize, Deserialize)]
289pub struct LoRALayer {
290 pub down_proj: Vec<f32>,
292 pub up_proj: Vec<f32>,
294 pub layer_idx: usize,
296}
297
298impl BaseLoRA {
299 pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
301 let layers = (0..num_layers)
302 .map(|idx| LoRALayer {
303 down_proj: vec![0.0; hidden_dim * rank],
304 up_proj: vec![0.0; rank * hidden_dim],
305 layer_idx: idx,
306 })
307 .collect();
308
309 Self {
310 layers,
311 rank,
312 hidden_dim,
313 alpha: rank as f32,
314 }
315 }
316
317 pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
319 if layer_idx >= self.layers.len() {
320 return;
321 }
322
323 let layer = &self.layers[layer_idx];
324 let scale = self.alpha / self.rank as f32;
325
326 let mut intermediate = vec![0.0f32; self.rank];
328 for r in 0..self.rank {
329 let offset = r * self.hidden_dim;
330 intermediate[r] = input.iter()
331 .zip(&layer.down_proj[offset..offset + self.hidden_dim])
332 .map(|(a, b)| a * b)
333 .sum();
334 }
335
336 for i in 0..self.hidden_dim {
338 let mut sum = 0.0f32;
339 for r in 0..self.rank {
340 sum += intermediate[r] * layer.up_proj[r * self.hidden_dim + i];
341 }
342 output[i] += sum * scale;
343 }
344 }
345
346 pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
348 if layer_idx >= self.layers.len() {
349 return;
350 }
351
352 let layer = &self.layers[layer_idx];
353 let scale = self.alpha / self.rank as f32;
354
355 for i in 0..self.hidden_dim {
358 for j in 0..self.hidden_dim {
359 let mut delta = 0.0f32;
360 for r in 0..self.rank {
361 delta += layer.down_proj[i * self.rank + r]
362 * layer.up_proj[r * self.hidden_dim + j];
363 }
364 model_weights[i * self.hidden_dim + j] += delta * scale;
365 }
366 }
367 }
368
369 pub fn num_layers(&self) -> usize {
371 self.layers.len()
372 }
373
374 pub fn param_count(&self) -> usize {
376 self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
377 }
378
379 pub fn get_layer_weights(&self, layer_idx: usize) -> Option<(&Vec<f32>, &Vec<f32>)> {
381 self.layers.get(layer_idx).map(|layer| (&layer.down_proj, &layer.up_proj))
382 }
383}
384
385#[derive(Clone, Debug)]
387pub struct LoRAEngine {
388 pub micro: MicroLoRA,
390 pub base: BaseLoRA,
392 pub micro_enabled: bool,
394 pub base_enabled: bool,
396}
397
398impl LoRAEngine {
399 pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
401 Self {
402 micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
403 base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
404 micro_enabled: true,
405 base_enabled: true,
406 }
407 }
408
409 pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
411 if self.micro_enabled {
412 self.micro.forward(input, output);
413 }
414 if self.base_enabled && layer_idx < self.base.num_layers() {
415 self.base.forward_layer(layer_idx, input, output);
416 }
417 }
418
419 pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
421 if self.micro_enabled {
422 self.micro.accumulate_gradient(signal);
423 }
424 }
425
426 pub fn apply_micro(&mut self, learning_rate: f32) {
428 if self.micro_enabled {
429 self.micro.apply_accumulated(learning_rate);
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_micro_lora_creation() {
440 let lora = MicroLoRA::new(256, 1);
441 assert_eq!(lora.rank(), 1);
442 assert_eq!(lora.hidden_dim(), 256);
443 assert_eq!(lora.param_count(), 256 + 256);
444 }
445
446 #[test]
447 fn test_micro_lora_forward() {
448 let lora = MicroLoRA::new(64, 1);
449 let input = vec![1.0f32; 64];
450 let mut output = vec![0.0f32; 64];
451
452 lora.forward(&input, &mut output);
453
454 let sum: f32 = output.iter().sum();
457 assert!(sum.abs() < 1e-6, "Expected ~0 with zero up_proj, got {}", sum);
458 }
459
460 #[test]
461 fn test_micro_lora_learning() {
462 let mut lora = MicroLoRA::new(64, 1);
463
464 let signal = LearningSignal::with_gradient(
465 vec![0.1; 64],
466 vec![0.5; 64],
467 0.8,
468 );
469
470 lora.accumulate_gradient(&signal);
471 assert_eq!(lora.pending_updates(), 1);
472
473 lora.apply_accumulated(0.01);
474 assert_eq!(lora.pending_updates(), 0);
475
476 let input = vec![1.0f32; 64];
478 let mut output = vec![0.0f32; 64];
479 lora.forward(&input, &mut output);
480
481 let sum: f32 = output.iter().map(|x| x.abs()).sum();
482 assert!(sum > 0.0, "Expected non-zero output after learning");
483 }
484
485 #[test]
486 fn test_base_lora() {
487 let lora = BaseLoRA::new(64, 4, 12);
488 assert_eq!(lora.num_layers(), 12);
489 assert_eq!(lora.rank, 4);
490 }
491
492 #[test]
493 fn test_lora_engine() {
494 let mut engine = LoRAEngine::new(64, 1, 4, 12);
495
496 let signal = LearningSignal::with_gradient(
497 vec![0.1; 64],
498 vec![0.5; 64],
499 0.9,
500 );
501
502 engine.accumulate_micro(&signal);
503 engine.apply_micro(0.01);
504
505 let input = vec![1.0f32; 64];
506 let mut output = vec![0.0f32; 64];
507 engine.forward(0, &input, &mut output);
508 }
509
510 #[test]
511 #[should_panic(expected = "MicroLoRA rank must be 1-2")]
512 fn test_invalid_rank() {
513 MicroLoRA::new(64, 5);
514 }
515}