1#[derive(Clone)]
8pub struct FlatVectors {
9 pub data: Vec<f32>,
10 pub dim: usize,
11 pub count: usize,
12}
13
14impl FlatVectors {
15 pub fn new(dim: usize) -> Self {
16 Self {
17 data: Vec::new(),
18 dim,
19 count: 0,
20 }
21 }
22
23 pub fn with_capacity(dim: usize, n: usize) -> Self {
24 Self {
25 data: Vec::with_capacity(n * dim),
26 dim,
27 count: 0,
28 }
29 }
30
31 #[inline]
32 pub fn push(&mut self, vector: &[f32]) {
33 debug_assert_eq!(vector.len(), self.dim);
34 self.data.extend_from_slice(vector);
35 self.count += 1;
36 }
37
38 #[inline]
39 pub fn get(&self, idx: usize) -> &[f32] {
40 let start = idx * self.dim;
41 &self.data[start..start + self.dim]
42 }
43
44 #[inline]
46 pub fn zero_out(&mut self, idx: usize) {
47 let start = idx * self.dim;
48 for v in &mut self.data[start..start + self.dim] {
49 *v = f32::NAN;
50 }
51 }
52
53 pub fn len(&self) -> usize {
54 self.count
55 }
56
57 pub fn is_empty(&self) -> bool {
58 self.count == 0
59 }
60}
61
62#[inline]
68pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
69 debug_assert_eq!(a.len(), b.len());
70
71 #[cfg(feature = "simd")]
72 {
73 simd_l2_squared(a, b)
74 }
75
76 #[cfg(not(feature = "simd"))]
77 {
78 scalar_l2_squared(a, b)
79 }
80}
81
82#[inline]
84pub fn scalar_l2_squared(a: &[f32], b: &[f32]) -> f32 {
85 let len = a.len();
86 let mut s0 = 0.0f32;
87 let mut s1 = 0.0f32;
88 let mut s2 = 0.0f32;
89 let mut s3 = 0.0f32;
90 let mut i = 0;
91
92 while i + 16 <= len {
93 for j in 0..4 {
94 let off = i + j * 4;
95 let d0 = a[off] - b[off];
96 let d1 = a[off + 1] - b[off + 1];
97 let d2 = a[off + 2] - b[off + 2];
98 let d3 = a[off + 3] - b[off + 3];
99 s0 += d0 * d0;
100 s1 += d1 * d1;
101 s2 += d2 * d2;
102 s3 += d3 * d3;
103 }
104 i += 16;
105 }
106 while i < len {
107 let d = a[i] - b[i];
108 s0 += d * d;
109 i += 1;
110 }
111 s0 + s1 + s2 + s3
112}
113
114#[cfg(feature = "simd")]
116#[inline]
117pub fn simd_l2_squared(a: &[f32], b: &[f32]) -> f32 {
118 simsimd::SpatialSimilarity::sqeuclidean(a, b)
120 .map(|d| d as f32)
121 .unwrap_or_else(|| scalar_l2_squared(a, b))
122}
123
124#[inline]
126pub fn inner_product(a: &[f32], b: &[f32]) -> f32 {
127 debug_assert_eq!(a.len(), b.len());
128
129 #[cfg(feature = "simd")]
130 {
131 simsimd::SpatialSimilarity::inner(a, b)
132 .map(|d| -(d as f32))
133 .unwrap_or_else(|| scalar_inner_product(a, b))
134 }
135
136 #[cfg(not(feature = "simd"))]
137 {
138 scalar_inner_product(a, b)
139 }
140}
141
142#[inline]
143fn scalar_inner_product(a: &[f32], b: &[f32]) -> f32 {
144 let mut s0 = 0.0f32;
145 let mut s1 = 0.0f32;
146 let mut s2 = 0.0f32;
147 let mut s3 = 0.0f32;
148 let len = a.len();
149 let mut i = 0;
150
151 while i + 16 <= len {
152 for j in 0..4 {
153 let off = i + j * 4;
154 s0 += a[off] * b[off];
155 s1 += a[off + 1] * b[off + 1];
156 s2 += a[off + 2] * b[off + 2];
157 s3 += a[off + 3] * b[off + 3];
158 }
159 i += 16;
160 }
161 while i < len {
162 s0 += a[i] * b[i];
163 i += 1;
164 }
165 -(s0 + s1 + s2 + s3)
166}
167
168#[inline]
170pub fn pq_asymmetric_distance(codes: &[u8], table: &[f32], k: usize) -> f32 {
171 let mut dist = 0.0f32;
173 for (i, &code) in codes.iter().enumerate() {
174 dist += unsafe { *table.get_unchecked(i * k + code as usize) };
175 }
176 dist
177}
178
179pub struct VisitedSet {
185 bits: Vec<u64>,
186 generation: u64,
187 gens: Vec<u64>,
188}
189
190impl VisitedSet {
191 pub fn new(n: usize) -> Self {
192 Self {
193 bits: vec![0u64; (n + 63) / 64],
194 generation: 1,
195 gens: vec![0u64; n],
196 }
197 }
198
199 #[inline]
201 pub fn clear(&mut self) {
202 self.generation += 1;
203 }
204
205 #[inline]
207 pub fn insert(&mut self, id: u32) {
208 self.gens[id as usize] = self.generation;
209 }
210
211 #[inline]
213 pub fn contains(&self, id: u32) -> bool {
214 self.gens[id as usize] == self.generation
215 }
216}
217
218#[cfg(feature = "gpu")]
225pub mod gpu {
226 use super::FlatVectors;
227
228 #[derive(Debug, Clone, Copy)]
230 pub enum GpuBackend {
231 Metal,
233 Cuda,
235 Vulkan,
237 }
238
239 pub struct GpuDistanceContext {
241 backend: GpuBackend,
242 batch_size: usize,
244 }
245
246 impl GpuDistanceContext {
247 pub fn new() -> Option<Self> {
249 #[cfg(target_os = "macos")]
251 let backend = GpuBackend::Metal;
252 #[cfg(not(target_os = "macos"))]
253 let backend = GpuBackend::Cuda;
254
255 Some(Self {
256 backend,
257 batch_size: 4096,
258 })
259 }
260
261 pub fn batch_l2_squared(
264 &self,
265 query: &[f32],
266 vectors: &FlatVectors,
267 k: usize,
268 ) -> Vec<(u32, f32)> {
269 use rayon::prelude::*;
278
279 let mut dists: Vec<(u32, f32)> = (0..vectors.count as u32)
280 .into_par_iter()
281 .map(|i| {
282 let v = vectors.get(i as usize);
283 (i, super::scalar_l2_squared(query, v))
284 })
285 .collect();
286
287 dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
288 dists.truncate(k);
289 dists
290 }
291
292 pub fn backend(&self) -> GpuBackend {
293 self.backend
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_l2_squared() {
304 let a = vec![1.0, 2.0, 3.0];
305 let b = vec![4.0, 5.0, 6.0];
306 assert!((l2_squared(&a, &b) - 27.0).abs() < 1e-6);
307 }
308
309 #[test]
310 fn test_l2_identical() {
311 let a = vec![1.0; 128];
312 assert!(l2_squared(&a, &a) < 1e-10);
313 }
314
315 #[test]
316 fn test_inner_product() {
317 let a = vec![1.0, 2.0, 3.0];
318 let b = vec![4.0, 5.0, 6.0];
319 assert!((inner_product(&a, &b) - (-32.0)).abs() < 1e-6);
320 }
321
322 #[test]
323 fn test_flat_vectors() {
324 let mut fv = FlatVectors::new(3);
325 fv.push(&[1.0, 2.0, 3.0]);
326 fv.push(&[4.0, 5.0, 6.0]);
327 assert_eq!(fv.len(), 2);
328 assert_eq!(fv.get(0), &[1.0, 2.0, 3.0]);
329 assert_eq!(fv.get(1), &[4.0, 5.0, 6.0]);
330 }
331
332 #[test]
333 fn test_visited_set() {
334 let mut vs = VisitedSet::new(100);
335 vs.insert(42);
336 assert!(vs.contains(42));
337 assert!(!vs.contains(43));
338 vs.clear(); assert!(!vs.contains(42));
340 vs.insert(43);
341 assert!(vs.contains(43));
342 }
343
344 #[test]
345 fn test_pq_flat_table() {
346 let table = vec![
348 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, ];
351 let codes = vec![1u8, 2u8]; let dist = pq_asymmetric_distance(&codes, &table, 4);
353 assert!((dist - (0.2 + 0.7)).abs() < 1e-6);
354 }
355}