1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::random::Rng;
8use std::fs::File;
9use std::io::{BufReader, BufWriter, Read, Write};
10use std::path::Path;
11
12use crate::error::{Result, TransformError};
13use crate::normalize::NormalizationMethod;
14
15#[derive(Debug, Clone)]
17pub struct OutOfCoreConfig {
18 pub chunk_size_mb: usize,
20 pub use_mmap: bool,
22 pub n_threads: usize,
24 pub temp_dir: String,
26}
27
28impl Default for OutOfCoreConfig {
29 fn default() -> Self {
30 OutOfCoreConfig {
31 chunk_size_mb: 100,
32 use_mmap: true,
33 n_threads: num_cpus::get(),
34 temp_dir: std::env::temp_dir().to_string_lossy().to_string(),
35 }
36 }
37}
38
39pub trait OutOfCoreTransformer: Send + Sync {
41 fn fit_chunks<I>(&mut self, chunks: I) -> Result<()>
43 where
44 I: Iterator<Item = Result<Array2<f64>>>;
45
46 fn transform_chunks<I>(&self, chunks: I) -> Result<ChunkedArrayWriter>
48 where
49 I: Iterator<Item = Result<Array2<f64>>>;
50
51 fn get_transformshape(&self, inputshape: (usize, usize)) -> (usize, usize);
53}
54
55pub struct ChunkedArrayReader {
57 file: BufReader<File>,
58 shape: (usize, usize),
59 chunk_size: usize,
60 current_row: usize,
61 dtype_size: usize,
62 buffer_pool: Vec<u8>,
64}
65
66impl ChunkedArrayReader {
67 pub fn new<P: AsRef<Path>>(path: P, shape: (usize, usize), chunk_size: usize) -> Result<Self> {
69 let file = File::open(&path).map_err(|e| {
70 TransformError::TransformationError(format!("Failed to open file: {e}"))
71 })?;
72
73 let max_chunk_bytes = chunk_size * shape.1 * std::mem::size_of::<f64>();
75
76 Ok(ChunkedArrayReader {
77 file: BufReader::new(file),
78 shape,
79 chunk_size,
80 current_row: 0,
81 dtype_size: std::mem::size_of::<f64>(),
82 buffer_pool: vec![0u8; max_chunk_bytes],
83 })
84 }
85
86 pub fn read_chunk(&mut self) -> Result<Option<Array2<f64>>> {
88 if self.current_row >= self.shape.0 {
89 return Ok(None);
90 }
91
92 let rows_to_read = (self.chunk_size).min(self.shape.0 - self.current_row);
93 let mut chunk = Array2::zeros((rows_to_read, self.shape.1));
94
95 let total_elements = rows_to_read * self.shape.1;
97 let total_bytes = total_elements * self.dtype_size;
98
99 if self.buffer_pool.len() < total_bytes {
101 return Err(TransformError::TransformationError(
102 "Buffer pool too small for chunk".to_string(),
103 ));
104 }
105
106 self.file
108 .read_exact(&mut self.buffer_pool[..total_bytes])
109 .map_err(|e| {
110 TransformError::TransformationError(format!("Failed to read data: {e}"))
111 })?;
112
113 for (element_idx, f64_bytes) in self.buffer_pool[..total_bytes].chunks_exact(8).enumerate()
115 {
116 let i = element_idx / self.shape.1;
117 let j = element_idx % self.shape.1;
118
119 let mut bytes_array = [0u8; 8];
121 bytes_array.copy_from_slice(f64_bytes);
122 chunk[[i, j]] = f64::from_le_bytes(bytes_array);
123 }
124
125 self.current_row += rows_to_read;
126 Ok(Some(chunk))
127 }
128
129 pub fn chunks(self) -> ChunkedArrayIterator {
131 ChunkedArrayIterator { reader: self }
132 }
133}
134
135pub struct ChunkedArrayIterator {
137 reader: ChunkedArrayReader,
138}
139
140impl Iterator for ChunkedArrayIterator {
141 type Item = Result<Array2<f64>>;
142
143 fn next(&mut self) -> Option<Self::Item> {
144 match self.reader.read_chunk() {
145 Ok(Some(chunk)) => Some(Ok(chunk)),
146 Ok(None) => None,
147 Err(e) => Some(Err(e)),
148 }
149 }
150}
151
152#[derive(Debug)]
154pub struct ChunkedArrayWriter {
155 file: BufWriter<File>,
156 shape: (usize, usize),
157 rows_written: usize,
158 path: String,
159 write_buffer: Vec<u8>,
161}
162
163impl ChunkedArrayWriter {
164 pub fn new<P: AsRef<Path>>(path: P, shape: (usize, usize)) -> Result<Self> {
166 let path_str = path.as_ref().to_string_lossy().to_string();
167 let file = File::create(&path).map_err(|e| {
168 TransformError::TransformationError(format!("Failed to create file: {e}"))
169 })?;
170
171 let typical_chunk_size = 1000_usize.min(shape.0);
173 let buffer_capacity = typical_chunk_size * shape.1 * std::mem::size_of::<f64>();
174
175 Ok(ChunkedArrayWriter {
176 file: BufWriter::new(file),
177 shape,
178 rows_written: 0,
179 path: path_str,
180 write_buffer: Vec::with_capacity(buffer_capacity),
181 })
182 }
183
184 pub fn write_chunk(&mut self, chunk: &Array2<f64>) -> Result<()> {
186 if chunk.shape()[1] != self.shape.1 {
187 return Err(TransformError::InvalidInput(format!(
188 "Chunk has {} columns, expected {}",
189 chunk.shape()[1],
190 self.shape.1
191 )));
192 }
193
194 if self.rows_written + chunk.shape()[0] > self.shape.0 {
195 return Err(TransformError::InvalidInput(
196 "Too many rows written".to_string(),
197 ));
198 }
199
200 let total_elements = chunk.shape()[0] * chunk.shape()[1];
202 let total_bytes = total_elements * std::mem::size_of::<f64>();
203
204 self.write_buffer.clear();
206 self.write_buffer.reserve(total_bytes);
207
208 for i in 0..chunk.shape()[0] {
210 for j in 0..chunk.shape()[1] {
211 let bytes = chunk[[i, j]].to_le_bytes();
212 self.write_buffer.extend_from_slice(&bytes);
213 }
214 }
215
216 self.file.write_all(&self.write_buffer).map_err(|e| {
218 TransformError::TransformationError(format!("Failed to write data: {e}"))
219 })?;
220
221 self.rows_written += chunk.shape()[0];
222 Ok(())
223 }
224
225 pub fn finalize(mut self) -> Result<String> {
227 self.file.flush().map_err(|e| {
228 TransformError::TransformationError(format!("Failed to flush data: {e}"))
229 })?;
230
231 if self.rows_written != self.shape.0 {
232 return Err(TransformError::InvalidInput(format!(
233 "Expected {} rows, but wrote {}",
234 self.shape.0, self.rows_written
235 )));
236 }
237
238 Ok(self.path)
239 }
240}
241
242pub struct OutOfCoreNormalizer {
244 method: NormalizationMethod,
245 stats: Option<NormalizationStats>,
247}
248
249#[derive(Clone)]
250struct NormalizationStats {
251 min: Array1<f64>,
252 max: Array1<f64>,
253 mean: Array1<f64>,
254 std: Array1<f64>,
255 median: Array1<f64>,
256 iqr: Array1<f64>,
257 count: usize,
258}
259
260impl OutOfCoreNormalizer {
261 pub fn new(method: NormalizationMethod) -> Self {
263 OutOfCoreNormalizer {
264 method,
265 stats: None,
266 }
267 }
268
269 fn compute_simple_stats<I>(&mut self, chunks: I, nfeatures: usize) -> Result<()>
271 where
272 I: Iterator<Item = Result<Array2<f64>>>,
273 {
274 let mut min = Array1::from_elem(nfeatures, f64::INFINITY);
275 let mut max = Array1::from_elem(nfeatures, f64::NEG_INFINITY);
276 let mut sum = Array1::zeros(nfeatures);
277 let mut sum_sq = Array1::zeros(nfeatures);
278 let mut count = 0;
279
280 for chunk_result in chunks {
282 let chunk = chunk_result?;
283 count += chunk.shape()[0];
284
285 for j in 0..nfeatures {
286 let col = chunk.column(j);
287 for &val in col.iter() {
288 min[j] = min[j].min(val);
289 max[j] = max[j].max(val);
290 sum[j] += val;
291 sum_sq[j] += val * val;
292 }
293 }
294 }
295
296 let mean = sum / count as f64;
298 let variance = sum_sq / count as f64 - &mean * &mean;
299 let std = variance.mapv(|v: f64| v.sqrt());
300
301 self.stats = Some(NormalizationStats {
302 min,
303 max,
304 mean,
305 std,
306 median: Array1::zeros(nfeatures), iqr: Array1::zeros(nfeatures), count,
309 });
310
311 Ok(())
312 }
313
314 fn compute_robust_stats<I>(&mut self, chunks: I, nfeatures: usize) -> Result<()>
316 where
317 I: Iterator<Item = Result<Array2<f64>>>,
318 {
319 const RESERVOIR_SIZE: usize = 10000; let mut reservoirs: Vec<Vec<f64>> = vec![Vec::with_capacity(RESERVOIR_SIZE); nfeatures];
323 let mut count = 0;
324 let mut rng = scirs2_core::random::rng();
325
326 for chunk_result in chunks {
328 let chunk = chunk_result?;
329
330 for i in 0..chunk.shape()[0] {
331 count += 1;
332
333 for j in 0..nfeatures {
334 let val = chunk[[i, j]];
335
336 if reservoirs[j].len() < RESERVOIR_SIZE {
337 reservoirs[j].push(val);
339 } else {
340 let k = (count as f64 * rng.random::<f64>()) as usize;
342 if k < RESERVOIR_SIZE {
343 reservoirs[j][k] = val;
344 }
345 }
346 }
347 }
348 }
349
350 let mut median = Array1::zeros(nfeatures);
352 let mut iqr = Array1::zeros(nfeatures);
353
354 for j in 0..nfeatures {
355 if !reservoirs[j].is_empty() {
356 reservoirs[j].sort_by(|a, b| a.partial_cmp(b).unwrap());
357 let len = reservoirs[j].len();
358
359 median[j] = if len.is_multiple_of(2) {
361 (reservoirs[j][len / 2 - 1] + reservoirs[j][len / 2]) / 2.0
362 } else {
363 reservoirs[j][len / 2]
364 };
365
366 let q1_idx = len / 4;
368 let q1 = reservoirs[j][q1_idx];
369
370 let q3_idx = 3 * len / 4;
372 let q3 = reservoirs[j][q3_idx.min(len - 1)];
373
374 iqr[j] = q3 - q1;
376
377 if iqr[j] < 1e-10 {
379 iqr[j] = 1.0;
380 }
381 } else {
382 median[j] = 0.0;
383 iqr[j] = 1.0;
384 }
385 }
386
387 self.stats = Some(NormalizationStats {
388 min: Array1::zeros(nfeatures), max: Array1::zeros(nfeatures), mean: Array1::zeros(nfeatures), std: Array1::zeros(nfeatures), median,
393 iqr,
394 count,
395 });
396
397 Ok(())
398 }
399}
400
401impl OutOfCoreTransformer for OutOfCoreNormalizer {
402 fn fit_chunks<I>(&mut self, chunks: I) -> Result<()>
403 where
404 I: Iterator<Item = Result<Array2<f64>>>,
405 {
406 let mut chunks_iter = chunks.peekable();
408 let nfeatures = match chunks_iter.peek() {
409 Some(Ok(chunk)) => chunk.shape()[1],
410 Some(Err(_)) => return chunks_iter.next().unwrap().map(|_| ()),
411 None => {
412 return Err(TransformError::InvalidInput(
413 "No chunks provided".to_string(),
414 ))
415 }
416 };
417
418 match self.method {
419 NormalizationMethod::MinMax
420 | NormalizationMethod::MinMaxCustom(_, _)
421 | NormalizationMethod::ZScore
422 | NormalizationMethod::MaxAbs => {
423 self.compute_simple_stats(chunks_iter, nfeatures)?;
424 }
425 NormalizationMethod::Robust => {
426 self.compute_robust_stats(chunks_iter, nfeatures)?;
428 }
429 _ => {
430 return Err(TransformError::NotImplemented(
431 "This normalization method is not supported for out-of-core processing"
432 .to_string(),
433 ));
434 }
435 }
436
437 Ok(())
438 }
439
440 fn transform_chunks<I>(&self, chunks: I) -> Result<ChunkedArrayWriter>
441 where
442 I: Iterator<Item = Result<Array2<f64>>>,
443 {
444 if self.stats.is_none() {
445 return Err(TransformError::TransformationError(
446 "Normalizer has not been fitted".to_string(),
447 ));
448 }
449
450 let stats = self.stats.as_ref().unwrap();
451
452 let output_path = format!(
454 "{}/transform_output_{}.bin",
455 std::env::temp_dir().to_string_lossy(),
456 std::process::id()
457 );
458
459 let mut writer = ChunkedArrayWriter::new(&output_path, (stats.count, stats.min.len()))?;
460
461 for chunk_result in chunks {
463 let chunk = chunk_result?;
464 let mut transformed = Array2::zeros((chunk.nrows(), chunk.ncols()));
465
466 match self.method {
467 NormalizationMethod::MinMax => {
468 let range = &stats.max - &stats.min;
469 for i in 0..chunk.shape()[0] {
470 for j in 0..chunk.shape()[1] {
471 if range[j].abs() > 1e-10 {
472 transformed[[i, j]] = (chunk[[i, j]] - stats.min[j]) / range[j];
473 } else {
474 transformed[[i, j]] = 0.5;
475 }
476 }
477 }
478 }
479 NormalizationMethod::ZScore => {
480 for i in 0..chunk.shape()[0] {
481 for j in 0..chunk.shape()[1] {
482 if stats.std[j] > 1e-10 {
483 transformed[[i, j]] =
484 (chunk[[i, j]] - stats.mean[j]) / stats.std[j];
485 } else {
486 transformed[[i, j]] = 0.0;
487 }
488 }
489 }
490 }
491 NormalizationMethod::MaxAbs => {
492 for i in 0..chunk.shape()[0] {
493 for j in 0..chunk.shape()[1] {
494 let max_abs = stats.max[j].abs().max(stats.min[j].abs());
495 if max_abs > 1e-10 {
496 transformed[[i, j]] = chunk[[i, j]] / max_abs;
497 } else {
498 transformed[[i, j]] = 0.0;
499 }
500 }
501 }
502 }
503 NormalizationMethod::Robust => {
504 for i in 0..chunk.shape()[0] {
505 for j in 0..chunk.shape()[1] {
506 if stats.iqr[j] > 1e-10 {
508 transformed[[i, j]] =
509 (chunk[[i, j]] - stats.median[j]) / stats.iqr[j];
510 } else {
511 transformed[[i, j]] = 0.0;
512 }
513 }
514 }
515 }
516 _ => {
517 return Err(TransformError::NotImplemented(
518 "This normalization method is not supported".to_string(),
519 ));
520 }
521 }
522
523 writer.write_chunk(&transformed)?;
524 }
525
526 Ok(writer)
527 }
528
529 fn get_transformshape(&self, inputshape: (usize, usize)) -> (usize, usize) {
530 inputshape }
532}
533
534#[allow(dead_code)]
536pub fn csv_chunks<P: AsRef<Path>>(
537 path: P,
538 chunk_size: usize,
539 has_header: bool,
540) -> Result<impl Iterator<Item = Result<Array2<f64>>>> {
541 let file = File::open(path).map_err(|e| {
542 TransformError::TransformationError(format!("Failed to open CSV file: {e}"))
543 })?;
544
545 Ok(CsvChunkIterator::new(
546 BufReader::new(file),
547 chunk_size,
548 has_header,
549 ))
550}
551
552struct CsvChunkIterator {
554 reader: BufReader<File>,
555 chunk_size: usize,
556 skipheader: bool,
557 header_skipped: bool,
558}
559
560impl CsvChunkIterator {
561 fn new(_reader: BufReader<File>, chunk_size: usize, skipheader: bool) -> Self {
562 CsvChunkIterator {
563 reader: _reader,
564 chunk_size,
565 skipheader,
566 header_skipped: false,
567 }
568 }
569}
570
571impl Iterator for CsvChunkIterator {
572 type Item = Result<Array2<f64>>;
573
574 fn next(&mut self) -> Option<Self::Item> {
575 use std::io::BufRead;
576
577 let mut rows = Vec::new();
578 let mut n_cols = None;
579
580 for line_result in (&mut self.reader).lines().take(self.chunk_size) {
581 let line = match line_result {
582 Ok(l) => l,
583 Err(e) => return Some(Err(TransformError::IoError(e))),
584 };
585
586 if self.skipheader && !self.header_skipped {
588 self.header_skipped = true;
589 continue;
590 }
591
592 let values: Result<Vec<f64>> = line
594 .split(',')
595 .map(|s| {
596 s.trim().parse::<f64>().map_err(|e| {
597 TransformError::ParseError(format!("Failed to parse number: {e}"))
598 })
599 })
600 .collect();
601
602 let values = match values {
603 Ok(v) => v,
604 Err(e) => return Some(Err(e)),
605 };
606
607 if let Some(nc) = n_cols {
609 if values.len() != nc {
610 return Some(Err(TransformError::InvalidInput(
611 "Inconsistent number of columns in CSV".to_string(),
612 )));
613 }
614 } else {
615 n_cols = Some(values.len());
616 }
617
618 rows.push(values);
619 }
620
621 if rows.is_empty() {
622 return None;
623 }
624
625 let n_rows = rows.len();
627 let n_cols = n_cols.unwrap();
628 let mut array = Array2::zeros((n_rows, n_cols));
629
630 for (i, row) in rows.iter().enumerate() {
631 for (j, &val) in row.iter().enumerate() {
632 array[[i, j]] = val;
633 }
634 }
635
636 Some(Ok(array))
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use scirs2_core::ndarray::Array;
644
645 #[test]
646 fn test_out_of_core_robust_scaling() {
647 let data = vec![
649 Array::from_shape_vec((3, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]).unwrap(),
650 Array::from_shape_vec((3, 2), vec![4.0, 40.0, 5.0, 50.0, 6.0, 60.0]).unwrap(),
651 Array::from_shape_vec((3, 2), vec![7.0, 70.0, 8.0, 80.0, 9.0, 90.0]).unwrap(),
652 ];
653
654 let chunks = data.into_iter().map(|chunk| Ok(chunk));
656
657 let mut normalizer = OutOfCoreNormalizer::new(NormalizationMethod::Robust);
659 normalizer.fit_chunks(chunks).unwrap();
660
661 let stats = normalizer.stats.as_ref().unwrap();
663 assert_eq!(stats.median.len(), 2);
664 assert_eq!(stats.iqr.len(), 2);
665
666 assert!((stats.median[0] - 5.0).abs() < 1.0); assert!(stats.iqr[0] > 0.0);
670
671 assert!((stats.median[1] - 50.0).abs() < 10.0); assert!(stats.iqr[1] > 0.0);
675 }
676
677 #[test]
678 fn test_out_of_core_robust_transform() {
679 let fit_data = vec![
681 Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap(),
682 Array::from_shape_vec((2, 1), vec![3.0, 4.0]).unwrap(),
683 Array::from_shape_vec((1, 1), vec![5.0]).unwrap(),
684 ];
685
686 let mut normalizer = OutOfCoreNormalizer::new(NormalizationMethod::Robust);
687 normalizer
688 .fit_chunks(fit_data.into_iter().map(|chunk| Ok(chunk)))
689 .unwrap();
690
691 let transform_data = vec![Array::from_shape_vec((2, 1), vec![3.0, 6.0]).unwrap()];
693
694 let result = normalizer.transform_chunks(transform_data.into_iter().map(|chunk| Ok(chunk)));
695 assert!(result.is_ok());
696 }
697
698 #[test]
699 fn test_out_of_core_normalizer_not_fitted() {
700 let normalizer = OutOfCoreNormalizer::new(NormalizationMethod::Robust);
701 let data = vec![Array::zeros((2, 2))];
702
703 let result = normalizer.transform_chunks(data.into_iter().map(|chunk| Ok(chunk)));
704 assert!(result.is_err());
705 assert!(result.unwrap_err().to_string().contains("not been fitted"));
706 }
707
708 #[test]
709 fn test_out_of_core_empty_chunks() {
710 let mut normalizer = OutOfCoreNormalizer::new(NormalizationMethod::Robust);
711 let empty_chunks: Vec<Result<Array2<f64>>> = vec![];
712
713 let result = normalizer.fit_chunks(empty_chunks.into_iter());
714 assert!(result.is_err());
715 assert!(result
716 .unwrap_err()
717 .to_string()
718 .contains("No chunks provided"));
719 }
720}