1use std::alloc::{alloc, dealloc, Layout};
7use std::ptr;
8
9const CACHE_LINE_SIZE: usize = 64;
11
12#[repr(align(64))] pub struct SoAVectorStorage {
18 count: usize,
20 dimensions: usize,
22 capacity: usize,
24 data: *mut f32,
27}
28
29impl SoAVectorStorage {
30 const MAX_DIMENSIONS: usize = 65536;
32 const MAX_CAPACITY: usize = 1 << 24; pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
40 assert!(
42 dimensions > 0 && dimensions <= Self::MAX_DIMENSIONS,
43 "dimensions must be between 1 and {}",
44 Self::MAX_DIMENSIONS
45 );
46 assert!(
47 initial_capacity <= Self::MAX_CAPACITY,
48 "initial_capacity exceeds maximum of {}",
49 Self::MAX_CAPACITY
50 );
51
52 let capacity = initial_capacity.next_power_of_two();
53
54 let total_elements = dimensions
56 .checked_mul(capacity)
57 .expect("dimensions * capacity overflow");
58 let total_bytes = total_elements
59 .checked_mul(std::mem::size_of::<f32>())
60 .expect("total size overflow");
61
62 let layout =
63 Layout::from_size_align(total_bytes, CACHE_LINE_SIZE).expect("invalid memory layout");
64
65 let data = unsafe { alloc(layout) as *mut f32 };
66
67 unsafe {
69 ptr::write_bytes(data, 0, total_elements);
70 }
71
72 Self {
73 count: 0,
74 dimensions,
75 capacity,
76 data,
77 }
78 }
79
80 pub fn push(&mut self, vector: &[f32]) {
82 assert_eq!(vector.len(), self.dimensions);
83
84 if self.count >= self.capacity {
85 self.grow();
86 }
87
88 for (dim_idx, &value) in vector.iter().enumerate() {
90 let offset = dim_idx * self.capacity + self.count;
91 unsafe {
92 *self.data.add(offset) = value;
93 }
94 }
95
96 self.count += 1;
97 }
98
99 pub fn get(&self, index: usize, output: &mut [f32]) {
101 assert!(index < self.count);
102 assert_eq!(output.len(), self.dimensions);
103
104 for dim_idx in 0..self.dimensions {
105 let offset = dim_idx * self.capacity + index;
106 output[dim_idx] = unsafe { *self.data.add(offset) };
107 }
108 }
109
110 pub fn dimension_slice(&self, dim_idx: usize) -> &[f32] {
113 assert!(dim_idx < self.dimensions);
114 let offset = dim_idx * self.capacity;
115 unsafe { std::slice::from_raw_parts(self.data.add(offset), self.count) }
116 }
117
118 pub fn dimension_slice_mut(&mut self, dim_idx: usize) -> &mut [f32] {
120 assert!(dim_idx < self.dimensions);
121 let offset = dim_idx * self.capacity;
122 unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.count) }
123 }
124
125 pub fn len(&self) -> usize {
127 self.count
128 }
129
130 pub fn is_empty(&self) -> bool {
132 self.count == 0
133 }
134
135 pub fn dimensions(&self) -> usize {
137 self.dimensions
138 }
139
140 fn grow(&mut self) {
142 let new_capacity = self.capacity * 2;
143
144 let new_total_elements = self
146 .dimensions
147 .checked_mul(new_capacity)
148 .expect("dimensions * new_capacity overflow");
149 let new_total_bytes = new_total_elements
150 .checked_mul(std::mem::size_of::<f32>())
151 .expect("total size overflow in grow");
152
153 let new_layout = Layout::from_size_align(new_total_bytes, CACHE_LINE_SIZE)
154 .expect("invalid memory layout in grow");
155
156 let new_data = unsafe { alloc(new_layout) as *mut f32 };
157
158 for dim_idx in 0..self.dimensions {
160 let old_offset = dim_idx * self.capacity;
161 let new_offset = dim_idx * new_capacity;
162
163 unsafe {
164 ptr::copy_nonoverlapping(
165 self.data.add(old_offset),
166 new_data.add(new_offset),
167 self.count,
168 );
169 }
170 }
171
172 let old_layout = Layout::from_size_align(
174 self.dimensions * self.capacity * std::mem::size_of::<f32>(),
175 CACHE_LINE_SIZE,
176 )
177 .unwrap();
178
179 unsafe {
180 dealloc(self.data as *mut u8, old_layout);
181 }
182
183 self.data = new_data;
184 self.capacity = new_capacity;
185 }
186
187 #[inline(always)]
190 pub fn batch_euclidean_distances(&self, query: &[f32], output: &mut [f32]) {
191 assert_eq!(query.len(), self.dimensions);
192 assert_eq!(output.len(), self.count);
193
194 #[cfg(target_arch = "aarch64")]
196 {
197 if self.count >= 16 {
198 unsafe { self.batch_euclidean_distances_neon(query, output) };
199 return;
200 }
201 }
202
203 #[cfg(target_arch = "x86_64")]
204 {
205 if self.count >= 32 && is_x86_feature_detected!("avx2") {
206 unsafe { self.batch_euclidean_distances_avx2(query, output) };
207 return;
208 }
209 }
210
211 self.batch_euclidean_distances_scalar(query, output);
213 }
214
215 #[inline(always)]
217 fn batch_euclidean_distances_scalar(&self, query: &[f32], output: &mut [f32]) {
218 output.fill(0.0);
220
221 for dim_idx in 0..self.dimensions {
223 let dim_slice = self.dimension_slice(dim_idx);
224 let query_val = unsafe { *query.get_unchecked(dim_idx) };
226
227 for vec_idx in 0..self.count {
230 let diff = unsafe { *dim_slice.get_unchecked(vec_idx) } - query_val;
231 unsafe { *output.get_unchecked_mut(vec_idx) += diff * diff };
232 }
233 }
234
235 for distance in output.iter_mut() {
237 *distance = distance.sqrt();
238 }
239 }
240
241 #[cfg(target_arch = "aarch64")]
246 #[inline(always)]
247 unsafe fn batch_euclidean_distances_neon(&self, query: &[f32], output: &mut [f32]) {
248 use std::arch::aarch64::*;
249
250 let out_ptr = output.as_mut_ptr();
251 let query_ptr = query.as_ptr();
252
253 let chunks = self.count / 4;
255
256 let zero = vdupq_n_f32(0.0);
258 for i in 0..chunks {
259 let idx = i * 4;
260 vst1q_f32(out_ptr.add(idx), zero);
261 }
262 for i in (chunks * 4)..self.count {
263 *output.get_unchecked_mut(i) = 0.0;
264 }
265
266 for dim_idx in 0..self.dimensions {
268 let dim_slice = self.dimension_slice(dim_idx);
269 let dim_ptr = dim_slice.as_ptr();
270 let query_val = vdupq_n_f32(*query_ptr.add(dim_idx));
271
272 for i in 0..chunks {
274 let idx = i * 4;
275 let dim_vals = vld1q_f32(dim_ptr.add(idx));
276 let out_vals = vld1q_f32(out_ptr.add(idx));
277
278 let diff = vsubq_f32(dim_vals, query_val);
279 let result = vfmaq_f32(out_vals, diff, diff);
280
281 vst1q_f32(out_ptr.add(idx), result);
282 }
283
284 let query_val_scalar = *query_ptr.add(dim_idx);
286 for i in (chunks * 4)..self.count {
287 let diff = *dim_slice.get_unchecked(i) - query_val_scalar;
288 *output.get_unchecked_mut(i) += diff * diff;
289 }
290 }
291
292 for i in 0..chunks {
294 let idx = i * 4;
295 let vals = vld1q_f32(out_ptr.add(idx));
296 let sqrt_vals = vsqrtq_f32(vals);
297 vst1q_f32(out_ptr.add(idx), sqrt_vals);
298 }
299 for i in (chunks * 4)..self.count {
300 *output.get_unchecked_mut(i) = output.get_unchecked(i).sqrt();
301 }
302 }
303
304 #[cfg(target_arch = "x86_64")]
306 #[target_feature(enable = "avx2")]
307 unsafe fn batch_euclidean_distances_avx2(&self, query: &[f32], output: &mut [f32]) {
308 use std::arch::x86_64::*;
309
310 let chunks = self.count / 8;
311
312 let zero = _mm256_setzero_ps();
314 for i in 0..chunks {
315 let idx = i * 8;
316 _mm256_storeu_ps(output.as_mut_ptr().add(idx), zero);
317 }
318 for i in (chunks * 8)..self.count {
319 output[i] = 0.0;
320 }
321
322 for dim_idx in 0..self.dimensions {
324 let dim_slice = self.dimension_slice(dim_idx);
325 let query_val = _mm256_set1_ps(query[dim_idx]);
326
327 for i in 0..chunks {
329 let idx = i * 8;
330 let dim_vals = _mm256_loadu_ps(dim_slice.as_ptr().add(idx));
331 let out_vals = _mm256_loadu_ps(output.as_ptr().add(idx));
332
333 let diff = _mm256_sub_ps(dim_vals, query_val);
334 let sq = _mm256_mul_ps(diff, diff);
335 let result = _mm256_add_ps(out_vals, sq);
336
337 _mm256_storeu_ps(output.as_mut_ptr().add(idx), result);
338 }
339
340 for i in (chunks * 8)..self.count {
342 let diff = dim_slice[i] - query[dim_idx];
343 output[i] += diff * diff;
344 }
345 }
346
347 for distance in output.iter_mut() {
349 *distance = distance.sqrt();
350 }
351 }
352}
353
354#[cfg(target_arch = "x86_64")]
356fn is_x86_feature_detected_helper(feature: &str) -> bool {
357 match feature {
358 "avx2" => is_x86_feature_detected!("avx2"),
359 _ => false,
360 }
361}
362
363impl Drop for SoAVectorStorage {
364 fn drop(&mut self) {
365 let layout = Layout::from_size_align(
366 self.dimensions * self.capacity * std::mem::size_of::<f32>(),
367 CACHE_LINE_SIZE,
368 )
369 .unwrap();
370
371 unsafe {
372 dealloc(self.data as *mut u8, layout);
373 }
374 }
375}
376
377unsafe impl Send for SoAVectorStorage {}
378unsafe impl Sync for SoAVectorStorage {}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_soa_storage() {
386 let mut storage = SoAVectorStorage::new(3, 4);
387
388 storage.push(&[1.0, 2.0, 3.0]);
389 storage.push(&[4.0, 5.0, 6.0]);
390
391 assert_eq!(storage.len(), 2);
392
393 let mut output = vec![0.0; 3];
394 storage.get(0, &mut output);
395 assert_eq!(output, vec![1.0, 2.0, 3.0]);
396
397 storage.get(1, &mut output);
398 assert_eq!(output, vec![4.0, 5.0, 6.0]);
399 }
400
401 #[test]
402 fn test_dimension_slice() {
403 let mut storage = SoAVectorStorage::new(3, 4);
404
405 storage.push(&[1.0, 2.0, 3.0]);
406 storage.push(&[4.0, 5.0, 6.0]);
407 storage.push(&[7.0, 8.0, 9.0]);
408
409 let dim0 = storage.dimension_slice(0);
411 assert_eq!(dim0, &[1.0, 4.0, 7.0]);
412
413 let dim1 = storage.dimension_slice(1);
415 assert_eq!(dim1, &[2.0, 5.0, 8.0]);
416 }
417
418 #[test]
419 fn test_batch_distances() {
420 let mut storage = SoAVectorStorage::new(3, 4);
421
422 storage.push(&[1.0, 0.0, 0.0]);
423 storage.push(&[0.0, 1.0, 0.0]);
424 storage.push(&[0.0, 0.0, 1.0]);
425
426 let query = vec![1.0, 0.0, 0.0];
427 let mut distances = vec![0.0; 3];
428
429 storage.batch_euclidean_distances(&query, &mut distances);
430
431 assert!((distances[0] - 0.0).abs() < 0.001);
432 assert!((distances[1] - 1.414).abs() < 0.01);
433 assert!((distances[2] - 1.414).abs() < 0.01);
434 }
435}