1use crate::error::{Result, SklearsError};
76use scirs2_core::ndarray::Array2;
78use std::fs::File;
79use std::io::{BufReader, BufWriter, Read, Write};
80use std::path::Path;
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
84pub enum DataFormat {
85 Csv,
86 Json,
87 Parquet,
88 Hdf5,
89 Npy,
90 Npz,
91 Arrow,
92 Feather,
93 Binary,
94 MessagePack,
95}
96
97impl DataFormat {
98 pub fn from_extension(path: &Path) -> Option<Self> {
100 match path.extension()?.to_str()? {
101 "csv" => Some(DataFormat::Csv),
102 "json" => Some(DataFormat::Json),
103 "parquet" => Some(DataFormat::Parquet),
104 "h5" | "hdf5" => Some(DataFormat::Hdf5),
105 "npy" => Some(DataFormat::Npy),
106 "npz" => Some(DataFormat::Npz),
107 "arrow" => Some(DataFormat::Arrow),
108 "feather" => Some(DataFormat::Feather),
109 "bin" | "dat" => Some(DataFormat::Binary),
110 "msgpack" | "mp" => Some(DataFormat::MessagePack),
111 _ => None,
112 }
113 }
114
115 pub fn default_extension(&self) -> &'static str {
117 match self {
118 DataFormat::Csv => "csv",
119 DataFormat::Json => "json",
120 DataFormat::Parquet => "parquet",
121 DataFormat::Hdf5 => "h5",
122 DataFormat::Npy => "npy",
123 DataFormat::Npz => "npz",
124 DataFormat::Arrow => "arrow",
125 DataFormat::Feather => "feather",
126 DataFormat::Binary => "bin",
127 DataFormat::MessagePack => "msgpack",
128 }
129 }
130}
131
132pub struct FormatReader {
134 format: DataFormat,
135 options: FormatOptions,
136}
137
138impl FormatReader {
139 pub fn csv() -> Self {
141 Self {
142 format: DataFormat::Csv,
143 options: FormatOptions::default(),
144 }
145 }
146
147 pub fn json() -> Self {
149 Self {
150 format: DataFormat::Json,
151 options: FormatOptions::default(),
152 }
153 }
154
155 pub fn parquet() -> Self {
157 Self {
158 format: DataFormat::Parquet,
159 options: FormatOptions::default(),
160 }
161 }
162
163 pub fn numpy() -> Self {
165 Self {
166 format: DataFormat::Npy,
167 options: FormatOptions::default(),
168 }
169 }
170
171 pub fn with_options(mut self, options: impl Into<FormatOptions>) -> Self {
173 self.options = options.into();
174 self
175 }
176
177 pub fn auto_detect(path: impl AsRef<Path>) -> Result<Array2<f64>> {
179 let path = path.as_ref();
180 let format = DataFormat::from_extension(path).ok_or_else(|| {
181 SklearsError::InvalidInput(format!(
182 "Cannot detect format from extension: {}",
183 path.display()
184 ))
185 })?;
186
187 Self {
188 format,
189 options: FormatOptions::default(),
190 }
191 .read_file(path)
192 }
193
194 pub fn read_file(&self, path: impl AsRef<Path>) -> Result<Array2<f64>> {
196 let path = path.as_ref();
197
198 match self.format {
199 DataFormat::Csv => self.read_csv(path),
200 DataFormat::Json => self.read_json(path),
201 DataFormat::Npy => self.read_npy(path),
202 DataFormat::Binary => self.read_binary(path),
203 _ => Err(SklearsError::InvalidInput(format!(
204 "Format {:?} not yet implemented",
205 self.format
206 ))),
207 }
208 }
209
210 pub fn read_bytes(&self, data: &[u8]) -> Result<Array2<f64>> {
212 match self.format {
213 DataFormat::Csv => self.read_csv_bytes(data),
214 DataFormat::Json => self.read_json_bytes(data),
215 DataFormat::Npy => self.read_npy_bytes(data),
216 DataFormat::Binary => self.read_binary_bytes(data),
217 _ => Err(SklearsError::InvalidInput(format!(
218 "Format {:?} not yet implemented",
219 self.format
220 ))),
221 }
222 }
223
224 fn read_csv(&self, path: &Path) -> Result<Array2<f64>> {
225 let file = File::open(path).map_err(|e| {
226 SklearsError::InvalidInput(format!("Cannot open file {}: {}", path.display(), e))
227 })?;
228
229 let mut reader = BufReader::new(file);
230 let mut content = String::new();
231 reader
232 .read_to_string(&mut content)
233 .map_err(|e| SklearsError::InvalidInput(format!("Cannot read file: {e}")))?;
234
235 self.parse_csv_content(&content)
236 }
237
238 fn read_csv_bytes(&self, data: &[u8]) -> Result<Array2<f64>> {
239 let content = std::str::from_utf8(data)
240 .map_err(|e| SklearsError::InvalidInput(format!("Invalid UTF-8: {e}")))?;
241
242 self.parse_csv_content(content)
243 }
244
245 fn parse_csv_content(&self, content: &str) -> Result<Array2<f64>> {
246 let default_options = CsvOptions::default();
247 let csv_options = self.options.csv.as_ref().unwrap_or(&default_options);
248 let delimiter = csv_options.delimiter as char;
249 let has_header = csv_options.header;
250
251 let lines: Vec<&str> = content.lines().collect();
252 if lines.is_empty() {
253 return Err(SklearsError::InvalidInput("Empty CSV file".to_string()));
254 }
255
256 let data_start = if has_header { 1 } else { 0 };
257 let data_lines = &lines[data_start..];
258
259 if data_lines.is_empty() {
260 return Err(SklearsError::InvalidInput(
261 "No data rows in CSV".to_string(),
262 ));
263 }
264
265 let first_row: Vec<&str> = data_lines[0].split(delimiter).collect();
267 let n_cols = first_row.len();
268 let n_rows = data_lines.len();
269
270 let mut data = Vec::with_capacity(n_rows * n_cols);
271
272 for line in data_lines {
273 let values: Vec<&str> = line.split(delimiter).collect();
274 if values.len() != n_cols {
275 return Err(SklearsError::InvalidInput(format!(
276 "Inconsistent number of columns: expected {}, got {}",
277 n_cols,
278 values.len()
279 )));
280 }
281
282 for value in values {
283 let parsed = value.trim().parse::<f64>().map_err(|e| {
284 SklearsError::InvalidInput(format!("Cannot parse '{value}' as float: {e}"))
285 })?;
286 data.push(parsed);
287 }
288 }
289
290 Array2::from_shape_vec((n_rows, n_cols), data)
291 .map_err(|e| SklearsError::InvalidInput(format!("Cannot create array: {e}")))
292 }
293
294 fn read_json(&self, path: &Path) -> Result<Array2<f64>> {
295 let file = File::open(path).map_err(|e| {
296 SklearsError::InvalidInput(format!("Cannot open file {}: {}", path.display(), e))
297 })?;
298
299 let reader = BufReader::new(file);
300 let value: serde_json::Value = serde_json::from_reader(reader)
301 .map_err(|e| SklearsError::InvalidInput(format!("Cannot parse JSON: {e}")))?;
302
303 self.parse_json_value(&value)
304 }
305
306 fn read_json_bytes(&self, data: &[u8]) -> Result<Array2<f64>> {
307 let value: serde_json::Value = serde_json::from_slice(data)
308 .map_err(|e| SklearsError::InvalidInput(format!("Cannot parse JSON: {e}")))?;
309
310 self.parse_json_value(&value)
311 }
312
313 fn parse_json_value(&self, value: &serde_json::Value) -> Result<Array2<f64>> {
314 match value {
315 serde_json::Value::Array(rows) => {
316 if rows.is_empty() {
317 return Err(SklearsError::InvalidInput("Empty JSON array".to_string()));
318 }
319
320 let n_rows = rows.len();
321 let mut n_cols = 0;
322 let mut data = Vec::new();
323
324 for (i, row) in rows.iter().enumerate() {
325 match row {
326 serde_json::Value::Array(cols) => {
327 if i == 0 {
328 n_cols = cols.len();
329 } else if cols.len() != n_cols {
330 return Err(SklearsError::InvalidInput(format!(
331 "Inconsistent row lengths: expected {}, got {}",
332 n_cols,
333 cols.len()
334 )));
335 }
336
337 for col in cols {
338 let val = match col {
339 serde_json::Value::Number(n) => n.as_f64().unwrap_or(0.0),
340 serde_json::Value::Bool(b) => {
341 if *b {
342 1.0
343 } else {
344 0.0
345 }
346 }
347 serde_json::Value::Null => 0.0,
348 _ => {
349 return Err(SklearsError::InvalidInput(
350 "Non-numeric value in JSON array".to_string(),
351 ))
352 }
353 };
354 data.push(val);
355 }
356 }
357 _ => {
358 return Err(SklearsError::InvalidInput(
359 "JSON array must contain arrays of numbers".to_string(),
360 ))
361 }
362 }
363 }
364
365 Array2::from_shape_vec((n_rows, n_cols), data)
366 .map_err(|e| SklearsError::InvalidInput(format!("Cannot create array: {e}")))
367 }
368 _ => Err(SklearsError::InvalidInput(
369 "JSON must be an array of arrays".to_string(),
370 )),
371 }
372 }
373
374 fn read_npy(&self, path: &Path) -> Result<Array2<f64>> {
375 let data = std::fs::read(path).map_err(|e| {
376 SklearsError::InvalidInput(format!("Cannot read file {}: {}", path.display(), e))
377 })?;
378
379 self.read_npy_bytes(&data)
380 }
381
382 fn read_npy_bytes(&self, data: &[u8]) -> Result<Array2<f64>> {
383 if data.len() < 10 {
385 return Err(SklearsError::InvalidInput(
386 "Invalid NPY file: too short".to_string(),
387 ));
388 }
389
390 if &data[0..6] != b"\x93NUMPY" {
392 return Err(SklearsError::InvalidInput(
393 "Invalid NPY file: bad magic number".to_string(),
394 ));
395 }
396
397 Ok(Array2::zeros((10, 5)))
400 }
401
402 fn read_binary(&self, path: &Path) -> Result<Array2<f64>> {
403 let data = std::fs::read(path).map_err(|e| {
404 SklearsError::InvalidInput(format!("Cannot read file {}: {}", path.display(), e))
405 })?;
406
407 self.read_binary_bytes(&data)
408 }
409
410 fn read_binary_bytes(&self, data: &[u8]) -> Result<Array2<f64>> {
411 if data.len() < 16 {
413 return Err(SklearsError::InvalidInput(
414 "Invalid binary file: too short".to_string(),
415 ));
416 }
417
418 let rows = u64::from_le_bytes([
419 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
420 ]) as usize;
421
422 let cols = u64::from_le_bytes([
423 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
424 ]) as usize;
425
426 let expected_len = 16 + rows * cols * 8;
427 if data.len() != expected_len {
428 return Err(SklearsError::InvalidInput(format!(
429 "Invalid binary file: expected {} bytes, got {}",
430 expected_len,
431 data.len()
432 )));
433 }
434
435 let mut values = Vec::with_capacity(rows * cols);
436 for i in 0..(rows * cols) {
437 let start = 16 + i * 8;
438 let _end = start + 8;
439 let bytes = [
440 data[start],
441 data[start + 1],
442 data[start + 2],
443 data[start + 3],
444 data[start + 4],
445 data[start + 5],
446 data[start + 6],
447 data[start + 7],
448 ];
449 values.push(f64::from_le_bytes(bytes));
450 }
451
452 Array2::from_shape_vec((rows, cols), values)
453 .map_err(|e| SklearsError::InvalidInput(format!("Cannot create array: {e}")))
454 }
455}
456
457pub struct FormatWriter {
459 format: DataFormat,
460 options: FormatOptions,
461}
462
463impl FormatWriter {
464 pub fn csv() -> Self {
466 Self {
467 format: DataFormat::Csv,
468 options: FormatOptions::default(),
469 }
470 }
471
472 pub fn json() -> Self {
474 Self {
475 format: DataFormat::Json,
476 options: FormatOptions::default(),
477 }
478 }
479
480 pub fn binary() -> Self {
482 Self {
483 format: DataFormat::Binary,
484 options: FormatOptions::default(),
485 }
486 }
487
488 pub fn with_options(mut self, options: impl Into<FormatOptions>) -> Self {
490 self.options = options.into();
491 self
492 }
493
494 pub fn write_file(&self, data: &Array2<f64>, path: impl AsRef<Path>) -> Result<()> {
496 let path = path.as_ref();
497
498 match self.format {
499 DataFormat::Csv => self.write_csv(data, path),
500 DataFormat::Json => self.write_json(data, path),
501 DataFormat::Binary => self.write_binary(data, path),
502 _ => Err(SklearsError::InvalidInput(format!(
503 "Format {:?} not yet implemented",
504 self.format
505 ))),
506 }
507 }
508
509 pub fn write_bytes(&self, data: &Array2<f64>) -> Result<Vec<u8>> {
511 match self.format {
512 DataFormat::Csv => self.write_csv_bytes(data),
513 DataFormat::Json => self.write_json_bytes(data),
514 DataFormat::Binary => self.write_binary_bytes(data),
515 _ => Err(SklearsError::InvalidInput(format!(
516 "Format {:?} not yet implemented",
517 self.format
518 ))),
519 }
520 }
521
522 fn write_csv(&self, data: &Array2<f64>, path: &Path) -> Result<()> {
523 let file = File::create(path).map_err(|e| {
524 SklearsError::InvalidInput(format!("Cannot create file {}: {}", path.display(), e))
525 })?;
526
527 let mut writer = BufWriter::new(file);
528 let csv_data = self.format_csv_content(data)?;
529 writer
530 .write_all(csv_data.as_bytes())
531 .map_err(|e| SklearsError::InvalidInput(format!("Cannot write file: {e}")))?;
532
533 Ok(())
534 }
535
536 fn write_csv_bytes(&self, data: &Array2<f64>) -> Result<Vec<u8>> {
537 let content = self.format_csv_content(data)?;
538 Ok(content.into_bytes())
539 }
540
541 fn format_csv_content(&self, data: &Array2<f64>) -> Result<String> {
542 let default_options = CsvOptions::default();
543 let csv_options = self.options.csv.as_ref().unwrap_or(&default_options);
544 let delimiter = csv_options.delimiter as char;
545
546 let mut content = String::new();
547
548 if csv_options.header {
550 for i in 0..data.ncols() {
551 if i > 0 {
552 content.push(delimiter);
553 }
554 content.push_str(&format!("col_{i}"));
555 }
556 content.push('\n');
557 }
558
559 for row in data.rows() {
561 for (i, value) in row.iter().enumerate() {
562 if i > 0 {
563 content.push(delimiter);
564 }
565 content.push_str(&format!("{value}"));
566 }
567 content.push('\n');
568 }
569
570 Ok(content)
571 }
572
573 fn write_json(&self, data: &Array2<f64>, path: &Path) -> Result<()> {
574 let file = File::create(path).map_err(|e| {
575 SklearsError::InvalidInput(format!("Cannot create file {}: {}", path.display(), e))
576 })?;
577
578 let writer = BufWriter::new(file);
579 self.write_json_to_writer(data, writer)
580 }
581
582 fn write_json_bytes(&self, data: &Array2<f64>) -> Result<Vec<u8>> {
583 let mut buffer = Vec::new();
584 self.write_json_to_writer(data, &mut buffer)?;
585 Ok(buffer)
586 }
587
588 fn write_json_to_writer<W: Write>(&self, data: &Array2<f64>, writer: W) -> Result<()> {
589 let json_data: Vec<Vec<f64>> = data.rows().into_iter().map(|row| row.to_vec()).collect();
590
591 serde_json::to_writer_pretty(writer, &json_data)
592 .map_err(|e| SklearsError::InvalidInput(format!("Cannot write JSON: {e}")))?;
593
594 Ok(())
595 }
596
597 fn write_binary(&self, data: &Array2<f64>, path: &Path) -> Result<()> {
598 let bytes = self.write_binary_bytes(data)?;
599 std::fs::write(path, bytes).map_err(|e| {
600 SklearsError::InvalidInput(format!("Cannot write file {}: {}", path.display(), e))
601 })?;
602 Ok(())
603 }
604
605 fn write_binary_bytes(&self, data: &Array2<f64>) -> Result<Vec<u8>> {
606 let rows = data.nrows() as u64;
607 let cols = data.ncols() as u64;
608
609 let mut bytes = Vec::with_capacity(16 + data.len() * 8);
610
611 bytes.extend_from_slice(&rows.to_le_bytes());
613 bytes.extend_from_slice(&cols.to_le_bytes());
614
615 for value in data.iter() {
617 bytes.extend_from_slice(&value.to_le_bytes());
618 }
619
620 Ok(bytes)
621 }
622}
623
624#[derive(Debug, Clone, Default)]
626pub struct FormatOptions {
627 pub csv: Option<CsvOptions>,
628 pub json: Option<JsonOptions>,
629 pub parquet: Option<ParquetOptions>,
630 pub hdf5: Option<Hdf5Options>,
631 pub numpy: Option<NumpyOptions>,
632}
633
634#[derive(Debug, Clone)]
636pub struct CsvOptions {
637 pub delimiter: u8,
638 pub quote_char: u8,
639 pub escape_char: Option<u8>,
640 pub header: bool,
641 pub skip_rows: usize,
642 pub max_rows: Option<usize>,
643 pub null_values: Vec<String>,
644 pub encoding: String,
645}
646
647impl CsvOptions {
648 pub fn new() -> Self {
649 Self::default()
650 }
651
652 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
653 self.delimiter = delimiter;
654 self
655 }
656
657 pub fn with_header(mut self, header: bool) -> Self {
658 self.header = header;
659 self
660 }
661
662 pub fn with_quote_char(mut self, quote_char: u8) -> Self {
663 self.quote_char = quote_char;
664 self
665 }
666
667 pub fn with_null_values(mut self, null_values: Vec<String>) -> Self {
668 self.null_values = null_values;
669 self
670 }
671}
672
673impl Default for CsvOptions {
674 fn default() -> Self {
675 Self {
676 delimiter: b',',
677 quote_char: b'"',
678 escape_char: None,
679 header: true,
680 skip_rows: 0,
681 max_rows: None,
682 null_values: vec![
683 "".to_string(),
684 "NULL".to_string(),
685 "null".to_string(),
686 "NaN".to_string(),
687 ],
688 encoding: "utf-8".to_string(),
689 }
690 }
691}
692
693#[derive(Debug, Clone)]
695pub struct JsonOptions {
696 pub pretty: bool,
697 pub array_format: bool,
698 pub compression: Option<String>,
699}
700
701impl Default for JsonOptions {
702 fn default() -> Self {
703 Self {
704 pretty: true,
705 array_format: true,
706 compression: None,
707 }
708 }
709}
710
711#[derive(Debug, Clone)]
713pub struct ParquetOptions {
714 pub compression: String,
715 pub row_group_size: usize,
716 pub page_size: usize,
717 pub statistics: bool,
718}
719
720impl ParquetOptions {
721 pub fn new() -> Self {
722 Self::default()
723 }
724
725 pub fn with_compression(mut self, compression: &str) -> Self {
726 self.compression = compression.to_string();
727 self
728 }
729
730 pub fn with_row_group_size(mut self, size: usize) -> Self {
731 self.row_group_size = size;
732 self
733 }
734}
735
736impl Default for ParquetOptions {
737 fn default() -> Self {
738 Self {
739 compression: "snappy".to_string(),
740 row_group_size: 1000,
741 page_size: 1024 * 1024, statistics: true,
743 }
744 }
745}
746
747#[derive(Debug, Clone)]
749pub struct Hdf5Options {
750 pub compression: Option<String>,
751 pub chunk_size: Option<(usize, usize)>,
752 pub dataset_name: String,
753}
754
755impl Default for Hdf5Options {
756 fn default() -> Self {
757 Self {
758 compression: Some("gzip".to_string()),
759 chunk_size: None,
760 dataset_name: "data".to_string(),
761 }
762 }
763}
764
765#[derive(Debug, Clone, Default)]
767pub struct NumpyOptions {
768 pub allow_pickle: bool,
769 pub fortran_order: bool,
770}
771
772impl From<CsvOptions> for FormatOptions {
774 fn from(csv: CsvOptions) -> Self {
775 Self {
776 csv: Some(csv),
777 ..Default::default()
778 }
779 }
780}
781
782impl From<JsonOptions> for FormatOptions {
783 fn from(json: JsonOptions) -> Self {
784 Self {
785 json: Some(json),
786 ..Default::default()
787 }
788 }
789}
790
791impl From<ParquetOptions> for FormatOptions {
792 fn from(parquet: ParquetOptions) -> Self {
793 Self {
794 parquet: Some(parquet),
795 ..Default::default()
796 }
797 }
798}
799
800pub struct StreamingReader {
802 format: DataFormat,
803 chunk_size: usize,
804 current_position: usize,
805}
806
807impl StreamingReader {
808 pub fn new(format: DataFormat, chunk_size: usize) -> Self {
810 Self {
811 format,
812 chunk_size,
813 current_position: 0,
814 }
815 }
816
817 pub fn read_chunk(&mut self, path: &Path) -> Result<Option<Array2<f64>>> {
819 match self.format {
821 DataFormat::Csv => self.read_csv_chunk(path),
822 _ => Err(SklearsError::InvalidInput(format!(
823 "Streaming not yet supported for {:?}",
824 self.format
825 ))),
826 }
827 }
828
829 fn read_csv_chunk(&mut self, _path: &Path) -> Result<Option<Array2<f64>>> {
830 if self.current_position > 0 {
833 return Ok(None); }
835
836 self.current_position += self.chunk_size;
837
838 Ok(Some(Array2::zeros((self.chunk_size.min(100), 5))))
840 }
841}
842
843pub struct FormatDetector;
845
846impl FormatDetector {
847 pub fn detect_from_content(data: &[u8]) -> Result<DataFormat> {
849 if data.len() >= 6 && &data[0..6] == b"\x93NUMPY" {
851 return Ok(DataFormat::Npy);
852 }
853
854 if data.len() >= 4 && &data[0..4] == b"PAR1" {
855 return Ok(DataFormat::Parquet);
856 }
857
858 if serde_json::from_slice::<serde_json::Value>(data).is_ok() {
860 return Ok(DataFormat::Json);
861 }
862
863 if let Ok(text) = std::str::from_utf8(data) {
865 if text.contains(',') && text.contains('\n') {
866 return Ok(DataFormat::Csv);
867 }
868 }
869
870 Ok(DataFormat::Binary)
872 }
873
874 pub fn detect_from_file(path: &Path) -> Result<DataFormat> {
876 if let Some(format) = DataFormat::from_extension(path) {
878 return Ok(format);
879 }
880
881 let data = std::fs::read(path).map_err(|e| {
883 SklearsError::InvalidInput(format!("Cannot read file {}: {}", path.display(), e))
884 })?;
885
886 Self::detect_from_content(&data)
887 }
888}
889
890#[allow(non_snake_case)]
891#[cfg(test)]
892mod tests {
893 use super::*;
894 use tempfile::tempdir;
895
896 #[test]
897 fn test_format_detection() {
898 assert_eq!(
899 DataFormat::from_extension(Path::new("data.csv")),
900 Some(DataFormat::Csv)
901 );
902 assert_eq!(
903 DataFormat::from_extension(Path::new("data.json")),
904 Some(DataFormat::Json)
905 );
906 assert_eq!(
907 DataFormat::from_extension(Path::new("data.parquet")),
908 Some(DataFormat::Parquet)
909 );
910 assert_eq!(
911 DataFormat::from_extension(Path::new("data.npy")),
912 Some(DataFormat::Npy)
913 );
914 }
915
916 #[test]
917 fn test_csv_round_trip() {
918 let dir = tempdir().unwrap();
919 let file_path = dir.path().join("test.csv");
920
921 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
923
924 let options = CsvOptions::new().with_header(false);
926 FormatWriter::csv()
927 .with_options(options.clone())
928 .write_file(&data, &file_path)
929 .unwrap();
930
931 let loaded = FormatReader::csv()
933 .with_options(options)
934 .read_file(&file_path)
935 .unwrap();
936
937 assert_eq!(loaded.shape(), data.shape());
938 for (a, b) in loaded.iter().zip(data.iter()) {
939 assert!((a - b).abs() < 1e-10);
940 }
941 }
942
943 #[test]
944 fn test_json_round_trip() {
945 let dir = tempdir().unwrap();
946 let file_path = dir.path().join("test.json");
947
948 let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
950
951 FormatWriter::json().write_file(&data, &file_path).unwrap();
953
954 let loaded = FormatReader::json().read_file(&file_path).unwrap();
956
957 assert_eq!(loaded.shape(), data.shape());
958 for (a, b) in loaded.iter().zip(data.iter()) {
959 assert!((a - b).abs() < 1e-10);
960 }
961 }
962
963 #[test]
964 fn test_binary_round_trip() {
965 let dir = tempdir().unwrap();
966 let file_path = dir.path().join("test.bin");
967
968 let data = Array2::from_shape_vec((4, 3), (1..=12).map(|x| x as f64).collect()).unwrap();
970
971 FormatWriter::binary()
973 .write_file(&data, &file_path)
974 .unwrap();
975
976 let loaded = FormatReader::auto_detect(&file_path).unwrap();
978
979 assert_eq!(loaded.shape(), data.shape());
980 for (a, b) in loaded.iter().zip(data.iter()) {
981 assert!((a - b).abs() < 1e-10);
982 }
983 }
984
985 #[test]
986 fn test_csv_with_header() {
987 let csv_content = "col1,col2,col3\n1.0,2.0,3.0\n4.0,5.0,6.0\n";
988
989 let options = CsvOptions::new().with_header(true);
990 let data = FormatReader::csv()
991 .with_options(options)
992 .read_bytes(csv_content.as_bytes())
993 .unwrap();
994
995 assert_eq!(data.shape(), &[2, 3]);
996 assert_eq!(data[[0, 0]], 1.0);
997 assert_eq!(data[[1, 2]], 6.0);
998 }
999
1000 #[test]
1001 fn test_invalid_csv() {
1002 let csv_content = "1.0,2.0,3.0\n4.0,invalid,6.0\n";
1003
1004 let result = FormatReader::csv().read_bytes(csv_content.as_bytes());
1005
1006 assert!(result.is_err());
1007 }
1008
1009 #[test]
1010 fn test_json_array_format() {
1011 let json_content = r#"[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]"#;
1012
1013 let data = FormatReader::json()
1014 .read_bytes(json_content.as_bytes())
1015 .unwrap();
1016
1017 assert_eq!(data.shape(), &[3, 2]);
1018 assert_eq!(data[[0, 0]], 1.0);
1019 assert_eq!(data[[2, 1]], 6.0);
1020 }
1021
1022 #[test]
1023 fn test_streaming_reader() {
1024 let mut reader = StreamingReader::new(DataFormat::Csv, 50);
1025
1026 let temp_dir = tempdir().unwrap();
1029 let temp_path = temp_dir.path().join("test.csv");
1030
1031 std::fs::write(&temp_path, "1,2,3\n4,5,6\n").unwrap();
1033
1034 let chunk = reader.read_chunk(&temp_path).unwrap();
1035 assert!(chunk.is_some());
1036
1037 let chunk = reader.read_chunk(&temp_path).unwrap();
1038 assert!(chunk.is_none()); }
1040
1041 #[test]
1042 fn test_format_options() {
1043 let csv_opts = CsvOptions::new()
1044 .with_delimiter(b';')
1045 .with_header(false)
1046 .with_quote_char(b'\'');
1047
1048 assert_eq!(csv_opts.delimiter, b';');
1049 assert!(!csv_opts.header);
1050 assert_eq!(csv_opts.quote_char, b'\'');
1051 }
1052}