ruvector_sparse_inference/
config.rs

1//! Configuration structures for sparse inference.
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for sparsity settings.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SparsityConfig {
8    /// Activation threshold τ for neuron selection.
9    pub threshold: Option<f32>,
10
11    /// Top-K neuron selection (alternative to threshold).
12    pub top_k: Option<usize>,
13
14    /// Target sparsity ratio (0.0 to 1.0).
15    /// Used for automatic threshold calibration.
16    pub target_sparsity: Option<f32>,
17
18    /// Enable adaptive threshold adjustment.
19    pub adaptive_threshold: bool,
20}
21
22impl Default for SparsityConfig {
23    fn default() -> Self {
24        Self {
25            threshold: Some(0.01),
26            top_k: None,
27            target_sparsity: None,
28            adaptive_threshold: false,
29        }
30    }
31}
32
33impl SparsityConfig {
34    /// Create config with threshold-based selection.
35    pub fn with_threshold(threshold: f32) -> Self {
36        Self {
37            threshold: Some(threshold),
38            top_k: None,
39            target_sparsity: None,
40            adaptive_threshold: false,
41        }
42    }
43
44    /// Create config with top-K selection.
45    pub fn with_top_k(k: usize) -> Self {
46        Self {
47            threshold: None,
48            top_k: Some(k),
49            target_sparsity: None,
50            adaptive_threshold: false,
51        }
52    }
53
54    /// Create config with target sparsity ratio.
55    pub fn with_target_sparsity(sparsity: f32) -> Self {
56        Self {
57            threshold: None,
58            top_k: None,
59            target_sparsity: Some(sparsity),
60            adaptive_threshold: true,
61        }
62    }
63
64    /// Validate configuration.
65    pub fn validate(&self) -> Result<(), String> {
66        if self.threshold.is_none() && self.top_k.is_none() && self.target_sparsity.is_none() {
67            return Err("Must specify threshold, top_k, or target_sparsity".to_string());
68        }
69
70        if let Some(threshold) = self.threshold {
71            if threshold < 0.0 {
72                return Err(format!("Threshold must be non-negative, got {}", threshold));
73            }
74        }
75
76        if let Some(k) = self.top_k {
77            if k == 0 {
78                return Err("top_k must be greater than 0".to_string());
79            }
80        }
81
82        if let Some(sparsity) = self.target_sparsity {
83            if !(0.0..=1.0).contains(&sparsity) {
84                return Err(format!("target_sparsity must be in [0, 1], got {}", sparsity));
85            }
86        }
87
88        Ok(())
89    }
90}
91
92/// Configuration for the model.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ModelConfig {
95    /// Input dimension.
96    pub input_dim: usize,
97
98    /// Hidden dimension (number of neurons).
99    pub hidden_dim: usize,
100
101    /// Output dimension.
102    pub output_dim: usize,
103
104    /// Activation function type.
105    pub activation: ActivationType,
106
107    /// Low-rank approximation rank.
108    pub rank: usize,
109
110    /// Sparsity configuration.
111    pub sparsity: SparsityConfig,
112
113    /// Enable quantization.
114    pub quantization: Option<QuantizationType>,
115}
116
117impl ModelConfig {
118    /// Create a new model configuration.
119    pub fn new(
120        input_dim: usize,
121        hidden_dim: usize,
122        output_dim: usize,
123        rank: usize,
124    ) -> Self {
125        Self {
126            input_dim,
127            hidden_dim,
128            output_dim,
129            activation: ActivationType::Gelu,
130            rank,
131            sparsity: SparsityConfig::default(),
132            quantization: None,
133        }
134    }
135
136    /// Validate configuration.
137    pub fn validate(&self) -> Result<(), String> {
138        if self.input_dim == 0 {
139            return Err("input_dim must be greater than 0".to_string());
140        }
141        if self.hidden_dim == 0 {
142            return Err("hidden_dim must be greater than 0".to_string());
143        }
144        if self.output_dim == 0 {
145            return Err("output_dim must be greater than 0".to_string());
146        }
147        if self.rank == 0 || self.rank > self.input_dim.min(self.hidden_dim) {
148            return Err(format!(
149                "rank must be in (0, min(input_dim, hidden_dim)], got {}",
150                self.rank
151            ));
152        }
153        self.sparsity.validate()?;
154        Ok(())
155    }
156}
157
158/// Cache strategy for cold neurons.
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
160pub enum CacheStrategy {
161    /// Least Recently Used eviction.
162    #[default]
163    Lru,
164    /// Least Frequently Used eviction.
165    Lfu,
166    /// First In First Out eviction.
167    Fifo,
168    /// No caching (always load from disk).
169    None,
170}
171
172/// Cache configuration.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct CacheConfig {
175    /// Fraction of neurons to keep hot (0.0 to 1.0).
176    pub hot_neuron_fraction: f32,
177
178    /// Maximum number of cold neurons to cache.
179    pub max_cold_cache_size: usize,
180
181    /// Cache eviction strategy.
182    pub cache_strategy: CacheStrategy,
183
184    /// Number of hot neurons (always in memory).
185    pub hot_neuron_count: usize,
186
187    /// LRU cache size for cold neurons.
188    pub lru_cache_size: usize,
189
190    /// Enable memory-mapped cold weights.
191    pub use_mmap: bool,
192
193    /// Activation frequency threshold for hot classification.
194    pub hot_threshold: f32,
195}
196
197impl Default for CacheConfig {
198    fn default() -> Self {
199        Self {
200            hot_neuron_fraction: 0.2,
201            max_cold_cache_size: 1000,
202            cache_strategy: CacheStrategy::Lru,
203            hot_neuron_count: 1024,
204            lru_cache_size: 4096,
205            use_mmap: false,
206            hot_threshold: 0.5,
207        }
208    }
209}
210
211/// Activation function types.
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
213pub enum ActivationType {
214    /// Rectified Linear Unit: max(0, x)
215    Relu,
216
217    /// Gaussian Error Linear Unit: x * Φ(x)
218    Gelu,
219
220    /// Sigmoid Linear Unit: x * sigmoid(x)
221    Silu,
222
223    /// Swish activation (same as SiLU)
224    Swish,
225
226    /// Identity (no activation)
227    Identity,
228}
229
230impl ActivationType {
231    /// Apply activation function to a single value.
232    pub fn apply(&self, x: f32) -> f32 {
233        match self {
234            Self::Relu => x.max(0.0),
235            Self::Gelu => {
236                // Approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
237                const SQRT_2_OVER_PI: f32 = 0.7978845608;
238                let x3 = x * x * x;
239                let inner = SQRT_2_OVER_PI * (x + 0.044715 * x3);
240                0.5 * x * (1.0 + inner.tanh())
241            }
242            Self::Silu | Self::Swish => {
243                // x * sigmoid(x) = x / (1 + exp(-x))
244                x / (1.0 + (-x).exp())
245            }
246            Self::Identity => x,
247        }
248    }
249
250    /// Apply activation function to a slice in-place.
251    pub fn apply_slice(&self, data: &mut [f32]) {
252        for x in data.iter_mut() {
253            *x = self.apply(*x);
254        }
255    }
256}
257
258/// Quantization types.
259#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
260pub enum QuantizationType {
261    /// 32-bit floating point (no quantization).
262    F32,
263
264    /// 16-bit floating point.
265    F16,
266
267    /// 8-bit integer quantization.
268    Int8,
269
270    /// 4-bit integer quantization (GGUF-style).
271    Int4 {
272        /// Group size for quantization.
273        group_size: usize,
274    },
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_sparsity_config_validation() {
283        let config = SparsityConfig::with_threshold(0.01);
284        assert!(config.validate().is_ok());
285
286        let config = SparsityConfig::with_top_k(100);
287        assert!(config.validate().is_ok());
288
289        let mut config = SparsityConfig::default();
290        config.threshold = None;
291        config.top_k = None;
292        config.target_sparsity = None;
293        assert!(config.validate().is_err());
294    }
295
296    #[test]
297    fn test_model_config_validation() {
298        let config = ModelConfig::new(128, 512, 128, 64);
299        assert!(config.validate().is_ok());
300
301        let mut config = ModelConfig::new(128, 512, 128, 0);
302        assert!(config.validate().is_err());
303
304        config.rank = 200;
305        assert!(config.validate().is_err());
306    }
307
308    #[test]
309    fn test_activation_functions() {
310        let relu = ActivationType::Relu;
311        assert_eq!(relu.apply(-1.0), 0.0);
312        assert_eq!(relu.apply(1.0), 1.0);
313
314        let gelu = ActivationType::Gelu;
315        assert!(gelu.apply(0.0).abs() < 0.01);
316        assert!(gelu.apply(1.0) > 0.8);
317
318        let silu = ActivationType::Silu;
319        assert!(silu.apply(0.0).abs() < 0.01);
320        assert!(silu.apply(1.0) > 0.7);
321    }
322}