Skip to main content

shape_runtime/schema_inference/
lockfile.rs

1//! Schema cache helpers backed by generic `shape.lock` artifacts.
2
3use arrow_schema::{DataType, Field, Schema as ArrowSchema};
4use sha2::{Digest, Sha256};
5use std::collections::BTreeMap;
6use std::path::Path;
7
8use crate::package_lock::{ArtifactDeterminism, LockedArtifact, PackageLock};
9
10/// Unified schema lock type backed by `shape.lock`.
11pub type SchemaLockfile = PackageLock;
12
13const SCHEMA_ARTIFACT_NAMESPACE: &str = "schema.infer";
14const SCHEMA_ARTIFACT_PRODUCER: &str = "shape-runtime/schema_inference@v1";
15
16/// Infer a schema, using the unified lockfile cache when possible.
17///
18/// Returns `(schema, from_cache)` — `from_cache` is true if the lockfile had
19/// a valid entry for the current external file fingerprint.
20pub fn infer_or_cached(
21    file_path: &Path,
22    source_key: &str,
23    lockfile: &mut SchemaLockfile,
24) -> Result<(ArrowSchema, bool), super::SchemaInferError> {
25    let format = file_path
26        .extension()
27        .and_then(|e| e.to_str())
28        .unwrap_or("")
29        .to_lowercase();
30    let file_hash = compute_file_hash(file_path).unwrap_or_else(|_| "unknown".to_string());
31
32    let (inputs, determinism) = schema_artifact_inputs(source_key, &format, &file_hash);
33    let inputs_hash = PackageLock::artifact_inputs_hash(inputs.clone(), &determinism)
34        .map_err(super::SchemaInferError::ParseError)?;
35
36    if let Some(artifact) = lockfile.artifact(SCHEMA_ARTIFACT_NAMESPACE, source_key, &inputs_hash) {
37        if let Ok(schema) = artifact_to_arrow_schema(artifact) {
38            return Ok((schema, true));
39        }
40    }
41
42    let schema = super::infer_schema(file_path)?;
43    let payload = schema_to_payload(&format, &schema);
44    let artifact = LockedArtifact::new(
45        SCHEMA_ARTIFACT_NAMESPACE,
46        source_key,
47        SCHEMA_ARTIFACT_PRODUCER,
48        determinism,
49        inputs,
50        payload,
51    )
52    .map_err(super::SchemaInferError::ParseError)?;
53    lockfile
54        .upsert_artifact(artifact)
55        .map_err(super::SchemaInferError::ParseError)?;
56
57    Ok((schema, false))
58}
59
60/// Compute SHA-256 hash of the first 4KB of a file.
61pub fn compute_file_hash(path: &Path) -> std::io::Result<String> {
62    use std::io::Read;
63    let mut file = std::fs::File::open(path)?;
64    let mut buffer = vec![0u8; 4096];
65    let bytes_read = file.read(&mut buffer)?;
66    buffer.truncate(bytes_read);
67
68    let mut hasher = Sha256::new();
69    hasher.update(&buffer);
70    let result = hasher.finalize();
71    Ok(format!("sha256:{:x}", result))
72}
73
74fn schema_artifact_inputs(
75    source_key: &str,
76    format: &str,
77    file_hash: &str,
78) -> (BTreeMap<String, String>, ArtifactDeterminism) {
79    let mut inputs = BTreeMap::new();
80    inputs.insert("source".to_string(), source_key.to_string());
81    inputs.insert("format".to_string(), format.to_string());
82    inputs.insert("file_hash".to_string(), file_hash.to_string());
83
84    let determinism = ArtifactDeterminism::External {
85        fingerprints: BTreeMap::from([(format!("file:{source_key}"), file_hash.to_string())]),
86    };
87    (inputs, determinism)
88}
89
90fn artifact_to_arrow_schema(artifact: &LockedArtifact) -> Result<ArrowSchema, String> {
91    let payload = artifact.payload()?;
92    payload_to_schema(&payload)
93}
94
95fn schema_to_payload(format: &str, schema: &ArrowSchema) -> shape_wire::WireValue {
96    let columns = schema
97        .fields()
98        .iter()
99        .map(|field| {
100            shape_wire::WireValue::Object(BTreeMap::from([
101                (
102                    "name".to_string(),
103                    shape_wire::WireValue::String(field.name().clone()),
104                ),
105                (
106                    "data_type".to_string(),
107                    shape_wire::WireValue::String(format_data_type(field.data_type())),
108                ),
109                (
110                    "nullable".to_string(),
111                    shape_wire::WireValue::Bool(field.is_nullable()),
112                ),
113            ]))
114        })
115        .collect::<Vec<_>>();
116
117    shape_wire::WireValue::Object(BTreeMap::from([
118        (
119            "format".to_string(),
120            shape_wire::WireValue::String(format.to_string()),
121        ),
122        ("columns".to_string(), shape_wire::WireValue::Array(columns)),
123    ]))
124}
125
126fn payload_to_schema(payload: &shape_wire::WireValue) -> Result<ArrowSchema, String> {
127    let shape_wire::WireValue::Object(map) = payload else {
128        return Err("schema artifact payload is not an object".to_string());
129    };
130    let columns = map
131        .get("columns")
132        .ok_or_else(|| "schema artifact payload missing columns".to_string())?;
133    let shape_wire::WireValue::Array(column_values) = columns else {
134        return Err("schema artifact payload columns must be an array".to_string());
135    };
136
137    let mut fields = Vec::with_capacity(column_values.len());
138    for column in column_values {
139        let shape_wire::WireValue::Object(col) = column else {
140            return Err("schema artifact column must be an object".to_string());
141        };
142        let name = col
143            .get("name")
144            .and_then(shape_wire::WireValue::as_str)
145            .ok_or_else(|| "schema artifact column missing name".to_string())?;
146        let data_type = col
147            .get("data_type")
148            .and_then(shape_wire::WireValue::as_str)
149            .ok_or_else(|| "schema artifact column missing data_type".to_string())?;
150        let nullable = col
151            .get("nullable")
152            .and_then(shape_wire::WireValue::as_bool)
153            .ok_or_else(|| "schema artifact column missing nullable".to_string())?;
154
155        fields.push(Field::new(name, parse_data_type(data_type), nullable));
156    }
157
158    Ok(ArrowSchema::new(fields))
159}
160
161/// Format an Arrow DataType as a string for lockfile storage.
162fn format_data_type(dt: &DataType) -> String {
163    match dt {
164        DataType::Float64 => "Float64".to_string(),
165        DataType::Float32 => "Float32".to_string(),
166        DataType::Int64 => "Int64".to_string(),
167        DataType::Int32 => "Int32".to_string(),
168        DataType::Int16 => "Int16".to_string(),
169        DataType::Int8 => "Int8".to_string(),
170        DataType::UInt64 => "UInt64".to_string(),
171        DataType::UInt32 => "UInt32".to_string(),
172        DataType::Boolean => "Boolean".to_string(),
173        DataType::Utf8 => "Utf8".to_string(),
174        DataType::LargeUtf8 => "LargeUtf8".to_string(),
175        DataType::Timestamp(unit, tz) => {
176            let unit_str = match unit {
177                arrow_schema::TimeUnit::Second => "s",
178                arrow_schema::TimeUnit::Millisecond => "ms",
179                arrow_schema::TimeUnit::Microsecond => "us",
180                arrow_schema::TimeUnit::Nanosecond => "ns",
181            };
182            match tz {
183                Some(tz) => format!("Timestamp({},{})", unit_str, tz),
184                None => format!("Timestamp({})", unit_str),
185            }
186        }
187        DataType::Date32 => "Date32".to_string(),
188        DataType::Date64 => "Date64".to_string(),
189        other => format!("{other:?}"),
190    }
191}
192
193/// Parse a data type string from the lockfile back into an Arrow DataType.
194fn parse_data_type(s: &str) -> DataType {
195    match s {
196        "Float64" => DataType::Float64,
197        "Float32" => DataType::Float32,
198        "Int64" => DataType::Int64,
199        "Int32" => DataType::Int32,
200        "Int16" => DataType::Int16,
201        "Int8" => DataType::Int8,
202        "UInt64" => DataType::UInt64,
203        "UInt32" => DataType::UInt32,
204        "Boolean" => DataType::Boolean,
205        "Utf8" => DataType::Utf8,
206        "LargeUtf8" => DataType::LargeUtf8,
207        "Date32" => DataType::Date32,
208        "Date64" => DataType::Date64,
209        s if s.starts_with("Timestamp(") => {
210            let inner = &s[10..s.len() - 1];
211            let parts: Vec<&str> = inner.splitn(2, ',').collect();
212            let unit = match parts[0] {
213                "s" => arrow_schema::TimeUnit::Second,
214                "ms" => arrow_schema::TimeUnit::Millisecond,
215                "us" => arrow_schema::TimeUnit::Microsecond,
216                "ns" => arrow_schema::TimeUnit::Nanosecond,
217                _ => arrow_schema::TimeUnit::Nanosecond,
218            };
219            let tz = parts.get(1).map(|value| value.to_string().into());
220            DataType::Timestamp(unit, tz)
221        }
222        _ => DataType::Utf8,
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::io::Write;
230
231    #[test]
232    fn test_lockfile_roundtrip() {
233        let mut lockfile = SchemaLockfile::new();
234        let payload = shape_wire::WireValue::Object(BTreeMap::from([(
235            "columns".to_string(),
236            shape_wire::WireValue::Array(vec![shape_wire::WireValue::Object(BTreeMap::from([
237                (
238                    "name".to_string(),
239                    shape_wire::WireValue::String("price".to_string()),
240                ),
241                (
242                    "data_type".to_string(),
243                    shape_wire::WireValue::String("Float64".to_string()),
244                ),
245                ("nullable".to_string(), shape_wire::WireValue::Bool(false)),
246            ]))]),
247        )]));
248        let artifact = LockedArtifact::new(
249            SCHEMA_ARTIFACT_NAMESPACE,
250            "data.csv",
251            SCHEMA_ARTIFACT_PRODUCER,
252            ArtifactDeterminism::External {
253                fingerprints: BTreeMap::from([(
254                    "file:data.csv".to_string(),
255                    "sha256:deadbeef".to_string(),
256                )]),
257            },
258            BTreeMap::new(),
259            payload,
260        )
261        .unwrap();
262        lockfile.upsert_artifact(artifact).unwrap();
263
264        let dir = tempfile::tempdir().unwrap();
265        let path = dir.path().join("shape.lock");
266        lockfile.write(&path).unwrap();
267
268        let loaded = SchemaLockfile::read(&path).unwrap();
269        assert_eq!(loaded.artifacts.len(), 1);
270    }
271
272    #[test]
273    fn test_compute_file_hash_changes_on_content_change() {
274        let dir = tempfile::tempdir().unwrap();
275        let path = dir.path().join("sample.csv");
276
277        std::fs::write(&path, "a,b\n1,2\n").unwrap();
278        let h1 = compute_file_hash(&path).unwrap();
279
280        std::fs::write(&path, "a,b\n1,2\n3,4\n").unwrap();
281        let h2 = compute_file_hash(&path).unwrap();
282
283        assert_ne!(h1, h2);
284        assert!(h1.starts_with("sha256:"));
285        assert!(h2.starts_with("sha256:"));
286    }
287
288    #[test]
289    fn test_infer_or_cached() {
290        let dir = tempfile::tempdir().unwrap();
291        let csv_path = dir.path().join("cached_test.csv");
292
293        let mut file = std::fs::File::create(&csv_path).unwrap();
294        writeln!(file, "x,y").unwrap();
295        writeln!(file, "1,2").unwrap();
296        writeln!(file, "3,4").unwrap();
297        drop(file);
298
299        let mut lockfile = SchemaLockfile::new();
300
301        let (_schema1, from_cache1) =
302            infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
303        assert!(!from_cache1);
304
305        let (_schema2, from_cache2) =
306            infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
307        assert!(from_cache2);
308
309        let mut file = std::fs::File::create(&csv_path).unwrap();
310        writeln!(file, "x,y").unwrap();
311        writeln!(file, "1,2").unwrap();
312        writeln!(file, "3,4").unwrap();
313        writeln!(file, "5,6").unwrap();
314        drop(file);
315
316        let (_schema3, from_cache3) =
317            infer_or_cached(&csv_path, "cached_test.csv", &mut lockfile).unwrap();
318        assert!(!from_cache3);
319    }
320}