1#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io::{self, Write};
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21#[derive(Debug)]
26pub struct AtomicBitmap {
27 bits: Vec<AtomicU64>,
29 size: usize,
31}
32
33impl AtomicBitmap {
34 pub fn new(size: usize) -> Self {
39 let num_words = (size + 63) / 64;
40 let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42 Self { bits, size }
43 }
44
45 pub fn set(&self, index: usize) {
50 if index >= self.size {
51 return;
52 }
53 let word_idx = index / 64;
54 let bit_idx = index % 64;
55 self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56 }
57
58 pub fn clear(&self, index: usize) {
63 if index >= self.size {
64 return;
65 }
66 let word_idx = index / 64;
67 let bit_idx = index % 64;
68 self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69 }
70
71 pub fn get(&self, index: usize) -> bool {
79 if index >= self.size {
80 return false;
81 }
82 let word_idx = index / 64;
83 let bit_idx = index % 64;
84 let word = self.bits[word_idx].load(Ordering::Acquire);
85 (word & (1u64 << bit_idx)) != 0
86 }
87
88 pub fn clear_all(&self) {
90 for word in &self.bits {
91 word.store(0, Ordering::Release);
92 }
93 }
94
95 pub fn get_set_indices(&self) -> Vec<usize> {
100 let mut indices = Vec::new();
101 for (word_idx, word) in self.bits.iter().enumerate() {
102 let mut w = word.load(Ordering::Acquire);
103 while w != 0 {
104 let bit_idx = w.trailing_zeros() as usize;
105 indices.push(word_idx * 64 + bit_idx);
106 w &= w - 1; }
108 }
109 indices
110 }
111}
112
113#[derive(Debug)]
118pub struct MmapManager {
119 file: File,
121 mmap: MmapMut,
123 page_size: usize,
125 d_embed: usize,
127 access_bitmap: AtomicBitmap,
129 dirty_bitmap: AtomicBitmap,
131 pin_count: Vec<AtomicU32>,
133 max_nodes: usize,
135}
136
137impl MmapManager {
138 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
148 let embedding_size = d_embed * std::mem::size_of::<f32>();
150 let file_size = max_nodes * embedding_size;
151
152 let file = OpenOptions::new()
154 .read(true)
155 .write(true)
156 .create(true)
157 .open(path)
158 .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
159
160 file.set_len(file_size as u64)
162 .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
163
164 let mmap = unsafe {
166 MmapOptions::new()
167 .len(file_size)
168 .map_mut(&file)
169 .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
170 };
171
172 let page_size = page_size::get();
174 let num_pages = (file_size + page_size - 1) / page_size;
175
176 Ok(Self {
177 file,
178 mmap,
179 page_size,
180 d_embed,
181 access_bitmap: AtomicBitmap::new(max_nodes),
182 dirty_bitmap: AtomicBitmap::new(max_nodes),
183 pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
184 max_nodes,
185 })
186 }
187
188 #[inline]
199 pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
200 let node_idx = usize::try_from(node_id).ok()?;
201 let elem_size = std::mem::size_of::<f32>();
202 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
203 }
204
205 #[inline]
207 fn validate_node_id(&self, node_id: u64) -> bool {
208 (node_id as usize) < self.max_nodes
209 }
210
211 pub fn get_embedding(&self, node_id: u64) -> &[f32] {
222 assert!(
224 self.validate_node_id(node_id),
225 "node_id {} out of bounds (max: {})",
226 node_id,
227 self.max_nodes
228 );
229
230 let offset = self
231 .embedding_offset(node_id)
232 .expect("embedding offset calculation overflow");
233 let end = offset
234 .checked_add(
235 self.d_embed
236 .checked_mul(std::mem::size_of::<f32>())
237 .unwrap(),
238 )
239 .expect("end offset overflow");
240 assert!(
241 end <= self.mmap.len(),
242 "embedding extends beyond mmap bounds"
243 );
244
245 self.access_bitmap.set(node_id as usize);
247
248 unsafe {
250 let ptr = self.mmap.as_ptr().add(offset) as *const f32;
251 std::slice::from_raw_parts(ptr, self.d_embed)
252 }
253 }
254
255 pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
265 assert!(
267 self.validate_node_id(node_id),
268 "node_id {} out of bounds (max: {})",
269 node_id,
270 self.max_nodes
271 );
272 assert_eq!(
273 data.len(),
274 self.d_embed,
275 "Embedding data length must match d_embed"
276 );
277
278 let offset = self
279 .embedding_offset(node_id)
280 .expect("embedding offset calculation overflow");
281 let end = offset
282 .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
283 .expect("end offset overflow");
284 assert!(
285 end <= self.mmap.len(),
286 "embedding extends beyond mmap bounds"
287 );
288
289 self.access_bitmap.set(node_id as usize);
291 self.dirty_bitmap.set(node_id as usize);
292
293 unsafe {
295 let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
296 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
297 }
298 }
299
300 pub fn flush_dirty(&self) -> io::Result<()> {
305 let dirty_nodes = self.dirty_bitmap.get_set_indices();
306
307 if dirty_nodes.is_empty() {
308 return Ok(());
309 }
310
311 self.mmap.flush()?;
314
315 for &node_id in &dirty_nodes {
317 self.dirty_bitmap.clear(node_id);
318 }
319
320 Ok(())
321 }
322
323 pub fn prefetch(&self, node_ids: &[u64]) {
328 #[cfg(target_os = "linux")]
329 {
330 #[allow(unused_imports)]
331 use std::os::unix::io::AsRawFd;
332
333 for &node_id in node_ids {
334 if !self.validate_node_id(node_id) {
336 continue;
337 }
338 let offset = match self.embedding_offset(node_id) {
339 Some(o) => o,
340 None => continue,
341 };
342 let page_offset = (offset / self.page_size) * self.page_size;
343 let length = self.d_embed * std::mem::size_of::<f32>();
344
345 unsafe {
346 libc::madvise(
348 self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
349 length,
350 libc::MADV_WILLNEED,
351 );
352 }
353 }
354 }
355
356 #[cfg(not(target_os = "linux"))]
358 {
359 for &node_id in node_ids {
360 if self.validate_node_id(node_id) {
361 let _ = self.get_embedding(node_id);
362 }
363 }
364 }
365 }
366
367 pub fn d_embed(&self) -> usize {
369 self.d_embed
370 }
371
372 pub fn max_nodes(&self) -> usize {
374 self.max_nodes
375 }
376}
377
378pub struct MmapGradientAccumulator {
383 grad_mmap: std::cell::UnsafeCell<MmapMut>,
385 lock_granularity: usize,
387 locks: Vec<RwLock<()>>,
389 n_nodes: usize,
391 d_embed: usize,
393 _file: File,
395}
396
397impl MmapGradientAccumulator {
398 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
408 let grad_size = d_embed * std::mem::size_of::<f32>();
410 let file_size = max_nodes * grad_size;
411
412 let file = OpenOptions::new()
414 .read(true)
415 .write(true)
416 .create(true)
417 .open(path)
418 .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
419
420 file.set_len(file_size as u64)
422 .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
423
424 let grad_mmap = unsafe {
426 MmapOptions::new()
427 .len(file_size)
428 .map_mut(&file)
429 .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
430 };
431
432 for byte in grad_mmap.iter() {
434 let _ = byte;
436 }
437
438 let lock_granularity = 64;
440 let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
441 let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
442
443 Ok(Self {
444 grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
445 lock_granularity,
446 locks,
447 n_nodes: max_nodes,
448 d_embed,
449 _file: file,
450 })
451 }
452
453 #[inline]
464 pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
465 let node_idx = usize::try_from(node_id).ok()?;
466 if node_idx >= self.n_nodes {
467 return None;
468 }
469 let elem_size = std::mem::size_of::<f32>();
470 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
471 }
472
473 pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
482 assert_eq!(
483 grad.len(),
484 self.d_embed,
485 "Gradient length must match d_embed"
486 );
487
488 let offset = self
489 .grad_offset(node_id)
490 .expect("node_id out of bounds or offset overflow");
491
492 let lock_idx = (node_id as usize) / self.lock_granularity;
493 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
494 let _lock = self.locks[lock_idx].write();
495
496 unsafe {
498 let mmap = &mut *self.grad_mmap.get();
499 assert!(
500 offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
501 "gradient write would exceed mmap bounds"
502 );
503 let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
504 let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
505
506 for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
508 *g += new_g;
509 }
510 }
511 }
512
513 pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
519 assert_eq!(
520 self.d_embed, embeddings.d_embed,
521 "Gradient and embedding dimensions must match"
522 );
523
524 for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
526 let grad = self.get_grad(node_id as u64);
527 let embedding = embeddings.get_embedding(node_id as u64);
528
529 let mut updated = vec![0.0f32; self.d_embed];
531 for i in 0..self.d_embed {
532 updated[i] = embedding[i] - learning_rate * grad[i];
533 }
534
535 embeddings.set_embedding(node_id as u64, &updated);
536 }
537
538 self.zero_grad();
540 }
541
542 pub fn zero_grad(&mut self) {
544 unsafe {
546 let mmap = &mut *self.grad_mmap.get();
547 for byte in mmap.iter_mut() {
548 *byte = 0;
549 }
550 }
551 }
552
553 pub fn get_grad(&self, node_id: u64) -> &[f32] {
561 let offset = self
562 .grad_offset(node_id)
563 .expect("node_id out of bounds or offset overflow");
564
565 let lock_idx = (node_id as usize) / self.lock_granularity;
566 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
567 let _lock = self.locks[lock_idx].read();
568
569 unsafe {
571 let mmap = &*self.grad_mmap.get();
572 assert!(
573 offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
574 "gradient read would exceed mmap bounds"
575 );
576 let ptr = mmap.as_ptr().add(offset) as *const f32;
577 std::slice::from_raw_parts(ptr, self.d_embed)
578 }
579 }
580
581 pub fn d_embed(&self) -> usize {
583 self.d_embed
584 }
585
586 pub fn n_nodes(&self) -> usize {
588 self.n_nodes
589 }
590}
591
592impl Drop for MmapManager {
594 fn drop(&mut self) {
595 let _ = self.flush_dirty();
597 }
598}
599
600impl Drop for MmapGradientAccumulator {
601 fn drop(&mut self) {
602 unsafe {
604 let mmap = &mut *self.grad_mmap.get();
605 let _ = mmap.flush();
606 }
607 }
608}
609
610unsafe impl Send for MmapGradientAccumulator {}
613unsafe impl Sync for MmapGradientAccumulator {}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use std::fs;
619 use tempfile::TempDir;
620
621 #[test]
622 fn test_atomic_bitmap_basic() {
623 let bitmap = AtomicBitmap::new(128);
624
625 assert!(!bitmap.get(0));
626 assert!(!bitmap.get(127));
627
628 bitmap.set(0);
629 bitmap.set(127);
630 bitmap.set(64);
631
632 assert!(bitmap.get(0));
633 assert!(bitmap.get(127));
634 assert!(bitmap.get(64));
635 assert!(!bitmap.get(1));
636
637 bitmap.clear(0);
638 assert!(!bitmap.get(0));
639 assert!(bitmap.get(127));
640 }
641
642 #[test]
643 fn test_atomic_bitmap_get_set_indices() {
644 let bitmap = AtomicBitmap::new(256);
645
646 bitmap.set(0);
647 bitmap.set(63);
648 bitmap.set(64);
649 bitmap.set(128);
650 bitmap.set(255);
651
652 let mut indices = bitmap.get_set_indices();
653 indices.sort();
654
655 assert_eq!(indices, vec![0, 63, 64, 128, 255]);
656 }
657
658 #[test]
659 fn test_atomic_bitmap_clear_all() {
660 let bitmap = AtomicBitmap::new(128);
661
662 bitmap.set(0);
663 bitmap.set(64);
664 bitmap.set(127);
665
666 assert!(bitmap.get(0));
667
668 bitmap.clear_all();
669
670 assert!(!bitmap.get(0));
671 assert!(!bitmap.get(64));
672 assert!(!bitmap.get(127));
673 }
674
675 #[test]
676 fn test_mmap_manager_creation() {
677 let temp_dir = TempDir::new().unwrap();
678 let path = temp_dir.path().join("embeddings.bin");
679
680 let manager = MmapManager::new(&path, 128, 1000).unwrap();
681
682 assert_eq!(manager.d_embed(), 128);
683 assert_eq!(manager.max_nodes(), 1000);
684 assert!(path.exists());
685 }
686
687 #[test]
688 fn test_mmap_manager_set_get_embedding() {
689 let temp_dir = TempDir::new().unwrap();
690 let path = temp_dir.path().join("embeddings.bin");
691
692 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
693
694 let embedding = vec![1.0f32; 64];
695 manager.set_embedding(0, &embedding);
696
697 let retrieved = manager.get_embedding(0);
698 assert_eq!(retrieved.len(), 64);
699 assert_eq!(retrieved[0], 1.0);
700 assert_eq!(retrieved[63], 1.0);
701 }
702
703 #[test]
704 fn test_mmap_manager_multiple_embeddings() {
705 let temp_dir = TempDir::new().unwrap();
706 let path = temp_dir.path().join("embeddings.bin");
707
708 let mut manager = MmapManager::new(&path, 32, 100).unwrap();
709
710 for i in 0..10 {
711 let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
712 manager.set_embedding(i, &embedding);
713 }
714
715 for i in 0..10 {
717 let retrieved = manager.get_embedding(i);
718 assert_eq!(retrieved.len(), 32);
719 assert_eq!(retrieved[0], (i * 32) as f32);
720 assert_eq!(retrieved[31], (i * 32 + 31) as f32);
721 }
722 }
723
724 #[test]
725 fn test_mmap_manager_dirty_tracking() {
726 let temp_dir = TempDir::new().unwrap();
727 let path = temp_dir.path().join("embeddings.bin");
728
729 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
730
731 let embedding = vec![2.0f32; 64];
732 manager.set_embedding(5, &embedding);
733
734 assert!(manager.dirty_bitmap.get(5));
736
737 manager.flush_dirty().unwrap();
739 assert!(!manager.dirty_bitmap.get(5));
740 }
741
742 #[test]
743 fn test_mmap_manager_persistence() {
744 let temp_dir = TempDir::new().unwrap();
745 let path = temp_dir.path().join("embeddings.bin");
746
747 {
748 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
749 let embedding = vec![3.14f32; 64];
750 manager.set_embedding(10, &embedding);
751 manager.flush_dirty().unwrap();
752 }
753
754 {
756 let manager = MmapManager::new(&path, 64, 100).unwrap();
757 let retrieved = manager.get_embedding(10);
758 assert_eq!(retrieved[0], 3.14);
759 assert_eq!(retrieved[63], 3.14);
760 }
761 }
762
763 #[test]
764 fn test_gradient_accumulator_creation() {
765 let temp_dir = TempDir::new().unwrap();
766 let path = temp_dir.path().join("gradients.bin");
767
768 let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
769
770 assert_eq!(accumulator.d_embed(), 128);
771 assert_eq!(accumulator.n_nodes(), 1000);
772 assert!(path.exists());
773 }
774
775 #[test]
776 fn test_gradient_accumulator_accumulate() {
777 let temp_dir = TempDir::new().unwrap();
778 let path = temp_dir.path().join("gradients.bin");
779
780 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
781
782 let grad1 = vec![1.0f32; 64];
783 let grad2 = vec![2.0f32; 64];
784
785 accumulator.accumulate(0, &grad1);
786 accumulator.accumulate(0, &grad2);
787
788 let accumulated = accumulator.get_grad(0);
789 assert_eq!(accumulated[0], 3.0);
790 assert_eq!(accumulated[63], 3.0);
791 }
792
793 #[test]
794 fn test_gradient_accumulator_zero_grad() {
795 let temp_dir = TempDir::new().unwrap();
796 let path = temp_dir.path().join("gradients.bin");
797
798 let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
799
800 let grad = vec![1.5f32; 64];
801 accumulator.accumulate(0, &grad);
802
803 let accumulated = accumulator.get_grad(0);
804 assert_eq!(accumulated[0], 1.5);
805
806 accumulator.zero_grad();
807
808 let zeroed = accumulator.get_grad(0);
809 assert_eq!(zeroed[0], 0.0);
810 assert_eq!(zeroed[63], 0.0);
811 }
812
813 #[test]
814 fn test_gradient_accumulator_apply() {
815 let temp_dir = TempDir::new().unwrap();
816 let embed_path = temp_dir.path().join("embeddings.bin");
817 let grad_path = temp_dir.path().join("gradients.bin");
818
819 let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
820 let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
821
822 let initial = vec![10.0f32; 32];
824 embeddings.set_embedding(0, &initial);
825
826 let grad = vec![1.0f32; 32];
828 accumulator.accumulate(0, &grad);
829
830 accumulator.apply(0.1, &mut embeddings);
832
833 let updated = embeddings.get_embedding(0);
835 assert!((updated[0] - 9.9).abs() < 1e-6);
836
837 let zeroed_grad = accumulator.get_grad(0);
839 assert_eq!(zeroed_grad[0], 0.0);
840 }
841
842 #[test]
843 fn test_gradient_accumulator_concurrent_accumulation() {
844 use std::thread;
845
846 let temp_dir = TempDir::new().unwrap();
847 let path = temp_dir.path().join("gradients.bin");
848
849 let accumulator =
850 std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
851
852 let mut handles = vec![];
853
854 for _ in 0..10 {
856 let acc = accumulator.clone();
857 let handle = thread::spawn(move || {
858 let grad = vec![1.0f32; 64];
859 acc.accumulate(0, &grad);
860 });
861 handles.push(handle);
862 }
863
864 for handle in handles {
865 handle.join().unwrap();
866 }
867
868 let result = accumulator.get_grad(0);
870 assert_eq!(result[0], 10.0);
871 }
872
873 #[test]
874 fn test_embedding_offset_calculation() {
875 let temp_dir = TempDir::new().unwrap();
876 let path = temp_dir.path().join("embeddings.bin");
877
878 let manager = MmapManager::new(&path, 64, 100).unwrap();
879
880 assert_eq!(manager.embedding_offset(0), Some(0));
881 assert_eq!(manager.embedding_offset(1), Some(64 * 4)); assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
883 }
884
885 #[test]
886 fn test_grad_offset_calculation() {
887 let temp_dir = TempDir::new().unwrap();
888 let path = temp_dir.path().join("gradients.bin");
889
890 let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
891
892 assert_eq!(accumulator.grad_offset(0), Some(0));
893 assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
895 }
896
897 #[test]
898 #[should_panic(expected = "Embedding data length must match d_embed")]
899 fn test_set_embedding_wrong_size() {
900 let temp_dir = TempDir::new().unwrap();
901 let path = temp_dir.path().join("embeddings.bin");
902
903 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
904 let wrong_size = vec![1.0f32; 32]; manager.set_embedding(0, &wrong_size);
906 }
907
908 #[test]
909 #[should_panic(expected = "Gradient length must match d_embed")]
910 fn test_accumulate_wrong_size() {
911 let temp_dir = TempDir::new().unwrap();
912 let path = temp_dir.path().join("gradients.bin");
913
914 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
915 let wrong_size = vec![1.0f32; 32]; accumulator.accumulate(0, &wrong_size);
917 }
918
919 #[test]
920 fn test_prefetch() {
921 let temp_dir = TempDir::new().unwrap();
922 let path = temp_dir.path().join("embeddings.bin");
923
924 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
925
926 for i in 0..10 {
928 let embedding = vec![i as f32; 64];
929 manager.set_embedding(i, &embedding);
930 }
931
932 manager.prefetch(&[0, 1, 2, 3, 4]);
934
935 let retrieved = manager.get_embedding(2);
937 assert_eq!(retrieved[0], 2.0);
938 }
939}