Skip to main content

small_world_rs/primitives/
vector.rs

1use half::f16;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5enum VectorStorage {
6    F16(Vec<f16>),
7    F32(Vec<f32>),
8}
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11/// Vector is a primitive type that stores a vector embedding, whether it's f16 or f32
12pub struct Vector {
13    storage: VectorStorage,
14}
15
16impl Vector {
17    pub fn new_f32(values: &[f32]) -> Self {
18        Self {
19            storage: VectorStorage::F32(values.to_vec()),
20        }
21    }
22
23    pub fn new_f16(values: &[f32]) -> Self {
24        Self {
25            storage: VectorStorage::F16(values.iter().map(|&x| f16::from_f32(x)).collect()),
26        }
27    }
28
29    pub fn len(&self) -> usize {
30        match &self.storage {
31            VectorStorage::F16(v) => v.len(),
32            VectorStorage::F32(v) => v.len(),
33        }
34    }
35
36    // Get value at index as f32
37    pub fn get(&self, index: usize) -> Option<f32> {
38        match &self.storage {
39            VectorStorage::F16(v) => v.get(index).map(|x| x.to_f32()),
40            VectorStorage::F32(v) => v.get(index).copied(),
41        }
42    }
43
44    pub fn as_slice(&self) -> Vec<f32> {
45        match &self.storage {
46            VectorStorage::F16(v) => v.as_slice().iter().map(|x| x.to_f32()).collect(),
47            VectorStorage::F32(v) => v.as_slice().to_vec(),
48        }
49    }
50}
51
52pub enum VectorIter<'a> {
53    F16(std::iter::Map<std::slice::Iter<'a, f16>, fn(&f16) -> f32>),
54    F32(std::iter::Copied<std::slice::Iter<'a, f32>>),
55}
56
57impl<'a> Iterator for VectorIter<'a> {
58    type Item = f32;
59
60    fn next(&mut self) -> Option<Self::Item> {
61        match self {
62            VectorIter::F16(iter) => iter.next(),
63            VectorIter::F32(iter) => iter.next(),
64        }
65    }
66}
67
68impl Vector {
69    pub fn iter(&self) -> VectorIter {
70        match &self.storage {
71            VectorStorage::F16(v) => VectorIter::F16(v.iter().map(|x| x.to_f32())),
72            VectorStorage::F32(v) => VectorIter::F32(v.iter().copied()),
73        }
74    }
75}