1#![allow(unused_variables)] use crate::errors::{Result, TrustformersError};
4use crate::layers::Linear;
5use crate::tensor::Tensor;
6use crate::traits::Layer;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum PeftMethod {
13 LoRA,
15 QLoRA,
17 AdaLoRA,
19 PrefixTuning,
21 PTuningV2,
23 PromptTuning,
25 Adapter,
27 BitFit,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PeftConfig {
34 pub method: PeftMethod,
35 pub r: Option<usize>, pub alpha: Option<f32>, pub dropout: Option<f32>, pub target_modules: Vec<String>, pub bias: Option<String>, pub task_type: Option<String>, pub inference_mode: bool, }
43
44impl Default for PeftConfig {
45 fn default() -> Self {
46 Self {
47 method: PeftMethod::LoRA,
48 r: Some(8),
49 alpha: Some(16.0),
50 dropout: Some(0.1),
51 target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
52 bias: Some("none".to_string()),
53 task_type: Some("CAUSAL_LM".to_string()),
54 inference_mode: false,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
64pub struct LoRALayer {
65 pub base_layer: Linear,
66 pub lora_a: Linear, pub lora_b: Linear, pub alpha: f32, pub r: usize, pub dropout: f32,
71 pub merged: bool, pub frozen: bool, }
74
75impl LoRALayer {
76 pub fn new(
77 input_dim: usize,
78 output_dim: usize,
79 r: usize,
80 alpha: f32,
81 dropout: f32,
82 bias: bool,
83 ) -> Result<Self> {
84 if r == 0 {
85 return Err(TrustformersError::invalid_config(
86 "LoRA rank must be greater than 0".into(),
87 ));
88 }
89
90 Ok(Self {
91 base_layer: Linear::new(input_dim, output_dim, bias),
92 lora_a: Linear::new(input_dim, r, false), lora_b: Linear::new(r, output_dim, false),
94 alpha,
95 r,
96 dropout,
97 merged: false,
98 frozen: true, })
100 }
101
102 pub fn initialize_weights(&mut self) -> Result<()> {
104 let a_weights = Tensor::randn(&[self.r, self.lora_a.weight().shape()[1]])?;
109 let scaled_a = a_weights.scalar_mul(0.01)?; self.lora_a.set_weight(scaled_a)?;
111
112 let b_weights = Tensor::zeros(&[self.lora_b.weight().shape()[0], self.r])?;
114 self.lora_b.set_weight(b_weights)?;
115
116 Ok(())
117 }
118
119 pub fn merge_weights(&mut self) -> Result<()> {
121 if self.merged {
122 return Ok(()); }
124
125 let lora_weight = self.lora_b.weight().matmul(self.lora_a.weight())?;
127 let scaling = self.alpha / self.r as f32;
128 let scaled_lora = lora_weight.scalar_mul(scaling)?;
129
130 let new_weight = self.base_layer.weight().add(&scaled_lora)?;
132 self.base_layer.set_weight(new_weight)?;
133 self.merged = true;
134
135 Ok(())
136 }
137
138 pub fn unmerge_weights(&mut self) -> Result<()> {
140 if !self.merged {
141 return Ok(()); }
143
144 let lora_weight = self.lora_b.weight().matmul(self.lora_a.weight())?;
146 let scaling = self.alpha / self.r as f32;
147 let scaled_lora = lora_weight.scalar_mul(scaling)?;
148
149 let neg_lora = scaled_lora.scalar_mul(-1.0)?;
151 let new_weight = self.base_layer.weight().add(&neg_lora)?;
152 self.base_layer.set_weight(new_weight)?;
153 self.merged = false;
154
155 Ok(())
156 }
157
158 pub fn train(&mut self) {
160 self.frozen = false;
161 }
162
163 pub fn eval(&mut self) {
165 self.frozen = true;
166 }
167
168 pub fn trainable_parameters(&self) -> Vec<&Tensor> {
170 let mut params = vec![self.lora_a.weight(), self.lora_b.weight()];
171
172 if !self.frozen {
173 params.push(self.base_layer.weight());
174 if let Some(bias) = self.base_layer.bias() {
175 params.push(bias);
176 }
177 }
178
179 params
180 }
181}
182
183impl Layer for LoRALayer {
184 type Input = Tensor;
185 type Output = Tensor;
186
187 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
188 if self.merged {
189 self.base_layer.forward(input)
191 } else {
192 let base_output = self.base_layer.forward(input.clone())?;
196
197 let lora_a_output = self.lora_a.forward(input)?;
199
200 let lora_a_dropped = if self.dropout > 0.0 {
202 lora_a_output.dropout(self.dropout)?
203 } else {
204 lora_a_output
205 };
206
207 let lora_output = self.lora_b.forward(lora_a_dropped)?;
209
210 let scaling = self.alpha / self.r as f32;
212 let scaled_lora = lora_output.scalar_mul(scaling)?;
213
214 base_output.add(&scaled_lora)
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct QLoRALayer {
222 pub lora_layer: LoRALayer,
223 pub quantized_base: Option<crate::quantization::QuantizedTensor>,
224}
225
226impl QLoRALayer {
227 pub fn new(
228 input_dim: usize,
229 output_dim: usize,
230 r: usize,
231 alpha: f32,
232 dropout: f32,
233 bias: bool,
234 ) -> Result<Self> {
235 Ok(Self {
236 lora_layer: LoRALayer::new(input_dim, output_dim, r, alpha, dropout, bias)?,
237 quantized_base: None,
238 })
239 }
240
241 pub fn quantize_base(
243 &mut self,
244 config: &crate::quantization::QuantizationConfig,
245 ) -> Result<()> {
246 let quantized =
247 crate::quantization::Quantizer::quantize(self.lora_layer.base_layer.weight(), config)?;
248 self.quantized_base = Some(quantized);
249 Ok(())
250 }
251}
252
253impl Layer for QLoRALayer {
254 type Input = Tensor;
255 type Output = Tensor;
256
257 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
258 if let Some(ref quantized) = self.quantized_base {
260 let dequantized_weight = quantized.dequantize()?;
261
262 let mut temp_base = self.lora_layer.base_layer.clone();
264 temp_base.set_weight(dequantized_weight)?;
265
266 let base_output = temp_base.forward(input.clone())?;
268
269 let lora_a_output = self.lora_layer.lora_a.forward(input)?;
271 let lora_a_dropped = if self.lora_layer.dropout > 0.0 {
272 lora_a_output.dropout(self.lora_layer.dropout)?
273 } else {
274 lora_a_output
275 };
276 let lora_output = self.lora_layer.lora_b.forward(lora_a_dropped)?;
277
278 let scaling = self.lora_layer.alpha / self.lora_layer.r as f32;
279 let scaled_lora = lora_output.scalar_mul(scaling)?;
280
281 base_output.add(&scaled_lora)
282 } else {
283 self.lora_layer.forward(input)
285 }
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct AdapterLayer {
292 pub down_proj: Linear,
293 pub up_proj: Linear,
294 pub activation: ActivationType,
295 pub bottleneck_size: usize,
296 pub dropout: f32,
297 pub residual_connection: bool,
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
301pub enum ActivationType {
302 ReLU,
303 GELU,
304 Swish,
305 Tanh,
306}
307
308impl AdapterLayer {
309 pub fn new(
310 hidden_size: usize,
311 bottleneck_size: usize,
312 activation: ActivationType,
313 dropout: f32,
314 ) -> Self {
315 Self {
316 down_proj: Linear::new(hidden_size, bottleneck_size, true),
317 up_proj: Linear::new(bottleneck_size, hidden_size, true),
318 activation,
319 bottleneck_size,
320 dropout,
321 residual_connection: true,
322 }
323 }
324
325 fn apply_activation(&self, tensor: &Tensor) -> Result<Tensor> {
326 match self.activation {
327 ActivationType::ReLU => self.relu(tensor),
328 ActivationType::GELU => self.gelu(tensor),
329 ActivationType::Swish => self.swish(tensor),
330 ActivationType::Tanh => self.tanh(tensor),
331 }
332 }
333
334 fn relu(&self, tensor: &Tensor) -> Result<Tensor> {
335 match tensor {
336 Tensor::F32(arr) => {
337 let result = arr.mapv(|x| x.max(0.0));
338 Ok(Tensor::F32(result))
339 },
340 _ => Err(TrustformersError::tensor_op_error(
341 "Unsupported tensor type for ReLU",
342 "LoRAActivation::relu",
343 )),
344 }
345 }
346
347 fn gelu(&self, tensor: &Tensor) -> Result<Tensor> {
348 match tensor {
349 Tensor::F32(arr) => {
350 let result = arr.mapv(|x| {
351 0.5 * x
352 * (1.0
353 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3)))
354 .tanh())
355 });
356 Ok(Tensor::F32(result))
357 },
358 _ => Err(TrustformersError::tensor_op_error(
359 "Unsupported tensor type for GELU",
360 "LoRAActivation::gelu",
361 )),
362 }
363 }
364
365 fn swish(&self, tensor: &Tensor) -> Result<Tensor> {
366 match tensor {
367 Tensor::F32(arr) => {
368 let result = arr.mapv(|x| x / (1.0 + (-x).exp()));
369 Ok(Tensor::F32(result))
370 },
371 _ => Err(TrustformersError::tensor_op_error(
372 "Unsupported tensor type for Swish",
373 "LoRAActivation::swish",
374 )),
375 }
376 }
377
378 fn tanh(&self, tensor: &Tensor) -> Result<Tensor> {
379 match tensor {
380 Tensor::F32(arr) => {
381 let result = arr.mapv(|x| x.tanh());
382 Ok(Tensor::F32(result))
383 },
384 _ => Err(TrustformersError::tensor_op_error(
385 "Unsupported tensor type for Tanh",
386 "LoRAActivation::tanh",
387 )),
388 }
389 }
390}
391
392impl Layer for AdapterLayer {
393 type Input = Tensor;
394 type Output = Tensor;
395
396 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
397 let down_output = self.down_proj.forward(input.clone())?;
399
400 let activated = self.apply_activation(&down_output)?;
402
403 let dropped = if self.dropout > 0.0 { activated.dropout(self.dropout)? } else { activated };
405
406 let up_output = self.up_proj.forward(dropped)?;
408
409 if self.residual_connection {
411 input.add(&up_output)
412 } else {
413 Ok(up_output)
414 }
415 }
416}
417
418#[derive(Debug, Clone)]
420pub struct PrefixTuningLayer {
421 pub prefix_length: usize,
422 pub hidden_size: usize,
423 pub num_layers: usize,
424 pub num_heads: usize,
425 pub prefix_projection: Linear,
426 pub prefix_embeddings: Tensor,
427}
428
429impl PrefixTuningLayer {
430 pub fn new(
431 prefix_length: usize,
432 hidden_size: usize,
433 num_layers: usize,
434 num_heads: usize,
435 ) -> Result<Self> {
436 let projection_dim = hidden_size * 2; let total_prefix_dim = num_layers * num_heads * prefix_length * 2; Ok(Self {
440 prefix_length,
441 hidden_size,
442 num_layers,
443 num_heads,
444 prefix_projection: Linear::new(hidden_size, projection_dim, true),
445 prefix_embeddings: Tensor::randn(&[prefix_length, hidden_size])?,
446 })
447 }
448
449 pub fn get_prefix_states(&self) -> Result<Vec<(Tensor, Tensor)>> {
450 let mut prefix_states = Vec::new();
451
452 for layer_idx in 0..self.num_layers {
453 let projected = self.prefix_projection.forward(self.prefix_embeddings.clone())?;
455
456 let key_value_split = projected.split(1, self.hidden_size)?; if key_value_split.len() != 2 {
459 return Err(TrustformersError::invalid_input(
460 "Projection split failed".into(),
461 ));
462 }
463
464 let key_states = key_value_split[0].clone();
465 let value_states = key_value_split[1].clone();
466
467 prefix_states.push((key_states, value_states));
468 }
469
470 Ok(prefix_states)
471 }
472}
473
474#[derive(Debug, Clone)]
476pub struct PromptTuningEmbedding {
477 pub num_virtual_tokens: usize,
478 pub hidden_size: usize,
479 pub prompt_embeddings: Tensor,
480 pub init_method: PromptInitMethod,
481}
482
483#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
484pub enum PromptInitMethod {
485 Random,
486 Text,
487 VocabAverage,
488}
489
490impl PromptTuningEmbedding {
491 pub fn new(
492 num_virtual_tokens: usize,
493 hidden_size: usize,
494 init_method: PromptInitMethod,
495 ) -> Result<Self> {
496 let prompt_embeddings = match init_method {
497 PromptInitMethod::Random => Tensor::randn(&[num_virtual_tokens, hidden_size])?,
498 PromptInitMethod::Text => {
499 let embeddings = Tensor::randn(&[num_virtual_tokens, hidden_size])?;
501 embeddings.scalar_mul(0.1)?
502 },
503 PromptInitMethod::VocabAverage => {
504 Tensor::zeros(&[num_virtual_tokens, hidden_size])?
506 },
507 };
508
509 Ok(Self {
510 num_virtual_tokens,
511 hidden_size,
512 prompt_embeddings,
513 init_method,
514 })
515 }
516
517 pub fn get_prompt_embeddings(&self) -> &Tensor {
518 &self.prompt_embeddings
519 }
520
521 pub fn update_embeddings(&mut self, new_embeddings: Tensor) -> Result<()> {
522 if new_embeddings.shape() != self.prompt_embeddings.shape() {
523 return Err(TrustformersError::shape_error(format!(
524 "Shape mismatch: expected {:?}, got {:?}",
525 self.prompt_embeddings.shape(),
526 new_embeddings.shape()
527 )));
528 }
529
530 self.prompt_embeddings = new_embeddings;
531 Ok(())
532 }
533}
534
535impl Layer for PrefixTuningLayer {
536 type Input = Tensor;
537 type Output = Tensor;
538
539 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
540 let projected = self.prefix_projection.forward(input)?;
542
543 Ok(projected)
547 }
548}
549
550impl Layer for PromptTuningEmbedding {
551 type Input = Tensor;
552 type Output = Tensor;
553
554 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
555 let input_shape = input.shape();
560 if input_shape.len() != 3 {
561 return Err(TrustformersError::shape_error(format!(
562 "Expected 3D input tensor [batch_size, seq_len, {}], got {:?}",
563 self.hidden_size, input_shape
564 )));
565 }
566
567 let batch_size = input_shape[0];
568
569 let prompt_with_batch =
572 self.prompt_embeddings
573 .reshape(&[1, self.num_virtual_tokens, self.hidden_size])?;
574
575 let prompt_expanded = prompt_with_batch.broadcast_to(&[
577 batch_size,
578 self.num_virtual_tokens,
579 self.hidden_size,
580 ])?;
581
582 let concatenated = Tensor::concat(&[prompt_expanded, input], 1)?;
584
585 Ok(concatenated)
586 }
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
591pub enum SerializableLayerData {
592 LoRA {
593 base_weight: Vec<f32>,
594 base_bias: Option<Vec<f32>>,
595 lora_a_weight: Vec<f32>,
596 lora_b_weight: Vec<f32>,
597 alpha: f32,
598 r: usize,
599 dropout: f32,
600 merged: bool,
601 frozen: bool,
602 input_dim: usize,
603 output_dim: usize,
604 },
605 Adapter {
606 down_proj_weight: Vec<f32>,
607 down_proj_bias: Vec<f32>,
608 up_proj_weight: Vec<f32>,
609 up_proj_bias: Vec<f32>,
610 activation: ActivationType,
611 bottleneck_size: usize,
612 dropout: f32,
613 residual_connection: bool,
614 hidden_size: usize,
615 },
616 PrefixTuning {
617 prefix_projection_weight: Vec<f32>,
618 prefix_projection_bias: Vec<f32>,
619 prefix_embeddings: Vec<f32>,
620 prefix_length: usize,
621 hidden_size: usize,
622 num_layers: usize,
623 num_heads: usize,
624 },
625 PromptTuning {
626 prompt_embeddings: Vec<f32>,
627 num_virtual_tokens: usize,
628 hidden_size: usize,
629 init_method: PromptInitMethod,
630 },
631}
632
633pub struct PeftModel {
635 pub config: PeftConfig,
636 pub peft_layers: HashMap<String, Box<dyn Layer<Input = Tensor, Output = Tensor>>>,
637 pub layer_metadata: HashMap<String, SerializableLayerData>,
638 pub active: bool,
639}
640
641impl PeftModel {
642 pub fn new(config: PeftConfig) -> Self {
643 Self {
644 config,
645 peft_layers: HashMap::new(),
646 layer_metadata: HashMap::new(),
647 active: true,
648 }
649 }
650
651 fn serialize_lora_layer(layer: &LoRALayer) -> Result<SerializableLayerData> {
653 let base_weight = layer.base_layer.weight().data()?;
654 let base_bias = layer.base_layer.bias().map(|b| b.data()).transpose()?;
655 let lora_a_weight = layer.lora_a.weight().data()?;
656 let lora_b_weight = layer.lora_b.weight().data()?;
657
658 Ok(SerializableLayerData::LoRA {
659 base_weight,
660 base_bias,
661 lora_a_weight,
662 lora_b_weight,
663 alpha: layer.alpha,
664 r: layer.r,
665 dropout: layer.dropout,
666 merged: layer.merged,
667 frozen: layer.frozen,
668 input_dim: layer.base_layer.weight().shape()[1],
669 output_dim: layer.base_layer.weight().shape()[0],
670 })
671 }
672
673 fn deserialize_lora_layer(data: &SerializableLayerData) -> Result<LoRALayer> {
675 if let SerializableLayerData::LoRA {
676 base_weight,
677 base_bias,
678 lora_a_weight,
679 lora_b_weight,
680 alpha,
681 r,
682 dropout,
683 merged,
684 frozen,
685 input_dim,
686 output_dim,
687 } = data
688 {
689 let mut layer = LoRALayer::new(
690 *input_dim,
691 *output_dim,
692 *r,
693 *alpha,
694 *dropout,
695 base_bias.is_some(),
696 )?;
697
698 let base_weight_tensor =
700 Tensor::from_vec(base_weight.clone(), &[*output_dim, *input_dim])?;
701 layer.base_layer.set_weight(base_weight_tensor)?;
702
703 if let Some(bias_data) = base_bias {
704 let bias_tensor = Tensor::from_vec(bias_data.clone(), &[*output_dim])?;
705 layer.base_layer.set_bias(bias_tensor)?;
706 }
707
708 let lora_a_tensor = Tensor::from_vec(lora_a_weight.clone(), &[*r, *input_dim])?;
710 layer.lora_a.set_weight(lora_a_tensor)?;
711
712 let lora_b_tensor = Tensor::from_vec(lora_b_weight.clone(), &[*output_dim, *r])?;
713 layer.lora_b.set_weight(lora_b_tensor)?;
714
715 layer.merged = *merged;
717 layer.frozen = *frozen;
718
719 Ok(layer)
720 } else {
721 Err(TrustformersError::invalid_input(
722 "Expected LoRA layer data".into(),
723 ))
724 }
725 }
726
727 fn serialize_adapter_layer(layer: &AdapterLayer) -> Result<SerializableLayerData> {
729 let down_proj_weight = layer.down_proj.weight().data()?;
730 let down_proj_bias =
731 layer.down_proj.bias().map(|b| b.data()).transpose()?.unwrap_or_default();
732 let up_proj_weight = layer.up_proj.weight().data()?;
733 let up_proj_bias = layer.up_proj.bias().map(|b| b.data()).transpose()?.unwrap_or_default();
734
735 Ok(SerializableLayerData::Adapter {
736 down_proj_weight,
737 down_proj_bias,
738 up_proj_weight,
739 up_proj_bias,
740 activation: layer.activation,
741 bottleneck_size: layer.bottleneck_size,
742 dropout: layer.dropout,
743 residual_connection: layer.residual_connection,
744 hidden_size: layer.up_proj.weight().shape()[1],
745 })
746 }
747
748 fn serialize_prefix_tuning_layer(layer: &PrefixTuningLayer) -> Result<SerializableLayerData> {
750 let prefix_projection_weight = layer.prefix_projection.weight().data()?;
751 let prefix_projection_bias = layer
752 .prefix_projection
753 .bias()
754 .map(|b| b.data())
755 .transpose()?
756 .unwrap_or_default();
757 let prefix_embeddings = layer.prefix_embeddings.data()?;
758
759 Ok(SerializableLayerData::PrefixTuning {
760 prefix_projection_weight,
761 prefix_projection_bias,
762 prefix_embeddings,
763 prefix_length: layer.prefix_length,
764 hidden_size: layer.hidden_size,
765 num_layers: layer.num_layers,
766 num_heads: layer.num_heads,
767 })
768 }
769
770 fn serialize_prompt_tuning_embedding(
772 embedding: &PromptTuningEmbedding,
773 ) -> Result<SerializableLayerData> {
774 let prompt_embeddings = embedding.prompt_embeddings.data()?;
775
776 Ok(SerializableLayerData::PromptTuning {
777 prompt_embeddings,
778 num_virtual_tokens: embedding.num_virtual_tokens,
779 hidden_size: embedding.hidden_size,
780 init_method: embedding.init_method,
781 })
782 }
783
784 fn deserialize_adapter_layer(data: &SerializableLayerData) -> Result<AdapterLayer> {
786 if let SerializableLayerData::Adapter {
787 down_proj_weight,
788 down_proj_bias,
789 up_proj_weight,
790 up_proj_bias,
791 activation,
792 bottleneck_size,
793 dropout,
794 residual_connection,
795 hidden_size,
796 } = data
797 {
798 let mut layer =
799 AdapterLayer::new(*hidden_size, *bottleneck_size, *activation, *dropout);
800
801 let down_weight_tensor =
803 Tensor::from_vec(down_proj_weight.clone(), &[*bottleneck_size, *hidden_size])?;
804 layer.down_proj.set_weight(down_weight_tensor)?;
805
806 let down_bias_tensor = Tensor::from_vec(down_proj_bias.clone(), &[*bottleneck_size])?;
807 layer.down_proj.set_bias(down_bias_tensor)?;
808
809 let up_weight_tensor =
811 Tensor::from_vec(up_proj_weight.clone(), &[*hidden_size, *bottleneck_size])?;
812 layer.up_proj.set_weight(up_weight_tensor)?;
813
814 let up_bias_tensor = Tensor::from_vec(up_proj_bias.clone(), &[*hidden_size])?;
815 layer.up_proj.set_bias(up_bias_tensor)?;
816
817 layer.residual_connection = *residual_connection;
819
820 Ok(layer)
821 } else {
822 Err(TrustformersError::invalid_input(
823 "Expected Adapter layer data".into(),
824 ))
825 }
826 }
827
828 fn deserialize_prefix_tuning_layer(data: &SerializableLayerData) -> Result<PrefixTuningLayer> {
830 if let SerializableLayerData::PrefixTuning {
831 prefix_projection_weight,
832 prefix_projection_bias,
833 prefix_embeddings,
834 prefix_length,
835 hidden_size,
836 num_layers,
837 num_heads,
838 } = data
839 {
840 let mut layer =
841 PrefixTuningLayer::new(*prefix_length, *hidden_size, *num_layers, *num_heads)?;
842
843 let proj_weight_tensor = Tensor::from_vec(
845 prefix_projection_weight.clone(),
846 &[*hidden_size, *prefix_length],
847 )?;
848 layer.prefix_projection.set_weight(proj_weight_tensor)?;
849
850 let proj_bias_tensor =
851 Tensor::from_vec(prefix_projection_bias.clone(), &[*hidden_size])?;
852 layer.prefix_projection.set_bias(proj_bias_tensor)?;
853
854 let embeddings_tensor = Tensor::from_vec(
856 prefix_embeddings.clone(),
857 &[
858 *num_layers,
859 2,
860 *num_heads,
861 *prefix_length,
862 *hidden_size / *num_heads,
863 ],
864 )?;
865 layer.prefix_embeddings = embeddings_tensor;
866
867 Ok(layer)
868 } else {
869 Err(TrustformersError::invalid_input(
870 "Expected PrefixTuning layer data".into(),
871 ))
872 }
873 }
874
875 fn deserialize_prompt_tuning_embedding(
877 data: &SerializableLayerData,
878 ) -> Result<PromptTuningEmbedding> {
879 if let SerializableLayerData::PromptTuning {
880 prompt_embeddings,
881 num_virtual_tokens,
882 hidden_size,
883 init_method,
884 } = data
885 {
886 let mut embedding =
887 PromptTuningEmbedding::new(*num_virtual_tokens, *hidden_size, *init_method)?;
888
889 let embeddings_tensor = Tensor::from_vec(
891 prompt_embeddings.clone(),
892 &[*num_virtual_tokens, *hidden_size],
893 )?;
894 embedding.prompt_embeddings = embeddings_tensor;
895
896 Ok(embedding)
897 } else {
898 Err(TrustformersError::invalid_input(
899 "Expected PromptTuning embedding data".into(),
900 ))
901 }
902 }
903
904 pub fn add_lora_layer(&mut self, name: String, layer: LoRALayer) {
905 if let Ok(metadata) = Self::serialize_lora_layer(&layer) {
907 self.layer_metadata.insert(name.clone(), metadata);
908 }
909 self.peft_layers.insert(name, Box::new(layer));
910 }
911
912 pub fn add_adapter_layer(&mut self, name: String, layer: AdapterLayer) {
913 if let Ok(metadata) = Self::serialize_adapter_layer(&layer) {
915 self.layer_metadata.insert(name.clone(), metadata);
916 }
917 self.peft_layers.insert(name, Box::new(layer));
918 }
919
920 pub fn add_prefix_tuning_layer(&mut self, name: String, layer: PrefixTuningLayer) {
921 if let Ok(metadata) = Self::serialize_prefix_tuning_layer(&layer) {
923 self.layer_metadata.insert(name.clone(), metadata);
924 }
925 self.peft_layers.insert(name, Box::new(layer));
926 }
927
928 pub fn add_prompt_tuning_embedding(&mut self, name: String, embedding: PromptTuningEmbedding) {
929 if let Ok(metadata) = Self::serialize_prompt_tuning_embedding(&embedding) {
931 self.layer_metadata.insert(name.clone(), metadata);
932 }
933 self.peft_layers.insert(name, Box::new(embedding));
934 }
935
936 pub fn enable_peft(&mut self) {
937 self.active = true;
938 }
939
940 pub fn disable_peft(&mut self) {
941 self.active = false;
942 }
943
944 pub fn merge_and_unload(&mut self) -> Result<()> {
945 for (name, layer) in &mut self.peft_layers {
947 }
950
951 self.active = false;
952 Ok(())
953 }
954
955 pub fn get_trainable_parameters(&self) -> Vec<String> {
956 if !self.active {
957 return Vec::new();
958 }
959
960 let mut trainable = Vec::new();
961 for name in self.peft_layers.keys() {
962 if self.config.target_modules.contains(name) {
963 trainable.push(name.clone());
964 }
965 }
966
967 trainable
968 }
969
970 pub fn save_pretrained(&self, path: &str) -> Result<()> {
971 std::fs::create_dir_all(path).map_err(|e| TrustformersError::io_error(e.to_string()))?;
973
974 let config_json = serde_json::to_string_pretty(&self.config)
976 .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
977 std::fs::write(format!("{}/peft_config.json", path), config_json)
978 .map_err(|e| TrustformersError::io_error(e.to_string()))?;
979
980 let weights_json = serde_json::to_string_pretty(&self.layer_metadata)
982 .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
983 std::fs::write(format!("{}/adapter_weights.json", path), weights_json)
984 .map_err(|e| TrustformersError::io_error(e.to_string()))?;
985
986 Ok(())
987 }
988
989 pub fn load_pretrained(path: &str) -> Result<Self> {
990 let config_str = std::fs::read_to_string(format!("{}/peft_config.json", path))
991 .map_err(|e| TrustformersError::io_error(e.to_string()))?;
992
993 let config: PeftConfig = serde_json::from_str(&config_str)
994 .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
995
996 let mut model = Self::new(config);
997
998 let weights_str = std::fs::read_to_string(format!("{}/adapter_weights.json", path))
1000 .map_err(|e| TrustformersError::io_error(e.to_string()))?;
1001
1002 let layer_metadata: HashMap<String, SerializableLayerData> =
1003 serde_json::from_str(&weights_str)
1004 .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
1005
1006 for (name, data) in layer_metadata {
1008 match &data {
1009 SerializableLayerData::LoRA { .. } => {
1010 let layer = Self::deserialize_lora_layer(&data)?;
1011 model.add_lora_layer(name, layer);
1012 },
1013 SerializableLayerData::Adapter { .. } => {
1014 let layer = Self::deserialize_adapter_layer(&data)?;
1015 model.add_adapter_layer(name, layer);
1016 },
1017 SerializableLayerData::PrefixTuning { .. } => {
1018 let layer = Self::deserialize_prefix_tuning_layer(&data)?;
1019 model.add_prefix_tuning_layer(name, layer);
1020 },
1021 SerializableLayerData::PromptTuning { .. } => {
1022 let embedding = Self::deserialize_prompt_tuning_embedding(&data)?;
1023 model.add_prompt_tuning_embedding(name, embedding);
1024 },
1025 }
1026 }
1027
1028 Ok(model)
1029 }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 #[test]
1037 fn test_lora_layer_creation() {
1038 let lora = LoRALayer::new(768, 768, 8, 16.0, 0.1, true).expect("operation failed in test");
1039 assert_eq!(lora.r, 8);
1040 assert_eq!(lora.alpha, 16.0);
1041 assert!(!lora.merged);
1042 assert!(lora.frozen);
1043 }
1044
1045 #[test]
1046 fn test_lora_layer_forward() {
1047 let mut lora =
1048 LoRALayer::new(64, 64, 4, 8.0, 0.0, false).expect("operation failed in test");
1049 lora.initialize_weights().expect("operation failed in test");
1050
1051 let input = Tensor::randn(&[10, 64]).expect("Failed to create random tensor");
1052 let output = lora.forward(input.clone()).expect("forward pass failed");
1053
1054 assert_eq!(output.shape(), input.shape());
1055 }
1056
1057 #[test]
1058 fn test_lora_merge_unmerge() {
1059 let mut lora =
1060 LoRALayer::new(32, 32, 2, 4.0, 0.0, false).expect("operation failed in test");
1061 lora.initialize_weights().expect("operation failed in test");
1062
1063 assert!(!lora.merged);
1064
1065 lora.merge_weights().expect("merge operation failed");
1066 assert!(lora.merged);
1067
1068 lora.unmerge_weights().expect("merge operation failed");
1069 assert!(!lora.merged);
1070 }
1071
1072 #[test]
1073 fn test_qlora_layer() {
1074 let mut qlora =
1075 QLoRALayer::new(64, 64, 4, 8.0, 0.1, false).expect("operation failed in test");
1076
1077 let quant_config = crate::quantization::QuantizationConfig::default();
1078 qlora.quantize_base(&quant_config).expect("operation failed in test");
1079
1080 let input = Tensor::randn(&[5, 64]).expect("Failed to create random tensor");
1081 let output = qlora.forward(input.clone()).expect("forward pass failed");
1082
1083 assert_eq!(output.shape(), input.shape());
1084 }
1085
1086 #[test]
1087 fn test_adapter_layer() {
1088 let adapter = AdapterLayer::new(128, 32, ActivationType::GELU, 0.1);
1089 assert_eq!(adapter.bottleneck_size, 32);
1090
1091 let input = Tensor::randn(&[8, 128]).expect("Failed to create random tensor");
1092 let output = adapter.forward(input.clone()).expect("forward pass failed");
1093
1094 assert_eq!(output.shape(), input.shape());
1095 }
1096
1097 #[test]
1098 fn test_prefix_tuning_layer() {
1099 let prefix = PrefixTuningLayer::new(10, 64, 12, 8).expect("operation failed in test");
1100 assert_eq!(prefix.prefix_length, 10);
1101 assert_eq!(prefix.num_layers, 12);
1102
1103 let prefix_states = prefix.get_prefix_states().expect("operation failed in test");
1104 assert_eq!(prefix_states.len(), 12);
1105 }
1106
1107 #[test]
1108 fn test_prompt_tuning_embedding() {
1109 let prompt = PromptTuningEmbedding::new(5, 768, PromptInitMethod::Random)
1110 .expect("operation failed in test");
1111 assert_eq!(prompt.num_virtual_tokens, 5);
1112 assert_eq!(prompt.hidden_size, 768);
1113
1114 let embeddings = prompt.get_prompt_embeddings();
1115 assert_eq!(embeddings.shape(), vec![5, 768]);
1116 }
1117
1118 #[test]
1119 fn test_peft_model() {
1120 let config = PeftConfig::default();
1121 let mut peft_model = PeftModel::new(config);
1122
1123 let lora = LoRALayer::new(64, 64, 4, 8.0, 0.1, false).expect("operation failed in test");
1124 peft_model.add_lora_layer("test_layer".to_string(), lora);
1125
1126 assert_eq!(peft_model.peft_layers.len(), 1);
1127 assert!(peft_model.active);
1128
1129 peft_model.disable_peft();
1130 assert!(!peft_model.active);
1131 }
1132
1133 #[test]
1134 fn test_peft_config_serialization() {
1135 let config = PeftConfig::default();
1136 let json = serde_json::to_string(&config).expect("JSON serialization failed");
1137 let deserialized: PeftConfig =
1138 serde_json::from_str(&json).expect("JSON deserialization failed");
1139
1140 assert_eq!(config.method, deserialized.method);
1141 assert_eq!(config.r, deserialized.r);
1142 assert_eq!(config.alpha, deserialized.alpha);
1143 }
1144
1145 #[test]
1146 fn test_activation_functions() {
1147 let adapter = AdapterLayer::new(64, 16, ActivationType::ReLU, 0.0);
1148 let input =
1149 Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).expect("Tensor from_vec failed");
1150
1151 let relu_result = adapter.relu(&input).expect("operation failed in test");
1152 let data = relu_result.data().expect("operation failed in test");
1153 assert_eq!(data[0], 0.0); assert_eq!(data[1], 0.0); assert_eq!(data[2], 1.0); assert_eq!(data[3], 2.0); }
1158}