Skip to main content

yscv_model/
lora.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// LoRA configuration.
7#[derive(Debug, Clone)]
8pub struct LoraConfig {
9    /// Rank of the low-rank matrices (default 4).
10    pub rank: usize,
11    /// Alpha scaling factor (default 1.0).
12    pub alpha: f32,
13}
14
15impl Default for LoraConfig {
16    fn default() -> Self {
17        Self {
18            rank: 4,
19            alpha: 1.0,
20        }
21    }
22}
23
24/// A LoRA adapter for a linear layer.
25///
26/// Wraps a frozen weight matrix with trainable low-rank A and B matrices.
27/// The effective weight is `W + A @ B * scaling` where `scaling = alpha / rank`.
28///
29/// In this crate, weight is stored as `[in_features, out_features]` (matching
30/// `LinearLayer`), so:
31///   - `lora_a` has shape `[in_features, rank]`
32///   - `lora_b` has shape `[rank, out_features]`
33///   - Forward: `y = x @ W + x @ A @ B * scaling + bias`
34#[derive(Debug, Clone)]
35pub struct LoraLinear {
36    /// Original frozen weight `[in_features, out_features]`.
37    pub frozen_weight: NodeId,
38    /// Low-rank matrix A `[in_features, rank]` — initialized with small random values.
39    pub lora_a: NodeId,
40    /// Low-rank matrix B `[rank, out_features]` — initialized to zeros.
41    pub lora_b: NodeId,
42    /// Optional bias (frozen).
43    pub bias: Option<NodeId>,
44    /// Scaling factor = alpha / rank.
45    pub scaling: f32,
46    pub in_features: usize,
47    pub out_features: usize,
48    pub rank: usize,
49}
50
51impl LoraLinear {
52    /// Creates a LoRA adapter from dimensions.
53    ///
54    /// Initializes `frozen_weight` and `bias` as graph constants (`requires_grad=false`),
55    /// `lora_a` with small Gaussian-like values (`requires_grad=true`),
56    /// `lora_b` with zeros (`requires_grad=true`).
57    pub fn new(
58        graph: &mut Graph,
59        in_features: usize,
60        out_features: usize,
61        config: &LoraConfig,
62    ) -> Result<Self, ModelError> {
63        let rank = config.rank;
64        let scaling = config.alpha / rank as f32;
65
66        // Frozen weight initialized to zeros (would normally be loaded from a pretrained model).
67        let frozen_weight_tensor = Tensor::zeros(vec![in_features, out_features])?;
68        let frozen_weight = graph.constant(frozen_weight_tensor);
69
70        // lora_a: small values using a simple deterministic initialization
71        // (Kaiming-like: stddev = 1/sqrt(in_features))
72        let stddev = 1.0 / (in_features as f32).sqrt();
73        let a_len = in_features * rank;
74        let a_data: Vec<f32> = (0..a_len)
75            .map(|i| {
76                // Simple deterministic pseudo-random using a hash-like function
77                let x = ((i as f32 + 1.0) * 0.618_034).fract() * 2.0 - 1.0;
78                x * stddev
79            })
80            .collect();
81        let lora_a_tensor = Tensor::from_vec(vec![in_features, rank], a_data)?;
82        let lora_a = graph.variable(lora_a_tensor);
83
84        // lora_b: initialized to zeros so initial LoRA contribution is zero
85        let lora_b_tensor = Tensor::zeros(vec![rank, out_features])?;
86        let lora_b = graph.variable(lora_b_tensor);
87
88        Ok(Self {
89            frozen_weight,
90            lora_a,
91            lora_b,
92            bias: None,
93            scaling,
94            in_features,
95            out_features,
96            rank,
97        })
98    }
99
100    /// Creates a LoRA adapter from an existing `LinearLayer`'s weights.
101    ///
102    /// The original weight and bias are frozen (stored as constants).
103    /// New trainable `lora_a` and `lora_b` matrices are created.
104    pub fn from_linear(
105        graph: &mut Graph,
106        weight_node: NodeId,
107        bias_node: NodeId,
108        in_features: usize,
109        out_features: usize,
110        config: &LoraConfig,
111    ) -> Result<Self, ModelError> {
112        let rank = config.rank;
113        let scaling = config.alpha / rank as f32;
114
115        // Freeze the original weight: copy tensor into a constant node.
116        let weight_tensor = graph.value(weight_node)?.clone();
117        let frozen_weight = graph.constant(weight_tensor);
118
119        // Freeze the original bias: copy tensor into a constant node.
120        let bias_tensor = graph.value(bias_node)?.clone();
121        let frozen_bias = graph.constant(bias_tensor);
122
123        // lora_a: small values (Kaiming-like)
124        let stddev = 1.0 / (in_features as f32).sqrt();
125        let a_len = in_features * rank;
126        let a_data: Vec<f32> = (0..a_len)
127            .map(|i| {
128                let x = ((i as f32 + 1.0) * 0.618_034).fract() * 2.0 - 1.0;
129                x * stddev
130            })
131            .collect();
132        let lora_a_tensor = Tensor::from_vec(vec![in_features, rank], a_data)?;
133        let lora_a = graph.variable(lora_a_tensor);
134
135        // lora_b: zeros so initial LoRA contribution is zero
136        let lora_b_tensor = Tensor::zeros(vec![rank, out_features])?;
137        let lora_b = graph.variable(lora_b_tensor);
138
139        Ok(Self {
140            frozen_weight,
141            lora_a,
142            lora_b,
143            bias: Some(frozen_bias),
144            scaling,
145            in_features,
146            out_features,
147            rank,
148        })
149    }
150
151    /// Creates a LoRA adapter with a frozen bias term.
152    pub fn with_bias(mut self, graph: &mut Graph, bias_tensor: Tensor) -> Result<Self, ModelError> {
153        if bias_tensor.shape() != [self.out_features] {
154            return Err(ModelError::InvalidParameterShape {
155                parameter: "bias",
156                expected: vec![self.out_features],
157                got: bias_tensor.shape().to_vec(),
158            });
159        }
160        self.bias = Some(graph.constant(bias_tensor));
161        Ok(self)
162    }
163
164    /// Forward pass: `y = x @ W + x @ A @ B * scaling + bias`.
165    ///
166    /// Gradients flow through A and B but not through the frozen weight.
167    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
168        let input_shape = graph.value(input)?.shape().to_vec();
169        if input_shape.len() != 2 || input_shape[1] != self.in_features {
170            return Err(ModelError::InvalidInputShape {
171                expected_features: self.in_features,
172                got: input_shape,
173            });
174        }
175
176        // Frozen path: x @ W  [batch, in] @ [in, out] -> [batch, out]
177        let frozen_out = graph.matmul_2d(input, self.frozen_weight)?;
178
179        // LoRA path: x @ A @ B * scaling
180        // x @ A: [batch, in] @ [in, rank] -> [batch, rank]
181        let lora_mid = graph.matmul_2d(input, self.lora_a)?;
182        // (x @ A) @ B: [batch, rank] @ [rank, out] -> [batch, out]
183        let lora_out = graph.matmul_2d(lora_mid, self.lora_b)?;
184
185        // Scale by alpha / rank
186        let scale_node = graph.constant(Tensor::scalar(self.scaling));
187        let lora_scaled = graph.mul(lora_out, scale_node)?;
188
189        // Combine frozen + lora
190        let mut output = graph.add(frozen_out, lora_scaled)?;
191
192        // Add bias if present
193        if let Some(bias) = self.bias {
194            output = graph.add(output, bias)?;
195        }
196
197        Ok(output)
198    }
199
200    /// Returns the trainable parameter `NodeId`s (only `lora_a` and `lora_b`).
201    pub fn trainable_params(&self) -> Vec<NodeId> {
202        vec![self.lora_a, self.lora_b]
203    }
204
205    /// Merges LoRA weights into the frozen weight, returning the effective weight tensor.
206    ///
207    /// `W_eff = W + A @ B * scaling`
208    ///
209    /// Useful for inference after training to avoid the LoRA overhead.
210    pub fn merge(&self, graph: &Graph) -> Result<Tensor, ModelError> {
211        let w = graph.value(self.frozen_weight)?;
212        let a = graph.value(self.lora_a)?;
213        let b = graph.value(self.lora_b)?;
214
215        // A @ B: [in, rank] @ [rank, out] -> [in, out]
216        let ab = yscv_kernels::matmul_2d(a, b)?;
217
218        // Scale and add to frozen weight
219        let w_data = w.data();
220        let ab_data = ab.data();
221        let merged_data: Vec<f32> = w_data
222            .iter()
223            .zip(ab_data.iter())
224            .map(|(&wi, &abi)| wi + abi * self.scaling)
225            .collect();
226
227        Ok(Tensor::from_vec(w.shape().to_vec(), merged_data)?)
228    }
229}