Skip to main content

scirs2_neural/hardware/
mod.rs

1//! Hardware-aware neural network operations
2//!
3//! This module provides hardware-specific optimizations for neural network inference,
4//! including hardware profiling, per-layer quantization targeting specific chips,
5//! and hardware-aware pruning strategies.
6//!
7//! # Key Types
8//!
9//! - [`HardwareProfile`] – describes capabilities of a target hardware device
10//! - [`HardwareOptimizer`] – applies hardware-specific quantization and pruning
11//! - [`QuantizationPrecision`] – precision modes supported by hardware
12
13use crate::error::{NeuralError, Result};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18// ─────────────────────────────────────────────────────────────────────────────
19// Hardware profile
20// ─────────────────────────────────────────────────────────────────────────────
21
22/// Precision modes available on hardware accelerators.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum QuantizationPrecision {
25    /// 32-bit floating point (full precision)
26    FP32,
27    /// 16-bit floating point (half precision)
28    FP16,
29    /// BFloat16 (brain float)
30    BF16,
31    /// 8-bit integer
32    INT8,
33    /// 4-bit integer
34    INT4,
35}
36
37impl QuantizationPrecision {
38    /// Returns the number of bits for this precision.
39    pub fn bits(self) -> u8 {
40        match self {
41            QuantizationPrecision::FP32 => 32,
42            QuantizationPrecision::FP16 | QuantizationPrecision::BF16 => 16,
43            QuantizationPrecision::INT8 => 8,
44            QuantizationPrecision::INT4 => 4,
45        }
46    }
47
48    /// Returns whether this precision is floating-point.
49    pub fn is_float(self) -> bool {
50        matches!(
51            self,
52            QuantizationPrecision::FP32 | QuantizationPrecision::FP16 | QuantizationPrecision::BF16
53        )
54    }
55}
56
57/// Describes the computational capabilities of a target hardware device.
58///
59/// # Examples
60/// ```
61/// use scirs2_neural::hardware::{HardwareProfile, QuantizationPrecision};
62///
63/// let cpu_profile = HardwareProfile::cpu_default();
64/// assert!(cpu_profile.supported_precisions.contains(&QuantizationPrecision::FP32));
65/// ```
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct HardwareProfile {
68    /// Human-readable device name (e.g. "Apple M2 Pro", "NVIDIA A100")
69    pub name: String,
70    /// Number of physical compute cores
71    pub num_cores: usize,
72    /// Memory bandwidth in GB/s
73    pub memory_bandwidth_gb_s: f64,
74    /// Total on-chip memory / cache in MB
75    pub cache_mb: f64,
76    /// Supported quantization precisions
77    pub supported_precisions: Vec<QuantizationPrecision>,
78    /// SIMD vector width in bits (0 = no SIMD)
79    pub simd_width_bits: usize,
80    /// Whether the device has dedicated neural-network accelerator units
81    pub has_npu: bool,
82    /// Peak compute in TFLOPS at FP32
83    pub peak_tflops_fp32: f64,
84    /// Custom hardware properties
85    pub properties: HashMap<String, String>,
86}
87
88impl HardwareProfile {
89    /// Build a generic CPU profile.
90    pub fn cpu_default() -> Self {
91        Self {
92            name: "Generic CPU".to_string(),
93            num_cores: detect_num_cpus(),
94            memory_bandwidth_gb_s: 50.0,
95            cache_mb: 8.0,
96            supported_precisions: vec![
97                QuantizationPrecision::FP32,
98                QuantizationPrecision::FP16,
99                QuantizationPrecision::INT8,
100            ],
101            simd_width_bits: 256, // AVX2
102            has_npu: false,
103            peak_tflops_fp32: 0.5,
104            properties: HashMap::new(),
105        }
106    }
107
108    /// Build a mobile ARM profile (e.g. Cortex-A series).
109    pub fn mobile_arm() -> Self {
110        Self {
111            name: "Mobile ARM".to_string(),
112            num_cores: 8,
113            memory_bandwidth_gb_s: 30.0,
114            cache_mb: 4.0,
115            supported_precisions: vec![
116                QuantizationPrecision::FP32,
117                QuantizationPrecision::FP16,
118                QuantizationPrecision::INT8,
119                QuantizationPrecision::INT4,
120            ],
121            simd_width_bits: 128, // NEON
122            has_npu: true,
123            peak_tflops_fp32: 0.1,
124            properties: {
125                let mut m = HashMap::new();
126                m.insert("arch".to_string(), "arm64".to_string());
127                m
128            },
129        }
130    }
131
132    /// Build an NVIDIA GPU profile.
133    pub fn nvidia_gpu(name: &str, tflops: f64, bandwidth_gb_s: f64) -> Self {
134        Self {
135            name: name.to_string(),
136            num_cores: 4096,
137            memory_bandwidth_gb_s: bandwidth_gb_s,
138            cache_mb: 40.0,
139            supported_precisions: vec![
140                QuantizationPrecision::FP32,
141                QuantizationPrecision::FP16,
142                QuantizationPrecision::BF16,
143                QuantizationPrecision::INT8,
144                QuantizationPrecision::INT4,
145            ],
146            simd_width_bits: 512,
147            has_npu: true,
148            peak_tflops_fp32: tflops,
149            properties: {
150                let mut m = HashMap::new();
151                m.insert("vendor".to_string(), "NVIDIA".to_string());
152                m
153            },
154        }
155    }
156
157    /// Returns the most efficient precision supported for inference.
158    ///
159    /// Chooses the smallest integer type available, falling back to FP16/FP32.
160    pub fn preferred_inference_precision(&self) -> QuantizationPrecision {
161        let priority = [
162            QuantizationPrecision::INT4,
163            QuantizationPrecision::INT8,
164            QuantizationPrecision::FP16,
165            QuantizationPrecision::BF16,
166            QuantizationPrecision::FP32,
167        ];
168        for p in &priority {
169            if self.supported_precisions.contains(p) {
170                return *p;
171            }
172        }
173        QuantizationPrecision::FP32
174    }
175}
176
177// ─────────────────────────────────────────────────────────────────────────────
178// Layer quantization plan
179// ─────────────────────────────────────────────────────────────────────────────
180
181/// Per-layer quantization decision produced by [`HardwareOptimizer`].
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct LayerQuantizationPlan {
184    /// Layer name / identifier
185    pub layer_name: String,
186    /// Chosen weight precision
187    pub weight_precision: QuantizationPrecision,
188    /// Chosen activation precision
189    pub activation_precision: QuantizationPrecision,
190    /// Whether this layer should be pruned
191    pub prune: bool,
192    /// Pruning sparsity target (0 = no pruning, 1 = fully pruned)
193    pub pruning_sparsity: f64,
194}
195
196// ─────────────────────────────────────────────────────────────────────────────
197// Hardware optimizer
198// ─────────────────────────────────────────────────────────────────────────────
199
200/// Strategies used by [`HardwareOptimizer`] to decide which precision to use.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
202pub enum OptimizationStrategy {
203    /// Maximise throughput at the cost of accuracy
204    MaxThroughput,
205    /// Balance speed and numerical accuracy
206    #[default]
207    Balanced,
208    /// Maximise accuracy, apply minimal quantization
209    MaxAccuracy,
210    /// Minimise power / energy consumption
211    PowerEfficient,
212}
213
214/// Configuration for [`HardwareOptimizer`].
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct HardwareOptimizerConfig {
217    /// Optimisation goal
218    pub strategy: OptimizationStrategy,
219    /// Fraction of layers eligible for INT8 quantization (0.0–1.0)
220    pub int8_fraction: f64,
221    /// Whether to enable weight pruning
222    pub enable_pruning: bool,
223    /// Target pruning sparsity for eligible layers
224    pub pruning_sparsity: f64,
225    /// Layers that must stay at full FP32 precision (by name/prefix)
226    pub sensitive_layers: Vec<String>,
227}
228
229impl Default for HardwareOptimizerConfig {
230    fn default() -> Self {
231        Self {
232            strategy: OptimizationStrategy::Balanced,
233            int8_fraction: 0.8,
234            enable_pruning: false,
235            pruning_sparsity: 0.3,
236            sensitive_layers: vec!["output".to_string(), "classifier".to_string()],
237        }
238    }
239}
240
241/// Applies hardware-aware quantization and pruning decisions to a model description.
242///
243/// The optimizer does **not** mutate any live parameters; instead it produces a
244/// [`Vec<LayerQuantizationPlan>`] that downstream code can apply.
245///
246/// # Examples
247/// ```
248/// use scirs2_neural::hardware::{HardwareProfile, HardwareOptimizer, HardwareOptimizerConfig};
249///
250/// let profile = HardwareProfile::mobile_arm();
251/// let optimizer = HardwareOptimizer::new(profile, HardwareOptimizerConfig::default());
252///
253/// let layers = &["conv1", "bn1", "relu1", "conv2", "fc_output"];
254/// let plan = optimizer.compute_quantization_plan(layers).expect("plan failed");
255/// assert_eq!(plan.len(), layers.len());
256/// ```
257pub struct HardwareOptimizer {
258    profile: HardwareProfile,
259    config: HardwareOptimizerConfig,
260}
261
262impl HardwareOptimizer {
263    /// Create a new hardware optimizer for the given device profile and config.
264    pub fn new(profile: HardwareProfile, config: HardwareOptimizerConfig) -> Self {
265        Self { profile, config }
266    }
267
268    /// Returns a reference to the hardware profile.
269    pub fn profile(&self) -> &HardwareProfile {
270        &self.profile
271    }
272
273    /// Returns a reference to the optimizer configuration.
274    pub fn config(&self) -> &HardwareOptimizerConfig {
275        &self.config
276    }
277
278    /// Produce a per-layer quantization plan for the given layer names.
279    ///
280    /// Layers whose names start with any prefix in `config.sensitive_layers` are
281    /// kept at FP32.  The remaining layers are quantized according to
282    /// `config.strategy` and the hardware's supported precisions.
283    pub fn compute_quantization_plan(
284        &self,
285        layer_names: &[&str],
286    ) -> Result<Vec<LayerQuantizationPlan>> {
287        if layer_names.is_empty() {
288            return Err(NeuralError::InvalidArgument(
289                "layer_names must not be empty".to_string(),
290            ));
291        }
292
293        let total = layer_names.len();
294        let n_int8 = ((total as f64) * self.config.int8_fraction.clamp(0.0, 1.0)) as usize;
295        let preferred = self.profile.preferred_inference_precision();
296
297        let mut plans = Vec::with_capacity(total);
298        let mut int8_assigned = 0usize;
299
300        for (i, &name) in layer_names.iter().enumerate() {
301            let is_sensitive = self.config.sensitive_layers.iter().any(|prefix| {
302                let p = prefix.as_str();
303                name.starts_with(p) || name.ends_with(p)
304            });
305
306            let precision = if is_sensitive {
307                QuantizationPrecision::FP32
308            } else if int8_assigned < n_int8
309                && self
310                    .profile
311                    .supported_precisions
312                    .contains(&QuantizationPrecision::INT8)
313            {
314                int8_assigned += 1;
315                QuantizationPrecision::INT8
316            } else {
317                preferred
318            };
319
320            // Pruning: only apply to inner layers (not first/last) if enabled
321            let prune =
322                self.config.enable_pruning && !is_sensitive && i > 0 && i < total.saturating_sub(1);
323
324            plans.push(LayerQuantizationPlan {
325                layer_name: name.to_string(),
326                weight_precision: precision,
327                activation_precision: precision,
328                prune,
329                pruning_sparsity: if prune {
330                    self.config.pruning_sparsity
331                } else {
332                    0.0
333                },
334            });
335        }
336
337        Ok(plans)
338    }
339
340    /// Estimate compressed model size in bytes given a base size and the plan.
341    ///
342    /// Uses the bit-width ratio of each layer's chosen precision vs FP32.
343    pub fn estimate_compressed_size_bytes(
344        &self,
345        base_size_bytes: u64,
346        plan: &[LayerQuantizationPlan],
347    ) -> u64 {
348        if plan.is_empty() {
349            return base_size_bytes;
350        }
351        let weight_ratio: f64 = plan
352            .iter()
353            .map(|p| p.weight_precision.bits() as f64 / 32.0)
354            .sum::<f64>()
355            / plan.len() as f64;
356
357        let prune_ratio: f64 =
358            plan.iter().map(|p| 1.0 - p.pruning_sparsity).sum::<f64>() / plan.len() as f64;
359
360        ((base_size_bytes as f64) * weight_ratio * prune_ratio) as u64
361    }
362
363    /// Quantize a flat f32 weight vector to i8 using symmetric INT8.
364    ///
365    /// Returns `(quantized_weights, scale)` where `scale` maps i8 → f32.
366    pub fn quantize_to_int8(weights: &[f32]) -> Result<(Vec<i8>, f32)> {
367        if weights.is_empty() {
368            return Err(NeuralError::InvalidArgument(
369                "weights slice is empty".to_string(),
370            ));
371        }
372        let abs_max = weights.iter().fold(0.0_f32, |acc, &v| acc.max(v.abs()));
373        let scale = if abs_max > 0.0 { abs_max / 127.0 } else { 1.0 };
374        let quantized: Vec<i8> = weights
375            .iter()
376            .map(|&w| {
377                let q = (w / scale).round();
378                q.clamp(-128.0, 127.0) as i8
379            })
380            .collect();
381        Ok((quantized, scale))
382    }
383
384    /// Dequantize i8 values back to f32 using the given scale.
385    pub fn dequantize_from_int8(quantized: &[i8], scale: f32) -> Vec<f32> {
386        quantized.iter().map(|&q| (q as f32) * scale).collect()
387    }
388
389    /// Quantize a flat f32 vector with asymmetric uint8 (FP16 simulation).
390    ///
391    /// Returns `(quantized, scale, zero_point)`.
392    pub fn quantize_to_fp16_sim(weights: &[f32]) -> Result<(Vec<u8>, f32, f32)> {
393        if weights.is_empty() {
394            return Err(NeuralError::InvalidArgument(
395                "weights slice is empty".to_string(),
396            ));
397        }
398        let min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
399        let max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
400        let scale = if (max - min).abs() > f32::EPSILON {
401            (max - min) / 255.0
402        } else {
403            1.0
404        };
405        let zero_point = min;
406        let quantized: Vec<u8> = weights
407            .iter()
408            .map(|&w| {
409                let q = ((w - zero_point) / scale).round();
410                q.clamp(0.0, 255.0) as u8
411            })
412            .collect();
413        Ok((quantized, scale, zero_point))
414    }
415}
416
417impl Debug for HardwareOptimizer {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        f.debug_struct("HardwareOptimizer")
420            .field("profile", &self.profile.name)
421            .field("strategy", &self.config.strategy)
422            .finish()
423    }
424}
425
426// ─────────────────────────────────────────────────────────────────────────────
427// Helpers
428// ─────────────────────────────────────────────────────────────────────────────
429
430/// Detect the number of logical CPUs on the host (returns 1 if unavailable).
431fn detect_num_cpus() -> usize {
432    std::thread::available_parallelism()
433        .map(|n| n.get())
434        .unwrap_or(1)
435}
436
437// ─────────────────────────────────────────────────────────────────────────────
438// Tests
439// ─────────────────────────────────────────────────────────────────────────────
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_hardware_profile_cpu_default() {
447        let p = HardwareProfile::cpu_default();
448        assert!(!p.name.is_empty());
449        assert!(p.num_cores >= 1);
450        assert!(p
451            .supported_precisions
452            .contains(&QuantizationPrecision::FP32));
453    }
454
455    #[test]
456    fn test_hardware_profile_mobile_arm() {
457        let p = HardwareProfile::mobile_arm();
458        assert!(p.has_npu);
459        assert!(p
460            .supported_precisions
461            .contains(&QuantizationPrecision::INT8));
462        assert_eq!(p.simd_width_bits, 128);
463    }
464
465    #[test]
466    fn test_preferred_precision_mobile() {
467        let p = HardwareProfile::mobile_arm();
468        let pref = p.preferred_inference_precision();
469        assert_eq!(pref, QuantizationPrecision::INT4);
470    }
471
472    #[test]
473    fn test_preferred_precision_cpu() {
474        let p = HardwareProfile::cpu_default();
475        let pref = p.preferred_inference_precision();
476        assert_eq!(pref, QuantizationPrecision::INT8);
477    }
478
479    #[test]
480    fn test_compute_quantization_plan_basic() {
481        let profile = HardwareProfile::mobile_arm();
482        let config = HardwareOptimizerConfig {
483            int8_fraction: 0.6,
484            enable_pruning: false,
485            ..Default::default()
486        };
487        let opt = HardwareOptimizer::new(profile, config);
488        let layers = &["conv1", "bn1", "conv2", "bn2", "fc_output"];
489        let plan = opt.compute_quantization_plan(layers).expect("plan ok");
490        assert_eq!(plan.len(), 5);
491        // fc_output is sensitive → FP32
492        let fc = plan
493            .iter()
494            .find(|p| p.layer_name == "fc_output")
495            .expect("fc");
496        assert_eq!(fc.weight_precision, QuantizationPrecision::FP32);
497    }
498
499    #[test]
500    fn test_compute_quantization_plan_empty_layers_err() {
501        let opt = HardwareOptimizer::new(
502            HardwareProfile::cpu_default(),
503            HardwareOptimizerConfig::default(),
504        );
505        assert!(opt.compute_quantization_plan(&[]).is_err());
506    }
507
508    #[test]
509    fn test_quantize_int8_roundtrip() {
510        let weights: Vec<f32> = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.25, -0.25];
511        let (quant, scale) = HardwareOptimizer::quantize_to_int8(&weights).expect("quant ok");
512        let dequant = HardwareOptimizer::dequantize_from_int8(&quant, scale);
513        for (orig, deq) in weights.iter().zip(dequant.iter()) {
514            assert!((orig - deq).abs() < 0.01, "orig={orig} deq={deq}");
515        }
516    }
517
518    #[test]
519    fn test_quantize_fp16_sim_roundtrip() {
520        let weights: Vec<f32> = vec![0.1, 0.5, -0.3, 0.9, -0.9];
521        let (quant, scale, zp) = HardwareOptimizer::quantize_to_fp16_sim(&weights).expect("ok");
522        let dequant: Vec<f32> = quant.iter().map(|&q| (q as f32) * scale + zp).collect();
523        for (orig, deq) in weights.iter().zip(dequant.iter()) {
524            assert!((orig - deq).abs() < 0.02, "orig={orig} deq={deq}");
525        }
526    }
527
528    #[test]
529    fn test_quantize_int8_empty_err() {
530        assert!(HardwareOptimizer::quantize_to_int8(&[]).is_err());
531    }
532
533    #[test]
534    fn test_estimate_compressed_size() {
535        let profile = HardwareProfile::cpu_default();
536        let opt = HardwareOptimizer::new(profile, HardwareOptimizerConfig::default());
537        let layers = &["layer1", "layer2"];
538        let plan = opt.compute_quantization_plan(layers).expect("plan");
539        let compressed = opt.estimate_compressed_size_bytes(1_000_000, &plan);
540        // INT8 is 8 bits vs 32 bits FP32, so expect < original
541        assert!(compressed < 1_000_000);
542    }
543
544    #[test]
545    fn test_precision_bits() {
546        assert_eq!(QuantizationPrecision::FP32.bits(), 32);
547        assert_eq!(QuantizationPrecision::FP16.bits(), 16);
548        assert_eq!(QuantizationPrecision::INT8.bits(), 8);
549        assert_eq!(QuantizationPrecision::INT4.bits(), 4);
550    }
551
552    #[test]
553    fn test_hardware_optimizer_debug() {
554        let opt = HardwareOptimizer::new(
555            HardwareProfile::cpu_default(),
556            HardwareOptimizerConfig::default(),
557        );
558        let s = format!("{opt:?}");
559        assert!(s.contains("HardwareOptimizer"));
560    }
561}