shape_runtime/schema_inference/
lockfile.rs1use arrow_schema::{DataType, Field, Schema as ArrowSchema};
4use sha2::{Digest, Sha256};
5use std::collections::BTreeMap;
6use std::path::Path;
7
8use crate::package_lock::{ArtifactDeterminism, LockedArtifact, PackageLock};
9
10pub type SchemaLockfile = PackageLock;
12
13const SCHEMA_ARTIFACT_NAMESPACE: &str = "schema.infer";
14const SCHEMA_ARTIFACT_PRODUCER: &str = "shape-runtime/schema_inference@v1";
15
16pub fn infer_or_cached(
21 file_path: &Path,
22 source_key: &str,
23 lockfile: &mut SchemaLockfile,
24) -> Result<(ArrowSchema, bool), super::SchemaInferError> {
25 let format = file_path
26 .extension()
27 .and_then(|e| e.to_str())
28 .unwrap_or("")
29 .to_lowercase();
30 let file_hash = compute_file_hash(file_path).unwrap_or_else(|_| "unknown".to_string());
31
32 let (inputs, determinism) = schema_artifact_inputs(source_key, &format, &file_hash);
33 let inputs_hash = PackageLock::artifact_inputs_hash(inputs.clone(), &determinism)
34 .map_err(super::SchemaInferError::ParseError)?;
35
36 if let Some(artifact) = lockfile.artifact(SCHEMA_ARTIFACT_NAMESPACE, source_key, &inputs_hash) {
37 if let Ok(schema) = artifact_to_arrow_schema(artifact) {
38 return Ok((schema, true));
39 }
40 }
41
42 let schema = super::infer_schema(file_path)?;
43 let payload = schema_to_payload(&format, &schema);
44 let artifact = LockedArtifact::new(
45 SCHEMA_ARTIFACT_NAMESPACE,
46 source_key,
47 SCHEMA_ARTIFACT_PRODUCER,
48 determinism,
49 inputs,
50 payload,
51 )
52 .map_err(super::SchemaInferError::ParseError)?;
53 lockfile
54 .upsert_artifact(artifact)
55 .map_err(super::SchemaInferError::ParseError)?;
56
57 Ok((schema, false))
58}
59
60pub fn compute_file_hash(path: &Path) -> std::io::Result<String> {
62 use std::io::Read;
63 let mut file = std::fs::File::open(path)?;
64 let mut buffer = vec![0u8; 4096];
65 let bytes_read = file.read(&mut buffer)?;
66 buffer.truncate(bytes_read);
67
68 let mut hasher = Sha256::new();
69 hasher.update(&buffer);
70 let result = hasher.finalize();
71 Ok(format!("sha256:{:x}", result))
72}
73
74fn schema_artifact_inputs(
75 source_key: &str,
76 format: &str,
77 file_hash: &str,
78) -> (BTreeMap<String, String>, ArtifactDeterminism) {
79 let mut inputs = BTreeMap::new();
80 inputs.insert("source".to_string(), source_key.to_string());
81 inputs.insert("format".to_string(), format.to_string());
82 inputs.insert("file_hash".to_string(), file_hash.to_string());
83
84 let determinism = ArtifactDeterminism::External {
85 fingerprints: BTreeMap::from([(format!("file:{source_key}"), file_hash.to_string())]),
86 };
87 (inputs, determinism)
88}
89
90fn artifact_to_arrow_schema(artifact: &LockedArtifact) -> Result<ArrowSchema, String> {
91 let payload = artifact.payload()?;
92 payload_to_schema(&payload)
93}
94
95fn schema_to_payload(format: &str, schema: &ArrowSchema) -> shape_wire::WireValue {
96 let columns = schema
97 .fields()
98 .iter()
99 .map(|field| {
100 shape_wire::WireValue::Object(BTreeMap::from([
101 (
102 "name".to_string(),
103 shape_wire::WireValue::String(field.name().clone()),
104 ),
105 (
106 "data_type".to_string(),
107 shape_wire::WireValue::String(format_data_type(field.data_type())),
108 ),
109 (
110 "nullable".to_string(),
111 shape_wire::WireValue::Bool(field.is_nullable()),
112 ),
113 ]))
114 })
115 .collect::<Vec<_>>();
116
117 shape_wire::WireValue::Object(BTreeMap::from([
118 (
119 "format".to_string(),
120 shape_wire::WireValue::String(format.to_string()),
121 ),
122 ("columns".to_string(), shape_wire::WireValue::Array(columns)),
123 ]))
124}
125
126fn payload_to_schema(payload: &shape_wire::WireValue) -> Result<ArrowSchema, String> {
127 let shape_wire::WireValue::Object(map) = payload else {
128 return Err("schema artifact payload is not an object".to_string());
129 };
130 let columns = map
131 .get("columns")
132 .ok_or_else(|| "schema artifact payload missing columns".to_string())?;
133 let shape_wire::WireValue::Array(column_values) = columns else {
134 return Err("schema artifact payload columns must be an array".to_string());
135 };
136
137 let mut fields = Vec::with_capacity(column_values.len());
138 for column in column_values {
139 let shape_wire::WireValue::Object(col) = column else {
140 return Err("schema artifact column must be an object".to_string());
141 };
142 let name = col
143 .get("name")
144 .and_then(shape_wire::WireValue::as_str)
145 .ok_or_else(|| "schema artifact column missing name".to_string())?;
146 let data_type = col
147 .get("data_type")
148 .and_then(shape_wire::WireValue::as_str)
149 .ok_or_else(|| "schema artifact column missing data_type".to_string())?;
150 let nullable = col
151 .get("nullable")
152 .and_then(shape_wire::WireValue::as_bool)
153 .ok_or_else(|| "schema artifact column missing nullable".to_string())?;
154
155 fields.push(Field::new(name, parse_data_type(data_type), nullable));
156 }
157
158 Ok(ArrowSchema::new(fields))
159}
160
161fn format_data_type(dt: &DataType) -> String {
163 match dt {
164 DataType::Float64 => "Float64".to_string(),
165 DataType::Float32 => "Float32".to_string(),
166 DataType::Int64 => "Int64".to_string(),
167 DataType::Int32 => "Int32".to_string(),
168 DataType::Int16 => "Int16".to_string(),
169 DataType::Int8 => "Int8".to_string(),
170 DataType::UInt64 => "UInt64".to_string(),
171 DataType::UInt32 => "UInt32".to_string(),
172 DataType::Boolean => "Boolean".to_string(),
173 DataType::Utf8 => "Utf8".to_string(),
174 DataType::LargeUtf8 => "LargeUtf8".to_string(),
175 DataType::Timestamp(unit, tz) => {
176 let unit_str = match unit {
177 arrow_schema::TimeUnit::Second => "s",
178 arrow_schema::TimeUnit::Millisecond => "ms",
179 arrow_schema::TimeUnit::Microsecond => "us",
180 arrow_schema::TimeUnit::Nanosecond => "ns",
181 };
182 match tz {
183 Some(tz) => format!("Timestamp({},{})", unit_str, tz),
184 None => format!("Timestamp({})", unit_str),
185 }
186 }
187 DataType::Date32 => "Date32".to_string(),
188 DataType::Date64 => "Date64".to_string(),
189 other => format!("{other:?}"),
190 }
191}
192
193fn parse_data_type(s: &str) -> DataType {
195 match s {
196 "Float64" => DataType::Float64,
197 "Float32" => DataType::Float32,
198 "Int64" => DataType::Int64,
199 "Int32" => DataType::Int32,
200 "Int16" => DataType::Int16,
201 "Int8" => DataType::Int8,
202 "UInt64" => DataType::UInt64,
203 "UInt32" => DataType::UInt32,
204 "Boolean" => DataType::Boolean,
205 "Utf8" => DataType::Utf8,
206 "LargeUtf8" => DataType::LargeUtf8,
207 "Date32" => DataType::Date32,
208 "Date64" => DataType::Date64,
209 s if s.starts_with("Timestamp(") => {
210 let inner = &s[10..s.len() - 1];
211 let parts: Vec<&str> = inner.splitn(2, ',').collect();
212 let unit = match parts[0] {
213 "s" => arrow_schema::TimeUnit::Second,
214 "ms" => arrow_schema::TimeUnit::Millisecond,
215 "us" => arrow_schema::TimeUnit::Microsecond,
216 "ns" => arrow_schema::TimeUnit::Nanosecond,
217 _ => arrow_schema::TimeUnit::Nanosecond,
218 };
219 let tz = parts.get(1).map(|value| value.to_string().into());
220 DataType::Timestamp(unit, tz)
221 }
222 _ => DataType::Utf8,
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::io::Write;
230
231 #[test]
232 fn test_lockfile_roundtrip() {
233 let mut lockfile = SchemaLockfile::new();
234 let payload = shape_wire::WireValue::Object(BTreeMap::from([(
235 "columns".to_string(),
236 shape_wire::WireValue::Array(vec![shape_wire::WireValue::Object(BTreeMap::from([
237 (
238 "name".to_string(),
239 shape_wire::WireValue::String("price".to_string()),
240 ),
241 (
242 "data_type".to_string(),
243 shape_wire::WireValue::String("Float64".to_string()),
244 ),
245 ("nullable".to_string(), shape_wire::WireValue::Bool(false)),
246 ]))]),
247 )]));
248 let artifact = LockedArtifact::new(
249 SCHEMA_ARTIFACT_NAMESPACE,
250 "data.csv",
251 SCHEMA_ARTIFACT_PRODUCER,
252 ArtifactDeterminism::External {
253 fingerprints: BTreeMap::from([(
254 "file:data.csv".to_string(),
255 "sha256:deadbeef".to_string(),
256 )]),
257 },
258 BTreeMap::new(),
259 payload,
260 )
261 .unwrap();
262 lockfile.upsert_artifact(artifact).unwrap();
263
264 let dir = tempfile::tempdir().unwrap();
265 let path = dir.path().join("shape.lock");
266 lockfile.write(&path).unwrap();
267
268 let loaded = SchemaLockfile::read(&path).unwrap();
269 assert_eq!(loaded.artifacts.len(), 1);
270 }
271
272 #[test]
273 fn test_compute_file_hash_changes_on_content_change() {
274 let dir = tempfile::tempdir().unwrap();
275 let path = dir.path().join("sample.csv");
276
277 std::fs::write(&path, "a,b\n1,2\n").unwrap();
278 let h1 = compute_file_hash(&path).unwrap();
279
280 std::fs::write(&path, "a,b\n1,2\n3,4\n").unwrap();
281 let h2 = compute_file_hash(&path).unwrap();
282
283 assert_ne!(h1, h2);
284 assert!(h1.starts_with("sha256:"));
285 assert!(h2.starts_with("sha256:"));
286 }
287
288 #[test]
289 fn test_infer_or_cached() {
290 let dir = tempfile::tempdir().unwrap();
291 let csv_path = dir.path().join("cached_test.csv");
292
293 let mut file = std::fs::File::create(&csv_path).unwrap();
294 writeln!(file, "x,y").unwrap();
295 writeln!(file, "1,2").unwrap();
296 writeln!(file, "3,4").unwrap();
297 drop(file);
298
299 let mut lockfile = SchemaLockfile::new();
300
301 let (_schema1, from_cache1) =
302 infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
303 assert!(!from_cache1);
304
305 let (_schema2, from_cache2) =
306 infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
307 assert!(from_cache2);
308
309 let mut file = std::fs::File::create(&csv_path).unwrap();
310 writeln!(file, "x,y").unwrap();
311 writeln!(file, "1,2").unwrap();
312 writeln!(file, "3,4").unwrap();
313 writeln!(file, "5,6").unwrap();
314 drop(file);
315
316 let (_schema3, from_cache3) =
317 infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
318 assert!(!from_cache3);
319 }
320}