Skip to main content

scirs2_datasets/
parquet_reader.rs

1//! Native Parquet dataset reader
2//!
3//! Provides `ParquetDataset` — reads a Parquet file into a typed dataset with
4//! per-column data accessible as `ColumnData` variants. Requires the
5//! `parquet_io` feature which activates the `parquet` and `arrow` crates.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! # #[cfg(feature = "parquet_io")]
11//! use scirs2_datasets::parquet_reader::ParquetDataset;
12//!
13//! # #[cfg(feature = "parquet_io")]
14//! # fn example() -> Result<(), scirs2_datasets::error::DatasetsError> {
15//! let dataset = ParquetDataset::from_file("data.parquet")?;
16//! println!("Rows: {}, Cols: {}", dataset.n_rows(), dataset.n_cols());
17//! for name in dataset.column_names() {
18//!     println!("  column: {}", name);
19//! }
20//! # Ok(())
21//! # }
22//! ```
23
24#[cfg(feature = "parquet_io")]
25use crate::error::{DatasetsError, Result};
26#[cfg(feature = "parquet_io")]
27use arrow::array::RecordBatchReader;
28#[cfg(feature = "parquet_io")]
29use indexmap::IndexMap;
30#[cfg(feature = "parquet_io")]
31use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
32#[cfg(feature = "parquet_io")]
33use scirs2_core::ndarray::Array2;
34#[cfg(feature = "parquet_io")]
35use std::fs::File;
36#[cfg(feature = "parquet_io")]
37use std::path::Path;
38
39/// A single column's data from a Parquet file.
40///
41/// Nullable values are wrapped in `Option`; `None` represents a null.
42#[cfg(feature = "parquet_io")]
43#[derive(Debug, Clone)]
44pub enum ColumnData {
45    /// 32-bit signed integer column
46    Int32(Vec<Option<i32>>),
47    /// 64-bit signed integer column
48    Int64(Vec<Option<i64>>),
49    /// 32-bit IEEE-754 float column
50    Float32(Vec<Option<f32>>),
51    /// 64-bit IEEE-754 float column
52    Float64(Vec<Option<f64>>),
53    /// Boolean column
54    Boolean(Vec<Option<bool>>),
55    /// UTF-8 string column
56    Utf8(Vec<Option<String>>),
57}
58
59#[cfg(feature = "parquet_io")]
60impl ColumnData {
61    /// Number of rows (including nulls) in this column.
62    pub fn len(&self) -> usize {
63        match self {
64            ColumnData::Int32(v) => v.len(),
65            ColumnData::Int64(v) => v.len(),
66            ColumnData::Float32(v) => v.len(),
67            ColumnData::Float64(v) => v.len(),
68            ColumnData::Boolean(v) => v.len(),
69            ColumnData::Utf8(v) => v.len(),
70        }
71    }
72
73    /// Returns `true` if this column contains no rows.
74    pub fn is_empty(&self) -> bool {
75        self.len() == 0
76    }
77
78    /// Returns `true` if this column holds numeric data (Int32/Int64/Float32/Float64).
79    pub fn is_numeric(&self) -> bool {
80        matches!(
81            self,
82            ColumnData::Int32(_)
83                | ColumnData::Int64(_)
84                | ColumnData::Float32(_)
85                | ColumnData::Float64(_)
86        )
87    }
88
89    /// Cast each non-null value to f64; nulls become `f64::NAN`.
90    pub fn to_f64_vec(&self) -> Option<Vec<f64>> {
91        match self {
92            ColumnData::Int32(v) => {
93                Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
94            }
95            ColumnData::Int64(v) => {
96                Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
97            }
98            ColumnData::Float32(v) => {
99                Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
100            }
101            ColumnData::Float64(v) => Some(v.iter().map(|x| x.unwrap_or(f64::NAN)).collect()),
102            ColumnData::Boolean(_) | ColumnData::Utf8(_) => None,
103        }
104    }
105}
106
107/// A dataset loaded from a Parquet file.
108///
109/// Columns are stored in an `IndexMap` so insertion order (i.e., file column
110/// order) is preserved.  Column names are case-sensitive.
111#[cfg(feature = "parquet_io")]
112pub struct ParquetDataset {
113    /// Per-column data indexed by column name
114    pub columns: IndexMap<String, ColumnData>,
115    /// Total number of rows across all columns
116    pub n_rows: usize,
117}
118
119#[cfg(feature = "parquet_io")]
120impl ParquetDataset {
121    /// Read a Parquet file from the filesystem.
122    ///
123    /// # Errors
124    ///
125    /// Returns `DatasetsError` if the file cannot be opened, is not valid
126    /// Parquet, or contains column types that are not supported (types other
127    /// than Int32/Int64/Float32/Float64/Boolean/Utf8 are skipped with a
128    /// warning rather than causing a hard error).
129    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
130        let file = File::open(path.as_ref()).map_err(DatasetsError::IoError)?;
131
132        let builder = ParquetRecordBatchReaderBuilder::try_new(file)
133            .map_err(|e| DatasetsError::InvalidFormat(format!("Parquet open error: {e}")))?;
134
135        let reader = builder.build().map_err(|e| {
136            DatasetsError::InvalidFormat(format!("Parquet reader build error: {e}"))
137        })?;
138
139        Self::from_record_batch_reader(reader)
140    }
141
142    /// Internal constructor — consumes a `RecordBatchReader` and accumulates
143    /// column data across all batches.
144    fn from_record_batch_reader(mut reader: impl RecordBatchReader) -> Result<Self> {
145        use arrow::array::{
146            Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
147        };
148        use arrow::datatypes::DataType as ArrowDataType;
149
150        let schema = reader.schema();
151        let field_names: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
152
153        // Pre-allocate per-column accumulators as Option<Vec<Option<...>>>
154        // We start each accumulator as None; on the first batch we decide the
155        // ColumnData variant. Columns with unsupported types get None and are
156        // skipped.
157        let num_cols = field_names.len();
158        let mut accumulators: Vec<Option<ColumnAccumulator>> =
159            (0..num_cols).map(|_| None).collect();
160        let mut total_rows: usize = 0;
161
162        for batch_result in reader.by_ref() {
163            let batch = batch_result.map_err(|e| {
164                DatasetsError::InvalidFormat(format!("Parquet read batch error: {e}"))
165            })?;
166
167            total_rows = total_rows.saturating_add(batch.num_rows());
168
169            for (col_idx, field) in batch.schema().fields().iter().enumerate() {
170                let array = batch.column(col_idx);
171
172                let col_acc =
173                    accumulators[col_idx].get_or_insert_with(|| match field.data_type() {
174                        ArrowDataType::Int32 => ColumnAccumulator::Int32(Vec::new()),
175                        ArrowDataType::Int64 => ColumnAccumulator::Int64(Vec::new()),
176                        ArrowDataType::Float32 => ColumnAccumulator::Float32(Vec::new()),
177                        ArrowDataType::Float64 => ColumnAccumulator::Float64(Vec::new()),
178                        ArrowDataType::Boolean => ColumnAccumulator::Boolean(Vec::new()),
179                        ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => {
180                            ColumnAccumulator::Utf8(Vec::new())
181                        }
182                        _ => ColumnAccumulator::Unsupported,
183                    });
184
185                match col_acc {
186                    ColumnAccumulator::Int32(buf) => {
187                        let typed =
188                            array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
189                                DatasetsError::InvalidFormat(format!(
190                                    "Column '{}' type mismatch",
191                                    field.name()
192                                ))
193                            })?;
194                        for i in 0..typed.len() {
195                            buf.push(if typed.is_null(i) {
196                                None
197                            } else {
198                                Some(typed.value(i))
199                            });
200                        }
201                    }
202                    ColumnAccumulator::Int64(buf) => {
203                        let typed =
204                            array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
205                                DatasetsError::InvalidFormat(format!(
206                                    "Column '{}' type mismatch",
207                                    field.name()
208                                ))
209                            })?;
210                        for i in 0..typed.len() {
211                            buf.push(if typed.is_null(i) {
212                                None
213                            } else {
214                                Some(typed.value(i))
215                            });
216                        }
217                    }
218                    ColumnAccumulator::Float32(buf) => {
219                        let typed =
220                            array
221                                .as_any()
222                                .downcast_ref::<Float32Array>()
223                                .ok_or_else(|| {
224                                    DatasetsError::InvalidFormat(format!(
225                                        "Column '{}' type mismatch",
226                                        field.name()
227                                    ))
228                                })?;
229                        for i in 0..typed.len() {
230                            buf.push(if typed.is_null(i) {
231                                None
232                            } else {
233                                Some(typed.value(i))
234                            });
235                        }
236                    }
237                    ColumnAccumulator::Float64(buf) => {
238                        let typed =
239                            array
240                                .as_any()
241                                .downcast_ref::<Float64Array>()
242                                .ok_or_else(|| {
243                                    DatasetsError::InvalidFormat(format!(
244                                        "Column '{}' type mismatch",
245                                        field.name()
246                                    ))
247                                })?;
248                        for i in 0..typed.len() {
249                            buf.push(if typed.is_null(i) {
250                                None
251                            } else {
252                                Some(typed.value(i))
253                            });
254                        }
255                    }
256                    ColumnAccumulator::Boolean(buf) => {
257                        let typed =
258                            array
259                                .as_any()
260                                .downcast_ref::<BooleanArray>()
261                                .ok_or_else(|| {
262                                    DatasetsError::InvalidFormat(format!(
263                                        "Column '{}' type mismatch",
264                                        field.name()
265                                    ))
266                                })?;
267                        for i in 0..typed.len() {
268                            buf.push(if typed.is_null(i) {
269                                None
270                            } else {
271                                Some(typed.value(i))
272                            });
273                        }
274                    }
275                    ColumnAccumulator::Utf8(buf) => {
276                        let typed =
277                            array
278                                .as_any()
279                                .downcast_ref::<StringArray>()
280                                .ok_or_else(|| {
281                                    DatasetsError::InvalidFormat(format!(
282                                        "Column '{}' type mismatch",
283                                        field.name()
284                                    ))
285                                })?;
286                        for i in 0..typed.len() {
287                            buf.push(if typed.is_null(i) {
288                                None
289                            } else {
290                                Some(typed.value(i).to_owned())
291                            });
292                        }
293                    }
294                    ColumnAccumulator::Unsupported => {
295                        // Skip silently — column will be absent from dataset
296                    }
297                }
298            }
299        }
300
301        // Build final IndexMap
302        let mut columns: IndexMap<String, ColumnData> = IndexMap::with_capacity(num_cols);
303        for (col_idx, name) in field_names.iter().enumerate() {
304            match accumulators[col_idx].take() {
305                Some(ColumnAccumulator::Int32(v)) => {
306                    columns.insert(name.clone(), ColumnData::Int32(v));
307                }
308                Some(ColumnAccumulator::Int64(v)) => {
309                    columns.insert(name.clone(), ColumnData::Int64(v));
310                }
311                Some(ColumnAccumulator::Float32(v)) => {
312                    columns.insert(name.clone(), ColumnData::Float32(v));
313                }
314                Some(ColumnAccumulator::Float64(v)) => {
315                    columns.insert(name.clone(), ColumnData::Float64(v));
316                }
317                Some(ColumnAccumulator::Boolean(v)) => {
318                    columns.insert(name.clone(), ColumnData::Boolean(v));
319                }
320                Some(ColumnAccumulator::Utf8(v)) => {
321                    columns.insert(name.clone(), ColumnData::Utf8(v));
322                }
323                Some(ColumnAccumulator::Unsupported) | None => {
324                    // Omit unsupported columns
325                }
326            }
327        }
328
329        Ok(Self {
330            columns,
331            n_rows: total_rows,
332        })
333    }
334
335    /// Look up a column by name.
336    pub fn column(&self, name: &str) -> Option<&ColumnData> {
337        self.columns.get(name)
338    }
339
340    /// Return column names in file order.
341    pub fn column_names(&self) -> Vec<&str> {
342        self.columns.keys().map(|s| s.as_str()).collect()
343    }
344
345    /// Number of rows.
346    pub fn n_rows(&self) -> usize {
347        self.n_rows
348    }
349
350    /// Number of supported columns in the dataset.
351    pub fn n_cols(&self) -> usize {
352        self.columns.len()
353    }
354
355    /// Convert all numeric columns to a dense `Array2<f64>` (column-major).
356    ///
357    /// String and Boolean columns are skipped. Null values become `f64::NAN`.
358    /// Column order matches `column_names()`.
359    ///
360    /// # Errors
361    ///
362    /// Returns an error if there are no numeric columns, or if column lengths
363    /// are inconsistent.
364    pub fn to_float_matrix(&self) -> Result<Array2<f64>> {
365        let numeric_cols: Vec<(&str, Vec<f64>)> = self
366            .columns
367            .iter()
368            .filter_map(|(name, col)| col.to_f64_vec().map(|v| (name.as_str(), v)))
369            .collect();
370
371        if numeric_cols.is_empty() {
372            return Err(DatasetsError::InvalidFormat(
373                "No numeric columns found in ParquetDataset".to_string(),
374            ));
375        }
376
377        let n_rows = self.n_rows;
378        let n_cols = numeric_cols.len();
379
380        // Verify all numeric columns have the expected length
381        for (name, col) in &numeric_cols {
382            if col.len() != n_rows {
383                return Err(DatasetsError::InvalidFormat(format!(
384                    "Column '{}' has {} rows, expected {}",
385                    name,
386                    col.len(),
387                    n_rows
388                )));
389            }
390        }
391
392        let mut matrix = Array2::<f64>::zeros((n_rows, n_cols));
393        for (j, (_, col)) in numeric_cols.iter().enumerate() {
394            for (i, &v) in col.iter().enumerate() {
395                matrix[[i, j]] = v;
396            }
397        }
398
399        Ok(matrix)
400    }
401}
402
403/// Internal accumulator used while reading batches.
404#[cfg(feature = "parquet_io")]
405#[derive(Debug)]
406enum ColumnAccumulator {
407    Int32(Vec<Option<i32>>),
408    Int64(Vec<Option<i64>>),
409    Float32(Vec<Option<f32>>),
410    Float64(Vec<Option<f64>>),
411    Boolean(Vec<Option<bool>>),
412    Utf8(Vec<Option<String>>),
413    Unsupported,
414}
415
416// ============================================================================
417// Tests
418// ============================================================================
419
420#[cfg(test)]
421#[cfg(feature = "parquet_io")]
422mod tests {
423    use super::*;
424    use arrow::array::{Float64Array, Int32Array, StringArray};
425    use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
426    use arrow::record_batch::RecordBatch;
427    use parquet::arrow::ArrowWriter;
428    use std::io::Write;
429    use std::sync::Arc;
430
431    /// Write a minimal Parquet file to a temp path and return the path.
432    fn write_test_parquet(
433        schema: Arc<Schema>,
434        batches: Vec<RecordBatch>,
435    ) -> (tempfile::TempDir, std::path::PathBuf) {
436        let dir = tempfile::tempdir().expect("tmpdir");
437        let path = dir.path().join("test.parquet");
438        let file = std::fs::File::create(&path).expect("create file");
439        let mut writer = ArrowWriter::try_new(file, schema, None).expect("create parquet writer");
440        for batch in batches {
441            writer.write(&batch).expect("write batch");
442        }
443        writer.close().expect("close writer");
444        (dir, path)
445    }
446
447    #[test]
448    fn test_parquet_read_numeric_columns() {
449        let schema = Arc::new(Schema::new(vec![
450            Field::new("x", ArrowDataType::Int32, false),
451            Field::new("y", ArrowDataType::Float64, false),
452        ]));
453        let batch = RecordBatch::try_new(
454            schema.clone(),
455            vec![
456                Arc::new(Int32Array::from(vec![1, 2, 3])),
457                Arc::new(Float64Array::from(vec![1.1, 2.2, 3.3])),
458            ],
459        )
460        .expect("record batch");
461
462        let (_dir, path) = write_test_parquet(schema, vec![batch]);
463        let ds = ParquetDataset::from_file(&path).expect("from_file");
464
465        assert_eq!(ds.n_rows(), 3);
466        assert_eq!(ds.n_cols(), 2);
467        assert!(ds.column("x").is_some());
468        assert!(ds.column("y").is_some());
469
470        if let Some(ColumnData::Int32(vals)) = ds.column("x") {
471            assert_eq!(vals[0], Some(1));
472            assert_eq!(vals[2], Some(3));
473        } else {
474            panic!("Expected Int32 column");
475        }
476
477        if let Some(ColumnData::Float64(vals)) = ds.column("y") {
478            assert!((vals[1].expect("non-null") - 2.2).abs() < 1e-10);
479        } else {
480            panic!("Expected Float64 column");
481        }
482    }
483
484    #[test]
485    fn test_parquet_read_string_column() {
486        let schema = Arc::new(Schema::new(vec![Field::new(
487            "name",
488            ArrowDataType::Utf8,
489            true,
490        )]));
491        let batch = RecordBatch::try_new(
492            schema.clone(),
493            vec![Arc::new(StringArray::from(vec![
494                Some("alice"),
495                None,
496                Some("bob"),
497            ]))],
498        )
499        .expect("record batch");
500
501        let (_dir, path) = write_test_parquet(schema, vec![batch]);
502        let ds = ParquetDataset::from_file(&path).expect("from_file");
503
504        assert_eq!(ds.n_rows(), 3);
505        if let Some(ColumnData::Utf8(vals)) = ds.column("name") {
506            assert_eq!(vals[0], Some("alice".to_owned()));
507            assert_eq!(vals[1], None);
508            assert_eq!(vals[2], Some("bob".to_owned()));
509        } else {
510            panic!("Expected Utf8 column");
511        }
512    }
513
514    #[test]
515    fn test_parquet_column_names_order() {
516        let schema = Arc::new(Schema::new(vec![
517            Field::new("z", ArrowDataType::Int32, false),
518            Field::new("a", ArrowDataType::Float64, false),
519            Field::new("m", ArrowDataType::Int64, false),
520        ]));
521        let batch = RecordBatch::try_new(
522            schema.clone(),
523            vec![
524                Arc::new(Int32Array::from(vec![0])),
525                Arc::new(Float64Array::from(vec![0.0])),
526                Arc::new(arrow::array::Int64Array::from(vec![0i64])),
527            ],
528        )
529        .expect("record batch");
530
531        let (_dir, path) = write_test_parquet(schema, vec![batch]);
532        let ds = ParquetDataset::from_file(&path).expect("from_file");
533
534        // Insertion order must be preserved
535        assert_eq!(ds.column_names(), vec!["z", "a", "m"]);
536    }
537
538    #[test]
539    fn test_parquet_to_float_matrix() {
540        let schema = Arc::new(Schema::new(vec![
541            Field::new("a", ArrowDataType::Float64, false),
542            Field::new("b", ArrowDataType::Float64, false),
543        ]));
544        let batch = RecordBatch::try_new(
545            schema.clone(),
546            vec![
547                Arc::new(Float64Array::from(vec![1.0, 2.0])),
548                Arc::new(Float64Array::from(vec![3.0, 4.0])),
549            ],
550        )
551        .expect("record batch");
552
553        let (_dir, path) = write_test_parquet(schema, vec![batch]);
554        let ds = ParquetDataset::from_file(&path).expect("from_file");
555        let mat = ds.to_float_matrix().expect("to_float_matrix");
556
557        assert_eq!(mat.shape(), &[2, 2]);
558        assert!((mat[[0, 0]] - 1.0).abs() < 1e-10);
559        assert!((mat[[0, 1]] - 3.0).abs() < 1e-10);
560        assert!((mat[[1, 0]] - 2.0).abs() < 1e-10);
561        assert!((mat[[1, 1]] - 4.0).abs() < 1e-10);
562    }
563
564    #[test]
565    fn test_parquet_nullable_values() {
566        let schema = Arc::new(Schema::new(vec![Field::new(
567            "v",
568            ArrowDataType::Float64,
569            true,
570        )]));
571        let batch = RecordBatch::try_new(
572            schema.clone(),
573            vec![Arc::new(Float64Array::from(vec![
574                Some(1.0),
575                None,
576                Some(3.0),
577            ]))],
578        )
579        .expect("record batch");
580
581        let (_dir, path) = write_test_parquet(schema, vec![batch]);
582        let ds = ParquetDataset::from_file(&path).expect("from_file");
583
584        if let Some(ColumnData::Float64(vals)) = ds.column("v") {
585            assert_eq!(vals[0], Some(1.0));
586            assert_eq!(vals[1], None);
587            assert_eq!(vals[2], Some(3.0));
588        } else {
589            panic!("Expected Float64 column");
590        }
591    }
592
593    #[test]
594    fn test_parquet_to_float_matrix_no_numeric_fails() {
595        let schema = Arc::new(Schema::new(vec![Field::new(
596            "name",
597            ArrowDataType::Utf8,
598            false,
599        )]));
600        let batch =
601            RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec!["x"]))])
602                .expect("record batch");
603
604        let (_dir, path) = write_test_parquet(schema, vec![batch]);
605        let ds = ParquetDataset::from_file(&path).expect("from_file");
606        assert!(ds.to_float_matrix().is_err());
607    }
608
609    #[test]
610    fn test_parquet_multiple_batches() {
611        let schema = Arc::new(Schema::new(vec![Field::new(
612            "v",
613            ArrowDataType::Int32,
614            false,
615        )]));
616        let batch1 =
617            RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2]))])
618                .expect("batch1");
619        let batch2 = RecordBatch::try_new(
620            schema.clone(),
621            vec![Arc::new(Int32Array::from(vec![3, 4, 5]))],
622        )
623        .expect("batch2");
624
625        let (_dir, path) = write_test_parquet(schema, vec![batch1, batch2]);
626        let ds = ParquetDataset::from_file(&path).expect("from_file");
627
628        assert_eq!(ds.n_rows(), 5);
629        if let Some(ColumnData::Int32(vals)) = ds.column("v") {
630            assert_eq!(vals.len(), 5);
631            assert_eq!(vals[4], Some(5));
632        } else {
633            panic!("Expected Int32 column");
634        }
635    }
636}