Skip to main content

trustformers_optim/
quantized.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug, Clone)]
6pub struct QuantizationConfig {
7    pub scale: f32,
8    pub zero_point: i8,
9    pub min_val: f32,
10    pub max_val: f32,
11}
12
13impl QuantizationConfig {
14    pub fn new(min_val: f32, max_val: f32) -> Self {
15        let scale = (max_val - min_val) / 255.0;
16        let zero_point = (-min_val / scale).round().clamp(-128.0, 127.0) as i8;
17
18        Self {
19            scale,
20            zero_point,
21            min_val,
22            max_val,
23        }
24    }
25
26    pub fn quantize(&self, value: f32) -> i8 {
27        let quantized = ((value - self.min_val) / self.scale).round() - 128.0;
28        quantized.clamp(-128.0, 127.0) as i8
29    }
30
31    pub fn dequantize(&self, quantized: i8) -> f32 {
32        (quantized as f32 + 128.0) * self.scale + self.min_val
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct QuantizedState {
38    pub data: Vec<i8>,
39    pub config: QuantizationConfig,
40}
41
42impl QuantizedState {
43    pub fn new(values: &[f32]) -> Self {
44        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
45        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
46
47        let config = QuantizationConfig::new(min_val, max_val);
48        let data: Vec<i8> = values.iter().map(|&v| config.quantize(v)).collect();
49
50        Self { data, config }
51    }
52
53    pub fn to_f32(&self) -> Vec<f32> {
54        self.data.iter().map(|&q| self.config.dequantize(q)).collect()
55    }
56
57    pub fn update(&mut self, new_values: &[f32]) {
58        let min_val = new_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
59        let max_val = new_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
60
61        self.config = QuantizationConfig::new(min_val, max_val);
62        self.data = new_values.iter().map(|&v| self.config.quantize(v)).collect();
63    }
64}
65
66#[derive(Debug)]
67pub struct Adam8bit {
68    pub learning_rate: f32,
69    pub beta1: f32,
70    pub beta2: f32,
71    pub epsilon: f32,
72    pub weight_decay: f32,
73    pub step: usize,
74    pub momentum_states: HashMap<String, QuantizedState>,
75    pub variance_states: HashMap<String, QuantizedState>,
76}
77
78impl Default for Adam8bit {
79    fn default() -> Self {
80        Self {
81            learning_rate: 1e-3,
82            beta1: 0.9,
83            beta2: 0.999,
84            epsilon: 1e-8,
85            weight_decay: 0.0,
86            step: 0,
87            momentum_states: HashMap::new(),
88            variance_states: HashMap::new(),
89        }
90    }
91}
92
93impl Adam8bit {
94    pub fn new(learning_rate: f32) -> Self {
95        Self {
96            learning_rate,
97            ..Default::default()
98        }
99    }
100
101    pub fn with_config(
102        learning_rate: f32,
103        beta1: f32,
104        beta2: f32,
105        epsilon: f32,
106        weight_decay: f32,
107    ) -> Self {
108        Self {
109            learning_rate,
110            beta1,
111            beta2,
112            epsilon,
113            weight_decay,
114            step: 0,
115            momentum_states: HashMap::new(),
116            variance_states: HashMap::new(),
117        }
118    }
119
120    pub fn step(
121        &mut self,
122        parameters: &mut HashMap<String, Tensor>,
123        gradients: &HashMap<String, Tensor>,
124    ) -> Result<()> {
125        self.step += 1;
126
127        let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
128        let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
129
130        for (name, param) in parameters.iter_mut() {
131            let grad = gradients
132                .get(name)
133                .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
134
135            let mut param_data = param.data()?;
136            let grad_data = grad.data()?;
137
138            if param_data.len() != grad_data.len() {
139                return Err(anyhow::anyhow!(
140                    "Parameter and gradient size mismatch for: {}",
141                    name
142                ));
143            }
144
145            if !self.momentum_states.contains_key(name) {
146                let zeros = vec![0.0; param_data.len()];
147                self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
148                self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
149            }
150
151            let momentum_state = self
152                .momentum_states
153                .get_mut(name)
154                .expect("momentum_state should exist after initialization");
155            let variance_state = self
156                .variance_states
157                .get_mut(name)
158                .expect("variance_state should exist after initialization");
159
160            let mut momentum = momentum_state.to_f32();
161            let mut variance = variance_state.to_f32();
162
163            for i in 0..param_data.len() {
164                let mut grad_val = grad_data[i];
165
166                if self.weight_decay > 0.0 {
167                    grad_val += self.weight_decay * param_data[i];
168                }
169
170                momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
171                variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
172
173                let corrected_momentum = momentum[i] / bias_correction1;
174                let corrected_variance = variance[i] / bias_correction2;
175
176                param_data[i] -= self.learning_rate * corrected_momentum
177                    / (corrected_variance.sqrt() + self.epsilon);
178            }
179
180            momentum_state.update(&momentum);
181            variance_state.update(&variance);
182
183            // Update the parameter tensor with modified data
184            *param = Tensor::new(param_data)?;
185        }
186
187        Ok(())
188    }
189
190    pub fn memory_usage(&self) -> usize {
191        let mut total = 0;
192        for state in self.momentum_states.values() {
193            total += state.data.len();
194        }
195        for state in self.variance_states.values() {
196            total += state.data.len();
197        }
198        total
199    }
200
201    pub fn memory_savings_vs_fp32(&self) -> f32 {
202        let quantized_size = self.memory_usage();
203        let fp32_equivalent = quantized_size * 4;
204        1.0 - (quantized_size as f32 / fp32_equivalent as f32)
205    }
206}
207
208#[derive(Debug)]
209pub struct AdamW8bit {
210    pub learning_rate: f32,
211    pub beta1: f32,
212    pub beta2: f32,
213    pub epsilon: f32,
214    pub weight_decay: f32,
215    pub step: usize,
216    pub momentum_states: HashMap<String, QuantizedState>,
217    pub variance_states: HashMap<String, QuantizedState>,
218}
219
220impl Default for AdamW8bit {
221    fn default() -> Self {
222        Self {
223            learning_rate: 1e-3,
224            beta1: 0.9,
225            beta2: 0.999,
226            epsilon: 1e-8,
227            weight_decay: 1e-2,
228            step: 0,
229            momentum_states: HashMap::new(),
230            variance_states: HashMap::new(),
231        }
232    }
233}
234
235impl AdamW8bit {
236    pub fn new(learning_rate: f32) -> Self {
237        Self {
238            learning_rate,
239            ..Default::default()
240        }
241    }
242
243    pub fn with_config(
244        learning_rate: f32,
245        beta1: f32,
246        beta2: f32,
247        epsilon: f32,
248        weight_decay: f32,
249    ) -> Self {
250        Self {
251            learning_rate,
252            beta1,
253            beta2,
254            epsilon,
255            weight_decay,
256            step: 0,
257            momentum_states: HashMap::new(),
258            variance_states: HashMap::new(),
259        }
260    }
261
262    pub fn step(
263        &mut self,
264        parameters: &mut HashMap<String, Tensor>,
265        gradients: &HashMap<String, Tensor>,
266    ) -> Result<()> {
267        self.step += 1;
268
269        let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
270        let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
271
272        for (name, param) in parameters.iter_mut() {
273            let grad = gradients
274                .get(name)
275                .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
276
277            let mut param_data = param.data()?;
278            let grad_data = grad.data()?;
279
280            if param_data.len() != grad_data.len() {
281                return Err(anyhow::anyhow!(
282                    "Parameter and gradient size mismatch for: {}",
283                    name
284                ));
285            }
286
287            if !self.momentum_states.contains_key(name) {
288                let zeros = vec![0.0; param_data.len()];
289                self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
290                self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
291            }
292
293            let momentum_state = self
294                .momentum_states
295                .get_mut(name)
296                .expect("momentum_state should exist after initialization");
297            let variance_state = self
298                .variance_states
299                .get_mut(name)
300                .expect("variance_state should exist after initialization");
301
302            let mut momentum = momentum_state.to_f32();
303            let mut variance = variance_state.to_f32();
304
305            for i in 0..param_data.len() {
306                let grad_val = grad_data[i];
307
308                momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
309                variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
310
311                let corrected_momentum = momentum[i] / bias_correction1;
312                let corrected_variance = variance[i] / bias_correction2;
313
314                let update = corrected_momentum / (corrected_variance.sqrt() + self.epsilon);
315
316                param_data[i] = param_data[i] * (1.0 - self.learning_rate * self.weight_decay)
317                    - self.learning_rate * update;
318            }
319
320            momentum_state.update(&momentum);
321            variance_state.update(&variance);
322
323            // Update the parameter tensor with modified data
324            *param = Tensor::new(param_data)?;
325        }
326
327        Ok(())
328    }
329
330    pub fn memory_usage(&self) -> usize {
331        let mut total = 0;
332        for state in self.momentum_states.values() {
333            total += state.data.len();
334        }
335        for state in self.variance_states.values() {
336            total += state.data.len();
337        }
338        total
339    }
340
341    pub fn memory_savings_vs_fp32(&self) -> f32 {
342        let quantized_size = self.memory_usage();
343        let fp32_equivalent = quantized_size * 4;
344        1.0 - (quantized_size as f32 / fp32_equivalent as f32)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use approx::assert_abs_diff_eq;
352
353    #[test]
354    fn test_quantization_config() {
355        let config = QuantizationConfig::new(-1.0, 1.0);
356
357        // Test that values map correctly to the quantized range
358        assert_eq!(config.quantize(-1.0), -128);
359        assert_eq!(config.quantize(1.0), 127);
360
361        // Test middle value
362        let mid_quantized = config.quantize(0.0);
363        assert!((-1..=1).contains(&mid_quantized));
364
365        // Test round-trip accuracy
366        let original = 0.5;
367        let quantized = config.quantize(original);
368        let reconstructed = config.dequantize(quantized);
369        assert_abs_diff_eq!(original, reconstructed, epsilon = 0.02);
370    }
371
372    #[test]
373    fn test_quantized_state() {
374        let values = vec![0.1, -0.5, 0.8, -0.2];
375        let state = QuantizedState::new(&values);
376
377        let reconstructed = state.to_f32();
378
379        // Test that we have the right number of values
380        assert_eq!(values.len(), reconstructed.len());
381
382        // Test that quantization preserves relative ordering
383        assert!(reconstructed[2] > reconstructed[0]); // 0.8 > 0.1
384        assert!(reconstructed[1] < reconstructed[0]); // -0.5 < 0.1
385
386        // Test approximate reconstruction (quantization introduces some error)
387        for (orig, recon) in values.iter().zip(reconstructed.iter()) {
388            assert_abs_diff_eq!(orig, recon, epsilon = 0.1);
389        }
390    }
391
392    #[test]
393    fn test_adam8bit_creation() {
394        let optimizer = Adam8bit::new(0.001);
395        assert_eq!(optimizer.learning_rate, 0.001);
396        assert_eq!(optimizer.beta1, 0.9);
397        assert_eq!(optimizer.beta2, 0.999);
398        assert_eq!(optimizer.step, 0);
399    }
400
401    #[test]
402    fn test_adam8bit_memory_usage() {
403        let mut optimizer = Adam8bit::new(0.001);
404
405        let mut parameters = HashMap::new();
406        let mut gradients = HashMap::new();
407
408        let param_data = vec![1.0, 2.0, 3.0, 4.0];
409        let grad_data = vec![0.1, 0.2, 0.3, 0.4];
410
411        parameters.insert(
412            "layer1".to_string(),
413            Tensor::new(param_data).expect("Failed to create tensor"),
414        );
415        gradients.insert(
416            "layer1".to_string(),
417            Tensor::new(grad_data).expect("Failed to create tensor"),
418        );
419
420        optimizer.step(&mut parameters, &gradients).expect("Step failed");
421
422        assert_eq!(optimizer.memory_usage(), 8);
423        assert!(optimizer.memory_savings_vs_fp32() > 0.7);
424    }
425
426    #[test]
427    fn test_adamw8bit_creation() {
428        let optimizer = AdamW8bit::new(0.001);
429        assert_eq!(optimizer.learning_rate, 0.001);
430        assert_eq!(optimizer.weight_decay, 0.01);
431    }
432}