Skip to main content

tensorlogic_compiler/error_recovery/
tolerant_compiler.rs

1//! Tolerant (partial-error-recovery) compilation driver.
2//!
3//! The [`TolerantCompiler`] compiles a *program* — i.e. a slice of
4//! [`TLExpr`] — under a configurable [`RecoveryStrategy`]. Each expression is
5//! compiled in **isolation**, so a single malformed expression never aborts
6//! the compilation of its siblings (under the default strategy).
7//!
8//! Internally the driver:
9//!
10//! 1. Iterates the input slice in order.
11//! 2. For each expression, calls [`compile_to_einsum_with_context`] inside
12//!    [`std::panic::catch_unwind`] — any panic becomes a
13//!    [`Severity::Fatal`] diagnostic rather than unwinding across the
14//!    driver boundary.
15//! 3. Any `Err(...)` from the compiler is converted into a
16//!    [`Severity::Error`] diagnostic and the offending slot becomes `None`.
17//! 4. The chosen [`RecoveryStrategy`] decides whether to continue, skip
18//!    this expression only, or abort the whole program.
19//!
20//! The original strict entry points ([`compile_to_einsum`],
21//! [`compile_to_einsum_with_context`]) are untouched.
22//!
23//! [`compile_to_einsum`]: crate::compile_to_einsum
24//! [`compile_to_einsum_with_context`]: crate::compile_to_einsum_with_context
25//! [`TLExpr`]: tensorlogic_ir::TLExpr
26
27use std::panic::{self, AssertUnwindSafe};
28
29use tensorlogic_ir::{EinsumGraph, TLExpr};
30
31use crate::compile_to_einsum_with_context;
32use crate::context::CompilerContext;
33
34use super::collector::DiagnosticCollector;
35use super::diagnostic::{Diagnostic, Severity};
36use super::strategy::{RecoveryAction, RecoveryStrategy};
37
38/// Result returned by [`TolerantCompiler::compile_program`].
39///
40/// The `graphs` vector has exactly the same length as the input slice.
41/// `graphs[i] == None` iff expression `i` was skipped due to a blocking
42/// diagnostic (Error or Fatal) or because the program was aborted while
43/// processing expression `i` or earlier.
44#[derive(Debug, Clone)]
45pub struct PartialCompilationResult {
46    /// Per-expression compilation output, aligned with the input slice.
47    pub graphs: Vec<Option<EinsumGraph>>,
48    /// Collected diagnostics, in insertion (i.e. expression) order.
49    pub diagnostics: DiagnosticCollector,
50    /// The recovery strategy that produced this result.
51    pub strategy: RecoveryStrategy,
52    /// `true` iff the driver stopped early (AbortOnAny / SkipOnFatal hit a
53    /// blocker). Expressions after `aborted_at` have `None` graphs.
54    pub aborted: bool,
55    /// Index at which the driver aborted (only meaningful when `aborted` is
56    /// `true`).
57    pub aborted_at: Option<usize>,
58}
59
60impl PartialCompilationResult {
61    /// Number of expressions that successfully produced a graph.
62    pub fn success_count(&self) -> usize {
63        self.graphs.iter().filter(|g| g.is_some()).count()
64    }
65
66    /// Number of expressions that failed to produce a graph.
67    pub fn failure_count(&self) -> usize {
68        self.graphs.iter().filter(|g| g.is_none()).count()
69    }
70
71    /// `true` if every expression produced a graph.
72    pub fn is_all_success(&self) -> bool {
73        self.graphs.iter().all(|g| g.is_some())
74    }
75
76    /// Iterate successful `(index, &graph)` pairs.
77    pub fn successes(&self) -> impl Iterator<Item = (usize, &EinsumGraph)> {
78        self.graphs
79            .iter()
80            .enumerate()
81            .filter_map(|(i, g)| g.as_ref().map(|gg| (i, gg)))
82    }
83
84    /// Indices of expressions that produced no graph.
85    pub fn failures(&self) -> Vec<usize> {
86        self.graphs
87            .iter()
88            .enumerate()
89            .filter_map(|(i, g)| if g.is_none() { Some(i) } else { None })
90            .collect()
91    }
92}
93
94/// Tolerant compilation façade.
95///
96/// The compiler is stateless besides its [`RecoveryStrategy`]; each call to
97/// [`Self::compile_program`] starts from a fresh [`CompilerContext`] unless
98/// the caller uses [`Self::compile_program_with_contexts`].
99#[derive(Debug, Clone, Default)]
100pub struct TolerantCompiler {
101    strategy: RecoveryStrategy,
102}
103
104impl TolerantCompiler {
105    /// Construct a tolerant compiler with [`RecoveryStrategy::SkipOnError`].
106    pub fn new() -> Self {
107        Self::default()
108    }
109
110    /// Construct a tolerant compiler with a specific recovery strategy.
111    pub fn with_strategy(strategy: RecoveryStrategy) -> Self {
112        Self { strategy }
113    }
114
115    /// Configured recovery strategy.
116    pub fn strategy(&self) -> RecoveryStrategy {
117        self.strategy
118    }
119
120    /// Update the recovery strategy in place.
121    pub fn set_strategy(&mut self, strategy: RecoveryStrategy) {
122        self.strategy = strategy;
123    }
124
125    /// Compile a program (slice of expressions) under the current recovery
126    /// strategy using **one fresh [`CompilerContext`] per expression** so
127    /// failures do not poison the context for siblings.
128    pub fn compile_program(&self, program: &[TLExpr]) -> PartialCompilationResult {
129        self.compile_program_with(program, |_idx| CompilerContext::new())
130    }
131
132    /// Compile a program, calling `make_ctx` once per expression to obtain a
133    /// fresh context. This lets callers share domain declarations across
134    /// expressions while still keeping each compilation isolated.
135    pub fn compile_program_with<F>(
136        &self,
137        program: &[TLExpr],
138        mut make_ctx: F,
139    ) -> PartialCompilationResult
140    where
141        F: FnMut(usize) -> CompilerContext,
142    {
143        let collector = DiagnosticCollector::new();
144        let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
145
146        let mut aborted = false;
147        let mut aborted_at: Option<usize> = None;
148
149        for (idx, expr) in program.iter().enumerate() {
150            if aborted {
151                graphs.push(None);
152                continue;
153            }
154
155            let mut ctx = make_ctx(idx);
156            match self.compile_one(idx, expr, &mut ctx, &collector) {
157                OneResult::Ok(graph) => graphs.push(Some(graph)),
158                OneResult::Skipped => graphs.push(None),
159                OneResult::Aborted => {
160                    graphs.push(None);
161                    aborted = true;
162                    aborted_at = Some(idx);
163                }
164            }
165        }
166
167        PartialCompilationResult {
168            graphs,
169            diagnostics: collector,
170            strategy: self.strategy,
171            aborted,
172            aborted_at,
173        }
174    }
175
176    /// Compile a program re-using a caller-supplied vector of contexts. Every
177    /// context is used exactly once, matched by index to the expression slot.
178    ///
179    /// The caller must provide `contexts.len() >= program.len()`; surplus
180    /// contexts are ignored.
181    pub fn compile_program_with_contexts(
182        &self,
183        program: &[TLExpr],
184        contexts: &mut [CompilerContext],
185    ) -> PartialCompilationResult {
186        let collector = DiagnosticCollector::new();
187        let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
188
189        let mut aborted = false;
190        let mut aborted_at: Option<usize> = None;
191
192        for (idx, expr) in program.iter().enumerate() {
193            if aborted {
194                graphs.push(None);
195                continue;
196            }
197
198            if idx >= contexts.len() {
199                collector.push(
200                    Diagnostic::fatal(format!(
201                        "tolerant compiler: missing CompilerContext for expression #{}",
202                        idx
203                    ))
204                    .with_expression_index(idx),
205                );
206                // Behave as if a fatal fired.
207                let action = self.strategy.decide(Severity::Fatal);
208                match action {
209                    RecoveryAction::Continue => graphs.push(None),
210                    RecoveryAction::SkipExpression => graphs.push(None),
211                    RecoveryAction::AbortProgram => {
212                        graphs.push(None);
213                        aborted = true;
214                        aborted_at = Some(idx);
215                    }
216                }
217                continue;
218            }
219
220            match self.compile_one(idx, expr, &mut contexts[idx], &collector) {
221                OneResult::Ok(graph) => graphs.push(Some(graph)),
222                OneResult::Skipped => graphs.push(None),
223                OneResult::Aborted => {
224                    graphs.push(None);
225                    aborted = true;
226                    aborted_at = Some(idx);
227                }
228            }
229        }
230
231        PartialCompilationResult {
232            graphs,
233            diagnostics: collector,
234            strategy: self.strategy,
235            aborted,
236            aborted_at,
237        }
238    }
239
240    /// Compile a single expression in tolerant mode, pushing diagnostics
241    /// into the supplied collector. Returns the per-expression outcome.
242    fn compile_one(
243        &self,
244        idx: usize,
245        expr: &TLExpr,
246        ctx: &mut CompilerContext,
247        collector: &DiagnosticCollector,
248    ) -> OneResult {
249        // catch_unwind only at the TOP LEVEL boundary, per the scope
250        // discipline: panics become Fatal diagnostics instead of unwinding
251        // across the driver.
252        let unwind_result = panic::catch_unwind(AssertUnwindSafe(|| {
253            compile_to_einsum_with_context(expr, ctx)
254        }));
255
256        match unwind_result {
257            Ok(Ok(graph)) => OneResult::Ok(graph),
258            Ok(Err(err)) => {
259                let diag =
260                    Diagnostic::error(format!("compilation error in expression #{}: {}", idx, err))
261                        .with_expression_index(idx);
262                collector.push(diag);
263                self.react(idx, Severity::Error)
264            }
265            Err(payload) => {
266                let msg = panic_payload_to_string(&payload);
267                let diag = Diagnostic::fatal(format!(
268                    "panic while compiling expression #{}: {}",
269                    idx, msg
270                ))
271                .with_expression_index(idx);
272                collector.push(diag);
273                self.react(idx, Severity::Fatal)
274            }
275        }
276    }
277
278    /// Consult the strategy to decide whether to skip or abort after a
279    /// diagnostic of the given severity has already been pushed.
280    fn react(&self, _idx: usize, severity: Severity) -> OneResult {
281        match self.strategy.decide(severity) {
282            RecoveryAction::Continue => {
283                // Non-blocking diagnostics never reach this arm (Error/Fatal
284                // only). Guard defensively just in case.
285                OneResult::Skipped
286            }
287            RecoveryAction::SkipExpression => OneResult::Skipped,
288            RecoveryAction::AbortProgram => OneResult::Aborted,
289        }
290    }
291}
292
293/// Per-expression outcome of the tolerant driver.
294#[derive(Debug)]
295enum OneResult {
296    Ok(EinsumGraph),
297    Skipped,
298    Aborted,
299}
300
301/// Public free function — tolerant counterpart of
302/// [`crate::compile_to_einsum`].
303///
304/// Equivalent to
305/// `TolerantCompiler::with_strategy(RecoveryStrategy::SkipOnError).compile_program(program)`.
306pub fn compile_tolerant(program: &[TLExpr]) -> PartialCompilationResult {
307    TolerantCompiler::new().compile_program(program)
308}
309
310/// Public free function — tolerant compilation with a caller-chosen strategy.
311pub fn compile_tolerant_with_strategy(
312    program: &[TLExpr],
313    strategy: RecoveryStrategy,
314) -> PartialCompilationResult {
315    TolerantCompiler::with_strategy(strategy).compile_program(program)
316}
317
318/// Convert a `Box<dyn Any + Send>` panic payload into a human-readable string
319/// without panicking again.
320fn panic_payload_to_string(payload: &Box<dyn std::any::Any + Send>) -> String {
321    if let Some(s) = payload.downcast_ref::<&'static str>() {
322        (*s).to_string()
323    } else if let Some(s) = payload.downcast_ref::<String>() {
324        s.clone()
325    } else {
326        "<non-string panic payload>".to_string()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use tensorlogic_ir::{TLExpr, Term};
334
335    fn good_expr() -> TLExpr {
336        TLExpr::pred("p", vec![Term::var("x")])
337    }
338
339    #[test]
340    fn compile_tolerant_all_good() {
341        let program = vec![good_expr(), good_expr(), good_expr()];
342        let res = compile_tolerant(&program);
343        assert_eq!(res.graphs.len(), 3);
344        assert!(res.is_all_success());
345        assert_eq!(res.success_count(), 3);
346        assert!(!res.aborted);
347        assert!(res.diagnostics.is_empty());
348    }
349
350    #[test]
351    fn partial_result_success_iter() {
352        let program = vec![good_expr(), good_expr()];
353        let res = compile_tolerant(&program);
354        let v: Vec<usize> = res.successes().map(|(i, _)| i).collect();
355        assert_eq!(v, vec![0, 1]);
356    }
357}