shape_runtime/schema_inference/
mod.rs1pub mod lockfile;
7
8use arrow_schema::Schema as ArrowSchema;
9use std::path::Path;
10
11#[derive(Debug, Clone, thiserror::Error)]
13pub enum SchemaInferError {
14 #[error("File not found: {0}")]
16 FileNotFound(String),
17 #[error("Unsupported file format: '{0}'. Supported: .csv, .json, .ndjson, .parquet")]
19 UnsupportedFormat(String),
20 #[error("Schema inference failed: {0}")]
22 ParseError(String),
23}
24
25pub fn infer_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
34 let ext = path
35 .extension()
36 .and_then(|e| e.to_str())
37 .unwrap_or("")
38 .to_lowercase();
39
40 match ext.as_str() {
41 "csv" => infer_csv_schema(path),
42 "json" | "ndjson" => infer_json_schema(path),
43 "parquet" => infer_parquet_schema(path),
44 other => Err(SchemaInferError::UnsupportedFormat(other.to_string())),
45 }
46}
47
48pub fn infer_csv_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
50 use arrow_csv::reader::Format;
51 use std::fs::File;
52 use std::io::BufReader;
53
54 let file = File::open(path).map_err(|e| {
55 if e.kind() == std::io::ErrorKind::NotFound {
56 SchemaInferError::FileNotFound(path.display().to_string())
57 } else {
58 SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
59 }
60 })?;
61
62 let format = Format::default().with_header(true);
63 let (schema, _records_read) = format
64 .infer_schema(BufReader::new(&file), Some(100))
65 .map_err(|e| SchemaInferError::ParseError(format!("CSV schema inference: {}", e)))?;
66
67 Ok(schema)
68}
69
70pub fn infer_json_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
72 use std::fs::File;
73 use std::io::BufReader;
74
75 let file = File::open(path).map_err(|e| {
76 if e.kind() == std::io::ErrorKind::NotFound {
77 SchemaInferError::FileNotFound(path.display().to_string())
78 } else {
79 SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
80 }
81 })?;
82
83 let (schema, _records_read) =
84 arrow_json::reader::infer_json_schema(BufReader::new(file), Some(100))
85 .map_err(|e| SchemaInferError::ParseError(format!("JSON schema inference: {}", e)))?;
86
87 Ok(schema)
88}
89
90pub fn infer_parquet_schema(path: &Path) -> Result<ArrowSchema, SchemaInferError> {
92 use parquet::file::reader::{FileReader, SerializedFileReader};
93 use std::fs::File;
94
95 let file = File::open(path).map_err(|e| {
96 if e.kind() == std::io::ErrorKind::NotFound {
97 SchemaInferError::FileNotFound(path.display().to_string())
98 } else {
99 SchemaInferError::ParseError(format!("Cannot open '{}': {}", path.display(), e))
100 }
101 })?;
102
103 let reader = SerializedFileReader::new(file)
104 .map_err(|e| SchemaInferError::ParseError(format!("Parquet reader: {}", e)))?;
105
106 let parquet_schema = reader.metadata().file_metadata().schema_descr_ptr();
107 let arrow_schema = parquet::arrow::parquet_to_arrow_schema(
108 &parquet_schema,
109 None, )
111 .map_err(|e| SchemaInferError::ParseError(format!("Parquet→Arrow schema: {}", e)))?;
112
113 Ok(arrow_schema)
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use std::io::Write;
120 use std::sync::Arc;
121
122 fn temp_csv(name: &str, content: &str) -> std::path::PathBuf {
123 let path = std::env::temp_dir().join(name);
124 let mut f = std::fs::File::create(&path).unwrap();
125 f.write_all(content.as_bytes()).unwrap();
126 path
127 }
128
129 #[test]
130 fn test_infer_csv_schema() {
131 let path = temp_csv(
132 "test_infer_csv.csv",
133 "name,value,active\nalpha,1.5,true\nbeta,2.7,false\n",
134 );
135 let schema = infer_schema(&path).unwrap();
136 let names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
137 assert_eq!(names, vec!["name", "value", "active"]);
138 std::fs::remove_file(&path).ok();
139 }
140
141 #[test]
142 fn test_infer_json_schema() {
143 let path = temp_csv(
144 "test_infer_json.ndjson",
145 r#"{"name":"alpha","value":1.5}
146{"name":"beta","value":2.7}
147"#,
148 );
149 let schema = infer_schema(&path).unwrap();
150 let mut names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
151 names.sort();
152 assert_eq!(names, vec!["name", "value"]);
153 std::fs::remove_file(&path).ok();
154 }
155
156 #[test]
157 fn test_infer_parquet_schema() {
158 use arrow_array::{Float64Array, RecordBatch, StringArray};
159 use arrow_schema::{DataType, Field, Schema};
160
161 let schema = Arc::new(Schema::new(vec![
163 Field::new("symbol", DataType::Utf8, false),
164 Field::new("price", DataType::Float64, false),
165 ]));
166 let batch = RecordBatch::try_new(
167 schema.clone(),
168 vec![
169 Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
170 Arc::new(Float64Array::from(vec![150.0, 2800.0])),
171 ],
172 )
173 .unwrap();
174
175 let path = std::env::temp_dir().join("test_infer_parquet.parquet");
176 let file = std::fs::File::create(&path).unwrap();
177 let mut writer =
178 parquet::arrow::arrow_writer::ArrowWriter::try_new(file, schema, None).unwrap();
179 writer.write(&batch).unwrap();
180 writer.close().unwrap();
181
182 let inferred = infer_schema(&path).unwrap();
183 let names: Vec<&str> = inferred
184 .fields()
185 .iter()
186 .map(|f| f.name().as_str())
187 .collect();
188 assert_eq!(names, vec!["symbol", "price"]);
189 std::fs::remove_file(&path).ok();
190 }
191
192 #[test]
193 fn test_unsupported_extension() {
194 let path = std::env::temp_dir().join("test_unsupported.xlsx");
195 std::fs::File::create(&path).unwrap();
196 let err = infer_schema(&path).unwrap_err();
197 assert!(matches!(err, SchemaInferError::UnsupportedFormat(_)));
198 std::fs::remove_file(&path).ok();
199 }
200
201 #[test]
202 fn test_missing_file() {
203 let path = Path::new("/nonexistent/file.csv");
204 let err = infer_schema(path).unwrap_err();
205 assert!(matches!(err, SchemaInferError::FileNotFound(_)));
206 }
207}