1use half::f16;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use statrs::distribution::{ContinuousCDF, StudentsT};
5use web_rwkv_derive::DeserializeSeed;
6
7use super::{ops::Activation, TensorCpu, TensorInit, TensorInto};
8use crate::{
9 context::Context,
10 num::Float,
11 tensor::{
12 kind::{ReadWrite, Uniform},
13 ops::TensorOp,
14 serialization::Seed,
15 shape::Shape,
16 TensorError, TensorGpu, TensorGpuView, TensorShape,
17 },
18};
19
20#[derive(Debug, Clone)]
21pub struct Float4Quant(pub TensorCpu<f32>);
22
23impl Default for Float4Quant {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29pub fn quantile_student(nu: f64) -> Vec<f32> {
30 let delta = (1.0 / 32.0 + 1.0 / 30.0) / 2.0;
31
32 let mut quant = Vec::with_capacity(16);
33
34 let step = (0.5 - delta) / 7.0;
35 quant.extend((0..7).map(|i| delta + step * i as f64));
36
37 let step = (1.0 - delta - 0.5) / 8.0;
38 quant.extend((0..9).map(|i| 0.5 + step * i as f64));
39
40 let dist = StudentsT::new(0.0, 1.0, nu).expect("invalid parameters");
41 let quant = quant.iter().map(|&p| dist.inverse_cdf(p)).collect_vec();
42 let max = *quant.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
43 quant.into_iter().map(|p| (p / max) as f32).collect()
44}
45
46impl Float4Quant {
47 pub fn new() -> Self {
49 #[allow(clippy::excessive_precision)]
50 let quant = vec![
51 -1.0,
52 -0.6961928009986877,
53 -0.5250730514526367,
54 -0.39491748809814453,
55 -0.28444138169288635,
56 -0.18477343022823334,
57 -0.09105003625154495,
58 0.0,
59 0.07958029955625534,
60 0.16093020141124725,
61 0.24611230194568634,
62 0.33791524171829224,
63 0.44070982933044434,
64 0.5626170039176941,
65 0.7229568362236023,
66 1.0,
67 ];
68 let shape = Shape::new(quant.len(), 1, 1, 1);
69 Self(TensorCpu::from_data(shape, quant).unwrap())
70 }
71
72 pub fn new_student(nu: f64) -> Self {
74 let quant = quantile_student(nu);
75 let shape = Shape::new(quant.len(), 1, 1, 1);
76 Self(TensorCpu::from_data(shape, quant).unwrap())
77 }
78}
79
80#[derive(Debug, Clone, Serialize, DeserializeSeed)]
81#[serde_seed(seed = "Seed", context = "Context")]
82pub enum Matrix {
83 Fp16(TensorGpu<f16, ReadWrite>),
84 Int8 {
85 w: TensorGpu<u8, ReadWrite>,
86 m: TensorGpu<f16, ReadWrite>,
87 },
88 #[serde(alias = "Nf4")]
89 Fp4 {
90 q: TensorGpu<f32, Uniform>,
91 w: TensorGpu<u8, ReadWrite>,
92 m: TensorGpu<f16, ReadWrite>,
93 },
94}
95
96impl Matrix {
97 pub fn matmul_vec_op<'a, 'b, F0: Float, F1: Float>(
98 &self,
99 input: impl Into<TensorGpuView<'a, F0>>,
100 output: impl Into<TensorGpuView<'b, F1>>,
101 act: Activation,
102 ) -> Result<TensorOp, TensorError> {
103 match self {
104 Matrix::Fp16(matrix) => TensorOp::matmul_vec_fp16(matrix, input, output, act, false),
105 Matrix::Int8 { w, m } => TensorOp::matmul_vec_int8(w, m, input, output, act, false),
106 Matrix::Fp4 { w, q, m } => TensorOp::matmul_vec_nf4(w, q, m, input, output, act, false),
107 }
108 }
109
110 pub fn matmul_vec_op_sparse<'a, 'b, F0: Float, F1: Float>(
111 &self,
112 input: impl Into<TensorGpuView<'a, F0>>,
113 output: impl Into<TensorGpuView<'b, F1>>,
114 act: Activation,
115 ) -> Result<TensorOp, TensorError> {
116 match self {
117 Matrix::Fp16(matrix) => TensorOp::matmul_vec_fp16(matrix, input, output, act, true),
118 Matrix::Int8 { w, m } => TensorOp::matmul_vec_int8(w, m, input, output, act, true),
119 Matrix::Fp4 { w, q, m } => TensorOp::matmul_vec_nf4(w, q, m, input, output, act, true),
120 }
121 }
122
123 pub fn matmul_mat_op<'a, 'b, F0: Float, F1: Float>(
124 &self,
125 input: impl Into<TensorGpuView<'a, F0>>,
126 output: impl Into<TensorGpuView<'b, F1>>,
127 act: Activation,
128 ) -> Result<TensorOp, TensorError> {
129 match self {
130 Matrix::Fp16(matrix) => TensorOp::matmul_mat_fp16(matrix, input, output, act),
131 Matrix::Int8 { w, m } => TensorOp::matmul_mat_int8(w, m, input, output, act),
132 Matrix::Fp4 { w, q, m } => TensorOp::matmul_mat_nf4(w, q, m, input, output, act),
133 }
134 }
135
136 pub fn matmul_op<'a, 'b, F0: Float, F1: Float>(
137 &self,
138 input: impl Into<TensorGpuView<'a, F0>>,
139 output: impl Into<TensorGpuView<'b, F1>>,
140 act: Activation,
141 turbo: bool,
142 ) -> Result<TensorOp, TensorError> {
143 match turbo {
144 true => self.matmul_mat_op(input, output, act),
145 false => self.matmul_vec_op(input, output, act),
146 }
147 }
148
149 pub fn matmul_op_sparse<'a, 'b, F0: Float, F1: Float>(
150 &self,
151 input: impl Into<TensorGpuView<'a, F0>>,
152 output: impl Into<TensorGpuView<'b, F1>>,
153 act: Activation,
154 turbo: bool,
155 ) -> Result<TensorOp, TensorError> {
156 match turbo {
157 true => self.matmul_mat_op(input, output, act),
158 false => self.matmul_vec_op_sparse(input, output, act),
159 }
160 }
161
162 pub fn quant_u8(matrix: &TensorGpu<f16, ReadWrite>) -> Result<Self, TensorError> {
163 let context = matrix.context();
164 let shape = matrix.shape();
165
166 let w = context.tensor_init(shape);
167 let m = context.tensor_init(Shape::new(
168 (shape.len() << 1).div_ceil(TensorOp::INT8_BLOCK_SIZE as usize),
169 1,
170 1,
171 1,
172 ));
173
174 let op = TensorOp::quantize_mat_int8(matrix, &m, &w)?;
175 context.queue.submit(context.encode(&op));
176
177 Ok(Matrix::Int8 { w, m })
178 }
179
180 pub fn quant_nf4(matrix: &TensorGpu<f16, ReadWrite>) -> Result<Self, TensorError> {
181 let context = matrix.context();
182 let shape = matrix.shape();
183
184 let matrix_shape = Shape::new(shape[0] / 2, shape[1], shape[2], shape[3]);
185 let absmax_shape = Shape::new(
186 shape.len().div_ceil(TensorOp::NF4_BLOCK_SIZE as usize),
187 1,
188 1,
189 1,
190 );
191
192 let q = Float4Quant::default().0.to(context);
193 let w = context.tensor_init(matrix_shape);
194 let m = context.tensor_init(absmax_shape);
195
196 let op = TensorOp::quantize_mat_nf4(matrix, &q, &m, &w)?;
197 context.queue.submit(context.encode(&op));
198
199 Ok(Matrix::Fp4 { w, q, m })
200 }
201
202 pub fn quant_sf4(matrix: &TensorGpu<f16, ReadWrite>, nu: f64) -> Result<Self, TensorError> {
203 let context = matrix.context();
204 let shape = matrix.shape();
205
206 let matrix_shape = Shape::new(shape[0] / 2, shape[1], shape[2], shape[3]);
207 let absmax_shape = Shape::new(
208 shape.len().div_ceil(TensorOp::NF4_BLOCK_SIZE as usize),
209 1,
210 1,
211 1,
212 );
213
214 let q = Float4Quant::new_student(nu).0.to(context);
215 let w = context.tensor_init(matrix_shape);
216 let m = context.tensor_init(absmax_shape);
217
218 let op = TensorOp::quantize_mat_nf4(matrix, &q, &m, &w)?;
219 context.queue.submit(context.encode(&op));
220
221 Ok(Matrix::Fp4 { w, q, m })
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct MatrixStatistics {
227 pub quantile: [f32; 7],
229}
230
231impl<F: Float> TensorCpu<F> {
232 pub fn statistics(&self) -> MatrixStatistics {
233 let values: Vec<f32> = self
234 .iter()
235 .map(|x| x.hom())
236 .sorted_unstable_by(|x: &f32, y: &f32| x.total_cmp(y))
237 .collect();
238 assert!(values.len() > 2);
239 let p0 = 0;
240 let p4 = values.len() - 1;
241 let p2 = (p0 + p4) / 2;
242 let p1 = (p0 + p2) / 2;
243 let p3 = (p2 + p4) / 2;
244 let p_005 = ((p4 as f32) * 0.005) as usize;
245 let p_995 = ((p4 as f32) * 0.995) as usize;
246 let quantile = [p0, p_005, p1, p2, p3, p_995, p4].map(|p| values[p]);
247 MatrixStatistics { quantile }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use anyhow::Result;
254
255 use super::quantile_student;
256
257 #[test]
258 fn test_student() -> Result<()> {
259 print!("{:?}", quantile_student(5.0));
260 Ok(())
261 }
262}