ruvector_dither/
channel.rs1use crate::{DitherSource, GoldenRatioDither};
8
9pub struct ChannelDither {
14 channels: Vec<GoldenRatioDither>,
15 bits: u32,
16 eps: f32,
17}
18
19impl ChannelDither {
20 pub fn new(layer_id: u32, n_channels: usize, bits: u32, eps: f32) -> Self {
22 let channels = (0..n_channels)
23 .map(|ch| GoldenRatioDither::from_ids(layer_id, ch as u32))
24 .collect();
25 Self { channels, bits, eps }
26 }
27
28 pub fn quantize_batch(&mut self, activations: &mut [f32]) {
35 assert!(!self.channels.is_empty(), "ChannelDither must have >= 1 channel");
36 assert!(self.bits >= 2 && self.bits <= 31, "bits must be in [2, 31]");
37 let nc = self.channels.len();
38 let qmax = ((1u32 << (self.bits - 1)) - 1) as f32;
39 let lsb = 1.0 / qmax;
40 for (i, x) in activations.iter_mut().enumerate() {
41 let ch = i % nc;
42 let d = self.channels[ch].next(self.eps * lsb);
43 *x = ((*x + d) * qmax).round().clamp(-qmax, qmax) / qmax;
44 }
45 }
46
47 pub fn n_channels(&self) -> usize {
49 self.channels.len()
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn channel_dither_correct_count() {
59 let cd = ChannelDither::new(0, 16, 8, 0.5);
60 assert_eq!(cd.n_channels(), 16);
61 }
62
63 #[test]
64 fn channel_dither_in_bounds() {
65 let mut cd = ChannelDither::new(1, 8, 5, 0.5);
66 let mut acts: Vec<f32> = (0..64).map(|i| (i as f32 / 63.0) * 2.0 - 1.0).collect();
67 cd.quantize_batch(&mut acts);
68 for v in acts {
69 assert!(v >= -1.0 && v <= 1.0, "out of bounds: {v}");
70 }
71 }
72
73 #[test]
74 fn different_layers_produce_different_outputs() {
75 let input: Vec<f32> = vec![0.5; 16];
76 let mut buf0 = input.clone();
77 let mut buf1 = input.clone();
78 ChannelDither::new(0, 8, 8, 0.5).quantize_batch(&mut buf0);
79 ChannelDither::new(99, 8, 8, 0.5).quantize_batch(&mut buf1);
80 assert_ne!(buf0, buf1, "different layer_ids must yield different dithered outputs");
81 }
82}