Skip to main content

ruvector_gnn/
mmap.rs

1//! Memory-mapped embedding management for large-scale GNN training.
2//!
3//! This module provides efficient memory-mapped access to embeddings and gradients
4//! that don't fit in RAM. It includes:
5//! - `MmapManager`: Memory-mapped embedding storage with dirty tracking
6//! - `MmapGradientAccumulator`: Lock-free gradient accumulation
7//! - `AtomicBitmap`: Thread-safe bitmap for access/dirty tracking
8//!
9//! Only available on non-WASM targets.
10
11#![cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
12
13use crate::error::{GnnError, Result};
14use memmap2::{MmapMut, MmapOptions};
15use parking_lot::RwLock;
16use std::fs::{File, OpenOptions};
17use std::io::{self, Write};
18use std::path::Path;
19use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
20
21/// Thread-safe bitmap using atomic operations.
22///
23/// Used for tracking which embeddings have been accessed or modified.
24/// Each bit represents one embedding node.
25#[derive(Debug)]
26pub struct AtomicBitmap {
27    /// Array of 64-bit atomic integers, each storing 64 bits
28    bits: Vec<AtomicU64>,
29    /// Total number of bits (nodes)
30    size: usize,
31}
32
33impl AtomicBitmap {
34    /// Create a new atomic bitmap with the specified capacity.
35    ///
36    /// # Arguments
37    /// * `size` - Number of bits to allocate
38    pub fn new(size: usize) -> Self {
39        let num_words = (size + 63) / 64;
40        let bits = (0..num_words).map(|_| AtomicU64::new(0)).collect();
41
42        Self { bits, size }
43    }
44
45    /// Set a bit to 1 (mark as accessed/dirty).
46    ///
47    /// # Arguments
48    /// * `index` - Bit index to set
49    pub fn set(&self, index: usize) {
50        if index >= self.size {
51            return;
52        }
53        let word_idx = index / 64;
54        let bit_idx = index % 64;
55        self.bits[word_idx].fetch_or(1u64 << bit_idx, Ordering::Release);
56    }
57
58    /// Clear a bit to 0 (mark as clean/not accessed).
59    ///
60    /// # Arguments
61    /// * `index` - Bit index to clear
62    pub fn clear(&self, index: usize) {
63        if index >= self.size {
64            return;
65        }
66        let word_idx = index / 64;
67        let bit_idx = index % 64;
68        self.bits[word_idx].fetch_and(!(1u64 << bit_idx), Ordering::Release);
69    }
70
71    /// Check if a bit is set.
72    ///
73    /// # Arguments
74    /// * `index` - Bit index to check
75    ///
76    /// # Returns
77    /// `true` if the bit is set, `false` otherwise
78    pub fn get(&self, index: usize) -> bool {
79        if index >= self.size {
80            return false;
81        }
82        let word_idx = index / 64;
83        let bit_idx = index % 64;
84        let word = self.bits[word_idx].load(Ordering::Acquire);
85        (word & (1u64 << bit_idx)) != 0
86    }
87
88    /// Clear all bits in the bitmap.
89    pub fn clear_all(&self) {
90        for word in &self.bits {
91            word.store(0, Ordering::Release);
92        }
93    }
94
95    /// Get all set bit indices (for finding dirty pages).
96    ///
97    /// # Returns
98    /// Vector of indices where bits are set
99    pub fn get_set_indices(&self) -> Vec<usize> {
100        let mut indices = Vec::new();
101        for (word_idx, word) in self.bits.iter().enumerate() {
102            let mut w = word.load(Ordering::Acquire);
103            while w != 0 {
104                let bit_idx = w.trailing_zeros() as usize;
105                indices.push(word_idx * 64 + bit_idx);
106                w &= w - 1; // Clear lowest set bit
107            }
108        }
109        indices
110    }
111}
112
113/// Memory-mapped embedding manager with dirty tracking and prefetching.
114///
115/// Manages large embedding matrices that may not fit in RAM using memory-mapped files.
116/// Tracks which embeddings have been accessed and modified for efficient I/O.
117#[derive(Debug)]
118pub struct MmapManager {
119    /// The memory-mapped file
120    file: File,
121    /// Mutable memory mapping
122    mmap: MmapMut,
123    /// Operating system page size
124    page_size: usize,
125    /// Embedding dimension
126    d_embed: usize,
127    /// Bitmap tracking which embeddings have been accessed
128    access_bitmap: AtomicBitmap,
129    /// Bitmap tracking which embeddings have been modified
130    dirty_bitmap: AtomicBitmap,
131    /// Pin count for each page (prevents eviction)
132    pin_count: Vec<AtomicU32>,
133    /// Maximum number of nodes
134    max_nodes: usize,
135}
136
137impl MmapManager {
138    /// Create a new memory-mapped embedding manager.
139    ///
140    /// # Arguments
141    /// * `path` - Path to the memory-mapped file
142    /// * `d_embed` - Embedding dimension
143    /// * `max_nodes` - Maximum number of nodes to support
144    ///
145    /// # Returns
146    /// A new `MmapManager` instance
147    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
148        // Calculate required file size
149        let embedding_size = d_embed * std::mem::size_of::<f32>();
150        let file_size = max_nodes * embedding_size;
151
152        // Create or open the file
153        let file = OpenOptions::new()
154            .read(true)
155            .write(true)
156            .create(true)
157            .open(path)
158            .map_err(|e| GnnError::mmap(format!("Failed to open mmap file: {}", e)))?;
159
160        // Set file size
161        file.set_len(file_size as u64)
162            .map_err(|e| GnnError::mmap(format!("Failed to set file size: {}", e)))?;
163
164        // Create memory mapping
165        let mmap = unsafe {
166            MmapOptions::new()
167                .len(file_size)
168                .map_mut(&file)
169                .map_err(|e| GnnError::mmap(format!("Failed to create mmap: {}", e)))?
170        };
171
172        // Get system page size
173        let page_size = page_size::get();
174        let num_pages = (file_size + page_size - 1) / page_size;
175
176        Ok(Self {
177            file,
178            mmap,
179            page_size,
180            d_embed,
181            access_bitmap: AtomicBitmap::new(max_nodes),
182            dirty_bitmap: AtomicBitmap::new(max_nodes),
183            pin_count: (0..num_pages).map(|_| AtomicU32::new(0)).collect(),
184            max_nodes,
185        })
186    }
187
188    /// Calculate the byte offset for a given node's embedding.
189    ///
190    /// # Arguments
191    /// * `node_id` - Node identifier
192    ///
193    /// # Returns
194    /// Byte offset in the memory-mapped file, or None if overflow would occur
195    ///
196    /// # Security
197    /// Uses checked arithmetic to prevent integer overflow attacks.
198    #[inline]
199    pub fn embedding_offset(&self, node_id: u64) -> Option<usize> {
200        let node_idx = usize::try_from(node_id).ok()?;
201        let elem_size = std::mem::size_of::<f32>();
202        node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
203    }
204
205    /// Validate that a node_id is within bounds.
206    #[inline]
207    fn validate_node_id(&self, node_id: u64) -> bool {
208        (node_id as usize) < self.max_nodes
209    }
210
211    /// Get a read-only reference to a node's embedding.
212    ///
213    /// # Arguments
214    /// * `node_id` - Node identifier
215    ///
216    /// # Returns
217    /// Slice containing the embedding vector
218    ///
219    /// # Panics
220    /// Panics if node_id is out of bounds or would cause overflow
221    pub fn get_embedding(&self, node_id: u64) -> &[f32] {
222        // Security: Validate bounds before any pointer arithmetic
223        assert!(
224            self.validate_node_id(node_id),
225            "node_id {} out of bounds (max: {})",
226            node_id,
227            self.max_nodes
228        );
229
230        let offset = self
231            .embedding_offset(node_id)
232            .expect("embedding offset calculation overflow");
233        let end = offset
234            .checked_add(
235                self.d_embed
236                    .checked_mul(std::mem::size_of::<f32>())
237                    .unwrap(),
238            )
239            .expect("end offset overflow");
240        assert!(
241            end <= self.mmap.len(),
242            "embedding extends beyond mmap bounds"
243        );
244
245        // Mark as accessed
246        self.access_bitmap.set(node_id as usize);
247
248        // Safety: We control the offset and know the data is properly aligned
249        unsafe {
250            let ptr = self.mmap.as_ptr().add(offset) as *const f32;
251            std::slice::from_raw_parts(ptr, self.d_embed)
252        }
253    }
254
255    /// Set a node's embedding data.
256    ///
257    /// # Arguments
258    /// * `node_id` - Node identifier
259    /// * `data` - Embedding vector to write
260    ///
261    /// # Panics
262    /// Panics if node_id is out of bounds, data length doesn't match d_embed,
263    /// or offset calculation would overflow.
264    pub fn set_embedding(&mut self, node_id: u64, data: &[f32]) {
265        // Security: Validate bounds first
266        assert!(
267            self.validate_node_id(node_id),
268            "node_id {} out of bounds (max: {})",
269            node_id,
270            self.max_nodes
271        );
272        assert_eq!(
273            data.len(),
274            self.d_embed,
275            "Embedding data length must match d_embed"
276        );
277
278        let offset = self
279            .embedding_offset(node_id)
280            .expect("embedding offset calculation overflow");
281        let end = offset
282            .checked_add(data.len().checked_mul(std::mem::size_of::<f32>()).unwrap())
283            .expect("end offset overflow");
284        assert!(
285            end <= self.mmap.len(),
286            "embedding extends beyond mmap bounds"
287        );
288
289        // Mark as accessed and dirty
290        self.access_bitmap.set(node_id as usize);
291        self.dirty_bitmap.set(node_id as usize);
292
293        // Safety: We control the offset and know the data is properly aligned
294        unsafe {
295            let ptr = self.mmap.as_mut_ptr().add(offset) as *mut f32;
296            std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, self.d_embed);
297        }
298    }
299
300    /// Flush all dirty pages to disk.
301    ///
302    /// # Returns
303    /// `Ok(())` on success, error otherwise
304    pub fn flush_dirty(&self) -> io::Result<()> {
305        let dirty_nodes = self.dirty_bitmap.get_set_indices();
306
307        if dirty_nodes.is_empty() {
308            return Ok(());
309        }
310
311        // Flush the entire mmap for simplicity
312        // In a production system, you might want to flush only dirty pages
313        self.mmap.flush()?;
314
315        // Clear dirty bitmap after successful flush
316        for &node_id in &dirty_nodes {
317            self.dirty_bitmap.clear(node_id);
318        }
319
320        Ok(())
321    }
322
323    /// Prefetch embeddings into memory for better cache locality.
324    ///
325    /// # Arguments
326    /// * `node_ids` - List of node IDs to prefetch
327    pub fn prefetch(&self, node_ids: &[u64]) {
328        #[cfg(target_os = "linux")]
329        {
330            #[allow(unused_imports)]
331            use std::os::unix::io::AsRawFd;
332
333            for &node_id in node_ids {
334                // Skip invalid node IDs
335                if !self.validate_node_id(node_id) {
336                    continue;
337                }
338                let offset = match self.embedding_offset(node_id) {
339                    Some(o) => o,
340                    None => continue,
341                };
342                let page_offset = (offset / self.page_size) * self.page_size;
343                let length = self.d_embed * std::mem::size_of::<f32>();
344
345                unsafe {
346                    // Use madvise to hint the kernel to prefetch
347                    libc::madvise(
348                        self.mmap.as_ptr().add(page_offset) as *mut libc::c_void,
349                        length,
350                        libc::MADV_WILLNEED,
351                    );
352                }
353            }
354        }
355
356        // On non-Linux platforms, just access the data to bring it into cache
357        #[cfg(not(target_os = "linux"))]
358        {
359            for &node_id in node_ids {
360                if self.validate_node_id(node_id) {
361                    let _ = self.get_embedding(node_id);
362                }
363            }
364        }
365    }
366
367    /// Get the embedding dimension.
368    pub fn d_embed(&self) -> usize {
369        self.d_embed
370    }
371
372    /// Get the maximum number of nodes.
373    pub fn max_nodes(&self) -> usize {
374        self.max_nodes
375    }
376}
377
378/// Memory-mapped gradient accumulator with fine-grained locking.
379///
380/// Allows multiple threads to accumulate gradients concurrently with minimal contention.
381/// Uses reader-writer locks at a configurable granularity.
382pub struct MmapGradientAccumulator {
383    /// Memory-mapped gradient storage (using UnsafeCell for interior mutability)
384    grad_mmap: std::cell::UnsafeCell<MmapMut>,
385    /// Number of nodes per lock (lock granularity)
386    lock_granularity: usize,
387    /// Reader-writer locks for gradient regions
388    locks: Vec<RwLock<()>>,
389    /// Number of nodes
390    n_nodes: usize,
391    /// Embedding dimension
392    d_embed: usize,
393    /// Gradient file
394    _file: File,
395}
396
397impl MmapGradientAccumulator {
398    /// Create a new memory-mapped gradient accumulator.
399    ///
400    /// # Arguments
401    /// * `path` - Path to the gradient file
402    /// * `d_embed` - Embedding dimension
403    /// * `max_nodes` - Maximum number of nodes
404    ///
405    /// # Returns
406    /// A new `MmapGradientAccumulator` instance
407    pub fn new(path: &Path, d_embed: usize, max_nodes: usize) -> Result<Self> {
408        // Calculate required file size
409        let grad_size = d_embed * std::mem::size_of::<f32>();
410        let file_size = max_nodes * grad_size;
411
412        // Create or open the file
413        let file = OpenOptions::new()
414            .read(true)
415            .write(true)
416            .create(true)
417            .open(path)
418            .map_err(|e| GnnError::mmap(format!("Failed to open gradient file: {}", e)))?;
419
420        // Set file size
421        file.set_len(file_size as u64)
422            .map_err(|e| GnnError::mmap(format!("Failed to set gradient file size: {}", e)))?;
423
424        // Create memory mapping
425        let grad_mmap = unsafe {
426            MmapOptions::new()
427                .len(file_size)
428                .map_mut(&file)
429                .map_err(|e| GnnError::mmap(format!("Failed to create gradient mmap: {}", e)))?
430        };
431
432        // Zero out the gradients
433        for byte in grad_mmap.iter() {
434            // This forces the pages to be allocated and zeroed
435            let _ = byte;
436        }
437
438        // Use a lock granularity of 64 nodes per lock for good parallelism
439        let lock_granularity = 64;
440        let num_locks = (max_nodes + lock_granularity - 1) / lock_granularity;
441        let locks = (0..num_locks).map(|_| RwLock::new(())).collect();
442
443        Ok(Self {
444            grad_mmap: std::cell::UnsafeCell::new(grad_mmap),
445            lock_granularity,
446            locks,
447            n_nodes: max_nodes,
448            d_embed,
449            _file: file,
450        })
451    }
452
453    /// Calculate the byte offset for a node's gradient.
454    ///
455    /// # Arguments
456    /// * `node_id` - Node identifier
457    ///
458    /// # Returns
459    /// Byte offset in the gradient file, or None on overflow or out-of-bounds
460    ///
461    /// # Security
462    /// Uses checked arithmetic to prevent integer overflow (SEC-001).
463    #[inline]
464    pub fn grad_offset(&self, node_id: u64) -> Option<usize> {
465        let node_idx = usize::try_from(node_id).ok()?;
466        if node_idx >= self.n_nodes {
467            return None;
468        }
469        let elem_size = std::mem::size_of::<f32>();
470        node_idx.checked_mul(self.d_embed)?.checked_mul(elem_size)
471    }
472
473    /// Accumulate gradients for a specific node.
474    ///
475    /// # Arguments
476    /// * `node_id` - Node identifier
477    /// * `grad` - Gradient vector to accumulate
478    ///
479    /// # Panics
480    /// Panics if grad length doesn't match d_embed
481    pub fn accumulate(&self, node_id: u64, grad: &[f32]) {
482        assert_eq!(
483            grad.len(),
484            self.d_embed,
485            "Gradient length must match d_embed"
486        );
487
488        let offset = self.grad_offset(node_id)
489            .expect("node_id out of bounds or offset overflow");
490
491        let lock_idx = (node_id as usize) / self.lock_granularity;
492        assert!(lock_idx < self.locks.len(), "lock index out of bounds");
493        let _lock = self.locks[lock_idx].write();
494
495        // Safety: We validated node_id bounds and offset above, and hold the write lock
496        unsafe {
497            let mmap = &mut *self.grad_mmap.get();
498            assert!(offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
499                "gradient write would exceed mmap bounds");
500            let ptr = mmap.as_mut_ptr().add(offset) as *mut f32;
501            let grad_slice = std::slice::from_raw_parts_mut(ptr, self.d_embed);
502
503            // Accumulate gradients
504            for (g, &new_g) in grad_slice.iter_mut().zip(grad.iter()) {
505                *g += new_g;
506            }
507        }
508    }
509
510    /// Apply accumulated gradients to embeddings and zero out gradients.
511    ///
512    /// # Arguments
513    /// * `learning_rate` - Learning rate for gradient descent
514    /// * `embeddings` - Embedding manager to update
515    pub fn apply(&mut self, learning_rate: f32, embeddings: &mut MmapManager) {
516        assert_eq!(
517            self.d_embed, embeddings.d_embed,
518            "Gradient and embedding dimensions must match"
519        );
520
521        // Process all nodes
522        for node_id in 0..self.n_nodes.min(embeddings.max_nodes) {
523            let grad = self.get_grad(node_id as u64);
524            let embedding = embeddings.get_embedding(node_id as u64);
525
526            // Apply gradient descent: embedding -= learning_rate * grad
527            let mut updated = vec![0.0f32; self.d_embed];
528            for i in 0..self.d_embed {
529                updated[i] = embedding[i] - learning_rate * grad[i];
530            }
531
532            embeddings.set_embedding(node_id as u64, &updated);
533        }
534
535        // Zero out gradients after applying
536        self.zero_grad();
537    }
538
539    /// Zero out all accumulated gradients.
540    pub fn zero_grad(&mut self) {
541        // Zero the entire gradient buffer
542        unsafe {
543            let mmap = &mut *self.grad_mmap.get();
544            for byte in mmap.iter_mut() {
545                *byte = 0;
546            }
547        }
548    }
549
550    /// Get a read-only reference to a node's accumulated gradient.
551    ///
552    /// # Arguments
553    /// * `node_id` - Node identifier
554    ///
555    /// # Returns
556    /// Slice containing the gradient vector
557    pub fn get_grad(&self, node_id: u64) -> &[f32] {
558        let offset = self.grad_offset(node_id)
559            .expect("node_id out of bounds or offset overflow");
560
561        let lock_idx = (node_id as usize) / self.lock_granularity;
562        assert!(lock_idx < self.locks.len(), "lock index out of bounds");
563        let _lock = self.locks[lock_idx].read();
564
565        // Safety: We validated node_id bounds and offset above, and hold the read lock
566        unsafe {
567            let mmap = &*self.grad_mmap.get();
568            assert!(offset + self.d_embed * std::mem::size_of::<f32>() <= mmap.len(),
569                "gradient read would exceed mmap bounds");
570            let ptr = mmap.as_ptr().add(offset) as *const f32;
571            std::slice::from_raw_parts(ptr, self.d_embed)
572        }
573    }
574
575    /// Get the embedding dimension.
576    pub fn d_embed(&self) -> usize {
577        self.d_embed
578    }
579
580    /// Get the number of nodes.
581    pub fn n_nodes(&self) -> usize {
582        self.n_nodes
583    }
584}
585
586// Implement Drop to ensure proper cleanup
587impl Drop for MmapManager {
588    fn drop(&mut self) {
589        // Try to flush dirty pages before dropping
590        let _ = self.flush_dirty();
591    }
592}
593
594impl Drop for MmapGradientAccumulator {
595    fn drop(&mut self) {
596        // Flush gradient data
597        unsafe {
598            let mmap = &mut *self.grad_mmap.get();
599            let _ = mmap.flush();
600        }
601    }
602}
603
604// Safety: MmapGradientAccumulator is safe to send between threads
605// because access is protected by RwLocks
606unsafe impl Send for MmapGradientAccumulator {}
607unsafe impl Sync for MmapGradientAccumulator {}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612    use std::fs;
613    use tempfile::TempDir;
614
615    #[test]
616    fn test_atomic_bitmap_basic() {
617        let bitmap = AtomicBitmap::new(128);
618
619        assert!(!bitmap.get(0));
620        assert!(!bitmap.get(127));
621
622        bitmap.set(0);
623        bitmap.set(127);
624        bitmap.set(64);
625
626        assert!(bitmap.get(0));
627        assert!(bitmap.get(127));
628        assert!(bitmap.get(64));
629        assert!(!bitmap.get(1));
630
631        bitmap.clear(0);
632        assert!(!bitmap.get(0));
633        assert!(bitmap.get(127));
634    }
635
636    #[test]
637    fn test_atomic_bitmap_get_set_indices() {
638        let bitmap = AtomicBitmap::new(256);
639
640        bitmap.set(0);
641        bitmap.set(63);
642        bitmap.set(64);
643        bitmap.set(128);
644        bitmap.set(255);
645
646        let mut indices = bitmap.get_set_indices();
647        indices.sort();
648
649        assert_eq!(indices, vec![0, 63, 64, 128, 255]);
650    }
651
652    #[test]
653    fn test_atomic_bitmap_clear_all() {
654        let bitmap = AtomicBitmap::new(128);
655
656        bitmap.set(0);
657        bitmap.set(64);
658        bitmap.set(127);
659
660        assert!(bitmap.get(0));
661
662        bitmap.clear_all();
663
664        assert!(!bitmap.get(0));
665        assert!(!bitmap.get(64));
666        assert!(!bitmap.get(127));
667    }
668
669    #[test]
670    fn test_mmap_manager_creation() {
671        let temp_dir = TempDir::new().unwrap();
672        let path = temp_dir.path().join("embeddings.bin");
673
674        let manager = MmapManager::new(&path, 128, 1000).unwrap();
675
676        assert_eq!(manager.d_embed(), 128);
677        assert_eq!(manager.max_nodes(), 1000);
678        assert!(path.exists());
679    }
680
681    #[test]
682    fn test_mmap_manager_set_get_embedding() {
683        let temp_dir = TempDir::new().unwrap();
684        let path = temp_dir.path().join("embeddings.bin");
685
686        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
687
688        let embedding = vec![1.0f32; 64];
689        manager.set_embedding(0, &embedding);
690
691        let retrieved = manager.get_embedding(0);
692        assert_eq!(retrieved.len(), 64);
693        assert_eq!(retrieved[0], 1.0);
694        assert_eq!(retrieved[63], 1.0);
695    }
696
697    #[test]
698    fn test_mmap_manager_multiple_embeddings() {
699        let temp_dir = TempDir::new().unwrap();
700        let path = temp_dir.path().join("embeddings.bin");
701
702        let mut manager = MmapManager::new(&path, 32, 100).unwrap();
703
704        for i in 0..10 {
705            let embedding: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
706            manager.set_embedding(i, &embedding);
707        }
708
709        // Verify each embedding
710        for i in 0..10 {
711            let retrieved = manager.get_embedding(i);
712            assert_eq!(retrieved.len(), 32);
713            assert_eq!(retrieved[0], (i * 32) as f32);
714            assert_eq!(retrieved[31], (i * 32 + 31) as f32);
715        }
716    }
717
718    #[test]
719    fn test_mmap_manager_dirty_tracking() {
720        let temp_dir = TempDir::new().unwrap();
721        let path = temp_dir.path().join("embeddings.bin");
722
723        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
724
725        let embedding = vec![2.0f32; 64];
726        manager.set_embedding(5, &embedding);
727
728        // Should be marked as dirty
729        assert!(manager.dirty_bitmap.get(5));
730
731        // Flush and check it's clean
732        manager.flush_dirty().unwrap();
733        assert!(!manager.dirty_bitmap.get(5));
734    }
735
736    #[test]
737    fn test_mmap_manager_persistence() {
738        let temp_dir = TempDir::new().unwrap();
739        let path = temp_dir.path().join("embeddings.bin");
740
741        {
742            let mut manager = MmapManager::new(&path, 64, 100).unwrap();
743            let embedding = vec![3.14f32; 64];
744            manager.set_embedding(10, &embedding);
745            manager.flush_dirty().unwrap();
746        }
747
748        // Reopen and verify data persisted
749        {
750            let manager = MmapManager::new(&path, 64, 100).unwrap();
751            let retrieved = manager.get_embedding(10);
752            assert_eq!(retrieved[0], 3.14);
753            assert_eq!(retrieved[63], 3.14);
754        }
755    }
756
757    #[test]
758    fn test_gradient_accumulator_creation() {
759        let temp_dir = TempDir::new().unwrap();
760        let path = temp_dir.path().join("gradients.bin");
761
762        let accumulator = MmapGradientAccumulator::new(&path, 128, 1000).unwrap();
763
764        assert_eq!(accumulator.d_embed(), 128);
765        assert_eq!(accumulator.n_nodes(), 1000);
766        assert!(path.exists());
767    }
768
769    #[test]
770    fn test_gradient_accumulator_accumulate() {
771        let temp_dir = TempDir::new().unwrap();
772        let path = temp_dir.path().join("gradients.bin");
773
774        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
775
776        let grad1 = vec![1.0f32; 64];
777        let grad2 = vec![2.0f32; 64];
778
779        accumulator.accumulate(0, &grad1);
780        accumulator.accumulate(0, &grad2);
781
782        let accumulated = accumulator.get_grad(0);
783        assert_eq!(accumulated[0], 3.0);
784        assert_eq!(accumulated[63], 3.0);
785    }
786
787    #[test]
788    fn test_gradient_accumulator_zero_grad() {
789        let temp_dir = TempDir::new().unwrap();
790        let path = temp_dir.path().join("gradients.bin");
791
792        let mut accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
793
794        let grad = vec![1.5f32; 64];
795        accumulator.accumulate(0, &grad);
796
797        let accumulated = accumulator.get_grad(0);
798        assert_eq!(accumulated[0], 1.5);
799
800        accumulator.zero_grad();
801
802        let zeroed = accumulator.get_grad(0);
803        assert_eq!(zeroed[0], 0.0);
804        assert_eq!(zeroed[63], 0.0);
805    }
806
807    #[test]
808    fn test_gradient_accumulator_apply() {
809        let temp_dir = TempDir::new().unwrap();
810        let embed_path = temp_dir.path().join("embeddings.bin");
811        let grad_path = temp_dir.path().join("gradients.bin");
812
813        let mut embeddings = MmapManager::new(&embed_path, 32, 100).unwrap();
814        let mut accumulator = MmapGradientAccumulator::new(&grad_path, 32, 100).unwrap();
815
816        // Set initial embedding
817        let initial = vec![10.0f32; 32];
818        embeddings.set_embedding(0, &initial);
819
820        // Accumulate gradient
821        let grad = vec![1.0f32; 32];
822        accumulator.accumulate(0, &grad);
823
824        // Apply with learning rate 0.1
825        accumulator.apply(0.1, &mut embeddings);
826
827        // Check updated embedding: 10.0 - 0.1 * 1.0 = 9.9
828        let updated = embeddings.get_embedding(0);
829        assert!((updated[0] - 9.9).abs() < 1e-6);
830
831        // Check gradients were zeroed
832        let zeroed_grad = accumulator.get_grad(0);
833        assert_eq!(zeroed_grad[0], 0.0);
834    }
835
836    #[test]
837    fn test_gradient_accumulator_concurrent_accumulation() {
838        use std::thread;
839
840        let temp_dir = TempDir::new().unwrap();
841        let path = temp_dir.path().join("gradients.bin");
842
843        let accumulator =
844            std::sync::Arc::new(MmapGradientAccumulator::new(&path, 64, 100).unwrap());
845
846        let mut handles = vec![];
847
848        // Spawn 10 threads, each accumulating 1.0 to node 0
849        for _ in 0..10 {
850            let acc = accumulator.clone();
851            let handle = thread::spawn(move || {
852                let grad = vec![1.0f32; 64];
853                acc.accumulate(0, &grad);
854            });
855            handles.push(handle);
856        }
857
858        for handle in handles {
859            handle.join().unwrap();
860        }
861
862        // Should have accumulated 10.0
863        let result = accumulator.get_grad(0);
864        assert_eq!(result[0], 10.0);
865    }
866
867    #[test]
868    fn test_embedding_offset_calculation() {
869        let temp_dir = TempDir::new().unwrap();
870        let path = temp_dir.path().join("embeddings.bin");
871
872        let manager = MmapManager::new(&path, 64, 100).unwrap();
873
874        assert_eq!(manager.embedding_offset(0), Some(0));
875        assert_eq!(manager.embedding_offset(1), Some(64 * 4)); // 64 floats * 4 bytes
876        assert_eq!(manager.embedding_offset(10), Some(64 * 4 * 10));
877    }
878
879    #[test]
880    fn test_grad_offset_calculation() {
881        let temp_dir = TempDir::new().unwrap();
882        let path = temp_dir.path().join("gradients.bin");
883
884        let accumulator = MmapGradientAccumulator::new(&path, 128, 100).unwrap();
885
886        assert_eq!(accumulator.grad_offset(0), Some(0));
887        assert_eq!(accumulator.grad_offset(1), Some(128 * 4)); // 128 floats * 4 bytes
888        assert_eq!(accumulator.grad_offset(5), Some(128 * 4 * 5));
889    }
890
891    #[test]
892    #[should_panic(expected = "Embedding data length must match d_embed")]
893    fn test_set_embedding_wrong_size() {
894        let temp_dir = TempDir::new().unwrap();
895        let path = temp_dir.path().join("embeddings.bin");
896
897        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
898        let wrong_size = vec![1.0f32; 32]; // Should be 64
899        manager.set_embedding(0, &wrong_size);
900    }
901
902    #[test]
903    #[should_panic(expected = "Gradient length must match d_embed")]
904    fn test_accumulate_wrong_size() {
905        let temp_dir = TempDir::new().unwrap();
906        let path = temp_dir.path().join("gradients.bin");
907
908        let accumulator = MmapGradientAccumulator::new(&path, 64, 100).unwrap();
909        let wrong_size = vec![1.0f32; 32]; // Should be 64
910        accumulator.accumulate(0, &wrong_size);
911    }
912
913    #[test]
914    fn test_prefetch() {
915        let temp_dir = TempDir::new().unwrap();
916        let path = temp_dir.path().join("embeddings.bin");
917
918        let mut manager = MmapManager::new(&path, 64, 100).unwrap();
919
920        // Set some embeddings
921        for i in 0..10 {
922            let embedding = vec![i as f32; 64];
923            manager.set_embedding(i, &embedding);
924        }
925
926        // Prefetch should not crash
927        manager.prefetch(&[0, 1, 2, 3, 4]);
928
929        // Access should still work
930        let retrieved = manager.get_embedding(2);
931        assert_eq!(retrieved[0], 2.0);
932    }
933}