1use std::alloc::{Layout, alloc_zeroed, dealloc};
46use std::mem::size_of;
47use std::ptr::NonNull;
48use std::sync::atomic::{AtomicU32, Ordering};
49
50pub const SIMD_ALIGNMENT: usize = 32;
56
57pub const CACHE_LINE_SIZE: usize = 64;
59
60pub const BLOCK_ALIGNMENT: usize = 4096;
62
63pub const EMBEDDING_MAGIC: u32 = 0x564543_01; pub const NEIGHBOR_MAGIC: u32 = 0x4E4249_01; #[repr(C, align(64))]
75#[derive(Debug, Clone)]
76pub struct EmbeddingBlockHeader {
77 pub magic: u32,
79 pub version: u32,
81 pub count: u32,
83 pub dim: u32,
85 pub data_offset: u32,
87 pub stride: u32,
89 pub checksum: u32,
91 pub reserved: [u32; 9],
93}
94
95impl EmbeddingBlockHeader {
96 pub fn new(count: u32, dim: u32) -> Self {
98 let vector_size = dim as usize * size_of::<f32>();
99 let stride = align_up(vector_size, SIMD_ALIGNMENT);
100 let data_offset = size_of::<Self>();
101
102 Self {
103 magic: EMBEDDING_MAGIC,
104 version: 1,
105 count,
106 dim,
107 data_offset: data_offset as u32,
108 stride: stride as u32,
109 checksum: 0,
110 reserved: [0; 9],
111 }
112 }
113
114 pub fn is_valid(&self) -> bool {
116 self.magic == EMBEDDING_MAGIC && self.version <= 1
117 }
118
119 pub fn block_size(&self) -> usize {
121 self.data_offset as usize + (self.count as usize * self.stride as usize)
122 }
123}
124
125#[repr(C, align(64))]
127#[derive(Debug, Clone)]
128pub struct NeighborBlockHeader {
129 pub magic: u32,
131 pub version: u32,
133 pub node_count: u32,
135 pub max_edges: u32,
137 pub data_offset: u32,
139 pub stride: u32,
141 pub checksum: u32,
143 pub reserved: [u32; 9],
145}
146
147impl NeighborBlockHeader {
148 pub fn new(node_count: u32, max_edges: u32) -> Self {
150 let list_size = max_edges as usize * size_of::<u32>();
151 let stride = align_up(list_size, CACHE_LINE_SIZE);
152 let data_offset = size_of::<Self>();
153
154 Self {
155 magic: NEIGHBOR_MAGIC,
156 version: 1,
157 node_count,
158 max_edges: max_edges,
159 data_offset: data_offset as u32,
160 stride: stride as u32,
161 checksum: 0,
162 reserved: [0; 9],
163 }
164 }
165
166 pub fn is_valid(&self) -> bool {
168 self.magic == NEIGHBOR_MAGIC && self.version <= 1
169 }
170
171 pub fn block_size(&self) -> usize {
173 self.data_offset as usize + (self.node_count as usize * self.stride as usize)
174 }
175}
176
177#[inline]
183pub const fn align_up(value: usize, alignment: usize) -> usize {
184 (value + alignment - 1) & !(alignment - 1)
185}
186
187#[inline]
189pub const fn align_down(value: usize, alignment: usize) -> usize {
190 value & !(alignment - 1)
191}
192
193pub fn alloc_aligned(size: usize, alignment: usize) -> Option<NonNull<u8>> {
195 if size == 0 {
196 return None;
197 }
198
199 let layout = Layout::from_size_align(size, alignment).ok()?;
200
201 unsafe {
202 let ptr = alloc_zeroed(layout);
203 NonNull::new(ptr)
204 }
205}
206
207pub unsafe fn free_aligned(ptr: NonNull<u8>, size: usize, alignment: usize) {
209 if let Ok(layout) = Layout::from_size_align(size, alignment) {
210 unsafe {
212 dealloc(ptr.as_ptr(), layout);
213 }
214 }
215}
216
217pub struct EmbeddingStorage {
223 data: NonNull<u8>,
225 size: usize,
227 header: EmbeddingBlockHeader,
229}
230
231impl EmbeddingStorage {
232 pub fn new(capacity: usize, dim: usize) -> Option<Self> {
234 let header = EmbeddingBlockHeader::new(capacity as u32, dim as u32);
235 let size = align_up(header.block_size(), BLOCK_ALIGNMENT);
236
237 let data = alloc_aligned(size, BLOCK_ALIGNMENT)?;
238
239 unsafe {
241 let header_ptr = data.as_ptr() as *mut EmbeddingBlockHeader;
242 header_ptr.write(header.clone());
243 }
244
245 Some(Self { data, size, header })
246 }
247
248 #[inline]
250 pub fn get(&self, index: usize) -> Option<&[f32]> {
251 if index >= self.header.count as usize {
252 return None;
253 }
254
255 let offset = self.header.data_offset as usize + index * self.header.stride as usize;
256
257 unsafe {
258 let ptr = self.data.as_ptr().add(offset) as *const f32;
259 Some(std::slice::from_raw_parts(ptr, self.header.dim as usize))
260 }
261 }
262
263 #[inline]
265 pub fn get_mut(&mut self, index: usize) -> Option<&mut [f32]> {
266 if index >= self.header.count as usize {
267 return None;
268 }
269
270 let offset = self.header.data_offset as usize + index * self.header.stride as usize;
271
272 unsafe {
273 let ptr = self.data.as_ptr().add(offset) as *mut f32;
274 Some(std::slice::from_raw_parts_mut(
275 ptr,
276 self.header.dim as usize,
277 ))
278 }
279 }
280
281 #[inline]
283 pub fn set(&mut self, index: usize, vector: &[f32]) -> bool {
284 if let Some(slot) = self.get_mut(index) {
285 if vector.len() == slot.len() {
286 slot.copy_from_slice(vector);
287 return true;
288 }
289 }
290 false
291 }
292
293 #[inline]
295 pub fn prefetch(&self, index: usize) {
296 if index < self.header.count as usize {
297 let offset = self.header.data_offset as usize + index * self.header.stride as usize;
298
299 unsafe {
300 let ptr = self.data.as_ptr().add(offset);
301
302 #[cfg(target_arch = "x86_64")]
303 {
304 use std::arch::x86_64::_mm_prefetch;
305 _mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr as *const i8);
306 }
307
308 #[cfg(target_arch = "aarch64")]
309 {
310 let _ = ptr;
312 }
313 }
314 }
315 }
316
317 pub fn dim(&self) -> usize {
319 self.header.dim as usize
320 }
321
322 pub fn capacity(&self) -> usize {
324 self.header.count as usize
325 }
326
327 pub fn stride(&self) -> usize {
329 self.header.stride as usize
330 }
331
332 #[inline]
334 pub fn as_ptr(&self) -> *const f32 {
335 unsafe { self.data.as_ptr().add(self.header.data_offset as usize) as *const f32 }
336 }
337
338 #[inline]
340 pub fn as_mut_ptr(&mut self) -> *mut f32 {
341 unsafe { self.data.as_ptr().add(self.header.data_offset as usize) as *mut f32 }
342 }
343}
344
345impl Drop for EmbeddingStorage {
346 fn drop(&mut self) {
347 unsafe {
348 free_aligned(self.data, self.size, BLOCK_ALIGNMENT);
349 }
350 }
351}
352
353unsafe impl Send for EmbeddingStorage {}
355unsafe impl Sync for EmbeddingStorage {}
356
357pub struct NeighborStorage {
363 data: NonNull<u8>,
365 size: usize,
367 header: NeighborBlockHeader,
369 edge_counts: Vec<AtomicU32>,
371}
372
373impl NeighborStorage {
374 pub fn new(node_count: usize, max_edges: usize) -> Option<Self> {
376 let header = NeighborBlockHeader::new(node_count as u32, max_edges as u32);
377 let size = align_up(header.block_size(), BLOCK_ALIGNMENT);
378
379 let data = alloc_aligned(size, BLOCK_ALIGNMENT)?;
380
381 unsafe {
383 let header_ptr = data.as_ptr() as *mut NeighborBlockHeader;
384 header_ptr.write(header.clone());
385 }
386
387 let edge_counts: Vec<AtomicU32> = (0..node_count).map(|_| AtomicU32::new(0)).collect();
389
390 Some(Self {
391 data,
392 size,
393 header,
394 edge_counts,
395 })
396 }
397 #[inline]
399 pub fn get_neighbors(&self, node: usize) -> Option<&[u32]> {
400 if node >= self.header.node_count as usize {
401 return None;
402 }
403
404 let offset = self.header.data_offset as usize + node * self.header.stride as usize;
405 let count = self.edge_counts[node].load(Ordering::Relaxed) as usize;
406
407 unsafe {
408 let ptr = self.data.as_ptr().add(offset) as *const u32;
409 Some(std::slice::from_raw_parts(
410 ptr,
411 count.min(self.header.max_edges as usize),
412 ))
413 }
414 }
415
416 #[inline]
418 fn get_neighbors_mut(&mut self, node: usize) -> Option<&mut [u32]> {
419 if node >= self.header.node_count as usize {
420 return None;
421 }
422
423 let offset = self.header.data_offset as usize + node * self.header.stride as usize;
424
425 unsafe {
426 let ptr = self.data.as_ptr().add(offset) as *mut u32;
427 Some(std::slice::from_raw_parts_mut(
428 ptr,
429 self.header.max_edges as usize,
430 ))
431 }
432 }
433
434 pub fn add_neighbor(&self, node: usize, neighbor: u32) -> bool {
436 if node >= self.header.node_count as usize {
437 return false;
438 }
439
440 let current = self.edge_counts[node].fetch_add(1, Ordering::AcqRel);
441
442 if current >= self.header.max_edges {
443 self.edge_counts[node].fetch_sub(1, Ordering::Release);
444 return false;
445 }
446
447 let offset = self.header.data_offset as usize + node * self.header.stride as usize;
448
449 unsafe {
450 let ptr = self.data.as_ptr().add(offset) as *mut u32;
451 ptr.add(current as usize).write(neighbor);
452 }
453
454 true
455 }
456
457 pub fn set_neighbors(&mut self, node: usize, neighbors: &[u32]) -> bool {
459 let max_edges = self.header.max_edges as usize;
460 if let Some(slot) = self.get_neighbors_mut(node) {
461 let count = neighbors.len().min(max_edges);
462 slot[..count].copy_from_slice(&neighbors[..count]);
463 self.edge_counts[node].store(count as u32, Ordering::Release);
464 true
465 } else {
466 false
467 }
468 }
469
470 #[inline]
472 pub fn prefetch(&self, node: usize) {
473 if node < self.header.node_count as usize {
474 let offset = self.header.data_offset as usize + node * self.header.stride as usize;
475
476 unsafe {
477 let ptr = self.data.as_ptr().add(offset);
478
479 #[cfg(target_arch = "x86_64")]
480 {
481 use std::arch::x86_64::_mm_prefetch;
482 _mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr as *const i8);
483 }
484
485 #[cfg(target_arch = "aarch64")]
486 {
487 let _ = ptr;
489 }
490 }
491 }
492 }
493
494 pub fn prefetch_neighbors(&self, embeddings: &EmbeddingStorage, node: usize) {
496 if let Some(neighbors) = self.get_neighbors(node) {
497 for &neighbor in neighbors.iter().take(4) {
499 embeddings.prefetch(neighbor as usize);
500 }
501 }
502 }
503
504 pub fn edge_count(&self, node: usize) -> usize {
506 if node < self.edge_counts.len() {
507 self.edge_counts[node].load(Ordering::Relaxed) as usize
508 } else {
509 0
510 }
511 }
512
513 pub fn max_edges(&self) -> usize {
515 self.header.max_edges as usize
516 }
517
518 pub fn node_count(&self) -> usize {
520 self.header.node_count as usize
521 }
522}
523
524impl Drop for NeighborStorage {
525 fn drop(&mut self) {
526 unsafe {
527 free_aligned(self.data, self.size, BLOCK_ALIGNMENT);
528 }
529 }
530}
531
532unsafe impl Send for NeighborStorage {}
533unsafe impl Sync for NeighborStorage {}
534
535pub struct HotPathVectorStore {
541 embeddings: EmbeddingStorage,
543 neighbors: Vec<NeighborStorage>,
545 entry_point: AtomicU32,
547 num_layers: usize,
549}
550
551impl HotPathVectorStore {
552 pub fn new(capacity: usize, dim: usize, num_layers: usize, max_edges: usize) -> Option<Self> {
554 let embeddings = EmbeddingStorage::new(capacity, dim)?;
555
556 let mut neighbors = Vec::with_capacity(num_layers);
557 for _ in 0..num_layers {
558 neighbors.push(NeighborStorage::new(capacity, max_edges)?);
559 }
560
561 Some(Self {
562 embeddings,
563 neighbors,
564 entry_point: AtomicU32::new(0),
565 num_layers,
566 })
567 }
568
569 #[inline]
571 pub fn get_embedding(&self, id: usize) -> Option<&[f32]> {
572 self.embeddings.get(id)
573 }
574
575 pub fn set_embedding(&mut self, id: usize, vector: &[f32]) -> bool {
577 self.embeddings.set(id, vector)
578 }
579
580 #[inline]
582 pub fn get_neighbors(&self, id: usize, layer: usize) -> Option<&[u32]> {
583 self.neighbors.get(layer)?.get_neighbors(id)
584 }
585
586 pub fn add_neighbor(&self, id: usize, layer: usize, neighbor: u32) -> bool {
588 if let Some(storage) = self.neighbors.get(layer) {
589 storage.add_neighbor(id, neighbor)
590 } else {
591 false
592 }
593 }
594
595 #[inline]
597 pub fn prefetch_node(&self, id: usize, layer: usize) {
598 self.embeddings.prefetch(id);
599 if let Some(neighbors) = self.neighbors.get(layer) {
600 neighbors.prefetch(id);
601 }
602 }
603
604 pub fn entry_point(&self) -> u32 {
606 self.entry_point.load(Ordering::Relaxed)
607 }
608
609 pub fn set_entry_point(&self, id: u32) {
611 self.entry_point.store(id, Ordering::Release);
612 }
613
614 pub fn dim(&self) -> usize {
616 self.embeddings.dim()
617 }
618
619 pub fn capacity(&self) -> usize {
621 self.embeddings.capacity()
622 }
623
624 pub fn num_layers(&self) -> usize {
626 self.num_layers
627 }
628}
629
630pub struct BatchDistanceComputer<'a> {
636 store: &'a HotPathVectorStore,
637 query: &'a [f32],
638}
639
640impl<'a> BatchDistanceComputer<'a> {
641 pub fn new(store: &'a HotPathVectorStore, query: &'a [f32]) -> Self {
643 Self { store, query }
644 }
645
646 pub fn compute_batch(&self, candidates: &[u32]) -> Vec<(u32, f32)> {
648 let mut results = Vec::with_capacity(candidates.len());
649
650 const PREFETCH_DISTANCE: usize = 4;
652
653 for (i, &id) in candidates.iter().enumerate() {
654 if i + PREFETCH_DISTANCE < candidates.len() {
656 self.store
657 .embeddings
658 .prefetch(candidates[i + PREFETCH_DISTANCE] as usize);
659 }
660
661 if let Some(vector) = self.store.get_embedding(id as usize) {
663 let dist = l2_distance(self.query, vector);
664 results.push((id, dist));
665 }
666 }
667
668 results
669 }
670}
671
672#[inline]
674fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
675 debug_assert_eq!(a.len(), b.len());
676
677 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_alignment() {
686 assert_eq!(align_up(100, 32), 128);
687 assert_eq!(align_up(128, 32), 128);
688 assert_eq!(align_up(129, 32), 160);
689 assert_eq!(align_down(100, 32), 96);
690 }
691
692 #[test]
693 fn test_embedding_storage() {
694 let mut storage = EmbeddingStorage::new(100, 128).unwrap();
695
696 let vector: Vec<f32> = (0..128).map(|i| i as f32).collect();
698 assert!(storage.set(0, &vector));
699
700 let retrieved = storage.get(0).unwrap();
701 assert_eq!(retrieved, vector.as_slice());
702
703 let ptr = storage.as_ptr();
705 assert_eq!(ptr as usize % SIMD_ALIGNMENT, 0);
706 }
707
708 #[test]
709 fn test_neighbor_storage() {
710 let mut storage = NeighborStorage::new(100, 32).unwrap();
711
712 assert!(storage.add_neighbor(0, 1));
714 assert!(storage.add_neighbor(0, 5));
715 assert!(storage.add_neighbor(0, 10));
716
717 let neighbors = storage.get_neighbors(0).unwrap();
718 assert_eq!(neighbors, &[1, 5, 10]);
719
720 storage.set_neighbors(1, &[2, 4, 6, 8]);
722 let neighbors = storage.get_neighbors(1).unwrap();
723 assert_eq!(neighbors, &[2, 4, 6, 8]);
724 }
725
726 #[test]
727 fn test_hot_path_store() {
728 let mut store = HotPathVectorStore::new(100, 64, 3, 16).unwrap();
729
730 let vector: Vec<f32> = (0..64).map(|i| i as f32).collect();
732 assert!(store.set_embedding(0, &vector));
733
734 store.set_entry_point(0);
736 assert_eq!(store.entry_point(), 0);
737
738 assert!(store.add_neighbor(0, 0, 1));
740 assert!(store.add_neighbor(0, 0, 2));
741
742 let neighbors = store.get_neighbors(0, 0).unwrap();
743 assert_eq!(neighbors, &[1, 2]);
744 }
745
746 #[test]
747 fn test_batch_distance() {
748 let mut store = HotPathVectorStore::new(10, 4, 1, 8).unwrap();
749
750 for i in 0..10 {
752 let vector: Vec<f32> = (0..4).map(|j| (i + j) as f32).collect();
753 store.set_embedding(i, &vector);
754 }
755
756 let query = vec![0.0, 1.0, 2.0, 3.0];
757 let computer = BatchDistanceComputer::new(&store, &query);
758
759 let candidates: Vec<u32> = (0..5).collect();
760 let results = computer.compute_batch(&candidates);
761
762 assert_eq!(results.len(), 5);
763 assert_eq!(results[0].0, 0); }
765}