velesdb_core/
half_precision.rs1use half::{bf16, f16};
32use serde::{Deserialize, Serialize};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
36pub enum VectorPrecision {
37 #[default]
39 F32,
40 F16,
42 BF16,
44}
45
46impl VectorPrecision {
47 #[must_use]
49 pub const fn bytes_per_element(&self) -> usize {
50 match self {
51 Self::F32 => 4,
52 Self::F16 | Self::BF16 => 2,
53 }
54 }
55
56 #[must_use]
58 pub const fn memory_size(&self, dimension: usize) -> usize {
59 self.bytes_per_element() * dimension
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68pub enum VectorData {
69 F32(Vec<f32>),
71 F16(Vec<f16>),
73 BF16(Vec<bf16>),
75}
76
77impl VectorData {
78 #[must_use]
94 pub fn from_f32_slice(data: &[f32], precision: VectorPrecision) -> Self {
95 match precision {
96 VectorPrecision::F32 => Self::F32(data.to_vec()),
97 VectorPrecision::F16 => Self::F16(data.iter().map(|&x| f16::from_f32(x)).collect()),
98 VectorPrecision::BF16 => Self::BF16(data.iter().map(|&x| bf16::from_f32(x)).collect()),
99 }
100 }
101
102 #[must_use]
104 pub fn from_f32_vec(data: Vec<f32>, precision: VectorPrecision) -> Self {
105 match precision {
106 VectorPrecision::F32 => Self::F32(data),
107 VectorPrecision::F16 => Self::F16(data.iter().map(|&x| f16::from_f32(x)).collect()),
108 VectorPrecision::BF16 => Self::BF16(data.iter().map(|&x| bf16::from_f32(x)).collect()),
109 }
110 }
111
112 #[must_use]
114 pub const fn precision(&self) -> VectorPrecision {
115 match self {
116 Self::F32(_) => VectorPrecision::F32,
117 Self::F16(_) => VectorPrecision::F16,
118 Self::BF16(_) => VectorPrecision::BF16,
119 }
120 }
121
122 #[must_use]
124 pub fn len(&self) -> usize {
125 match self {
126 Self::F32(v) => v.len(),
127 Self::F16(v) => v.len(),
128 Self::BF16(v) => v.len(),
129 }
130 }
131
132 #[must_use]
134 pub fn is_empty(&self) -> bool {
135 self.len() == 0
136 }
137
138 #[must_use]
140 pub fn memory_size(&self) -> usize {
141 self.precision().memory_size(self.len())
142 }
143
144 #[must_use]
149 pub fn to_f32_vec(&self) -> Vec<f32> {
150 match self {
151 Self::F32(v) => v.clone(),
152 Self::F16(v) => v.iter().map(|x| x.to_f32()).collect(),
153 Self::BF16(v) => v.iter().map(|x| x.to_f32()).collect(),
154 }
155 }
156
157 #[must_use]
161 pub fn as_f32_slice(&self) -> Option<&[f32]> {
162 match self {
163 Self::F32(v) => Some(v.as_slice()),
164 Self::F16(_) | Self::BF16(_) => None,
165 }
166 }
167
168 #[must_use]
170 pub fn convert(&self, target: VectorPrecision) -> Self {
171 if self.precision() == target {
172 return self.clone();
173 }
174 Self::from_f32_slice(&self.to_f32_vec(), target)
175 }
176}
177
178impl From<Vec<f32>> for VectorData {
179 fn from(data: Vec<f32>) -> Self {
180 Self::F32(data)
181 }
182}
183
184impl From<&[f32]> for VectorData {
185 fn from(data: &[f32]) -> Self {
186 Self::F32(data.to_vec())
187 }
188}
189
190#[must_use]
199pub fn dot_product(a: &VectorData, b: &VectorData) -> f32 {
200 use crate::simd_avx512::dot_product_auto;
201
202 match (a, b) {
203 (VectorData::F32(va), VectorData::F32(vb)) => dot_product_auto(va, vb),
204 (VectorData::F32(va), VectorData::F16(vb)) => {
205 va.iter().zip(vb.iter()).map(|(&x, y)| x * y.to_f32()).sum()
206 }
207 (VectorData::F16(va), VectorData::F32(vb)) => {
208 va.iter().zip(vb.iter()).map(|(x, &y)| x.to_f32() * y).sum()
209 }
210 (VectorData::F16(va), VectorData::F16(vb)) => va
211 .iter()
212 .zip(vb.iter())
213 .map(|(x, y)| x.to_f32() * y.to_f32())
214 .sum(),
215 (VectorData::F32(va), VectorData::BF16(vb)) => {
216 va.iter().zip(vb.iter()).map(|(&x, y)| x * y.to_f32()).sum()
217 }
218 (VectorData::BF16(va), VectorData::F32(vb)) => {
219 va.iter().zip(vb.iter()).map(|(x, &y)| x.to_f32() * y).sum()
220 }
221 (VectorData::BF16(va), VectorData::BF16(vb)) => va
222 .iter()
223 .zip(vb.iter())
224 .map(|(x, y)| x.to_f32() * y.to_f32())
225 .sum(),
226 _ => {
228 let va = a.to_f32_vec();
229 let vb = b.to_f32_vec();
230 dot_product_auto(&va, &vb)
231 }
232 }
233}
234
235#[must_use]
237pub fn cosine_similarity(a: &VectorData, b: &VectorData) -> f32 {
238 use crate::simd_avx512::cosine_similarity_auto;
239
240 if let (VectorData::F32(va), VectorData::F32(vb)) = (a, b) {
241 cosine_similarity_auto(va, vb)
242 } else {
243 let dot = dot_product(a, b);
244 let norm_a = norm_squared(a).sqrt();
245 let norm_b = norm_squared(b).sqrt();
246
247 if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
248 0.0
249 } else {
250 dot / (norm_a * norm_b)
251 }
252 }
253}
254
255#[must_use]
257pub fn euclidean_distance(a: &VectorData, b: &VectorData) -> f32 {
258 use crate::simd_avx512::euclidean_auto;
259
260 match (a, b) {
261 (VectorData::F32(va), VectorData::F32(vb)) => euclidean_auto(va, vb),
262 (VectorData::F32(va), VectorData::F16(vb)) => va
263 .iter()
264 .zip(vb.iter())
265 .map(|(&x, y)| (x - y.to_f32()).powi(2))
266 .sum::<f32>()
267 .sqrt(),
268 (VectorData::F16(va), VectorData::F32(vb)) => va
269 .iter()
270 .zip(vb.iter())
271 .map(|(x, &y)| (x.to_f32() - y).powi(2))
272 .sum::<f32>()
273 .sqrt(),
274 (VectorData::F16(va), VectorData::F16(vb)) => va
275 .iter()
276 .zip(vb.iter())
277 .map(|(x, y)| (x.to_f32() - y.to_f32()).powi(2))
278 .sum::<f32>()
279 .sqrt(),
280 _ => {
282 let va = a.to_f32_vec();
283 let vb = b.to_f32_vec();
284 euclidean_auto(&va, &vb)
285 }
286 }
287}
288
289fn norm_squared(v: &VectorData) -> f32 {
291 match v {
292 VectorData::F32(data) => data.iter().map(|&x| x * x).sum(),
293 VectorData::F16(data) => data
294 .iter()
295 .map(|x| {
296 let f = x.to_f32();
297 f * f
298 })
299 .sum(),
300 VectorData::BF16(data) => data
301 .iter()
302 .map(|x| {
303 let f = x.to_f32();
304 f * f
305 })
306 .sum(),
307 }
308}