tensorlogic_train/lora/
adapter.rs1use indexmap::IndexMap;
4
5use super::config::LoraConfig;
6use super::error::{LoraError, LoraResult};
7use super::layer::LoraLayer;
8
9#[derive(Debug, Clone)]
11pub struct LayerStats {
12 pub name: String,
13 pub d: usize,
14 pub k: usize,
15 pub rank: usize,
16 pub trainable_params: usize,
17 pub total_params: usize,
18 pub compression_ratio: f64,
19 pub merged: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct LoraAdapterSummary {
25 pub layers: Vec<LayerStats>,
26 pub total_trainable: usize,
27 pub total_params: usize,
28}
29
30pub struct LoraAdapter {
32 config: LoraConfig,
33 layers: IndexMap<String, LoraLayer>,
34}
35
36impl LoraAdapter {
37 pub fn new(config: LoraConfig) -> Self {
38 Self {
39 config,
40 layers: IndexMap::new(),
41 }
42 }
43
44 pub fn add_layer(&mut self, name: &str, base_weight: Vec<Vec<f64>>) -> LoraResult<()> {
46 let layer = LoraLayer::new(base_weight, self.config.clone())?;
47 self.layers.insert(name.to_string(), layer);
48 Ok(())
49 }
50
51 pub fn forward(&mut self, name: &str, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
53 let layer = self
54 .layers
55 .get_mut(name)
56 .ok_or_else(|| LoraError::DimensionMismatch {
57 expected: format!("layer '{name}' exists"),
58 got: "not found".into(),
59 })?;
60 layer.forward(input)
61 }
62
63 pub fn merge_all(&mut self) -> LoraResult<()> {
65 for layer in self.layers.values_mut() {
66 if !layer.merged {
67 layer.merge()?;
68 }
69 }
70 Ok(())
71 }
72
73 pub fn unmerge_all(&mut self) -> LoraResult<()> {
75 for layer in self.layers.values_mut() {
76 if layer.merged {
77 layer.unmerge()?;
78 }
79 }
80 Ok(())
81 }
82
83 pub fn total_trainable_params(&self) -> usize {
85 self.layers.values().map(|l| l.trainable_params()).sum()
86 }
87
88 pub fn summary(&self) -> LoraAdapterSummary {
90 let mut layers = Vec::with_capacity(self.layers.len());
91 for (name, layer) in &self.layers {
92 let d = layer.base_weight.len();
93 let k = layer.base_weight[0].len();
94 layers.push(LayerStats {
95 name: name.clone(),
96 d,
97 k,
98 rank: layer.config.rank,
99 trainable_params: layer.trainable_params(),
100 total_params: layer.total_params(),
101 compression_ratio: layer.compression_ratio(),
102 merged: layer.merged,
103 });
104 }
105 let total_trainable = layers.iter().map(|s| s.trainable_params).sum();
106 let total_params = layers.iter().map(|s| s.total_params).sum();
107 LoraAdapterSummary {
108 layers,
109 total_trainable,
110 total_params,
111 }
112 }
113}