1use crate::error::{DatasetsError, Result};
8use crate::streaming::{DataChunk, StreamConfig};
9use crate::utils::Dataset;
10use memmap2::{Mmap, MmapOptions};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
12use std::fs::File;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub struct LazyLoadConfig {
19 pub target_memory_bytes: usize,
21 pub min_chunk_size: usize,
23 pub max_chunk_size: usize,
25 pub use_mmap: bool,
27 pub page_size: usize,
29 pub prefetch: bool,
31 pub lock_pages: bool,
33}
34
35impl Default for LazyLoadConfig {
36 fn default() -> Self {
37 Self {
38 target_memory_bytes: 512 * 1024 * 1024, min_chunk_size: 1000,
40 max_chunk_size: 100_000,
41 use_mmap: true,
42 page_size: 0,
43 prefetch: true,
44 lock_pages: false,
45 }
46 }
47}
48
49impl LazyLoadConfig {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn with_target_memory(mut self, bytes: usize) -> Self {
57 self.target_memory_bytes = bytes;
58 self
59 }
60
61 pub fn with_chunk_size_range(mut self, min: usize, max: usize) -> Self {
63 self.min_chunk_size = min;
64 self.max_chunk_size = max;
65 self
66 }
67
68 pub fn with_mmap(mut self, use_mmap: bool) -> Self {
70 self.use_mmap = use_mmap;
71 self
72 }
73
74 pub fn with_prefetch(mut self, prefetch: bool) -> Self {
76 self.prefetch = prefetch;
77 self
78 }
79}
80
81pub struct MmapDataset {
83 mmap: Arc<Mmap>,
84 n_samples: usize,
85 n_features: usize,
86 data_offset: usize,
87 element_size: usize,
88 config: LazyLoadConfig,
89}
90
91impl MmapDataset {
92 pub fn from_binary<P: AsRef<Path>>(
105 path: P,
106 n_samples: usize,
107 n_features: usize,
108 data_offset: usize,
109 config: LazyLoadConfig,
110 ) -> Result<Self> {
111 let file = File::open(path.as_ref()).map_err(DatasetsError::IoError)?;
112
113 let mut mmap_opts = MmapOptions::new();
114 if data_offset > 0 {
115 mmap_opts.offset(data_offset as u64);
116 }
117
118 let mmap = unsafe {
119 mmap_opts.map(&file).map_err(|e| {
120 DatasetsError::InvalidFormat(format!("Memory mapping failed: {}", e))
121 })?
122 };
123
124 if config.prefetch {
126 let _ = mmap.advise(memmap2::Advice::WillNeed);
127 }
128
129 Ok(Self {
130 mmap: Arc::new(mmap),
131 n_samples,
132 n_features,
133 data_offset,
134 element_size: std::mem::size_of::<f64>(),
135 config,
136 })
137 }
138
139 pub fn n_samples(&self) -> usize {
141 self.n_samples
142 }
143
144 pub fn n_features(&self) -> usize {
146 self.n_features
147 }
148
149 pub fn view_range(&self, start: usize, end: usize) -> Result<Array2<f64>> {
159 if start >= self.n_samples || end > self.n_samples || start >= end {
160 return Err(DatasetsError::InvalidFormat(format!(
161 "Invalid range [{}, {}) for dataset with {} samples",
162 start, end, self.n_samples
163 )));
164 }
165
166 let n_samples_in_range = end - start;
167 let start_byte = start * self.n_features * self.element_size;
168 let len_bytes = n_samples_in_range * self.n_features * self.element_size;
169
170 if start_byte + len_bytes > self.mmap.len() {
172 return Err(DatasetsError::InvalidFormat(
173 "Range exceeds available data".to_string(),
174 ));
175 }
176
177 let byte_slice = &self.mmap[start_byte..start_byte + len_bytes];
179 let (_, f64_slice, _) = unsafe { byte_slice.align_to::<f64>() };
180
181 let array =
183 Array2::from_shape_vec((n_samples_in_range, self.n_features), f64_slice.to_vec())
184 .map_err(|e| {
185 DatasetsError::InvalidFormat(format!("Array creation failed: {}", e))
186 })?;
187
188 Ok(array)
189 }
190
191 pub fn adaptive_chunk_size(&self) -> usize {
193 let bytes_per_sample = self.n_features * self.element_size;
194 let ideal_chunk = self.config.target_memory_bytes / bytes_per_sample;
195
196 ideal_chunk
198 .max(self.config.min_chunk_size)
199 .min(self.config.max_chunk_size)
200 .min(self.n_samples)
201 }
202
203 pub fn iter_chunks(&self) -> LazyChunkIterator {
205 let chunk_size = self.adaptive_chunk_size();
206 LazyChunkIterator {
207 dataset: self,
208 current_pos: 0,
209 chunk_size,
210 }
211 }
212}
213
214pub struct LazyChunkIterator<'a> {
216 dataset: &'a MmapDataset,
217 current_pos: usize,
218 chunk_size: usize,
219}
220
221impl<'a> Iterator for LazyChunkIterator<'a> {
222 type Item = Result<DataChunk>;
223
224 fn next(&mut self) -> Option<Self::Item> {
225 if self.current_pos >= self.dataset.n_samples {
226 return None;
227 }
228
229 let end = (self.current_pos + self.chunk_size).min(self.dataset.n_samples);
230 let chunk_idx = self.current_pos / self.chunk_size;
231
232 let result = self.dataset.view_range(self.current_pos, end).map(|data| {
233 let sample_indices: Vec<usize> = (self.current_pos..end).collect();
234 let is_last = end >= self.dataset.n_samples;
235
236 DataChunk {
237 data,
238 target: None,
239 chunk_index: chunk_idx,
240 sample_indices,
241 is_last,
242 }
243 });
244
245 self.current_pos = end;
246 Some(result)
247 }
248}
249
250pub struct LazyDataset {
252 path: PathBuf,
253 n_samples: usize,
254 n_features: usize,
255 data_offset: usize,
256 config: LazyLoadConfig,
257 mmap_dataset: Option<Arc<MmapDataset>>,
258}
259
260impl LazyDataset {
261 pub fn new<P: AsRef<Path>>(
263 path: P,
264 n_samples: usize,
265 n_features: usize,
266 data_offset: usize,
267 config: LazyLoadConfig,
268 ) -> Self {
269 Self {
270 path: path.as_ref().to_path_buf(),
271 n_samples,
272 n_features,
273 data_offset,
274 config,
275 mmap_dataset: None,
276 }
277 }
278
279 fn ensure_mapped(&mut self) -> Result<()> {
281 if self.mmap_dataset.is_none() {
282 let mmap = MmapDataset::from_binary(
283 &self.path,
284 self.n_samples,
285 self.n_features,
286 self.data_offset,
287 self.config.clone(),
288 )?;
289 self.mmap_dataset = Some(Arc::new(mmap));
290 }
291 Ok(())
292 }
293
294 pub fn load_range(&mut self, start: usize, end: usize) -> Result<Array2<f64>> {
296 self.ensure_mapped()?;
297 self.mmap_dataset
298 .as_ref()
299 .ok_or_else(|| DatasetsError::InvalidFormat("Dataset not mapped".to_string()))?
300 .view_range(start, end)
301 }
302
303 pub fn load_all(&mut self) -> Result<Dataset> {
305 let data = self.load_range(0, self.n_samples)?;
306 Ok(Dataset {
307 data,
308 target: None,
309 targetnames: None,
310 featurenames: None,
311 feature_descriptions: None,
312 description: None,
313 metadata: Default::default(),
314 })
315 }
316
317 pub fn shape(&self) -> (usize, usize) {
319 (self.n_samples, self.n_features)
320 }
321}
322
323pub fn from_binary<P: AsRef<Path>>(
334 path: P,
335 n_samples: usize,
336 n_features: usize,
337) -> Result<LazyDataset> {
338 Ok(LazyDataset::new(
339 path,
340 n_samples,
341 n_features,
342 0,
343 LazyLoadConfig::default(),
344 ))
345}
346
347pub fn from_binary_with_config<P: AsRef<Path>>(
349 path: P,
350 n_samples: usize,
351 n_features: usize,
352 config: LazyLoadConfig,
353) -> Result<LazyDataset> {
354 Ok(LazyDataset::new(path, n_samples, n_features, 0, config))
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use std::io::Write;
361
362 #[test]
363 fn test_lazy_load_config() {
364 let config = LazyLoadConfig::new()
365 .with_target_memory(256 * 1024 * 1024)
366 .with_chunk_size_range(500, 50_000)
367 .with_mmap(true)
368 .with_prefetch(false);
369
370 assert_eq!(config.target_memory_bytes, 256 * 1024 * 1024);
371 assert_eq!(config.min_chunk_size, 500);
372 assert_eq!(config.max_chunk_size, 50_000);
373 assert!(config.use_mmap);
374 assert!(!config.prefetch);
375 }
376
377 #[test]
378 fn test_mmap_dataset() -> Result<()> {
379 let temp_dir = tempfile::tempdir().map_err(|e| {
381 DatasetsError::InvalidFormat(format!("Failed to create temp dir: {}", e))
382 })?;
383 let file_path = temp_dir.path().join("test_data.bin");
384 let mut file = File::create(&file_path).map_err(DatasetsError::IoError)?;
385
386 let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
388 let bytes: &[u8] = unsafe {
389 std::slice::from_raw_parts(
390 data.as_ptr() as *const u8,
391 data.len() * std::mem::size_of::<f64>(),
392 )
393 };
394 file.write_all(bytes).map_err(DatasetsError::IoError)?;
395 drop(file);
396
397 let config = LazyLoadConfig::default();
399 let mmap_ds = MmapDataset::from_binary(&file_path, 10, 3, 0, config)?;
400
401 assert_eq!(mmap_ds.n_samples(), 10);
402 assert_eq!(mmap_ds.n_features(), 3);
403
404 let view = mmap_ds.view_range(0, 3)?;
406 assert_eq!(view.nrows(), 3);
407 assert_eq!(view.ncols(), 3);
408 assert_eq!(view[[0, 0]], 0.0);
409 assert_eq!(view[[2, 2]], 8.0);
410
411 Ok(())
412 }
413
414 #[test]
415 fn test_adaptive_chunking() -> Result<()> {
416 let temp_dir = tempfile::tempdir().map_err(|e| {
417 DatasetsError::InvalidFormat(format!("Failed to create temp dir: {}", e))
418 })?;
419 let file_path = temp_dir.path().join("test_adaptive.bin");
420 let mut file = File::create(&file_path).map_err(DatasetsError::IoError)?;
421
422 let data: Vec<f64> = (0..10_000).map(|i| i as f64).collect();
424 let bytes: &[u8] = unsafe {
425 std::slice::from_raw_parts(
426 data.as_ptr() as *const u8,
427 data.len() * std::mem::size_of::<f64>(),
428 )
429 };
430 file.write_all(bytes).map_err(DatasetsError::IoError)?;
431 drop(file);
432
433 let config = LazyLoadConfig::new()
435 .with_target_memory(8000) .with_chunk_size_range(10, 200);
437
438 let mmap_ds = MmapDataset::from_binary(&file_path, 1000, 10, 0, config)?;
439 let chunk_size = mmap_ds.adaptive_chunk_size();
440
441 assert!(chunk_size >= 10);
443 assert!(chunk_size <= 200);
444
445 Ok(())
446 }
447}