sklears_clustering/
memory_mapped.rs

1//! Memory-mapped distance matrix computation for large datasets
2//!
3//! This module provides memory-mapped distance matrix computation that allows
4//! processing of datasets larger than available RAM by storing intermediate
5//! results on disk and accessing them through memory mapping.
6
7use memmap2::{MmapMut, MmapOptions};
8use std::fs::{File, OpenOptions};
9use std::io::Write;
10use std::path::PathBuf;
11use tempfile::TempDir;
12
13use sklears_core::{
14    error::{Result, SklearsError},
15    types::{Array2, Float},
16};
17
18use crate::simd_distances::{simd_distance, SimdDistanceMetric};
19
20/// Configuration for memory-mapped distance computation
21#[derive(Debug, Clone)]
22pub struct MemoryMappedConfig {
23    /// Directory for temporary files (None for system temp)
24    pub temp_dir: Option<PathBuf>,
25    /// Chunk size for processing (number of samples per chunk)
26    pub chunk_size: usize,
27    /// Distance metric to use
28    pub metric: SimdDistanceMetric,
29    /// Whether to use compression for temporary files
30    pub use_compression: bool,
31    /// Whether to keep temporary files for debugging
32    pub keep_temp_files: bool,
33}
34
35impl Default for MemoryMappedConfig {
36    fn default() -> Self {
37        Self {
38            temp_dir: None,
39            chunk_size: 1000,
40            metric: SimdDistanceMetric::Euclidean,
41            use_compression: false,
42            keep_temp_files: false,
43        }
44    }
45}
46
47/// Memory-mapped distance matrix for large-scale distance computation
48pub struct MemoryMappedDistanceMatrix {
49    /// Configuration
50    config: MemoryMappedConfig,
51    /// Number of samples
52    n_samples: usize,
53    /// Temporary directory
54    temp_dir: TempDir,
55    /// Memory-mapped file for distance matrix
56    distance_file: File,
57    /// Memory map of the distance matrix
58    distance_mmap: Option<MmapMut>,
59    /// Size of each distance value in bytes
60    value_size: usize,
61}
62
63impl MemoryMappedDistanceMatrix {
64    /// Create a new memory-mapped distance matrix
65    ///
66    /// # Arguments
67    /// * `n_samples` - Number of samples in the dataset
68    /// * `config` - Configuration for memory mapping
69    pub fn new(n_samples: usize, config: MemoryMappedConfig) -> Result<Self> {
70        // Create temporary directory
71        let temp_dir = if let Some(ref dir) = config.temp_dir {
72            TempDir::new_in(dir)
73        } else {
74            TempDir::new()
75        }
76        .map_err(|e| SklearsError::Other(format!("Failed to create temp directory: {}", e)))?;
77
78        // Calculate required file size for distance matrix
79        // We store only the upper triangle since distance matrices are symmetric
80        let n_pairs = (n_samples * (n_samples - 1)) / 2;
81        let value_size = std::mem::size_of::<Float>();
82        let file_size = n_pairs * value_size;
83
84        // Create memory-mapped file
85        let distance_file_path = temp_dir.path().join("distance_matrix.bin");
86        let distance_file = OpenOptions::new()
87            .read(true)
88            .write(true)
89            .create(true)
90            .truncate(true)
91            .open(&distance_file_path)
92            .map_err(|e| SklearsError::Other(format!("Failed to create distance file: {}", e)))?;
93
94        // Set file size
95        distance_file
96            .set_len(file_size as u64)
97            .map_err(|e| SklearsError::Other(format!("Failed to set file size: {}", e)))?;
98
99        Ok(Self {
100            config,
101            n_samples,
102            temp_dir,
103            distance_file,
104            distance_mmap: None,
105            value_size,
106        })
107    }
108
109    /// Initialize the memory map
110    fn initialize_mmap(&mut self) -> Result<()> {
111        if self.distance_mmap.is_none() {
112            let mmap = unsafe {
113                MmapOptions::new()
114                    .map_mut(&self.distance_file)
115                    .map_err(|e| {
116                        SklearsError::Other(format!("Failed to create memory map: {}", e))
117                    })?
118            };
119            self.distance_mmap = Some(mmap);
120        }
121        Ok(())
122    }
123
124    /// Compute distance matrix in chunks and store in memory-mapped file
125    ///
126    /// # Arguments
127    /// * `data` - Input data matrix (n_samples × n_features)
128    pub fn compute_distances(&mut self, data: &Array2<Float>) -> Result<()> {
129        if data.nrows() != self.n_samples {
130            return Err(SklearsError::InvalidInput(format!(
131                "Data has {} samples but expected {}",
132                data.nrows(),
133                self.n_samples
134            )));
135        }
136
137        self.initialize_mmap()?;
138
139        let chunk_size = self.config.chunk_size;
140        let n_chunks = (self.n_samples + chunk_size - 1) / chunk_size;
141
142        // Process data in chunks to manage memory usage
143        for i_chunk in 0..n_chunks {
144            let i_start = i_chunk * chunk_size;
145            let i_end = (i_start + chunk_size).min(self.n_samples);
146
147            for j_chunk in i_chunk..n_chunks {
148                let j_start = j_chunk * chunk_size;
149                let j_end = (j_start + chunk_size).min(self.n_samples);
150
151                // Compute distances between chunk i and chunk j
152                self.compute_chunk_distances(data, i_start, i_end, j_start, j_end)?;
153            }
154
155            // Log progress
156            if (i_chunk + 1) % 10 == 0 || i_chunk == n_chunks - 1 {
157                eprintln!("Processed chunk {} of {}", i_chunk + 1, n_chunks);
158            }
159        }
160
161        Ok(())
162    }
163
164    /// Compute distances between two chunks
165    fn compute_chunk_distances(
166        &mut self,
167        data: &Array2<Float>,
168        i_start: usize,
169        i_end: usize,
170        j_start: usize,
171        j_end: usize,
172    ) -> Result<()> {
173        for i in i_start..i_end {
174            let j_min = if i_start == j_start { i + 1 } else { j_start };
175            for j in j_min..j_end {
176                if i < j {
177                    let row_i = data.row(i);
178                    let row_j = data.row(j);
179
180                    let distance =
181                        simd_distance(&row_i, &row_j, self.config.metric).map_err(|e| {
182                            SklearsError::NumericalError(format!(
183                                "SIMD distance computation failed: {}",
184                                e
185                            ))
186                        })?;
187
188                    self.set_distance(i, j, distance)?;
189                }
190            }
191        }
192        Ok(())
193    }
194
195    /// Convert (i, j) indices to linear index in upper triangular storage
196    fn indices_to_linear(&self, i: usize, j: usize) -> usize {
197        assert!(i < j, "Only upper triangle is stored (i must be < j)");
198        assert!(i < self.n_samples && j < self.n_samples);
199
200        // Formula for upper triangular matrix indexing
201        let n = self.n_samples;
202        i * n - (i * (i + 1)) / 2 + j - i - 1
203    }
204
205    /// Set distance value at position (i, j)
206    fn set_distance(&mut self, i: usize, j: usize, distance: Float) -> Result<()> {
207        let linear_index = self.indices_to_linear(i, j);
208        let byte_offset = linear_index * self.value_size;
209
210        if let Some(ref mut mmap) = self.distance_mmap {
211            let bytes = distance.to_le_bytes();
212            let start = byte_offset;
213            let end = start + self.value_size;
214
215            if end <= mmap.len() {
216                mmap[start..end].copy_from_slice(&bytes);
217            } else {
218                return Err(SklearsError::Other(format!(
219                    "Index out of bounds: {} >= {}",
220                    end,
221                    mmap.len()
222                )));
223            }
224        }
225        Ok(())
226    }
227
228    /// Get distance value at position (i, j)
229    pub fn get_distance(&self, i: usize, j: usize) -> Result<Float> {
230        if i == j {
231            return Ok(0.0);
232        }
233
234        let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
235        let linear_index = self.indices_to_linear(min_idx, max_idx);
236        let byte_offset = linear_index * self.value_size;
237
238        if let Some(ref mmap) = self.distance_mmap {
239            let start = byte_offset;
240            let end = start + self.value_size;
241
242            if end <= mmap.len() {
243                let bytes: [u8; 8] = mmap[start..end].try_into().map_err(|_| {
244                    SklearsError::Other("Failed to read distance bytes".to_string())
245                })?;
246                Ok(Float::from_le_bytes(bytes))
247            } else {
248                Err(SklearsError::Other(format!(
249                    "Index out of bounds: {} >= {}",
250                    end,
251                    mmap.len()
252                )))
253            }
254        } else {
255            Err(SklearsError::Other(
256                "Memory map not initialized".to_string(),
257            ))
258        }
259    }
260
261    /// Get k-nearest neighbors for a specific sample
262    ///
263    /// # Arguments
264    /// * `sample_idx` - Index of the sample
265    /// * `k` - Number of nearest neighbors to find
266    ///
267    /// # Returns
268    /// Vector of (neighbor_index, distance) pairs sorted by distance
269    pub fn get_k_nearest_neighbors(
270        &self,
271        sample_idx: usize,
272        k: usize,
273    ) -> Result<Vec<(usize, Float)>> {
274        if sample_idx >= self.n_samples {
275            return Err(SklearsError::InvalidInput(format!(
276                "Sample index {} out of bounds (max: {})",
277                sample_idx,
278                self.n_samples - 1
279            )));
280        }
281
282        let mut neighbors = Vec::new();
283
284        // Collect all distances for this sample
285        for other_idx in 0..self.n_samples {
286            if other_idx != sample_idx {
287                let distance = self.get_distance(sample_idx, other_idx)?;
288                neighbors.push((other_idx, distance));
289            }
290        }
291
292        // Sort by distance and take k nearest
293        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
294        neighbors.truncate(k);
295
296        Ok(neighbors)
297    }
298
299    /// Get all neighbors within a specific radius
300    ///
301    /// # Arguments
302    /// * `sample_idx` - Index of the sample
303    /// * `radius` - Maximum distance for neighbors
304    ///
305    /// # Returns
306    /// Vector of (neighbor_index, distance) pairs within radius
307    pub fn get_neighbors_within_radius(
308        &self,
309        sample_idx: usize,
310        radius: Float,
311    ) -> Result<Vec<(usize, Float)>> {
312        if sample_idx >= self.n_samples {
313            return Err(SklearsError::InvalidInput(format!(
314                "Sample index {} out of bounds (max: {})",
315                sample_idx,
316                self.n_samples - 1
317            )));
318        }
319
320        let mut neighbors = Vec::new();
321
322        for other_idx in 0..self.n_samples {
323            if other_idx != sample_idx {
324                let distance = self.get_distance(sample_idx, other_idx)?;
325                if distance <= radius {
326                    neighbors.push((other_idx, distance));
327                }
328            }
329        }
330
331        // Sort by distance
332        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
333
334        Ok(neighbors)
335    }
336
337    /// Get memory usage statistics
338    pub fn memory_stats(&self) -> MemoryStats {
339        let n_pairs = (self.n_samples * (self.n_samples - 1)) / 2;
340        let matrix_size_bytes = n_pairs * self.value_size;
341        let matrix_size_mb = matrix_size_bytes as f64 / (1024.0 * 1024.0);
342        let matrix_size_gb = matrix_size_mb / 1024.0;
343
344        MemoryStats {
345            n_samples: self.n_samples,
346            n_pairs,
347            matrix_size_bytes,
348            matrix_size_mb,
349            matrix_size_gb,
350            temp_dir_path: self.temp_dir.path().to_path_buf(),
351        }
352    }
353
354    /// Export distance matrix to a standard format
355    ///
356    /// This method exports the distance matrix as a regular Array2 for smaller datasets
357    /// or when you need to integrate with other algorithms that expect in-memory matrices.
358    ///
359    /// Warning: This will load the entire distance matrix into memory.
360    pub fn to_array(&self) -> Result<Array2<Float>> {
361        if self.n_samples > 10000 {
362            eprintln!("Warning: Converting large distance matrix ({} samples) to Array2. This may use significant memory.", self.n_samples);
363        }
364
365        let mut matrix = Array2::zeros((self.n_samples, self.n_samples));
366
367        for i in 0..self.n_samples {
368            for j in i + 1..self.n_samples {
369                let distance = self.get_distance(i, j)?;
370                matrix[[i, j]] = distance;
371                matrix[[j, i]] = distance; // Symmetric
372            }
373        }
374
375        Ok(matrix)
376    }
377
378    /// Flush any pending writes to disk
379    pub fn sync(&mut self) -> Result<()> {
380        if let Some(ref mut mmap) = self.distance_mmap {
381            mmap.flush()
382                .map_err(|e| SklearsError::Other(format!("Failed to sync memory map: {}", e)))?;
383        }
384        Ok(())
385    }
386}
387
388/// Memory usage statistics for the distance matrix
389#[derive(Debug, Clone)]
390pub struct MemoryStats {
391    /// Number of samples
392    pub n_samples: usize,
393    /// Number of distance pairs stored
394    pub n_pairs: usize,
395    /// Size of distance matrix in bytes
396    pub matrix_size_bytes: usize,
397    /// Size of distance matrix in MB
398    pub matrix_size_mb: f64,
399    /// Size of distance matrix in GB
400    pub matrix_size_gb: f64,
401    /// Path to temporary directory
402    pub temp_dir_path: PathBuf,
403}
404
405impl std::fmt::Display for MemoryStats {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        write!(
408            f,
409            "MemoryStats {{ samples: {}, pairs: {}, size: {:.2} MB ({:.2} GB), temp: {:?} }}",
410            self.n_samples,
411            self.n_pairs,
412            self.matrix_size_mb,
413            self.matrix_size_gb,
414            self.temp_dir_path
415        )
416    }
417}
418
419#[allow(non_snake_case)]
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use scirs2_core::ndarray::array;
424
425    #[test]
426    fn test_memory_mapped_small_dataset() {
427        let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
428
429        let config = MemoryMappedConfig {
430            chunk_size: 2,
431            ..Default::default()
432        };
433
434        let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
435        mmap_matrix.compute_distances(&data).unwrap();
436
437        // Test distance computation
438        let dist_01 = mmap_matrix.get_distance(0, 1).unwrap();
439        let dist_02 = mmap_matrix.get_distance(0, 2).unwrap();
440        let dist_03 = mmap_matrix.get_distance(0, 3).unwrap();
441
442        // Verify expected distances
443        assert!((dist_01 - 1.0).abs() < 1e-6); // Distance between (0,0) and (1,0)
444        assert!((dist_02 - 1.0).abs() < 1e-6); // Distance between (0,0) and (0,1)
445        assert!((dist_03 - 2.0_f64.sqrt()).abs() < 1e-6); // Distance between (0,0) and (1,1)
446
447        // Test symmetry
448        let dist_10 = mmap_matrix.get_distance(1, 0).unwrap();
449        assert!((dist_01 - dist_10).abs() < 1e-10);
450    }
451
452    #[test]
453    fn test_k_nearest_neighbors() {
454        let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [10.0, 10.0],];
455
456        let config = MemoryMappedConfig::default();
457        let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
458        mmap_matrix.compute_distances(&data).unwrap();
459
460        // Get 2 nearest neighbors of point 0
461        let neighbors = mmap_matrix.get_k_nearest_neighbors(0, 2).unwrap();
462
463        assert_eq!(neighbors.len(), 2);
464
465        // Points 1 and 2 should be the nearest to point 0
466        let neighbor_indices: Vec<usize> = neighbors.iter().map(|&(idx, _)| idx).collect();
467        assert!(neighbor_indices.contains(&1));
468        assert!(neighbor_indices.contains(&2));
469        assert!(!neighbor_indices.contains(&3)); // Point 3 is farthest
470    }
471
472    #[test]
473    fn test_neighbors_within_radius() {
474        let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [10.0, 10.0],];
475
476        let config = MemoryMappedConfig::default();
477        let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
478        mmap_matrix.compute_distances(&data).unwrap();
479
480        // Get neighbors within radius 1.5 of point 0
481        let neighbors = mmap_matrix.get_neighbors_within_radius(0, 1.5).unwrap();
482
483        // Should include points 1 and 2 but not 3
484        assert_eq!(neighbors.len(), 2);
485        let neighbor_indices: Vec<usize> = neighbors.iter().map(|&(idx, _)| idx).collect();
486        assert!(neighbor_indices.contains(&1));
487        assert!(neighbor_indices.contains(&2));
488        assert!(!neighbor_indices.contains(&3));
489    }
490
491    #[test]
492    fn test_memory_stats() {
493        let config = MemoryMappedConfig::default();
494        let mmap_matrix = MemoryMappedDistanceMatrix::new(100, config).unwrap();
495
496        let stats = mmap_matrix.memory_stats();
497        assert_eq!(stats.n_samples, 100);
498        assert_eq!(stats.n_pairs, (100 * 99) / 2); // Upper triangle
499        assert!(stats.matrix_size_bytes > 0);
500        assert!(stats.matrix_size_mb > 0.0);
501    }
502
503    #[test]
504    fn test_indices_to_linear() {
505        let config = MemoryMappedConfig::default();
506        let mmap_matrix = MemoryMappedDistanceMatrix::new(5, config).unwrap();
507
508        // Test a few specific index conversions
509        // For n=5, upper triangle indices should be:
510        // Row 0: (0,1)=0, (0,2)=1, (0,3)=2, (0,4)=3
511        // Row 1: (1,2)=4, (1,3)=5, (1,4)=6
512        // Row 2: (2,3)=7, (2,4)=8
513        // Row 3: (3,4)=9
514        assert_eq!(mmap_matrix.indices_to_linear(0, 1), 0);
515        assert_eq!(mmap_matrix.indices_to_linear(0, 2), 1);
516        assert_eq!(mmap_matrix.indices_to_linear(0, 3), 2);
517        assert_eq!(mmap_matrix.indices_to_linear(1, 2), 4);
518        assert_eq!(mmap_matrix.indices_to_linear(1, 3), 5);
519        assert_eq!(mmap_matrix.indices_to_linear(2, 3), 7);
520    }
521
522    #[test]
523    fn test_to_array_conversion() {
524        let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0],];
525
526        let config = MemoryMappedConfig::default();
527        let mut mmap_matrix = MemoryMappedDistanceMatrix::new(3, config).unwrap();
528        mmap_matrix.compute_distances(&data).unwrap();
529
530        let array_matrix = mmap_matrix.to_array().unwrap();
531
532        assert_eq!(array_matrix.shape(), &[3, 3]);
533
534        // Check diagonal is zero
535        for i in 0..3 {
536            assert!((array_matrix[[i, i]] - 0.0).abs() < 1e-10);
537        }
538
539        // Check symmetry
540        for i in 0..3 {
541            for j in 0..3 {
542                assert!((array_matrix[[i, j]] - array_matrix[[j, i]]).abs() < 1e-10);
543            }
544        }
545    }
546}