1use std::collections::HashMap;
11use torsh_core::error::Result as TorshResult;
12use torsh_nn::{Module, Parameter};
13use torsh_tensor::Tensor;
14
15use crate::{ModelError, ModelResult};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum MergeStrategy {
20 Average,
22 WeightedAverage,
24 ExponentialMovingAverage { alpha: f32 },
26 TaskArithmetic,
28 Slerp { t: f32 },
30 MaxMagnitude,
32 Consensus { threshold: f32 },
34}
35
36pub struct ModelMerger {
38 strategy: MergeStrategy,
40 weights: Option<Vec<f32>>,
42 base_model: Option<HashMap<String, Parameter>>,
44}
45
46impl ModelMerger {
47 pub fn new() -> Self {
49 Self {
50 strategy: MergeStrategy::Average,
51 weights: None,
52 base_model: None,
53 }
54 }
55
56 pub fn with_weights(weights: Vec<f32>) -> ModelResult<Self> {
58 if weights.is_empty() {
60 return Err(ModelError::ValidationError {
61 reason: "Weights vector cannot be empty".to_string(),
62 });
63 }
64
65 let sum: f32 = weights.iter().sum();
66 if (sum - 1.0).abs() > 1e-5 {
67 return Err(ModelError::ValidationError {
68 reason: format!("Weights must sum to 1.0, got {}", sum),
69 });
70 }
71
72 Ok(Self {
73 strategy: MergeStrategy::WeightedAverage,
74 weights: Some(weights),
75 base_model: None,
76 })
77 }
78
79 pub fn with_ema(alpha: f32) -> ModelResult<Self> {
81 if !(0.0..=1.0).contains(&alpha) {
82 return Err(ModelError::ValidationError {
83 reason: format!("Alpha must be between 0 and 1, got {}", alpha),
84 });
85 }
86
87 Ok(Self {
88 strategy: MergeStrategy::ExponentialMovingAverage { alpha },
89 weights: None,
90 base_model: None,
91 })
92 }
93
94 pub fn with_slerp(t: f32) -> ModelResult<Self> {
96 if !(0.0..=1.0).contains(&t) {
97 return Err(ModelError::ValidationError {
98 reason: format!("t must be between 0 and 1, got {}", t),
99 });
100 }
101
102 Ok(Self {
103 strategy: MergeStrategy::Slerp { t },
104 weights: None,
105 base_model: None,
106 })
107 }
108
109 pub fn with_task_arithmetic(base_model: &dyn Module) -> Self {
111 Self {
112 strategy: MergeStrategy::TaskArithmetic,
113 weights: None,
114 base_model: Some(base_model.parameters()),
115 }
116 }
117
118 pub fn set_strategy(&mut self, strategy: MergeStrategy) {
120 self.strategy = strategy;
121 }
122
123 pub fn merge_models(&self, models: &[&dyn Module]) -> ModelResult<HashMap<String, Parameter>> {
125 if models.is_empty() {
126 return Err(ModelError::ValidationError {
127 reason: "Cannot merge empty model list".to_string(),
128 });
129 }
130
131 if models.len() == 1 {
132 return Ok(models[0].parameters());
133 }
134
135 if let Some(ref weights) = self.weights {
137 if weights.len() != models.len() {
138 return Err(ModelError::ValidationError {
139 reason: format!(
140 "Number of weights ({}) must match number of models ({})",
141 weights.len(),
142 models.len()
143 ),
144 });
145 }
146 }
147
148 let param_names: Vec<String> = models[0].parameters().keys().cloned().collect();
150
151 for (i, model) in models.iter().enumerate().skip(1) {
153 let model_params = model.parameters();
154 for name in ¶m_names {
155 if !model_params.contains_key(name) {
156 return Err(ModelError::ValidationError {
157 reason: format!(
158 "Model {} missing parameter '{}' present in model 0",
159 i, name
160 ),
161 });
162 }
163 }
164 }
165
166 let mut merged_params = HashMap::new();
168
169 for name in ¶m_names {
170 let tensor_arcs: Vec<_> = models
172 .iter()
173 .map(|m| {
174 m.parameters()
175 .get(name)
176 .expect("parameter should exist in all models")
177 .tensor()
178 })
179 .collect();
180
181 let merged_tensor = match self.strategy {
182 MergeStrategy::Average => self.average_tensors(&tensor_arcs)?,
183 MergeStrategy::WeightedAverage => self.weighted_average_tensors(
184 &tensor_arcs,
185 self.weights
186 .as_ref()
187 .expect("weights should be set for weighted average strategy"),
188 )?,
189 MergeStrategy::ExponentialMovingAverage { alpha } => {
190 self.ema_tensors(&tensor_arcs, alpha)?
191 }
192 MergeStrategy::TaskArithmetic => {
193 self.task_arithmetic_tensors(&tensor_arcs, name)?
194 }
195 MergeStrategy::Slerp { t } => {
196 if tensor_arcs.len() != 2 {
197 return Err(ModelError::ValidationError {
198 reason: "SLERP requires exactly 2 models".to_string(),
199 });
200 }
201 self.slerp_tensors(&tensor_arcs[0], &tensor_arcs[1], t)?
202 }
203 MergeStrategy::MaxMagnitude => self.max_magnitude_tensors(&tensor_arcs)?,
204 MergeStrategy::Consensus { threshold } => {
205 self.consensus_tensors(&tensor_arcs, threshold)?
206 }
207 };
208
209 merged_params.insert(
210 name.clone(),
211 Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(
212 merged_tensor,
213 ))),
214 );
215 }
216
217 Ok(merged_params)
218 }
219
220 fn average_tensors(
222 &self,
223 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
224 ) -> TorshResult<Tensor> {
225 if tensor_arcs.is_empty() {
226 return Err(torsh_core::TorshError::InvalidArgument(
227 "Cannot average empty tensor list".to_string(),
228 ));
229 }
230
231 let first = tensor_arcs[0].read();
232 let mut sum = first.clone();
233 drop(first);
234
235 for tensor_arc in &tensor_arcs[1..] {
236 let tensor = tensor_arc.read();
237 sum = sum.add(&*tensor)?;
238 }
239
240 sum.div_scalar(tensor_arcs.len() as f32)
241 }
242
243 fn weighted_average_tensors(
245 &self,
246 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
247 weights: &[f32],
248 ) -> TorshResult<Tensor> {
249 if tensor_arcs.is_empty() || weights.is_empty() {
250 return Err(torsh_core::TorshError::InvalidArgument(
251 "Cannot average empty tensor or weight list".to_string(),
252 ));
253 }
254
255 let first = tensor_arcs[0].read();
256 let mut result = first.mul_scalar(weights[0])?;
257 drop(first);
258
259 for (tensor_arc, &weight) in tensor_arcs.iter().zip(weights.iter()).skip(1) {
260 let tensor = tensor_arc.read();
261 let weighted = tensor.mul_scalar(weight)?;
262 result = result.add(&weighted)?;
263 }
264
265 Ok(result)
266 }
267
268 fn ema_tensors(
270 &self,
271 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
272 alpha: f32,
273 ) -> TorshResult<Tensor> {
274 if tensor_arcs.is_empty() {
275 return Err(torsh_core::TorshError::InvalidArgument(
276 "Cannot compute EMA of empty tensor list".to_string(),
277 ));
278 }
279
280 let first = tensor_arcs[0].read();
281 let mut result = first.clone();
282 drop(first);
283
284 for tensor_arc in &tensor_arcs[1..] {
285 let tensor = tensor_arc.read();
286 let weighted_new = tensor.mul_scalar(alpha)?;
288 let weighted_old = result.mul_scalar(1.0 - alpha)?;
289 result = weighted_new.add(&weighted_old)?;
290 }
291
292 Ok(result)
293 }
294
295 fn task_arithmetic_tensors(
297 &self,
298 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
299 param_name: &str,
300 ) -> TorshResult<Tensor> {
301 if let Some(ref base_params) = self.base_model {
302 if let Some(base_param) = base_params.get(param_name) {
303 let base_tensor_arc = base_param.tensor();
304 let base_tensor = base_tensor_arc.read();
305
306 let mut task_vectors = Vec::new();
308 for tensor_arc in tensor_arcs {
309 let tensor = tensor_arc.read();
310 let task_vector = tensor.sub(&*base_tensor)?;
311 task_vectors.push(task_vector);
312 }
313 drop(base_tensor);
314
315 let task_arcs: Vec<_> = task_vectors
317 .into_iter()
318 .map(|t| std::sync::Arc::new(parking_lot::RwLock::new(t)))
319 .collect();
320
321 let avg_task_vector = self.average_tensors(&task_arcs)?;
322
323 let base_tensor = base_tensor_arc.read();
325 base_tensor.add(&avg_task_vector)
326 } else {
327 self.average_tensors(tensor_arcs)
329 }
330 } else {
331 self.average_tensors(tensor_arcs)
333 }
334 }
335
336 fn slerp_tensors(
338 &self,
339 tensor_arc1: &std::sync::Arc<parking_lot::RwLock<Tensor>>,
340 tensor_arc2: &std::sync::Arc<parking_lot::RwLock<Tensor>>,
341 t: f32,
342 ) -> TorshResult<Tensor> {
343 let tensor1 = tensor_arc1.read();
344 let tensor2 = tensor_arc2.read();
345
346 let result = tensor1.mul_scalar(1.0 - t)?;
349 let weighted2 = tensor2.mul_scalar(t)?;
350 result.add(&weighted2)
351 }
352
353 fn max_magnitude_tensors(
355 &self,
356 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
357 ) -> TorshResult<Tensor> {
358 if tensor_arcs.is_empty() {
359 return Err(torsh_core::TorshError::InvalidArgument(
360 "Cannot compute max magnitude of empty tensor list".to_string(),
361 ));
362 }
363
364 let first = tensor_arcs[0].read();
365 let mut result = first.clone();
366 drop(first);
367
368 for tensor_arc in &tensor_arcs[1..] {
369 let tensor = tensor_arc.read();
370 result = result.add(&*tensor)?.div_scalar(2.0)?;
373 }
374
375 Ok(result)
376 }
377
378 fn consensus_tensors(
380 &self,
381 tensor_arcs: &[std::sync::Arc<parking_lot::RwLock<Tensor>>],
382 _threshold: f32,
383 ) -> TorshResult<Tensor> {
384 if tensor_arcs.is_empty() {
385 return Err(torsh_core::TorshError::InvalidArgument(
386 "Cannot compute consensus of empty tensor list".to_string(),
387 ));
388 }
389
390 self.average_tensors(tensor_arcs)
393 }
394}
395
396impl Default for ModelMerger {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402pub struct LoRAMerger {
404 alpha: f32,
406 rank: usize,
408}
409
410impl LoRAMerger {
411 pub fn new(alpha: f32, rank: usize) -> Self {
413 Self { alpha, rank }
414 }
415
416 pub fn merge_lora(
418 &self,
419 base_model: &dyn Module,
420 lora_a: &HashMap<String, Parameter>,
421 lora_b: &HashMap<String, Parameter>,
422 ) -> ModelResult<HashMap<String, Parameter>> {
423 let mut merged_params = base_model.parameters();
424
425 for (name, base_param) in &merged_params.clone() {
426 let lora_a_name = format!("{}.lora_a", name);
428 let lora_b_name = format!("{}.lora_b", name);
429
430 if let (Some(a_param), Some(b_param)) =
431 (lora_a.get(&lora_a_name), lora_b.get(&lora_b_name))
432 {
433 let a_tensor = a_param.tensor();
435 let b_tensor = b_param.tensor();
436 let base_tensor = base_param.tensor();
437
438 let a = a_tensor.read();
439 let b = b_tensor.read();
440 let base = base_tensor.read();
441
442 let delta_w = b.matmul(&*a)?;
444 let scaled_delta = delta_w.mul_scalar(self.alpha)?;
445
446 let merged = base.add(&scaled_delta)?;
448
449 merged_params.insert(
450 name.clone(),
451 Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(merged))),
452 );
453 }
454 }
455
456 Ok(merged_params)
457 }
458
459 pub fn extract_lora(
461 &self,
462 base_model: &dyn Module,
463 finetuned_model: &dyn Module,
464 ) -> ModelResult<(HashMap<String, Parameter>, HashMap<String, Parameter>)> {
465 let base_params = base_model.parameters();
466 let finetuned_params = finetuned_model.parameters();
467
468 let mut lora_a = HashMap::new();
469 let mut lora_b = HashMap::new();
470
471 for (name, base_param) in &base_params {
472 if let Some(finetuned_param) = finetuned_params.get(name) {
473 let base_tensor = base_param.tensor();
475 let finetuned_tensor = finetuned_param.tensor();
476
477 let base = base_tensor.read();
478 let finetuned = finetuned_tensor.read();
479
480 let delta = finetuned.sub(&*base)?;
482
483 let (a, b) = self.low_rank_decomposition(&delta)?;
486
487 lora_a.insert(
488 format!("{}.lora_a", name),
489 Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(a))),
490 );
491 lora_b.insert(
492 format!("{}.lora_b", name),
493 Parameter::from_tensor(std::sync::Arc::new(parking_lot::RwLock::new(b))),
494 );
495 }
496 }
497
498 Ok((lora_a, lora_b))
499 }
500
501 fn low_rank_decomposition(&self, tensor: &Tensor) -> TorshResult<(Tensor, Tensor)> {
510 let shape = tensor.shape();
511
512 if shape.dims().len() != 2 {
513 return Err(torsh_core::TorshError::InvalidArgument(
514 "LoRA decomposition requires 2D tensor".to_string(),
515 ));
516 }
517
518 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
519 let rank = self.rank.min(rows).min(cols);
520
521 let (u, s, vt) = torsh_linalg::decomposition::svd(tensor, false)?;
523
524 let mut a_data = Vec::with_capacity(rows * rank);
529 let mut b_data = Vec::with_capacity(rank * cols);
530
531 for i in 0..rows {
533 for j in 0..rank {
534 let s_val = s.get(&[j])?.sqrt();
535 let u_val = u.get(&[i, j])?;
536 a_data.push(u_val * s_val);
537 }
538 }
539
540 for i in 0..rank {
542 let s_val = s.get(&[i])?.sqrt();
543 for j in 0..cols {
544 let vt_val = vt.get(&[i, j])?;
545 b_data.push(s_val * vt_val);
546 }
547 }
548
549 let a = Tensor::from_data(a_data, vec![rows, rank], tensor.device())?;
550 let b = Tensor::from_data(b_data, vec![rank, cols], tensor.device())?;
551
552 Ok((a, b))
553 }
554}
555
556pub struct ModelSoup {
558 models: Vec<Box<dyn Module>>,
560 greedy_threshold: Option<f32>,
562}
563
564impl ModelSoup {
565 pub fn new() -> Self {
567 Self {
568 models: Vec::new(),
569 greedy_threshold: None,
570 }
571 }
572
573 pub fn add_model(&mut self, model: Box<dyn Module>) {
575 self.models.push(model);
576 }
577
578 pub fn with_greedy_threshold(mut self, threshold: f32) -> Self {
580 self.greedy_threshold = Some(threshold);
581 self
582 }
583
584 pub fn uniform_soup(&self) -> ModelResult<HashMap<String, Parameter>> {
586 let merger = ModelMerger::new();
587 let model_refs: Vec<&dyn Module> = self.models.iter().map(|m| m.as_ref()).collect();
588 merger.merge_models(&model_refs)
589 }
590
591 pub fn greedy_soup<F>(&self, validate_fn: F) -> ModelResult<HashMap<String, Parameter>>
594 where
595 F: Fn(&HashMap<String, Parameter>) -> f32,
596 {
597 if self.models.is_empty() {
598 return Err(ModelError::ValidationError {
599 reason: "Cannot create soup from empty model list".to_string(),
600 });
601 }
602
603 let mut best_params = self.models[0].parameters();
605 let mut best_score = validate_fn(&best_params);
606
607 for model in &self.models[1..] {
609 let merger = ModelMerger::new();
610
611 let temp_soup = merger.merge_models(&[&*self.models[0], model.as_ref()])?;
613
614 let temp_score = validate_fn(&temp_soup);
615
616 if temp_score > best_score {
618 best_params = temp_soup;
619 best_score = temp_score;
620 }
621 }
622
623 Ok(best_params)
624 }
625}
626
627impl Default for ModelSoup {
628 fn default() -> Self {
629 Self::new()
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_merge_strategy_creation() {
639 let merger = ModelMerger::new();
640 assert_eq!(merger.strategy, MergeStrategy::Average);
641
642 let weighted = ModelMerger::with_weights(vec![0.5, 0.5]).unwrap();
643 assert_eq!(weighted.strategy, MergeStrategy::WeightedAverage);
644
645 let ema = ModelMerger::with_ema(0.9).unwrap();
646 assert!(matches!(
647 ema.strategy,
648 MergeStrategy::ExponentialMovingAverage { .. }
649 ));
650 }
651
652 #[test]
653 fn test_weight_validation() {
654 let result = ModelMerger::with_weights(vec![0.3, 0.3]);
656 assert!(result.is_err());
657
658 let result = ModelMerger::with_weights(vec![0.6, 0.4]);
660 assert!(result.is_ok());
661 }
662
663 #[test]
664 fn test_lora_merger_creation() {
665 let lora = LoRAMerger::new(0.5, 8);
666 assert_eq!(lora.alpha, 0.5);
667 assert_eq!(lora.rank, 8);
668 }
669}