1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6#[derive(Debug, Clone)]
8pub struct LoraConfig {
9 pub rank: usize,
11 pub alpha: f32,
13}
14
15impl Default for LoraConfig {
16 fn default() -> Self {
17 Self {
18 rank: 4,
19 alpha: 1.0,
20 }
21 }
22}
23
24#[derive(Debug, Clone)]
35pub struct LoraLinear {
36 pub frozen_weight: NodeId,
38 pub lora_a: NodeId,
40 pub lora_b: NodeId,
42 pub bias: Option<NodeId>,
44 pub scaling: f32,
46 pub in_features: usize,
47 pub out_features: usize,
48 pub rank: usize,
49}
50
51impl LoraLinear {
52 pub fn new(
58 graph: &mut Graph,
59 in_features: usize,
60 out_features: usize,
61 config: &LoraConfig,
62 ) -> Result<Self, ModelError> {
63 let rank = config.rank;
64 let scaling = config.alpha / rank as f32;
65
66 let frozen_weight_tensor = Tensor::zeros(vec![in_features, out_features])?;
68 let frozen_weight = graph.constant(frozen_weight_tensor);
69
70 let stddev = 1.0 / (in_features as f32).sqrt();
73 let a_len = in_features * rank;
74 let a_data: Vec<f32> = (0..a_len)
75 .map(|i| {
76 let x = ((i as f32 + 1.0) * 0.618_034).fract() * 2.0 - 1.0;
78 x * stddev
79 })
80 .collect();
81 let lora_a_tensor = Tensor::from_vec(vec![in_features, rank], a_data)?;
82 let lora_a = graph.variable(lora_a_tensor);
83
84 let lora_b_tensor = Tensor::zeros(vec![rank, out_features])?;
86 let lora_b = graph.variable(lora_b_tensor);
87
88 Ok(Self {
89 frozen_weight,
90 lora_a,
91 lora_b,
92 bias: None,
93 scaling,
94 in_features,
95 out_features,
96 rank,
97 })
98 }
99
100 pub fn from_linear(
105 graph: &mut Graph,
106 weight_node: NodeId,
107 bias_node: NodeId,
108 in_features: usize,
109 out_features: usize,
110 config: &LoraConfig,
111 ) -> Result<Self, ModelError> {
112 let rank = config.rank;
113 let scaling = config.alpha / rank as f32;
114
115 let weight_tensor = graph.value(weight_node)?.clone();
117 let frozen_weight = graph.constant(weight_tensor);
118
119 let bias_tensor = graph.value(bias_node)?.clone();
121 let frozen_bias = graph.constant(bias_tensor);
122
123 let stddev = 1.0 / (in_features as f32).sqrt();
125 let a_len = in_features * rank;
126 let a_data: Vec<f32> = (0..a_len)
127 .map(|i| {
128 let x = ((i as f32 + 1.0) * 0.618_034).fract() * 2.0 - 1.0;
129 x * stddev
130 })
131 .collect();
132 let lora_a_tensor = Tensor::from_vec(vec![in_features, rank], a_data)?;
133 let lora_a = graph.variable(lora_a_tensor);
134
135 let lora_b_tensor = Tensor::zeros(vec![rank, out_features])?;
137 let lora_b = graph.variable(lora_b_tensor);
138
139 Ok(Self {
140 frozen_weight,
141 lora_a,
142 lora_b,
143 bias: Some(frozen_bias),
144 scaling,
145 in_features,
146 out_features,
147 rank,
148 })
149 }
150
151 pub fn with_bias(mut self, graph: &mut Graph, bias_tensor: Tensor) -> Result<Self, ModelError> {
153 if bias_tensor.shape() != [self.out_features] {
154 return Err(ModelError::InvalidParameterShape {
155 parameter: "bias",
156 expected: vec![self.out_features],
157 got: bias_tensor.shape().to_vec(),
158 });
159 }
160 self.bias = Some(graph.constant(bias_tensor));
161 Ok(self)
162 }
163
164 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
168 let input_shape = graph.value(input)?.shape().to_vec();
169 if input_shape.len() != 2 || input_shape[1] != self.in_features {
170 return Err(ModelError::InvalidInputShape {
171 expected_features: self.in_features,
172 got: input_shape,
173 });
174 }
175
176 let frozen_out = graph.matmul_2d(input, self.frozen_weight)?;
178
179 let lora_mid = graph.matmul_2d(input, self.lora_a)?;
182 let lora_out = graph.matmul_2d(lora_mid, self.lora_b)?;
184
185 let scale_node = graph.constant(Tensor::scalar(self.scaling));
187 let lora_scaled = graph.mul(lora_out, scale_node)?;
188
189 let mut output = graph.add(frozen_out, lora_scaled)?;
191
192 if let Some(bias) = self.bias {
194 output = graph.add(output, bias)?;
195 }
196
197 Ok(output)
198 }
199
200 pub fn trainable_params(&self) -> Vec<NodeId> {
202 vec![self.lora_a, self.lora_b]
203 }
204
205 pub fn merge(&self, graph: &Graph) -> Result<Tensor, ModelError> {
211 let w = graph.value(self.frozen_weight)?;
212 let a = graph.value(self.lora_a)?;
213 let b = graph.value(self.lora_b)?;
214
215 let ab = yscv_kernels::matmul_2d(a, b)?;
217
218 let w_data = w.data();
220 let ab_data = ab.data();
221 let merged_data: Vec<f32> = w_data
222 .iter()
223 .zip(ab_data.iter())
224 .map(|(&wi, &abi)| wi + abi * self.scaling)
225 .collect();
226
227 Ok(Tensor::from_vec(w.shape().to_vec(), merged_data)?)
228 }
229}