scirs2_linalg/quantization/
vector.rs1use half::{bf16, f16};
7use scirs2_core::ndarray::Array1;
8
9use super::types::QuantizedDataType;
10
11#[derive(Debug, Clone)]
13pub struct QuantizedVector {
14 pub data: QuantizedData1D,
16
17 pub length: usize,
19
20 pub data_type: QuantizedDataType,
22}
23
24#[derive(Debug, Clone)]
26pub enum QuantizedData1D {
27 Int8(Array1<i8>),
29 Float16(Array1<f16>),
31 BFloat16(Array1<bf16>),
33}
34
35impl QuantizedData1D {
36 pub fn len(&self) -> usize {
38 match self {
39 QuantizedData1D::Int8(arr) => arr.len(),
40 QuantizedData1D::Float16(arr) => arr.len(),
41 QuantizedData1D::BFloat16(arr) => arr.len(),
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.len() == 0
48 }
49}
50
51#[allow(dead_code)]
55pub fn get_quantized_vector_1d_i8(vector: &QuantizedVector) -> Option<&Array1<i8>> {
56 match &vector.data {
57 QuantizedData1D::Int8(data) => Some(data),
58 _ => None,
59 }
60}
61
62impl QuantizedVector {
63 pub fn new_i8(data: Array1<i8>, length: usize, datatype: QuantizedDataType) -> Self {
65 Self {
66 data: QuantizedData1D::Int8(data),
67 length,
68 data_type: datatype,
69 }
70 }
71
72 pub fn new_f16(data: Array1<f16>, length: usize) -> Self {
74 Self {
75 data: QuantizedData1D::Float16(data),
76 length,
77 data_type: QuantizedDataType::Float16,
78 }
79 }
80
81 pub fn new_bf16(data: Array1<bf16>, length: usize) -> Self {
83 Self {
84 data: QuantizedData1D::BFloat16(data),
85 length,
86 data_type: QuantizedDataType::BFloat16,
87 }
88 }
89
90 pub fn from_i8(data: Array1<i8>, length: usize) -> Self {
92 Self {
93 data: QuantizedData1D::Int8(data),
94 length,
95 data_type: QuantizedDataType::Int8,
96 }
97 }
98
99 #[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
102 pub fn get(&self, idx: usize) -> i8 {
103 self.get_i8(idx)
104 }
105
106 pub fn len(&self) -> usize {
108 self.length
109 }
110
111 pub fn is_empty(&self) -> bool {
113 self.length == 0
114 }
115
116 pub fn get_i8(&self, idx: usize) -> i8 {
118 match &self.data {
119 QuantizedData1D::Int8(arr) => {
120 match self.data_type {
121 QuantizedDataType::Int8 => arr[idx],
122 QuantizedDataType::Int4 => {
123 let byte_idx = idx / 2;
124 let nibble_idx = idx % 2;
125 let byte = arr[byte_idx];
126
127 if nibble_idx == 0 {
128 byte >> 4
130 } else {
131 byte & 0x0F
133 }
134 }
135 QuantizedDataType::UInt4 => {
136 let byte_idx = idx / 2;
137 let nibble_idx = idx % 2;
138 let byte = arr[byte_idx];
139
140 if nibble_idx == 0 {
141 (byte >> 4) & 0x0F
143 } else {
144 byte & 0x0F
146 }
147 }
148 _ => unreachable!(
149 "Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
150 ),
151 }
152 }
153 _ => unreachable!("Cannot get i8 value from floating-point quantized vector"),
154 }
155 }
156
157 pub fn get_f32(&self, idx: usize) -> f32 {
159 match &self.data {
160 QuantizedData1D::Int8(arr) => match self.data_type {
161 QuantizedDataType::Int8 => arr[idx] as f32,
162 QuantizedDataType::Int4 => self.get_i8(idx) as f32,
163 QuantizedDataType::UInt4 => self.get_i8(idx) as f32,
164 _ => unreachable!(
165 "Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
166 ),
167 },
168 QuantizedData1D::Float16(arr) => arr[idx].to_f32(),
169 QuantizedData1D::BFloat16(arr) => arr[idx].to_f32(),
170 }
171 }
172}