Skip to main content

tensorlogic_train/lora/
adapter.rs

1//! Multi-layer LoRA adapter managing named LoRA layers.
2
3use indexmap::IndexMap;
4
5use super::config::LoraConfig;
6use super::error::{LoraError, LoraResult};
7use super::layer::LoraLayer;
8
9/// Per-layer statistics included in the adapter summary.
10#[derive(Debug, Clone)]
11pub struct LayerStats {
12    pub name: String,
13    pub d: usize,
14    pub k: usize,
15    pub rank: usize,
16    pub trainable_params: usize,
17    pub total_params: usize,
18    pub compression_ratio: f64,
19    pub merged: bool,
20}
21
22/// Summary of the entire LoRA adapter.
23#[derive(Debug, Clone)]
24pub struct LoraAdapterSummary {
25    pub layers: Vec<LayerStats>,
26    pub total_trainable: usize,
27    pub total_params: usize,
28}
29
30/// Manages multiple named [`LoraLayer`]s that share a single [`LoraConfig`].
31pub struct LoraAdapter {
32    config: LoraConfig,
33    layers: IndexMap<String, LoraLayer>,
34}
35
36impl LoraAdapter {
37    pub fn new(config: LoraConfig) -> Self {
38        Self {
39            config,
40            layers: IndexMap::new(),
41        }
42    }
43
44    /// Wrap `base_weight` in a new LoRA layer registered under `name`.
45    pub fn add_layer(&mut self, name: &str, base_weight: Vec<Vec<f64>>) -> LoraResult<()> {
46        let layer = LoraLayer::new(base_weight, self.config.clone())?;
47        self.layers.insert(name.to_string(), layer);
48        Ok(())
49    }
50
51    /// Forward pass through the named layer.
52    pub fn forward(&mut self, name: &str, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
53        let layer = self
54            .layers
55            .get_mut(name)
56            .ok_or_else(|| LoraError::DimensionMismatch {
57                expected: format!("layer '{name}' exists"),
58                got: "not found".into(),
59            })?;
60        layer.forward(input)
61    }
62
63    /// Merge all layers.
64    pub fn merge_all(&mut self) -> LoraResult<()> {
65        for layer in self.layers.values_mut() {
66            if !layer.merged {
67                layer.merge()?;
68            }
69        }
70        Ok(())
71    }
72
73    /// Unmerge all layers.
74    pub fn unmerge_all(&mut self) -> LoraResult<()> {
75        for layer in self.layers.values_mut() {
76            if layer.merged {
77                layer.unmerge()?;
78            }
79        }
80        Ok(())
81    }
82
83    /// Sum of trainable params across all layers.
84    pub fn total_trainable_params(&self) -> usize {
85        self.layers.values().map(|l| l.trainable_params()).sum()
86    }
87
88    /// Build a summary with per-layer statistics.
89    pub fn summary(&self) -> LoraAdapterSummary {
90        let mut layers = Vec::with_capacity(self.layers.len());
91        for (name, layer) in &self.layers {
92            let d = layer.base_weight.len();
93            let k = layer.base_weight[0].len();
94            layers.push(LayerStats {
95                name: name.clone(),
96                d,
97                k,
98                rank: layer.config.rank,
99                trainable_params: layer.trainable_params(),
100                total_params: layer.total_params(),
101                compression_ratio: layer.compression_ratio(),
102                merged: layer.merged,
103            });
104        }
105        let total_trainable = layers.iter().map(|s| s.trainable_params).sum();
106        let total_params = layers.iter().map(|s| s.total_params).sum();
107        LoraAdapterSummary {
108            layers,
109            total_trainable,
110            total_params,
111        }
112    }
113}