1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use torsh_core::error::{Result, TorshError};
14use torsh_tensor::Tensor;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ModelDiff {
19 pub common_parameters: Vec<String>,
21 pub only_in_first: Vec<String>,
23 pub only_in_second: Vec<String>,
25 pub shape_differences: Vec<ShapeDifference>,
27 pub value_differences: Vec<ValueDifference>,
29 pub param_counts: (usize, usize),
31 pub memory_footprints: (u64, u64),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ShapeDifference {
38 pub parameter_name: String,
39 pub shape_first: Vec<usize>,
40 pub shape_second: Vec<usize>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ValueDifference {
46 pub parameter_name: String,
47 pub mean_absolute_diff: f64,
48 pub max_absolute_diff: f64,
49 pub relative_diff_percent: f64,
50 pub cosine_similarity: f64,
51}
52
53#[derive(Debug, Clone)]
55pub struct ComparisonOptions {
56 pub compute_value_diffs: bool,
58 pub diff_threshold: f64,
60 pub max_params_to_compare: usize,
62}
63
64impl Default for ComparisonOptions {
65 fn default() -> Self {
66 Self {
67 compute_value_diffs: true,
68 diff_threshold: 1e-5,
69 max_params_to_compare: 1000,
70 }
71 }
72}
73
74pub fn compare_models(
76 model1_state: &HashMap<String, Tensor<f32>>,
77 model2_state: &HashMap<String, Tensor<f32>>,
78 options: Option<ComparisonOptions>,
79) -> Result<ModelDiff> {
80 let options = options.unwrap_or_default();
81
82 let keys1: std::collections::HashSet<_> = model1_state.keys().cloned().collect();
83 let keys2: std::collections::HashSet<_> = model2_state.keys().cloned().collect();
84
85 let common_parameters: Vec<String> = keys1.intersection(&keys2).cloned().collect();
86 let only_in_first: Vec<String> = keys1.difference(&keys2).cloned().collect();
87 let only_in_second: Vec<String> = keys2.difference(&keys1).cloned().collect();
88
89 let mut shape_differences = Vec::new();
90 let mut value_differences = Vec::new();
91
92 for param_name in &common_parameters {
94 let tensor1 = &model1_state[param_name];
95 let tensor2 = &model2_state[param_name];
96
97 let shape1 = tensor1.shape().dims().to_vec();
98 let shape2 = tensor2.shape().dims().to_vec();
99
100 if shape1 != shape2 {
101 shape_differences.push(ShapeDifference {
102 parameter_name: param_name.clone(),
103 shape_first: shape1,
104 shape_second: shape2,
105 });
106 } else if options.compute_value_diffs
107 && value_differences.len() < options.max_params_to_compare
108 {
109 if let Ok(diff) =
111 compute_value_difference(tensor1, tensor2, param_name, options.diff_threshold)
112 {
113 value_differences.push(diff);
114 }
115 }
116 }
117
118 let param_counts = (model1_state.len(), model2_state.len());
119 let memory_footprints = (
120 estimate_memory_footprint(model1_state),
121 estimate_memory_footprint(model2_state),
122 );
123
124 Ok(ModelDiff {
125 common_parameters,
126 only_in_first,
127 only_in_second,
128 shape_differences,
129 value_differences,
130 param_counts,
131 memory_footprints,
132 })
133}
134
135fn compute_value_difference(
137 tensor1: &Tensor<f32>,
138 tensor2: &Tensor<f32>,
139 param_name: &str,
140 _threshold: f64,
141) -> Result<ValueDifference> {
142 let data1 = tensor1.to_vec()?;
144 let data2 = tensor2.to_vec()?;
145
146 if data1.len() != data2.len() {
147 return Err(TorshError::InvalidArgument(
148 "Tensors must have same number of elements".to_string(),
149 ));
150 }
151
152 let mut sum_abs_diff = 0.0f64;
154 let mut max_abs_diff = 0.0f64;
155 let mut dot_product = 0.0f64;
156 let mut norm1_sq = 0.0f64;
157 let mut norm2_sq = 0.0f64;
158
159 for (&v1, &v2) in data1.iter().zip(data2.iter()) {
160 let v1 = v1 as f64;
161 let v2 = v2 as f64;
162 let abs_diff = (v1 - v2).abs();
163
164 sum_abs_diff += abs_diff;
165 max_abs_diff = max_abs_diff.max(abs_diff);
166 dot_product += v1 * v2;
167 norm1_sq += v1 * v1;
168 norm2_sq += v2 * v2;
169 }
170
171 let mean_absolute_diff = sum_abs_diff / data1.len() as f64;
172
173 let mean1 = data1.iter().map(|&x| x as f64).sum::<f64>() / data1.len() as f64;
175 let relative_diff_percent = if mean1.abs() > 1e-10 {
176 (mean_absolute_diff / mean1.abs()) * 100.0
177 } else {
178 0.0
179 };
180
181 let cosine_similarity = if norm1_sq > 0.0 && norm2_sq > 0.0 {
183 dot_product / (norm1_sq.sqrt() * norm2_sq.sqrt())
184 } else {
185 0.0
186 };
187
188 Ok(ValueDifference {
189 parameter_name: param_name.to_string(),
190 mean_absolute_diff,
191 max_absolute_diff: max_abs_diff,
192 relative_diff_percent,
193 cosine_similarity,
194 })
195}
196
197fn estimate_memory_footprint(state_dict: &HashMap<String, Tensor<f32>>) -> u64 {
199 state_dict
200 .values()
201 .map(|tensor| {
202 let num_elements = tensor.shape().numel();
203 (num_elements * std::mem::size_of::<f32>()) as u64
204 })
205 .sum()
206}
207
208#[derive(Debug, Clone)]
210pub struct EnsembleConfig {
211 pub weights: Vec<f32>,
213 pub normalize_weights: bool,
215 pub voting_strategy: VotingStrategy,
217}
218
219#[derive(Debug, Clone, Copy)]
221pub enum VotingStrategy {
222 Average,
224 WeightedAverage,
226 MajorityVote,
228 SoftVoting,
230}
231
232impl Default for EnsembleConfig {
233 fn default() -> Self {
234 Self {
235 weights: vec![1.0],
236 normalize_weights: true,
237 voting_strategy: VotingStrategy::WeightedAverage,
238 }
239 }
240}
241
242pub fn create_model_ensemble(
244 models: &[HashMap<String, Tensor<f32>>],
245 config: Option<EnsembleConfig>,
246) -> Result<HashMap<String, Tensor<f32>>> {
247 if models.is_empty() {
248 return Err(TorshError::InvalidArgument(
249 "Cannot create ensemble from empty model list".to_string(),
250 ));
251 }
252
253 let config = config.unwrap_or_default();
254 let mut weights = config.weights.clone();
255
256 if weights.len() != models.len() {
258 weights = vec![1.0; models.len()];
259 }
260
261 if config.normalize_weights {
263 let sum: f32 = weights.iter().sum();
264 if sum > 0.0 {
265 weights.iter_mut().for_each(|w| *w /= sum);
266 }
267 }
268
269 let param_keys: std::collections::HashSet<_> = models[0].keys().cloned().collect();
271
272 let mut ensemble_state = HashMap::new();
273
274 for param_name in param_keys {
275 let tensors: Vec<&Tensor<f32>> = models.iter().filter_map(|m| m.get(¶m_name)).collect();
277
278 if tensors.len() != models.len() {
279 continue; }
281
282 if let Ok(averaged) = weighted_average_tensors(&tensors, &weights) {
284 ensemble_state.insert(param_name, averaged);
285 }
286 }
287
288 Ok(ensemble_state)
289}
290
291fn weighted_average_tensors(tensors: &[&Tensor<f32>], weights: &[f32]) -> Result<Tensor<f32>> {
293 if tensors.is_empty() {
294 return Err(TorshError::InvalidArgument("Empty tensor list".to_string()));
295 }
296
297 let shape = tensors[0].shape();
299 for tensor in &tensors[1..] {
300 if tensor.shape() != shape {
301 return Err(TorshError::InvalidArgument(
302 "All tensors must have the same shape".to_string(),
303 ));
304 }
305 }
306
307 let data_vecs: Vec<Vec<f32>> = tensors
309 .iter()
310 .map(|t| t.to_vec())
311 .collect::<Result<Vec<_>>>()?;
312 let num_elements = data_vecs[0].len();
313
314 let mut result = vec![0.0f32; num_elements];
315 for (tensor_data, weight) in data_vecs.iter().zip(weights.iter()) {
316 for (i, value) in tensor_data.iter().enumerate() {
317 result[i] += value * weight;
318 }
319 }
320
321 Tensor::from_data(result, shape.dims().to_vec(), torsh_core::DeviceType::Cpu)
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct QuantizationStats {
328 pub original_size_bytes: u64,
329 pub quantized_size_bytes: u64,
330 pub compression_ratio: f32,
331 pub parameters_quantized: usize,
332 pub mean_quantization_error: f64,
333 pub max_quantization_error: f64,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ConversionMetadata {
339 pub source_format: String,
340 pub target_format: String,
341 pub conversion_time_ms: u64,
342 pub warnings: Vec<String>,
343 pub unsupported_operations: Vec<String>,
344}
345
346pub fn load_model_auto(path: &Path) -> Result<HashMap<String, Tensor<f32>>> {
348 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
349
350 match extension {
351 "torsh" | "pth" | "pt" => load_torsh_model(path),
352 "onnx" => load_onnx_model_state(path),
353 "h5" | "keras" => load_keras_model(path),
354 _ => Err(TorshError::InvalidArgument(format!(
355 "Unsupported model format: {}",
356 extension
357 ))),
358 }
359}
360
361fn load_torsh_model(_path: &Path) -> Result<HashMap<String, Tensor<f32>>> {
363 Ok(HashMap::new())
365}
366
367fn load_onnx_model_state(_path: &Path) -> Result<HashMap<String, Tensor<f32>>> {
369 Ok(HashMap::new())
371}
372
373fn load_keras_model(_path: &Path) -> Result<HashMap<String, Tensor<f32>>> {
375 Ok(HashMap::new())
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use torsh_core::DeviceType;
383
384 #[test]
385 fn test_model_comparison() {
386 let mut model1 = HashMap::new();
387 let mut model2 = HashMap::new();
388
389 let tensor1 =
391 Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap();
392 let tensor2 =
393 Tensor::from_data(vec![1.1, 2.1, 3.1, 4.1], vec![2, 2], DeviceType::Cpu).unwrap();
394 model1.insert("layer1.weight".to_string(), tensor1);
395 model2.insert("layer1.weight".to_string(), tensor2);
396
397 let tensor3 = Tensor::from_data(vec![5.0, 6.0], vec![2], DeviceType::Cpu).unwrap();
399 model1.insert("layer1.bias".to_string(), tensor3);
400
401 let tensor4 = Tensor::from_data(vec![7.0, 8.0], vec![2], DeviceType::Cpu).unwrap();
403 model2.insert("layer2.weight".to_string(), tensor4);
404
405 let diff = compare_models(&model1, &model2, None).unwrap();
406
407 assert_eq!(diff.common_parameters.len(), 1);
408 assert_eq!(diff.only_in_first.len(), 1);
409 assert_eq!(diff.only_in_second.len(), 1);
410 assert_eq!(diff.param_counts, (2, 2));
411 }
412
413 #[test]
414 fn test_memory_footprint() {
415 let mut model = HashMap::new();
416 let tensor1 =
417 Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap();
418 let tensor2 = Tensor::from_data(vec![5.0, 6.0], vec![2], DeviceType::Cpu).unwrap();
419
420 model.insert("weight".to_string(), tensor1);
421 model.insert("bias".to_string(), tensor2);
422
423 let footprint = estimate_memory_footprint(&model);
424 assert_eq!(footprint, 24);
426 }
427
428 #[test]
429 fn test_weighted_average() {
430 let tensor1 = Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu).unwrap();
431 let tensor2 = Tensor::from_data(vec![3.0, 4.0], vec![2], DeviceType::Cpu).unwrap();
432
433 let tensors = vec![&tensor1, &tensor2];
434 let weights = vec![0.5, 0.5];
435
436 let result = weighted_average_tensors(&tensors, &weights).unwrap();
437 let result_data = result.to_vec().unwrap();
438
439 assert_eq!(result_data.len(), 2);
440 assert!((result_data[0] - 2.0).abs() < 1e-5);
441 assert!((result_data[1] - 3.0).abs() < 1e-5);
442 }
443
444 #[test]
445 fn test_ensemble_creation() {
446 let mut model1 = HashMap::new();
447 let mut model2 = HashMap::new();
448
449 let tensor1 = Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu).unwrap();
450 let tensor2 = Tensor::from_data(vec![3.0, 4.0], vec![2], DeviceType::Cpu).unwrap();
451
452 model1.insert("weight".to_string(), tensor1);
453 model2.insert("weight".to_string(), tensor2);
454
455 let models = vec![model1, model2];
456 let config = EnsembleConfig {
457 weights: vec![0.5, 0.5],
458 normalize_weights: false,
459 voting_strategy: VotingStrategy::WeightedAverage,
460 };
461
462 let ensemble = create_model_ensemble(&models, Some(config)).unwrap();
463
464 assert_eq!(ensemble.len(), 1);
465 let result = &ensemble["weight"];
466 let result_data = result.to_vec().unwrap();
467 assert!((result_data[0] - 2.0).abs() < 1e-5);
468 assert!((result_data[1] - 3.0).abs() < 1e-5);
469 }
470}