ruvector_gnn/
mmap.rs

1//! Memory-mapped embedding management for large-scale GNN training.
2//!
3//! This module provides efficient memory-mapped access to embeddings and gradients
4//! that don't fit in RAM. It includes:
5//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
6//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
7//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
8//!
9//! Only available on non-WASM targets.
10
11#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io::{self, Write};
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21/// Thread-safe bitmap using atomic operations.
22///
23/// Used for tracking which embeddings have been accessed or modified.
24/// Each bit represents one embedding node.
25#[derive(Debug)]
26pub struct AtomicBitmap {
27    /// Array of 64-bit atomic integers, each storing 64 bits
28    bits: Vec<AtomicU64>,
29    /// Total number of bits (nodes)
30    size: usize,
31}
32
33impl AtomicBitmap {
34    /// Create a new atomic bitmap with the specified capacity.
35    ///
36    /// # Arguments
37    /// * `size` - Number of bits to allocate
38    pub fn new(size: usize) -> Self {
39        let num_words = (size + 63) / 64;
40        let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42        Self { bits, size }
43    }
44
45    /// Set a bit to 1 (mark as accessed/dirty).
46    ///
47    /// # Arguments
48    /// * `index` - Bit index to set
49    pub fn set(&self, index: usize) {
50        if index >= self.size {
51            return;
52        }
53        let word_idx = index / 64;
54        let bit_idx = index % 64;
55        self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56    }
57
58    /// Clear a bit to 0 (mark as clean/not accessed).
59    ///
60    /// # Arguments
61    /// * `index` - Bit index to clear
62    pub fn clear(&self, index: usize) {
63        if index >= self.size {
64            return;
65        }
66        let word_idx = index / 64;
67        let bit_idx = index % 64;
68        self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69    }
70
71    /// Check if a bit is set.
72    ///
73    /// # Arguments
74    /// * `index` - Bit index to check
75    ///
76    /// # Returns
77    /// `true` if the bit is set, `false` otherwise
78    pub fn get(&self, index: usize) -> bool {
79        if index >= self.size {
80            return false;
81        }
82        let word_idx = index / 64;
83        let bit_idx = index % 64;
84        let word = self.bits[word_idx].load(Ordering::Acquire);
85        (word & (1u64 << bit_idx)) != 0
86    }
87
88    /// Clear all bits in the bitmap.
89    pub fn clear_all(&self) {
90        for word in &self.bits {
91            word.store(0, Ordering::Release);
92        }
93    }
94
95    /// Get all set bit indices (for finding dirty pages).
96    ///
97    /// # Returns
98    /// Vector of indices where bits are set
99    pub fn get_set_indices(&self) -> Vec<usize> {
100        let mut indices = Vec::new();
101        for (word_idx, word) in self.bits.iter().enumerate() {
102            let mut w = word.load(Ordering::Acquire);
103            while w != 0 {
104                let bit_idx = w.trailing_zeros() as usize;
105                indices.push(word_idx * 64 + bit_idx);
106                w &= w - 1; // Clear lowest set bit
107            }
108        }
109        indices
110    }
111}
112
113/// Memory-mapped embedding manager with dirty tracking and prefetching.
114///
115/// Manages large embedding matrices that may not fit in RAM using memory-mapped files.
116/// Tracks which embeddings have been accessed and modified for efficient I/O.
117#[derive(Debug)]
118pub struct MmapManager {
119    /// The memory-mapped file
120    file: File,
121    /// Mutable memory mapping
122    mmap: MmapMut,
123    /// Operating system page size
124    page_size: usize,
125    /// Embedding dimension
126    d_embed: usize,
127    /// Bitmap tracking which embeddings have been accessed
128    access_bitmap: AtomicBitmap,
129    /// Bitmap tracking which embeddings have been modified
130    dirty_bitmap: AtomicBitmap,
131    /// Pin count for each page (prevents eviction)
132    pin_count: Vec<AtomicU32>,
133    /// Maximum number of nodes
134    max_nodes: usize,
135}
136
137impl MmapManager {
138    /// Create a new memory-mapped embedding manager.
139    ///
140    /// # Arguments
141    /// * `path` - Path to the memory-mapped file
142    /// * `d_embed` - Embedding dimension
143    /// * `max_nodes` - Maximum number of nodes to support
144    ///
145    /// # Returns
146    /// A new `MmapManager` instance
147    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
148        // Calculate required file size
149        let embedding_size = d_embed * std::mem::size_of::<f32>();
150        let file_size = max_nodes * embedding_size;
151
152        // Create or open the file
153        let file = OpenOptions::new()
154            .read(true)
155            .write(true)
156            .create(true)
157            .open(path)
158            .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
159
160        // Set file size
161        file.set_len(file_size as u64)
162            .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
163
164        // Create memory mapping
165        let mmap = unsafe {
166            MmapOptions::new()
167                .len(file_size)
168                .map_mut(&file)
169                .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
170        };
171
172        // Get system page size
173        let page_size = page_size::get();
174        let num_pages = (file_size + page_size - 1) / page_size;
175
176        Ok(Self {
177            file,
178            mmap,
179            page_size,
180            d_embed,
181            access_bitmap: AtomicBitmap::new(max_nodes),
182            dirty_bitmap: AtomicBitmap::new(max_nodes),
183            pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
184            max_nodes,
185        })
186    }
187
188    /// Calculate the byte offset for a given node's embedding.
189    ///
190    /// # Arguments
191    /// * `node_id` - Node identifier
192    ///
193    /// # Returns
194    /// Byte offset in the memory-mapped file, or None if overflow would occur
195    ///
196    /// # Security
197    /// Uses checked arithmetic to prevent integer overflow attacks.
198    #[inline]
199    pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
200        let node_idx = usize::try_from(node_id).ok()?;
201        let elem_size = std::mem::size_of::<f32>();
202        node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
203    }
204
205    /// Validate that a node_id is within bounds.
206    #[inline]
207    fn validate_node_id(&self, node_id: u64) -> bool {
208        (node_id as usize) < self.max_nodes
209    }
210
211    /// Get a read-only reference to a node's embedding.
212    ///
213    /// # Arguments
214    /// * `node_id` - Node identifier
215    ///
216    /// # Returns
217    /// Slice containing the embedding vector
218    ///
219    /// # Panics
220    /// Panics if node_id is out of bounds or would cause overflow
221    pub fn get_embedding(&self, node_id: u64) -> &[f32] {
222        // Security: Validate bounds before any pointer arithmetic
223        assert!(
224            self.validate_node_id(node_id),
225            "node_id {} out of bounds (max: {})",
226            node_id,
227            self.max_nodes
228        );
229
230        let offset = self
231            .embedding_offset(node_id)
232            .expect("embedding offset calculation overflow");
233        let end = offset
234            .checked_add(
235                self.d_embed
236                    .checked_mul(std::mem::size_of::<f32>())
237                    .unwrap(),
238            )
239            .expect("end offset overflow");
240        assert!(
241            end <= self.mmap.len(),
242            "embedding extends beyond mmap bounds"
243        );
244
245        // Mark as accessed
246        self.access_bitmap.set(node_id as usize);
247
248        // Safety: We control the offset and know the data is properly aligned
249        unsafe {
250            let ptr = self.mmap.as_ptr().add(offset) as *const f32;
251            std::slice::from_raw_parts(ptr, self.d_embed)
252        }
253    }
254
255    /// Set a node's embedding data.
256    ///
257    /// # Arguments
258    /// * `node_id` - Node identifier
259    /// * `data` - Embedding vector to write
260    ///
261    /// # Panics
262    /// Panics if node_id is out of bounds, data length doesn't match d_embed,
263    /// or offset calculation would overflow.
264    pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
265        // Security: Validate bounds first
266        assert!(
267            self.validate_node_id(node_id),
268            "node_id {} out of bounds (max: {})",
269            node_id,
270            self.max_nodes
271        );
272        assert_eq!(
273            data.len(),
274            self.d_embed,
275            "Embedding data length must match d_embed"
276        );
277
278        let offset = self
279            .embedding_offset(node_id)
280            .expect("embedding offset calculation overflow");
281        let end = offset
282            .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
283            .expect("end offset overflow");
284        assert!(
285            end <= self.mmap.len(),
286            "embedding extends beyond mmap bounds"
287        );
288
289        // Mark as accessed and dirty
290        self.access_bitmap.set(node_id as usize);
291        self.dirty_bitmap.set(node_id as usize);
292
293        // Safety: We control the offset and know the data is properly aligned
294        unsafe {
295            let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
296            std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
297        }
298    }
299
300    /// Flush all dirty pages to disk.
301    ///
302    /// # Returns
303    /// `Ok(())` on success, error otherwise
304    pub fn flush_dirty(&self) -> io::Result<()> {
305        let dirty_nodes = self.dirty_bitmap.get_set_indices();
306
307        if dirty_nodes.is_empty() {
308            return Ok(());
309        }
310
311        // Flush the entire mmap for simplicity
312        // In a production system, you might want to flush only dirty pages
313        self.mmap.flush()?;
314
315        // Clear dirty bitmap after successful flush
316        for &node_id in &dirty_nodes {
317            self.dirty_bitmap.clear(node_id);
318        }
319
320        Ok(())
321    }
322
323    /// Prefetch embeddings into memory for better cache locality.
324    ///
325    /// # Arguments
326    /// * `node_ids` - List of node IDs to prefetch
327    pub fn prefetch(&self, node_ids: &[u64]) {
328        #[cfg(target_os = "linux")]
329        {
330            #[allow(unused_imports)]
331            use std::os::unix::io::AsRawFd;
332
333            for &node_id in node_ids {
334                // Skip invalid node IDs
335                if !self.validate_node_id(node_id) {
336                    continue;
337                }
338                let offset = match self.embedding_offset(node_id) {
339                    Some(o) => o,
340                    None => continue,
341                };
342                let page_offset = (offset / self.page_size) * self.page_size;
343                let length = self.d_embed * std::mem::size_of::<f32>();
344
345                unsafe {
346                    // Use madvise to hint the kernel to prefetch
347                    libc::madvise(
348                        self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
349                        length,
350                        libc::MADV_WILLNEED,
351                    );
352                }
353            }
354        }
355
356        // On non-Linux platforms, just access the data to bring it into cache
357        #[cfg(not(target_os = "linux"))]
358        {
359            for &node_id in node_ids {
360                if self.validate_node_id(node_id) {
361                    let _ = self.get_embedding(node_id);
362                }
363            }
364        }
365    }
366
367    /// Get the embedding dimension.
368    pub fn d_embed(&self) -> usize {
369        self.d_embed
370    }
371
372    /// Get the maximum number of nodes.
373    pub fn max_nodes(&self) -> usize {
374        self.max_nodes
375    }
376}
377
378/// Memory-mapped gradient accumulator with fine-grained locking.
379///
380/// Allows multiple threads to accumulate gradients concurrently with minimal contention.
381/// Uses reader-writer locks at a configurable granularity.
382pub struct MmapGradientAccumulator {
383    /// Memory-mapped gradient storage (using UnsafeCell for interior mutability)
384    grad_mmap: std::cell::UnsafeCell<MmapMut>,
385    /// Number of nodes per lock (lock granularity)
386    lock_granularity: usize,
387    /// Reader-writer locks for gradient regions
388    locks: Vec<RwLock<()>>,
389    /// Number of nodes
390    n_nodes: usize,
391    /// Embedding dimension
392    d_embed: usize,
393    /// Gradient file
394    _file: File,
395}
396
397impl MmapGradientAccumulator {
398    /// Create a new memory-mapped gradient accumulator.
399    ///
400    /// # Arguments
401    /// * `path` - Path to the gradient file
402    /// * `d_embed` - Embedding dimension
403    /// * `max_nodes` - Maximum number of nodes
404    ///
405    /// # Returns
406    /// A new `MmapGradientAccumulator` instance
407    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
408        // Calculate required file size
409        let grad_size = d_embed * std::mem::size_of::<f32>();
410        let file_size = max_nodes * grad_size;
411
412        // Create or open the file
413        let file = OpenOptions::new()
414            .read(true)
415            .write(true)
416            .create(true)
417            .open(path)
418            .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
419
420        // Set file size
421        file.set_len(file_size as u64)
422            .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
423
424        // Create memory mapping
425        let grad_mmap = unsafe {
426            MmapOptions::new()
427                .len(file_size)
428                .map_mut(&file)
429                .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
430        };
431
432        // Zero out the gradients
433        for byte in grad_mmap.iter() {
434            // This forces the pages to be allocated and zeroed
435            let _ = byte;
436        }
437
438        // Use a lock granularity of 64 nodes per lock for good parallelism
439        let lock_granularity = 64;
440        let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
441        let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
442
443        Ok(Self {
444            grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
445            lock_granularity,
446            locks,
447            n_nodes: max_nodes,
448            d_embed,
449            _file: file,
450        })
451    }
452
453    /// Calculate the byte offset for a node's gradient.
454    ///
455    /// # Arguments
456    /// * `node_id` - Node identifier
457    ///
458    /// # Returns
459    /// Byte offset in the gradient file
460    #[inline]
461    pub fn grad_offset(&self, node_id: u64) -> usize {
462        (node_id as usize) * self.d_embed * std::mem::size_of::<f32>()
463    }
464
465    /// Accumulate gradients for a specific node.
466    ///
467    /// # Arguments
468    /// * `node_id` - Node identifier
469    /// * `grad` - Gradient vector to accumulate
470    ///
471    /// # Panics
472    /// Panics if grad length doesn't match d_embed
473    pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
474        assert_eq!(
475            grad.len(),
476            self.d_embed,
477            "Gradient length must match d_embed"
478        );
479
480        let lock_idx = (node_id as usize) / self.lock_granularity;
481        let _lock = self.locks[lock_idx].write();
482
483        let offset = self.grad_offset(node_id);
484
485        // Safety: We hold the write lock for this region, ensuring exclusive access
486        unsafe {
487            let mmap = &mut *self.grad_mmap.get();
488            let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
489            let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
490
491            // Accumulate gradients
492            for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
493                *g += new_g;
494            }
495        }
496    }
497
498    /// Apply accumulated gradients to embeddings and zero out gradients.
499    ///
500    /// # Arguments
501    /// * `learning_rate` - Learning rate for gradient descent
502    /// * `embeddings` - Embedding manager to update
503    pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
504        assert_eq!(
505            self.d_embed, embeddings.d_embed,
506            "Gradient and embedding dimensions must match"
507        );
508
509        // Process all nodes
510        for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
511            let grad = self.get_grad(node_id as u64);
512            let embedding = embeddings.get_embedding(node_id as u64);
513
514            // Apply gradient descent: embedding -= learning_rate * grad
515            let mut updated = vec![0.0f32; self.d_embed];
516            for i in 0..self.d_embed {
517                updated[i] = embedding[i] - learning_rate * grad[i];
518            }
519
520            embeddings.set_embedding(node_id as u64, &updated);
521        }
522
523        // Zero out gradients after applying
524        self.zero_grad();
525    }
526
527    /// Zero out all accumulated gradients.
528    pub fn zero_grad(&mut self) {
529        // Zero the entire gradient buffer
530        unsafe {
531            let mmap = &mut *self.grad_mmap.get();
532            for byte in mmap.iter_mut() {
533                *byte = 0;
534            }
535        }
536    }
537
538    /// Get a read-only reference to a node's accumulated gradient.
539    ///
540    /// # Arguments
541    /// * `node_id` - Node identifier
542    ///
543    /// # Returns
544    /// Slice containing the gradient vector
545    pub fn get_grad(&self, node_id: u64) -> &[f32] {
546        let lock_idx = (node_id as usize) / self.lock_granularity;
547        let _lock = self.locks[lock_idx].read();
548
549        let offset = self.grad_offset(node_id);
550
551        // Safety: We hold the read lock for this region
552        unsafe {
553            let mmap = &*self.grad_mmap.get();
554            let ptr = mmap.as_ptr().add(offset) as *const f32;
555            std::slice::from_raw_parts(ptr, self.d_embed)
556        }
557    }
558
559    /// Get the embedding dimension.
560    pub fn d_embed(&self) -> usize {
561        self.d_embed
562    }
563
564    /// Get the number of nodes.
565    pub fn n_nodes(&self) -> usize {
566        self.n_nodes
567    }
568}
569
570// Implement Drop to ensure proper cleanup
571impl Drop for MmapManager {
572    fn drop(&mut self) {
573        // Try to flush dirty pages before dropping
574        let _ = self.flush_dirty();
575    }
576}
577
578impl Drop for MmapGradientAccumulator {
579    fn drop(&mut self) {
580        // Flush gradient data
581        unsafe {
582            let mmap = &mut *self.grad_mmap.get();
583            let _ = mmap.flush();
584        }
585    }
586}
587
588// Safety: MmapGradientAccumulator is safe to send between threads
589// because access is protected by RwLocks
590unsafe impl Send for MmapGradientAccumulator {}
591unsafe impl Sync for MmapGradientAccumulator {}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use std::fs;
597    use tempfile::TempDir;
598
599    #[test]
600    fn test_atomic_bitmap_basic() {
601        let bitmap = AtomicBitmap::new(128);
602
603        assert!(!bitmap.get(0));
604        assert!(!bitmap.get(127));
605
606        bitmap.set(0);
607        bitmap.set(127);
608        bitmap.set(64);
609
610        assert!(bitmap.get(0));
611        assert!(bitmap.get(127));
612        assert!(bitmap.get(64));
613        assert!(!bitmap.get(1));
614
615        bitmap.clear(0);
616        assert!(!bitmap.get(0));
617        assert!(bitmap.get(127));
618    }
619
620    #[test]
621    fn test_atomic_bitmap_get_set_indices() {
622        let bitmap = AtomicBitmap::new(256);
623
624        bitmap.set(0);
625        bitmap.set(63);
626        bitmap.set(64);
627        bitmap.set(128);
628        bitmap.set(255);
629
630        let mut indices = bitmap.get_set_indices();
631        indices.sort();
632
633        assert_eq!(indices, vec![0, 63, 64, 128, 255]);
634    }
635
636    #[test]
637    fn test_atomic_bitmap_clear_all() {
638        let bitmap = AtomicBitmap::new(128);
639
640        bitmap.set(0);
641        bitmap.set(64);
642        bitmap.set(127);
643
644        assert!(bitmap.get(0));
645
646        bitmap.clear_all();
647
648        assert!(!bitmap.get(0));
649        assert!(!bitmap.get(64));
650        assert!(!bitmap.get(127));
651    }
652
653    #[test]
654    fn test_mmap_manager_creation() {
655        let temp_dir = TempDir::new().unwrap();
656        let path = temp_dir.path().join("embeddings.bin");
657
658        let manager = MmapManager::new(&path, 128, 1000).unwrap();
659
660        assert_eq!(manager.d_embed(), 128);
661        assert_eq!(manager.max_nodes(), 1000);
662        assert!(path.exists());
663    }
664
665    #[test]
666    fn test_mmap_manager_set_get_embedding() {
667        let temp_dir = TempDir::new().unwrap();
668        let path = temp_dir.path().join("embeddings.bin");
669
670        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
671
672        let embedding = vec![1.0f32; 64];
673        manager.set_embedding(0, &embedding);
674
675        let retrieved = manager.get_embedding(0);
676        assert_eq!(retrieved.len(), 64);
677        assert_eq!(retrieved[0], 1.0);
678        assert_eq!(retrieved[63], 1.0);
679    }
680
681    #[test]
682    fn test_mmap_manager_multiple_embeddings() {
683        let temp_dir = TempDir::new().unwrap();
684        let path = temp_dir.path().join("embeddings.bin");
685
686        let mut manager = MmapManager::new(&path, 32, 100).unwrap();
687
688        for i in 0..10 {
689            let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
690            manager.set_embedding(i, &embedding);
691        }
692
693        // Verify each embedding
694        for i in 0..10 {
695            let retrieved = manager.get_embedding(i);
696            assert_eq!(retrieved.len(), 32);
697            assert_eq!(retrieved[0], (i * 32) as f32);
698            assert_eq!(retrieved[31], (i * 32 + 31) as f32);
699        }
700    }
701
702    #[test]
703    fn test_mmap_manager_dirty_tracking() {
704        let temp_dir = TempDir::new().unwrap();
705        let path = temp_dir.path().join("embeddings.bin");
706
707        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
708
709        let embedding = vec![2.0f32; 64];
710        manager.set_embedding(5, &embedding);
711
712        // Should be marked as dirty
713        assert!(manager.dirty_bitmap.get(5));
714
715        // Flush and check it's clean
716        manager.flush_dirty().unwrap();
717        assert!(!manager.dirty_bitmap.get(5));
718    }
719
720    #[test]
721    fn test_mmap_manager_persistence() {
722        let temp_dir = TempDir::new().unwrap();
723        let path = temp_dir.path().join("embeddings.bin");
724
725        {
726            let mut manager = MmapManager::new(&path, 64, 100).unwrap();
727            let embedding = vec![3.14f32; 64];
728            manager.set_embedding(10, &embedding);
729            manager.flush_dirty().unwrap();
730        }
731
732        // Reopen and verify data persisted
733        {
734            let manager = MmapManager::new(&path, 64, 100).unwrap();
735            let retrieved = manager.get_embedding(10);
736            assert_eq!(retrieved[0], 3.14);
737            assert_eq!(retrieved[63], 3.14);
738        }
739    }
740
741    #[test]
742    fn test_gradient_accumulator_creation() {
743        let temp_dir = TempDir::new().unwrap();
744        let path = temp_dir.path().join("gradients.bin");
745
746        let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
747
748        assert_eq!(accumulator.d_embed(), 128);
749        assert_eq!(accumulator.n_nodes(), 1000);
750        assert!(path.exists());
751    }
752
753    #[test]
754    fn test_gradient_accumulator_accumulate() {
755        let temp_dir = TempDir::new().unwrap();
756        let path = temp_dir.path().join("gradients.bin");
757
758        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
759
760        let grad1 = vec![1.0f32; 64];
761        let grad2 = vec![2.0f32; 64];
762
763        accumulator.accumulate(0, &grad1);
764        accumulator.accumulate(0, &grad2);
765
766        let accumulated = accumulator.get_grad(0);
767        assert_eq!(accumulated[0], 3.0);
768        assert_eq!(accumulated[63], 3.0);
769    }
770
771    #[test]
772    fn test_gradient_accumulator_zero_grad() {
773        let temp_dir = TempDir::new().unwrap();
774        let path = temp_dir.path().join("gradients.bin");
775
776        let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
777
778        let grad = vec![1.5f32; 64];
779        accumulator.accumulate(0, &grad);
780
781        let accumulated = accumulator.get_grad(0);
782        assert_eq!(accumulated[0], 1.5);
783
784        accumulator.zero_grad();
785
786        let zeroed = accumulator.get_grad(0);
787        assert_eq!(zeroed[0], 0.0);
788        assert_eq!(zeroed[63], 0.0);
789    }
790
791    #[test]
792    fn test_gradient_accumulator_apply() {
793        let temp_dir = TempDir::new().unwrap();
794        let embed_path = temp_dir.path().join("embeddings.bin");
795        let grad_path = temp_dir.path().join("gradients.bin");
796
797        let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
798        let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
799
800        // Set initial embedding
801        let initial = vec![10.0f32; 32];
802        embeddings.set_embedding(0, &initial);
803
804        // Accumulate gradient
805        let grad = vec![1.0f32; 32];
806        accumulator.accumulate(0, &grad);
807
808        // Apply with learning rate 0.1
809        accumulator.apply(0.1, &mut embeddings);
810
811        // Check updated embedding: 10.0 - 0.1 * 1.0 = 9.9
812        let updated = embeddings.get_embedding(0);
813        assert!((updated[0] - 9.9).abs() < 1e-6);
814
815        // Check gradients were zeroed
816        let zeroed_grad = accumulator.get_grad(0);
817        assert_eq!(zeroed_grad[0], 0.0);
818    }
819
820    #[test]
821    fn test_gradient_accumulator_concurrent_accumulation() {
822        use std::thread;
823
824        let temp_dir = TempDir::new().unwrap();
825        let path = temp_dir.path().join("gradients.bin");
826
827        let accumulator =
828            std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
829
830        let mut handles = vec![];
831
832        // Spawn 10 threads, each accumulating 1.0 to node 0
833        for _ in 0..10 {
834            let acc = accumulator.clone();
835            let handle = thread::spawn(move || {
836                let grad = vec![1.0f32; 64];
837                acc.accumulate(0, &grad);
838            });
839            handles.push(handle);
840        }
841
842        for handle in handles {
843            handle.join().unwrap();
844        }
845
846        // Should have accumulated 10.0
847        let result = accumulator.get_grad(0);
848        assert_eq!(result[0], 10.0);
849    }
850
851    #[test]
852    fn test_embedding_offset_calculation() {
853        let temp_dir = TempDir::new().unwrap();
854        let path = temp_dir.path().join("embeddings.bin");
855
856        let manager = MmapManager::new(&path, 64, 100).unwrap();
857
858        assert_eq!(manager.embedding_offset(0), Some(0));
859        assert_eq!(manager.embedding_offset(1), Some(64 * 4)); // 64 floats * 4 bytes
860        assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
861    }
862
863    #[test]
864    fn test_grad_offset_calculation() {
865        let temp_dir = TempDir::new().unwrap();
866        let path = temp_dir.path().join("gradients.bin");
867
868        let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
869
870        assert_eq!(accumulator.grad_offset(0), 0);
871        assert_eq!(accumulator.grad_offset(1), 128 * 4); // 128 floats * 4 bytes
872        assert_eq!(accumulator.grad_offset(5), 128 * 4 * 5);
873    }
874
875    #[test]
876    #[should_panic(expected = "Embedding data length must match d_embed")]
877    fn test_set_embedding_wrong_size() {
878        let temp_dir = TempDir::new().unwrap();
879        let path = temp_dir.path().join("embeddings.bin");
880
881        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
882        let wrong_size = vec![1.0f32; 32]; // Should be 64
883        manager.set_embedding(0, &wrong_size);
884    }
885
886    #[test]
887    #[should_panic(expected = "Gradient length must match d_embed")]
888    fn test_accumulate_wrong_size() {
889        let temp_dir = TempDir::new().unwrap();
890        let path = temp_dir.path().join("gradients.bin");
891
892        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
893        let wrong_size = vec![1.0f32; 32]; // Should be 64
894        accumulator.accumulate(0, &wrong_size);
895    }
896
897    #[test]
898    fn test_prefetch() {
899        let temp_dir = TempDir::new().unwrap();
900        let path = temp_dir.path().join("embeddings.bin");
901
902        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
903
904        // Set some embeddings
905        for i in 0..10 {
906            let embedding = vec![i as f32; 64];
907            manager.set_embedding(i, &embedding);
908        }
909
910        // Prefetch should not crash
911        manager.prefetch(&[0, 1, 2, 3, 4]);
912
913        // Access should still work
914        let retrieved = manager.get_embedding(2);
915        assert_eq!(retrieved[0], 2.0);
916    }
917}