Skip to main content

scirs2_datasets/
arrow_dataset.rs

1//! HuggingFace-compatible Arrow dataset reader.
2//!
3//! Reads `.arrow` files in Arrow IPC format with optional `dataset_info.json`
4//! metadata. This mirrors the on-disk layout used by HuggingFace `datasets`
5//! when calling `dataset.save_to_disk()`.
6//!
7//! # Feature gates
8//!
9//! | Feature | Provides |
10//! |---------|----------|
11//! | *(default)* | Magic-byte validation, directory scanning, JSON metadata parse |
12//! | `parquet_io` | Full Arrow IPC record-batch reading via the `arrow` crate |
13//!
14//! # File layout expected
15//!
16//! ```text
17//! my_dataset/
18//!   dataset_info.json        ← optional metadata
19//!   train/
20//!     data-00000-of-00001.arrow
21//!   test/
22//!     data-00000-of-00001.arrow
23//! ```
24//!
25//! # Example
26//!
27//! ```rust,no_run
28//! use scirs2_datasets::arrow_dataset::ArrowDataset;
29//!
30//! # fn example() -> Result<(), scirs2_datasets::error::DatasetsError> {
31//! // Validate magic bytes without parsing the full file
32//! let ok = ArrowDataset::validate_arrow_magic("/path/to/data.arrow")?;
33//! println!("Is Arrow IPC: {}", ok);
34//! # Ok(())
35//! # }
36//! ```
37
38use crate::error::{DatasetsError, Result};
39use std::collections::HashMap;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42
43// Arrow IPC magic bytes: "ARROW1\0\0" (8 bytes)
44const ARROW_MAGIC: &[u8; 6] = b"ARROW1";
45
46// ============================================================================
47// Public types
48// ============================================================================
49
50/// Feature type descriptor for a HuggingFace dataset column.
51#[derive(Debug, Clone)]
52pub enum FeatureType {
53    /// A scalar value with a given dtype string (e.g. `"int64"`, `"float32"`).
54    Value {
55        /// Data type name as used in `dataset_info.json`.
56        dtype: String,
57    },
58    /// A variable-length sequence of another feature.
59    Sequence {
60        /// Inner feature descriptor.
61        feature: Box<FeatureType>,
62    },
63    /// Categorical label with an associated name list.
64    ClassLabel {
65        /// Ordered list of class names.
66        names: Vec<String>,
67    },
68    /// Free-text column.
69    Text,
70    /// Image column (raw pixel bytes or file path).
71    Image,
72    /// Unknown or unsupported feature type.
73    Unknown,
74}
75
76/// Parsed representation of a HuggingFace `dataset_info.json` file.
77#[derive(Debug, Clone)]
78pub struct DatasetInfo {
79    /// Dataset name (from the `dataset_name` key).
80    pub dataset_name: String,
81    /// Dataset version string.
82    pub version: String,
83    /// Column feature descriptors, keyed by column name.
84    pub features: HashMap<String, FeatureType>,
85    /// Number of rows reported in the metadata (may differ from actual).
86    pub num_rows: Option<usize>,
87    /// Split name this metadata describes (e.g. `"train"`, `"test"`).
88    pub split: Option<String>,
89}
90
91impl Default for DatasetInfo {
92    fn default() -> Self {
93        Self {
94            dataset_name: String::new(),
95            version: "0.0.0".to_string(),
96            features: HashMap::new(),
97            num_rows: None,
98            split: None,
99        }
100    }
101}
102
103/// A loaded Arrow IPC dataset handle.
104///
105/// Without the `parquet_io` feature this struct stores only metadata and the
106/// file paths discovered on disk; actual column data is not decoded. Enabling
107/// `parquet_io` activates full record-batch parsing via the `arrow` crate.
108#[derive(Debug)]
109pub struct ArrowDataset {
110    /// Parsed `dataset_info.json` metadata, if present.
111    pub info: Option<DatasetInfo>,
112    /// Ordered list of column names discovered in the first file.
113    pub column_names: Vec<String>,
114    /// Total number of rows across all loaded files.
115    pub num_rows: usize,
116    /// Arrow IPC file paths that were loaded.
117    pub(crate) file_paths: Vec<PathBuf>,
118    /// Raw column data per column name (only populated with `parquet_io`).
119    columns: HashMap<String, Vec<u8>>,
120}
121
122impl ArrowDataset {
123    // ------------------------------------------------------------------
124    // Constructors
125    // ------------------------------------------------------------------
126
127    /// Load a HuggingFace-style dataset from a directory.
128    ///
129    /// Scans `dir` (and one level of subdirectories) for `*.arrow` files and
130    /// optionally reads `dataset_info.json` from the same directory.
131    ///
132    /// # Errors
133    ///
134    /// Returns `DatasetsError::NotFound` if no `.arrow` files are present.
135    pub fn from_directory(dir: impl AsRef<Path>) -> Result<Self> {
136        let dir = dir.as_ref();
137
138        if !dir.exists() {
139            return Err(DatasetsError::NotFound(format!(
140                "Directory not found: {}",
141                dir.display()
142            )));
143        }
144
145        // Collect .arrow files (top-level and one sub-directory deep)
146        let mut arrow_files: Vec<PathBuf> = Vec::new();
147        for entry in std::fs::read_dir(dir).map_err(DatasetsError::IoError)? {
148            let entry = entry.map_err(DatasetsError::IoError)?;
149            let path = entry.path();
150            if path.is_file() {
151                if path.extension().and_then(|e| e.to_str()) == Some("arrow") {
152                    arrow_files.push(path);
153                }
154            } else if path.is_dir() {
155                // One level of sub-dirs (split directories like train/, test/)
156                for sub in std::fs::read_dir(&path).map_err(DatasetsError::IoError)? {
157                    let sub = sub.map_err(DatasetsError::IoError)?;
158                    let sub_path = sub.path();
159                    if sub_path.is_file()
160                        && sub_path.extension().and_then(|e| e.to_str()) == Some("arrow")
161                    {
162                        arrow_files.push(sub_path);
163                    }
164                }
165            }
166        }
167
168        if arrow_files.is_empty() {
169            return Err(DatasetsError::NotFound(format!(
170                "No .arrow files found under: {}",
171                dir.display()
172            )));
173        }
174
175        // Sort for deterministic order
176        arrow_files.sort();
177
178        // Attempt to parse dataset_info.json
179        let info = Self::try_load_dataset_info(dir).or_else(|_| {
180            // Also check parent of first sub-dir file
181            if let Some(parent) = arrow_files
182                .first()
183                .and_then(|p| p.parent())
184                .and_then(|p| p.parent())
185            {
186                Self::try_load_dataset_info(parent).ok()
187            } else {
188                None
189            }
190            .ok_or(DatasetsError::NotFound("no dataset_info.json".to_string()))
191        });
192
193        // Validate magic bytes on all files
194        for path in &arrow_files {
195            Self::validate_arrow_magic(path)?;
196        }
197
198        Ok(Self {
199            info: info.ok(),
200            column_names: Vec::new(),
201            num_rows: 0,
202            file_paths: arrow_files,
203            columns: HashMap::new(),
204        })
205    }
206
207    /// Load from a single Arrow IPC file.
208    ///
209    /// Validates magic bytes and (with `parquet_io`) decodes record batches.
210    ///
211    /// # Errors
212    ///
213    /// Returns `DatasetsError::InvalidFormat` if the file does not begin with
214    /// the Arrow IPC magic bytes (`ARROW1`).
215    pub fn from_arrow_file(path: impl AsRef<Path>) -> Result<Self> {
216        let path = path.as_ref();
217
218        if !path.exists() {
219            return Err(DatasetsError::NotFound(format!(
220                "Arrow file not found: {}",
221                path.display()
222            )));
223        }
224
225        // Validate magic bytes
226        Self::validate_arrow_magic(path)?;
227
228        #[cfg(feature = "parquet_io")]
229        {
230            Self::from_arrow_file_full(path)
231        }
232
233        #[cfg(not(feature = "parquet_io"))]
234        {
235            Ok(Self {
236                info: None,
237                column_names: Vec::new(),
238                num_rows: 0,
239                file_paths: vec![path.to_path_buf()],
240                columns: HashMap::new(),
241            })
242        }
243    }
244
245    // ------------------------------------------------------------------
246    // Accessors
247    // ------------------------------------------------------------------
248
249    /// Returns the column names discovered in the dataset.
250    pub fn column_names(&self) -> &[String] {
251        &self.column_names
252    }
253
254    /// Total number of rows across all loaded files.
255    pub fn num_rows(&self) -> usize {
256        self.num_rows
257    }
258
259    /// Dataset metadata from `dataset_info.json`, if present.
260    pub fn info(&self) -> Option<&DatasetInfo> {
261        self.info.as_ref()
262    }
263
264    /// Arrow IPC file paths that back this dataset.
265    pub fn file_paths(&self) -> &[PathBuf] {
266        &self.file_paths
267    }
268
269    // ------------------------------------------------------------------
270    // Validation helpers
271    // ------------------------------------------------------------------
272
273    /// Validate that a file begins with the Arrow IPC magic bytes (`ARROW1`).
274    ///
275    /// Returns `Ok(true)` on success, or an error if the file cannot be read
276    /// or does not start with the expected magic.
277    pub fn validate_arrow_magic(path: impl AsRef<Path>) -> Result<bool> {
278        let path = path.as_ref();
279        let mut f = std::fs::File::open(path).map_err(DatasetsError::IoError)?;
280        let mut buf = [0u8; 6];
281        f.read_exact(&mut buf).map_err(|e| {
282            DatasetsError::InvalidFormat(format!(
283                "Could not read magic bytes from {}: {}",
284                path.display(),
285                e
286            ))
287        })?;
288        if &buf == ARROW_MAGIC {
289            Ok(true)
290        } else {
291            Err(DatasetsError::InvalidFormat(format!(
292                "Not an Arrow IPC file (bad magic bytes): {}",
293                path.display()
294            )))
295        }
296    }
297
298    // ------------------------------------------------------------------
299    // Internal helpers
300    // ------------------------------------------------------------------
301
302    /// Try to parse `dataset_info.json` from `dir`.
303    fn try_load_dataset_info(dir: &Path) -> Result<DatasetInfo> {
304        let info_path = dir.join("dataset_info.json");
305        if !info_path.exists() {
306            return Err(DatasetsError::NotFound(
307                "dataset_info.json not found".to_string(),
308            ));
309        }
310
311        let content = std::fs::read_to_string(&info_path).map_err(DatasetsError::IoError)?;
312
313        Self::parse_dataset_info_json(&content)
314    }
315
316    /// Parse the JSON string of a `dataset_info.json` file.
317    ///
318    /// This is a best-effort parser; unknown keys are silently ignored.
319    fn parse_dataset_info_json(json: &str) -> Result<DatasetInfo> {
320        let value: serde_json::Value =
321            serde_json::from_str(json).map_err(|e| DatasetsError::SerdeError(e.to_string()))?;
322
323        let dataset_name = value
324            .get("dataset_name")
325            .and_then(|v| v.as_str())
326            .unwrap_or("")
327            .to_string();
328
329        let version = value
330            .get("version")
331            .and_then(|v| v.as_str())
332            .unwrap_or("0.0.0")
333            .to_string();
334
335        let split = value
336            .get("split")
337            .and_then(|v| v.as_str())
338            .map(|s| s.to_string());
339
340        let num_rows = value
341            .get("num_rows")
342            .or_else(|| value.get("num_examples"))
343            .and_then(|v| v.as_u64())
344            .map(|n| n as usize);
345
346        // Parse features map
347        let features = if let Some(feat_map) = value.get("features").and_then(|v| v.as_object()) {
348            feat_map
349                .iter()
350                .map(|(k, v)| (k.clone(), Self::parse_feature_type(v)))
351                .collect()
352        } else {
353            HashMap::new()
354        };
355
356        Ok(DatasetInfo {
357            dataset_name,
358            version,
359            features,
360            num_rows,
361            split,
362        })
363    }
364
365    /// Parse a single feature descriptor from a JSON value.
366    fn parse_feature_type(v: &serde_json::Value) -> FeatureType {
367        // Handle both string shorthand and object notation
368        if let Some(s) = v.as_str() {
369            return match s {
370                "text" | "string" => FeatureType::Text,
371                "image" => FeatureType::Image,
372                other => FeatureType::Value {
373                    dtype: other.to_string(),
374                },
375            };
376        }
377
378        if let Some(obj) = v.as_object() {
379            // ClassLabel: {"names": ["class_a", "class_b"]}
380            if let Some(names_val) = obj.get("names") {
381                if let Some(names_arr) = names_val.as_array() {
382                    let names: Vec<String> = names_arr
383                        .iter()
384                        .filter_map(|n| n.as_str().map(|s| s.to_string()))
385                        .collect();
386                    return FeatureType::ClassLabel { names };
387                }
388            }
389
390            // Sequence: {"feature": {...}}
391            if let Some(inner) = obj.get("feature") {
392                return FeatureType::Sequence {
393                    feature: Box::new(Self::parse_feature_type(inner)),
394                };
395            }
396
397            // Value: {"dtype": "int64"}
398            if let Some(dtype) = obj.get("dtype").and_then(|d| d.as_str()) {
399                return FeatureType::Value {
400                    dtype: dtype.to_string(),
401                };
402            }
403
404            // Value: {"_type": "Value", "dtype": "float32"}
405            if obj.get("_type").and_then(|t| t.as_str()) == Some("Value") {
406                let dtype = obj
407                    .get("dtype")
408                    .and_then(|d| d.as_str())
409                    .unwrap_or("unknown")
410                    .to_string();
411                return FeatureType::Value { dtype };
412            }
413
414            if obj.get("_type").and_then(|t| t.as_str()) == Some("ClassLabel") {
415                if let Some(names_arr) = obj.get("names").and_then(|n| n.as_array()) {
416                    let names: Vec<String> = names_arr
417                        .iter()
418                        .filter_map(|n| n.as_str().map(|s| s.to_string()))
419                        .collect();
420                    return FeatureType::ClassLabel { names };
421                }
422            }
423
424            if obj.get("_type").and_then(|t| t.as_str()) == Some("Image") {
425                return FeatureType::Image;
426            }
427
428            if obj.get("_type").and_then(|t| t.as_str()) == Some("Sequence") {
429                if let Some(inner) = obj.get("feature") {
430                    return FeatureType::Sequence {
431                        feature: Box::new(Self::parse_feature_type(inner)),
432                    };
433                }
434            }
435        }
436
437        FeatureType::Unknown
438    }
439
440    // ------------------------------------------------------------------
441    // Full implementation (parquet_io feature)
442    // ------------------------------------------------------------------
443
444    #[cfg(feature = "parquet_io")]
445    fn from_arrow_file_full(path: &Path) -> Result<Self> {
446        use arrow::ipc::reader::FileReader;
447        use std::fs::File;
448
449        let file = File::open(path).map_err(DatasetsError::IoError)?;
450        let reader = FileReader::try_new(file, None)
451            .map_err(|e| DatasetsError::InvalidFormat(format!("Arrow IPC read error: {}", e)))?;
452
453        let schema = reader.schema();
454        let column_names: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
455
456        let mut total_rows = 0usize;
457        let mut columns: HashMap<String, Vec<u8>> = HashMap::new();
458
459        for batch_result in reader {
460            let batch = batch_result.map_err(|e| {
461                DatasetsError::InvalidFormat(format!("Arrow batch read error: {}", e))
462            })?;
463            total_rows += batch.num_rows();
464
465            // Store serialised column data (column name → Arrow buffer bytes)
466            for (i, field) in schema.fields().iter().enumerate() {
467                let col = batch.column(i);
468                let buffers = col.to_data().buffers().to_vec();
469                let entry = columns.entry(field.name().clone()).or_default();
470                for buf in buffers {
471                    entry.extend_from_slice(buf.as_slice());
472                }
473            }
474        }
475
476        Ok(Self {
477            info: None,
478            column_names,
479            num_rows: total_rows,
480            file_paths: vec![path.to_path_buf()],
481            columns,
482        })
483    }
484}
485
486// ============================================================================
487// Tests
488// ============================================================================
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use std::io::Write;
494
495    /// Helper: write an Arrow IPC magic header to a temp file.
496    fn temp_arrow_file(valid: bool) -> std::path::PathBuf {
497        let dir = std::env::temp_dir();
498        let file_name = if valid {
499            "test_valid_arrow.arrow"
500        } else {
501            "test_invalid_arrow.arrow"
502        };
503        let path = dir.join(file_name);
504        let mut f = std::fs::File::create(&path).expect("create temp file");
505        if valid {
506            // Write ARROW1 magic + dummy continuation bytes
507            f.write_all(b"ARROW1\x00\x00some_padding_bytes_for_test")
508                .expect("write magic");
509        } else {
510            // Write wrong magic
511            f.write_all(b"NOTARROW_FILE_CONTENT")
512                .expect("write wrong magic");
513        }
514        path
515    }
516
517    #[test]
518    fn arrow_dataset_validates_magic_bytes() {
519        let path = temp_arrow_file(true);
520        let result = ArrowDataset::validate_arrow_magic(&path);
521        assert!(result.is_ok(), "valid Arrow magic should succeed");
522        assert!(
523            result.expect("valid arrow result"),
524            "validate_arrow_magic should return true for valid magic"
525        );
526    }
527
528    #[test]
529    fn arrow_dataset_rejects_wrong_magic() {
530        let path = temp_arrow_file(false);
531        let result = ArrowDataset::validate_arrow_magic(&path);
532        assert!(result.is_err(), "wrong magic should return an error");
533        if let Err(DatasetsError::InvalidFormat(msg)) = result {
534            assert!(
535                msg.contains("magic bytes"),
536                "error should mention magic bytes, got: {}",
537                msg
538            );
539        } else {
540            panic!("expected InvalidFormat error");
541        }
542    }
543
544    #[test]
545    #[cfg(not(feature = "parquet_io"))]
546    fn arrow_dataset_from_arrow_file_valid() {
547        let path = temp_arrow_file(true);
548        // Without parquet_io the constructor accepts valid magic and returns a stub
549        let result = ArrowDataset::from_arrow_file(&path);
550        assert!(
551            result.is_ok(),
552            "from_arrow_file with valid magic should succeed"
553        );
554        let ds = result.expect("valid arrow dataset");
555        assert_eq!(ds.file_paths().len(), 1);
556    }
557
558    /// With parquet_io, from_arrow_file on a dummy file (wrong IPC content after
559    /// magic) is expected to fail — the Arrow IPC reader parses beyond the magic.
560    #[test]
561    #[cfg(feature = "parquet_io")]
562    fn arrow_dataset_from_arrow_file_valid_parquet_io() {
563        // We cannot easily construct a valid full Arrow IPC file in a unit test
564        // without the arrow crate itself.  Just verify the function exists and
565        // returns a meaningful error for a stub-only file.
566        let path = temp_arrow_file(true);
567        // With full IPC parsing, a stub file (only magic) will fail at the IPC
568        // record-batch level — this is expected.
569        let result = ArrowDataset::from_arrow_file(&path);
570        // It either succeeds (unlikely for a stub) or fails with InvalidFormat
571        match result {
572            Ok(_) => {}                                // Unlikely but acceptable
573            Err(DatasetsError::InvalidFormat(_)) => {} // Expected for stub file
574            Err(other) => panic!("unexpected error variant: {:?}", other),
575        }
576    }
577
578    #[test]
579    fn arrow_dataset_from_arrow_file_invalid() {
580        let path = temp_arrow_file(false);
581        let result = ArrowDataset::from_arrow_file(&path);
582        assert!(
583            result.is_err(),
584            "from_arrow_file with bad magic should fail"
585        );
586    }
587
588    #[test]
589    fn arrow_dataset_from_directory_empty_dir() {
590        let dir = std::env::temp_dir().join("test_empty_arrow_dir");
591        std::fs::create_dir_all(&dir).expect("create temp dir");
592        // Remove any stale .arrow files from previous runs
593        for entry in std::fs::read_dir(&dir).expect("read dir") {
594            let entry = entry.expect("entry");
595            if entry.path().extension().and_then(|e| e.to_str()) == Some("arrow") {
596                std::fs::remove_file(entry.path()).ok();
597            }
598        }
599        let result = ArrowDataset::from_directory(&dir);
600        assert!(result.is_err(), "empty dir should return NotFound");
601        if let Err(DatasetsError::NotFound(_)) = result {
602            // expected
603        } else {
604            panic!("expected NotFound error for empty directory");
605        }
606    }
607
608    #[test]
609    fn arrow_dataset_from_directory_with_arrow_file() {
610        let dir = std::env::temp_dir().join("test_arrow_dir_with_file");
611        std::fs::create_dir_all(&dir).expect("create temp dir");
612        let arrow_path = dir.join("data-00000-of-00001.arrow");
613        {
614            let mut f = std::fs::File::create(&arrow_path).expect("create arrow");
615            f.write_all(b"ARROW1\x00\x00dummy_ipc_content_for_test")
616                .expect("write arrow");
617        }
618        let result = ArrowDataset::from_directory(&dir);
619        assert!(
620            result.is_ok(),
621            "directory with valid arrow file should succeed"
622        );
623        let ds = result.expect("arrow dataset from dir");
624        assert_eq!(ds.file_paths().len(), 1);
625    }
626
627    #[test]
628    fn dataset_info_default() {
629        let info = DatasetInfo::default();
630        assert!(info.dataset_name.is_empty());
631        assert_eq!(info.version, "0.0.0");
632        assert!(info.features.is_empty());
633        assert!(info.num_rows.is_none());
634        assert!(info.split.is_none());
635    }
636
637    #[test]
638    fn dataset_info_parse_json() {
639        let json = r#"{
640            "dataset_name": "my_dataset",
641            "version": "1.0.0",
642            "split": "train",
643            "num_rows": 42,
644            "features": {
645                "text": "text",
646                "label": {"names": ["neg", "pos"]},
647                "score": {"dtype": "float32"}
648            }
649        }"#;
650        let info = ArrowDataset::parse_dataset_info_json(json).expect("parse dataset_info.json");
651        assert_eq!(info.dataset_name, "my_dataset");
652        assert_eq!(info.version, "1.0.0");
653        assert_eq!(info.split.as_deref(), Some("train"));
654        assert_eq!(info.num_rows, Some(42));
655        assert_eq!(info.features.len(), 3);
656        if let FeatureType::ClassLabel { names } = &info.features["label"] {
657            assert_eq!(names, &["neg", "pos"]);
658        } else {
659            panic!("expected ClassLabel for 'label' feature");
660        }
661    }
662
663    #[test]
664    fn arrow_dataset_nonexistent_file() {
665        let result = ArrowDataset::from_arrow_file("/nonexistent/path/data.arrow");
666        assert!(result.is_err());
667        if let Err(DatasetsError::NotFound(_)) = result {
668            // expected
669        } else {
670            panic!("expected NotFound for nonexistent file");
671        }
672    }
673
674    #[test]
675    fn arrow_dataset_nonexistent_directory() {
676        let result = ArrowDataset::from_directory("/nonexistent/arrow_dataset_dir_xyz");
677        assert!(result.is_err());
678        if let Err(DatasetsError::NotFound(_)) = result {
679            // expected
680        } else {
681            panic!("expected NotFound for nonexistent directory");
682        }
683    }
684}