1use memmap2::{MmapMut, MmapOptions};
8use std::fs::{File, OpenOptions};
9use std::io::Write;
10use std::path::PathBuf;
11use tempfile::TempDir;
12
13use sklears_core::{
14 error::{Result, SklearsError},
15 types::{Array2, Float},
16};
17
18use crate::simd_distances::{simd_distance, SimdDistanceMetric};
19
20#[derive(Debug, Clone)]
22pub struct MemoryMappedConfig {
23 pub temp_dir: Option<PathBuf>,
25 pub chunk_size: usize,
27 pub metric: SimdDistanceMetric,
29 pub use_compression: bool,
31 pub keep_temp_files: bool,
33}
34
35impl Default for MemoryMappedConfig {
36 fn default() -> Self {
37 Self {
38 temp_dir: None,
39 chunk_size: 1000,
40 metric: SimdDistanceMetric::Euclidean,
41 use_compression: false,
42 keep_temp_files: false,
43 }
44 }
45}
46
47pub struct MemoryMappedDistanceMatrix {
49 config: MemoryMappedConfig,
51 n_samples: usize,
53 temp_dir: TempDir,
55 distance_file: File,
57 distance_mmap: Option<MmapMut>,
59 value_size: usize,
61}
62
63impl MemoryMappedDistanceMatrix {
64 pub fn new(n_samples: usize, config: MemoryMappedConfig) -> Result<Self> {
70 let temp_dir = if let Some(ref dir) = config.temp_dir {
72 TempDir::new_in(dir)
73 } else {
74 TempDir::new()
75 }
76 .map_err(|e| SklearsError::Other(format!("Failed to create temp directory: {}", e)))?;
77
78 let n_pairs = (n_samples * (n_samples - 1)) / 2;
81 let value_size = std::mem::size_of::<Float>();
82 let file_size = n_pairs * value_size;
83
84 let distance_file_path = temp_dir.path().join("distance_matrix.bin");
86 let distance_file = OpenOptions::new()
87 .read(true)
88 .write(true)
89 .create(true)
90 .truncate(true)
91 .open(&distance_file_path)
92 .map_err(|e| SklearsError::Other(format!("Failed to create distance file: {}", e)))?;
93
94 distance_file
96 .set_len(file_size as u64)
97 .map_err(|e| SklearsError::Other(format!("Failed to set file size: {}", e)))?;
98
99 Ok(Self {
100 config,
101 n_samples,
102 temp_dir,
103 distance_file,
104 distance_mmap: None,
105 value_size,
106 })
107 }
108
109 fn initialize_mmap(&mut self) -> Result<()> {
111 if self.distance_mmap.is_none() {
112 let mmap = unsafe {
113 MmapOptions::new()
114 .map_mut(&self.distance_file)
115 .map_err(|e| {
116 SklearsError::Other(format!("Failed to create memory map: {}", e))
117 })?
118 };
119 self.distance_mmap = Some(mmap);
120 }
121 Ok(())
122 }
123
124 pub fn compute_distances(&mut self, data: &Array2<Float>) -> Result<()> {
129 if data.nrows() != self.n_samples {
130 return Err(SklearsError::InvalidInput(format!(
131 "Data has {} samples but expected {}",
132 data.nrows(),
133 self.n_samples
134 )));
135 }
136
137 self.initialize_mmap()?;
138
139 let chunk_size = self.config.chunk_size;
140 let n_chunks = (self.n_samples + chunk_size - 1) / chunk_size;
141
142 for i_chunk in 0..n_chunks {
144 let i_start = i_chunk * chunk_size;
145 let i_end = (i_start + chunk_size).min(self.n_samples);
146
147 for j_chunk in i_chunk..n_chunks {
148 let j_start = j_chunk * chunk_size;
149 let j_end = (j_start + chunk_size).min(self.n_samples);
150
151 self.compute_chunk_distances(data, i_start, i_end, j_start, j_end)?;
153 }
154
155 if (i_chunk + 1) % 10 == 0 || i_chunk == n_chunks - 1 {
157 eprintln!("Processed chunk {} of {}", i_chunk + 1, n_chunks);
158 }
159 }
160
161 Ok(())
162 }
163
164 fn compute_chunk_distances(
166 &mut self,
167 data: &Array2<Float>,
168 i_start: usize,
169 i_end: usize,
170 j_start: usize,
171 j_end: usize,
172 ) -> Result<()> {
173 for i in i_start..i_end {
174 let j_min = if i_start == j_start { i + 1 } else { j_start };
175 for j in j_min..j_end {
176 if i < j {
177 let row_i = data.row(i);
178 let row_j = data.row(j);
179
180 let distance =
181 simd_distance(&row_i, &row_j, self.config.metric).map_err(|e| {
182 SklearsError::NumericalError(format!(
183 "SIMD distance computation failed: {}",
184 e
185 ))
186 })?;
187
188 self.set_distance(i, j, distance)?;
189 }
190 }
191 }
192 Ok(())
193 }
194
195 fn indices_to_linear(&self, i: usize, j: usize) -> usize {
197 assert!(i < j, "Only upper triangle is stored (i must be < j)");
198 assert!(i < self.n_samples && j < self.n_samples);
199
200 let n = self.n_samples;
202 i * n - (i * (i + 1)) / 2 + j - i - 1
203 }
204
205 fn set_distance(&mut self, i: usize, j: usize, distance: Float) -> Result<()> {
207 let linear_index = self.indices_to_linear(i, j);
208 let byte_offset = linear_index * self.value_size;
209
210 if let Some(ref mut mmap) = self.distance_mmap {
211 let bytes = distance.to_le_bytes();
212 let start = byte_offset;
213 let end = start + self.value_size;
214
215 if end <= mmap.len() {
216 mmap[start..end].copy_from_slice(&bytes);
217 } else {
218 return Err(SklearsError::Other(format!(
219 "Index out of bounds: {} >= {}",
220 end,
221 mmap.len()
222 )));
223 }
224 }
225 Ok(())
226 }
227
228 pub fn get_distance(&self, i: usize, j: usize) -> Result<Float> {
230 if i == j {
231 return Ok(0.0);
232 }
233
234 let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
235 let linear_index = self.indices_to_linear(min_idx, max_idx);
236 let byte_offset = linear_index * self.value_size;
237
238 if let Some(ref mmap) = self.distance_mmap {
239 let start = byte_offset;
240 let end = start + self.value_size;
241
242 if end <= mmap.len() {
243 let bytes: [u8; 8] = mmap[start..end].try_into().map_err(|_| {
244 SklearsError::Other("Failed to read distance bytes".to_string())
245 })?;
246 Ok(Float::from_le_bytes(bytes))
247 } else {
248 Err(SklearsError::Other(format!(
249 "Index out of bounds: {} >= {}",
250 end,
251 mmap.len()
252 )))
253 }
254 } else {
255 Err(SklearsError::Other(
256 "Memory map not initialized".to_string(),
257 ))
258 }
259 }
260
261 pub fn get_k_nearest_neighbors(
270 &self,
271 sample_idx: usize,
272 k: usize,
273 ) -> Result<Vec<(usize, Float)>> {
274 if sample_idx >= self.n_samples {
275 return Err(SklearsError::InvalidInput(format!(
276 "Sample index {} out of bounds (max: {})",
277 sample_idx,
278 self.n_samples - 1
279 )));
280 }
281
282 let mut neighbors = Vec::new();
283
284 for other_idx in 0..self.n_samples {
286 if other_idx != sample_idx {
287 let distance = self.get_distance(sample_idx, other_idx)?;
288 neighbors.push((other_idx, distance));
289 }
290 }
291
292 neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
294 neighbors.truncate(k);
295
296 Ok(neighbors)
297 }
298
299 pub fn get_neighbors_within_radius(
308 &self,
309 sample_idx: usize,
310 radius: Float,
311 ) -> Result<Vec<(usize, Float)>> {
312 if sample_idx >= self.n_samples {
313 return Err(SklearsError::InvalidInput(format!(
314 "Sample index {} out of bounds (max: {})",
315 sample_idx,
316 self.n_samples - 1
317 )));
318 }
319
320 let mut neighbors = Vec::new();
321
322 for other_idx in 0..self.n_samples {
323 if other_idx != sample_idx {
324 let distance = self.get_distance(sample_idx, other_idx)?;
325 if distance <= radius {
326 neighbors.push((other_idx, distance));
327 }
328 }
329 }
330
331 neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
333
334 Ok(neighbors)
335 }
336
337 pub fn memory_stats(&self) -> MemoryStats {
339 let n_pairs = (self.n_samples * (self.n_samples - 1)) / 2;
340 let matrix_size_bytes = n_pairs * self.value_size;
341 let matrix_size_mb = matrix_size_bytes as f64 / (1024.0 * 1024.0);
342 let matrix_size_gb = matrix_size_mb / 1024.0;
343
344 MemoryStats {
345 n_samples: self.n_samples,
346 n_pairs,
347 matrix_size_bytes,
348 matrix_size_mb,
349 matrix_size_gb,
350 temp_dir_path: self.temp_dir.path().to_path_buf(),
351 }
352 }
353
354 pub fn to_array(&self) -> Result<Array2<Float>> {
361 if self.n_samples > 10000 {
362 eprintln!("Warning: Converting large distance matrix ({} samples) to Array2. This may use significant memory.", self.n_samples);
363 }
364
365 let mut matrix = Array2::zeros((self.n_samples, self.n_samples));
366
367 for i in 0..self.n_samples {
368 for j in i + 1..self.n_samples {
369 let distance = self.get_distance(i, j)?;
370 matrix[[i, j]] = distance;
371 matrix[[j, i]] = distance; }
373 }
374
375 Ok(matrix)
376 }
377
378 pub fn sync(&mut self) -> Result<()> {
380 if let Some(ref mut mmap) = self.distance_mmap {
381 mmap.flush()
382 .map_err(|e| SklearsError::Other(format!("Failed to sync memory map: {}", e)))?;
383 }
384 Ok(())
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct MemoryStats {
391 pub n_samples: usize,
393 pub n_pairs: usize,
395 pub matrix_size_bytes: usize,
397 pub matrix_size_mb: f64,
399 pub matrix_size_gb: f64,
401 pub temp_dir_path: PathBuf,
403}
404
405impl std::fmt::Display for MemoryStats {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 write!(
408 f,
409 "MemoryStats {{ samples: {}, pairs: {}, size: {:.2} MB ({:.2} GB), temp: {:?} }}",
410 self.n_samples,
411 self.n_pairs,
412 self.matrix_size_mb,
413 self.matrix_size_gb,
414 self.temp_dir_path
415 )
416 }
417}
418
419#[allow(non_snake_case)]
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use scirs2_core::ndarray::array;
424
425 #[test]
426 fn test_memory_mapped_small_dataset() {
427 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
428
429 let config = MemoryMappedConfig {
430 chunk_size: 2,
431 ..Default::default()
432 };
433
434 let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
435 mmap_matrix.compute_distances(&data).unwrap();
436
437 let dist_01 = mmap_matrix.get_distance(0, 1).unwrap();
439 let dist_02 = mmap_matrix.get_distance(0, 2).unwrap();
440 let dist_03 = mmap_matrix.get_distance(0, 3).unwrap();
441
442 assert!((dist_01 - 1.0).abs() < 1e-6); assert!((dist_02 - 1.0).abs() < 1e-6); assert!((dist_03 - 2.0_f64.sqrt()).abs() < 1e-6); let dist_10 = mmap_matrix.get_distance(1, 0).unwrap();
449 assert!((dist_01 - dist_10).abs() < 1e-10);
450 }
451
452 #[test]
453 fn test_k_nearest_neighbors() {
454 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [10.0, 10.0],];
455
456 let config = MemoryMappedConfig::default();
457 let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
458 mmap_matrix.compute_distances(&data).unwrap();
459
460 let neighbors = mmap_matrix.get_k_nearest_neighbors(0, 2).unwrap();
462
463 assert_eq!(neighbors.len(), 2);
464
465 let neighbor_indices: Vec<usize> = neighbors.iter().map(|&(idx, _)| idx).collect();
467 assert!(neighbor_indices.contains(&1));
468 assert!(neighbor_indices.contains(&2));
469 assert!(!neighbor_indices.contains(&3)); }
471
472 #[test]
473 fn test_neighbors_within_radius() {
474 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [10.0, 10.0],];
475
476 let config = MemoryMappedConfig::default();
477 let mut mmap_matrix = MemoryMappedDistanceMatrix::new(4, config).unwrap();
478 mmap_matrix.compute_distances(&data).unwrap();
479
480 let neighbors = mmap_matrix.get_neighbors_within_radius(0, 1.5).unwrap();
482
483 assert_eq!(neighbors.len(), 2);
485 let neighbor_indices: Vec<usize> = neighbors.iter().map(|&(idx, _)| idx).collect();
486 assert!(neighbor_indices.contains(&1));
487 assert!(neighbor_indices.contains(&2));
488 assert!(!neighbor_indices.contains(&3));
489 }
490
491 #[test]
492 fn test_memory_stats() {
493 let config = MemoryMappedConfig::default();
494 let mmap_matrix = MemoryMappedDistanceMatrix::new(100, config).unwrap();
495
496 let stats = mmap_matrix.memory_stats();
497 assert_eq!(stats.n_samples, 100);
498 assert_eq!(stats.n_pairs, (100 * 99) / 2); assert!(stats.matrix_size_bytes > 0);
500 assert!(stats.matrix_size_mb > 0.0);
501 }
502
503 #[test]
504 fn test_indices_to_linear() {
505 let config = MemoryMappedConfig::default();
506 let mmap_matrix = MemoryMappedDistanceMatrix::new(5, config).unwrap();
507
508 assert_eq!(mmap_matrix.indices_to_linear(0, 1), 0);
515 assert_eq!(mmap_matrix.indices_to_linear(0, 2), 1);
516 assert_eq!(mmap_matrix.indices_to_linear(0, 3), 2);
517 assert_eq!(mmap_matrix.indices_to_linear(1, 2), 4);
518 assert_eq!(mmap_matrix.indices_to_linear(1, 3), 5);
519 assert_eq!(mmap_matrix.indices_to_linear(2, 3), 7);
520 }
521
522 #[test]
523 fn test_to_array_conversion() {
524 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0],];
525
526 let config = MemoryMappedConfig::default();
527 let mut mmap_matrix = MemoryMappedDistanceMatrix::new(3, config).unwrap();
528 mmap_matrix.compute_distances(&data).unwrap();
529
530 let array_matrix = mmap_matrix.to_array().unwrap();
531
532 assert_eq!(array_matrix.shape(), &[3, 3]);
533
534 for i in 0..3 {
536 assert!((array_matrix[[i, i]] - 0.0).abs() < 1e-10);
537 }
538
539 for i in 0..3 {
541 for j in 0..3 {
542 assert!((array_matrix[[i, j]] - array_matrix[[j, i]]).abs() < 1e-10);
543 }
544 }
545 }
546}