1#[cfg(feature = "parquet_io")]
25use crate::error::{DatasetsError, Result};
26#[cfg(feature = "parquet_io")]
27use arrow::array::RecordBatchReader;
28#[cfg(feature = "parquet_io")]
29use indexmap::IndexMap;
30#[cfg(feature = "parquet_io")]
31use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
32#[cfg(feature = "parquet_io")]
33use scirs2_core::ndarray::Array2;
34#[cfg(feature = "parquet_io")]
35use std::fs::File;
36#[cfg(feature = "parquet_io")]
37use std::path::Path;
38
39#[cfg(feature = "parquet_io")]
43#[derive(Debug, Clone)]
44pub enum ColumnData {
45 Int32(Vec<Option<i32>>),
47 Int64(Vec<Option<i64>>),
49 Float32(Vec<Option<f32>>),
51 Float64(Vec<Option<f64>>),
53 Boolean(Vec<Option<bool>>),
55 Utf8(Vec<Option<String>>),
57}
58
59#[cfg(feature = "parquet_io")]
60impl ColumnData {
61 pub fn len(&self) -> usize {
63 match self {
64 ColumnData::Int32(v) => v.len(),
65 ColumnData::Int64(v) => v.len(),
66 ColumnData::Float32(v) => v.len(),
67 ColumnData::Float64(v) => v.len(),
68 ColumnData::Boolean(v) => v.len(),
69 ColumnData::Utf8(v) => v.len(),
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.len() == 0
76 }
77
78 pub fn is_numeric(&self) -> bool {
80 matches!(
81 self,
82 ColumnData::Int32(_)
83 | ColumnData::Int64(_)
84 | ColumnData::Float32(_)
85 | ColumnData::Float64(_)
86 )
87 }
88
89 pub fn to_f64_vec(&self) -> Option<Vec<f64>> {
91 match self {
92 ColumnData::Int32(v) => {
93 Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
94 }
95 ColumnData::Int64(v) => {
96 Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
97 }
98 ColumnData::Float32(v) => {
99 Some(v.iter().map(|x| x.map_or(f64::NAN, |n| n as f64)).collect())
100 }
101 ColumnData::Float64(v) => Some(v.iter().map(|x| x.unwrap_or(f64::NAN)).collect()),
102 ColumnData::Boolean(_) | ColumnData::Utf8(_) => None,
103 }
104 }
105}
106
107#[cfg(feature = "parquet_io")]
112pub struct ParquetDataset {
113 pub columns: IndexMap<String, ColumnData>,
115 pub n_rows: usize,
117}
118
119#[cfg(feature = "parquet_io")]
120impl ParquetDataset {
121 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
130 let file = File::open(path.as_ref()).map_err(DatasetsError::IoError)?;
131
132 let builder = ParquetRecordBatchReaderBuilder::try_new(file)
133 .map_err(|e| DatasetsError::InvalidFormat(format!("Parquet open error: {e}")))?;
134
135 let reader = builder.build().map_err(|e| {
136 DatasetsError::InvalidFormat(format!("Parquet reader build error: {e}"))
137 })?;
138
139 Self::from_record_batch_reader(reader)
140 }
141
142 fn from_record_batch_reader(mut reader: impl RecordBatchReader) -> Result<Self> {
145 use arrow::array::{
146 Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
147 };
148 use arrow::datatypes::DataType as ArrowDataType;
149
150 let schema = reader.schema();
151 let field_names: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
152
153 let num_cols = field_names.len();
158 let mut accumulators: Vec<Option<ColumnAccumulator>> =
159 (0..num_cols).map(|_| None).collect();
160 let mut total_rows: usize = 0;
161
162 for batch_result in reader.by_ref() {
163 let batch = batch_result.map_err(|e| {
164 DatasetsError::InvalidFormat(format!("Parquet read batch error: {e}"))
165 })?;
166
167 total_rows = total_rows.saturating_add(batch.num_rows());
168
169 for (col_idx, field) in batch.schema().fields().iter().enumerate() {
170 let array = batch.column(col_idx);
171
172 let col_acc =
173 accumulators[col_idx].get_or_insert_with(|| match field.data_type() {
174 ArrowDataType::Int32 => ColumnAccumulator::Int32(Vec::new()),
175 ArrowDataType::Int64 => ColumnAccumulator::Int64(Vec::new()),
176 ArrowDataType::Float32 => ColumnAccumulator::Float32(Vec::new()),
177 ArrowDataType::Float64 => ColumnAccumulator::Float64(Vec::new()),
178 ArrowDataType::Boolean => ColumnAccumulator::Boolean(Vec::new()),
179 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => {
180 ColumnAccumulator::Utf8(Vec::new())
181 }
182 _ => ColumnAccumulator::Unsupported,
183 });
184
185 match col_acc {
186 ColumnAccumulator::Int32(buf) => {
187 let typed =
188 array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
189 DatasetsError::InvalidFormat(format!(
190 "Column '{}' type mismatch",
191 field.name()
192 ))
193 })?;
194 for i in 0..typed.len() {
195 buf.push(if typed.is_null(i) {
196 None
197 } else {
198 Some(typed.value(i))
199 });
200 }
201 }
202 ColumnAccumulator::Int64(buf) => {
203 let typed =
204 array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
205 DatasetsError::InvalidFormat(format!(
206 "Column '{}' type mismatch",
207 field.name()
208 ))
209 })?;
210 for i in 0..typed.len() {
211 buf.push(if typed.is_null(i) {
212 None
213 } else {
214 Some(typed.value(i))
215 });
216 }
217 }
218 ColumnAccumulator::Float32(buf) => {
219 let typed =
220 array
221 .as_any()
222 .downcast_ref::<Float32Array>()
223 .ok_or_else(|| {
224 DatasetsError::InvalidFormat(format!(
225 "Column '{}' type mismatch",
226 field.name()
227 ))
228 })?;
229 for i in 0..typed.len() {
230 buf.push(if typed.is_null(i) {
231 None
232 } else {
233 Some(typed.value(i))
234 });
235 }
236 }
237 ColumnAccumulator::Float64(buf) => {
238 let typed =
239 array
240 .as_any()
241 .downcast_ref::<Float64Array>()
242 .ok_or_else(|| {
243 DatasetsError::InvalidFormat(format!(
244 "Column '{}' type mismatch",
245 field.name()
246 ))
247 })?;
248 for i in 0..typed.len() {
249 buf.push(if typed.is_null(i) {
250 None
251 } else {
252 Some(typed.value(i))
253 });
254 }
255 }
256 ColumnAccumulator::Boolean(buf) => {
257 let typed =
258 array
259 .as_any()
260 .downcast_ref::<BooleanArray>()
261 .ok_or_else(|| {
262 DatasetsError::InvalidFormat(format!(
263 "Column '{}' type mismatch",
264 field.name()
265 ))
266 })?;
267 for i in 0..typed.len() {
268 buf.push(if typed.is_null(i) {
269 None
270 } else {
271 Some(typed.value(i))
272 });
273 }
274 }
275 ColumnAccumulator::Utf8(buf) => {
276 let typed =
277 array
278 .as_any()
279 .downcast_ref::<StringArray>()
280 .ok_or_else(|| {
281 DatasetsError::InvalidFormat(format!(
282 "Column '{}' type mismatch",
283 field.name()
284 ))
285 })?;
286 for i in 0..typed.len() {
287 buf.push(if typed.is_null(i) {
288 None
289 } else {
290 Some(typed.value(i).to_owned())
291 });
292 }
293 }
294 ColumnAccumulator::Unsupported => {
295 }
297 }
298 }
299 }
300
301 let mut columns: IndexMap<String, ColumnData> = IndexMap::with_capacity(num_cols);
303 for (col_idx, name) in field_names.iter().enumerate() {
304 match accumulators[col_idx].take() {
305 Some(ColumnAccumulator::Int32(v)) => {
306 columns.insert(name.clone(), ColumnData::Int32(v));
307 }
308 Some(ColumnAccumulator::Int64(v)) => {
309 columns.insert(name.clone(), ColumnData::Int64(v));
310 }
311 Some(ColumnAccumulator::Float32(v)) => {
312 columns.insert(name.clone(), ColumnData::Float32(v));
313 }
314 Some(ColumnAccumulator::Float64(v)) => {
315 columns.insert(name.clone(), ColumnData::Float64(v));
316 }
317 Some(ColumnAccumulator::Boolean(v)) => {
318 columns.insert(name.clone(), ColumnData::Boolean(v));
319 }
320 Some(ColumnAccumulator::Utf8(v)) => {
321 columns.insert(name.clone(), ColumnData::Utf8(v));
322 }
323 Some(ColumnAccumulator::Unsupported) | None => {
324 }
326 }
327 }
328
329 Ok(Self {
330 columns,
331 n_rows: total_rows,
332 })
333 }
334
335 pub fn column(&self, name: &str) -> Option<&ColumnData> {
337 self.columns.get(name)
338 }
339
340 pub fn column_names(&self) -> Vec<&str> {
342 self.columns.keys().map(|s| s.as_str()).collect()
343 }
344
345 pub fn n_rows(&self) -> usize {
347 self.n_rows
348 }
349
350 pub fn n_cols(&self) -> usize {
352 self.columns.len()
353 }
354
355 pub fn to_float_matrix(&self) -> Result<Array2<f64>> {
365 let numeric_cols: Vec<(&str, Vec<f64>)> = self
366 .columns
367 .iter()
368 .filter_map(|(name, col)| col.to_f64_vec().map(|v| (name.as_str(), v)))
369 .collect();
370
371 if numeric_cols.is_empty() {
372 return Err(DatasetsError::InvalidFormat(
373 "No numeric columns found in ParquetDataset".to_string(),
374 ));
375 }
376
377 let n_rows = self.n_rows;
378 let n_cols = numeric_cols.len();
379
380 for (name, col) in &numeric_cols {
382 if col.len() != n_rows {
383 return Err(DatasetsError::InvalidFormat(format!(
384 "Column '{}' has {} rows, expected {}",
385 name,
386 col.len(),
387 n_rows
388 )));
389 }
390 }
391
392 let mut matrix = Array2::<f64>::zeros((n_rows, n_cols));
393 for (j, (_, col)) in numeric_cols.iter().enumerate() {
394 for (i, &v) in col.iter().enumerate() {
395 matrix[[i, j]] = v;
396 }
397 }
398
399 Ok(matrix)
400 }
401}
402
403#[cfg(feature = "parquet_io")]
405#[derive(Debug)]
406enum ColumnAccumulator {
407 Int32(Vec<Option<i32>>),
408 Int64(Vec<Option<i64>>),
409 Float32(Vec<Option<f32>>),
410 Float64(Vec<Option<f64>>),
411 Boolean(Vec<Option<bool>>),
412 Utf8(Vec<Option<String>>),
413 Unsupported,
414}
415
416#[cfg(test)]
421#[cfg(feature = "parquet_io")]
422mod tests {
423 use super::*;
424 use arrow::array::{Float64Array, Int32Array, StringArray};
425 use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
426 use arrow::record_batch::RecordBatch;
427 use parquet::arrow::ArrowWriter;
428 use std::io::Write;
429 use std::sync::Arc;
430
431 fn write_test_parquet(
433 schema: Arc<Schema>,
434 batches: Vec<RecordBatch>,
435 ) -> (tempfile::TempDir, std::path::PathBuf) {
436 let dir = tempfile::tempdir().expect("tmpdir");
437 let path = dir.path().join("test.parquet");
438 let file = std::fs::File::create(&path).expect("create file");
439 let mut writer = ArrowWriter::try_new(file, schema, None).expect("create parquet writer");
440 for batch in batches {
441 writer.write(&batch).expect("write batch");
442 }
443 writer.close().expect("close writer");
444 (dir, path)
445 }
446
447 #[test]
448 fn test_parquet_read_numeric_columns() {
449 let schema = Arc::new(Schema::new(vec![
450 Field::new("x", ArrowDataType::Int32, false),
451 Field::new("y", ArrowDataType::Float64, false),
452 ]));
453 let batch = RecordBatch::try_new(
454 schema.clone(),
455 vec![
456 Arc::new(Int32Array::from(vec![1, 2, 3])),
457 Arc::new(Float64Array::from(vec![1.1, 2.2, 3.3])),
458 ],
459 )
460 .expect("record batch");
461
462 let (_dir, path) = write_test_parquet(schema, vec![batch]);
463 let ds = ParquetDataset::from_file(&path).expect("from_file");
464
465 assert_eq!(ds.n_rows(), 3);
466 assert_eq!(ds.n_cols(), 2);
467 assert!(ds.column("x").is_some());
468 assert!(ds.column("y").is_some());
469
470 if let Some(ColumnData::Int32(vals)) = ds.column("x") {
471 assert_eq!(vals[0], Some(1));
472 assert_eq!(vals[2], Some(3));
473 } else {
474 panic!("Expected Int32 column");
475 }
476
477 if let Some(ColumnData::Float64(vals)) = ds.column("y") {
478 assert!((vals[1].expect("non-null") - 2.2).abs() < 1e-10);
479 } else {
480 panic!("Expected Float64 column");
481 }
482 }
483
484 #[test]
485 fn test_parquet_read_string_column() {
486 let schema = Arc::new(Schema::new(vec![Field::new(
487 "name",
488 ArrowDataType::Utf8,
489 true,
490 )]));
491 let batch = RecordBatch::try_new(
492 schema.clone(),
493 vec![Arc::new(StringArray::from(vec![
494 Some("alice"),
495 None,
496 Some("bob"),
497 ]))],
498 )
499 .expect("record batch");
500
501 let (_dir, path) = write_test_parquet(schema, vec![batch]);
502 let ds = ParquetDataset::from_file(&path).expect("from_file");
503
504 assert_eq!(ds.n_rows(), 3);
505 if let Some(ColumnData::Utf8(vals)) = ds.column("name") {
506 assert_eq!(vals[0], Some("alice".to_owned()));
507 assert_eq!(vals[1], None);
508 assert_eq!(vals[2], Some("bob".to_owned()));
509 } else {
510 panic!("Expected Utf8 column");
511 }
512 }
513
514 #[test]
515 fn test_parquet_column_names_order() {
516 let schema = Arc::new(Schema::new(vec![
517 Field::new("z", ArrowDataType::Int32, false),
518 Field::new("a", ArrowDataType::Float64, false),
519 Field::new("m", ArrowDataType::Int64, false),
520 ]));
521 let batch = RecordBatch::try_new(
522 schema.clone(),
523 vec![
524 Arc::new(Int32Array::from(vec![0])),
525 Arc::new(Float64Array::from(vec![0.0])),
526 Arc::new(arrow::array::Int64Array::from(vec![0i64])),
527 ],
528 )
529 .expect("record batch");
530
531 let (_dir, path) = write_test_parquet(schema, vec![batch]);
532 let ds = ParquetDataset::from_file(&path).expect("from_file");
533
534 assert_eq!(ds.column_names(), vec!["z", "a", "m"]);
536 }
537
538 #[test]
539 fn test_parquet_to_float_matrix() {
540 let schema = Arc::new(Schema::new(vec![
541 Field::new("a", ArrowDataType::Float64, false),
542 Field::new("b", ArrowDataType::Float64, false),
543 ]));
544 let batch = RecordBatch::try_new(
545 schema.clone(),
546 vec![
547 Arc::new(Float64Array::from(vec![1.0, 2.0])),
548 Arc::new(Float64Array::from(vec![3.0, 4.0])),
549 ],
550 )
551 .expect("record batch");
552
553 let (_dir, path) = write_test_parquet(schema, vec![batch]);
554 let ds = ParquetDataset::from_file(&path).expect("from_file");
555 let mat = ds.to_float_matrix().expect("to_float_matrix");
556
557 assert_eq!(mat.shape(), &[2, 2]);
558 assert!((mat[[0, 0]] - 1.0).abs() < 1e-10);
559 assert!((mat[[0, 1]] - 3.0).abs() < 1e-10);
560 assert!((mat[[1, 0]] - 2.0).abs() < 1e-10);
561 assert!((mat[[1, 1]] - 4.0).abs() < 1e-10);
562 }
563
564 #[test]
565 fn test_parquet_nullable_values() {
566 let schema = Arc::new(Schema::new(vec![Field::new(
567 "v",
568 ArrowDataType::Float64,
569 true,
570 )]));
571 let batch = RecordBatch::try_new(
572 schema.clone(),
573 vec![Arc::new(Float64Array::from(vec![
574 Some(1.0),
575 None,
576 Some(3.0),
577 ]))],
578 )
579 .expect("record batch");
580
581 let (_dir, path) = write_test_parquet(schema, vec![batch]);
582 let ds = ParquetDataset::from_file(&path).expect("from_file");
583
584 if let Some(ColumnData::Float64(vals)) = ds.column("v") {
585 assert_eq!(vals[0], Some(1.0));
586 assert_eq!(vals[1], None);
587 assert_eq!(vals[2], Some(3.0));
588 } else {
589 panic!("Expected Float64 column");
590 }
591 }
592
593 #[test]
594 fn test_parquet_to_float_matrix_no_numeric_fails() {
595 let schema = Arc::new(Schema::new(vec![Field::new(
596 "name",
597 ArrowDataType::Utf8,
598 false,
599 )]));
600 let batch =
601 RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec!["x"]))])
602 .expect("record batch");
603
604 let (_dir, path) = write_test_parquet(schema, vec![batch]);
605 let ds = ParquetDataset::from_file(&path).expect("from_file");
606 assert!(ds.to_float_matrix().is_err());
607 }
608
609 #[test]
610 fn test_parquet_multiple_batches() {
611 let schema = Arc::new(Schema::new(vec![Field::new(
612 "v",
613 ArrowDataType::Int32,
614 false,
615 )]));
616 let batch1 =
617 RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2]))])
618 .expect("batch1");
619 let batch2 = RecordBatch::try_new(
620 schema.clone(),
621 vec![Arc::new(Int32Array::from(vec![3, 4, 5]))],
622 )
623 .expect("batch2");
624
625 let (_dir, path) = write_test_parquet(schema, vec![batch1, batch2]);
626 let ds = ParquetDataset::from_file(&path).expect("from_file");
627
628 assert_eq!(ds.n_rows(), 5);
629 if let Some(ColumnData::Int32(vals)) = ds.column("v") {
630 assert_eq!(vals.len(), 5);
631 assert_eq!(vals[4], Some(5));
632 } else {
633 panic!("Expected Int32 column");
634 }
635 }
636}