1#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io::{self, Write};
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21#[derive(Debug)]
26pub struct AtomicBitmap {
27 bits: Vec<AtomicU64>,
29 size: usize,
31}
32
33impl AtomicBitmap {
34 pub fn new(size: usize) -> Self {
39 let num_words = (size + 63) / 64;
40 let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42 Self { bits, size }
43 }
44
45 pub fn set(&self, index: usize) {
50 if index >= self.size {
51 return;
52 }
53 let word_idx = index / 64;
54 let bit_idx = index % 64;
55 self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56 }
57
58 pub fn clear(&self, index: usize) {
63 if index >= self.size {
64 return;
65 }
66 let word_idx = index / 64;
67 let bit_idx = index % 64;
68 self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69 }
70
71 pub fn get(&self, index: usize) -> bool {
79 if index >= self.size {
80 return false;
81 }
82 let word_idx = index / 64;
83 let bit_idx = index % 64;
84 let word = self.bits[word_idx].load(Ordering::Acquire);
85 (word & (1u64 << bit_idx)) != 0
86 }
87
88 pub fn clear_all(&self) {
90 for word in &self.bits {
91 word.store(0, Ordering::Release);
92 }
93 }
94
95 pub fn get_set_indices(&self) -> Vec<usize> {
100 let mut indices = Vec::new();
101 for (word_idx, word) in self.bits.iter().enumerate() {
102 let mut w = word.load(Ordering::Acquire);
103 while w != 0 {
104 let bit_idx = w.trailing_zeros() as usize;
105 indices.push(word_idx * 64 + bit_idx);
106 w &= w - 1; }
108 }
109 indices
110 }
111}
112
113#[derive(Debug)]
118pub struct MmapManager {
119 file: File,
121 mmap: MmapMut,
123 page_size: usize,
125 d_embed: usize,
127 access_bitmap: AtomicBitmap,
129 dirty_bitmap: AtomicBitmap,
131 pin_count: Vec<AtomicU32>,
133 max_nodes: usize,
135}
136
137impl MmapManager {
138 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
148 let embedding_size = d_embed * std::mem::size_of::<f32>();
150 let file_size = max_nodes * embedding_size;
151
152 let file = OpenOptions::new()
154 .read(true)
155 .write(true)
156 .create(true)
157 .open(path)
158 .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
159
160 file.set_len(file_size as u64)
162 .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
163
164 let mmap = unsafe {
166 MmapOptions::new()
167 .len(file_size)
168 .map_mut(&file)
169 .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
170 };
171
172 let page_size = page_size::get();
174 let num_pages = (file_size + page_size - 1) / page_size;
175
176 Ok(Self {
177 file,
178 mmap,
179 page_size,
180 d_embed,
181 access_bitmap: AtomicBitmap::new(max_nodes),
182 dirty_bitmap: AtomicBitmap::new(max_nodes),
183 pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
184 max_nodes,
185 })
186 }
187
188 #[inline]
199 pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
200 let node_idx = usize::try_from(node_id).ok()?;
201 let elem_size = std::mem::size_of::<f32>();
202 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
203 }
204
205 #[inline]
207 fn validate_node_id(&self, node_id: u64) -> bool {
208 (node_id as usize) < self.max_nodes
209 }
210
211 pub fn get_embedding(&self, node_id: u64) -> &[f32] {
222 assert!(
224 self.validate_node_id(node_id),
225 "node_id {} out of bounds (max: {})",
226 node_id,
227 self.max_nodes
228 );
229
230 let offset = self
231 .embedding_offset(node_id)
232 .expect("embedding offset calculation overflow");
233 let end = offset
234 .checked_add(
235 self.d_embed
236 .checked_mul(std::mem::size_of::<f32>())
237 .unwrap(),
238 )
239 .expect("end offset overflow");
240 assert!(
241 end <= self.mmap.len(),
242 "embedding extends beyond mmap bounds"
243 );
244
245 self.access_bitmap.set(node_id as usize);
247
248 unsafe {
250 let ptr = self.mmap.as_ptr().add(offset) as *const f32;
251 std::slice::from_raw_parts(ptr, self.d_embed)
252 }
253 }
254
255 pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
265 assert!(
267 self.validate_node_id(node_id),
268 "node_id {} out of bounds (max: {})",
269 node_id,
270 self.max_nodes
271 );
272 assert_eq!(
273 data.len(),
274 self.d_embed,
275 "Embedding data length must match d_embed"
276 );
277
278 let offset = self
279 .embedding_offset(node_id)
280 .expect("embedding offset calculation overflow");
281 let end = offset
282 .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
283 .expect("end offset overflow");
284 assert!(
285 end <= self.mmap.len(),
286 "embedding extends beyond mmap bounds"
287 );
288
289 self.access_bitmap.set(node_id as usize);
291 self.dirty_bitmap.set(node_id as usize);
292
293 unsafe {
295 let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
296 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
297 }
298 }
299
300 pub fn flush_dirty(&self) -> io::Result<()> {
305 let dirty_nodes = self.dirty_bitmap.get_set_indices();
306
307 if dirty_nodes.is_empty() {
308 return Ok(());
309 }
310
311 self.mmap.flush()?;
314
315 for &node_id in &dirty_nodes {
317 self.dirty_bitmap.clear(node_id);
318 }
319
320 Ok(())
321 }
322
323 pub fn prefetch(&self, node_ids: &[u64]) {
328 #[cfg(target_os = "linux")]
329 {
330 #[allow(unused_imports)]
331 use std::os::unix::io::AsRawFd;
332
333 for &node_id in node_ids {
334 if !self.validate_node_id(node_id) {
336 continue;
337 }
338 let offset = match self.embedding_offset(node_id) {
339 Some(o) => o,
340 None => continue,
341 };
342 let page_offset = (offset / self.page_size) * self.page_size;
343 let length = self.d_embed * std::mem::size_of::<f32>();
344
345 unsafe {
346 libc::madvise(
348 self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
349 length,
350 libc::MADV_WILLNEED,
351 );
352 }
353 }
354 }
355
356 #[cfg(not(target_os = "linux"))]
358 {
359 for &node_id in node_ids {
360 if self.validate_node_id(node_id) {
361 let _ = self.get_embedding(node_id);
362 }
363 }
364 }
365 }
366
367 pub fn d_embed(&self) -> usize {
369 self.d_embed
370 }
371
372 pub fn max_nodes(&self) -> usize {
374 self.max_nodes
375 }
376}
377
378pub struct MmapGradientAccumulator {
383 grad_mmap: std::cell::UnsafeCell<MmapMut>,
385 lock_granularity: usize,
387 locks: Vec<RwLock<()>>,
389 n_nodes: usize,
391 d_embed: usize,
393 _file: File,
395}
396
397impl MmapGradientAccumulator {
398 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
408 let grad_size = d_embed * std::mem::size_of::<f32>();
410 let file_size = max_nodes * grad_size;
411
412 let file = OpenOptions::new()
414 .read(true)
415 .write(true)
416 .create(true)
417 .open(path)
418 .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
419
420 file.set_len(file_size as u64)
422 .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
423
424 let grad_mmap = unsafe {
426 MmapOptions::new()
427 .len(file_size)
428 .map_mut(&file)
429 .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
430 };
431
432 for byte in grad_mmap.iter() {
434 let _ = byte;
436 }
437
438 let lock_granularity = 64;
440 let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
441 let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
442
443 Ok(Self {
444 grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
445 lock_granularity,
446 locks,
447 n_nodes: max_nodes,
448 d_embed,
449 _file: file,
450 })
451 }
452
453 #[inline]
464 pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
465 let node_idx = usize::try_from(node_id).ok()?;
466 if node_idx >= self.n_nodes {
467 return None;
468 }
469 let elem_size = std::mem::size_of::<f32>();
470 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
471 }
472
473 pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
482 assert_eq!(
483 grad.len(),
484 self.d_embed,
485 "Gradient length must match d_embed"
486 );
487
488 let offset = self.grad_offset(node_id)
489 .expect("node_id out of bounds or offset overflow");
490
491 let lock_idx = (node_id as usize) / self.lock_granularity;
492 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
493 let _lock = self.locks[lock_idx].write();
494
495 unsafe {
497 let mmap = &mut *self.grad_mmap.get();
498 assert!(offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
499 "gradient write would exceed mmap bounds");
500 let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
501 let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
502
503 for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
505 *g += new_g;
506 }
507 }
508 }
509
510 pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
516 assert_eq!(
517 self.d_embed, embeddings.d_embed,
518 "Gradient and embedding dimensions must match"
519 );
520
521 for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
523 let grad = self.get_grad(node_id as u64);
524 let embedding = embeddings.get_embedding(node_id as u64);
525
526 let mut updated = vec![0.0f32; self.d_embed];
528 for i in 0..self.d_embed {
529 updated[i] = embedding[i] - learning_rate * grad[i];
530 }
531
532 embeddings.set_embedding(node_id as u64, &updated);
533 }
534
535 self.zero_grad();
537 }
538
539 pub fn zero_grad(&mut self) {
541 unsafe {
543 let mmap = &mut *self.grad_mmap.get();
544 for byte in mmap.iter_mut() {
545 *byte = 0;
546 }
547 }
548 }
549
550 pub fn get_grad(&self, node_id: u64) -> &[f32] {
558 let offset = self.grad_offset(node_id)
559 .expect("node_id out of bounds or offset overflow");
560
561 let lock_idx = (node_id as usize) / self.lock_granularity;
562 assert!(lock_idx < self.locks.len(), "lock index out of bounds");
563 let _lock = self.locks[lock_idx].read();
564
565 unsafe {
567 let mmap = &*self.grad_mmap.get();
568 assert!(offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
569 "gradient read would exceed mmap bounds");
570 let ptr = mmap.as_ptr().add(offset) as *const f32;
571 std::slice::from_raw_parts(ptr, self.d_embed)
572 }
573 }
574
575 pub fn d_embed(&self) -> usize {
577 self.d_embed
578 }
579
580 pub fn n_nodes(&self) -> usize {
582 self.n_nodes
583 }
584}
585
586impl Drop for MmapManager {
588 fn drop(&mut self) {
589 let _ = self.flush_dirty();
591 }
592}
593
594impl Drop for MmapGradientAccumulator {
595 fn drop(&mut self) {
596 unsafe {
598 let mmap = &mut *self.grad_mmap.get();
599 let _ = mmap.flush();
600 }
601 }
602}
603
604unsafe impl Send for MmapGradientAccumulator {}
607unsafe impl Sync for MmapGradientAccumulator {}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612 use std::fs;
613 use tempfile::TempDir;
614
615 #[test]
616 fn test_atomic_bitmap_basic() {
617 let bitmap = AtomicBitmap::new(128);
618
619 assert!(!bitmap.get(0));
620 assert!(!bitmap.get(127));
621
622 bitmap.set(0);
623 bitmap.set(127);
624 bitmap.set(64);
625
626 assert!(bitmap.get(0));
627 assert!(bitmap.get(127));
628 assert!(bitmap.get(64));
629 assert!(!bitmap.get(1));
630
631 bitmap.clear(0);
632 assert!(!bitmap.get(0));
633 assert!(bitmap.get(127));
634 }
635
636 #[test]
637 fn test_atomic_bitmap_get_set_indices() {
638 let bitmap = AtomicBitmap::new(256);
639
640 bitmap.set(0);
641 bitmap.set(63);
642 bitmap.set(64);
643 bitmap.set(128);
644 bitmap.set(255);
645
646 let mut indices = bitmap.get_set_indices();
647 indices.sort();
648
649 assert_eq!(indices, vec![0, 63, 64, 128, 255]);
650 }
651
652 #[test]
653 fn test_atomic_bitmap_clear_all() {
654 let bitmap = AtomicBitmap::new(128);
655
656 bitmap.set(0);
657 bitmap.set(64);
658 bitmap.set(127);
659
660 assert!(bitmap.get(0));
661
662 bitmap.clear_all();
663
664 assert!(!bitmap.get(0));
665 assert!(!bitmap.get(64));
666 assert!(!bitmap.get(127));
667 }
668
669 #[test]
670 fn test_mmap_manager_creation() {
671 let temp_dir = TempDir::new().unwrap();
672 let path = temp_dir.path().join("embeddings.bin");
673
674 let manager = MmapManager::new(&path, 128, 1000).unwrap();
675
676 assert_eq!(manager.d_embed(), 128);
677 assert_eq!(manager.max_nodes(), 1000);
678 assert!(path.exists());
679 }
680
681 #[test]
682 fn test_mmap_manager_set_get_embedding() {
683 let temp_dir = TempDir::new().unwrap();
684 let path = temp_dir.path().join("embeddings.bin");
685
686 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
687
688 let embedding = vec![1.0f32; 64];
689 manager.set_embedding(0, &embedding);
690
691 let retrieved = manager.get_embedding(0);
692 assert_eq!(retrieved.len(), 64);
693 assert_eq!(retrieved[0], 1.0);
694 assert_eq!(retrieved[63], 1.0);
695 }
696
697 #[test]
698 fn test_mmap_manager_multiple_embeddings() {
699 let temp_dir = TempDir::new().unwrap();
700 let path = temp_dir.path().join("embeddings.bin");
701
702 let mut manager = MmapManager::new(&path, 32, 100).unwrap();
703
704 for i in 0..10 {
705 let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
706 manager.set_embedding(i, &embedding);
707 }
708
709 for i in 0..10 {
711 let retrieved = manager.get_embedding(i);
712 assert_eq!(retrieved.len(), 32);
713 assert_eq!(retrieved[0], (i * 32) as f32);
714 assert_eq!(retrieved[31], (i * 32 + 31) as f32);
715 }
716 }
717
718 #[test]
719 fn test_mmap_manager_dirty_tracking() {
720 let temp_dir = TempDir::new().unwrap();
721 let path = temp_dir.path().join("embeddings.bin");
722
723 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
724
725 let embedding = vec![2.0f32; 64];
726 manager.set_embedding(5, &embedding);
727
728 assert!(manager.dirty_bitmap.get(5));
730
731 manager.flush_dirty().unwrap();
733 assert!(!manager.dirty_bitmap.get(5));
734 }
735
736 #[test]
737 fn test_mmap_manager_persistence() {
738 let temp_dir = TempDir::new().unwrap();
739 let path = temp_dir.path().join("embeddings.bin");
740
741 {
742 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
743 let embedding = vec![3.14f32; 64];
744 manager.set_embedding(10, &embedding);
745 manager.flush_dirty().unwrap();
746 }
747
748 {
750 let manager = MmapManager::new(&path, 64, 100).unwrap();
751 let retrieved = manager.get_embedding(10);
752 assert_eq!(retrieved[0], 3.14);
753 assert_eq!(retrieved[63], 3.14);
754 }
755 }
756
757 #[test]
758 fn test_gradient_accumulator_creation() {
759 let temp_dir = TempDir::new().unwrap();
760 let path = temp_dir.path().join("gradients.bin");
761
762 let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
763
764 assert_eq!(accumulator.d_embed(), 128);
765 assert_eq!(accumulator.n_nodes(), 1000);
766 assert!(path.exists());
767 }
768
769 #[test]
770 fn test_gradient_accumulator_accumulate() {
771 let temp_dir = TempDir::new().unwrap();
772 let path = temp_dir.path().join("gradients.bin");
773
774 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
775
776 let grad1 = vec![1.0f32; 64];
777 let grad2 = vec![2.0f32; 64];
778
779 accumulator.accumulate(0, &grad1);
780 accumulator.accumulate(0, &grad2);
781
782 let accumulated = accumulator.get_grad(0);
783 assert_eq!(accumulated[0], 3.0);
784 assert_eq!(accumulated[63], 3.0);
785 }
786
787 #[test]
788 fn test_gradient_accumulator_zero_grad() {
789 let temp_dir = TempDir::new().unwrap();
790 let path = temp_dir.path().join("gradients.bin");
791
792 let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
793
794 let grad = vec![1.5f32; 64];
795 accumulator.accumulate(0, &grad);
796
797 let accumulated = accumulator.get_grad(0);
798 assert_eq!(accumulated[0], 1.5);
799
800 accumulator.zero_grad();
801
802 let zeroed = accumulator.get_grad(0);
803 assert_eq!(zeroed[0], 0.0);
804 assert_eq!(zeroed[63], 0.0);
805 }
806
807 #[test]
808 fn test_gradient_accumulator_apply() {
809 let temp_dir = TempDir::new().unwrap();
810 let embed_path = temp_dir.path().join("embeddings.bin");
811 let grad_path = temp_dir.path().join("gradients.bin");
812
813 let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
814 let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
815
816 let initial = vec![10.0f32; 32];
818 embeddings.set_embedding(0, &initial);
819
820 let grad = vec![1.0f32; 32];
822 accumulator.accumulate(0, &grad);
823
824 accumulator.apply(0.1, &mut embeddings);
826
827 let updated = embeddings.get_embedding(0);
829 assert!((updated[0] - 9.9).abs() < 1e-6);
830
831 let zeroed_grad = accumulator.get_grad(0);
833 assert_eq!(zeroed_grad[0], 0.0);
834 }
835
836 #[test]
837 fn test_gradient_accumulator_concurrent_accumulation() {
838 use std::thread;
839
840 let temp_dir = TempDir::new().unwrap();
841 let path = temp_dir.path().join("gradients.bin");
842
843 let accumulator =
844 std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
845
846 let mut handles = vec![];
847
848 for _ in 0..10 {
850 let acc = accumulator.clone();
851 let handle = thread::spawn(move || {
852 let grad = vec![1.0f32; 64];
853 acc.accumulate(0, &grad);
854 });
855 handles.push(handle);
856 }
857
858 for handle in handles {
859 handle.join().unwrap();
860 }
861
862 let result = accumulator.get_grad(0);
864 assert_eq!(result[0], 10.0);
865 }
866
867 #[test]
868 fn test_embedding_offset_calculation() {
869 let temp_dir = TempDir::new().unwrap();
870 let path = temp_dir.path().join("embeddings.bin");
871
872 let manager = MmapManager::new(&path, 64, 100).unwrap();
873
874 assert_eq!(manager.embedding_offset(0), Some(0));
875 assert_eq!(manager.embedding_offset(1), Some(64 * 4)); assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
877 }
878
879 #[test]
880 fn test_grad_offset_calculation() {
881 let temp_dir = TempDir::new().unwrap();
882 let path = temp_dir.path().join("gradients.bin");
883
884 let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
885
886 assert_eq!(accumulator.grad_offset(0), Some(0));
887 assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
889 }
890
891 #[test]
892 #[should_panic(expected = "Embedding data length must match d_embed")]
893 fn test_set_embedding_wrong_size() {
894 let temp_dir = TempDir::new().unwrap();
895 let path = temp_dir.path().join("embeddings.bin");
896
897 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
898 let wrong_size = vec![1.0f32; 32]; manager.set_embedding(0, &wrong_size);
900 }
901
902 #[test]
903 #[should_panic(expected = "Gradient length must match d_embed")]
904 fn test_accumulate_wrong_size() {
905 let temp_dir = TempDir::new().unwrap();
906 let path = temp_dir.path().join("gradients.bin");
907
908 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
909 let wrong_size = vec![1.0f32; 32]; accumulator.accumulate(0, &wrong_size);
911 }
912
913 #[test]
914 fn test_prefetch() {
915 let temp_dir = TempDir::new().unwrap();
916 let path = temp_dir.path().join("embeddings.bin");
917
918 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
919
920 for i in 0..10 {
922 let embedding = vec![i as f32; 64];
923 manager.set_embedding(i, &embedding);
924 }
925
926 manager.prefetch(&[0, 1, 2, 3, 4]);
928
929 let retrieved = manager.get_embedding(2);
931 assert_eq!(retrieved[0], 2.0);
932 }
933}