Skip to main content

tensorlogic_train/lora/
layer.rs

1//! Core LoRA layer: low-rank A/B decomposition of a frozen base weight.
2
3use scirs2_core::random::{RngExt, SeedableRng, StdRng};
4use std::f64::consts::PI;
5
6use super::config::LoraConfig;
7use super::error::{LoraError, LoraResult};
8
9/// A single LoRA-augmented weight matrix.
10///
11/// Holds the frozen base weight `W` (d x k) and the low-rank factors
12/// `B` (d x r, init zeros) and `A` (r x k, init Gaussian).
13/// Forward: `output = input @ (W + scaling * B @ A)^T`.
14pub struct LoraLayer {
15    /// A matrix (r x k), initialised with random Gaussian N(0, 1/r).
16    pub weight_a: Vec<Vec<f64>>,
17    /// B matrix (d x r), initialised to zeros.
18    pub weight_b: Vec<Vec<f64>>,
19    /// Original frozen weight (d x k).
20    pub base_weight: Vec<Vec<f64>>,
21    /// Configuration.
22    pub config: LoraConfig,
23    /// Whether delta W has been merged into base_weight.
24    pub merged: bool,
25    /// RNG for dropout during forward pass.
26    rng: StdRng,
27}
28
29// ---------------------------------------------------------------------------
30// Private helpers
31// ---------------------------------------------------------------------------
32
33fn next_normal(rng: &mut StdRng) -> f64 {
34    let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
35    let u2 = rng.random::<f64>();
36    (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
37}
38
39fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
40    if a.is_empty() || b.is_empty() {
41        return Err(LoraError::DimensionMismatch {
42            expected: "non-empty matrices".into(),
43            got: format!("a rows={}, b rows={}", a.len(), b.len()),
44        });
45    }
46    let a_cols = a[0].len();
47    let b_rows = b.len();
48    if a_cols != b_rows {
49        return Err(LoraError::DimensionMismatch {
50            expected: format!("a_cols ({a_cols}) == b_rows"),
51            got: format!("{b_rows}"),
52        });
53    }
54    let b_cols = b[0].len();
55    let mut out = vec![vec![0.0; b_cols]; a.len()];
56    for i in 0..a.len() {
57        for k in 0..a_cols {
58            let a_ik = a[i][k];
59            for j in 0..b_cols {
60                out[i][j] += a_ik * b[k][j];
61            }
62        }
63    }
64    Ok(out)
65}
66
67fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
68    if m.is_empty() {
69        return Vec::new();
70    }
71    let rows = m.len();
72    let cols = m[0].len();
73    let mut t = vec![vec![0.0; rows]; cols];
74    for i in 0..rows {
75        for j in 0..cols {
76            t[j][i] = m[i][j];
77        }
78    }
79    t
80}
81
82fn add_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
83    a.iter()
84        .zip(b.iter())
85        .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x + y).collect())
86        .collect()
87}
88
89fn sub_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
90    a.iter()
91        .zip(b.iter())
92        .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x - y).collect())
93        .collect()
94}
95
96fn scale_matrix(s: f64, m: &[Vec<f64>]) -> Vec<Vec<f64>> {
97    m.iter()
98        .map(|row| row.iter().map(|v| v * s).collect())
99        .collect()
100}
101
102// ---------------------------------------------------------------------------
103// LoraLayer impl
104// ---------------------------------------------------------------------------
105
106impl LoraLayer {
107    /// Create a new LoRA layer wrapping `base_weight` (d x k).
108    ///
109    /// `weight_a` is initialised from N(0, 1/r) and `weight_b` from zeros,
110    /// so the initial delta W is the zero matrix.
111    pub fn new(base_weight: Vec<Vec<f64>>, config: LoraConfig) -> LoraResult<Self> {
112        let d = base_weight.len();
113        if d == 0 {
114            return Err(LoraError::DimensionMismatch {
115                expected: "d > 0".into(),
116                got: "0".into(),
117            });
118        }
119        let k = base_weight[0].len();
120        if k == 0 {
121            return Err(LoraError::DimensionMismatch {
122                expected: "k > 0".into(),
123                got: "0".into(),
124            });
125        }
126
127        let rank = config.rank;
128        if rank == 0 || rank > d.min(k) {
129            return Err(LoraError::InvalidRank(rank));
130        }
131
132        let mut rng = StdRng::seed_from_u64(config.seed);
133        let stddev = 1.0 / (rank as f64).sqrt();
134
135        // A: r x k, Gaussian N(0, 1/r)
136        let weight_a: Vec<Vec<f64>> = (0..rank)
137            .map(|_| (0..k).map(|_| next_normal(&mut rng) * stddev).collect())
138            .collect();
139
140        // B: d x r, zeros
141        let weight_b: Vec<Vec<f64>> = vec![vec![0.0; rank]; d];
142
143        Ok(Self {
144            weight_a,
145            weight_b,
146            base_weight,
147            config,
148            merged: false,
149            rng,
150        })
151    }
152
153    fn scaling(&self) -> f64 {
154        self.config.alpha / self.config.rank as f64
155    }
156
157    /// Compute `B @ A` (d x k).
158    fn delta_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
159        matmul(&self.weight_b, &self.weight_a)
160    }
161
162    /// Compute the effective weight `W + scaling * B @ A` without mutating state.
163    /// When merged, returns `base_weight` (the delta is already folded in).
164    pub fn effective_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
165        if self.merged {
166            return Ok(self.base_weight.clone());
167        }
168        let dw = self.delta_weight()?;
169        Ok(add_matrices(
170            &self.base_weight,
171            &scale_matrix(self.scaling(), &dw),
172        ))
173    }
174
175    /// Forward pass: `output = input @ effective_weight^T`.
176    ///
177    /// `input` has shape `(n, k)` and the result has shape `(n, d)`.
178    /// When not merged, applies optional dropout on the LoRA branch.
179    pub fn forward(&mut self, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
180        if input.is_empty() {
181            return Ok(Vec::new());
182        }
183        let k = self.base_weight[0].len();
184        if input[0].len() != k {
185            return Err(LoraError::DimensionMismatch {
186                expected: format!("input cols = {k}"),
187                got: format!("{}", input[0].len()),
188            });
189        }
190
191        if self.merged {
192            // base_weight already contains the delta
193            let wt = transpose(&self.base_weight);
194            return matmul(input, &wt);
195        }
196
197        // base contribution: input @ W^T
198        let wt = transpose(&self.base_weight);
199        let base_out = matmul(input, &wt)?;
200
201        // LoRA branch: input @ A^T @ B^T * scaling
202        let at = transpose(&self.weight_a);
203        let mut lora_hidden = matmul(input, &at)?;
204
205        // Apply dropout on the hidden activations if p > 0
206        if self.config.dropout > 0.0 && self.config.dropout < 1.0 {
207            let inv_keep = 1.0 / (1.0 - self.config.dropout);
208            for row in &mut lora_hidden {
209                for v in row.iter_mut() {
210                    if self.rng.random::<f64>() < self.config.dropout {
211                        *v = 0.0;
212                    } else {
213                        *v *= inv_keep;
214                    }
215                }
216            }
217        }
218
219        let bt = transpose(&self.weight_b);
220        let lora_out = matmul(&lora_hidden, &bt)?;
221        let scaled = scale_matrix(self.scaling(), &lora_out);
222        Ok(add_matrices(&base_out, &scaled))
223    }
224
225    /// Merge `scaling * B @ A` into `base_weight`.
226    pub fn merge(&mut self) -> LoraResult<()> {
227        if self.merged {
228            return Err(LoraError::MergeError("already merged".into()));
229        }
230        let dw = self.delta_weight()?;
231        self.base_weight = add_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
232        self.merged = true;
233        Ok(())
234    }
235
236    /// Remove `scaling * B @ A` from `base_weight`.
237    pub fn unmerge(&mut self) -> LoraResult<()> {
238        if !self.merged {
239            return Err(LoraError::MergeError("not merged".into()));
240        }
241        let dw = self.delta_weight()?;
242        self.base_weight = sub_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
243        self.merged = false;
244        Ok(())
245    }
246
247    /// Number of trainable parameters: `r * (d + k)`.
248    pub fn trainable_params(&self) -> usize {
249        let d = self.base_weight.len();
250        let k = self.base_weight[0].len();
251        self.config.rank * (d + k)
252    }
253
254    /// Total parameter count: `d * k + r * (d + k)`.
255    pub fn total_params(&self) -> usize {
256        let d = self.base_weight.len();
257        let k = self.base_weight[0].len();
258        d * k + self.trainable_params()
259    }
260
261    /// Fraction of trainable vs total parameters.
262    pub fn compression_ratio(&self) -> f64 {
263        self.trainable_params() as f64 / self.total_params() as f64
264    }
265}