Skip to main content

shape_runtime/schema_inference/
mod.rs

1//! Schema inference for data files.
2//!
3//! Reads just the schema (column names + types) from CSV, JSON, and Parquet files
4//! without loading the full data. Used for compile-time schema validation.
5
6pub mod lockfile;
7
8use arrow_schema::Schema as ArrowSchema;
9use std::path::Path;
10
11/// Error type for schema inference operations.
12#[derive(Debug, Clone, thiserror::Error)]
13pub enum SchemaInferError {
14    /// File does not exist or cannot be opened.
15    #[error("File not found: {0}")]
16    FileNotFound(String),
17    /// File extension is not supported (.csv, .json, .ndjson, .parquet).
18    #[error("Unsupported file format: '{0}'. Supported: .csv, .json, .ndjson, .parquet")]
19    UnsupportedFormat(String),
20    /// Failed to parse file header or infer schema.
21    #[error("Schema inference failed: {0}")]
22    ParseError(String),
23}
24
25/// Infer the Arrow schema from a data file by extension.
26///
27/// Dispatches to the appropriate reader based on file extension:
28/// - `.csv` → CSV header + sample inference
29/// - `.json` / `.ndjson` → JSON schema inference
30/// - `.parquet` → Parquet footer metadata
31///
32/// Only reads the minimum data needed (header/sample rows/footer), not the full file.
33pub fn infer_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
34    let ext = path
35        .extension()
36        .and_then(|e| e.to_str())
37        .unwrap_or("")
38        .to_lowercase();
39
40    match ext.as_str() {
41        "csv" => infer_csv_schema(path),
42        "json" | "ndjson" => infer_json_schema(path),
43        "parquet" => infer_parquet_schema(path),
44        other => Err(SchemaInferError::UnsupportedFormat(other.to_string())),
45    }
46}
47
48/// Infer schema from a CSV file using header + sample rows.
49pub fn infer_csv_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
50    use arrow_csv::reader::Format;
51    use std::fs::File;
52    use std::io::BufReader;
53
54    let file = File::open(path).map_err(|e| {
55        if e.kind() == std::io::ErrorKind::NotFound {
56            SchemaInferError::FileNotFound(path.display().to_string())
57        } else {
58            SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
59        }
60    })?;
61
62    let format = Format::default().with_header(true);
63    let (schema, _records_read) = format
64        .infer_schema(BufReader::new(&file), Some(100))
65        .map_err(|e| SchemaInferError::ParseError(format!("CSV schema inference: {}", e)))?;
66
67    Ok(schema)
68}
69
70/// Infer schema from a JSON/NDJSON file using sample rows.
71pub fn infer_json_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
72    use std::fs::File;
73    use std::io::BufReader;
74
75    let file = File::open(path).map_err(|e| {
76        if e.kind() == std::io::ErrorKind::NotFound {
77            SchemaInferError::FileNotFound(path.display().to_string())
78        } else {
79            SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
80        }
81    })?;
82
83    let (schema, _records_read) =
84        arrow_json::reader::infer_json_schema(BufReader::new(file), Some(100))
85            .map_err(|e| SchemaInferError::ParseError(format!("JSON schema inference: {}", e)))?;
86
87    Ok(schema)
88}
89
90/// Infer schema from a Parquet file by reading only the footer metadata.
91pub fn infer_parquet_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
92    use parquet::file::reader::{FileReader, SerializedFileReader};
93    use std::fs::File;
94
95    let file = File::open(path).map_err(|e| {
96        if e.kind() == std::io::ErrorKind::NotFound {
97            SchemaInferError::FileNotFound(path.display().to_string())
98        } else {
99            SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
100        }
101    })?;
102
103    let reader = SerializedFileReader::new(file)
104        .map_err(|e| SchemaInferError::ParseError(format!("Parquet reader: {}", e)))?;
105
106    let parquet_schema = reader.metadata().file_metadata().schema_descr_ptr();
107    let arrow_schema = parquet::arrow::parquet_to_arrow_schema(
108        &parquet_schema,
109        None, // no key-value metadata filter
110    )
111    .map_err(|e| SchemaInferError::ParseError(format!("Parquet→Arrow schema: {}", e)))?;
112
113    Ok(arrow_schema)
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use std::io::Write;
120    use std::sync::Arc;
121
122    fn temp_csv(name: &str, content: &str) -> std::path::PathBuf {
123        let path = std::env::temp_dir().join(name);
124        let mut f = std::fs::File::create(&path).unwrap();
125        f.write_all(content.as_bytes()).unwrap();
126        path
127    }
128
129    #[test]
130    fn test_infer_csv_schema() {
131        let path = temp_csv(
132            "test_infer_csv.csv",
133            "name,value,active\nalpha,1.5,true\nbeta,2.7,false\n",
134        );
135        let schema = infer_schema(&path).unwrap();
136        let names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
137        assert_eq!(names, vec!["name", "value", "active"]);
138        std::fs::remove_file(&path).ok();
139    }
140
141    #[test]
142    fn test_infer_json_schema() {
143        let path = temp_csv(
144            "test_infer_json.ndjson",
145            r#"{"name":"alpha","value":1.5}
146{"name":"beta","value":2.7}
147"#,
148        );
149        let schema = infer_schema(&path).unwrap();
150        let mut names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
151        names.sort();
152        assert_eq!(names, vec!["name", "value"]);
153        std::fs::remove_file(&path).ok();
154    }
155
156    #[test]
157    fn test_infer_parquet_schema() {
158        use arrow_array::{Float64Array, RecordBatch, StringArray};
159        use arrow_schema::{DataType, Field, Schema};
160
161        // Create a small Parquet file
162        let schema = Arc::new(Schema::new(vec![
163            Field::new("symbol", DataType::Utf8, false),
164            Field::new("price", DataType::Float64, false),
165        ]));
166        let batch = RecordBatch::try_new(
167            schema.clone(),
168            vec![
169                Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
170                Arc::new(Float64Array::from(vec![150.0, 2800.0])),
171            ],
172        )
173        .unwrap();
174
175        let path = std::env::temp_dir().join("test_infer_parquet.parquet");
176        let file = std::fs::File::create(&path).unwrap();
177        let mut writer =
178            parquet::arrow::arrow_writer::ArrowWriter::try_new(file, schema, None).unwrap();
179        writer.write(&batch).unwrap();
180        writer.close().unwrap();
181
182        let inferred = infer_schema(&path).unwrap();
183        let names: Vec<&str> = inferred
184            .fields()
185            .iter()
186            .map(|f| f.name().as_str())
187            .collect();
188        assert_eq!(names, vec!["symbol", "price"]);
189        std::fs::remove_file(&path).ok();
190    }
191
192    #[test]
193    fn test_unsupported_extension() {
194        let path = std::env::temp_dir().join("test_unsupported.xlsx");
195        std::fs::File::create(&path).unwrap();
196        let err = infer_schema(&path).unwrap_err();
197        assert!(matches!(err, SchemaInferError::UnsupportedFormat(_)));
198        std::fs::remove_file(&path).ok();
199    }
200
201    #[test]
202    fn test_missing_file() {
203        let path = Path::new("/nonexistent/file.csv");
204        let err = infer_schema(path).unwrap_err();
205        assert!(matches!(err, SchemaInferError::FileNotFound(_)));
206    }
207}