Skip to main content

ruvector_gnn/
mmap.rs

1//! Memory-mapped embedding management for large-scale GNN training.
2//!
3//! This module provides efficient memory-mapped access to embeddings and gradients
4//! that don't fit in RAM. It includes:
5//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
6//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
7//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
8//!
9//! Only available on non-WASM targets.
10
11#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io;
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21/// Thread-safe bitmap using atomic operations.
22///
23/// Used for tracking which embeddings have been accessed or modified.
24/// Each bit represents one embedding node.
25#[derive(Debug)]
26pub struct AtomicBitmap {
27    /// Array of 64-bit atomic integers, each storing 64 bits
28    bits: Vec<AtomicU64>,
29    /// Total number of bits (nodes)
30    size: usize,
31}
32
33impl AtomicBitmap {
34    /// Create a new atomic bitmap with the specified capacity.
35    ///
36    /// # Arguments
37    /// * `size` - Number of bits to allocate
38    pub fn new(size: usize) -> Self {
39        let num_words = size.div_ceil(64);
40        let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42        Self { bits, size }
43    }
44
45    /// Set a bit to 1 (mark as accessed/dirty).
46    ///
47    /// # Arguments
48    /// * `index` - Bit index to set
49    pub fn set(&self, index: usize) {
50        if index >= self.size {
51            return;
52        }
53        let word_idx = index / 64;
54        let bit_idx = index % 64;
55        self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56    }
57
58    /// Clear a bit to 0 (mark as clean/not accessed).
59    ///
60    /// # Arguments
61    /// * `index` - Bit index to clear
62    pub fn clear(&self, index: usize) {
63        if index >= self.size {
64            return;
65        }
66        let word_idx = index / 64;
67        let bit_idx = index % 64;
68        self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69    }
70
71    /// Check if a bit is set.
72    ///
73    /// # Arguments
74    /// * `index` - Bit index to check
75    ///
76    /// # Returns
77    /// `true` if the bit is set, `false` otherwise
78    pub fn get(&self, index: usize) -> bool {
79        if index >= self.size {
80            return false;
81        }
82        let word_idx = index / 64;
83        let bit_idx = index % 64;
84        let word = self.bits[word_idx].load(Ordering::Acquire);
85        (word & (1u64 << bit_idx)) != 0
86    }
87
88    /// Clear all bits in the bitmap.
89    pub fn clear_all(&self) {
90        for word in &self.bits {
91            word.store(0, Ordering::Release);
92        }
93    }
94
95    /// Get all set bit indices (for finding dirty pages).
96    ///
97    /// # Returns
98    /// Vector of indices where bits are set
99    pub fn get_set_indices(&self) -> Vec<usize> {
100        let mut indices = Vec::new();
101        for (word_idx, word) in self.bits.iter().enumerate() {
102            let mut w = word.load(Ordering::Acquire);
103            while w != 0 {
104                let bit_idx = w.trailing_zeros() as usize;
105                indices.push(word_idx * 64 + bit_idx);
106                w &= w - 1; // Clear lowest set bit
107            }
108        }
109        indices
110    }
111}
112
113/// Memory-mapped embedding manager with dirty tracking and prefetching.
114///
115/// Manages large embedding matrices that may not fit in RAM using memory-mapped files.
116/// Tracks which embeddings have been accessed and modified for efficient I/O.
117#[derive(Debug)]
118pub struct MmapManager {
119    /// The memory-mapped file
120    #[allow(dead_code)]
121    file: File,
122    /// Mutable memory mapping
123    mmap: MmapMut,
124    /// Operating system page size
125    page_size: usize,
126    /// Embedding dimension
127    d_embed: usize,
128    /// Bitmap tracking which embeddings have been accessed
129    access_bitmap: AtomicBitmap,
130    /// Bitmap tracking which embeddings have been modified
131    dirty_bitmap: AtomicBitmap,
132    /// Pin count for each page (prevents eviction)
133    #[allow(dead_code)]
134    pin_count: Vec<AtomicU32>,
135    /// Maximum number of nodes
136    max_nodes: usize,
137}
138
139impl MmapManager {
140    /// Create a new memory-mapped embedding manager.
141    ///
142    /// # Arguments
143    /// * `path` - Path to the memory-mapped file
144    /// * `d_embed` - Embedding dimension
145    /// * `max_nodes` - Maximum number of nodes to support
146    ///
147    /// # Returns
148    /// A new `MmapManager` instance
149    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
150        // Calculate required file size
151        let embedding_size = d_embed * std::mem::size_of::<f32>();
152        let file_size = max_nodes * embedding_size;
153
154        // Create or open the file
155        let file = OpenOptions::new()
156            .read(true)
157            .write(true)
158            .create(true)
159            .truncate(false)
160            .open(path)
161            .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
162
163        // Set file size
164        file.set_len(file_size as u64)
165            .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
166
167        // Create memory mapping
168        let mmap = unsafe {
169            MmapOptions::new()
170                .len(file_size)
171                .map_mut(&file)
172                .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
173        };
174
175        // Get system page size
176        let page_size = page_size::get();
177        let num_pages = file_size.div_ceil(page_size);
178
179        Ok(Self {
180            file,
181            mmap,
182            page_size,
183            d_embed,
184            access_bitmap: AtomicBitmap::new(max_nodes),
185            dirty_bitmap: AtomicBitmap::new(max_nodes),
186            pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
187            max_nodes,
188        })
189    }
190
191    /// Calculate the byte offset for a given node's embedding.
192    ///
193    /// # Arguments
194    /// * `node_id` - Node identifier
195    ///
196    /// # Returns
197    /// Byte offset in the memory-mapped file, or None if overflow would occur
198    ///
199    /// # Security
200    /// Uses checked arithmetic to prevent integer overflow attacks.
201    #[inline]
202    pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
203        let node_idx = usize::try_from(node_id).ok()?;
204        let elem_size = std::mem::size_of::<f32>();
205        node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
206    }
207
208    /// Validate that a node_id is within bounds.
209    #[inline]
210    fn validate_node_id(&self, node_id: u64) -> bool {
211        (node_id as usize) < self.max_nodes
212    }
213
214    /// Get a read-only reference to a node's embedding.
215    ///
216    /// # Arguments
217    /// * `node_id` - Node identifier
218    ///
219    /// # Returns
220    /// Slice containing the embedding vector
221    ///
222    /// # Panics
223    /// Panics if node_id is out of bounds or would cause overflow
224    pub fn get_embedding(&self, node_id: u64) -> &[f32] {
225        // Security: Validate bounds before any pointer arithmetic
226        assert!(
227            self.validate_node_id(node_id),
228            "node_id {} out of bounds (max: {})",
229            node_id,
230            self.max_nodes
231        );
232
233        let offset = self
234            .embedding_offset(node_id)
235            .expect("embedding offset calculation overflow");
236        let end = offset
237            .checked_add(
238                self.d_embed
239                    .checked_mul(std::mem::size_of::<f32>())
240                    .unwrap(),
241            )
242            .expect("end offset overflow");
243        assert!(
244            end <= self.mmap.len(),
245            "embedding extends beyond mmap bounds"
246        );
247
248        // Mark as accessed
249        self.access_bitmap.set(node_id as usize);
250
251        // Safety: We control the offset and know the data is properly aligned
252        unsafe {
253            let ptr = self.mmap.as_ptr().add(offset) as *const f32;
254            std::slice::from_raw_parts(ptr, self.d_embed)
255        }
256    }
257
258    /// Set a node's embedding data.
259    ///
260    /// # Arguments
261    /// * `node_id` - Node identifier
262    /// * `data` - Embedding vector to write
263    ///
264    /// # Panics
265    /// Panics if node_id is out of bounds, data length doesn't match d_embed,
266    /// or offset calculation would overflow.
267    pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
268        // Security: Validate bounds first
269        assert!(
270            self.validate_node_id(node_id),
271            "node_id {} out of bounds (max: {})",
272            node_id,
273            self.max_nodes
274        );
275        assert_eq!(
276            data.len(),
277            self.d_embed,
278            "Embedding data length must match d_embed"
279        );
280
281        let offset = self
282            .embedding_offset(node_id)
283            .expect("embedding offset calculation overflow");
284        let end = offset
285            .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
286            .expect("end offset overflow");
287        assert!(
288            end <= self.mmap.len(),
289            "embedding extends beyond mmap bounds"
290        );
291
292        // Mark as accessed and dirty
293        self.access_bitmap.set(node_id as usize);
294        self.dirty_bitmap.set(node_id as usize);
295
296        // Safety: We control the offset and know the data is properly aligned
297        unsafe {
298            let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
299            std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
300        }
301    }
302
303    /// Flush all dirty pages to disk.
304    ///
305    /// # Returns
306    /// `Ok(())` on success, error otherwise
307    pub fn flush_dirty(&self) -> io::Result<()> {
308        let dirty_nodes = self.dirty_bitmap.get_set_indices();
309
310        if dirty_nodes.is_empty() {
311            return Ok(());
312        }
313
314        // Flush the entire mmap for simplicity
315        // In a production system, you might want to flush only dirty pages
316        self.mmap.flush()?;
317
318        // Clear dirty bitmap after successful flush
319        for &node_id in &dirty_nodes {
320            self.dirty_bitmap.clear(node_id);
321        }
322
323        Ok(())
324    }
325
326    /// Prefetch embeddings into memory for better cache locality.
327    ///
328    /// # Arguments
329    /// * `node_ids` - List of node IDs to prefetch
330    pub fn prefetch(&self, node_ids: &[u64]) {
331        #[cfg(target_os = "linux")]
332        {
333            #[allow(unused_imports)]
334            use std::os::unix::io::AsRawFd;
335
336            for &node_id in node_ids {
337                // Skip invalid node IDs
338                if !self.validate_node_id(node_id) {
339                    continue;
340                }
341                let offset = match self.embedding_offset(node_id) {
342                    Some(o) => o,
343                    None => continue,
344                };
345                let page_offset = (offset / self.page_size) * self.page_size;
346                let length = self.d_embed * std::mem::size_of::<f32>();
347
348                unsafe {
349                    // Use madvise to hint the kernel to prefetch
350                    libc::madvise(
351                        self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
352                        length,
353                        libc::MADV_WILLNEED,
354                    );
355                }
356            }
357        }
358
359        // On non-Linux platforms, just access the data to bring it into cache
360        #[cfg(not(target_os = "linux"))]
361        {
362            for &node_id in node_ids {
363                if self.validate_node_id(node_id) {
364                    let _ = self.get_embedding(node_id);
365                }
366            }
367        }
368    }
369
370    /// Get the embedding dimension.
371    pub fn d_embed(&self) -> usize {
372        self.d_embed
373    }
374
375    /// Get the maximum number of nodes.
376    pub fn max_nodes(&self) -> usize {
377        self.max_nodes
378    }
379}
380
381/// Memory-mapped gradient accumulator with fine-grained locking.
382///
383/// Allows multiple threads to accumulate gradients concurrently with minimal contention.
384/// Uses reader-writer locks at a configurable granularity.
385pub struct MmapGradientAccumulator {
386    /// Memory-mapped gradient storage (using UnsafeCell for interior mutability)
387    grad_mmap: std::cell::UnsafeCell<MmapMut>,
388    /// Number of nodes per lock (lock granularity)
389    lock_granularity: usize,
390    /// Reader-writer locks for gradient regions
391    locks: Vec<RwLock<()>>,
392    /// Number of nodes
393    n_nodes: usize,
394    /// Embedding dimension
395    d_embed: usize,
396    /// Gradient file
397    _file: File,
398}
399
400impl MmapGradientAccumulator {
401    /// Create a new memory-mapped gradient accumulator.
402    ///
403    /// # Arguments
404    /// * `path` - Path to the gradient file
405    /// * `d_embed` - Embedding dimension
406    /// * `max_nodes` - Maximum number of nodes
407    ///
408    /// # Returns
409    /// A new `MmapGradientAccumulator` instance
410    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
411        // Calculate required file size
412        let grad_size = d_embed * std::mem::size_of::<f32>();
413        let file_size = max_nodes * grad_size;
414
415        // Create or open the file
416        let file = OpenOptions::new()
417            .read(true)
418            .write(true)
419            .create(true)
420            .truncate(false)
421            .open(path)
422            .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
423
424        // Set file size
425        file.set_len(file_size as u64)
426            .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
427
428        // Create memory mapping
429        let grad_mmap = unsafe {
430            MmapOptions::new()
431                .len(file_size)
432                .map_mut(&file)
433                .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
434        };
435
436        // Zero out the gradients
437        for byte in grad_mmap.iter() {
438            // This forces the pages to be allocated and zeroed
439            let _ = byte;
440        }
441
442        // Use a lock granularity of 64 nodes per lock for good parallelism
443        let lock_granularity = 64;
444        let num_locks = max_nodes.div_ceil(lock_granularity);
445        let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
446
447        Ok(Self {
448            grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
449            lock_granularity,
450            locks,
451            n_nodes: max_nodes,
452            d_embed,
453            _file: file,
454        })
455    }
456
457    /// Calculate the byte offset for a node's gradient.
458    ///
459    /// # Arguments
460    /// * `node_id` - Node identifier
461    ///
462    /// # Returns
463    /// Byte offset in the gradient file, or None on overflow or out-of-bounds
464    ///
465    /// # Security
466    /// Uses checked arithmetic to prevent integer overflow (SEC-001).
467    #[inline]
468    pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
469        let node_idx = usize::try_from(node_id).ok()?;
470        if node_idx >= self.n_nodes {
471            return None;
472        }
473        let elem_size = std::mem::size_of::<f32>();
474        node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
475    }
476
477    /// Accumulate gradients for a specific node.
478    ///
479    /// # Arguments
480    /// * `node_id` - Node identifier
481    /// * `grad` - Gradient vector to accumulate
482    ///
483    /// # Panics
484    /// Panics if grad length doesn't match d_embed
485    pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
486        assert_eq!(
487            grad.len(),
488            self.d_embed,
489            "Gradient length must match d_embed"
490        );
491
492        let offset = self
493            .grad_offset(node_id)
494            .expect("node_id out of bounds or offset overflow");
495
496        let lock_idx = (node_id as usize) / self.lock_granularity;
497        assert!(lock_idx < self.locks.len(), "lock index out of bounds");
498        let _lock = self.locks[lock_idx].write();
499
500        // Safety: We validated node_id bounds and offset above, and hold the write lock
501        unsafe {
502            let mmap = &mut *self.grad_mmap.get();
503            assert!(
504                offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
505                "gradient write would exceed mmap bounds"
506            );
507            let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
508            let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
509
510            // Accumulate gradients
511            for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
512                *g += new_g;
513            }
514        }
515    }
516
517    /// Apply accumulated gradients to embeddings and zero out gradients.
518    ///
519    /// # Arguments
520    /// * `learning_rate` - Learning rate for gradient descent
521    /// * `embeddings` - Embedding manager to update
522    pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
523        assert_eq!(
524            self.d_embed, embeddings.d_embed,
525            "Gradient and embedding dimensions must match"
526        );
527
528        // Process all nodes
529        for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
530            let grad = self.get_grad(node_id as u64);
531            let embedding = embeddings.get_embedding(node_id as u64);
532
533            // Apply gradient descent: embedding -= learning_rate * grad
534            let mut updated = vec![0.0f32; self.d_embed];
535            for i in 0..self.d_embed {
536                updated[i] = embedding[i] - learning_rate * grad[i];
537            }
538
539            embeddings.set_embedding(node_id as u64, &updated);
540        }
541
542        // Zero out gradients after applying
543        self.zero_grad();
544    }
545
546    /// Zero out all accumulated gradients.
547    pub fn zero_grad(&mut self) {
548        // Zero the entire gradient buffer
549        unsafe {
550            let mmap = &mut *self.grad_mmap.get();
551            for byte in mmap.iter_mut() {
552                *byte = 0;
553            }
554        }
555    }
556
557    /// Get a read-only reference to a node's accumulated gradient.
558    ///
559    /// # Arguments
560    /// * `node_id` - Node identifier
561    ///
562    /// # Returns
563    /// Slice containing the gradient vector
564    pub fn get_grad(&self, node_id: u64) -> &[f32] {
565        let offset = self
566            .grad_offset(node_id)
567            .expect("node_id out of bounds or offset overflow");
568
569        let lock_idx = (node_id as usize) / self.lock_granularity;
570        assert!(lock_idx < self.locks.len(), "lock index out of bounds");
571        let _lock = self.locks[lock_idx].read();
572
573        // Safety: We validated node_id bounds and offset above, and hold the read lock
574        unsafe {
575            let mmap = &*self.grad_mmap.get();
576            assert!(
577                offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
578                "gradient read would exceed mmap bounds"
579            );
580            let ptr = mmap.as_ptr().add(offset) as *const f32;
581            std::slice::from_raw_parts(ptr, self.d_embed)
582        }
583    }
584
585    /// Get the embedding dimension.
586    pub fn d_embed(&self) -> usize {
587        self.d_embed
588    }
589
590    /// Get the number of nodes.
591    pub fn n_nodes(&self) -> usize {
592        self.n_nodes
593    }
594}
595
596// Implement Drop to ensure proper cleanup
597impl Drop for MmapManager {
598    fn drop(&mut self) {
599        // Try to flush dirty pages before dropping
600        let _ = self.flush_dirty();
601    }
602}
603
604impl Drop for MmapGradientAccumulator {
605    fn drop(&mut self) {
606        // Flush gradient data
607        unsafe {
608            let mmap = &mut *self.grad_mmap.get();
609            let _ = mmap.flush();
610        }
611    }
612}
613
614// Safety: MmapGradientAccumulator is safe to send between threads
615// because access is protected by RwLocks
616unsafe impl Send for MmapGradientAccumulator {}
617unsafe impl Sync for MmapGradientAccumulator {}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622    use tempfile::TempDir;
623
624    #[test]
625    fn test_atomic_bitmap_basic() {
626        let bitmap = AtomicBitmap::new(128);
627
628        assert!(!bitmap.get(0));
629        assert!(!bitmap.get(127));
630
631        bitmap.set(0);
632        bitmap.set(127);
633        bitmap.set(64);
634
635        assert!(bitmap.get(0));
636        assert!(bitmap.get(127));
637        assert!(bitmap.get(64));
638        assert!(!bitmap.get(1));
639
640        bitmap.clear(0);
641        assert!(!bitmap.get(0));
642        assert!(bitmap.get(127));
643    }
644
645    #[test]
646    fn test_atomic_bitmap_get_set_indices() {
647        let bitmap = AtomicBitmap::new(256);
648
649        bitmap.set(0);
650        bitmap.set(63);
651        bitmap.set(64);
652        bitmap.set(128);
653        bitmap.set(255);
654
655        let mut indices = bitmap.get_set_indices();
656        indices.sort();
657
658        assert_eq!(indices, vec![0, 63, 64, 128, 255]);
659    }
660
661    #[test]
662    fn test_atomic_bitmap_clear_all() {
663        let bitmap = AtomicBitmap::new(128);
664
665        bitmap.set(0);
666        bitmap.set(64);
667        bitmap.set(127);
668
669        assert!(bitmap.get(0));
670
671        bitmap.clear_all();
672
673        assert!(!bitmap.get(0));
674        assert!(!bitmap.get(64));
675        assert!(!bitmap.get(127));
676    }
677
678    #[test]
679    fn test_mmap_manager_creation() {
680        let temp_dir = TempDir::new().unwrap();
681        let path = temp_dir.path().join("embeddings.bin");
682
683        let manager = MmapManager::new(&path, 128, 1000).unwrap();
684
685        assert_eq!(manager.d_embed(), 128);
686        assert_eq!(manager.max_nodes(), 1000);
687        assert!(path.exists());
688    }
689
690    #[test]
691    fn test_mmap_manager_set_get_embedding() {
692        let temp_dir = TempDir::new().unwrap();
693        let path = temp_dir.path().join("embeddings.bin");
694
695        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
696
697        let embedding = vec![1.0f32; 64];
698        manager.set_embedding(0, &embedding);
699
700        let retrieved = manager.get_embedding(0);
701        assert_eq!(retrieved.len(), 64);
702        assert_eq!(retrieved[0], 1.0);
703        assert_eq!(retrieved[63], 1.0);
704    }
705
706    #[test]
707    fn test_mmap_manager_multiple_embeddings() {
708        let temp_dir = TempDir::new().unwrap();
709        let path = temp_dir.path().join("embeddings.bin");
710
711        let mut manager = MmapManager::new(&path, 32, 100).unwrap();
712
713        for i in 0..10 {
714            let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
715            manager.set_embedding(i, &embedding);
716        }
717
718        // Verify each embedding
719        for i in 0..10 {
720            let retrieved = manager.get_embedding(i);
721            assert_eq!(retrieved.len(), 32);
722            assert_eq!(retrieved[0], (i * 32) as f32);
723            assert_eq!(retrieved[31], (i * 32 + 31) as f32);
724        }
725    }
726
727    #[test]
728    fn test_mmap_manager_dirty_tracking() {
729        let temp_dir = TempDir::new().unwrap();
730        let path = temp_dir.path().join("embeddings.bin");
731
732        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
733
734        let embedding = vec![2.0f32; 64];
735        manager.set_embedding(5, &embedding);
736
737        // Should be marked as dirty
738        assert!(manager.dirty_bitmap.get(5));
739
740        // Flush and check it's clean
741        manager.flush_dirty().unwrap();
742        assert!(!manager.dirty_bitmap.get(5));
743    }
744
745    #[test]
746    fn test_mmap_manager_persistence() {
747        let temp_dir = TempDir::new().unwrap();
748        let path = temp_dir.path().join("embeddings.bin");
749
750        {
751            let mut manager = MmapManager::new(&path, 64, 100).unwrap();
752            let embedding = vec![1.5f32; 64];
753            manager.set_embedding(10, &embedding);
754            manager.flush_dirty().unwrap();
755        }
756
757        // Reopen and verify data persisted
758        {
759            let manager = MmapManager::new(&path, 64, 100).unwrap();
760            let retrieved = manager.get_embedding(10);
761            assert_eq!(retrieved[0], 1.5);
762            assert_eq!(retrieved[63], 1.5);
763        }
764    }
765
766    #[test]
767    fn test_gradient_accumulator_creation() {
768        let temp_dir = TempDir::new().unwrap();
769        let path = temp_dir.path().join("gradients.bin");
770
771        let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
772
773        assert_eq!(accumulator.d_embed(), 128);
774        assert_eq!(accumulator.n_nodes(), 1000);
775        assert!(path.exists());
776    }
777
778    #[test]
779    fn test_gradient_accumulator_accumulate() {
780        let temp_dir = TempDir::new().unwrap();
781        let path = temp_dir.path().join("gradients.bin");
782
783        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
784
785        let grad1 = vec![1.0f32; 64];
786        let grad2 = vec![2.0f32; 64];
787
788        accumulator.accumulate(0, &grad1);
789        accumulator.accumulate(0, &grad2);
790
791        let accumulated = accumulator.get_grad(0);
792        assert_eq!(accumulated[0], 3.0);
793        assert_eq!(accumulated[63], 3.0);
794    }
795
796    #[test]
797    fn test_gradient_accumulator_zero_grad() {
798        let temp_dir = TempDir::new().unwrap();
799        let path = temp_dir.path().join("gradients.bin");
800
801        let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
802
803        let grad = vec![1.5f32; 64];
804        accumulator.accumulate(0, &grad);
805
806        let accumulated = accumulator.get_grad(0);
807        assert_eq!(accumulated[0], 1.5);
808
809        accumulator.zero_grad();
810
811        let zeroed = accumulator.get_grad(0);
812        assert_eq!(zeroed[0], 0.0);
813        assert_eq!(zeroed[63], 0.0);
814    }
815
816    #[test]
817    fn test_gradient_accumulator_apply() {
818        let temp_dir = TempDir::new().unwrap();
819        let embed_path = temp_dir.path().join("embeddings.bin");
820        let grad_path = temp_dir.path().join("gradients.bin");
821
822        let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
823        let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
824
825        // Set initial embedding
826        let initial = vec![10.0f32; 32];
827        embeddings.set_embedding(0, &initial);
828
829        // Accumulate gradient
830        let grad = vec![1.0f32; 32];
831        accumulator.accumulate(0, &grad);
832
833        // Apply with learning rate 0.1
834        accumulator.apply(0.1, &mut embeddings);
835
836        // Check updated embedding: 10.0 - 0.1 * 1.0 = 9.9
837        let updated = embeddings.get_embedding(0);
838        assert!((updated[0] - 9.9).abs() < 1e-6);
839
840        // Check gradients were zeroed
841        let zeroed_grad = accumulator.get_grad(0);
842        assert_eq!(zeroed_grad[0], 0.0);
843    }
844
845    #[test]
846    fn test_gradient_accumulator_concurrent_accumulation() {
847        use std::thread;
848
849        let temp_dir = TempDir::new().unwrap();
850        let path = temp_dir.path().join("gradients.bin");
851
852        let accumulator =
853            std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
854
855        let mut handles = vec![];
856
857        // Spawn 10 threads, each accumulating 1.0 to node 0
858        for _ in 0..10 {
859            let acc = accumulator.clone();
860            let handle = thread::spawn(move || {
861                let grad = vec![1.0f32; 64];
862                acc.accumulate(0, &grad);
863            });
864            handles.push(handle);
865        }
866
867        for handle in handles {
868            handle.join().unwrap();
869        }
870
871        // Should have accumulated 10.0
872        let result = accumulator.get_grad(0);
873        assert_eq!(result[0], 10.0);
874    }
875
876    #[test]
877    fn test_embedding_offset_calculation() {
878        let temp_dir = TempDir::new().unwrap();
879        let path = temp_dir.path().join("embeddings.bin");
880
881        let manager = MmapManager::new(&path, 64, 100).unwrap();
882
883        assert_eq!(manager.embedding_offset(0), Some(0));
884        assert_eq!(manager.embedding_offset(1), Some(64 * 4)); // 64 floats * 4 bytes
885        assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
886    }
887
888    #[test]
889    fn test_grad_offset_calculation() {
890        let temp_dir = TempDir::new().unwrap();
891        let path = temp_dir.path().join("gradients.bin");
892
893        let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
894
895        assert_eq!(accumulator.grad_offset(0), Some(0));
896        assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); // 128 floats * 4 bytes
897        assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
898    }
899
900    #[test]
901    #[should_panic(expected = "Embedding data length must match d_embed")]
902    fn test_set_embedding_wrong_size() {
903        let temp_dir = TempDir::new().unwrap();
904        let path = temp_dir.path().join("embeddings.bin");
905
906        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
907        let wrong_size = vec![1.0f32; 32]; // Should be 64
908        manager.set_embedding(0, &wrong_size);
909    }
910
911    #[test]
912    #[should_panic(expected = "Gradient length must match d_embed")]
913    fn test_accumulate_wrong_size() {
914        let temp_dir = TempDir::new().unwrap();
915        let path = temp_dir.path().join("gradients.bin");
916
917        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
918        let wrong_size = vec![1.0f32; 32]; // Should be 64
919        accumulator.accumulate(0, &wrong_size);
920    }
921
922    #[test]
923    fn test_prefetch() {
924        let temp_dir = TempDir::new().unwrap();
925        let path = temp_dir.path().join("embeddings.bin");
926
927        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
928
929        // Set some embeddings
930        for i in 0..10 {
931            let embedding = vec![i as f32; 64];
932            manager.set_embedding(i, &embedding);
933        }
934
935        // Prefetch should not crash
936        manager.prefetch(&[0, 1, 2, 3, 4]);
937
938        // Access should still work
939        let retrieved = manager.get_embedding(2);
940        assert_eq!(retrieved[0], 2.0);
941    }
942}