tensorlogic_compiler/error_recovery/
tolerant_compiler.rs1use std::panic::{self, AssertUnwindSafe};
28
29use tensorlogic_ir::{EinsumGraph, TLExpr};
30
31use crate::compile_to_einsum_with_context;
32use crate::context::CompilerContext;
33
34use super::collector::DiagnosticCollector;
35use super::diagnostic::{Diagnostic, Severity};
36use super::strategy::{RecoveryAction, RecoveryStrategy};
37
38#[derive(Debug, Clone)]
45pub struct PartialCompilationResult {
46 pub graphs: Vec<Option<EinsumGraph>>,
48 pub diagnostics: DiagnosticCollector,
50 pub strategy: RecoveryStrategy,
52 pub aborted: bool,
55 pub aborted_at: Option<usize>,
58}
59
60impl PartialCompilationResult {
61 pub fn success_count(&self) -> usize {
63 self.graphs.iter().filter(|g| g.is_some()).count()
64 }
65
66 pub fn failure_count(&self) -> usize {
68 self.graphs.iter().filter(|g| g.is_none()).count()
69 }
70
71 pub fn is_all_success(&self) -> bool {
73 self.graphs.iter().all(|g| g.is_some())
74 }
75
76 pub fn successes(&self) -> impl Iterator<Item = (usize, &EinsumGraph)> {
78 self.graphs
79 .iter()
80 .enumerate()
81 .filter_map(|(i, g)| g.as_ref().map(|gg| (i, gg)))
82 }
83
84 pub fn failures(&self) -> Vec<usize> {
86 self.graphs
87 .iter()
88 .enumerate()
89 .filter_map(|(i, g)| if g.is_none() { Some(i) } else { None })
90 .collect()
91 }
92}
93
94#[derive(Debug, Clone, Default)]
100pub struct TolerantCompiler {
101 strategy: RecoveryStrategy,
102}
103
104impl TolerantCompiler {
105 pub fn new() -> Self {
107 Self::default()
108 }
109
110 pub fn with_strategy(strategy: RecoveryStrategy) -> Self {
112 Self { strategy }
113 }
114
115 pub fn strategy(&self) -> RecoveryStrategy {
117 self.strategy
118 }
119
120 pub fn set_strategy(&mut self, strategy: RecoveryStrategy) {
122 self.strategy = strategy;
123 }
124
125 pub fn compile_program(&self, program: &[TLExpr]) -> PartialCompilationResult {
129 self.compile_program_with(program, |_idx| CompilerContext::new())
130 }
131
132 pub fn compile_program_with<F>(
136 &self,
137 program: &[TLExpr],
138 mut make_ctx: F,
139 ) -> PartialCompilationResult
140 where
141 F: FnMut(usize) -> CompilerContext,
142 {
143 let collector = DiagnosticCollector::new();
144 let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
145
146 let mut aborted = false;
147 let mut aborted_at: Option<usize> = None;
148
149 for (idx, expr) in program.iter().enumerate() {
150 if aborted {
151 graphs.push(None);
152 continue;
153 }
154
155 let mut ctx = make_ctx(idx);
156 match self.compile_one(idx, expr, &mut ctx, &collector) {
157 OneResult::Ok(graph) => graphs.push(Some(graph)),
158 OneResult::Skipped => graphs.push(None),
159 OneResult::Aborted => {
160 graphs.push(None);
161 aborted = true;
162 aborted_at = Some(idx);
163 }
164 }
165 }
166
167 PartialCompilationResult {
168 graphs,
169 diagnostics: collector,
170 strategy: self.strategy,
171 aborted,
172 aborted_at,
173 }
174 }
175
176 pub fn compile_program_with_contexts(
182 &self,
183 program: &[TLExpr],
184 contexts: &mut [CompilerContext],
185 ) -> PartialCompilationResult {
186 let collector = DiagnosticCollector::new();
187 let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
188
189 let mut aborted = false;
190 let mut aborted_at: Option<usize> = None;
191
192 for (idx, expr) in program.iter().enumerate() {
193 if aborted {
194 graphs.push(None);
195 continue;
196 }
197
198 if idx >= contexts.len() {
199 collector.push(
200 Diagnostic::fatal(format!(
201 "tolerant compiler: missing CompilerContext for expression #{}",
202 idx
203 ))
204 .with_expression_index(idx),
205 );
206 let action = self.strategy.decide(Severity::Fatal);
208 match action {
209 RecoveryAction::Continue => graphs.push(None),
210 RecoveryAction::SkipExpression => graphs.push(None),
211 RecoveryAction::AbortProgram => {
212 graphs.push(None);
213 aborted = true;
214 aborted_at = Some(idx);
215 }
216 }
217 continue;
218 }
219
220 match self.compile_one(idx, expr, &mut contexts[idx], &collector) {
221 OneResult::Ok(graph) => graphs.push(Some(graph)),
222 OneResult::Skipped => graphs.push(None),
223 OneResult::Aborted => {
224 graphs.push(None);
225 aborted = true;
226 aborted_at = Some(idx);
227 }
228 }
229 }
230
231 PartialCompilationResult {
232 graphs,
233 diagnostics: collector,
234 strategy: self.strategy,
235 aborted,
236 aborted_at,
237 }
238 }
239
240 fn compile_one(
243 &self,
244 idx: usize,
245 expr: &TLExpr,
246 ctx: &mut CompilerContext,
247 collector: &DiagnosticCollector,
248 ) -> OneResult {
249 let unwind_result = panic::catch_unwind(AssertUnwindSafe(|| {
253 compile_to_einsum_with_context(expr, ctx)
254 }));
255
256 match unwind_result {
257 Ok(Ok(graph)) => OneResult::Ok(graph),
258 Ok(Err(err)) => {
259 let diag =
260 Diagnostic::error(format!("compilation error in expression #{}: {}", idx, err))
261 .with_expression_index(idx);
262 collector.push(diag);
263 self.react(idx, Severity::Error)
264 }
265 Err(payload) => {
266 let msg = panic_payload_to_string(&payload);
267 let diag = Diagnostic::fatal(format!(
268 "panic while compiling expression #{}: {}",
269 idx, msg
270 ))
271 .with_expression_index(idx);
272 collector.push(diag);
273 self.react(idx, Severity::Fatal)
274 }
275 }
276 }
277
278 fn react(&self, _idx: usize, severity: Severity) -> OneResult {
281 match self.strategy.decide(severity) {
282 RecoveryAction::Continue => {
283 OneResult::Skipped
286 }
287 RecoveryAction::SkipExpression => OneResult::Skipped,
288 RecoveryAction::AbortProgram => OneResult::Aborted,
289 }
290 }
291}
292
293#[derive(Debug)]
295enum OneResult {
296 Ok(EinsumGraph),
297 Skipped,
298 Aborted,
299}
300
301pub fn compile_tolerant(program: &[TLExpr]) -> PartialCompilationResult {
307 TolerantCompiler::new().compile_program(program)
308}
309
310pub fn compile_tolerant_with_strategy(
312 program: &[TLExpr],
313 strategy: RecoveryStrategy,
314) -> PartialCompilationResult {
315 TolerantCompiler::with_strategy(strategy).compile_program(program)
316}
317
318fn panic_payload_to_string(payload: &Box<dyn std::any::Any + Send>) -> String {
321 if let Some(s) = payload.downcast_ref::<&'static str>() {
322 (*s).to_string()
323 } else if let Some(s) = payload.downcast_ref::<String>() {
324 s.clone()
325 } else {
326 "<non-string panic payload>".to_string()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use tensorlogic_ir::{TLExpr, Term};
334
335 fn good_expr() -> TLExpr {
336 TLExpr::pred("p", vec![Term::var("x")])
337 }
338
339 #[test]
340 fn compile_tolerant_all_good() {
341 let program = vec![good_expr(), good_expr(), good_expr()];
342 let res = compile_tolerant(&program);
343 assert_eq!(res.graphs.len(), 3);
344 assert!(res.is_all_success());
345 assert_eq!(res.success_count(), 3);
346 assert!(!res.aborted);
347 assert!(res.diagnostics.is_empty());
348 }
349
350 #[test]
351 fn partial_result_success_iter() {
352 let program = vec![good_expr(), good_expr()];
353 let res = compile_tolerant(&program);
354 let v: Vec<usize> = res.successes().map(|(i, _)| i).collect();
355 assert_eq!(v, vec![0, 1]);
356 }
357}