Skip to main content

trustformers_optim/
quantized.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug, Clone)]
6pub struct QuantizationConfig {
7    pub scale: f32,
8    pub zero_point: i8,
9    pub min_val: f32,
10    pub max_val: f32,
11}
12
13impl QuantizationConfig {
14    pub fn new(min_val: f32, max_val: f32) -> Self {
15        let scale = (max_val - min_val) / 255.0;
16        let zero_point = (-min_val / scale).round().clamp(-128.0, 127.0) as i8;
17
18        Self {
19            scale,
20            zero_point,
21            min_val,
22            max_val,
23        }
24    }
25
26    pub fn quantize(&self, value: f32) -> i8 {
27        let quantized = ((value - self.min_val) / self.scale).round() - 128.0;
28        quantized.clamp(-128.0, 127.0) as i8
29    }
30
31    pub fn dequantize(&self, quantized: i8) -> f32 {
32        (quantized as f32 + 128.0) * self.scale + self.min_val
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct QuantizedState {
38    pub data: Vec<i8>,
39    pub config: QuantizationConfig,
40}
41
42impl QuantizedState {
43    pub fn new(values: &[f32]) -> Self {
44        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
45        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
46
47        let config = QuantizationConfig::new(min_val, max_val);
48        let data: Vec<i8> = values.iter().map(|&v| config.quantize(v)).collect();
49
50        Self { data, config }
51    }
52
53    pub fn to_f32(&self) -> Vec<f32> {
54        self.data.iter().map(|&q| self.config.dequantize(q)).collect()
55    }
56
57    pub fn update(&mut self, new_values: &[f32]) {
58        let min_val = new_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
59        let max_val = new_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
60
61        self.config = QuantizationConfig::new(min_val, max_val);
62        self.data = new_values.iter().map(|&v| self.config.quantize(v)).collect();
63    }
64}
65
66#[derive(Debug)]
67pub struct Adam8bit {
68    pub learning_rate: f32,
69    pub beta1: f32,
70    pub beta2: f32,
71    pub epsilon: f32,
72    pub weight_decay: f32,
73    pub step: usize,
74    pub momentum_states: HashMap<String, QuantizedState>,
75    pub variance_states: HashMap<String, QuantizedState>,
76}
77
78impl Default for Adam8bit {
79    fn default() -> Self {
80        Self {
81            learning_rate: 1e-3,
82            beta1: 0.9,
83            beta2: 0.999,
84            epsilon: 1e-8,
85            weight_decay: 0.0,
86            step: 0,
87            momentum_states: HashMap::new(),
88            variance_states: HashMap::new(),
89        }
90    }
91}
92
93impl Adam8bit {
94    pub fn new(learning_rate: f32) -> Self {
95        Self {
96            learning_rate,
97            ..Default::default()
98        }
99    }
100
101    pub fn with_config(
102        learning_rate: f32,
103        beta1: f32,
104        beta2: f32,
105        epsilon: f32,
106        weight_decay: f32,
107    ) -> Self {
108        Self {
109            learning_rate,
110            beta1,
111            beta2,
112            epsilon,
113            weight_decay,
114            step: 0,
115            momentum_states: HashMap::new(),
116            variance_states: HashMap::new(),
117        }
118    }
119
120    pub fn step(
121        &mut self,
122        parameters: &mut HashMap<String, Tensor>,
123        gradients: &HashMap<String, Tensor>,
124    ) -> Result<()> {
125        self.step += 1;
126
127        let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
128        let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
129
130        for (name, param) in parameters.iter_mut() {
131            let grad = gradients
132                .get(name)
133                .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
134
135            let mut param_data = param.data()?;
136            let grad_data = grad.data()?;
137
138            if param_data.len() != grad_data.len() {
139                return Err(anyhow::anyhow!(
140                    "Parameter and gradient size mismatch for: {}",
141                    name
142                ));
143            }
144
145            if !self.momentum_states.contains_key(name) {
146                let zeros = vec![0.0; param_data.len()];
147                self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
148                self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
149            }
150
151            let momentum_state = self.momentum_states.get_mut(name).unwrap();
152            let variance_state = self.variance_states.get_mut(name).unwrap();
153
154            let mut momentum = momentum_state.to_f32();
155            let mut variance = variance_state.to_f32();
156
157            for i in 0..param_data.len() {
158                let mut grad_val = grad_data[i];
159
160                if self.weight_decay > 0.0 {
161                    grad_val += self.weight_decay * param_data[i];
162                }
163
164                momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
165                variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
166
167                let corrected_momentum = momentum[i] / bias_correction1;
168                let corrected_variance = variance[i] / bias_correction2;
169
170                param_data[i] -= self.learning_rate * corrected_momentum
171                    / (corrected_variance.sqrt() + self.epsilon);
172            }
173
174            momentum_state.update(&momentum);
175            variance_state.update(&variance);
176
177            // Update the parameter tensor with modified data
178            *param = Tensor::new(param_data)?;
179        }
180
181        Ok(())
182    }
183
184    pub fn memory_usage(&self) -> usize {
185        let mut total = 0;
186        for state in self.momentum_states.values() {
187            total += state.data.len();
188        }
189        for state in self.variance_states.values() {
190            total += state.data.len();
191        }
192        total
193    }
194
195    pub fn memory_savings_vs_fp32(&self) -> f32 {
196        let quantized_size = self.memory_usage();
197        let fp32_equivalent = quantized_size * 4;
198        1.0 - (quantized_size as f32 / fp32_equivalent as f32)
199    }
200}
201
202#[derive(Debug)]
203pub struct AdamW8bit {
204    pub learning_rate: f32,
205    pub beta1: f32,
206    pub beta2: f32,
207    pub epsilon: f32,
208    pub weight_decay: f32,
209    pub step: usize,
210    pub momentum_states: HashMap<String, QuantizedState>,
211    pub variance_states: HashMap<String, QuantizedState>,
212}
213
214impl Default for AdamW8bit {
215    fn default() -> Self {
216        Self {
217            learning_rate: 1e-3,
218            beta1: 0.9,
219            beta2: 0.999,
220            epsilon: 1e-8,
221            weight_decay: 1e-2,
222            step: 0,
223            momentum_states: HashMap::new(),
224            variance_states: HashMap::new(),
225        }
226    }
227}
228
229impl AdamW8bit {
230    pub fn new(learning_rate: f32) -> Self {
231        Self {
232            learning_rate,
233            ..Default::default()
234        }
235    }
236
237    pub fn with_config(
238        learning_rate: f32,
239        beta1: f32,
240        beta2: f32,
241        epsilon: f32,
242        weight_decay: f32,
243    ) -> Self {
244        Self {
245            learning_rate,
246            beta1,
247            beta2,
248            epsilon,
249            weight_decay,
250            step: 0,
251            momentum_states: HashMap::new(),
252            variance_states: HashMap::new(),
253        }
254    }
255
256    pub fn step(
257        &mut self,
258        parameters: &mut HashMap<String, Tensor>,
259        gradients: &HashMap<String, Tensor>,
260    ) -> Result<()> {
261        self.step += 1;
262
263        let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
264        let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
265
266        for (name, param) in parameters.iter_mut() {
267            let grad = gradients
268                .get(name)
269                .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
270
271            let mut param_data = param.data()?;
272            let grad_data = grad.data()?;
273
274            if param_data.len() != grad_data.len() {
275                return Err(anyhow::anyhow!(
276                    "Parameter and gradient size mismatch for: {}",
277                    name
278                ));
279            }
280
281            if !self.momentum_states.contains_key(name) {
282                let zeros = vec![0.0; param_data.len()];
283                self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
284                self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
285            }
286
287            let momentum_state = self.momentum_states.get_mut(name).unwrap();
288            let variance_state = self.variance_states.get_mut(name).unwrap();
289
290            let mut momentum = momentum_state.to_f32();
291            let mut variance = variance_state.to_f32();
292
293            for i in 0..param_data.len() {
294                let grad_val = grad_data[i];
295
296                momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
297                variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
298
299                let corrected_momentum = momentum[i] / bias_correction1;
300                let corrected_variance = variance[i] / bias_correction2;
301
302                let update = corrected_momentum / (corrected_variance.sqrt() + self.epsilon);
303
304                param_data[i] = param_data[i] * (1.0 - self.learning_rate * self.weight_decay)
305                    - self.learning_rate * update;
306            }
307
308            momentum_state.update(&momentum);
309            variance_state.update(&variance);
310
311            // Update the parameter tensor with modified data
312            *param = Tensor::new(param_data)?;
313        }
314
315        Ok(())
316    }
317
318    pub fn memory_usage(&self) -> usize {
319        let mut total = 0;
320        for state in self.momentum_states.values() {
321            total += state.data.len();
322        }
323        for state in self.variance_states.values() {
324            total += state.data.len();
325        }
326        total
327    }
328
329    pub fn memory_savings_vs_fp32(&self) -> f32 {
330        let quantized_size = self.memory_usage();
331        let fp32_equivalent = quantized_size * 4;
332        1.0 - (quantized_size as f32 / fp32_equivalent as f32)
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use approx::assert_abs_diff_eq;
340
341    #[test]
342    fn test_quantization_config() {
343        let config = QuantizationConfig::new(-1.0, 1.0);
344
345        // Test that values map correctly to the quantized range
346        assert_eq!(config.quantize(-1.0), -128);
347        assert_eq!(config.quantize(1.0), 127);
348
349        // Test middle value
350        let mid_quantized = config.quantize(0.0);
351        assert!(mid_quantized >= -1 && mid_quantized <= 1);
352
353        // Test round-trip accuracy
354        let original = 0.5;
355        let quantized = config.quantize(original);
356        let reconstructed = config.dequantize(quantized);
357        assert_abs_diff_eq!(original, reconstructed, epsilon = 0.02);
358    }
359
360    #[test]
361    fn test_quantized_state() {
362        let values = vec![0.1, -0.5, 0.8, -0.2];
363        let state = QuantizedState::new(&values);
364
365        let reconstructed = state.to_f32();
366
367        // Test that we have the right number of values
368        assert_eq!(values.len(), reconstructed.len());
369
370        // Test that quantization preserves relative ordering
371        assert!(reconstructed[2] > reconstructed[0]); // 0.8 > 0.1
372        assert!(reconstructed[1] < reconstructed[0]); // -0.5 < 0.1
373
374        // Test approximate reconstruction (quantization introduces some error)
375        for (orig, recon) in values.iter().zip(reconstructed.iter()) {
376            assert_abs_diff_eq!(orig, recon, epsilon = 0.1);
377        }
378    }
379
380    #[test]
381    fn test_adam8bit_creation() {
382        let optimizer = Adam8bit::new(0.001);
383        assert_eq!(optimizer.learning_rate, 0.001);
384        assert_eq!(optimizer.beta1, 0.9);
385        assert_eq!(optimizer.beta2, 0.999);
386        assert_eq!(optimizer.step, 0);
387    }
388
389    #[test]
390    fn test_adam8bit_memory_usage() {
391        let mut optimizer = Adam8bit::new(0.001);
392
393        let mut parameters = HashMap::new();
394        let mut gradients = HashMap::new();
395
396        let param_data = vec![1.0, 2.0, 3.0, 4.0];
397        let grad_data = vec![0.1, 0.2, 0.3, 0.4];
398
399        parameters.insert("layer1".to_string(), Tensor::new(param_data).unwrap());
400        gradients.insert("layer1".to_string(), Tensor::new(grad_data).unwrap());
401
402        optimizer.step(&mut parameters, &gradients).unwrap();
403
404        assert_eq!(optimizer.memory_usage(), 8);
405        assert!(optimizer.memory_savings_vs_fp32() > 0.7);
406    }
407
408    #[test]
409    fn test_adamw8bit_creation() {
410        let optimizer = AdamW8bit::new(0.001);
411        assert_eq!(optimizer.learning_rate, 0.001);
412        assert_eq!(optimizer.weight_decay, 0.01);
413    }
414}