web_rwkv/tensor/
matrix.rs

1use half::f16;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use statrs::distribution::{ContinuousCDF, StudentsT};
5use web_rwkv_derive::DeserializeSeed;
6
7use super::{ops::Activation, TensorCpu, TensorInit, TensorInto};
8use crate::{
9    context::Context,
10    num::Float,
11    tensor::{
12        kind::{ReadWrite, Uniform},
13        ops::TensorOp,
14        serialization::Seed,
15        shape::Shape,
16        TensorError, TensorGpu, TensorGpuView, TensorShape,
17    },
18};
19
20#[derive(Debug, Clone)]
21pub struct Float4Quant(pub TensorCpu<f32>);
22
23impl Default for Float4Quant {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29pub fn quantile_student(nu: f64) -> Vec<f32> {
30    let delta = (1.0 / 32.0 + 1.0 / 30.0) / 2.0;
31
32    let mut quant = Vec::with_capacity(16);
33
34    let step = (0.5 - delta) / 7.0;
35    quant.extend((0..7).map(|i| delta + step * i as f64));
36
37    let step = (1.0 - delta - 0.5) / 8.0;
38    quant.extend((0..9).map(|i| 0.5 + step * i as f64));
39
40    let dist = StudentsT::new(0.0, 1.0, nu).expect("invalid parameters");
41    let quant = quant.iter().map(|&p| dist.inverse_cdf(p)).collect_vec();
42    let max = *quant.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
43    quant.into_iter().map(|p| (p / max) as f32).collect()
44}
45
46impl Float4Quant {
47    /// Use normal distribution to quantize.
48    pub fn new() -> Self {
49        #[allow(clippy::excessive_precision)]
50        let quant = vec![
51            -1.0,
52            -0.6961928009986877,
53            -0.5250730514526367,
54            -0.39491748809814453,
55            -0.28444138169288635,
56            -0.18477343022823334,
57            -0.09105003625154495,
58            0.0,
59            0.07958029955625534,
60            0.16093020141124725,
61            0.24611230194568634,
62            0.33791524171829224,
63            0.44070982933044434,
64            0.5626170039176941,
65            0.7229568362236023,
66            1.0,
67        ];
68        let shape = Shape::new(quant.len(), 1, 1, 1);
69        Self(TensorCpu::from_data(shape, quant).unwrap())
70    }
71
72    /// Use Student's T distribution to quantize. For most cases `nu` can be set to 5.
73    pub fn new_student(nu: f64) -> Self {
74        let quant = quantile_student(nu);
75        let shape = Shape::new(quant.len(), 1, 1, 1);
76        Self(TensorCpu::from_data(shape, quant).unwrap())
77    }
78}
79
80#[derive(Debug, Clone, Serialize, DeserializeSeed)]
81#[serde_seed(seed = "Seed", context = "Context")]
82pub enum Matrix {
83    Fp16(TensorGpu<f16, ReadWrite>),
84    Int8 {
85        w: TensorGpu<u8, ReadWrite>,
86        m: TensorGpu<f16, ReadWrite>,
87    },
88    #[serde(alias = "Nf4")]
89    Fp4 {
90        q: TensorGpu<f32, Uniform>,
91        w: TensorGpu<u8, ReadWrite>,
92        m: TensorGpu<f16, ReadWrite>,
93    },
94}
95
96impl Matrix {
97    pub fn matmul_vec_op<'a, 'b, F0: Float, F1: Float>(
98        &self,
99        input: impl Into<TensorGpuView<'a, F0>>,
100        output: impl Into<TensorGpuView<'b, F1>>,
101        act: Activation,
102    ) -> Result<TensorOp, TensorError> {
103        match self {
104            Matrix::Fp16(matrix) => TensorOp::matmul_vec_fp16(matrix, input, output, act, false),
105            Matrix::Int8 { w, m } => TensorOp::matmul_vec_int8(w, m, input, output, act, false),
106            Matrix::Fp4 { w, q, m } => TensorOp::matmul_vec_nf4(w, q, m, input, output, act, false),
107        }
108    }
109
110    pub fn matmul_vec_op_sparse<'a, 'b, F0: Float, F1: Float>(
111        &self,
112        input: impl Into<TensorGpuView<'a, F0>>,
113        output: impl Into<TensorGpuView<'b, F1>>,
114        act: Activation,
115    ) -> Result<TensorOp, TensorError> {
116        match self {
117            Matrix::Fp16(matrix) => TensorOp::matmul_vec_fp16(matrix, input, output, act, true),
118            Matrix::Int8 { w, m } => TensorOp::matmul_vec_int8(w, m, input, output, act, true),
119            Matrix::Fp4 { w, q, m } => TensorOp::matmul_vec_nf4(w, q, m, input, output, act, true),
120        }
121    }
122
123    pub fn matmul_mat_op<'a, 'b, F0: Float, F1: Float>(
124        &self,
125        input: impl Into<TensorGpuView<'a, F0>>,
126        output: impl Into<TensorGpuView<'b, F1>>,
127        act: Activation,
128    ) -> Result<TensorOp, TensorError> {
129        match self {
130            Matrix::Fp16(matrix) => TensorOp::matmul_mat_fp16(matrix, input, output, act),
131            Matrix::Int8 { w, m } => TensorOp::matmul_mat_int8(w, m, input, output, act),
132            Matrix::Fp4 { w, q, m } => TensorOp::matmul_mat_nf4(w, q, m, input, output, act),
133        }
134    }
135
136    pub fn matmul_op<'a, 'b, F0: Float, F1: Float>(
137        &self,
138        input: impl Into<TensorGpuView<'a, F0>>,
139        output: impl Into<TensorGpuView<'b, F1>>,
140        act: Activation,
141        turbo: bool,
142    ) -> Result<TensorOp, TensorError> {
143        match turbo {
144            true => self.matmul_mat_op(input, output, act),
145            false => self.matmul_vec_op(input, output, act),
146        }
147    }
148
149    pub fn matmul_op_sparse<'a, 'b, F0: Float, F1: Float>(
150        &self,
151        input: impl Into<TensorGpuView<'a, F0>>,
152        output: impl Into<TensorGpuView<'b, F1>>,
153        act: Activation,
154        turbo: bool,
155    ) -> Result<TensorOp, TensorError> {
156        match turbo {
157            true => self.matmul_mat_op(input, output, act),
158            false => self.matmul_vec_op_sparse(input, output, act),
159        }
160    }
161
162    pub fn quant_u8(matrix: &TensorGpu<f16, ReadWrite>) -> Result<Self, TensorError> {
163        let context = matrix.context();
164        let shape = matrix.shape();
165
166        let w = context.tensor_init(shape);
167        let m = context.tensor_init(Shape::new(
168            (shape.len() << 1).div_ceil(TensorOp::INT8_BLOCK_SIZE as usize),
169            1,
170            1,
171            1,
172        ));
173
174        let op = TensorOp::quantize_mat_int8(matrix, &m, &w)?;
175        context.queue.submit(context.encode(&op));
176
177        Ok(Matrix::Int8 { w, m })
178    }
179
180    pub fn quant_nf4(matrix: &TensorGpu<f16, ReadWrite>) -> Result<Self, TensorError> {
181        let context = matrix.context();
182        let shape = matrix.shape();
183
184        let matrix_shape = Shape::new(shape[0] / 2, shape[1], shape[2], shape[3]);
185        let absmax_shape = Shape::new(
186            shape.len().div_ceil(TensorOp::NF4_BLOCK_SIZE as usize),
187            1,
188            1,
189            1,
190        );
191
192        let q = Float4Quant::default().0.to(context);
193        let w = context.tensor_init(matrix_shape);
194        let m = context.tensor_init(absmax_shape);
195
196        let op = TensorOp::quantize_mat_nf4(matrix, &q, &m, &w)?;
197        context.queue.submit(context.encode(&op));
198
199        Ok(Matrix::Fp4 { w, q, m })
200    }
201
202    pub fn quant_sf4(matrix: &TensorGpu<f16, ReadWrite>, nu: f64) -> Result<Self, TensorError> {
203        let context = matrix.context();
204        let shape = matrix.shape();
205
206        let matrix_shape = Shape::new(shape[0] / 2, shape[1], shape[2], shape[3]);
207        let absmax_shape = Shape::new(
208            shape.len().div_ceil(TensorOp::NF4_BLOCK_SIZE as usize),
209            1,
210            1,
211            1,
212        );
213
214        let q = Float4Quant::new_student(nu).0.to(context);
215        let w = context.tensor_init(matrix_shape);
216        let m = context.tensor_init(absmax_shape);
217
218        let op = TensorOp::quantize_mat_nf4(matrix, &q, &m, &w)?;
219        context.queue.submit(context.encode(&op));
220
221        Ok(Matrix::Fp4 { w, q, m })
222    }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct MatrixStatistics {
227    /// Quantile values: `min`, `q_005`, `q_25`, `q_50`, `q_75`, `q_995`, `max`.
228    pub quantile: [f32; 7],
229}
230
231impl<F: Float> TensorCpu<F> {
232    pub fn statistics(&self) -> MatrixStatistics {
233        let values: Vec<f32> = self
234            .iter()
235            .map(|x| x.hom())
236            .sorted_unstable_by(|x: &f32, y: &f32| x.total_cmp(y))
237            .collect();
238        assert!(values.len() > 2);
239        let p0 = 0;
240        let p4 = values.len() - 1;
241        let p2 = (p0 + p4) / 2;
242        let p1 = (p0 + p2) / 2;
243        let p3 = (p2 + p4) / 2;
244        let p_005 = ((p4 as f32) * 0.005) as usize;
245        let p_995 = ((p4 as f32) * 0.995) as usize;
246        let quantile = [p0, p_005, p1, p2, p3, p_995, p4].map(|p| values[p]);
247        MatrixStatistics { quantile }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use anyhow::Result;
254
255    use super::quantile_student;
256
257    #[test]
258    fn test_student() -> Result<()> {
259        print!("{:?}", quantile_student(5.0));
260        Ok(())
261    }
262}