1use crate::types::LearningSignal;
8use serde::{Deserialize, Serialize};
9
10pub const OPTIMAL_BATCH_SIZE: usize = 32;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct MicroLoRA {
24 down_proj: Vec<f32>,
26 up_proj: Vec<f32>,
28 rank: usize,
30 hidden_dim: usize,
32 #[serde(skip)]
34 grad_down: Vec<f32>,
35 #[serde(skip)]
37 grad_up: Vec<f32>,
38 #[serde(skip)]
40 update_count: usize,
41 scale: f32,
43}
44
45impl MicroLoRA {
46 pub fn new(hidden_dim: usize, rank: usize) -> Self {
55 assert!(
56 rank >= 1 && rank <= 2,
57 "MicroLoRA rank must be 1-2, got {}",
58 rank
59 );
60
61 let down_proj: Vec<f32> = (0..hidden_dim * rank)
63 .map(|i| {
64 let x = (i as f32 * 0.618033988749895) % 1.0;
65 (x - 0.5) * 0.02
66 })
67 .collect();
68
69 let up_proj = vec![0.0f32; rank * hidden_dim];
71
72 Self {
73 down_proj,
74 up_proj,
75 rank,
76 hidden_dim,
77 grad_down: vec![0.0; hidden_dim * rank],
78 grad_up: vec![0.0; rank * hidden_dim],
79 update_count: 0,
80 scale: 1.0 / (rank as f32).sqrt(),
81 }
82 }
83
84 pub fn forward_scalar(&self, input: &[f32], output: &mut [f32]) {
86 assert_eq!(input.len(), self.hidden_dim);
87 assert_eq!(output.len(), self.hidden_dim);
88
89 let mut intermediate = vec![0.0f32; self.rank];
91 for r in 0..self.rank {
92 let mut sum = 0.0f32;
93 let offset = r * self.hidden_dim;
94 for i in 0..self.hidden_dim {
95 sum += input[i] * self.down_proj[offset + i];
96 }
97 intermediate[r] = sum;
98 }
99
100 for i in 0..self.hidden_dim {
102 let mut sum = 0.0f32;
103 for r in 0..self.rank {
104 sum += intermediate[r] * self.up_proj[r * self.hidden_dim + i];
105 }
106 output[i] += sum * self.scale;
107 }
108 }
109
110 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
112 pub fn forward_simd(&self, input: &[f32], output: &mut [f32]) {
113 use std::arch::x86_64::*;
114
115 assert_eq!(input.len(), self.hidden_dim);
116 assert_eq!(output.len(), self.hidden_dim);
117
118 unsafe {
119 let mut intermediate = vec![0.0f32; self.rank];
121
122 for r in 0..self.rank {
123 let mut sum = _mm256_setzero_ps();
124 let offset = r * self.hidden_dim;
125
126 let mut i = 0;
127 while i + 8 <= self.hidden_dim {
128 let inp = _mm256_loadu_ps(input[i..].as_ptr());
129 let weight = _mm256_loadu_ps(self.down_proj[offset + i..].as_ptr());
130 sum = _mm256_fmadd_ps(inp, weight, sum);
131 i += 8;
132 }
133
134 let mut result = [0.0f32; 8];
136 _mm256_storeu_ps(result.as_mut_ptr(), sum);
137 intermediate[r] = result.iter().sum();
138
139 for j in i..self.hidden_dim {
141 intermediate[r] += input[j] * self.down_proj[offset + j];
142 }
143 }
144
145 let scale_vec = _mm256_set1_ps(self.scale);
147
148 let mut i = 0;
149 while i + 8 <= self.hidden_dim {
150 let mut sum = _mm256_setzero_ps();
151
152 for r in 0..self.rank {
153 let up_offset = r * self.hidden_dim;
154 let weight = _mm256_loadu_ps(self.up_proj[up_offset + i..].as_ptr());
155 let inter = _mm256_set1_ps(intermediate[r]);
156 sum = _mm256_fmadd_ps(inter, weight, sum);
157 }
158
159 sum = _mm256_mul_ps(sum, scale_vec);
161 let existing = _mm256_loadu_ps(output[i..].as_ptr());
162 let result = _mm256_add_ps(existing, sum);
163 _mm256_storeu_ps(output[i..].as_mut_ptr(), result);
164
165 i += 8;
166 }
167
168 for j in i..self.hidden_dim {
170 let mut val = 0.0;
171 for r in 0..self.rank {
172 val += intermediate[r] * self.up_proj[r * self.hidden_dim + j];
173 }
174 output[j] += val * self.scale;
175 }
176 }
177 }
178
179 pub fn forward(&self, input: &[f32], output: &mut [f32]) {
181 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
182 {
183 self.forward_simd(input, output);
184 return;
185 }
186
187 #[allow(unreachable_code)]
188 self.forward_scalar(input, output);
189 }
190
191 pub fn accumulate_gradient(&mut self, signal: &LearningSignal) {
193 if signal.gradient_estimate.len() != self.hidden_dim {
194 return;
195 }
196
197 let quality = signal.quality_score;
198
199 for r in 0..self.rank {
202 for i in 0..self.hidden_dim {
203 let grad_idx = r * self.hidden_dim + i;
204 self.grad_up[grad_idx] += signal.gradient_estimate[i] * quality;
206 }
207 }
208
209 self.update_count += 1;
210 }
211
212 pub fn apply_accumulated(&mut self, learning_rate: f32) {
214 if self.update_count == 0 {
215 return;
216 }
217
218 let scale = learning_rate / self.update_count as f32;
219
220 for (w, g) in self.up_proj.iter_mut().zip(self.grad_up.iter()) {
222 *w += g * scale;
223 }
224
225 self.grad_up.fill(0.0);
227 self.grad_down.fill(0.0);
228 self.update_count = 0;
229 }
230
231 pub fn reset(&mut self) {
233 self.up_proj.fill(0.0);
234 self.grad_up.fill(0.0);
235 self.grad_down.fill(0.0);
236 self.update_count = 0;
237 }
238
239 pub fn rank(&self) -> usize {
241 self.rank
242 }
243
244 pub fn hidden_dim(&self) -> usize {
246 self.hidden_dim
247 }
248
249 pub fn param_count(&self) -> usize {
251 self.down_proj.len() + self.up_proj.len()
252 }
253
254 pub fn scale(&self) -> f32 {
256 self.scale
257 }
258
259 pub fn set_scale(&mut self, scale: f32) {
261 self.scale = scale;
262 }
263
264 pub fn pending_updates(&self) -> usize {
266 self.update_count
267 }
268
269 pub fn get_weights(&self) -> (&Vec<f32>, &Vec<f32>) {
271 (&self.down_proj, &self.up_proj)
272 }
273}
274
275#[derive(Clone, Debug, Serialize, Deserialize)]
280pub struct BaseLoRA {
281 pub layers: Vec<LoRALayer>,
283 pub rank: usize,
285 pub hidden_dim: usize,
287 pub alpha: f32,
289}
290
291#[derive(Clone, Debug, Serialize, Deserialize)]
293pub struct LoRALayer {
294 pub down_proj: Vec<f32>,
296 pub up_proj: Vec<f32>,
298 pub layer_idx: usize,
300}
301
302impl BaseLoRA {
303 pub fn new(hidden_dim: usize, rank: usize, num_layers: usize) -> Self {
305 let layers = (0..num_layers)
306 .map(|idx| LoRALayer {
307 down_proj: vec![0.0; hidden_dim * rank],
308 up_proj: vec![0.0; rank * hidden_dim],
309 layer_idx: idx,
310 })
311 .collect();
312
313 Self {
314 layers,
315 rank,
316 hidden_dim,
317 alpha: rank as f32,
318 }
319 }
320
321 pub fn forward_layer(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
323 if layer_idx >= self.layers.len() {
324 return;
325 }
326
327 let layer = &self.layers[layer_idx];
328 let scale = self.alpha / self.rank as f32;
329
330 let mut intermediate = vec![0.0f32; self.rank];
332 for r in 0..self.rank {
333 let offset = r * self.hidden_dim;
334 intermediate[r] = input
335 .iter()
336 .zip(&layer.down_proj[offset..offset + self.hidden_dim])
337 .map(|(a, b)| a * b)
338 .sum();
339 }
340
341 for i in 0..self.hidden_dim {
343 let mut sum = 0.0f32;
344 for r in 0..self.rank {
345 sum += intermediate[r] * layer.up_proj[r * self.hidden_dim + i];
346 }
347 output[i] += sum * scale;
348 }
349 }
350
351 pub fn merge_into(&self, model_weights: &mut [f32], layer_idx: usize) {
353 if layer_idx >= self.layers.len() {
354 return;
355 }
356
357 let layer = &self.layers[layer_idx];
358 let scale = self.alpha / self.rank as f32;
359
360 for i in 0..self.hidden_dim {
363 for j in 0..self.hidden_dim {
364 let mut delta = 0.0f32;
365 for r in 0..self.rank {
366 delta +=
367 layer.down_proj[i * self.rank + r] * layer.up_proj[r * self.hidden_dim + j];
368 }
369 model_weights[i * self.hidden_dim + j] += delta * scale;
370 }
371 }
372 }
373
374 pub fn num_layers(&self) -> usize {
376 self.layers.len()
377 }
378
379 pub fn param_count(&self) -> usize {
381 self.layers.len() * (self.hidden_dim * self.rank + self.rank * self.hidden_dim)
382 }
383
384 pub fn get_layer_weights(&self, layer_idx: usize) -> Option<(&Vec<f32>, &Vec<f32>)> {
386 self.layers
387 .get(layer_idx)
388 .map(|layer| (&layer.down_proj, &layer.up_proj))
389 }
390}
391
392#[derive(Clone, Debug)]
394pub struct LoRAEngine {
395 pub micro: MicroLoRA,
397 pub base: BaseLoRA,
399 pub micro_enabled: bool,
401 pub base_enabled: bool,
403}
404
405impl LoRAEngine {
406 pub fn new(hidden_dim: usize, micro_rank: usize, base_rank: usize, num_layers: usize) -> Self {
408 Self {
409 micro: MicroLoRA::new(hidden_dim, micro_rank.clamp(1, 2)),
410 base: BaseLoRA::new(hidden_dim, base_rank, num_layers),
411 micro_enabled: true,
412 base_enabled: true,
413 }
414 }
415
416 pub fn forward(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
418 if self.micro_enabled {
419 self.micro.forward(input, output);
420 }
421 if self.base_enabled && layer_idx < self.base.num_layers() {
422 self.base.forward_layer(layer_idx, input, output);
423 }
424 }
425
426 pub fn accumulate_micro(&mut self, signal: &LearningSignal) {
428 if self.micro_enabled {
429 self.micro.accumulate_gradient(signal);
430 }
431 }
432
433 pub fn apply_micro(&mut self, learning_rate: f32) {
435 if self.micro_enabled {
436 self.micro.apply_accumulated(learning_rate);
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_micro_lora_creation() {
447 let lora = MicroLoRA::new(256, 1);
448 assert_eq!(lora.rank(), 1);
449 assert_eq!(lora.hidden_dim(), 256);
450 assert_eq!(lora.param_count(), 256 + 256);
451 }
452
453 #[test]
454 fn test_micro_lora_forward() {
455 let lora = MicroLoRA::new(64, 1);
456 let input = vec![1.0f32; 64];
457 let mut output = vec![0.0f32; 64];
458
459 lora.forward(&input, &mut output);
460
461 let sum: f32 = output.iter().sum();
464 assert!(
465 sum.abs() < 1e-6,
466 "Expected ~0 with zero up_proj, got {}",
467 sum
468 );
469 }
470
471 #[test]
472 fn test_micro_lora_learning() {
473 let mut lora = MicroLoRA::new(64, 1);
474
475 let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.8);
476
477 lora.accumulate_gradient(&signal);
478 assert_eq!(lora.pending_updates(), 1);
479
480 lora.apply_accumulated(0.01);
481 assert_eq!(lora.pending_updates(), 0);
482
483 let input = vec![1.0f32; 64];
485 let mut output = vec![0.0f32; 64];
486 lora.forward(&input, &mut output);
487
488 let sum: f32 = output.iter().map(|x| x.abs()).sum();
489 assert!(sum > 0.0, "Expected non-zero output after learning");
490 }
491
492 #[test]
493 fn test_base_lora() {
494 let lora = BaseLoRA::new(64, 4, 12);
495 assert_eq!(lora.num_layers(), 12);
496 assert_eq!(lora.rank, 4);
497 }
498
499 #[test]
500 fn test_lora_engine() {
501 let mut engine = LoRAEngine::new(64, 1, 4, 12);
502
503 let signal = LearningSignal::with_gradient(vec![0.1; 64], vec![0.5; 64], 0.9);
504
505 engine.accumulate_micro(&signal);
506 engine.apply_micro(0.01);
507
508 let input = vec![1.0f32; 64];
509 let mut output = vec![0.0f32; 64];
510 engine.forward(0, &input, &mut output);
511 }
512
513 #[test]
514 #[should_panic(expected = "MicroLoRA rank must be 1-2")]
515 fn test_invalid_rank() {
516 MicroLoRA::new(64, 5);
517 }
518}