1#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io::{self, Write};
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21#[derive(Debug)]
26pub struct AtomicBitmap {
27 bits: Vec<AtomicU64>,
29 size: usize,
31}
32
33impl AtomicBitmap {
34 pub fn new(size: usize) -> Self {
39 let num_words = (size + 63) / 64;
40 let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42 Self { bits, size }
43 }
44
45 pub fn set(&self, index: usize) {
50 if index >= self.size {
51 return;
52 }
53 let word_idx = index / 64;
54 let bit_idx = index % 64;
55 self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56 }
57
58 pub fn clear(&self, index: usize) {
63 if index >= self.size {
64 return;
65 }
66 let word_idx = index / 64;
67 let bit_idx = index % 64;
68 self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69 }
70
71 pub fn get(&self, index: usize) -> bool {
79 if index >= self.size {
80 return false;
81 }
82 let word_idx = index / 64;
83 let bit_idx = index % 64;
84 let word = self.bits[word_idx].load(Ordering::Acquire);
85 (word & (1u64 << bit_idx)) != 0
86 }
87
88 pub fn clear_all(&self) {
90 for word in &self.bits {
91 word.store(0, Ordering::Release);
92 }
93 }
94
95 pub fn get_set_indices(&self) -> Vec<usize> {
100 let mut indices = Vec::new();
101 for (word_idx, word) in self.bits.iter().enumerate() {
102 let mut w = word.load(Ordering::Acquire);
103 while w != 0 {
104 let bit_idx = w.trailing_zeros() as usize;
105 indices.push(word_idx * 64 + bit_idx);
106 w &= w - 1; }
108 }
109 indices
110 }
111}
112
113#[derive(Debug)]
118pub struct MmapManager {
119 file: File,
121 mmap: MmapMut,
123 page_size: usize,
125 d_embed: usize,
127 access_bitmap: AtomicBitmap,
129 dirty_bitmap: AtomicBitmap,
131 pin_count: Vec<AtomicU32>,
133 max_nodes: usize,
135}
136
137impl MmapManager {
138 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
148 let embedding_size = d_embed * std::mem::size_of::<f32>();
150 let file_size = max_nodes * embedding_size;
151
152 let file = OpenOptions::new()
154 .read(true)
155 .write(true)
156 .create(true)
157 .open(path)
158 .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
159
160 file.set_len(file_size as u64)
162 .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
163
164 let mmap = unsafe {
166 MmapOptions::new()
167 .len(file_size)
168 .map_mut(&file)
169 .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
170 };
171
172 let page_size = page_size::get();
174 let num_pages = (file_size + page_size - 1) / page_size;
175
176 Ok(Self {
177 file,
178 mmap,
179 page_size,
180 d_embed,
181 access_bitmap: AtomicBitmap::new(max_nodes),
182 dirty_bitmap: AtomicBitmap::new(max_nodes),
183 pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
184 max_nodes,
185 })
186 }
187
188 #[inline]
199 pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
200 let node_idx = usize::try_from(node_id).ok()?;
201 let elem_size = std::mem::size_of::<f32>();
202 node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
203 }
204
205 #[inline]
207 fn validate_node_id(&self, node_id: u64) -> bool {
208 (node_id as usize) < self.max_nodes
209 }
210
211 pub fn get_embedding(&self, node_id: u64) -> &[f32] {
222 assert!(
224 self.validate_node_id(node_id),
225 "node_id {} out of bounds (max: {})",
226 node_id,
227 self.max_nodes
228 );
229
230 let offset = self
231 .embedding_offset(node_id)
232 .expect("embedding offset calculation overflow");
233 let end = offset
234 .checked_add(
235 self.d_embed
236 .checked_mul(std::mem::size_of::<f32>())
237 .unwrap(),
238 )
239 .expect("end offset overflow");
240 assert!(
241 end <= self.mmap.len(),
242 "embedding extends beyond mmap bounds"
243 );
244
245 self.access_bitmap.set(node_id as usize);
247
248 unsafe {
250 let ptr = self.mmap.as_ptr().add(offset) as *const f32;
251 std::slice::from_raw_parts(ptr, self.d_embed)
252 }
253 }
254
255 pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
265 assert!(
267 self.validate_node_id(node_id),
268 "node_id {} out of bounds (max: {})",
269 node_id,
270 self.max_nodes
271 );
272 assert_eq!(
273 data.len(),
274 self.d_embed,
275 "Embedding data length must match d_embed"
276 );
277
278 let offset = self
279 .embedding_offset(node_id)
280 .expect("embedding offset calculation overflow");
281 let end = offset
282 .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
283 .expect("end offset overflow");
284 assert!(
285 end <= self.mmap.len(),
286 "embedding extends beyond mmap bounds"
287 );
288
289 self.access_bitmap.set(node_id as usize);
291 self.dirty_bitmap.set(node_id as usize);
292
293 unsafe {
295 let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
296 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
297 }
298 }
299
300 pub fn flush_dirty(&self) -> io::Result<()> {
305 let dirty_nodes = self.dirty_bitmap.get_set_indices();
306
307 if dirty_nodes.is_empty() {
308 return Ok(());
309 }
310
311 self.mmap.flush()?;
314
315 for &node_id in &dirty_nodes {
317 self.dirty_bitmap.clear(node_id);
318 }
319
320 Ok(())
321 }
322
323 pub fn prefetch(&self, node_ids: &[u64]) {
328 #[cfg(target_os = "linux")]
329 {
330 #[allow(unused_imports)]
331 use std::os::unix::io::AsRawFd;
332
333 for &node_id in node_ids {
334 if !self.validate_node_id(node_id) {
336 continue;
337 }
338 let offset = match self.embedding_offset(node_id) {
339 Some(o) => o,
340 None => continue,
341 };
342 let page_offset = (offset / self.page_size) * self.page_size;
343 let length = self.d_embed * std::mem::size_of::<f32>();
344
345 unsafe {
346 libc::madvise(
348 self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
349 length,
350 libc::MADV_WILLNEED,
351 );
352 }
353 }
354 }
355
356 #[cfg(not(target_os = "linux"))]
358 {
359 for &node_id in node_ids {
360 if self.validate_node_id(node_id) {
361 let _ = self.get_embedding(node_id);
362 }
363 }
364 }
365 }
366
367 pub fn d_embed(&self) -> usize {
369 self.d_embed
370 }
371
372 pub fn max_nodes(&self) -> usize {
374 self.max_nodes
375 }
376}
377
378pub struct MmapGradientAccumulator {
383 grad_mmap: std::cell::UnsafeCell<MmapMut>,
385 lock_granularity: usize,
387 locks: Vec<RwLock<()>>,
389 n_nodes: usize,
391 d_embed: usize,
393 _file: File,
395}
396
397impl MmapGradientAccumulator {
398 pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
408 let grad_size = d_embed * std::mem::size_of::<f32>();
410 let file_size = max_nodes * grad_size;
411
412 let file = OpenOptions::new()
414 .read(true)
415 .write(true)
416 .create(true)
417 .open(path)
418 .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
419
420 file.set_len(file_size as u64)
422 .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
423
424 let grad_mmap = unsafe {
426 MmapOptions::new()
427 .len(file_size)
428 .map_mut(&file)
429 .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
430 };
431
432 for byte in grad_mmap.iter() {
434 let _ = byte;
436 }
437
438 let lock_granularity = 64;
440 let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
441 let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
442
443 Ok(Self {
444 grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
445 lock_granularity,
446 locks,
447 n_nodes: max_nodes,
448 d_embed,
449 _file: file,
450 })
451 }
452
453 #[inline]
461 pub fn grad_offset(&self, node_id: u64) -> usize {
462 (node_id as usize) * self.d_embed * std::mem::size_of::<f32>()
463 }
464
465 pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
474 assert_eq!(
475 grad.len(),
476 self.d_embed,
477 "Gradient length must match d_embed"
478 );
479
480 let lock_idx = (node_id as usize) / self.lock_granularity;
481 let _lock = self.locks[lock_idx].write();
482
483 let offset = self.grad_offset(node_id);
484
485 unsafe {
487 let mmap = &mut *self.grad_mmap.get();
488 let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
489 let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
490
491 for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
493 *g += new_g;
494 }
495 }
496 }
497
498 pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
504 assert_eq!(
505 self.d_embed, embeddings.d_embed,
506 "Gradient and embedding dimensions must match"
507 );
508
509 for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
511 let grad = self.get_grad(node_id as u64);
512 let embedding = embeddings.get_embedding(node_id as u64);
513
514 let mut updated = vec![0.0f32; self.d_embed];
516 for i in 0..self.d_embed {
517 updated[i] = embedding[i] - learning_rate * grad[i];
518 }
519
520 embeddings.set_embedding(node_id as u64, &updated);
521 }
522
523 self.zero_grad();
525 }
526
527 pub fn zero_grad(&mut self) {
529 unsafe {
531 let mmap = &mut *self.grad_mmap.get();
532 for byte in mmap.iter_mut() {
533 *byte = 0;
534 }
535 }
536 }
537
538 pub fn get_grad(&self, node_id: u64) -> &[f32] {
546 let lock_idx = (node_id as usize) / self.lock_granularity;
547 let _lock = self.locks[lock_idx].read();
548
549 let offset = self.grad_offset(node_id);
550
551 unsafe {
553 let mmap = &*self.grad_mmap.get();
554 let ptr = mmap.as_ptr().add(offset) as *const f32;
555 std::slice::from_raw_parts(ptr, self.d_embed)
556 }
557 }
558
559 pub fn d_embed(&self) -> usize {
561 self.d_embed
562 }
563
564 pub fn n_nodes(&self) -> usize {
566 self.n_nodes
567 }
568}
569
570impl Drop for MmapManager {
572 fn drop(&mut self) {
573 let _ = self.flush_dirty();
575 }
576}
577
578impl Drop for MmapGradientAccumulator {
579 fn drop(&mut self) {
580 unsafe {
582 let mmap = &mut *self.grad_mmap.get();
583 let _ = mmap.flush();
584 }
585 }
586}
587
588unsafe impl Send for MmapGradientAccumulator {}
591unsafe impl Sync for MmapGradientAccumulator {}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use std::fs;
597 use tempfile::TempDir;
598
599 #[test]
600 fn test_atomic_bitmap_basic() {
601 let bitmap = AtomicBitmap::new(128);
602
603 assert!(!bitmap.get(0));
604 assert!(!bitmap.get(127));
605
606 bitmap.set(0);
607 bitmap.set(127);
608 bitmap.set(64);
609
610 assert!(bitmap.get(0));
611 assert!(bitmap.get(127));
612 assert!(bitmap.get(64));
613 assert!(!bitmap.get(1));
614
615 bitmap.clear(0);
616 assert!(!bitmap.get(0));
617 assert!(bitmap.get(127));
618 }
619
620 #[test]
621 fn test_atomic_bitmap_get_set_indices() {
622 let bitmap = AtomicBitmap::new(256);
623
624 bitmap.set(0);
625 bitmap.set(63);
626 bitmap.set(64);
627 bitmap.set(128);
628 bitmap.set(255);
629
630 let mut indices = bitmap.get_set_indices();
631 indices.sort();
632
633 assert_eq!(indices, vec![0, 63, 64, 128, 255]);
634 }
635
636 #[test]
637 fn test_atomic_bitmap_clear_all() {
638 let bitmap = AtomicBitmap::new(128);
639
640 bitmap.set(0);
641 bitmap.set(64);
642 bitmap.set(127);
643
644 assert!(bitmap.get(0));
645
646 bitmap.clear_all();
647
648 assert!(!bitmap.get(0));
649 assert!(!bitmap.get(64));
650 assert!(!bitmap.get(127));
651 }
652
653 #[test]
654 fn test_mmap_manager_creation() {
655 let temp_dir = TempDir::new().unwrap();
656 let path = temp_dir.path().join("embeddings.bin");
657
658 let manager = MmapManager::new(&path, 128, 1000).unwrap();
659
660 assert_eq!(manager.d_embed(), 128);
661 assert_eq!(manager.max_nodes(), 1000);
662 assert!(path.exists());
663 }
664
665 #[test]
666 fn test_mmap_manager_set_get_embedding() {
667 let temp_dir = TempDir::new().unwrap();
668 let path = temp_dir.path().join("embeddings.bin");
669
670 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
671
672 let embedding = vec![1.0f32; 64];
673 manager.set_embedding(0, &embedding);
674
675 let retrieved = manager.get_embedding(0);
676 assert_eq!(retrieved.len(), 64);
677 assert_eq!(retrieved[0], 1.0);
678 assert_eq!(retrieved[63], 1.0);
679 }
680
681 #[test]
682 fn test_mmap_manager_multiple_embeddings() {
683 let temp_dir = TempDir::new().unwrap();
684 let path = temp_dir.path().join("embeddings.bin");
685
686 let mut manager = MmapManager::new(&path, 32, 100).unwrap();
687
688 for i in 0..10 {
689 let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
690 manager.set_embedding(i, &embedding);
691 }
692
693 for i in 0..10 {
695 let retrieved = manager.get_embedding(i);
696 assert_eq!(retrieved.len(), 32);
697 assert_eq!(retrieved[0], (i * 32) as f32);
698 assert_eq!(retrieved[31], (i * 32 + 31) as f32);
699 }
700 }
701
702 #[test]
703 fn test_mmap_manager_dirty_tracking() {
704 let temp_dir = TempDir::new().unwrap();
705 let path = temp_dir.path().join("embeddings.bin");
706
707 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
708
709 let embedding = vec![2.0f32; 64];
710 manager.set_embedding(5, &embedding);
711
712 assert!(manager.dirty_bitmap.get(5));
714
715 manager.flush_dirty().unwrap();
717 assert!(!manager.dirty_bitmap.get(5));
718 }
719
720 #[test]
721 fn test_mmap_manager_persistence() {
722 let temp_dir = TempDir::new().unwrap();
723 let path = temp_dir.path().join("embeddings.bin");
724
725 {
726 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
727 let embedding = vec![3.14f32; 64];
728 manager.set_embedding(10, &embedding);
729 manager.flush_dirty().unwrap();
730 }
731
732 {
734 let manager = MmapManager::new(&path, 64, 100).unwrap();
735 let retrieved = manager.get_embedding(10);
736 assert_eq!(retrieved[0], 3.14);
737 assert_eq!(retrieved[63], 3.14);
738 }
739 }
740
741 #[test]
742 fn test_gradient_accumulator_creation() {
743 let temp_dir = TempDir::new().unwrap();
744 let path = temp_dir.path().join("gradients.bin");
745
746 let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
747
748 assert_eq!(accumulator.d_embed(), 128);
749 assert_eq!(accumulator.n_nodes(), 1000);
750 assert!(path.exists());
751 }
752
753 #[test]
754 fn test_gradient_accumulator_accumulate() {
755 let temp_dir = TempDir::new().unwrap();
756 let path = temp_dir.path().join("gradients.bin");
757
758 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
759
760 let grad1 = vec![1.0f32; 64];
761 let grad2 = vec![2.0f32; 64];
762
763 accumulator.accumulate(0, &grad1);
764 accumulator.accumulate(0, &grad2);
765
766 let accumulated = accumulator.get_grad(0);
767 assert_eq!(accumulated[0], 3.0);
768 assert_eq!(accumulated[63], 3.0);
769 }
770
771 #[test]
772 fn test_gradient_accumulator_zero_grad() {
773 let temp_dir = TempDir::new().unwrap();
774 let path = temp_dir.path().join("gradients.bin");
775
776 let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
777
778 let grad = vec![1.5f32; 64];
779 accumulator.accumulate(0, &grad);
780
781 let accumulated = accumulator.get_grad(0);
782 assert_eq!(accumulated[0], 1.5);
783
784 accumulator.zero_grad();
785
786 let zeroed = accumulator.get_grad(0);
787 assert_eq!(zeroed[0], 0.0);
788 assert_eq!(zeroed[63], 0.0);
789 }
790
791 #[test]
792 fn test_gradient_accumulator_apply() {
793 let temp_dir = TempDir::new().unwrap();
794 let embed_path = temp_dir.path().join("embeddings.bin");
795 let grad_path = temp_dir.path().join("gradients.bin");
796
797 let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
798 let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
799
800 let initial = vec![10.0f32; 32];
802 embeddings.set_embedding(0, &initial);
803
804 let grad = vec![1.0f32; 32];
806 accumulator.accumulate(0, &grad);
807
808 accumulator.apply(0.1, &mut embeddings);
810
811 let updated = embeddings.get_embedding(0);
813 assert!((updated[0] - 9.9).abs() < 1e-6);
814
815 let zeroed_grad = accumulator.get_grad(0);
817 assert_eq!(zeroed_grad[0], 0.0);
818 }
819
820 #[test]
821 fn test_gradient_accumulator_concurrent_accumulation() {
822 use std::thread;
823
824 let temp_dir = TempDir::new().unwrap();
825 let path = temp_dir.path().join("gradients.bin");
826
827 let accumulator =
828 std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
829
830 let mut handles = vec![];
831
832 for _ in 0..10 {
834 let acc = accumulator.clone();
835 let handle = thread::spawn(move || {
836 let grad = vec![1.0f32; 64];
837 acc.accumulate(0, &grad);
838 });
839 handles.push(handle);
840 }
841
842 for handle in handles {
843 handle.join().unwrap();
844 }
845
846 let result = accumulator.get_grad(0);
848 assert_eq!(result[0], 10.0);
849 }
850
851 #[test]
852 fn test_embedding_offset_calculation() {
853 let temp_dir = TempDir::new().unwrap();
854 let path = temp_dir.path().join("embeddings.bin");
855
856 let manager = MmapManager::new(&path, 64, 100).unwrap();
857
858 assert_eq!(manager.embedding_offset(0), Some(0));
859 assert_eq!(manager.embedding_offset(1), Some(64 * 4)); assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
861 }
862
863 #[test]
864 fn test_grad_offset_calculation() {
865 let temp_dir = TempDir::new().unwrap();
866 let path = temp_dir.path().join("gradients.bin");
867
868 let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
869
870 assert_eq!(accumulator.grad_offset(0), 0);
871 assert_eq!(accumulator.grad_offset(1), 128 * 4); assert_eq!(accumulator.grad_offset(5), 128 * 4 * 5);
873 }
874
875 #[test]
876 #[should_panic(expected = "Embedding data length must match d_embed")]
877 fn test_set_embedding_wrong_size() {
878 let temp_dir = TempDir::new().unwrap();
879 let path = temp_dir.path().join("embeddings.bin");
880
881 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
882 let wrong_size = vec![1.0f32; 32]; manager.set_embedding(0, &wrong_size);
884 }
885
886 #[test]
887 #[should_panic(expected = "Gradient length must match d_embed")]
888 fn test_accumulate_wrong_size() {
889 let temp_dir = TempDir::new().unwrap();
890 let path = temp_dir.path().join("gradients.bin");
891
892 let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
893 let wrong_size = vec![1.0f32; 32]; accumulator.accumulate(0, &wrong_size);
895 }
896
897 #[test]
898 fn test_prefetch() {
899 let temp_dir = TempDir::new().unwrap();
900 let path = temp_dir.path().join("embeddings.bin");
901
902 let mut manager = MmapManager::new(&path, 64, 100).unwrap();
903
904 for i in 0..10 {
906 let embedding = vec![i as f32; 64];
907 manager.set_embedding(i, &embedding);
908 }
909
910 manager.prefetch(&[0, 1, 2, 3, 4]);
912
913 let retrieved = manager.get_embedding(2);
915 assert_eq!(retrieved[0], 2.0);
916 }
917}