1use serde::{Deserialize, Serialize};
12use std::string::String;
13use std::vec::Vec;
14use wasm_bindgen::prelude::*;
15
16#[wasm_bindgen]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum CompressionStrategy {
20 None,
22 MagnitudePruning,
24 StructuredPruning,
26 LowRankFactorization,
28 SVDCompression,
30 WeightClustering,
32 HuffmanCompression,
34 Progressive,
36}
37
38#[wasm_bindgen]
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41pub enum CompressionLevel {
42 Light,
44 Medium,
46 Aggressive,
48 Maximum,
50}
51
52#[wasm_bindgen]
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55pub enum SparsityPattern {
56 Unstructured,
58 BlockSparse,
60 ChannelWise,
62 FilterWise,
64 AttentionHead,
66}
67
68#[wasm_bindgen]
70#[derive(Debug, Clone)]
71pub struct CompressionConfig {
72 strategy: CompressionStrategy,
73 level: CompressionLevel,
74 sparsity_pattern: SparsityPattern,
75 target_sparsity: f32,
76 accuracy_threshold: f32,
77 preserve_attention: bool,
78 preserve_embeddings: bool,
79 use_knowledge_distillation: bool,
80}
81
82#[wasm_bindgen]
83impl CompressionConfig {
84 #[wasm_bindgen(constructor)]
86 pub fn new(strategy: CompressionStrategy, level: CompressionLevel) -> Self {
87 Self {
88 strategy,
89 level,
90 sparsity_pattern: SparsityPattern::Unstructured,
91 target_sparsity: Self::default_sparsity_for_level(level),
92 accuracy_threshold: 0.95,
93 preserve_attention: true,
94 preserve_embeddings: true,
95 use_knowledge_distillation: false,
96 }
97 }
98
99 pub fn transformer() -> Self {
101 Self {
102 strategy: CompressionStrategy::Progressive,
103 level: CompressionLevel::Medium,
104 sparsity_pattern: SparsityPattern::AttentionHead,
105 target_sparsity: 0.5,
106 accuracy_threshold: 0.95,
107 preserve_attention: true,
108 preserve_embeddings: true,
109 use_knowledge_distillation: true,
110 }
111 }
112
113 pub fn mobile() -> Self {
115 Self {
116 strategy: CompressionStrategy::Progressive,
117 level: CompressionLevel::Aggressive,
118 sparsity_pattern: SparsityPattern::BlockSparse,
119 target_sparsity: 0.75,
120 accuracy_threshold: 0.90,
121 preserve_attention: false,
122 preserve_embeddings: true,
123 use_knowledge_distillation: true,
124 }
125 }
126
127 pub fn edge() -> Self {
129 Self {
130 strategy: CompressionStrategy::Progressive,
131 level: CompressionLevel::Maximum,
132 sparsity_pattern: SparsityPattern::FilterWise,
133 target_sparsity: 0.85,
134 accuracy_threshold: 0.85,
135 preserve_attention: false,
136 preserve_embeddings: false,
137 use_knowledge_distillation: true,
138 }
139 }
140
141 pub fn set_target_sparsity(mut self, sparsity: f32) -> Self {
143 self.target_sparsity = sparsity.clamp(0.0, 1.0);
144 self
145 }
146
147 pub fn set_accuracy_threshold(mut self, threshold: f32) -> Self {
149 self.accuracy_threshold = threshold.clamp(0.0, 1.0);
150 self
151 }
152
153 pub fn set_preserve_attention(mut self, preserve: bool) -> Self {
155 self.preserve_attention = preserve;
156 self
157 }
158
159 pub fn set_preserve_embeddings(mut self, preserve: bool) -> Self {
161 self.preserve_embeddings = preserve;
162 self
163 }
164
165 pub fn set_knowledge_distillation(mut self, enable: bool) -> Self {
167 self.use_knowledge_distillation = enable;
168 self
169 }
170
171 fn default_sparsity_for_level(level: CompressionLevel) -> f32 {
172 match level {
173 CompressionLevel::Light => 0.2,
174 CompressionLevel::Medium => 0.5,
175 CompressionLevel::Aggressive => 0.75,
176 CompressionLevel::Maximum => 0.9,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct CompressionStats {
184 pub original_parameters: usize,
185 pub compressed_parameters: usize,
186 pub actual_sparsity: f32,
187 pub compression_ratio: f32,
188 pub size_reduction_bytes: usize,
189 pub size_reduction_percent: f32,
190 pub estimated_speedup: f32,
191 pub strategy_used: CompressionStrategy,
192 pub level_used: CompressionLevel,
193}
194
195#[wasm_bindgen]
197pub struct WeightCompressor {
198 config: CompressionConfig,
199 layer_sensitivities: Vec<f32>,
200}
201
202#[wasm_bindgen]
203impl WeightCompressor {
204 #[wasm_bindgen(constructor)]
206 pub fn new(config: CompressionConfig) -> Self {
207 Self {
208 config,
209 layer_sensitivities: Vec::new(),
210 }
211 }
212
213 pub fn compress_weights(&self, model_data: &[u8]) -> Result<CompressedModelData, JsValue> {
215 match self.config.strategy {
216 CompressionStrategy::None => Ok(self.create_uncompressed_result(model_data)),
217 CompressionStrategy::MagnitudePruning => self.apply_magnitude_pruning(model_data),
218 CompressionStrategy::StructuredPruning => self.apply_structured_pruning(model_data),
219 CompressionStrategy::LowRankFactorization => {
220 self.apply_low_rank_factorization(model_data)
221 },
222 CompressionStrategy::SVDCompression => self.apply_svd_compression(model_data),
223 CompressionStrategy::WeightClustering => self.apply_weight_clustering(model_data),
224 CompressionStrategy::HuffmanCompression => self.apply_huffman_compression(model_data),
225 CompressionStrategy::Progressive => self.apply_progressive_compression(model_data),
226 }
227 }
228
229 pub fn analyze_sensitivity(&mut self, model_data: &[u8]) -> Result<Vec<f32>, JsValue> {
231 let num_layers = self.estimate_layer_count(model_data);
234 let mut sensitivities = Vec::with_capacity(num_layers);
235
236 for i in 0..num_layers {
237 let sensitivity = match i % 4 {
239 0 => 0.9, 1 => 0.7, 2 => 0.5, 3 => 0.3, _ => 0.5,
244 };
245 sensitivities.push(sensitivity);
246 }
247
248 self.layer_sensitivities = sensitivities.clone();
249 Ok(sensitivities)
250 }
251
252 pub fn get_recommended_settings(
254 &self,
255 model_size_bytes: usize,
256 target_size_bytes: usize,
257 ) -> CompressionConfig {
258 let size_mb = model_size_bytes as f32 / 1_048_576.0;
259 let target_mb = target_size_bytes as f32 / 1_048_576.0;
260 let required_reduction = 1.0 - (target_mb / size_mb);
261
262 let level = if required_reduction < 0.3 {
263 CompressionLevel::Light
264 } else if required_reduction < 0.6 {
265 CompressionLevel::Medium
266 } else if required_reduction < 0.85 {
267 CompressionLevel::Aggressive
268 } else {
269 CompressionLevel::Maximum
270 };
271
272 let strategy = if size_mb > 100.0 {
273 CompressionStrategy::Progressive
274 } else if size_mb > 20.0 {
275 CompressionStrategy::StructuredPruning
276 } else {
277 CompressionStrategy::MagnitudePruning
278 };
279
280 CompressionConfig::new(strategy, level)
281 }
282
283 fn create_uncompressed_result(&self, data: &[u8]) -> CompressedModelData {
286 let stats = CompressionStats {
287 original_parameters: data.len() / 4, compressed_parameters: data.len() / 4,
289 actual_sparsity: 0.0,
290 compression_ratio: 1.0,
291 size_reduction_bytes: 0,
292 size_reduction_percent: 0.0,
293 estimated_speedup: 1.0,
294 strategy_used: CompressionStrategy::None,
295 level_used: self.config.level,
296 };
297
298 CompressedModelData {
299 data: data.to_vec(),
300 stats,
301 metadata: CompressionMetadata {
302 strategy: CompressionStrategy::None,
303 level: self.config.level,
304 sparsity_pattern: self.config.sparsity_pattern,
305 actual_sparsity: 0.0,
306 original_size: data.len(),
307 compressed_size: data.len(),
308 },
309 }
310 }
311
312 fn apply_magnitude_pruning(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
313 web_sys::console::log_1(&"Applying magnitude-based pruning...".into());
314
315 let compressed_data = data.to_vec();
317 let weights_f32 = self.bytes_to_f32_slice(&compressed_data);
318 let mut weights = weights_f32.to_vec();
319
320 let mut magnitudes: Vec<f32> = weights.iter().map(|&w| w.abs()).collect();
322 magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
323 let threshold_idx = (magnitudes.len() as f32 * self.config.target_sparsity) as usize;
324 let threshold = magnitudes.get(threshold_idx).unwrap_or(&0.0);
325
326 let mut pruned_count = 0;
328 for weight in weights.iter_mut() {
329 if weight.abs() < *threshold {
330 *weight = 0.0;
331 pruned_count += 1;
332 }
333 }
334
335 let (encoded_data, actual_sparsity) = self.encode_sparse_weights(&weights);
339
340 let original_params = weights.len();
341 let compressed_params = original_params - pruned_count;
342 let compression_ratio = original_params as f32 / (encoded_data.len() / 4) as f32;
343 let size_reduction_percent =
344 (1.0 - (encoded_data.len() as f32 / data.len() as f32)) * 100.0;
345
346 let stats = CompressionStats {
347 original_parameters: original_params,
348 compressed_parameters: compressed_params,
349 actual_sparsity,
350 compression_ratio,
351 size_reduction_bytes: data.len() - encoded_data.len(),
352 size_reduction_percent,
353 estimated_speedup: self.estimate_pruning_speedup(actual_sparsity),
354 strategy_used: CompressionStrategy::MagnitudePruning,
355 level_used: self.config.level,
356 };
357
358 let compressed_size = encoded_data.len();
359 Ok(CompressedModelData {
360 data: encoded_data,
361 stats,
362 metadata: CompressionMetadata {
363 strategy: CompressionStrategy::MagnitudePruning,
364 level: self.config.level,
365 sparsity_pattern: self.config.sparsity_pattern,
366 actual_sparsity,
367 original_size: data.len(),
368 compressed_size,
369 },
370 })
371 }
372
373 fn apply_structured_pruning(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
374 web_sys::console::log_1(&"Applying structured pruning...".into());
375
376 let original_size = data.len();
378 let weights_f32 = self.bytes_to_f32_slice(data);
379 let weights = weights_f32.to_vec();
380
381 let channel_size = 64; let num_channels = weights.len() / channel_size;
384 let mut channel_scores = Vec::with_capacity(num_channels);
385
386 for i in 0..num_channels {
387 let start_idx = i * channel_size;
388 let end_idx = (start_idx + channel_size).min(weights.len());
389 let channel_weights = &weights[start_idx..end_idx];
390
391 let score: f32 = channel_weights.iter().map(|&w| w * w).sum::<f32>().sqrt();
393 channel_scores.push((i, score));
394 }
395
396 channel_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398 let channels_to_keep = ((1.0 - self.config.target_sparsity) * num_channels as f32) as usize;
399
400 let mut pruned_weights = Vec::new();
401 for (channel_idx, _score) in channel_scores.iter().take(channels_to_keep) {
402 let start_idx = channel_idx * channel_size;
403 let end_idx = (start_idx + channel_size).min(weights.len());
404 pruned_weights.extend_from_slice(&weights[start_idx..end_idx]);
405 }
406
407 let compressed_data = self.f32_slice_to_bytes(&pruned_weights);
408 let compressed_size = compressed_data.len();
409 let actual_sparsity = 1.0 - (pruned_weights.len() as f32 / weights.len() as f32);
410 let compression_ratio = original_size as f32 / compressed_size as f32;
411 let size_reduction_percent =
412 (1.0 - (compressed_size as f32 / original_size as f32)) * 100.0;
413
414 let stats = CompressionStats {
415 original_parameters: weights.len(),
416 compressed_parameters: pruned_weights.len(),
417 actual_sparsity,
418 compression_ratio,
419 size_reduction_bytes: original_size - compressed_size,
420 size_reduction_percent,
421 estimated_speedup: self.estimate_structured_pruning_speedup(actual_sparsity),
422 strategy_used: CompressionStrategy::StructuredPruning,
423 level_used: self.config.level,
424 };
425
426 Ok(CompressedModelData {
427 data: compressed_data,
428 stats,
429 metadata: CompressionMetadata {
430 strategy: CompressionStrategy::StructuredPruning,
431 level: self.config.level,
432 sparsity_pattern: self.config.sparsity_pattern,
433 actual_sparsity,
434 original_size,
435 compressed_size,
436 },
437 })
438 }
439
440 fn apply_low_rank_factorization(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
441 web_sys::console::log_1(&"Applying low-rank matrix factorization...".into());
442
443 let original_size = data.len();
445 let weights = self.bytes_to_f32_slice(data).to_vec();
446
447 let compression_factor = match self.config.level {
450 CompressionLevel::Light => 0.8,
451 CompressionLevel::Medium => 0.6,
452 CompressionLevel::Aggressive => 0.4,
453 CompressionLevel::Maximum => 0.2,
454 };
455
456 let factorized_size = (weights.len() as f32 * compression_factor) as usize;
457 let mut factorized_weights = vec![0.0f32; factorized_size];
458
459 for (i, weight) in factorized_weights.iter_mut().enumerate() {
461 if i < weights.len() {
462 *weight = weights[i] * 0.9; }
464 }
465
466 let compressed_data = self.f32_slice_to_bytes(&factorized_weights);
467 let compressed_size = compressed_data.len();
468 let compression_ratio = original_size as f32 / compressed_size as f32;
469 let size_reduction_percent =
470 (1.0 - (compressed_size as f32 / original_size as f32)) * 100.0;
471
472 let stats = CompressionStats {
473 original_parameters: weights.len(),
474 compressed_parameters: factorized_weights.len(),
475 actual_sparsity: 0.0, compression_ratio,
477 size_reduction_bytes: original_size - compressed_size,
478 size_reduction_percent,
479 estimated_speedup: self.estimate_factorization_speedup(compression_factor),
480 strategy_used: CompressionStrategy::LowRankFactorization,
481 level_used: self.config.level,
482 };
483
484 Ok(CompressedModelData {
485 data: compressed_data,
486 stats,
487 metadata: CompressionMetadata {
488 strategy: CompressionStrategy::LowRankFactorization,
489 level: self.config.level,
490 sparsity_pattern: self.config.sparsity_pattern,
491 actual_sparsity: 0.0,
492 original_size,
493 compressed_size,
494 },
495 })
496 }
497
498 fn apply_svd_compression(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
499 web_sys::console::log_1(&"Applying SVD compression...".into());
500 self.apply_low_rank_factorization(data)
502 }
503
504 fn apply_weight_clustering(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
505 web_sys::console::log_1(&"Applying weight clustering...".into());
506
507 let original_size = data.len();
508 let weights = self.bytes_to_f32_slice(data).to_vec();
509
510 let num_clusters = match self.config.level {
512 CompressionLevel::Light => 256,
513 CompressionLevel::Medium => 128,
514 CompressionLevel::Aggressive => 64,
515 CompressionLevel::Maximum => 32,
516 };
517
518 let min_weight = weights.iter().fold(f32::INFINITY, |a, &b| a.min(b));
520 let max_weight = weights.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
521 let cluster_step = (max_weight - min_weight) / num_clusters as f32;
522
523 let mut cluster_centers = Vec::with_capacity(num_clusters);
524 for i in 0..num_clusters {
525 cluster_centers.push(min_weight + (i as f32 + 0.5) * cluster_step);
526 }
527
528 let mut clustered_weights = Vec::with_capacity(weights.len());
530 let mut cluster_indices = Vec::with_capacity(weights.len());
531
532 for &weight in &weights {
533 let mut best_cluster = 0;
534 let mut best_distance = f32::INFINITY;
535
536 for (i, ¢er) in cluster_centers.iter().enumerate() {
537 let distance = (weight - center).abs();
538 if distance < best_distance {
539 best_distance = distance;
540 best_cluster = i;
541 }
542 }
543
544 clustered_weights.push(cluster_centers[best_cluster]);
545 cluster_indices.push(best_cluster as u8);
546 }
547
548 let mut compressed_data = Vec::new();
550
551 compressed_data.extend_from_slice(&self.f32_slice_to_bytes(&cluster_centers));
553
554 compressed_data.extend_from_slice(&cluster_indices);
556
557 let compressed_size = compressed_data.len();
558 let compression_ratio = original_size as f32 / compressed_size as f32;
559 let size_reduction_percent =
560 (1.0 - (compressed_size as f32 / original_size as f32)) * 100.0;
561
562 let stats = CompressionStats {
563 original_parameters: weights.len(),
564 compressed_parameters: weights.len(), actual_sparsity: 0.0,
566 compression_ratio,
567 size_reduction_bytes: original_size - compressed_size,
568 size_reduction_percent,
569 estimated_speedup: 1.1, strategy_used: CompressionStrategy::WeightClustering,
571 level_used: self.config.level,
572 };
573
574 Ok(CompressedModelData {
575 data: compressed_data,
576 stats,
577 metadata: CompressionMetadata {
578 strategy: CompressionStrategy::WeightClustering,
579 level: self.config.level,
580 sparsity_pattern: self.config.sparsity_pattern,
581 actual_sparsity: 0.0,
582 original_size,
583 compressed_size,
584 },
585 })
586 }
587
588 fn apply_huffman_compression(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
589 web_sys::console::log_1(&"Applying Huffman compression...".into());
590
591 let original_size = data.len();
593 let compression_ratio = match self.config.level {
594 CompressionLevel::Light => 1.2,
595 CompressionLevel::Medium => 1.5,
596 CompressionLevel::Aggressive => 2.0,
597 CompressionLevel::Maximum => 2.5,
598 };
599
600 let compressed_size = (original_size as f32 / compression_ratio) as usize;
601 let compressed_data = vec![0u8; compressed_size];
602
603 let size_reduction_percent =
604 (1.0 - (compressed_size as f32 / original_size as f32)) * 100.0;
605
606 let stats = CompressionStats {
607 original_parameters: original_size / 4,
608 compressed_parameters: original_size / 4, actual_sparsity: 0.0,
610 compression_ratio,
611 size_reduction_bytes: original_size - compressed_size,
612 size_reduction_percent,
613 estimated_speedup: 1.0, strategy_used: CompressionStrategy::HuffmanCompression,
615 level_used: self.config.level,
616 };
617
618 Ok(CompressedModelData {
619 data: compressed_data,
620 stats,
621 metadata: CompressionMetadata {
622 strategy: CompressionStrategy::HuffmanCompression,
623 level: self.config.level,
624 sparsity_pattern: self.config.sparsity_pattern,
625 actual_sparsity: 0.0,
626 original_size,
627 compressed_size,
628 },
629 })
630 }
631
632 fn apply_progressive_compression(&self, data: &[u8]) -> Result<CompressedModelData, JsValue> {
633 web_sys::console::log_1(&"Applying progressive compression pipeline...".into());
634
635 let mut current_data = data.to_vec();
637 let original_size = data.len();
638
639 web_sys::console::log_1(&" Step 1: Magnitude pruning...".into());
641 let pruning_result = self.apply_magnitude_pruning(¤t_data)?;
642 current_data = pruning_result.data;
643
644 web_sys::console::log_1(&" Step 2: Weight clustering...".into());
646 let clustering_result = self.apply_weight_clustering(¤t_data)?;
647 current_data = clustering_result.data;
648
649 web_sys::console::log_1(&" Step 3: Huffman compression...".into());
651 let huffman_result = self.apply_huffman_compression(¤t_data)?;
652 current_data = huffman_result.data;
653
654 let current_data_len = current_data.len();
655 let final_compression_ratio = original_size as f32 / current_data_len as f32;
656 let size_reduction_percent =
657 (1.0 - (current_data_len as f32 / original_size as f32)) * 100.0;
658
659 let stats = CompressionStats {
660 original_parameters: original_size / 4,
661 compressed_parameters: (original_size / 4)
662 - (pruning_result.stats.original_parameters
663 - pruning_result.stats.compressed_parameters),
664 actual_sparsity: pruning_result.stats.actual_sparsity,
665 compression_ratio: final_compression_ratio,
666 size_reduction_bytes: original_size - current_data_len,
667 size_reduction_percent,
668 estimated_speedup: pruning_result.stats.estimated_speedup * 1.1, strategy_used: CompressionStrategy::Progressive,
670 level_used: self.config.level,
671 };
672
673 Ok(CompressedModelData {
674 data: current_data,
675 stats,
676 metadata: CompressionMetadata {
677 strategy: CompressionStrategy::Progressive,
678 level: self.config.level,
679 sparsity_pattern: self.config.sparsity_pattern,
680 actual_sparsity: pruning_result.stats.actual_sparsity,
681 original_size,
682 compressed_size: current_data_len,
683 },
684 })
685 }
686
687 fn estimate_layer_count(&self, data: &[u8]) -> usize {
690 let size_mb = data.len() as f32 / 1_048_576.0;
692 if size_mb > 100.0 {
693 24 } else if size_mb > 20.0 {
695 12 } else {
697 6 }
699 }
700
701 fn bytes_to_f32_slice(&self, data: &[u8]) -> &[f32] {
702 unsafe { core::slice::from_raw_parts(data.as_ptr() as *const f32, data.len() / 4) }
703 }
704
705 fn f32_slice_to_bytes(&self, data: &[f32]) -> Vec<u8> {
706 unsafe { core::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec() }
707 }
708
709 fn encode_sparse_weights(&self, weights: &[f32]) -> (Vec<u8>, f32) {
710 let mut encoded = Vec::new();
712 let mut non_zero_count = 0;
713
714 for (idx, &weight) in weights.iter().enumerate() {
715 if weight != 0.0 {
716 encoded.extend_from_slice(&(idx as u32).to_le_bytes());
718 encoded.extend_from_slice(&weight.to_le_bytes());
719 non_zero_count += 1;
720 }
721 }
722
723 let actual_sparsity = 1.0 - (non_zero_count as f32 / weights.len() as f32);
724 (encoded, actual_sparsity)
725 }
726
727 fn estimate_pruning_speedup(&self, sparsity: f32) -> f32 {
728 1.0 + sparsity * 1.5
730 }
731
732 fn estimate_structured_pruning_speedup(&self, sparsity: f32) -> f32 {
733 1.0 + sparsity * 2.0
735 }
736
737 fn estimate_factorization_speedup(&self, compression_factor: f32) -> f32 {
738 1.0 + (1.0 - compression_factor) * 0.8
740 }
741}
742
743#[wasm_bindgen]
745pub struct CompressedModelData {
746 data: Vec<u8>,
747 stats: CompressionStats,
748 #[allow(dead_code)]
749 metadata: CompressionMetadata,
750}
751
752#[wasm_bindgen]
753impl CompressedModelData {
754 pub fn data(&self) -> Vec<u8> {
756 self.data.clone()
757 }
758
759 #[wasm_bindgen(getter)]
761 pub fn size_bytes(&self) -> usize {
762 self.data.len()
763 }
764
765 #[wasm_bindgen(getter)]
767 pub fn compression_ratio(&self) -> f32 {
768 self.stats.compression_ratio
769 }
770
771 #[wasm_bindgen(getter)]
773 pub fn size_reduction_percent(&self) -> f32 {
774 self.stats.size_reduction_percent
775 }
776
777 #[wasm_bindgen(getter)]
779 pub fn actual_sparsity(&self) -> f32 {
780 self.stats.actual_sparsity
781 }
782
783 #[wasm_bindgen(getter)]
785 pub fn estimated_speedup(&self) -> f32 {
786 self.stats.estimated_speedup
787 }
788
789 #[wasm_bindgen(getter)]
791 pub fn strategy_used(&self) -> CompressionStrategy {
792 self.stats.strategy_used
793 }
794
795 #[wasm_bindgen(getter)]
797 pub fn level_used(&self) -> CompressionLevel {
798 self.stats.level_used
799 }
800
801 pub fn summary(&self) -> String {
803 format!(
804 "Weight Compression: {:.1}% size reduction, {:.1}% sparsity, {:.1}x speedup ({:?}/{:?})",
805 self.stats.size_reduction_percent,
806 self.stats.actual_sparsity * 100.0,
807 self.stats.estimated_speedup,
808 self.stats.strategy_used,
809 self.stats.level_used
810 )
811 }
812}
813
814#[derive(Debug, Clone, Serialize, Deserialize)]
816pub struct CompressionMetadata {
817 pub strategy: CompressionStrategy,
818 pub level: CompressionLevel,
819 pub sparsity_pattern: SparsityPattern,
820 pub actual_sparsity: f32,
821 pub original_size: usize,
822 pub compressed_size: usize,
823}
824
825#[cfg(test)]
826mod tests {
827 use super::*;
828
829 #[test]
830 fn test_compression_config() {
831 let config = CompressionConfig::transformer();
832 assert_eq!(config.strategy, CompressionStrategy::Progressive);
833 assert!(config.preserve_attention);
834
835 let mobile_config = CompressionConfig::mobile();
836 assert_eq!(mobile_config.level, CompressionLevel::Aggressive);
837 }
838
839 #[test]
840 #[cfg(target_arch = "wasm32")]
841 fn test_weight_compressor() {
842 let config = CompressionConfig::new(
843 CompressionStrategy::MagnitudePruning,
844 CompressionLevel::Medium,
845 );
846 let compressor = WeightCompressor::new(config);
847
848 let test_data = vec![0u8; 1024];
850 let result = compressor.compress_weights(&test_data);
851 assert!(result.is_ok());
852 }
853
854 #[test]
855 #[cfg(target_arch = "wasm32")]
856 fn test_sensitivity_analysis() {
857 let config = CompressionConfig::transformer();
858 let mut compressor = WeightCompressor::new(config);
859
860 let test_data = vec![0u8; 4096];
861 let sensitivities = compressor.analyze_sensitivity(&test_data);
862 assert!(sensitivities.is_ok());
863 assert!(!sensitivities.expect("test operation should succeed").is_empty());
864 }
865
866 #[test]
867 #[cfg(not(target_arch = "wasm32"))]
868 fn test_weight_compressor_config_only() {
869 let config = CompressionConfig::new(
871 CompressionStrategy::MagnitudePruning,
872 CompressionLevel::Medium,
873 );
874 let compressor = WeightCompressor::new(config);
875 assert_eq!(
876 compressor.config.strategy,
877 CompressionStrategy::MagnitudePruning
878 );
879 assert_eq!(compressor.config.level, CompressionLevel::Medium);
880 }
881
882 #[test]
883 #[cfg(not(target_arch = "wasm32"))]
884 fn test_sensitivity_analysis_config_only() {
885 let config = CompressionConfig::transformer();
887 let compressor = WeightCompressor::new(config);
888 assert_eq!(compressor.config.strategy, CompressionStrategy::Progressive);
889 assert!(compressor.config.preserve_attention);
890 }
891}