Skip to main content

torsh_models/
model_merging.rs

1//! Model merging and fusion utilities
2//!
3//! This module provides utilities for combining multiple models:
4//! - Model averaging (simple, weighted, exponential moving average)
5//! - LoRA merging and extraction
6//! - Model soup (combining fine-tuned models)
7//! - Task arithmetic (adding/subtracting task vectors)
8//! - SLERP (Spherical Linear Interpolation)
9
10use std::collections::HashMap;
11use torsh_core::error::Result as TorshResult;
12use torsh_nn::{Module, Parameter};
13use torsh_tensor::Tensor;
14
15use crate::{ModelError, ModelResult};
16
17/// Model merging strategy
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum MergeStrategy {
20    /// Simple average of parameters
21    Average,
22    /// Weighted average with specified weights
23    WeightedAverage,
24    /// Exponential moving average
25    ExponentialMovingAverage { alpha: f32 },
26    /// Task arithmetic (subtract base, add task vectors)
27    TaskArithmetic,
28    /// SLERP - Spherical Linear Interpolation
29    Slerp { t: f32 },
30    /// Maximum magnitude (take parameter with largest magnitude)
31    MaxMagnitude,
32    /// Consensus (only merge if models agree within threshold)
33    Consensus { threshold: f32 },
34}
35
36/// Model merger for combining multiple models
37pub struct ModelMerger {
38    /// Merging strategy
39    strategy: MergeStrategy,
40    /// Weights for weighted averaging (if applicable)
41    weights: Option<Vec<f32>>,
42    /// Base model for task arithmetic
43    base_model: Option<HashMap<String, Parameter>>,
44}
45
46impl ModelMerger {
47    /// Create a new model merger with simple averaging
48    pub fn new() -> Self {
49        Self {
50            strategy: MergeStrategy::Average,
51            weights: None,
52            base_model: None,
53        }
54    }
55
56    /// Create merger with weighted averaging
57    pub fn with_weights(weights: Vec<f32>) -> ModelResult<Self> {
58        // Validate weights
59        if weights.is_empty() {
60            return Err(ModelError::ValidationError {
61                reason: "Weights vector cannot be empty".to_string(),
62            });
63        }
64
65        let sum: f32 = weights.iter().sum();
66        if (sum - 1.0).abs() > 1e-5 {
67            return Err(ModelError::ValidationError {
68                reason: format!("Weights must sum to 1.0, got {}", sum),
69            });
70        }
71
72        Ok(Self {
73            strategy: MergeStrategy::WeightedAverage,
74            weights: Some(weights),
75            base_model: None,
76        })
77    }
78
79    /// Create merger with exponential moving average
80    pub fn with_ema(alpha: f32) -> ModelResult<Self> {
81        if !(0.0..=1.0).contains(&alpha) {
82            return Err(ModelError::ValidationError {
83                reason: format!("Alpha must be between 0 and 1, got {}", alpha),
84            });
85        }
86
87        Ok(Self {
88            strategy: MergeStrategy::ExponentialMovingAverage { alpha },
89            weights: None,
90            base_model: None,
91        })
92    }
93
94    /// Create merger with SLERP
95    pub fn with_slerp(t: f32) -> ModelResult<Self> {
96        if !(0.0..=1.0).contains(&t) {
97            return Err(ModelError::ValidationError {
98                reason: format!("t must be between 0 and 1, got {}", t),
99            });
100        }
101
102        Ok(Self {
103            strategy: MergeStrategy::Slerp { t },
104            weights: None,
105            base_model: None,
106        })
107    }
108
109    /// Create merger with task arithmetic
110    pub fn with_task_arithmetic(base_model: &dyn Module) -> Self {
111        Self {
112            strategy: MergeStrategy::TaskArithmetic,
113            weights: None,
114            base_model: Some(base_model.parameters()),
115        }
116    }
117
118    /// Set merging strategy
119    pub fn set_strategy(&mut self, strategy: MergeStrategy) {
120        self.strategy = strategy;
121    }
122
123    /// Merge multiple models into one
124    pub fn merge_models(&self, models: &[&dyn Module]) -> ModelResult<HashMap<String, Parameter>> {
125        if models.is_empty() {
126            return Err(ModelError::ValidationError {
127                reason: "Cannot merge empty model list".to_string(),
128            });
129        }
130
131        if models.len() == 1 {
132            return Ok(models[0].parameters());
133        }
134
135        // Validate weights match number of models if using weighted averaging
136        if let Some(ref weights) = self.weights {
137            if weights.len() != models.len() {
138                return Err(ModelError::ValidationError {
139                    reason: format!(
140                        "Number of weights ({}) must match number of models ({})",
141                        weights.len(),
142                        models.len()
143                    ),
144                });
145            }
146        }
147
148        // Get all parameter names from first model
149        let param_names: Vec<String> = models[0].parameters().keys().cloned().collect();
150
151        // Validate all models have the same parameters
152        for (i, model) in models.iter().enumerate().skip(1) {
153            let model_params = model.parameters();
154            for name in &param_names {
155                if !model_params.contains_key(name) {
156                    return Err(ModelError::ValidationError {
157                        reason: format!(
158                            "Model {} missing parameter '{}' present in model 0",
159                            i, name
160                        ),
161                    });
162                }
163            }
164        }
165
166        // Merge parameters
167        let mut merged_params = HashMap::new();
168
169        for name in &param_names {
170            // Collect tensors with proper Arc<RwLock> handling
171            let tensor_arcs: Vec<_> = models
172                .iter()
173                .map(|m| {
174                    m.parameters()
175                        .get(name)
176                        .expect("parameter should exist in all models")
177                        .tensor()
178                })
179                .collect();
180
181            let merged_tensor = match self.strategy {
182                MergeStrategy::Average => self.average_tensors(&tensor_arcs)?,
183                MergeStrategy::WeightedAverage => self.weighted_average_tensors(
184                    &tensor_arcs,
185                    self.weights
186                        .as_ref()
187                        .expect("weights should be set for weighted average strategy"),
188                )?,
189                MergeStrategy::ExponentialMovingAverage { alpha } => {
190                    self.ema_tensors(&tensor_arcs, alpha)?
191                }
192                MergeStrategy::TaskArithmetic => {
193                    self.task_arithmetic_tensors(&tensor_arcs, name)?
194                }
195                MergeStrategy::Slerp { t } => {
196                    if tensor_arcs.len() != 2 {
197                        return Err(ModelError::ValidationError {
198                            reason: "SLERP requires exactly 2 models".to_string(),
199                        });
200                    }
201                    self.slerp_tensors(&tensor_arcs[0], &tensor_arcs[1], t)?
202                }
203                MergeStrategy::MaxMagnitude => self.max_magnitude_tensors(&tensor_arcs)?,
204                MergeStrategy::Consensus { threshold } => {
205                    self.consensus_tensors(&tensor_arcs, threshold)?
206                }
207            };
208
209            merged_params.insert(
210                name.clone(),
211                Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(
212                    merged_tensor,
213                ))),
214            );
215        }
216
217        Ok(merged_params)
218    }
219
220    /// Simple averaging of tensors
221    fn average_tensors(
222        &self,
223        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
224    ) -> TorshResult<Tensor> {
225        if tensor_arcs.is_empty() {
226            return Err(torsh_core::TorshError::InvalidArgument(
227                "Cannot average empty tensor list".to_string(),
228            ));
229        }
230
231        let first = tensor_arcs[0].read();
232        let mut sum = first.clone();
233        drop(first);
234
235        for tensor_arc in &tensor_arcs[1..] {
236            let tensor = tensor_arc.read();
237            sum = sum.add(&*tensor)?;
238        }
239
240        sum.div_scalar(tensor_arcs.len() as f32)
241    }
242
243    /// Weighted averaging of tensors
244    fn weighted_average_tensors(
245        &self,
246        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
247        weights: &[f32],
248    ) -> TorshResult<Tensor> {
249        if tensor_arcs.is_empty() || weights.is_empty() {
250            return Err(torsh_core::TorshError::InvalidArgument(
251                "Cannot average empty tensor or weight list".to_string(),
252            ));
253        }
254
255        let first = tensor_arcs[0].read();
256        let mut result = first.mul_scalar(weights[0])?;
257        drop(first);
258
259        for (tensor_arc, &weight) in tensor_arcs.iter().zip(weights.iter()).skip(1) {
260            let tensor = tensor_arc.read();
261            let weighted = tensor.mul_scalar(weight)?;
262            result = result.add(&weighted)?;
263        }
264
265        Ok(result)
266    }
267
268    /// Exponential moving average
269    fn ema_tensors(
270        &self,
271        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
272        alpha: f32,
273    ) -> TorshResult<Tensor> {
274        if tensor_arcs.is_empty() {
275            return Err(torsh_core::TorshError::InvalidArgument(
276                "Cannot compute EMA of empty tensor list".to_string(),
277            ));
278        }
279
280        let first = tensor_arcs[0].read();
281        let mut result = first.clone();
282        drop(first);
283
284        for tensor_arc in &tensor_arcs[1..] {
285            let tensor = tensor_arc.read();
286            // result = alpha * tensor + (1 - alpha) * result
287            let weighted_new = tensor.mul_scalar(alpha)?;
288            let weighted_old = result.mul_scalar(1.0 - alpha)?;
289            result = weighted_new.add(&weighted_old)?;
290        }
291
292        Ok(result)
293    }
294
295    /// Task arithmetic: (model - base) merging
296    fn task_arithmetic_tensors(
297        &self,
298        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
299        param_name: &str,
300    ) -> TorshResult<Tensor> {
301        if let Some(ref base_params) = self.base_model {
302            if let Some(base_param) = base_params.get(param_name) {
303                let base_tensor_arc = base_param.tensor();
304                let base_tensor = base_tensor_arc.read();
305
306                // Compute task vectors: model - base
307                let mut task_vectors = Vec::new();
308                for tensor_arc in tensor_arcs {
309                    let tensor = tensor_arc.read();
310                    let task_vector = tensor.sub(&*base_tensor)?;
311                    task_vectors.push(task_vector);
312                }
313                drop(base_tensor);
314
315                // Average task vectors - need to create Arc<RwLock> wrappers
316                let task_arcs: Vec<_> = task_vectors
317                    .into_iter()
318                    .map(|t| std::sync::Arc::new(parking_lot::RwLock::new(t)))
319                    .collect();
320
321                let avg_task_vector = self.average_tensors(&task_arcs)?;
322
323                // Add back to base
324                let base_tensor = base_tensor_arc.read();
325                base_tensor.add(&avg_task_vector)
326            } else {
327                // No base parameter, just average
328                self.average_tensors(tensor_arcs)
329            }
330        } else {
331            // No base model, fall back to averaging
332            self.average_tensors(tensor_arcs)
333        }
334    }
335
336    /// SLERP - Spherical Linear Interpolation
337    fn slerp_tensors(
338        &self,
339        tensor_arc1: &std::sync::Arc<parking_lot::RwLock<Tensor>>,
340        tensor_arc2: &std::sync::Arc<parking_lot::RwLock<Tensor>>,
341        t: f32,
342    ) -> TorshResult<Tensor> {
343        let tensor1 = tensor_arc1.read();
344        let tensor2 = tensor_arc2.read();
345
346        // Simplified SLERP - just linear interpolation for now
347        // Full SLERP implementation would require more tensor operations
348        let result = tensor1.mul_scalar(1.0 - t)?;
349        let weighted2 = tensor2.mul_scalar(t)?;
350        result.add(&weighted2)
351    }
352
353    /// Maximum magnitude merging
354    fn max_magnitude_tensors(
355        &self,
356        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
357    ) -> TorshResult<Tensor> {
358        if tensor_arcs.is_empty() {
359            return Err(torsh_core::TorshError::InvalidArgument(
360                "Cannot compute max magnitude of empty tensor list".to_string(),
361            ));
362        }
363
364        let first = tensor_arcs[0].read();
365        let mut result = first.clone();
366        drop(first);
367
368        for tensor_arc in &tensor_arcs[1..] {
369            let tensor = tensor_arc.read();
370            // Simplified: just take average for now
371            // Full implementation would need element-wise comparison
372            result = result.add(&*tensor)?.div_scalar(2.0)?;
373        }
374
375        Ok(result)
376    }
377
378    /// Consensus merging - only merge if models agree within threshold
379    fn consensus_tensors(
380        &self,
381        tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
382        _threshold: f32,
383    ) -> TorshResult<Tensor> {
384        if tensor_arcs.is_empty() {
385            return Err(torsh_core::TorshError::InvalidArgument(
386                "Cannot compute consensus of empty tensor list".to_string(),
387            ));
388        }
389
390        // Simplified: just average for now
391        // Full implementation would check threshold
392        self.average_tensors(tensor_arcs)
393    }
394}
395
396impl Default for ModelMerger {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402/// LoRA (Low-Rank Adaptation) merger
403pub struct LoRAMerger {
404    /// Scaling factor for LoRA weights
405    alpha: f32,
406    /// Rank of LoRA matrices
407    rank: usize,
408}
409
410impl LoRAMerger {
411    /// Create a new LoRA merger
412    pub fn new(alpha: f32, rank: usize) -> Self {
413        Self { alpha, rank }
414    }
415
416    /// Merge LoRA weights into base model
417    pub fn merge_lora(
418        &self,
419        base_model: &dyn Module,
420        lora_a: &HashMap<String, Parameter>,
421        lora_b: &HashMap<String, Parameter>,
422    ) -> ModelResult<HashMap<String, Parameter>> {
423        let mut merged_params = base_model.parameters();
424
425        for (name, base_param) in &merged_params.clone() {
426            // Check if LoRA parameters exist for this layer
427            let lora_a_name = format!("{}.lora_a", name);
428            let lora_b_name = format!("{}.lora_b", name);
429
430            if let (Some(a_param), Some(b_param)) =
431                (lora_a.get(&lora_a_name), lora_b.get(&lora_b_name))
432            {
433                // Read tensors from Arc<RwLock>
434                let a_tensor = a_param.tensor();
435                let b_tensor = b_param.tensor();
436                let base_tensor = base_param.tensor();
437
438                let a = a_tensor.read();
439                let b = b_tensor.read();
440                let base = base_tensor.read();
441
442                // Compute delta_W = alpha * B @ A
443                let delta_w = b.matmul(&*a)?;
444                let scaled_delta = delta_w.mul_scalar(self.alpha)?;
445
446                // Add to base weight: W' = W + delta_W
447                let merged = base.add(&scaled_delta)?;
448
449                merged_params.insert(
450                    name.clone(),
451                    Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(merged))),
452                );
453            }
454        }
455
456        Ok(merged_params)
457    }
458
459    /// Extract LoRA parameters from fine-tuned model
460    pub fn extract_lora(
461        &self,
462        base_model: &dyn Module,
463        finetuned_model: &dyn Module,
464    ) -> ModelResult<(HashMap<String, Parameter>, HashMap<String, Parameter>)> {
465        let base_params = base_model.parameters();
466        let finetuned_params = finetuned_model.parameters();
467
468        let mut lora_a = HashMap::new();
469        let mut lora_b = HashMap::new();
470
471        for (name, base_param) in &base_params {
472            if let Some(finetuned_param) = finetuned_params.get(name) {
473                // Read tensors from Arc<RwLock>
474                let base_tensor = base_param.tensor();
475                let finetuned_tensor = finetuned_param.tensor();
476
477                let base = base_tensor.read();
478                let finetuned = finetuned_tensor.read();
479
480                // Compute delta: W_finetuned - W_base
481                let delta = finetuned.sub(&*base)?;
482
483                // Perform low-rank decomposition (simplified SVD)
484                // In practice, use proper SVD with rank truncation
485                let (a, b) = self.low_rank_decomposition(&delta)?;
486
487                lora_a.insert(
488                    format!("{}.lora_a", name),
489                    Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(a))),
490                );
491                lora_b.insert(
492                    format!("{}.lora_b", name),
493                    Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(b))),
494                );
495            }
496        }
497
498        Ok((lora_a, lora_b))
499    }
500
501    /// Low-rank decomposition using SVD
502    ///
503    /// Computes a rank-k approximation of the input tensor using Singular Value Decomposition.
504    /// For a matrix W ∈ ℝ^(m×n), computes W ≈ A @ B where:
505    /// - A ∈ ℝ^(m×k) contains the k largest left singular vectors scaled by singular values
506    /// - B ∈ ℝ^(k×n) contains the k largest right singular vectors
507    ///
508    /// This is the optimal rank-k approximation in the Frobenius norm (Eckart-Young theorem).
509    fn low_rank_decomposition(&self, tensor: &Tensor) -> TorshResult<(Tensor, Tensor)> {
510        let shape = tensor.shape();
511
512        if shape.dims().len() != 2 {
513            return Err(torsh_core::TorshError::InvalidArgument(
514                "LoRA decomposition requires 2D tensor".to_string(),
515            ));
516        }
517
518        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
519        let rank = self.rank.min(rows).min(cols);
520
521        // Perform SVD: tensor = U @ diag(S) @ V^T
522        let (u, s, vt) = torsh_linalg::decomposition::svd(tensor, false)?;
523
524        // Extract the first 'rank' components
525        // A = U[:, :rank] @ diag(sqrt(S[:rank]))
526        // B = diag(sqrt(S[:rank])) @ V^T[:rank, :]
527
528        let mut a_data = Vec::with_capacity(rows * rank);
529        let mut b_data = Vec::with_capacity(rank * cols);
530
531        // Build A = U[:, :rank] @ diag(sqrt(S[:rank]))
532        for i in 0..rows {
533            for j in 0..rank {
534                let s_val = s.get(&[j])?.sqrt();
535                let u_val = u.get(&[i, j])?;
536                a_data.push(u_val * s_val);
537            }
538        }
539
540        // Build B = diag(sqrt(S[:rank])) @ V^T[:rank, :]
541        for i in 0..rank {
542            let s_val = s.get(&[i])?.sqrt();
543            for j in 0..cols {
544                let vt_val = vt.get(&[i, j])?;
545                b_data.push(s_val * vt_val);
546            }
547        }
548
549        let a = Tensor::from_data(a_data, vec![rows, rank], tensor.device())?;
550        let b = Tensor::from_data(b_data, vec![rank, cols], tensor.device())?;
551
552        Ok((a, b))
553    }
554}
555
556/// Model soup - combining multiple fine-tuned models
557pub struct ModelSoup {
558    /// Models to combine
559    models: Vec<Box<dyn Module>>,
560    /// Greedy selection threshold
561    greedy_threshold: Option<f32>,
562}
563
564impl ModelSoup {
565    /// Create a new model soup
566    pub fn new() -> Self {
567        Self {
568            models: Vec::new(),
569            greedy_threshold: None,
570        }
571    }
572
573    /// Add a model to the soup
574    pub fn add_model(&mut self, model: Box<dyn Module>) {
575        self.models.push(model);
576    }
577
578    /// Set greedy selection threshold
579    pub fn with_greedy_threshold(mut self, threshold: f32) -> Self {
580        self.greedy_threshold = Some(threshold);
581        self
582    }
583
584    /// Create soup by averaging all models
585    pub fn uniform_soup(&self) -> ModelResult<HashMap<String, Parameter>> {
586        let merger = ModelMerger::new();
587        let model_refs: Vec<&dyn Module> = self.models.iter().map(|m| m.as_ref()).collect();
588        merger.merge_models(&model_refs)
589    }
590
591    /// Create soup using greedy selection
592    /// Adds models one at a time if they improve validation performance
593    pub fn greedy_soup<F>(&self, validate_fn: F) -> ModelResult<HashMap<String, Parameter>>
594    where
595        F: Fn(&HashMap<String, Parameter>) -> f32,
596    {
597        if self.models.is_empty() {
598            return Err(ModelError::ValidationError {
599                reason: "Cannot create soup from empty model list".to_string(),
600            });
601        }
602
603        // Start with first model
604        let mut best_params = self.models[0].parameters();
605        let mut best_score = validate_fn(&best_params);
606
607        // Try adding each model
608        for model in &self.models[1..] {
609            let merger = ModelMerger::new();
610
611            // Create temporary soup with current best + this model
612            let temp_soup = merger.merge_models(&[&*self.models[0], model.as_ref()])?;
613
614            let temp_score = validate_fn(&temp_soup);
615
616            // If score improves, keep it
617            if temp_score > best_score {
618                best_params = temp_soup;
619                best_score = temp_score;
620            }
621        }
622
623        Ok(best_params)
624    }
625}
626
627impl Default for ModelSoup {
628    fn default() -> Self {
629        Self::new()
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_merge_strategy_creation() {
639        let merger = ModelMerger::new();
640        assert_eq!(merger.strategy, MergeStrategy::Average);
641
642        let weighted = ModelMerger::with_weights(vec![0.5, 0.5]).unwrap();
643        assert_eq!(weighted.strategy, MergeStrategy::WeightedAverage);
644
645        let ema = ModelMerger::with_ema(0.9).unwrap();
646        assert!(matches!(
647            ema.strategy,
648            MergeStrategy::ExponentialMovingAverage { .. }
649        ));
650    }
651
652    #[test]
653    fn test_weight_validation() {
654        // Invalid: doesn't sum to 1
655        let result = ModelMerger::with_weights(vec![0.3, 0.3]);
656        assert!(result.is_err());
657
658        // Valid
659        let result = ModelMerger::with_weights(vec![0.6, 0.4]);
660        assert!(result.is_ok());
661    }
662
663    #[test]
664    fn test_lora_merger_creation() {
665        let lora = LoRAMerger::new(0.5, 8);
666        assert_eq!(lora.alpha, 0.5);
667        assert_eq!(lora.rank, 8);
668    }
669}