ruvector_sparse_inference/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SparsityConfig {
8 pub threshold: Option<f32>,
10
11 pub top_k: Option<usize>,
13
14 pub target_sparsity: Option<f32>,
17
18 pub adaptive_threshold: bool,
20}
21
22impl Default for SparsityConfig {
23 fn default() -> Self {
24 Self {
25 threshold: Some(0.01),
26 top_k: None,
27 target_sparsity: None,
28 adaptive_threshold: false,
29 }
30 }
31}
32
33impl SparsityConfig {
34 pub fn with_threshold(threshold: f32) -> Self {
36 Self {
37 threshold: Some(threshold),
38 top_k: None,
39 target_sparsity: None,
40 adaptive_threshold: false,
41 }
42 }
43
44 pub fn with_top_k(k: usize) -> Self {
46 Self {
47 threshold: None,
48 top_k: Some(k),
49 target_sparsity: None,
50 adaptive_threshold: false,
51 }
52 }
53
54 pub fn with_target_sparsity(sparsity: f32) -> Self {
56 Self {
57 threshold: None,
58 top_k: None,
59 target_sparsity: Some(sparsity),
60 adaptive_threshold: true,
61 }
62 }
63
64 pub fn validate(&self) -> Result<(), String> {
66 if self.threshold.is_none() && self.top_k.is_none() && self.target_sparsity.is_none() {
67 return Err("Must specify threshold, top_k, or target_sparsity".to_string());
68 }
69
70 if let Some(threshold) = self.threshold {
71 if threshold < 0.0 {
72 return Err(format!("Threshold must be non-negative, got {}", threshold));
73 }
74 }
75
76 if let Some(k) = self.top_k {
77 if k == 0 {
78 return Err("top_k must be greater than 0".to_string());
79 }
80 }
81
82 if let Some(sparsity) = self.target_sparsity {
83 if !(0.0..=1.0).contains(&sparsity) {
84 return Err(format!("target_sparsity must be in [0, 1], got {}", sparsity));
85 }
86 }
87
88 Ok(())
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ModelConfig {
95 pub input_dim: usize,
97
98 pub hidden_dim: usize,
100
101 pub output_dim: usize,
103
104 pub activation: ActivationType,
106
107 pub rank: usize,
109
110 pub sparsity: SparsityConfig,
112
113 pub quantization: Option<QuantizationType>,
115}
116
117impl ModelConfig {
118 pub fn new(
120 input_dim: usize,
121 hidden_dim: usize,
122 output_dim: usize,
123 rank: usize,
124 ) -> Self {
125 Self {
126 input_dim,
127 hidden_dim,
128 output_dim,
129 activation: ActivationType::Gelu,
130 rank,
131 sparsity: SparsityConfig::default(),
132 quantization: None,
133 }
134 }
135
136 pub fn validate(&self) -> Result<(), String> {
138 if self.input_dim == 0 {
139 return Err("input_dim must be greater than 0".to_string());
140 }
141 if self.hidden_dim == 0 {
142 return Err("hidden_dim must be greater than 0".to_string());
143 }
144 if self.output_dim == 0 {
145 return Err("output_dim must be greater than 0".to_string());
146 }
147 if self.rank == 0 || self.rank > self.input_dim.min(self.hidden_dim) {
148 return Err(format!(
149 "rank must be in (0, min(input_dim, hidden_dim)], got {}",
150 self.rank
151 ));
152 }
153 self.sparsity.validate()?;
154 Ok(())
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
160pub enum CacheStrategy {
161 #[default]
163 Lru,
164 Lfu,
166 Fifo,
168 None,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct CacheConfig {
175 pub hot_neuron_fraction: f32,
177
178 pub max_cold_cache_size: usize,
180
181 pub cache_strategy: CacheStrategy,
183
184 pub hot_neuron_count: usize,
186
187 pub lru_cache_size: usize,
189
190 pub use_mmap: bool,
192
193 pub hot_threshold: f32,
195}
196
197impl Default for CacheConfig {
198 fn default() -> Self {
199 Self {
200 hot_neuron_fraction: 0.2,
201 max_cold_cache_size: 1000,
202 cache_strategy: CacheStrategy::Lru,
203 hot_neuron_count: 1024,
204 lru_cache_size: 4096,
205 use_mmap: false,
206 hot_threshold: 0.5,
207 }
208 }
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
213pub enum ActivationType {
214 Relu,
216
217 Gelu,
219
220 Silu,
222
223 Swish,
225
226 Identity,
228}
229
230impl ActivationType {
231 pub fn apply(&self, x: f32) -> f32 {
233 match self {
234 Self::Relu => x.max(0.0),
235 Self::Gelu => {
236 const SQRT_2_OVER_PI: f32 = 0.7978845608;
238 let x3 = x * x * x;
239 let inner = SQRT_2_OVER_PI * (x + 0.044715 * x3);
240 0.5 * x * (1.0 + inner.tanh())
241 }
242 Self::Silu | Self::Swish => {
243 x / (1.0 + (-x).exp())
245 }
246 Self::Identity => x,
247 }
248 }
249
250 pub fn apply_slice(&self, data: &mut [f32]) {
252 for x in data.iter_mut() {
253 *x = self.apply(*x);
254 }
255 }
256}
257
258#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
260pub enum QuantizationType {
261 F32,
263
264 F16,
266
267 Int8,
269
270 Int4 {
272 group_size: usize,
274 },
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_sparsity_config_validation() {
283 let config = SparsityConfig::with_threshold(0.01);
284 assert!(config.validate().is_ok());
285
286 let config = SparsityConfig::with_top_k(100);
287 assert!(config.validate().is_ok());
288
289 let mut config = SparsityConfig::default();
290 config.threshold = None;
291 config.top_k = None;
292 config.target_sparsity = None;
293 assert!(config.validate().is_err());
294 }
295
296 #[test]
297 fn test_model_config_validation() {
298 let config = ModelConfig::new(128, 512, 128, 64);
299 assert!(config.validate().is_ok());
300
301 let mut config = ModelConfig::new(128, 512, 128, 0);
302 assert!(config.validate().is_err());
303
304 config.rank = 200;
305 assert!(config.validate().is_err());
306 }
307
308 #[test]
309 fn test_activation_functions() {
310 let relu = ActivationType::Relu;
311 assert_eq!(relu.apply(-1.0), 0.0);
312 assert_eq!(relu.apply(1.0), 1.0);
313
314 let gelu = ActivationType::Gelu;
315 assert!(gelu.apply(0.0).abs() < 0.01);
316 assert!(gelu.apply(1.0) > 0.8);
317
318 let silu = ActivationType::Silu;
319 assert!(silu.apply(0.0).abs() < 0.01);
320 assert!(silu.apply(1.0) > 0.7);
321 }
322}