Skip to main content

tla_connect/trace_validation/
validator.rs

1//! Apalache-based trace validation (Approach 3).
2//!
3//! Validates that a recorded NDJSON trace is a valid behavior of a TLA+
4//! specification by running Apalache on a TraceSpec.
5
6use crate::error::{Error, ValidationError};
7use std::collections::{BTreeMap, BTreeSet};
8use std::path::{Path, PathBuf};
9use tracing::{debug, info};
10
11/// Result of trace validation.
12#[derive(Debug)]
13#[non_exhaustive]
14#[must_use = "trace validation result should be checked"]
15pub enum TraceResult {
16    /// The trace is a valid behavior of the specification.
17    /// Apalache violated the `TraceFinished` invariant (meaning the full trace
18    /// was successfully replayed).
19    Valid,
20
21    /// The trace is NOT a valid behavior of the specification.
22    Invalid {
23        /// Human-readable reason for the failure.
24        reason: String,
25    },
26}
27
28/// Configuration for Apalache-based trace validation.
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub struct TraceValidatorConfig {
32    /// Path to the TLA+ TraceSpec file.
33    pub trace_spec: PathBuf,
34
35    /// INIT predicate name in the TraceSpec (default: "TraceInit").
36    pub init: String,
37
38    /// NEXT predicate name in the TraceSpec (default: "TraceNext").
39    pub next: String,
40
41    /// Invariant that is violated when the trace is fully consumed
42    /// (default: "TraceFinished").
43    pub inv: String,
44
45    /// Constant initialization predicate (default: "TraceConstInit").
46    pub cinit: String,
47
48    /// Path to the Apalache binary (default: "apalache-mc").
49    pub apalache_bin: String,
50
51    /// Timeout for the Apalache subprocess. If None, no timeout is applied.
52    pub timeout: Option<std::time::Duration>,
53}
54
55impl Default for TraceValidatorConfig {
56    fn default() -> Self {
57        Self {
58            trace_spec: PathBuf::new(),
59            init: "TraceInit".into(),
60            next: "TraceNext".into(),
61            inv: "TraceFinished".into(),
62            cinit: "TraceConstInit".into(),
63            apalache_bin: "apalache-mc".into(),
64            timeout: None,
65        }
66    }
67}
68
69crate::builder::impl_builder!(TraceValidatorConfig, TraceValidatorConfigBuilder {
70    required { trace_spec: PathBuf }
71    optional { init: String, next: String, inv: String, cinit: String, apalache_bin: String }
72    optional_or { timeout: std::time::Duration }
73});
74
75/// Validates Rust execution traces against TLA+ specs using Apalache.
76///
77/// Uses the "inverted invariant" technique: the TraceSpec defines a
78/// `TraceFinished` invariant that is violated when the entire trace has
79/// been consumed. If Apalache reports a violation, the trace is valid.
80#[must_use = "validation result should be checked"]
81pub fn validate_trace(config: &TraceValidatorConfig, trace_file: &Path) -> Result<TraceResult, Error> {
82    let trace_spec = config
83        .trace_spec
84        .canonicalize()
85        .map_err(|_| ValidationError::TraceSpecNotFound(config.trace_spec.clone()))?;
86
87    let trace_file = trace_file
88        .canonicalize()
89        .map_err(|_| ValidationError::TraceFileNotFound(trace_file.to_path_buf()))?;
90
91    let spec_dir = trace_spec
92        .parent()
93        .ok_or_else(|| ValidationError::TraceSpecNotFound(trace_spec.clone()))?;
94
95    let spec_filename = trace_spec
96        .file_name()
97        .ok_or_else(|| ValidationError::TraceSpecNotFound(trace_spec.clone()))?;
98
99    info!(
100        spec = %trace_spec.display(),
101        trace = %trace_file.display(),
102        "Validating trace with Apalache"
103    );
104
105    let (trace_data, trace_len) = ndjson_to_tla_module(&trace_file)?;
106
107    let work_dir = tempfile::Builder::new()
108        .prefix("tla_trace_")
109        .tempdir()
110        .map_err(|e| ValidationError::WorkDir(e.to_string()))?;
111    let spec_subdir = work_dir.path().join("spec");
112    let out_subdir = work_dir.path().join("out");
113    std::fs::create_dir_all(&spec_subdir).map_err(ValidationError::Io)?;
114    std::fs::create_dir_all(&out_subdir).map_err(ValidationError::Io)?;
115
116    for entry in std::fs::read_dir(spec_dir).map_err(ValidationError::Io)? {
117        let entry = entry.map_err(ValidationError::Io)?;
118        let path = entry.path();
119        if path.extension().and_then(|e| e.to_str()) == Some("tla") {
120            let dest = spec_subdir.join(entry.file_name());
121            std::fs::copy(&path, &dest).map_err(|e| ValidationError::FileCopy {
122                path: path.clone(),
123                reason: e.to_string(),
124            })?;
125        }
126    }
127
128    let trace_data_path = spec_subdir.join("TraceData.tla");
129    std::fs::write(&trace_data_path, &trace_data).map_err(ValidationError::Io)?;
130
131    debug!(
132        "Generated TraceData.tla ({} bytes, {} trace entries)",
133        trace_data.len(),
134        trace_len
135    );
136
137    let length = trace_len.saturating_sub(1);
138
139    let mut cmd = std::process::Command::new(&config.apalache_bin);
140    cmd.arg("check")
141        .arg(format!("--init={}", config.init))
142        .arg(format!("--next={}", config.next))
143        .arg(format!("--inv={}", config.inv))
144        .arg(format!("--cinit={}", config.cinit))
145        .arg(format!("--length={length}"))
146        .arg(format!("--out-dir={}", out_subdir.display()))
147        .arg(spec_subdir.join(spec_filename));
148
149    debug!("Apalache command: {:?}", cmd);
150
151    let output = crate::util::run_with_timeout(&mut cmd, config.timeout)
152        .map_err(ValidationError::from)?;
153
154    let stdout = String::from_utf8_lossy(&output.stdout);
155    let stderr = String::from_utf8_lossy(&output.stderr);
156
157    debug!("Apalache stdout:\n{}", stdout);
158    if !stderr.is_empty() {
159        debug!("Apalache stderr:\n{}", stderr);
160    }
161
162    parse_apalache_output(&stdout, &stderr, output.status.code())
163}
164
165fn parse_apalache_output(
166    stdout: &str,
167    stderr: &str,
168    exit_code: Option<i32>,
169) -> Result<TraceResult, Error> {
170    match exit_code {
171        Some(12) => {
172            info!("Trace validated successfully (Apalache violated TraceFinished)");
173            Ok(TraceResult::Valid)
174        }
175
176        Some(0) => Ok(TraceResult::Invalid {
177            reason: "Apalache completed without violating TraceFinished – \
178                     the trace could not be fully replayed against the spec"
179                .to_string(),
180        }),
181
182        _ => {
183            let error_lines: Vec<&str> = stdout
184                .lines()
185                .filter(|l| l.contains("Error") || l.contains("error"))
186                .chain(stderr.lines().filter(|l| !l.is_empty()))
187                .collect();
188
189            Err(ValidationError::from(crate::error::ApalacheError::Execution {
190                exit_code,
191                message: error_lines.join("\n"),
192            })
193            .into())
194        }
195    }
196}
197
198/// Convert an NDJSON trace file to a TLA+ module defining `TraceLog`.
199#[doc(hidden)]
200pub fn ndjson_to_tla_module(trace_file: &Path) -> Result<(String, usize), Error> {
201    let content = std::fs::read_to_string(trace_file).map_err(ValidationError::Io)?;
202
203    let mut json_objects = Vec::new();
204    let mut records = Vec::new();
205    let mut expected_keys: Option<BTreeSet<String>> = None;
206
207    for (i, line) in content.lines().enumerate() {
208        let line = line.trim();
209        if line.is_empty() {
210            continue;
211        }
212
213        let line_num = i + 1;
214
215        let obj: serde_json::Value = serde_json::from_str(line).map_err(|e| {
216            ValidationError::InvalidJson {
217                line: line_num,
218                reason: e.to_string(),
219            }
220        })?;
221
222        let obj_map = obj.as_object().ok_or_else(|| ValidationError::NonObjectState {
223            found: format!("line {line_num}: {}", obj),
224        })?;
225
226        let current_keys: BTreeSet<String> = obj_map.keys().cloned().collect();
227
228        if let Some(ref expected) = expected_keys {
229            if &current_keys != expected {
230                return Err(ValidationError::InconsistentSchema {
231                    line: line_num,
232                    expected: expected.iter().cloned().collect(),
233                    found: current_keys.into_iter().collect(),
234                }
235                .into());
236            }
237        } else {
238            expected_keys = Some(current_keys);
239        }
240
241        validate_json_types(&obj, line_num)?;
242
243        let record = json_obj_to_tla_record(&obj, line_num)?;
244        json_objects.push(obj);
245        records.push(record);
246    }
247
248    if records.is_empty() {
249        return Err(ValidationError::EmptyTrace(trace_file.to_path_buf()).into());
250    }
251
252    let record_type = infer_snowcat_record_type(&json_objects[0])?;
253
254    let actions: Vec<String> = json_objects
255        .iter()
256        .map(|obj| {
257            obj.get("action")
258                .and_then(|v| v.as_str())
259                .unwrap_or("unknown")
260                .to_string()
261        })
262        .collect();
263
264    let count = records.len();
265    let mut out = String::new();
266    out.push_str("---- MODULE TraceData ----\n");
267    out.push_str("EXTENDS Integers, Sequences\n\n");
268
269    out.push_str(&format!("\\* @type: () => Seq({record_type});\n"));
270    out.push_str("TraceLog == <<\n");
271    for (i, record) in records.iter().enumerate() {
272        if i > 0 {
273            out.push_str(",\n");
274        }
275        out.push_str("  ");
276        out.push_str(record);
277    }
278    out.push_str("\n>>\n\n");
279
280    out.push_str("\\* @type: () => Seq(Str);\n");
281    out.push_str("TraceActions == <<\n");
282    for (i, action) in actions.iter().enumerate() {
283        if i > 0 {
284            out.push_str(",\n");
285        }
286        out.push_str(&format!("  \"{}\"", escape_tla_string(action)));
287    }
288    out.push_str("\n>>\n\n====\n");
289    Ok((out, count))
290}
291
292/// Validate JSON types are supported (reject floats, nested structures).
293fn validate_json_types(value: &serde_json::Value, line: usize) -> Result<(), Error> {
294    let obj = value.as_object().ok_or_else(|| ValidationError::NonObjectState {
295        found: format!("{value}"),
296    })?;
297
298    for (key, val) in obj {
299        validate_json_value(val, line, key)?;
300    }
301    Ok(())
302}
303
304/// Recursively validate a JSON value, rejecting floats at any depth.
305fn validate_json_value(value: &serde_json::Value, line: usize, field: &str) -> Result<(), Error> {
306    match value {
307        serde_json::Value::Number(n) => {
308            if n.is_f64() && !n.is_i64() && !n.is_u64() {
309                return Err(ValidationError::FloatNotSupported {
310                    line,
311                    field: field.to_string(),
312                    value: n.as_f64().unwrap_or(0.0),
313                }
314                .into());
315            }
316        }
317        serde_json::Value::Array(arr) => {
318            for (idx, elem) in arr.iter().enumerate() {
319                validate_json_value(elem, line, &format!("{field}[{idx}]"))?;
320            }
321        }
322        serde_json::Value::Object(obj) => {
323            for (key, val) in obj {
324                validate_json_value(val, line, &format!("{field}.{key}"))?;
325            }
326        }
327        _ => {}
328    }
329    Ok(())
330}
331
332fn infer_snowcat_record_type(value: &serde_json::Value) -> Result<String, Error> {
333    let obj = value.as_object().ok_or_else(|| ValidationError::NonObjectState {
334        found: format!("{value}"),
335    })?;
336
337    let sorted: BTreeMap<_, _> = obj.iter().collect();
338    let mut fields = Vec::new();
339    for (key, val) in &sorted {
340        let ty = infer_snowcat_type(val, key)?;
341        fields.push(format!("{key}: {ty}"));
342    }
343
344    Ok(format!("{{{}}}", fields.join(", ")))
345}
346
347fn infer_snowcat_type(value: &serde_json::Value, field: &str) -> Result<String, Error> {
348    match value {
349        serde_json::Value::Bool(_) => Ok("Bool".to_string()),
350        serde_json::Value::Number(_) => Ok("Int".to_string()),
351        serde_json::Value::String(_) => Ok("Str".to_string()),
352        serde_json::Value::Array(arr) => {
353            if arr.is_empty() {
354                return Ok("Seq(Int)".to_string());
355            }
356            let first_type = infer_snowcat_type(&arr[0], field)?;
357            for (i, elem) in arr.iter().enumerate().skip(1) {
358                let elem_type = infer_snowcat_type(elem, field)?;
359                if elem_type != first_type {
360                    return Err(ValidationError::InconsistentArrayType {
361                        field: format!("{field}[{i}]"),
362                        expected: first_type,
363                        found: elem_type,
364                    }
365                    .into());
366                }
367            }
368            Ok(format!("Seq({})", first_type))
369        }
370        serde_json::Value::Object(obj) => {
371            let sorted: BTreeMap<_, _> = obj.iter().collect();
372            let fields: Result<Vec<String>, Error> = sorted
373                .iter()
374                .map(|(k, v)| Ok(format!("{k}: {}", infer_snowcat_type(v, &format!("{field}.{k}"))?)))
375                .collect();
376            Ok(format!("{{{}}}", fields?.join(", ")))
377        }
378        serde_json::Value::Null => Ok("Str".to_string()),
379    }
380}
381
382fn json_obj_to_tla_record(value: &serde_json::Value, line: usize) -> Result<String, Error> {
383    let obj = value.as_object().ok_or_else(|| ValidationError::TlaConversion {
384        line,
385        reason: format!("Expected JSON object, got: {value}"),
386    })?;
387
388    let sorted: BTreeMap<_, _> = obj.iter().collect();
389    let mut fields = Vec::new();
390
391    for (key, val) in &sorted {
392        let tla_val = json_to_tla_value(val, line, key)?;
393        fields.push(format!("{key} |-> {tla_val}"));
394    }
395
396    Ok(format!("[{}]", fields.join(", ")))
397}
398
399fn json_to_tla_value(value: &serde_json::Value, line: usize, field: &str) -> Result<String, Error> {
400    match value {
401        serde_json::Value::Null => Ok("\"null\"".to_string()),
402        serde_json::Value::Bool(b) => Ok(if *b { "TRUE" } else { "FALSE" }.to_string()),
403        serde_json::Value::Number(n) => {
404            if let Some(i) = n.as_i64() {
405                Ok(i.to_string())
406            } else if let Some(u) = n.as_u64() {
407                Ok(u.to_string())
408            } else {
409                Err(ValidationError::FloatNotSupported {
410                    line,
411                    field: field.to_string(),
412                    value: n.as_f64().unwrap_or(0.0),
413                }
414                .into())
415            }
416        }
417        serde_json::Value::String(s) => Ok(format!("\"{}\"", escape_tla_string(s))),
418        serde_json::Value::Array(arr) => {
419            let elems: Result<Vec<String>, Error> = arr
420                .iter()
421                .enumerate()
422                .map(|(i, v)| json_to_tla_value(v, line, &format!("{field}[{i}]")))
423                .collect();
424            Ok(format!("<<{}>>", elems?.join(", ")))
425        }
426        serde_json::Value::Object(_) => {
427            json_obj_to_tla_record(value, line)
428        }
429    }
430}
431
432impl From<PathBuf> for TraceValidatorConfig {
433    fn from(trace_spec: PathBuf) -> Self {
434        Self {
435            trace_spec,
436            ..Default::default()
437        }
438    }
439}
440
441impl From<&str> for TraceValidatorConfig {
442    fn from(trace_spec: &str) -> Self {
443        Self {
444            trace_spec: PathBuf::from(trace_spec),
445            ..Default::default()
446        }
447    }
448}
449
450/// Escape a string for use in a TLA+ string literal.
451fn escape_tla_string(s: &str) -> String {
452    let mut out = String::with_capacity(s.len());
453    for c in s.chars() {
454        match c {
455            '\\' => out.push_str("\\\\"),
456            '"' => out.push_str("\\\""),
457            '\n' => out.push_str("\\n"),
458            '\r' => out.push_str("\\r"),
459            '\t' => out.push_str("\\t"),
460            c if c.is_control() => {
461                out.push_str(&format!("\\u{:04x}", c as u32));
462            }
463            c => out.push(c),
464        }
465    }
466    out
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use serde_json::json;
473
474    #[test]
475    fn escape_tla_string_plain() {
476        assert_eq!(escape_tla_string("hello"), "hello");
477    }
478
479    #[test]
480    fn escape_tla_string_special_chars() {
481        assert_eq!(escape_tla_string("a\\b"), "a\\\\b");
482        assert_eq!(escape_tla_string("a\"b"), "a\\\"b");
483        assert_eq!(escape_tla_string("a\nb"), "a\\nb");
484        assert_eq!(escape_tla_string("a\rb"), "a\\rb");
485        assert_eq!(escape_tla_string("a\tb"), "a\\tb");
486    }
487
488    #[test]
489    fn escape_tla_string_control_char() {
490        let result = escape_tla_string("a\x01b");
491        assert_eq!(result, "a\\u0001b");
492    }
493
494    #[test]
495    fn json_to_tla_value_null() {
496        let val = json!(null);
497        assert_eq!(json_to_tla_value(&val, 1, "f").unwrap(), "\"null\"");
498    }
499
500    #[test]
501    fn json_to_tla_value_bool() {
502        assert_eq!(json_to_tla_value(&json!(true), 1, "f").unwrap(), "TRUE");
503        assert_eq!(json_to_tla_value(&json!(false), 1, "f").unwrap(), "FALSE");
504    }
505
506    #[test]
507    fn json_to_tla_value_int() {
508        assert_eq!(json_to_tla_value(&json!(42), 1, "f").unwrap(), "42");
509        assert_eq!(json_to_tla_value(&json!(-7), 1, "f").unwrap(), "-7");
510    }
511
512    #[test]
513    fn json_to_tla_value_string() {
514        assert_eq!(json_to_tla_value(&json!("hello"), 1, "f").unwrap(), "\"hello\"");
515    }
516
517    #[test]
518    fn json_to_tla_value_array() {
519        assert_eq!(json_to_tla_value(&json!([1, 2, 3]), 1, "f").unwrap(), "<<1, 2, 3>>");
520        assert_eq!(json_to_tla_value(&json!([]), 1, "f").unwrap(), "<<>>");
521    }
522
523    #[test]
524    fn json_to_tla_value_float_rejected() {
525        assert!(json_to_tla_value(&json!(3.14), 1, "f").is_err());
526    }
527
528    #[test]
529    fn validate_json_types_nested_float() {
530        // Float nested in array of arrays should be rejected
531        let val = json!({"data": [[3.14]]});
532        assert!(validate_json_types(&val, 1).is_err());
533    }
534
535    #[test]
536    fn validate_json_types_nested_object_float() {
537        // Float nested in object should be rejected
538        let val = json!({"outer": {"inner": 3.14}});
539        assert!(validate_json_types(&val, 1).is_err());
540    }
541
542    #[test]
543    fn validate_json_types_valid() {
544        let val = json!({"a": 1, "b": "str", "c": true, "d": [1, 2]});
545        assert!(validate_json_types(&val, 1).is_ok());
546    }
547
548    #[test]
549    fn infer_snowcat_type_primitives() {
550        assert_eq!(infer_snowcat_type(&json!(true), "f").unwrap(), "Bool");
551        assert_eq!(infer_snowcat_type(&json!(42), "f").unwrap(), "Int");
552        assert_eq!(infer_snowcat_type(&json!("hi"), "f").unwrap(), "Str");
553        assert_eq!(infer_snowcat_type(&json!(null), "f").unwrap(), "Str");
554    }
555
556    #[test]
557    fn infer_snowcat_type_array() {
558        assert_eq!(infer_snowcat_type(&json!([1, 2]), "f").unwrap(), "Seq(Int)");
559        assert_eq!(infer_snowcat_type(&json!([]), "f").unwrap(), "Seq(Int)");
560    }
561
562    #[test]
563    fn infer_snowcat_type_mixed_array_rejected() {
564        let result = infer_snowcat_type(&json!([1, "hello"]), "f");
565        assert!(result.is_err());
566    }
567
568    #[test]
569    fn json_obj_to_tla_record_sorted() {
570        let val = json!({"z": 1, "a": 2});
571        let record = json_obj_to_tla_record(&val, 1).unwrap();
572        // Fields should be sorted alphabetically
573        assert!(record.starts_with("[a |->"));
574    }
575
576    #[test]
577    fn builder_missing_required_field() {
578        let result = TraceValidatorConfig::builder().build();
579        assert!(result.is_err());
580        let err = result.unwrap_err();
581        assert!(err.to_string().contains("trace_spec"));
582    }
583}