1use crate::optimizer::OptimizerState;
19use anyhow::{anyhow, Result as AnyhowResult};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use trustformers_core::errors::Result;
23use trustformers_core::tensor::Tensor;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct LoRAConfig {
28 pub rank: usize,
30 pub alpha: f32,
32 pub dropout: f32,
34 pub bias: bool,
36 pub target_modules: Vec<String>,
38 pub merge_weights: bool,
40}
41
42impl Default for LoRAConfig {
43 fn default() -> Self {
44 Self {
45 rank: 8,
46 alpha: 16.0,
47 dropout: 0.1,
48 bias: false,
49 target_modules: vec!["query".to_string(), "key".to_string(), "value".to_string()],
50 merge_weights: false,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct LoRAAdapter {
58 pub lora_a: Tensor,
60 pub lora_b: Tensor,
62 pub scaling: f32,
64 pub active: bool,
66}
67
68impl LoRAAdapter {
69 pub fn new(input_dim: usize, output_dim: usize, rank: usize, alpha: f32) -> Result<Self> {
71 let lora_a = Tensor::randn(&[input_dim, rank])?;
73 let lora_b = Tensor::zeros(&[rank, output_dim])?;
74 let scaling = alpha / rank as f32;
75
76 Ok(Self {
77 lora_a,
78 lora_b,
79 scaling,
80 active: true,
81 })
82 }
83
84 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
86 if !self.active {
87 return Tensor::zeros_like(input);
88 }
89
90 let intermediate = input.matmul(&self.lora_a)?;
92 let output = intermediate.matmul(&self.lora_b)?;
93 output.mul_scalar(self.scaling)
94 }
95
96 pub fn get_delta_weight(&self) -> Result<Tensor> {
98 if !self.active {
99 return Err(
100 trustformers_core::errors::TrustformersError::tensor_op_error(
101 "Adapter is not active",
102 "get_delta_weight",
103 ),
104 );
105 }
106
107 let delta_w = self.lora_b.matmul(&self.lora_a)?;
108 delta_w.mul_scalar(self.scaling)
109 }
110
111 pub fn merge_into_weight(&self, base_weight: &mut Tensor) -> Result<()> {
113 if !self.active {
114 return Ok(());
115 }
116
117 let delta_w = self.get_delta_weight()?;
118 *base_weight = base_weight.add(&delta_w)?;
119 Ok(())
120 }
121
122 pub fn set_active(&mut self, active: bool) {
124 self.active = active;
125 }
126
127 pub fn num_parameters(&self) -> usize {
129 self.lora_a.len() + self.lora_b.len()
130 }
131}
132
133pub struct LoRAOptimizer {
135 base_optimizer: Box<dyn OptimizerState>,
137 adapters: HashMap<String, LoRAAdapter>,
139 config: LoRAConfig,
141 frozen_parameters: HashMap<String, Tensor>,
143 learning_rate: f32,
145}
146
147impl std::fmt::Debug for LoRAOptimizer {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("LoRAOptimizer")
150 .field("adapters", &self.adapters)
151 .field("config", &self.config)
152 .field("frozen_parameters", &self.frozen_parameters)
153 .field("learning_rate", &self.learning_rate)
154 .finish()
155 }
156}
157
158impl LoRAOptimizer {
159 pub fn new(
161 base_optimizer: Box<dyn OptimizerState>,
162 config: LoRAConfig,
163 learning_rate: f32,
164 ) -> Self {
165 Self {
166 base_optimizer,
167 adapters: HashMap::new(),
168 config,
169 frozen_parameters: HashMap::new(),
170 learning_rate,
171 }
172 }
173
174 pub fn add_adapter(
176 &mut self,
177 module_name: &str,
178 input_dim: usize,
179 output_dim: usize,
180 ) -> Result<()> {
181 let adapter = LoRAAdapter::new(input_dim, output_dim, self.config.rank, self.config.alpha)?;
182 self.adapters.insert(module_name.to_string(), adapter);
183 Ok(())
184 }
185
186 pub fn remove_adapter(&mut self, module_name: &str) -> Option<LoRAAdapter> {
188 self.adapters.remove(module_name)
189 }
190
191 pub fn get_adapter(&self, module_name: &str) -> Option<&LoRAAdapter> {
193 self.adapters.get(module_name)
194 }
195
196 pub fn get_adapter_mut(&mut self, module_name: &str) -> Option<&mut LoRAAdapter> {
198 self.adapters.get_mut(module_name)
199 }
200
201 pub fn set_adapter_active(&mut self, module_name: &str, active: bool) -> Result<()> {
203 if let Some(adapter) = self.adapters.get_mut(module_name) {
204 adapter.set_active(active);
205 Ok(())
206 } else {
207 Err(
208 trustformers_core::errors::TrustformersError::tensor_op_error(
209 &format!("Adapter {} not found", module_name),
210 "set_adapter_active",
211 ),
212 )
213 }
214 }
215
216 pub fn set_all_adapters_active(&mut self, active: bool) {
218 for adapter in self.adapters.values_mut() {
219 adapter.set_active(active);
220 }
221 }
222
223 pub fn num_trainable_parameters(&self) -> usize {
225 self.adapters.values().map(|a| a.num_parameters()).sum()
226 }
227
228 pub fn freeze_base_parameters(&mut self, parameters: HashMap<String, Tensor>) {
230 self.frozen_parameters = parameters;
231 }
232
233 pub fn merge_adapters_into_base(&mut self) -> Result<()> {
235 for (module_name, adapter) in &self.adapters {
236 if adapter.active {
237 if let Some(base_weight) = self.frozen_parameters.get_mut(module_name) {
238 adapter.merge_into_weight(base_weight)?;
239 }
240 }
241 }
242 Ok(())
243 }
244
245 pub fn save_adapters(&self) -> HashMap<String, (Tensor, Tensor, f32)> {
247 self.adapters
248 .iter()
249 .map(|(name, adapter)| {
250 (
251 name.clone(),
252 (
253 adapter.lora_a.clone(),
254 adapter.lora_b.clone(),
255 adapter.scaling,
256 ),
257 )
258 })
259 .collect()
260 }
261
262 pub fn load_adapters(
264 &mut self,
265 adapters: HashMap<String, (Tensor, Tensor, f32)>,
266 ) -> Result<()> {
267 for (name, (lora_a, lora_b, scaling)) in adapters {
268 let adapter = LoRAAdapter {
269 lora_a,
270 lora_b,
271 scaling,
272 active: true,
273 };
274 self.adapters.insert(name, adapter);
275 }
276 Ok(())
277 }
278
279 pub fn get_config(&self) -> &LoRAConfig {
281 &self.config
282 }
283
284 fn get_trainable_parameters(&self) -> Vec<Tensor> {
286 let mut params = Vec::new();
287 for adapter in self.adapters.values() {
288 if adapter.active {
289 params.push(adapter.lora_a.clone());
290 params.push(adapter.lora_b.clone());
291 }
292 }
293 params
294 }
295
296 fn update_adapters_from_parameters(&mut self, parameters: &[Tensor]) -> AnyhowResult<()> {
298 let mut param_idx = 0;
299 for adapter in self.adapters.values_mut() {
300 if adapter.active {
301 if param_idx + 1 >= parameters.len() {
302 return Err(anyhow!("Not enough parameters provided"));
303 }
304 adapter.lora_a = parameters[param_idx].clone();
305 adapter.lora_b = parameters[param_idx + 1].clone();
306 param_idx += 2;
307 }
308 }
309 Ok(())
310 }
311}
312
313impl OptimizerState for LoRAOptimizer {
314 fn zero_grad(&mut self) -> AnyhowResult<()> {
315 self.base_optimizer.zero_grad()
316 }
317
318 fn step(&mut self, _parameters: &mut [Tensor]) -> AnyhowResult<()> {
319 let mut trainable_params = self.get_trainable_parameters();
321
322 self.base_optimizer.step(&mut trainable_params)?;
324
325 self.update_adapters_from_parameters(&trainable_params)?;
327
328 Ok(())
329 }
330
331 fn get_lr(&self) -> f32 {
332 self.learning_rate
333 }
334
335 fn set_lr(&mut self, lr: f32) {
336 self.learning_rate = lr;
337 self.base_optimizer.set_lr(lr);
338 }
339
340 fn state_dict(&self) -> AnyhowResult<HashMap<String, Tensor>> {
341 let mut state = HashMap::new();
342
343 for (name, adapter) in &self.adapters {
345 state.insert(format!("adapter_{}_lora_a", name), adapter.lora_a.clone());
346 state.insert(format!("adapter_{}_lora_b", name), adapter.lora_b.clone());
347 state.insert(
348 format!("adapter_{}_scaling", name),
349 Tensor::scalar(adapter.scaling)?,
350 );
351 state.insert(
352 format!("adapter_{}_active", name),
353 Tensor::scalar(adapter.active as i32 as f32)?,
354 );
355 }
356
357 let base_state = self.base_optimizer.state_dict()?;
359 for (key, value) in base_state {
360 state.insert(format!("base_{}", key), value);
361 }
362
363 Ok(state)
364 }
365
366 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> AnyhowResult<()> {
367 let mut base_state = HashMap::new();
368 let mut adapter_states: HashMap<
369 String,
370 (Option<Tensor>, Option<Tensor>, Option<f32>, Option<bool>),
371 > = HashMap::new();
372
373 for (key, value) in state {
375 if key.starts_with("adapter_") {
376 let parts: Vec<&str> = key.split('_').collect();
377 if parts.len() >= 3 {
378 let adapter_name = parts[1];
379 let field = parts[2..].join("_");
380
381 let entry = adapter_states
382 .entry(adapter_name.to_string())
383 .or_insert((None, None, None, None));
384
385 match field.as_str() {
386 "lora_a" => entry.0 = Some(value),
387 "lora_b" => entry.1 = Some(value),
388 "scaling" => entry.2 = Some(value.to_scalar()?),
389 "active" => entry.3 = Some(value.to_scalar()? > 0.5),
390 _ => {},
391 }
392 }
393 } else if let Some(stripped) = key.strip_prefix("base_") {
394 base_state.insert(stripped.to_string(), value);
395 }
396 }
397
398 for (name, (lora_a_opt, lora_b_opt, scaling_opt, active_opt)) in adapter_states {
400 if let (Some(lora_a), Some(lora_b), Some(scaling), Some(active)) =
401 (lora_a_opt, lora_b_opt, scaling_opt, active_opt)
402 {
403 let adapter = LoRAAdapter {
404 lora_a,
405 lora_b,
406 scaling,
407 active,
408 };
409 self.adapters.insert(name, adapter);
410 }
411 }
412
413 self.base_optimizer.load_state_dict(base_state)?;
415
416 Ok(())
417 }
418}
419
420pub fn create_lora_adam(
422 learning_rate: f32,
423 config: LoRAConfig,
424 beta1: f32,
425 beta2: f32,
426 epsilon: f32,
427 weight_decay: f32,
428) -> LoRAOptimizer {
429 let adam = Box::new(crate::sparse::SparseAdam::with_default_config(
430 learning_rate,
431 beta1,
432 beta2,
433 epsilon,
434 weight_decay,
435 ));
436 LoRAOptimizer::new(adam, config, learning_rate)
437}
438
439pub fn create_lora_adamw(
441 learning_rate: f32,
442 config: LoRAConfig,
443 beta1: f32,
444 beta2: f32,
445 epsilon: f32,
446 weight_decay: f32,
447) -> LoRAOptimizer {
448 let adamw = Box::new(crate::sparse::SparseAdam::with_default_config(
449 learning_rate,
450 beta1,
451 beta2,
452 epsilon,
453 weight_decay,
454 ));
455 LoRAOptimizer::new(adamw, config, learning_rate)
456}
457
458pub fn create_lora_sgd(
460 learning_rate: f32,
461 config: LoRAConfig,
462 momentum: f32,
463 _dampening: f32,
464 _weight_decay: f32,
465 _nesterov: bool,
466) -> LoRAOptimizer {
467 let sgd = Box::new(crate::convergence::QHM::with_defaults(
468 learning_rate,
469 momentum,
470 0.999,
471 ));
472 LoRAOptimizer::new(sgd, config, learning_rate)
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_lora_config_default() {
481 let config = LoRAConfig::default();
482 assert_eq!(config.rank, 8);
483 assert_eq!(config.alpha, 16.0);
484 assert_eq!(config.dropout, 0.1);
485 assert!(!config.bias);
486 assert_eq!(config.target_modules.len(), 3);
487 assert!(!config.merge_weights);
488 }
489
490 #[test]
491 fn test_lora_adapter_creation() {
492 let adapter = LoRAAdapter::new(512, 256, 8, 16.0).unwrap();
493
494 assert_eq!(adapter.lora_a.shape(), &[512, 8]);
495 assert_eq!(adapter.lora_b.shape(), &[8, 256]);
496 assert_eq!(adapter.scaling, 2.0); assert!(adapter.active);
498 }
499
500 #[test]
501 fn test_lora_adapter_parameters() {
502 let adapter = LoRAAdapter::new(512, 256, 8, 16.0).unwrap();
503 let expected_params = 512 * 8 + 8 * 256; assert_eq!(adapter.num_parameters(), expected_params);
505 }
506
507 #[test]
508 fn test_lora_optimizer_creation() {
509 let config = LoRAConfig::default();
510 let optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
511
512 assert_eq!(optimizer.get_lr(), 1e-3);
513 assert_eq!(optimizer.num_trainable_parameters(), 0); }
515
516 #[test]
517 fn test_adapter_management() {
518 let config = LoRAConfig::default();
519 let mut optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
520
521 optimizer.add_adapter("query", 512, 512).unwrap();
523 assert_eq!(optimizer.num_trainable_parameters(), 512 * 8 + 8 * 512);
524
525 assert!(optimizer.get_adapter("query").is_some());
527
528 let removed = optimizer.remove_adapter("query");
530 assert!(removed.is_some());
531 assert_eq!(optimizer.num_trainable_parameters(), 0);
532 }
533
534 #[test]
535 fn test_adapter_activation() {
536 let config = LoRAConfig::default();
537 let mut optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
538
539 optimizer.add_adapter("query", 512, 512).unwrap();
540
541 assert!(optimizer.get_adapter("query").unwrap().active);
543
544 optimizer.set_adapter_active("query", false).unwrap();
546 assert!(!optimizer.get_adapter("query").unwrap().active);
547
548 optimizer.set_all_adapters_active(true);
550 assert!(optimizer.get_adapter("query").unwrap().active);
551 }
552}