1use std::hash::{Hash, Hasher};
2
3use bincode::{Decode, Encode};
4#[derive(Debug, Clone, Encode, Decode)]
5pub struct Vector<const N: usize>(Vec<f32>);
6
7impl<const N: usize> From<[f32; N]> for Vector<N> {
8 fn from(values: [f32; N]) -> Self {
9 Self(values.to_vec())
10 }
11}
12
13impl<const N: usize> Hash for Vector<N> {
14 fn hash<H: Hasher>(&self, state: &mut H) {
15 let bytes: &[u8] =
16 unsafe { std::slice::from_raw_parts(self.0.as_ptr() as *const u8, N * 4) };
17 bytes.hash(state);
18 }
19}
20
21impl<const N: usize> PartialEq for Vector<N> {
22 fn eq(&self, other: &Self) -> bool {
24 let bytes_left: &[u8] =
25 unsafe { std::slice::from_raw_parts(self.0.as_ptr() as *const u8, N * 4) };
26 let bytes_right: &[u8] =
27 unsafe { std::slice::from_raw_parts(other.0.as_ptr() as *const u8, N * 4) };
28 bytes_left == bytes_right
29 }
30}
31
32impl<const N: usize> Eq for Vector<N> {}
33
34impl<const N: usize> Vector<N> {
35 pub fn try_from(values: Vec<f32>) -> Result<Self, Vec<f32>> {
55 if values.len() != N {
56 Err(values)
57 } else {
58 Ok(Self(values))
59 }
60 }
61
62 pub fn into_inner(self) -> Vec<f32> {
64 self.0
65 }
66
67 pub fn as_slice(&self) -> &[f32; N] {
68 unsafe { &*(self.0.as_ptr() as *const [f32; N]) }
71 }
72
73 pub(crate) fn subtract_from(&self, other: &Vector<N>) -> Vector<N> {
74 let vals = self
75 .as_slice()
76 .iter()
77 .zip(other.as_slice().iter())
78 .map(|(a, b)| a - b)
79 .collect::<Vec<_>>();
80 debug_assert_eq!(vals.capacity(), N);
81 Vector(vals)
82 }
83
84 pub(crate) fn memory_usage(&self) -> usize {
85 std::mem::size_of::<Self>() + self.0.len() * 4
86 }
87
88 #[cfg(test)]
89 pub(crate) fn add(&self, vector: &Vector<N>) -> Vector<N> {
90 let vals = self
91 .as_slice()
92 .iter()
93 .zip(vector.as_slice().iter())
94 .map(|(a, b)| a + b)
95 .collect::<Vec<_>>();
96 Vector(vals)
97 }
98
99 pub(crate) fn avg(&self, vector: &Vector<N>) -> Vector<N> {
100 let vals = self
101 .as_slice()
102 .iter()
103 .zip(vector.as_slice().iter())
104 .map(|(a, b)| (a + b) / 2.0)
105 .collect::<Vec<_>>();
106 Vector(vals)
107 }
108
109 pub(crate) fn dot_product(&self, vector: &Vector<N>) -> f32 {
110 self.as_slice()
111 .iter()
112 .zip(vector.as_slice().iter())
113 .map(|(a, b)| a * b)
114 .sum()
115 }
116
117 pub fn sq_euc_dist(&self, vector: &Vector<N>) -> f32 {
118 self.as_slice()
119 .iter()
120 .zip(vector.as_slice().iter())
121 .map(|(a, b)| (a - b).powi(2))
122 .sum()
123 }
124
125 pub(crate) fn norm(&self) -> f32 {
126 self.as_slice()
127 .iter()
128 .map(|a| a.powi(2))
129 .sum::<f32>()
130 .sqrt()
131 }
132
133 pub fn cosine_similarity(&self, vector: &Vector<N>) -> f32 {
134 let dot = self.dot_product(vector);
135 let norm_self = self.norm();
136 let norm_vector = vector.norm();
137 dot / (norm_self * norm_vector)
138 }
139}