Skip to main content

trustformers_core/
peft.rs

1#![allow(unused_variables)] // PEFT implementation with reserved parameters
2
3use crate::errors::{Result, TrustformersError};
4use crate::layers::Linear;
5use crate::tensor::Tensor;
6use crate::traits::Layer;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Parameter-Efficient Fine-Tuning (PEFT) methods
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum PeftMethod {
13    /// Low-Rank Adaptation (LoRA)
14    LoRA,
15    /// Quantized LoRA (QLoRA)
16    QLoRA,
17    /// Adaptive Low-Rank Adaptation (AdaLoRA)
18    AdaLoRA,
19    /// Prefix Tuning
20    PrefixTuning,
21    /// P-Tuning v2
22    PTuningV2,
23    /// Prompt Tuning
24    PromptTuning,
25    /// Adapter layers
26    Adapter,
27    /// BitFit (bias-only fine-tuning)
28    BitFit,
29}
30
31/// Configuration for PEFT methods
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PeftConfig {
34    pub method: PeftMethod,
35    pub r: Option<usize>,            // Rank for LoRA
36    pub alpha: Option<f32>,          // Scaling factor for LoRA
37    pub dropout: Option<f32>,        // Dropout rate
38    pub target_modules: Vec<String>, // Which modules to apply PEFT to
39    pub bias: Option<String>,        // Bias training strategy
40    pub task_type: Option<String>,   // Task type for optimization
41    pub inference_mode: bool,        // Whether in inference mode
42}
43
44impl Default for PeftConfig {
45    fn default() -> Self {
46        Self {
47            method: PeftMethod::LoRA,
48            r: Some(8),
49            alpha: Some(16.0),
50            dropout: Some(0.1),
51            target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
52            bias: Some("none".to_string()),
53            task_type: Some("CAUSAL_LM".to_string()),
54            inference_mode: false,
55        }
56    }
57}
58
59/// LoRA (Low-Rank Adaptation) layer
60///
61/// LoRA approximates weight updates as W' = W + BA where B is r×d and A is d×r
62/// This reduces trainable parameters from d×d to 2×d×r where r << d
63#[derive(Debug, Clone)]
64pub struct LoRALayer {
65    pub base_layer: Linear,
66    pub lora_a: Linear, // Down-projection: input_dim -> r
67    pub lora_b: Linear, // Up-projection: r -> output_dim
68    pub alpha: f32,     // Scaling factor
69    pub r: usize,       // Rank
70    pub dropout: f32,
71    pub merged: bool, // Whether LoRA weights are merged into base layer
72    pub frozen: bool, // Whether base layer is frozen
73}
74
75impl LoRALayer {
76    pub fn new(
77        input_dim: usize,
78        output_dim: usize,
79        r: usize,
80        alpha: f32,
81        dropout: f32,
82        bias: bool,
83    ) -> Result<Self> {
84        if r == 0 {
85            return Err(TrustformersError::invalid_config(
86                "LoRA rank must be greater than 0".into(),
87            ));
88        }
89
90        Ok(Self {
91            base_layer: Linear::new(input_dim, output_dim, bias),
92            lora_a: Linear::new(input_dim, r, false), // No bias for LoRA layers
93            lora_b: Linear::new(r, output_dim, false),
94            alpha,
95            r,
96            dropout,
97            merged: false,
98            frozen: true, // Base layer starts frozen
99        })
100    }
101
102    /// Initialize LoRA weights
103    pub fn initialize_weights(&mut self) -> Result<()> {
104        // Initialize A with Gaussian noise, B with zeros (standard LoRA initialization)
105        // This ensures that initially LoRA contributes nothing: BA = 0
106
107        // Initialize lora_a weights with small random values
108        let a_weights = Tensor::randn(&[self.r, self.lora_a.weight().shape()[1]])?;
109        let scaled_a = a_weights.scalar_mul(0.01)?; // Small initialization
110        self.lora_a.set_weight(scaled_a)?;
111
112        // Initialize lora_b weights to zero
113        let b_weights = Tensor::zeros(&[self.lora_b.weight().shape()[0], self.r])?;
114        self.lora_b.set_weight(b_weights)?;
115
116        Ok(())
117    }
118
119    /// Merge LoRA weights into base layer for inference
120    pub fn merge_weights(&mut self) -> Result<()> {
121        if self.merged {
122            return Ok(()); // Already merged
123        }
124
125        // Compute LoRA contribution: (alpha/r) * B @ A
126        let lora_weight = self.lora_b.weight().matmul(self.lora_a.weight())?;
127        let scaling = self.alpha / self.r as f32;
128        let scaled_lora = lora_weight.scalar_mul(scaling)?;
129
130        // Add to base weights: W' = W + (alpha/r) * B @ A
131        let new_weight = self.base_layer.weight().add(&scaled_lora)?;
132        self.base_layer.set_weight(new_weight)?;
133        self.merged = true;
134
135        Ok(())
136    }
137
138    /// Unmerge LoRA weights from base layer
139    pub fn unmerge_weights(&mut self) -> Result<()> {
140        if !self.merged {
141            return Ok(()); // Not merged
142        }
143
144        // Subtract LoRA contribution
145        let lora_weight = self.lora_b.weight().matmul(self.lora_a.weight())?;
146        let scaling = self.alpha / self.r as f32;
147        let scaled_lora = lora_weight.scalar_mul(scaling)?;
148
149        // Subtract from base weights: W = W' - (alpha/r) * B @ A
150        let neg_lora = scaled_lora.scalar_mul(-1.0)?;
151        let new_weight = self.base_layer.weight().add(&neg_lora)?;
152        self.base_layer.set_weight(new_weight)?;
153        self.merged = false;
154
155        Ok(())
156    }
157
158    /// Set training mode
159    pub fn train(&mut self) {
160        self.frozen = false;
161    }
162
163    /// Set evaluation mode
164    pub fn eval(&mut self) {
165        self.frozen = true;
166    }
167
168    /// Get trainable parameters
169    pub fn trainable_parameters(&self) -> Vec<&Tensor> {
170        let mut params = vec![self.lora_a.weight(), self.lora_b.weight()];
171
172        if !self.frozen {
173            params.push(self.base_layer.weight());
174            if let Some(bias) = self.base_layer.bias() {
175                params.push(bias);
176            }
177        }
178
179        params
180    }
181}
182
183impl Layer for LoRALayer {
184    type Input = Tensor;
185    type Output = Tensor;
186
187    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
188        if self.merged {
189            // If weights are merged, just use base layer
190            self.base_layer.forward(input)
191        } else {
192            // Compute: h = (W + (alpha/r) * B @ A) @ x
193            // = W @ x + (alpha/r) * B @ (A @ x)
194
195            let base_output = self.base_layer.forward(input.clone())?;
196
197            // LoRA path: A @ x
198            let lora_a_output = self.lora_a.forward(input)?;
199
200            // Apply dropout to LoRA path
201            let lora_a_dropped = if self.dropout > 0.0 {
202                lora_a_output.dropout(self.dropout)?
203            } else {
204                lora_a_output
205            };
206
207            // B @ (A @ x)
208            let lora_output = self.lora_b.forward(lora_a_dropped)?;
209
210            // Scale and add: output = base_output + (alpha/r) * lora_output
211            let scaling = self.alpha / self.r as f32;
212            let scaled_lora = lora_output.scalar_mul(scaling)?;
213
214            base_output.add(&scaled_lora)
215        }
216    }
217}
218
219/// QLoRA layer combining LoRA with quantization
220#[derive(Debug, Clone)]
221pub struct QLoRALayer {
222    pub lora_layer: LoRALayer,
223    pub quantized_base: Option<crate::quantization::QuantizedTensor>,
224}
225
226impl QLoRALayer {
227    pub fn new(
228        input_dim: usize,
229        output_dim: usize,
230        r: usize,
231        alpha: f32,
232        dropout: f32,
233        bias: bool,
234    ) -> Result<Self> {
235        Ok(Self {
236            lora_layer: LoRALayer::new(input_dim, output_dim, r, alpha, dropout, bias)?,
237            quantized_base: None,
238        })
239    }
240
241    /// Quantize the base layer weights
242    pub fn quantize_base(
243        &mut self,
244        config: &crate::quantization::QuantizationConfig,
245    ) -> Result<()> {
246        let quantized =
247            crate::quantization::Quantizer::quantize(self.lora_layer.base_layer.weight(), config)?;
248        self.quantized_base = Some(quantized);
249        Ok(())
250    }
251}
252
253impl Layer for QLoRALayer {
254    type Input = Tensor;
255    type Output = Tensor;
256
257    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
258        // If base layer is quantized, dequantize for computation
259        if let Some(ref quantized) = self.quantized_base {
260            let dequantized_weight = quantized.dequantize()?;
261
262            // Create temporary linear layer with dequantized weights
263            let mut temp_base = self.lora_layer.base_layer.clone();
264            temp_base.set_weight(dequantized_weight)?;
265
266            // Compute base output
267            let base_output = temp_base.forward(input.clone())?;
268
269            // LoRA computation
270            let lora_a_output = self.lora_layer.lora_a.forward(input)?;
271            let lora_a_dropped = if self.lora_layer.dropout > 0.0 {
272                lora_a_output.dropout(self.lora_layer.dropout)?
273            } else {
274                lora_a_output
275            };
276            let lora_output = self.lora_layer.lora_b.forward(lora_a_dropped)?;
277
278            let scaling = self.lora_layer.alpha / self.lora_layer.r as f32;
279            let scaled_lora = lora_output.scalar_mul(scaling)?;
280
281            base_output.add(&scaled_lora)
282        } else {
283            // Fall back to regular LoRA
284            self.lora_layer.forward(input)
285        }
286    }
287}
288
289/// Adapter layer for parameter-efficient fine-tuning
290#[derive(Debug, Clone)]
291pub struct AdapterLayer {
292    pub down_proj: Linear,
293    pub up_proj: Linear,
294    pub activation: ActivationType,
295    pub bottleneck_size: usize,
296    pub dropout: f32,
297    pub residual_connection: bool,
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
301pub enum ActivationType {
302    ReLU,
303    GELU,
304    Swish,
305    Tanh,
306}
307
308impl AdapterLayer {
309    pub fn new(
310        hidden_size: usize,
311        bottleneck_size: usize,
312        activation: ActivationType,
313        dropout: f32,
314    ) -> Self {
315        Self {
316            down_proj: Linear::new(hidden_size, bottleneck_size, true),
317            up_proj: Linear::new(bottleneck_size, hidden_size, true),
318            activation,
319            bottleneck_size,
320            dropout,
321            residual_connection: true,
322        }
323    }
324
325    fn apply_activation(&self, tensor: &Tensor) -> Result<Tensor> {
326        match self.activation {
327            ActivationType::ReLU => self.relu(tensor),
328            ActivationType::GELU => self.gelu(tensor),
329            ActivationType::Swish => self.swish(tensor),
330            ActivationType::Tanh => self.tanh(tensor),
331        }
332    }
333
334    fn relu(&self, tensor: &Tensor) -> Result<Tensor> {
335        match tensor {
336            Tensor::F32(arr) => {
337                let result = arr.mapv(|x| x.max(0.0));
338                Ok(Tensor::F32(result))
339            },
340            _ => Err(TrustformersError::tensor_op_error(
341                "Unsupported tensor type for ReLU",
342                "LoRAActivation::relu",
343            )),
344        }
345    }
346
347    fn gelu(&self, tensor: &Tensor) -> Result<Tensor> {
348        match tensor {
349            Tensor::F32(arr) => {
350                let result = arr.mapv(|x| {
351                    0.5 * x
352                        * (1.0
353                            + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3)))
354                                .tanh())
355                });
356                Ok(Tensor::F32(result))
357            },
358            _ => Err(TrustformersError::tensor_op_error(
359                "Unsupported tensor type for GELU",
360                "LoRAActivation::gelu",
361            )),
362        }
363    }
364
365    fn swish(&self, tensor: &Tensor) -> Result<Tensor> {
366        match tensor {
367            Tensor::F32(arr) => {
368                let result = arr.mapv(|x| x / (1.0 + (-x).exp()));
369                Ok(Tensor::F32(result))
370            },
371            _ => Err(TrustformersError::tensor_op_error(
372                "Unsupported tensor type for Swish",
373                "LoRAActivation::swish",
374            )),
375        }
376    }
377
378    fn tanh(&self, tensor: &Tensor) -> Result<Tensor> {
379        match tensor {
380            Tensor::F32(arr) => {
381                let result = arr.mapv(|x| x.tanh());
382                Ok(Tensor::F32(result))
383            },
384            _ => Err(TrustformersError::tensor_op_error(
385                "Unsupported tensor type for Tanh",
386                "LoRAActivation::tanh",
387            )),
388        }
389    }
390}
391
392impl Layer for AdapterLayer {
393    type Input = Tensor;
394    type Output = Tensor;
395
396    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
397        // Down-projection
398        let down_output = self.down_proj.forward(input.clone())?;
399
400        // Activation
401        let activated = self.apply_activation(&down_output)?;
402
403        // Dropout
404        let dropped = if self.dropout > 0.0 { activated.dropout(self.dropout)? } else { activated };
405
406        // Up-projection
407        let up_output = self.up_proj.forward(dropped)?;
408
409        // Residual connection
410        if self.residual_connection {
411            input.add(&up_output)
412        } else {
413            Ok(up_output)
414        }
415    }
416}
417
418/// Prefix tuning layer
419#[derive(Debug, Clone)]
420pub struct PrefixTuningLayer {
421    pub prefix_length: usize,
422    pub hidden_size: usize,
423    pub num_layers: usize,
424    pub num_heads: usize,
425    pub prefix_projection: Linear,
426    pub prefix_embeddings: Tensor,
427}
428
429impl PrefixTuningLayer {
430    pub fn new(
431        prefix_length: usize,
432        hidden_size: usize,
433        num_layers: usize,
434        num_heads: usize,
435    ) -> Result<Self> {
436        let projection_dim = hidden_size * 2; // For both key and value
437        let total_prefix_dim = num_layers * num_heads * prefix_length * 2; // Key + Value
438
439        Ok(Self {
440            prefix_length,
441            hidden_size,
442            num_layers,
443            num_heads,
444            prefix_projection: Linear::new(hidden_size, projection_dim, true),
445            prefix_embeddings: Tensor::randn(&[prefix_length, hidden_size])?,
446        })
447    }
448
449    pub fn get_prefix_states(&self) -> Result<Vec<(Tensor, Tensor)>> {
450        let mut prefix_states = Vec::new();
451
452        for layer_idx in 0..self.num_layers {
453            // Project prefix embeddings to get key and value states
454            let projected = self.prefix_projection.forward(self.prefix_embeddings.clone())?;
455
456            // Split into key and value
457            let key_value_split = projected.split(1, self.hidden_size)?; // Split along last dimension
458            if key_value_split.len() != 2 {
459                return Err(TrustformersError::invalid_input(
460                    "Projection split failed".into(),
461                ));
462            }
463
464            let key_states = key_value_split[0].clone();
465            let value_states = key_value_split[1].clone();
466
467            prefix_states.push((key_states, value_states));
468        }
469
470        Ok(prefix_states)
471    }
472}
473
474/// Prompt tuning embeddings
475#[derive(Debug, Clone)]
476pub struct PromptTuningEmbedding {
477    pub num_virtual_tokens: usize,
478    pub hidden_size: usize,
479    pub prompt_embeddings: Tensor,
480    pub init_method: PromptInitMethod,
481}
482
483#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
484pub enum PromptInitMethod {
485    Random,
486    Text,
487    VocabAverage,
488}
489
490impl PromptTuningEmbedding {
491    pub fn new(
492        num_virtual_tokens: usize,
493        hidden_size: usize,
494        init_method: PromptInitMethod,
495    ) -> Result<Self> {
496        let prompt_embeddings = match init_method {
497            PromptInitMethod::Random => Tensor::randn(&[num_virtual_tokens, hidden_size])?,
498            PromptInitMethod::Text => {
499                // Initialize with small random values for text-based initialization
500                let embeddings = Tensor::randn(&[num_virtual_tokens, hidden_size])?;
501                embeddings.scalar_mul(0.1)?
502            },
503            PromptInitMethod::VocabAverage => {
504                // Initialize with zeros for vocabulary average initialization
505                Tensor::zeros(&[num_virtual_tokens, hidden_size])?
506            },
507        };
508
509        Ok(Self {
510            num_virtual_tokens,
511            hidden_size,
512            prompt_embeddings,
513            init_method,
514        })
515    }
516
517    pub fn get_prompt_embeddings(&self) -> &Tensor {
518        &self.prompt_embeddings
519    }
520
521    pub fn update_embeddings(&mut self, new_embeddings: Tensor) -> Result<()> {
522        if new_embeddings.shape() != self.prompt_embeddings.shape() {
523            return Err(TrustformersError::shape_error(format!(
524                "Shape mismatch: expected {:?}, got {:?}",
525                self.prompt_embeddings.shape(),
526                new_embeddings.shape()
527            )));
528        }
529
530        self.prompt_embeddings = new_embeddings;
531        Ok(())
532    }
533}
534
535impl Layer for PrefixTuningLayer {
536    type Input = Tensor;
537    type Output = Tensor;
538
539    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
540        // Apply prefix projection to input
541        let projected = self.prefix_projection.forward(input)?;
542
543        // For prefix tuning, we typically just return the projected input
544        // The actual prefix embeddings are used during attention computation
545        // which is handled by the attention mechanism that queries this layer
546        Ok(projected)
547    }
548}
549
550impl Layer for PromptTuningEmbedding {
551    type Input = Tensor;
552    type Output = Tensor;
553
554    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
555        // For prompt tuning, we concatenate the virtual prompt tokens with the input
556        // The prompt embeddings are prepended to the input sequence
557
558        // Get batch size from input (assuming input shape is [batch_size, seq_len, hidden_size])
559        let input_shape = input.shape();
560        if input_shape.len() != 3 {
561            return Err(TrustformersError::shape_error(format!(
562                "Expected 3D input tensor [batch_size, seq_len, {}], got {:?}",
563                self.hidden_size, input_shape
564            )));
565        }
566
567        let batch_size = input_shape[0];
568
569        // Expand prompt embeddings to match batch size
570        // First reshape to add batch dimension: [num_virtual_tokens, hidden_size] -> [1, num_virtual_tokens, hidden_size]
571        let prompt_with_batch =
572            self.prompt_embeddings
573                .reshape(&[1, self.num_virtual_tokens, self.hidden_size])?;
574
575        // Then broadcast to match batch size: [1, num_virtual_tokens, hidden_size] -> [batch_size, num_virtual_tokens, hidden_size]
576        let prompt_expanded = prompt_with_batch.broadcast_to(&[
577            batch_size,
578            self.num_virtual_tokens,
579            self.hidden_size,
580        ])?;
581
582        // Concatenate prompt embeddings with input along sequence dimension
583        let concatenated = Tensor::concat(&[prompt_expanded, input], 1)?;
584
585        Ok(concatenated)
586    }
587}
588
589/// Serializable representation of PEFT layer data
590#[derive(Debug, Clone, Serialize, Deserialize)]
591pub enum SerializableLayerData {
592    LoRA {
593        base_weight: Vec<f32>,
594        base_bias: Option<Vec<f32>>,
595        lora_a_weight: Vec<f32>,
596        lora_b_weight: Vec<f32>,
597        alpha: f32,
598        r: usize,
599        dropout: f32,
600        merged: bool,
601        frozen: bool,
602        input_dim: usize,
603        output_dim: usize,
604    },
605    Adapter {
606        down_proj_weight: Vec<f32>,
607        down_proj_bias: Vec<f32>,
608        up_proj_weight: Vec<f32>,
609        up_proj_bias: Vec<f32>,
610        activation: ActivationType,
611        bottleneck_size: usize,
612        dropout: f32,
613        residual_connection: bool,
614        hidden_size: usize,
615    },
616    PrefixTuning {
617        prefix_projection_weight: Vec<f32>,
618        prefix_projection_bias: Vec<f32>,
619        prefix_embeddings: Vec<f32>,
620        prefix_length: usize,
621        hidden_size: usize,
622        num_layers: usize,
623        num_heads: usize,
624    },
625    PromptTuning {
626        prompt_embeddings: Vec<f32>,
627        num_virtual_tokens: usize,
628        hidden_size: usize,
629        init_method: PromptInitMethod,
630    },
631}
632
633/// PEFT model wrapper that applies PEFT methods to a base model
634pub struct PeftModel {
635    pub config: PeftConfig,
636    pub peft_layers: HashMap<String, Box<dyn Layer<Input = Tensor, Output = Tensor>>>,
637    pub layer_metadata: HashMap<String, SerializableLayerData>,
638    pub active: bool,
639}
640
641impl PeftModel {
642    pub fn new(config: PeftConfig) -> Self {
643        Self {
644            config,
645            peft_layers: HashMap::new(),
646            layer_metadata: HashMap::new(),
647            active: true,
648        }
649    }
650
651    /// Convert a LoRA layer to serializable data
652    fn serialize_lora_layer(layer: &LoRALayer) -> Result<SerializableLayerData> {
653        let base_weight = layer.base_layer.weight().data()?;
654        let base_bias = layer.base_layer.bias().map(|b| b.data()).transpose()?;
655        let lora_a_weight = layer.lora_a.weight().data()?;
656        let lora_b_weight = layer.lora_b.weight().data()?;
657
658        Ok(SerializableLayerData::LoRA {
659            base_weight,
660            base_bias,
661            lora_a_weight,
662            lora_b_weight,
663            alpha: layer.alpha,
664            r: layer.r,
665            dropout: layer.dropout,
666            merged: layer.merged,
667            frozen: layer.frozen,
668            input_dim: layer.base_layer.weight().shape()[1],
669            output_dim: layer.base_layer.weight().shape()[0],
670        })
671    }
672
673    /// Convert serializable data to a LoRA layer
674    fn deserialize_lora_layer(data: &SerializableLayerData) -> Result<LoRALayer> {
675        if let SerializableLayerData::LoRA {
676            base_weight,
677            base_bias,
678            lora_a_weight,
679            lora_b_weight,
680            alpha,
681            r,
682            dropout,
683            merged,
684            frozen,
685            input_dim,
686            output_dim,
687        } = data
688        {
689            let mut layer = LoRALayer::new(
690                *input_dim,
691                *output_dim,
692                *r,
693                *alpha,
694                *dropout,
695                base_bias.is_some(),
696            )?;
697
698            // Set base layer weights
699            let base_weight_tensor =
700                Tensor::from_vec(base_weight.clone(), &[*output_dim, *input_dim])?;
701            layer.base_layer.set_weight(base_weight_tensor)?;
702
703            if let Some(bias_data) = base_bias {
704                let bias_tensor = Tensor::from_vec(bias_data.clone(), &[*output_dim])?;
705                layer.base_layer.set_bias(bias_tensor)?;
706            }
707
708            // Set LoRA weights
709            let lora_a_tensor = Tensor::from_vec(lora_a_weight.clone(), &[*r, *input_dim])?;
710            layer.lora_a.set_weight(lora_a_tensor)?;
711
712            let lora_b_tensor = Tensor::from_vec(lora_b_weight.clone(), &[*output_dim, *r])?;
713            layer.lora_b.set_weight(lora_b_tensor)?;
714
715            // Set state
716            layer.merged = *merged;
717            layer.frozen = *frozen;
718
719            Ok(layer)
720        } else {
721            Err(TrustformersError::invalid_input(
722                "Expected LoRA layer data".into(),
723            ))
724        }
725    }
726
727    /// Convert an Adapter layer to serializable data
728    fn serialize_adapter_layer(layer: &AdapterLayer) -> Result<SerializableLayerData> {
729        let down_proj_weight = layer.down_proj.weight().data()?;
730        let down_proj_bias =
731            layer.down_proj.bias().map(|b| b.data()).transpose()?.unwrap_or_default();
732        let up_proj_weight = layer.up_proj.weight().data()?;
733        let up_proj_bias = layer.up_proj.bias().map(|b| b.data()).transpose()?.unwrap_or_default();
734
735        Ok(SerializableLayerData::Adapter {
736            down_proj_weight,
737            down_proj_bias,
738            up_proj_weight,
739            up_proj_bias,
740            activation: layer.activation,
741            bottleneck_size: layer.bottleneck_size,
742            dropout: layer.dropout,
743            residual_connection: layer.residual_connection,
744            hidden_size: layer.up_proj.weight().shape()[1],
745        })
746    }
747
748    /// Convert a PrefixTuning layer to serializable data
749    fn serialize_prefix_tuning_layer(layer: &PrefixTuningLayer) -> Result<SerializableLayerData> {
750        let prefix_projection_weight = layer.prefix_projection.weight().data()?;
751        let prefix_projection_bias = layer
752            .prefix_projection
753            .bias()
754            .map(|b| b.data())
755            .transpose()?
756            .unwrap_or_default();
757        let prefix_embeddings = layer.prefix_embeddings.data()?;
758
759        Ok(SerializableLayerData::PrefixTuning {
760            prefix_projection_weight,
761            prefix_projection_bias,
762            prefix_embeddings,
763            prefix_length: layer.prefix_length,
764            hidden_size: layer.hidden_size,
765            num_layers: layer.num_layers,
766            num_heads: layer.num_heads,
767        })
768    }
769
770    /// Convert a PromptTuning embedding to serializable data
771    fn serialize_prompt_tuning_embedding(
772        embedding: &PromptTuningEmbedding,
773    ) -> Result<SerializableLayerData> {
774        let prompt_embeddings = embedding.prompt_embeddings.data()?;
775
776        Ok(SerializableLayerData::PromptTuning {
777            prompt_embeddings,
778            num_virtual_tokens: embedding.num_virtual_tokens,
779            hidden_size: embedding.hidden_size,
780            init_method: embedding.init_method,
781        })
782    }
783
784    /// Convert serializable data to an Adapter layer
785    fn deserialize_adapter_layer(data: &SerializableLayerData) -> Result<AdapterLayer> {
786        if let SerializableLayerData::Adapter {
787            down_proj_weight,
788            down_proj_bias,
789            up_proj_weight,
790            up_proj_bias,
791            activation,
792            bottleneck_size,
793            dropout,
794            residual_connection,
795            hidden_size,
796        } = data
797        {
798            let mut layer =
799                AdapterLayer::new(*hidden_size, *bottleneck_size, *activation, *dropout);
800
801            // Set down projection weights
802            let down_weight_tensor =
803                Tensor::from_vec(down_proj_weight.clone(), &[*bottleneck_size, *hidden_size])?;
804            layer.down_proj.set_weight(down_weight_tensor)?;
805
806            let down_bias_tensor = Tensor::from_vec(down_proj_bias.clone(), &[*bottleneck_size])?;
807            layer.down_proj.set_bias(down_bias_tensor)?;
808
809            // Set up projection weights
810            let up_weight_tensor =
811                Tensor::from_vec(up_proj_weight.clone(), &[*hidden_size, *bottleneck_size])?;
812            layer.up_proj.set_weight(up_weight_tensor)?;
813
814            let up_bias_tensor = Tensor::from_vec(up_proj_bias.clone(), &[*hidden_size])?;
815            layer.up_proj.set_bias(up_bias_tensor)?;
816
817            // Set configuration
818            layer.residual_connection = *residual_connection;
819
820            Ok(layer)
821        } else {
822            Err(TrustformersError::invalid_input(
823                "Expected Adapter layer data".into(),
824            ))
825        }
826    }
827
828    /// Convert serializable data to a PrefixTuning layer
829    fn deserialize_prefix_tuning_layer(data: &SerializableLayerData) -> Result<PrefixTuningLayer> {
830        if let SerializableLayerData::PrefixTuning {
831            prefix_projection_weight,
832            prefix_projection_bias,
833            prefix_embeddings,
834            prefix_length,
835            hidden_size,
836            num_layers,
837            num_heads,
838        } = data
839        {
840            let mut layer =
841                PrefixTuningLayer::new(*prefix_length, *hidden_size, *num_layers, *num_heads)?;
842
843            // Set prefix projection weights
844            let proj_weight_tensor = Tensor::from_vec(
845                prefix_projection_weight.clone(),
846                &[*hidden_size, *prefix_length],
847            )?;
848            layer.prefix_projection.set_weight(proj_weight_tensor)?;
849
850            let proj_bias_tensor =
851                Tensor::from_vec(prefix_projection_bias.clone(), &[*hidden_size])?;
852            layer.prefix_projection.set_bias(proj_bias_tensor)?;
853
854            // Set prefix embeddings
855            let embeddings_tensor = Tensor::from_vec(
856                prefix_embeddings.clone(),
857                &[
858                    *num_layers,
859                    2,
860                    *num_heads,
861                    *prefix_length,
862                    *hidden_size / *num_heads,
863                ],
864            )?;
865            layer.prefix_embeddings = embeddings_tensor;
866
867            Ok(layer)
868        } else {
869            Err(TrustformersError::invalid_input(
870                "Expected PrefixTuning layer data".into(),
871            ))
872        }
873    }
874
875    /// Convert serializable data to a PromptTuning embedding
876    fn deserialize_prompt_tuning_embedding(
877        data: &SerializableLayerData,
878    ) -> Result<PromptTuningEmbedding> {
879        if let SerializableLayerData::PromptTuning {
880            prompt_embeddings,
881            num_virtual_tokens,
882            hidden_size,
883            init_method,
884        } = data
885        {
886            let mut embedding =
887                PromptTuningEmbedding::new(*num_virtual_tokens, *hidden_size, *init_method)?;
888
889            // Set prompt embeddings
890            let embeddings_tensor = Tensor::from_vec(
891                prompt_embeddings.clone(),
892                &[*num_virtual_tokens, *hidden_size],
893            )?;
894            embedding.prompt_embeddings = embeddings_tensor;
895
896            Ok(embedding)
897        } else {
898            Err(TrustformersError::invalid_input(
899                "Expected PromptTuning embedding data".into(),
900            ))
901        }
902    }
903
904    pub fn add_lora_layer(&mut self, name: String, layer: LoRALayer) {
905        // Store serializable metadata
906        if let Ok(metadata) = Self::serialize_lora_layer(&layer) {
907            self.layer_metadata.insert(name.clone(), metadata);
908        }
909        self.peft_layers.insert(name, Box::new(layer));
910    }
911
912    pub fn add_adapter_layer(&mut self, name: String, layer: AdapterLayer) {
913        // Store serializable metadata
914        if let Ok(metadata) = Self::serialize_adapter_layer(&layer) {
915            self.layer_metadata.insert(name.clone(), metadata);
916        }
917        self.peft_layers.insert(name, Box::new(layer));
918    }
919
920    pub fn add_prefix_tuning_layer(&mut self, name: String, layer: PrefixTuningLayer) {
921        // Store serializable metadata
922        if let Ok(metadata) = Self::serialize_prefix_tuning_layer(&layer) {
923            self.layer_metadata.insert(name.clone(), metadata);
924        }
925        self.peft_layers.insert(name, Box::new(layer));
926    }
927
928    pub fn add_prompt_tuning_embedding(&mut self, name: String, embedding: PromptTuningEmbedding) {
929        // Store serializable metadata
930        if let Ok(metadata) = Self::serialize_prompt_tuning_embedding(&embedding) {
931            self.layer_metadata.insert(name.clone(), metadata);
932        }
933        self.peft_layers.insert(name, Box::new(embedding));
934    }
935
936    pub fn enable_peft(&mut self) {
937        self.active = true;
938    }
939
940    pub fn disable_peft(&mut self) {
941        self.active = false;
942    }
943
944    pub fn merge_and_unload(&mut self) -> Result<()> {
945        // Merge all LoRA layers
946        for (name, layer) in &mut self.peft_layers {
947            // This would need to be implemented per layer type
948            // For now, just mark as merged
949        }
950
951        self.active = false;
952        Ok(())
953    }
954
955    pub fn get_trainable_parameters(&self) -> Vec<String> {
956        if !self.active {
957            return Vec::new();
958        }
959
960        let mut trainable = Vec::new();
961        for name in self.peft_layers.keys() {
962            if self.config.target_modules.contains(name) {
963                trainable.push(name.clone());
964            }
965        }
966
967        trainable
968    }
969
970    pub fn save_pretrained(&self, path: &str) -> Result<()> {
971        // Create directory if it doesn't exist
972        std::fs::create_dir_all(path).map_err(|e| TrustformersError::io_error(e.to_string()))?;
973
974        // Save PEFT configuration
975        let config_json = serde_json::to_string_pretty(&self.config)
976            .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
977        std::fs::write(format!("{}/peft_config.json", path), config_json)
978            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
979
980        // Save adapter weights using stored metadata
981        let weights_json = serde_json::to_string_pretty(&self.layer_metadata)
982            .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
983        std::fs::write(format!("{}/adapter_weights.json", path), weights_json)
984            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
985
986        Ok(())
987    }
988
989    pub fn load_pretrained(path: &str) -> Result<Self> {
990        let config_str = std::fs::read_to_string(format!("{}/peft_config.json", path))
991            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
992
993        let config: PeftConfig = serde_json::from_str(&config_str)
994            .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
995
996        let mut model = Self::new(config);
997
998        // Load adapter weights
999        let weights_str = std::fs::read_to_string(format!("{}/adapter_weights.json", path))
1000            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
1001
1002        let layer_metadata: HashMap<String, SerializableLayerData> =
1003            serde_json::from_str(&weights_str)
1004                .map_err(|e| TrustformersError::other(format!("Serialization error: {}", e)))?;
1005
1006        // Reconstruct layers from metadata
1007        for (name, data) in layer_metadata {
1008            match &data {
1009                SerializableLayerData::LoRA { .. } => {
1010                    let layer = Self::deserialize_lora_layer(&data)?;
1011                    model.add_lora_layer(name, layer);
1012                },
1013                SerializableLayerData::Adapter { .. } => {
1014                    let layer = Self::deserialize_adapter_layer(&data)?;
1015                    model.add_adapter_layer(name, layer);
1016                },
1017                SerializableLayerData::PrefixTuning { .. } => {
1018                    let layer = Self::deserialize_prefix_tuning_layer(&data)?;
1019                    model.add_prefix_tuning_layer(name, layer);
1020                },
1021                SerializableLayerData::PromptTuning { .. } => {
1022                    let embedding = Self::deserialize_prompt_tuning_embedding(&data)?;
1023                    model.add_prompt_tuning_embedding(name, embedding);
1024                },
1025            }
1026        }
1027
1028        Ok(model)
1029    }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035
1036    #[test]
1037    fn test_lora_layer_creation() {
1038        let lora = LoRALayer::new(768, 768, 8, 16.0, 0.1, true).expect("operation failed in test");
1039        assert_eq!(lora.r, 8);
1040        assert_eq!(lora.alpha, 16.0);
1041        assert!(!lora.merged);
1042        assert!(lora.frozen);
1043    }
1044
1045    #[test]
1046    fn test_lora_layer_forward() {
1047        let mut lora =
1048            LoRALayer::new(64, 64, 4, 8.0, 0.0, false).expect("operation failed in test");
1049        lora.initialize_weights().expect("operation failed in test");
1050
1051        let input = Tensor::randn(&[10, 64]).expect("Failed to create random tensor");
1052        let output = lora.forward(input.clone()).expect("forward pass failed");
1053
1054        assert_eq!(output.shape(), input.shape());
1055    }
1056
1057    #[test]
1058    fn test_lora_merge_unmerge() {
1059        let mut lora =
1060            LoRALayer::new(32, 32, 2, 4.0, 0.0, false).expect("operation failed in test");
1061        lora.initialize_weights().expect("operation failed in test");
1062
1063        assert!(!lora.merged);
1064
1065        lora.merge_weights().expect("merge operation failed");
1066        assert!(lora.merged);
1067
1068        lora.unmerge_weights().expect("merge operation failed");
1069        assert!(!lora.merged);
1070    }
1071
1072    #[test]
1073    fn test_qlora_layer() {
1074        let mut qlora =
1075            QLoRALayer::new(64, 64, 4, 8.0, 0.1, false).expect("operation failed in test");
1076
1077        let quant_config = crate::quantization::QuantizationConfig::default();
1078        qlora.quantize_base(&quant_config).expect("operation failed in test");
1079
1080        let input = Tensor::randn(&[5, 64]).expect("Failed to create random tensor");
1081        let output = qlora.forward(input.clone()).expect("forward pass failed");
1082
1083        assert_eq!(output.shape(), input.shape());
1084    }
1085
1086    #[test]
1087    fn test_adapter_layer() {
1088        let adapter = AdapterLayer::new(128, 32, ActivationType::GELU, 0.1);
1089        assert_eq!(adapter.bottleneck_size, 32);
1090
1091        let input = Tensor::randn(&[8, 128]).expect("Failed to create random tensor");
1092        let output = adapter.forward(input.clone()).expect("forward pass failed");
1093
1094        assert_eq!(output.shape(), input.shape());
1095    }
1096
1097    #[test]
1098    fn test_prefix_tuning_layer() {
1099        let prefix = PrefixTuningLayer::new(10, 64, 12, 8).expect("operation failed in test");
1100        assert_eq!(prefix.prefix_length, 10);
1101        assert_eq!(prefix.num_layers, 12);
1102
1103        let prefix_states = prefix.get_prefix_states().expect("operation failed in test");
1104        assert_eq!(prefix_states.len(), 12);
1105    }
1106
1107    #[test]
1108    fn test_prompt_tuning_embedding() {
1109        let prompt = PromptTuningEmbedding::new(5, 768, PromptInitMethod::Random)
1110            .expect("operation failed in test");
1111        assert_eq!(prompt.num_virtual_tokens, 5);
1112        assert_eq!(prompt.hidden_size, 768);
1113
1114        let embeddings = prompt.get_prompt_embeddings();
1115        assert_eq!(embeddings.shape(), vec![5, 768]);
1116    }
1117
1118    #[test]
1119    fn test_peft_model() {
1120        let config = PeftConfig::default();
1121        let mut peft_model = PeftModel::new(config);
1122
1123        let lora = LoRALayer::new(64, 64, 4, 8.0, 0.1, false).expect("operation failed in test");
1124        peft_model.add_lora_layer("test_layer".to_string(), lora);
1125
1126        assert_eq!(peft_model.peft_layers.len(), 1);
1127        assert!(peft_model.active);
1128
1129        peft_model.disable_peft();
1130        assert!(!peft_model.active);
1131    }
1132
1133    #[test]
1134    fn test_peft_config_serialization() {
1135        let config = PeftConfig::default();
1136        let json = serde_json::to_string(&config).expect("JSON serialization failed");
1137        let deserialized: PeftConfig =
1138            serde_json::from_str(&json).expect("JSON deserialization failed");
1139
1140        assert_eq!(config.method, deserialized.method);
1141        assert_eq!(config.r, deserialized.r);
1142        assert_eq!(config.alpha, deserialized.alpha);
1143    }
1144
1145    #[test]
1146    fn test_activation_functions() {
1147        let adapter = AdapterLayer::new(64, 16, ActivationType::ReLU, 0.0);
1148        let input =
1149            Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).expect("Tensor from_vec failed");
1150
1151        let relu_result = adapter.relu(&input).expect("operation failed in test");
1152        let data = relu_result.data().expect("operation failed in test");
1153        assert_eq!(data[0], 0.0); // ReLU(-1) = 0
1154        assert_eq!(data[1], 0.0); // ReLU(0) = 0
1155        assert_eq!(data[2], 1.0); // ReLU(1) = 1
1156        assert_eq!(data[3], 2.0); // ReLU(2) = 2
1157    }
1158}