1#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io;
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21#[derive(Debug)]
26pub struct AtomicBitmap {
27 bits: Vec<AtomicU64>,
29 size: usize,
31}
32
33impl AtomicBitmap {
34 pub fn new(size: usize) -> Self {
39 let num_words = size.div_ceil(64);
40 let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42 Self { bits, size }
43 }
44
45 pub fn set(&self, index: usize) {
50 if index >= self.size {
51 return;
52 }
53 let word_idx = index / 64;
54 let bit_idx = index % 64;
55 self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56 }
57
58 pub fn clear(&self, index: usize) {
63 if index >= self.size {
64 return;
65 }
66 let word_idx = index / 64;
67 let bit_idx = index % 64;
68 self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69 }
70
71 pub fn get(&self, index: usize) -> bool {
79 if index >= self.size {
80 return false;
81 }
82 let word_idx = index / 64;
83 let bit_idx = index % 64;
84 let word = self.bits[word_idx].load(Ordering::Acquire);
85 (word & (1u64 << bit_idx)) != 0
86 }
87
88 pub fn clear_all(&self) {
90 for word in &self.bits {
91 word.store(0, Ordering::Release);
92 }
93 }
94
95 pub fn get_set_indices(&self) -> Vec<usize> {
100 let mut indices = Vec::new();
101 for (word_idx, word) in self.bits.iter().enumerate() {
102 let mut w = word.load(Ordering::Acquire);
103 while w != 0 {
104 let bit_idx = w.trailing_zeros() as usize;
105 indices.push(word_idx * 64 + bit_idx);
106 w &= w - 1; }
108 }
109 indices
110 }
111}
112
113#[derive(Debug)]
118pub struct MmapManager {
119 #[allow(dead_code)]
121 file: File,
122 mmap: MmapMut,
124 page_size: usize,
126 d_embed: usize,
128 access_bitmap: AtomicBitmap,
130 dirty_bitmap: AtomicBitmap,
132 #[allow(dead_code)]
134 pin_count: Vec<AtomicU32>,
135 max_nodes: usize,
137}
138
139impl MmapManager {
140 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
150 let embedding_size = d_embed * std::mem::size_of::<f32>();
152 let file_size = max_nodes * embedding_size;
153
154 let file = OpenOptions::new()
156 .read(true)
157 .write(true)
158 .create(true)
159 .truncate(false)
160 .open(path)
161 .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
162
163 file.set_len(file_size as u64)
165 .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
166
167 let mmap = unsafe {
169 MmapOptions::new()
170 .len(file_size)
171 .map_mut(&file)
172 .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
173 };
174
175 let page_size = page_size::get();
177 let num_pages = file_size.div_ceil(page_size);
178
179 Ok(Self {
180 file,
181 mmap,
182 page_size,
183 d_embed,
184 access_bitmap: AtomicBitmap::new(max_nodes),
185 dirty_bitmap: AtomicBitmap::new(max_nodes),
186 pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
187 max_nodes,
188 })
189 }
190
191 #[inline]
202 pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
203 let node_idx = usize::try_from(node_id).ok()?;
204 let elem_size = std::mem::size_of::<f32>();
205 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
206 }
207
208 #[inline]
210 fn validate_node_id(&self, node_id: u64) -> bool {
211 (node_id as usize) < self.max_nodes
212 }
213
214 pub fn get_embedding(&self, node_id: u64) -> &[f32] {
225 assert!(
227 self.validate_node_id(node_id),
228 "node_id {} out of bounds (max: {})",
229 node_id,
230 self.max_nodes
231 );
232
233 let offset = self
234 .embedding_offset(node_id)
235 .expect("embedding offset calculation overflow");
236 let end = offset
237 .checked_add(
238 self.d_embed
239 .checked_mul(std::mem::size_of::<f32>())
240 .unwrap(),
241 )
242 .expect("end offset overflow");
243 assert!(
244 end <= self.mmap.len(),
245 "embedding extends beyond mmap bounds"
246 );
247
248 self.access_bitmap.set(node_id as usize);
250
251 unsafe {
253 let ptr = self.mmap.as_ptr().add(offset) as *const f32;
254 std::slice::from_raw_parts(ptr, self.d_embed)
255 }
256 }
257
258 pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
268 assert!(
270 self.validate_node_id(node_id),
271 "node_id {} out of bounds (max: {})",
272 node_id,
273 self.max_nodes
274 );
275 assert_eq!(
276 data.len(),
277 self.d_embed,
278 "Embedding data length must match d_embed"
279 );
280
281 let offset = self
282 .embedding_offset(node_id)
283 .expect("embedding offset calculation overflow");
284 let end = offset
285 .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
286 .expect("end offset overflow");
287 assert!(
288 end <= self.mmap.len(),
289 "embedding extends beyond mmap bounds"
290 );
291
292 self.access_bitmap.set(node_id as usize);
294 self.dirty_bitmap.set(node_id as usize);
295
296 unsafe {
298 let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
299 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
300 }
301 }
302
303 pub fn flush_dirty(&self) -> io::Result<()> {
308 let dirty_nodes = self.dirty_bitmap.get_set_indices();
309
310 if dirty_nodes.is_empty() {
311 return Ok(());
312 }
313
314 self.mmap.flush()?;
317
318 for &node_id in &dirty_nodes {
320 self.dirty_bitmap.clear(node_id);
321 }
322
323 Ok(())
324 }
325
326 pub fn prefetch(&self, node_ids: &[u64]) {
331 #[cfg(target_os = "linux")]
332 {
333 #[allow(unused_imports)]
334 use std::os::unix::io::AsRawFd;
335
336 for &node_id in node_ids {
337 if !self.validate_node_id(node_id) {
339 continue;
340 }
341 let offset = match self.embedding_offset(node_id) {
342 Some(o) => o,
343 None => continue,
344 };
345 let page_offset = (offset / self.page_size) * self.page_size;
346 let length = self.d_embed * std::mem::size_of::<f32>();
347
348 unsafe {
349 libc::madvise(
351 self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
352 length,
353 libc::MADV_WILLNEED,
354 );
355 }
356 }
357 }
358
359 #[cfg(not(target_os = "linux"))]
361 {
362 for &node_id in node_ids {
363 if self.validate_node_id(node_id) {
364 let _ = self.get_embedding(node_id);
365 }
366 }
367 }
368 }
369
370 pub fn d_embed(&self) -> usize {
372 self.d_embed
373 }
374
375 pub fn max_nodes(&self) -> usize {
377 self.max_nodes
378 }
379}
380
381pub struct MmapGradientAccumulator {
386 grad_mmap: std::cell::UnsafeCell<MmapMut>,
388 lock_granularity: usize,
390 locks: Vec<RwLock<()>>,
392 n_nodes: usize,
394 d_embed: usize,
396 _file: File,
398}
399
400impl MmapGradientAccumulator {
401 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
411 let grad_size = d_embed * std::mem::size_of::<f32>();
413 let file_size = max_nodes * grad_size;
414
415 let file = OpenOptions::new()
417 .read(true)
418 .write(true)
419 .create(true)
420 .truncate(false)
421 .open(path)
422 .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
423
424 file.set_len(file_size as u64)
426 .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
427
428 let grad_mmap = unsafe {
430 MmapOptions::new()
431 .len(file_size)
432 .map_mut(&file)
433 .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
434 };
435
436 for byte in grad_mmap.iter() {
438 let _ = byte;
440 }
441
442 let lock_granularity = 64;
444 let num_locks = max_nodes.div_ceil(lock_granularity);
445 let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
446
447 Ok(Self {
448 grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
449 lock_granularity,
450 locks,
451 n_nodes: max_nodes,
452 d_embed,
453 _file: file,
454 })
455 }
456
457 #[inline]
468 pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
469 let node_idx = usize::try_from(node_id).ok()?;
470 if node_idx >= self.n_nodes {
471 return None;
472 }
473 let elem_size = std::mem::size_of::<f32>();
474 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
475 }
476
477 pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
486 assert_eq!(
487 grad.len(),
488 self.d_embed,
489 "Gradient length must match d_embed"
490 );
491
492 let offset = self
493 .grad_offset(node_id)
494 .expect("node_id out of bounds or offset overflow");
495
496 let lock_idx = (node_id as usize) / self.lock_granularity;
497 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
498 let _lock = self.locks[lock_idx].write();
499
500 unsafe {
502 let mmap = &mut *self.grad_mmap.get();
503 assert!(
504 offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
505 "gradient write would exceed mmap bounds"
506 );
507 let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
508 let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
509
510 for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
512 *g += new_g;
513 }
514 }
515 }
516
517 pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
523 assert_eq!(
524 self.d_embed, embeddings.d_embed,
525 "Gradient and embedding dimensions must match"
526 );
527
528 for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
530 let grad = self.get_grad(node_id as u64);
531 let embedding = embeddings.get_embedding(node_id as u64);
532
533 let mut updated = vec![0.0f32; self.d_embed];
535 for i in 0..self.d_embed {
536 updated[i] = embedding[i] - learning_rate * grad[i];
537 }
538
539 embeddings.set_embedding(node_id as u64, &updated);
540 }
541
542 self.zero_grad();
544 }
545
546 pub fn zero_grad(&mut self) {
548 unsafe {
550 let mmap = &mut *self.grad_mmap.get();
551 for byte in mmap.iter_mut() {
552 *byte = 0;
553 }
554 }
555 }
556
557 pub fn get_grad(&self, node_id: u64) -> &[f32] {
565 let offset = self
566 .grad_offset(node_id)
567 .expect("node_id out of bounds or offset overflow");
568
569 let lock_idx = (node_id as usize) / self.lock_granularity;
570 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
571 let _lock = self.locks[lock_idx].read();
572
573 unsafe {
575 let mmap = &*self.grad_mmap.get();
576 assert!(
577 offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
578 "gradient read would exceed mmap bounds"
579 );
580 let ptr = mmap.as_ptr().add(offset) as *const f32;
581 std::slice::from_raw_parts(ptr, self.d_embed)
582 }
583 }
584
585 pub fn d_embed(&self) -> usize {
587 self.d_embed
588 }
589
590 pub fn n_nodes(&self) -> usize {
592 self.n_nodes
593 }
594}
595
596impl Drop for MmapManager {
598 fn drop(&mut self) {
599 let _ = self.flush_dirty();
601 }
602}
603
604impl Drop for MmapGradientAccumulator {
605 fn drop(&mut self) {
606 unsafe {
608 let mmap = &mut *self.grad_mmap.get();
609 let _ = mmap.flush();
610 }
611 }
612}
613
614unsafe impl Send for MmapGradientAccumulator {}
617unsafe impl Sync for MmapGradientAccumulator {}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use tempfile::TempDir;
623
624 #[test]
625 fn test_atomic_bitmap_basic() {
626 let bitmap = AtomicBitmap::new(128);
627
628 assert!(!bitmap.get(0));
629 assert!(!bitmap.get(127));
630
631 bitmap.set(0);
632 bitmap.set(127);
633 bitmap.set(64);
634
635 assert!(bitmap.get(0));
636 assert!(bitmap.get(127));
637 assert!(bitmap.get(64));
638 assert!(!bitmap.get(1));
639
640 bitmap.clear(0);
641 assert!(!bitmap.get(0));
642 assert!(bitmap.get(127));
643 }
644
645 #[test]
646 fn test_atomic_bitmap_get_set_indices() {
647 let bitmap = AtomicBitmap::new(256);
648
649 bitmap.set(0);
650 bitmap.set(63);
651 bitmap.set(64);
652 bitmap.set(128);
653 bitmap.set(255);
654
655 let mut indices = bitmap.get_set_indices();
656 indices.sort();
657
658 assert_eq!(indices, vec![0, 63, 64, 128, 255]);
659 }
660
661 #[test]
662 fn test_atomic_bitmap_clear_all() {
663 let bitmap = AtomicBitmap::new(128);
664
665 bitmap.set(0);
666 bitmap.set(64);
667 bitmap.set(127);
668
669 assert!(bitmap.get(0));
670
671 bitmap.clear_all();
672
673 assert!(!bitmap.get(0));
674 assert!(!bitmap.get(64));
675 assert!(!bitmap.get(127));
676 }
677
678 #[test]
679 fn test_mmap_manager_creation() {
680 let temp_dir = TempDir::new().unwrap();
681 let path = temp_dir.path().join("embeddings.bin");
682
683 let manager = MmapManager::new(&path, 128, 1000).unwrap();
684
685 assert_eq!(manager.d_embed(), 128);
686 assert_eq!(manager.max_nodes(), 1000);
687 assert!(path.exists());
688 }
689
690 #[test]
691 fn test_mmap_manager_set_get_embedding() {
692 let temp_dir = TempDir::new().unwrap();
693 let path = temp_dir.path().join("embeddings.bin");
694
695 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
696
697 let embedding = vec![1.0f32; 64];
698 manager.set_embedding(0, &embedding);
699
700 let retrieved = manager.get_embedding(0);
701 assert_eq!(retrieved.len(), 64);
702 assert_eq!(retrieved[0], 1.0);
703 assert_eq!(retrieved[63], 1.0);
704 }
705
706 #[test]
707 fn test_mmap_manager_multiple_embeddings() {
708 let temp_dir = TempDir::new().unwrap();
709 let path = temp_dir.path().join("embeddings.bin");
710
711 let mut manager = MmapManager::new(&path, 32, 100).unwrap();
712
713 for i in 0..10 {
714 let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
715 manager.set_embedding(i, &embedding);
716 }
717
718 for i in 0..10 {
720 let retrieved = manager.get_embedding(i);
721 assert_eq!(retrieved.len(), 32);
722 assert_eq!(retrieved[0], (i * 32) as f32);
723 assert_eq!(retrieved[31], (i * 32 + 31) as f32);
724 }
725 }
726
727 #[test]
728 fn test_mmap_manager_dirty_tracking() {
729 let temp_dir = TempDir::new().unwrap();
730 let path = temp_dir.path().join("embeddings.bin");
731
732 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
733
734 let embedding = vec![2.0f32; 64];
735 manager.set_embedding(5, &embedding);
736
737 assert!(manager.dirty_bitmap.get(5));
739
740 manager.flush_dirty().unwrap();
742 assert!(!manager.dirty_bitmap.get(5));
743 }
744
745 #[test]
746 fn test_mmap_manager_persistence() {
747 let temp_dir = TempDir::new().unwrap();
748 let path = temp_dir.path().join("embeddings.bin");
749
750 {
751 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
752 let embedding = vec![1.5f32; 64];
753 manager.set_embedding(10, &embedding);
754 manager.flush_dirty().unwrap();
755 }
756
757 {
759 let manager = MmapManager::new(&path, 64, 100).unwrap();
760 let retrieved = manager.get_embedding(10);
761 assert_eq!(retrieved[0], 1.5);
762 assert_eq!(retrieved[63], 1.5);
763 }
764 }
765
766 #[test]
767 fn test_gradient_accumulator_creation() {
768 let temp_dir = TempDir::new().unwrap();
769 let path = temp_dir.path().join("gradients.bin");
770
771 let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
772
773 assert_eq!(accumulator.d_embed(), 128);
774 assert_eq!(accumulator.n_nodes(), 1000);
775 assert!(path.exists());
776 }
777
778 #[test]
779 fn test_gradient_accumulator_accumulate() {
780 let temp_dir = TempDir::new().unwrap();
781 let path = temp_dir.path().join("gradients.bin");
782
783 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
784
785 let grad1 = vec![1.0f32; 64];
786 let grad2 = vec![2.0f32; 64];
787
788 accumulator.accumulate(0, &grad1);
789 accumulator.accumulate(0, &grad2);
790
791 let accumulated = accumulator.get_grad(0);
792 assert_eq!(accumulated[0], 3.0);
793 assert_eq!(accumulated[63], 3.0);
794 }
795
796 #[test]
797 fn test_gradient_accumulator_zero_grad() {
798 let temp_dir = TempDir::new().unwrap();
799 let path = temp_dir.path().join("gradients.bin");
800
801 let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
802
803 let grad = vec![1.5f32; 64];
804 accumulator.accumulate(0, &grad);
805
806 let accumulated = accumulator.get_grad(0);
807 assert_eq!(accumulated[0], 1.5);
808
809 accumulator.zero_grad();
810
811 let zeroed = accumulator.get_grad(0);
812 assert_eq!(zeroed[0], 0.0);
813 assert_eq!(zeroed[63], 0.0);
814 }
815
816 #[test]
817 fn test_gradient_accumulator_apply() {
818 let temp_dir = TempDir::new().unwrap();
819 let embed_path = temp_dir.path().join("embeddings.bin");
820 let grad_path = temp_dir.path().join("gradients.bin");
821
822 let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
823 let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
824
825 let initial = vec![10.0f32; 32];
827 embeddings.set_embedding(0, &initial);
828
829 let grad = vec![1.0f32; 32];
831 accumulator.accumulate(0, &grad);
832
833 accumulator.apply(0.1, &mut embeddings);
835
836 let updated = embeddings.get_embedding(0);
838 assert!((updated[0] - 9.9).abs() < 1e-6);
839
840 let zeroed_grad = accumulator.get_grad(0);
842 assert_eq!(zeroed_grad[0], 0.0);
843 }
844
845 #[test]
846 fn test_gradient_accumulator_concurrent_accumulation() {
847 use std::thread;
848
849 let temp_dir = TempDir::new().unwrap();
850 let path = temp_dir.path().join("gradients.bin");
851
852 let accumulator =
853 std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
854
855 let mut handles = vec![];
856
857 for _ in 0..10 {
859 let acc = accumulator.clone();
860 let handle = thread::spawn(move || {
861 let grad = vec![1.0f32; 64];
862 acc.accumulate(0, &grad);
863 });
864 handles.push(handle);
865 }
866
867 for handle in handles {
868 handle.join().unwrap();
869 }
870
871 let result = accumulator.get_grad(0);
873 assert_eq!(result[0], 10.0);
874 }
875
876 #[test]
877 fn test_embedding_offset_calculation() {
878 let temp_dir = TempDir::new().unwrap();
879 let path = temp_dir.path().join("embeddings.bin");
880
881 let manager = MmapManager::new(&path, 64, 100).unwrap();
882
883 assert_eq!(manager.embedding_offset(0), Some(0));
884 assert_eq!(manager.embedding_offset(1), Some(64 * 4)); assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
886 }
887
888 #[test]
889 fn test_grad_offset_calculation() {
890 let temp_dir = TempDir::new().unwrap();
891 let path = temp_dir.path().join("gradients.bin");
892
893 let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
894
895 assert_eq!(accumulator.grad_offset(0), Some(0));
896 assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
898 }
899
900 #[test]
901 #[should_panic(expected = "Embedding data length must match d_embed")]
902 fn test_set_embedding_wrong_size() {
903 let temp_dir = TempDir::new().unwrap();
904 let path = temp_dir.path().join("embeddings.bin");
905
906 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
907 let wrong_size = vec![1.0f32; 32]; manager.set_embedding(0, &wrong_size);
909 }
910
911 #[test]
912 #[should_panic(expected = "Gradient length must match d_embed")]
913 fn test_accumulate_wrong_size() {
914 let temp_dir = TempDir::new().unwrap();
915 let path = temp_dir.path().join("gradients.bin");
916
917 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
918 let wrong_size = vec![1.0f32; 32]; accumulator.accumulate(0, &wrong_size);
920 }
921
922 #[test]
923 fn test_prefetch() {
924 let temp_dir = TempDir::new().unwrap();
925 let path = temp_dir.path().join("embeddings.bin");
926
927 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
928
929 for i in 0..10 {
931 let embedding = vec![i as f32; 64];
932 manager.set_embedding(i, &embedding);
933 }
934
935 manager.prefetch(&[0, 1, 2, 3, 4]);
937
938 let retrieved = manager.get_embedding(2);
940 assert_eq!(retrieved[0], 2.0);
941 }
942}