Skip to main content

ruvector_dither/
channel.rs

1//! Per-channel and per-layer dither management.
2//!
3//! `ChannelDither` bundles one `GoldenRatioDither` state per channel,
4//! seeded from `(layer_id, channel_id)` pairs so every channel is
5//! structurally decorrelated without any RNG.
6
7use crate::{DitherSource, GoldenRatioDither};
8
9/// Per-channel dither pool seeded from `(layer_id, channel_id)` pairs.
10///
11/// Allocates one `GoldenRatioDither` per channel; each is independently
12/// advanced, so channels cannot constructively interfere.
13pub struct ChannelDither {
14    channels: Vec<GoldenRatioDither>,
15    bits: u32,
16    eps: f32,
17}
18
19impl ChannelDither {
20    /// Build a pool of `n_channels` dithers for `layer_id` / `bits` / `eps`.
21    pub fn new(layer_id: u32, n_channels: usize, bits: u32, eps: f32) -> Self {
22        let channels = (0..n_channels)
23            .map(|ch| GoldenRatioDither::from_ids(layer_id, ch as u32))
24            .collect();
25        Self { channels, bits, eps }
26    }
27
28    /// Quantize `activations` in-place.  Each column (channel dimension) uses
29    /// its own independent dither state.
30    ///
31    /// `activations` is a flat row-major tensor of shape `[batch, channels]`.
32    /// If the slice is not a multiple of `n_channels`, the remainder is
33    /// processed using channel 0.
34    pub fn quantize_batch(&mut self, activations: &mut [f32]) {
35        assert!(!self.channels.is_empty(), "ChannelDither must have >= 1 channel");
36        assert!(self.bits >= 2 && self.bits <= 31, "bits must be in [2, 31]");
37        let nc = self.channels.len();
38        let qmax = ((1u32 << (self.bits - 1)) - 1) as f32;
39        let lsb = 1.0 / qmax;
40        for (i, x) in activations.iter_mut().enumerate() {
41            let ch = i % nc;
42            let d = self.channels[ch].next(self.eps * lsb);
43            *x = ((*x + d) * qmax).round().clamp(-qmax, qmax) / qmax;
44        }
45    }
46
47    /// Number of channels in this pool.
48    pub fn n_channels(&self) -> usize {
49        self.channels.len()
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn channel_dither_correct_count() {
59        let cd = ChannelDither::new(0, 16, 8, 0.5);
60        assert_eq!(cd.n_channels(), 16);
61    }
62
63    #[test]
64    fn channel_dither_in_bounds() {
65        let mut cd = ChannelDither::new(1, 8, 5, 0.5);
66        let mut acts: Vec<f32> = (0..64).map(|i| (i as f32 / 63.0) * 2.0 - 1.0).collect();
67        cd.quantize_batch(&mut acts);
68        for v in acts {
69            assert!(v >= -1.0 && v <= 1.0, "out of bounds: {v}");
70        }
71    }
72
73    #[test]
74    fn different_layers_produce_different_outputs() {
75        let input: Vec<f32> = vec![0.5; 16];
76        let mut buf0 = input.clone();
77        let mut buf1 = input.clone();
78        ChannelDither::new(0, 8, 8, 0.5).quantize_batch(&mut buf0);
79        ChannelDither::new(99, 8, 8, 0.5).quantize_batch(&mut buf1);
80        assert_ne!(buf0, buf1, "different layer_ids must yield different dithered outputs");
81    }
82}