tensorlogic_train/lora/
layer.rs1use scirs2_core::random::{RngExt, SeedableRng, StdRng};
4use std::f64::consts::PI;
5
6use super::config::LoraConfig;
7use super::error::{LoraError, LoraResult};
8
9pub struct LoraLayer {
15 pub weight_a: Vec<Vec<f64>>,
17 pub weight_b: Vec<Vec<f64>>,
19 pub base_weight: Vec<Vec<f64>>,
21 pub config: LoraConfig,
23 pub merged: bool,
25 rng: StdRng,
27}
28
29fn next_normal(rng: &mut StdRng) -> f64 {
34 let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
35 let u2 = rng.random::<f64>();
36 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
37}
38
39fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
40 if a.is_empty() || b.is_empty() {
41 return Err(LoraError::DimensionMismatch {
42 expected: "non-empty matrices".into(),
43 got: format!("a rows={}, b rows={}", a.len(), b.len()),
44 });
45 }
46 let a_cols = a[0].len();
47 let b_rows = b.len();
48 if a_cols != b_rows {
49 return Err(LoraError::DimensionMismatch {
50 expected: format!("a_cols ({a_cols}) == b_rows"),
51 got: format!("{b_rows}"),
52 });
53 }
54 let b_cols = b[0].len();
55 let mut out = vec![vec![0.0; b_cols]; a.len()];
56 for i in 0..a.len() {
57 for k in 0..a_cols {
58 let a_ik = a[i][k];
59 for j in 0..b_cols {
60 out[i][j] += a_ik * b[k][j];
61 }
62 }
63 }
64 Ok(out)
65}
66
67fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
68 if m.is_empty() {
69 return Vec::new();
70 }
71 let rows = m.len();
72 let cols = m[0].len();
73 let mut t = vec![vec![0.0; rows]; cols];
74 for i in 0..rows {
75 for j in 0..cols {
76 t[j][i] = m[i][j];
77 }
78 }
79 t
80}
81
82fn add_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
83 a.iter()
84 .zip(b.iter())
85 .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x + y).collect())
86 .collect()
87}
88
89fn sub_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
90 a.iter()
91 .zip(b.iter())
92 .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x - y).collect())
93 .collect()
94}
95
96fn scale_matrix(s: f64, m: &[Vec<f64>]) -> Vec<Vec<f64>> {
97 m.iter()
98 .map(|row| row.iter().map(|v| v * s).collect())
99 .collect()
100}
101
102impl LoraLayer {
107 pub fn new(base_weight: Vec<Vec<f64>>, config: LoraConfig) -> LoraResult<Self> {
112 let d = base_weight.len();
113 if d == 0 {
114 return Err(LoraError::DimensionMismatch {
115 expected: "d > 0".into(),
116 got: "0".into(),
117 });
118 }
119 let k = base_weight[0].len();
120 if k == 0 {
121 return Err(LoraError::DimensionMismatch {
122 expected: "k > 0".into(),
123 got: "0".into(),
124 });
125 }
126
127 let rank = config.rank;
128 if rank == 0 || rank > d.min(k) {
129 return Err(LoraError::InvalidRank(rank));
130 }
131
132 let mut rng = StdRng::seed_from_u64(config.seed);
133 let stddev = 1.0 / (rank as f64).sqrt();
134
135 let weight_a: Vec<Vec<f64>> = (0..rank)
137 .map(|_| (0..k).map(|_| next_normal(&mut rng) * stddev).collect())
138 .collect();
139
140 let weight_b: Vec<Vec<f64>> = vec![vec![0.0; rank]; d];
142
143 Ok(Self {
144 weight_a,
145 weight_b,
146 base_weight,
147 config,
148 merged: false,
149 rng,
150 })
151 }
152
153 fn scaling(&self) -> f64 {
154 self.config.alpha / self.config.rank as f64
155 }
156
157 fn delta_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
159 matmul(&self.weight_b, &self.weight_a)
160 }
161
162 pub fn effective_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
165 if self.merged {
166 return Ok(self.base_weight.clone());
167 }
168 let dw = self.delta_weight()?;
169 Ok(add_matrices(
170 &self.base_weight,
171 &scale_matrix(self.scaling(), &dw),
172 ))
173 }
174
175 pub fn forward(&mut self, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
180 if input.is_empty() {
181 return Ok(Vec::new());
182 }
183 let k = self.base_weight[0].len();
184 if input[0].len() != k {
185 return Err(LoraError::DimensionMismatch {
186 expected: format!("input cols = {k}"),
187 got: format!("{}", input[0].len()),
188 });
189 }
190
191 if self.merged {
192 let wt = transpose(&self.base_weight);
194 return matmul(input, &wt);
195 }
196
197 let wt = transpose(&self.base_weight);
199 let base_out = matmul(input, &wt)?;
200
201 let at = transpose(&self.weight_a);
203 let mut lora_hidden = matmul(input, &at)?;
204
205 if self.config.dropout > 0.0 && self.config.dropout < 1.0 {
207 let inv_keep = 1.0 / (1.0 - self.config.dropout);
208 for row in &mut lora_hidden {
209 for v in row.iter_mut() {
210 if self.rng.random::<f64>() < self.config.dropout {
211 *v = 0.0;
212 } else {
213 *v *= inv_keep;
214 }
215 }
216 }
217 }
218
219 let bt = transpose(&self.weight_b);
220 let lora_out = matmul(&lora_hidden, &bt)?;
221 let scaled = scale_matrix(self.scaling(), &lora_out);
222 Ok(add_matrices(&base_out, &scaled))
223 }
224
225 pub fn merge(&mut self) -> LoraResult<()> {
227 if self.merged {
228 return Err(LoraError::MergeError("already merged".into()));
229 }
230 let dw = self.delta_weight()?;
231 self.base_weight = add_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
232 self.merged = true;
233 Ok(())
234 }
235
236 pub fn unmerge(&mut self) -> LoraResult<()> {
238 if !self.merged {
239 return Err(LoraError::MergeError("not merged".into()));
240 }
241 let dw = self.delta_weight()?;
242 self.base_weight = sub_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
243 self.merged = false;
244 Ok(())
245 }
246
247 pub fn trainable_params(&self) -> usize {
249 let d = self.base_weight.len();
250 let k = self.base_weight[0].len();
251 self.config.rank * (d + k)
252 }
253
254 pub fn total_params(&self) -> usize {
256 let d = self.base_weight.len();
257 let k = self.base_weight[0].len();
258 d * k + self.trainable_params()
259 }
260
261 pub fn compression_ratio(&self) -> f64 {
263 self.trainable_params() as f64 / self.total_params() as f64
264 }
265}