Skip to main content

trustformers_optim/
lora.rs

1//! # Low-Rank Adaptation (LoRA) Optimizers
2//!
3//! This module provides optimizers specifically designed for Low-Rank Adaptation (LoRA),
4//! a parameter-efficient fine-tuning technique that reduces the number of trainable parameters
5//! by decomposing weight updates into low-rank matrices.
6//!
7//! ## Overview
8//!
9//! LoRA works by freezing the original model weights and introducing trainable low-rank
10//! decomposition matrices A and B such that ∆W = B × A, where the rank r << min(input_dim, output_dim).
11//!
12//! ## Benefits
13//! - Reduces trainable parameters by 10,000x or more
14//! - Enables efficient fine-tuning on consumer hardware
15//! - Maintains model quality comparable to full fine-tuning
16//! - Allows efficient storage and sharing of multiple adaptations
17
18use crate::optimizer::OptimizerState;
19use anyhow::{anyhow, Result as AnyhowResult};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use trustformers_core::errors::Result;
23use trustformers_core::tensor::Tensor;
24
25/// Configuration for LoRA optimization.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct LoRAConfig {
28    /// Rank of the low-rank decomposition
29    pub rank: usize,
30    /// Alpha parameter for scaling (typically rank * 2)
31    pub alpha: f32,
32    /// Dropout probability for LoRA layers
33    pub dropout: f32,
34    /// Whether to enable bias training
35    pub bias: bool,
36    /// Modules to apply LoRA to (e.g., ["query", "key", "value"])
37    pub target_modules: Vec<String>,
38    /// Whether to merge adapter weights into base model
39    pub merge_weights: bool,
40}
41
42impl Default for LoRAConfig {
43    fn default() -> Self {
44        Self {
45            rank: 8,
46            alpha: 16.0,
47            dropout: 0.1,
48            bias: false,
49            target_modules: vec!["query".to_string(), "key".to_string(), "value".to_string()],
50            merge_weights: false,
51        }
52    }
53}
54
55/// LoRA adapter containing the low-rank matrices.
56#[derive(Debug, Clone)]
57pub struct LoRAAdapter {
58    /// Low-rank matrix A (input_dim x rank)
59    pub lora_a: Tensor,
60    /// Low-rank matrix B (rank x output_dim)
61    pub lora_b: Tensor,
62    /// Scaling factor
63    pub scaling: f32,
64    /// Whether this adapter is active
65    pub active: bool,
66}
67
68impl LoRAAdapter {
69    /// Creates a new LoRA adapter with random initialization.
70    pub fn new(input_dim: usize, output_dim: usize, rank: usize, alpha: f32) -> Result<Self> {
71        // Initialize A with small random values, B with zeros (common practice)
72        let lora_a = Tensor::randn(&[input_dim, rank])?;
73        let lora_b = Tensor::zeros(&[rank, output_dim])?;
74        let scaling = alpha / rank as f32;
75
76        Ok(Self {
77            lora_a,
78            lora_b,
79            scaling,
80            active: true,
81        })
82    }
83
84    /// Compute the LoRA update: scaling * B @ A
85    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
86        if !self.active {
87            return Tensor::zeros_like(input);
88        }
89
90        // Compute x @ A @ B * scaling
91        let intermediate = input.matmul(&self.lora_a)?;
92        let output = intermediate.matmul(&self.lora_b)?;
93        output.mul_scalar(self.scaling)
94    }
95
96    /// Get the effective weight matrix ∆W = scaling * B @ A
97    pub fn get_delta_weight(&self) -> Result<Tensor> {
98        if !self.active {
99            return Err(
100                trustformers_core::errors::TrustformersError::tensor_op_error(
101                    "Adapter is not active",
102                    "get_delta_weight",
103                ),
104            );
105        }
106
107        let delta_w = self.lora_b.matmul(&self.lora_a)?;
108        delta_w.mul_scalar(self.scaling)
109    }
110
111    /// Merge adapter weights into the base weight matrix
112    pub fn merge_into_weight(&self, base_weight: &mut Tensor) -> Result<()> {
113        if !self.active {
114            return Ok(());
115        }
116
117        let delta_w = self.get_delta_weight()?;
118        *base_weight = base_weight.add(&delta_w)?;
119        Ok(())
120    }
121
122    /// Enable or disable this adapter
123    pub fn set_active(&mut self, active: bool) {
124        self.active = active;
125    }
126
127    /// Get the number of trainable parameters
128    pub fn num_parameters(&self) -> usize {
129        self.lora_a.len() + self.lora_b.len()
130    }
131}
132
133/// LoRA-specific optimizer that only updates adapter parameters.
134pub struct LoRAOptimizer {
135    /// Base optimizer for LoRA parameters
136    base_optimizer: Box<dyn OptimizerState>,
137    /// LoRA adapters by module name
138    adapters: HashMap<String, LoRAAdapter>,
139    /// LoRA configuration
140    config: LoRAConfig,
141    /// Frozen base model parameters
142    frozen_parameters: HashMap<String, Tensor>,
143    /// Learning rate
144    learning_rate: f32,
145}
146
147impl std::fmt::Debug for LoRAOptimizer {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("LoRAOptimizer")
150            .field("adapters", &self.adapters)
151            .field("config", &self.config)
152            .field("frozen_parameters", &self.frozen_parameters)
153            .field("learning_rate", &self.learning_rate)
154            .finish()
155    }
156}
157
158impl LoRAOptimizer {
159    /// Creates a new LoRA optimizer.
160    pub fn new(
161        base_optimizer: Box<dyn OptimizerState>,
162        config: LoRAConfig,
163        learning_rate: f32,
164    ) -> Self {
165        Self {
166            base_optimizer,
167            adapters: HashMap::new(),
168            config,
169            frozen_parameters: HashMap::new(),
170            learning_rate,
171        }
172    }
173
174    /// Add a LoRA adapter for a specific module.
175    pub fn add_adapter(
176        &mut self,
177        module_name: &str,
178        input_dim: usize,
179        output_dim: usize,
180    ) -> Result<()> {
181        let adapter = LoRAAdapter::new(input_dim, output_dim, self.config.rank, self.config.alpha)?;
182        self.adapters.insert(module_name.to_string(), adapter);
183        Ok(())
184    }
185
186    /// Remove a LoRA adapter.
187    pub fn remove_adapter(&mut self, module_name: &str) -> Option<LoRAAdapter> {
188        self.adapters.remove(module_name)
189    }
190
191    /// Get a reference to an adapter.
192    pub fn get_adapter(&self, module_name: &str) -> Option<&LoRAAdapter> {
193        self.adapters.get(module_name)
194    }
195
196    /// Get a mutable reference to an adapter.
197    pub fn get_adapter_mut(&mut self, module_name: &str) -> Option<&mut LoRAAdapter> {
198        self.adapters.get_mut(module_name)
199    }
200
201    /// Enable or disable an adapter.
202    pub fn set_adapter_active(&mut self, module_name: &str, active: bool) -> Result<()> {
203        if let Some(adapter) = self.adapters.get_mut(module_name) {
204            adapter.set_active(active);
205            Ok(())
206        } else {
207            Err(
208                trustformers_core::errors::TrustformersError::tensor_op_error(
209                    &format!("Adapter {} not found", module_name),
210                    "set_adapter_active",
211                ),
212            )
213        }
214    }
215
216    /// Enable or disable all adapters.
217    pub fn set_all_adapters_active(&mut self, active: bool) {
218        for adapter in self.adapters.values_mut() {
219            adapter.set_active(active);
220        }
221    }
222
223    /// Get the total number of trainable parameters.
224    pub fn num_trainable_parameters(&self) -> usize {
225        self.adapters.values().map(|a| a.num_parameters()).sum()
226    }
227
228    /// Freeze base model parameters.
229    pub fn freeze_base_parameters(&mut self, parameters: HashMap<String, Tensor>) {
230        self.frozen_parameters = parameters;
231    }
232
233    /// Merge all active adapters into their respective base weights.
234    pub fn merge_adapters_into_base(&mut self) -> Result<()> {
235        for (module_name, adapter) in &self.adapters {
236            if adapter.active {
237                if let Some(base_weight) = self.frozen_parameters.get_mut(module_name) {
238                    adapter.merge_into_weight(base_weight)?;
239                }
240            }
241        }
242        Ok(())
243    }
244
245    /// Save adapter weights (for efficient storage/sharing).
246    pub fn save_adapters(&self) -> HashMap<String, (Tensor, Tensor, f32)> {
247        self.adapters
248            .iter()
249            .map(|(name, adapter)| {
250                (
251                    name.clone(),
252                    (
253                        adapter.lora_a.clone(),
254                        adapter.lora_b.clone(),
255                        adapter.scaling,
256                    ),
257                )
258            })
259            .collect()
260    }
261
262    /// Load adapter weights.
263    pub fn load_adapters(
264        &mut self,
265        adapters: HashMap<String, (Tensor, Tensor, f32)>,
266    ) -> Result<()> {
267        for (name, (lora_a, lora_b, scaling)) in adapters {
268            let adapter = LoRAAdapter {
269                lora_a,
270                lora_b,
271                scaling,
272                active: true,
273            };
274            self.adapters.insert(name, adapter);
275        }
276        Ok(())
277    }
278
279    /// Get configuration.
280    pub fn get_config(&self) -> &LoRAConfig {
281        &self.config
282    }
283
284    /// Extract trainable parameters (adapter parameters only).
285    fn get_trainable_parameters(&self) -> Vec<Tensor> {
286        let mut params = Vec::new();
287        for adapter in self.adapters.values() {
288            if adapter.active {
289                params.push(adapter.lora_a.clone());
290                params.push(adapter.lora_b.clone());
291            }
292        }
293        params
294    }
295
296    /// Update adapter parameters from optimized tensors.
297    fn update_adapters_from_parameters(&mut self, parameters: &[Tensor]) -> AnyhowResult<()> {
298        let mut param_idx = 0;
299        for adapter in self.adapters.values_mut() {
300            if adapter.active {
301                if param_idx + 1 >= parameters.len() {
302                    return Err(anyhow!("Not enough parameters provided"));
303                }
304                adapter.lora_a = parameters[param_idx].clone();
305                adapter.lora_b = parameters[param_idx + 1].clone();
306                param_idx += 2;
307            }
308        }
309        Ok(())
310    }
311}
312
313impl OptimizerState for LoRAOptimizer {
314    fn zero_grad(&mut self) -> AnyhowResult<()> {
315        self.base_optimizer.zero_grad()
316    }
317
318    fn step(&mut self, _parameters: &mut [Tensor]) -> AnyhowResult<()> {
319        // For LoRA, we only optimize adapter parameters, not the full model parameters
320        let mut trainable_params = self.get_trainable_parameters();
321
322        // Step the base optimizer on adapter parameters
323        self.base_optimizer.step(&mut trainable_params)?;
324
325        // Update our adapters with the optimized parameters
326        self.update_adapters_from_parameters(&trainable_params)?;
327
328        Ok(())
329    }
330
331    fn get_lr(&self) -> f32 {
332        self.learning_rate
333    }
334
335    fn set_lr(&mut self, lr: f32) {
336        self.learning_rate = lr;
337        self.base_optimizer.set_lr(lr);
338    }
339
340    fn state_dict(&self) -> AnyhowResult<HashMap<String, Tensor>> {
341        let mut state = HashMap::new();
342
343        // Save adapter states
344        for (name, adapter) in &self.adapters {
345            state.insert(format!("adapter_{}_lora_a", name), adapter.lora_a.clone());
346            state.insert(format!("adapter_{}_lora_b", name), adapter.lora_b.clone());
347            state.insert(
348                format!("adapter_{}_scaling", name),
349                Tensor::scalar(adapter.scaling)?,
350            );
351            state.insert(
352                format!("adapter_{}_active", name),
353                Tensor::scalar(adapter.active as i32 as f32)?,
354            );
355        }
356
357        // Save base optimizer state
358        let base_state = self.base_optimizer.state_dict()?;
359        for (key, value) in base_state {
360            state.insert(format!("base_{}", key), value);
361        }
362
363        Ok(state)
364    }
365
366    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> AnyhowResult<()> {
367        let mut base_state = HashMap::new();
368        let mut adapter_states: HashMap<
369            String,
370            (Option<Tensor>, Option<Tensor>, Option<f32>, Option<bool>),
371        > = HashMap::new();
372
373        // Separate adapter and base optimizer states
374        for (key, value) in state {
375            if key.starts_with("adapter_") {
376                let parts: Vec<&str> = key.split('_').collect();
377                if parts.len() >= 3 {
378                    let adapter_name = parts[1];
379                    let field = parts[2..].join("_");
380
381                    let entry = adapter_states
382                        .entry(adapter_name.to_string())
383                        .or_insert((None, None, None, None));
384
385                    match field.as_str() {
386                        "lora_a" => entry.0 = Some(value),
387                        "lora_b" => entry.1 = Some(value),
388                        "scaling" => entry.2 = Some(value.to_scalar()?),
389                        "active" => entry.3 = Some(value.to_scalar()? > 0.5),
390                        _ => {},
391                    }
392                }
393            } else if let Some(stripped) = key.strip_prefix("base_") {
394                base_state.insert(stripped.to_string(), value);
395            }
396        }
397
398        // Reconstruct adapters
399        for (name, (lora_a_opt, lora_b_opt, scaling_opt, active_opt)) in adapter_states {
400            if let (Some(lora_a), Some(lora_b), Some(scaling), Some(active)) =
401                (lora_a_opt, lora_b_opt, scaling_opt, active_opt)
402            {
403                let adapter = LoRAAdapter {
404                    lora_a,
405                    lora_b,
406                    scaling,
407                    active,
408                };
409                self.adapters.insert(name, adapter);
410            }
411        }
412
413        // Load base optimizer state
414        self.base_optimizer.load_state_dict(base_state)?;
415
416        Ok(())
417    }
418}
419
420/// Convenience function to create a LoRA-enabled Adam optimizer.
421pub fn create_lora_adam(
422    learning_rate: f32,
423    config: LoRAConfig,
424    beta1: f32,
425    beta2: f32,
426    epsilon: f32,
427    weight_decay: f32,
428) -> LoRAOptimizer {
429    let adam = Box::new(crate::sparse::SparseAdam::with_default_config(
430        learning_rate,
431        beta1,
432        beta2,
433        epsilon,
434        weight_decay,
435    ));
436    LoRAOptimizer::new(adam, config, learning_rate)
437}
438
439/// Convenience function to create a LoRA-enabled AdamW optimizer.
440pub fn create_lora_adamw(
441    learning_rate: f32,
442    config: LoRAConfig,
443    beta1: f32,
444    beta2: f32,
445    epsilon: f32,
446    weight_decay: f32,
447) -> LoRAOptimizer {
448    let adamw = Box::new(crate::sparse::SparseAdam::with_default_config(
449        learning_rate,
450        beta1,
451        beta2,
452        epsilon,
453        weight_decay,
454    ));
455    LoRAOptimizer::new(adamw, config, learning_rate)
456}
457
458/// Convenience function to create a LoRA-enabled SGD optimizer.
459pub fn create_lora_sgd(
460    learning_rate: f32,
461    config: LoRAConfig,
462    momentum: f32,
463    _dampening: f32,
464    _weight_decay: f32,
465    _nesterov: bool,
466) -> LoRAOptimizer {
467    let sgd = Box::new(crate::convergence::QHM::with_defaults(
468        learning_rate,
469        momentum,
470        0.999,
471    ));
472    LoRAOptimizer::new(sgd, config, learning_rate)
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn test_lora_config_default() {
481        let config = LoRAConfig::default();
482        assert_eq!(config.rank, 8);
483        assert_eq!(config.alpha, 16.0);
484        assert_eq!(config.dropout, 0.1);
485        assert!(!config.bias);
486        assert_eq!(config.target_modules.len(), 3);
487        assert!(!config.merge_weights);
488    }
489
490    #[test]
491    fn test_lora_adapter_creation() {
492        let adapter = LoRAAdapter::new(512, 256, 8, 16.0).unwrap();
493
494        assert_eq!(adapter.lora_a.shape(), &[512, 8]);
495        assert_eq!(adapter.lora_b.shape(), &[8, 256]);
496        assert_eq!(adapter.scaling, 2.0); // 16.0 / 8
497        assert!(adapter.active);
498    }
499
500    #[test]
501    fn test_lora_adapter_parameters() {
502        let adapter = LoRAAdapter::new(512, 256, 8, 16.0).unwrap();
503        let expected_params = 512 * 8 + 8 * 256; // A + B parameters
504        assert_eq!(adapter.num_parameters(), expected_params);
505    }
506
507    #[test]
508    fn test_lora_optimizer_creation() {
509        let config = LoRAConfig::default();
510        let optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
511
512        assert_eq!(optimizer.get_lr(), 1e-3);
513        assert_eq!(optimizer.num_trainable_parameters(), 0); // No adapters added yet
514    }
515
516    #[test]
517    fn test_adapter_management() {
518        let config = LoRAConfig::default();
519        let mut optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
520
521        // Add adapter
522        optimizer.add_adapter("query", 512, 512).unwrap();
523        assert_eq!(optimizer.num_trainable_parameters(), 512 * 8 + 8 * 512);
524
525        // Check adapter exists
526        assert!(optimizer.get_adapter("query").is_some());
527
528        // Remove adapter
529        let removed = optimizer.remove_adapter("query");
530        assert!(removed.is_some());
531        assert_eq!(optimizer.num_trainable_parameters(), 0);
532    }
533
534    #[test]
535    fn test_adapter_activation() {
536        let config = LoRAConfig::default();
537        let mut optimizer = create_lora_adam(1e-3, config, 0.9, 0.999, 1e-8, 0.01);
538
539        optimizer.add_adapter("query", 512, 512).unwrap();
540
541        // Initially active
542        assert!(optimizer.get_adapter("query").unwrap().active);
543
544        // Deactivate
545        optimizer.set_adapter_active("query", false).unwrap();
546        assert!(!optimizer.get_adapter("query").unwrap().active);
547
548        // Activate all
549        optimizer.set_all_adapters_active(true);
550        assert!(optimizer.get_adapter("query").unwrap().active);
551    }
552}