tensorlogic_compiler/jit.rs
1//! JIT compilation for hot expression paths.
2//!
3//! [`JitCompiler`] wraps a standard compilation pipeline and tracks expression
4//! usage frequency. When an expression exceeds [`JitCompiler::hot_threshold`]
5//! compilations it is promoted to the "hot path": the expression is re-optimised
6//! more aggressively with [`OptimizationPipeline`] (aggressive preset) before
7//! compilation, and the result is stored as a pre-computed [`Arc<EinsumGraph>`].
8//! All subsequent compilations of the same hot expression return the cached
9//! graph in O(1) without re-running the optimizer or compiler.
10//!
11//! # Design notes
12//!
13//! Expression identity is determined via the `Debug` representation of the
14//! `TLExpr` — a deterministic structural fingerprint. This avoids requiring
15//! `Hash` or `PartialEq` on `TLExpr` while still being correct for the
16//! intended use case (repeated compilation of the same logical rule).
17//!
18//! The call-count map stores a clone of the originating `TLExpr` alongside
19//! its hit count so that, when the threshold is crossed, the original
20//! expression is available for the extra optimization pass.
21//!
22//! # Thread safety
23//!
24//! Both the hot-path cache and the call-count map are guarded by a single
25//! `Mutex`. The cold path (compilation itself) is performed *outside* the
26//! lock so that concurrent cold compilations of different expressions do not
27//! serialise on I/O-heavy optimizer work.
28
29use std::collections::hash_map::DefaultHasher;
30use std::collections::HashMap;
31use std::hash::{Hash, Hasher};
32use std::sync::{Arc, Mutex};
33
34use anyhow::Result;
35use tensorlogic_ir::{EinsumGraph, TLExpr};
36
37use crate::{
38 compile_to_einsum_with_config,
39 config::CompilationConfig,
40 dead_code::{DceConfig, DeadCodeEliminator},
41 optimize::pipeline::{OptimizationPipeline, PipelineConfig},
42};
43
44// ─────────────────────────────────────────────────────────────────────────────
45// Public error type
46// ─────────────────────────────────────────────────────────────────────────────
47
48/// Errors emitted by [`JitCompiler`].
49#[derive(Debug)]
50pub enum JitError {
51 /// The underlying compilation step failed.
52 CompilationFailed(anyhow::Error),
53}
54
55impl std::fmt::Display for JitError {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 JitError::CompilationFailed(e) => write!(f, "JIT compilation failed: {}", e),
59 }
60 }
61}
62
63impl std::error::Error for JitError {
64 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
65 match self {
66 JitError::CompilationFailed(e) => e.source(),
67 }
68 }
69}
70
71impl From<anyhow::Error> for JitError {
72 fn from(e: anyhow::Error) -> Self {
73 JitError::CompilationFailed(e)
74 }
75}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// Statistics
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Statistics snapshot from a [`JitCompiler`].
82#[derive(Debug, Clone, Default)]
83pub struct JitStats {
84 /// Number of distinct expressions currently promoted to the hot-path cache.
85 pub hot_paths: usize,
86 /// Total number of compile calls that went through the cold path
87 /// (including the final cold call that triggers an upgrade).
88 pub cold_compilations: usize,
89 /// Number of compile calls that returned a pre-compiled hot-path graph.
90 pub jit_hits: usize,
91 /// Number of expressions that were upgraded from cold to hot (promoted).
92 pub jit_upgrades: usize,
93}
94
95// ─────────────────────────────────────────────────────────────────────────────
96// Internal types
97// ─────────────────────────────────────────────────────────────────────────────
98
99/// A compiled hot-path entry.
100#[derive(Clone)]
101struct JitEntry {
102 /// Pre-optimised, pre-compiled graph.
103 graph: Arc<EinsumGraph>,
104 /// Number of cache hits since promotion.
105 hit_count: usize,
106}
107
108/// Per-expression tracking record kept in the call-count map.
109struct CallRecord {
110 /// Running invocation count (incremented on every `compile` call).
111 count: usize,
112 /// Clone of the originating expression, needed for extra-optimization
113 /// when the threshold is crossed.
114 expr: TLExpr,
115}
116
117struct JitCacheInner {
118 /// Expressions that have been promoted to the hot path.
119 hot_paths: HashMap<u64, JitEntry>,
120 /// Call counts plus originating expression for every seen expression.
121 call_counts: HashMap<u64, CallRecord>,
122 /// Running statistics.
123 stats: JitStats,
124}
125
126impl JitCacheInner {
127 fn new() -> Self {
128 Self {
129 hot_paths: HashMap::new(),
130 call_counts: HashMap::new(),
131 stats: JitStats::default(),
132 }
133 }
134}
135
136// ─────────────────────────────────────────────────────────────────────────────
137// JitCompiler
138// ─────────────────────────────────────────────────────────────────────────────
139
140/// JIT compiler with hot-path detection and pre-optimized graph caching.
141///
142/// # Example
143///
144/// ```rust
145/// use tensorlogic_compiler::JitCompiler;
146/// use tensorlogic_ir::{TLExpr, Term};
147///
148/// let jit = JitCompiler::new(3);
149/// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
150///
151/// for _ in 0..5 {
152/// let graph = jit.compile(&expr).expect("compile");
153/// let _ = graph;
154/// }
155///
156/// let stats = jit.stats();
157/// assert_eq!(jit.hot_path_count(), 1);
158/// assert!(stats.jit_hits > 0);
159/// ```
160pub struct JitCompiler {
161 /// Compilation configuration forwarded to the cold path.
162 config: CompilationConfig,
163 /// Number of compilations required before an expression is promoted.
164 pub hot_threshold: usize,
165 /// Shared cache guarded by a mutex.
166 cache: Arc<Mutex<JitCacheInner>>,
167}
168
169// ─────────────────────────────────────────────────────────────────────────────
170// Expression hashing helper
171// ─────────────────────────────────────────────────────────────────────────────
172
173/// Compute a structural fingerprint for a `TLExpr` via its `Debug` output.
174///
175/// Two structurally identical expressions produce the same fingerprint.
176/// Collisions are possible but astronomically unlikely for the intended use
177/// case of tracking repeated rule compilations.
178fn expr_hash(expr: &TLExpr) -> u64 {
179 let repr = format!("{expr:?}");
180 let mut hasher = DefaultHasher::new();
181 repr.hash(&mut hasher);
182 hasher.finish()
183}
184
185// ─────────────────────────────────────────────────────────────────────────────
186// JitCompiler implementation
187// ─────────────────────────────────────────────────────────────────────────────
188
189impl JitCompiler {
190 /// Create a new JIT compiler with default [`CompilationConfig`].
191 ///
192 /// `hot_threshold` is the number of compilations an expression must
193 /// accumulate before it is promoted to the hot-path cache.
194 pub fn new(hot_threshold: usize) -> Self {
195 Self::with_config(CompilationConfig::default(), hot_threshold)
196 }
197
198 /// Create a new JIT compiler with a custom [`CompilationConfig`].
199 pub fn with_config(config: CompilationConfig, hot_threshold: usize) -> Self {
200 Self {
201 config,
202 hot_threshold,
203 cache: Arc::new(Mutex::new(JitCacheInner::new())),
204 }
205 }
206
207 /// Compile `expr`, returning a shared `Arc<EinsumGraph>`.
208 ///
209 /// - On the first `hot_threshold` calls the expression is compiled via the
210 /// normal cold path.
211 /// - When the call count reaches `hot_threshold` the expression is
212 /// optimised with an aggressive expression-level pass and recompiled;
213 /// the result is inserted into the hot-path cache.
214 /// - All subsequent calls for the same expression return the cached graph
215 /// directly without invoking the compiler.
216 pub fn compile(&self, expr: &TLExpr) -> Result<Arc<EinsumGraph>, JitError> {
217 let key = expr_hash(expr);
218
219 // ── Fast path: check hot cache before doing any compilation work ──────
220 {
221 let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
222
223 // Increment call count; insert a new record if first time seen.
224 let record = guard.call_counts.entry(key).or_insert_with(|| CallRecord {
225 count: 0,
226 expr: expr.clone(),
227 });
228 record.count += 1;
229
230 // Hot-path hit: return cached graph immediately.
231 //
232 // We clone the Arc while holding the mutable borrow on the entry,
233 // then drop the mutable borrow before updating the sibling stats
234 // field — satisfying the single-&mut rule.
235 if let Some(arc) = guard.hot_paths.get_mut(&key).map(|entry| {
236 entry.hit_count += 1;
237 Arc::clone(&entry.graph)
238 }) {
239 guard.stats.jit_hits += 1;
240 return Ok(arc);
241 }
242 }
243
244 // ── Cold path: compile the expression normally ─────────────────────────
245 let cold_graph = compile_to_einsum_with_config(expr, &self.config)?;
246
247 // ── Check current call count to decide on promotion ───────────────────
248 let current_count = {
249 let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
250 guard.call_counts.get(&key).map(|r| r.count).unwrap_or(0)
251 };
252
253 if current_count >= self.hot_threshold {
254 // Retrieve the stored expression for the extra optimisation pass.
255 let stored_expr = {
256 let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
257 guard.call_counts.get(&key).map(|r| r.expr.clone())
258 };
259
260 if let Some(original_expr) = stored_expr {
261 let optimized_graph = self.apply_extra_optimization(&original_expr)?;
262 let arc = Arc::new(optimized_graph);
263
264 let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
265 // Guard against a concurrent thread that already promoted this key.
266 if let std::collections::hash_map::Entry::Vacant(slot) = guard.hot_paths.entry(key)
267 {
268 slot.insert(JitEntry {
269 graph: Arc::clone(&arc),
270 hit_count: 0,
271 });
272 guard.stats.jit_upgrades += 1;
273 guard.stats.hot_paths += 1;
274 }
275 guard.stats.cold_compilations += 1;
276 return Ok(arc);
277 }
278 }
279
280 // Below threshold: return cold-compiled graph without promotion.
281 let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
282 guard.stats.cold_compilations += 1;
283 Ok(Arc::new(cold_graph))
284 }
285
286 /// Apply the extra expression-level optimization pass used when promoting
287 /// an expression to the hot path.
288 ///
289 /// Strategy (in order of decreasing preference):
290 ///
291 /// 1. Run the [`OptimizationPipeline`] with an **aggressive** configuration
292 /// (max 20 iterations, all passes enabled including distributivity and
293 /// quantifier hoisting) on `expr`.
294 /// 2. Follow with a full [`DeadCodeEliminator`] fixed-point pass.
295 /// 3. Recompile the doubly-optimised expression with [`compile_to_einsum_with_config`].
296 ///
297 /// This produces a graph whose underlying expression has had significantly
298 /// more algebraic simplification applied compared to the cold path.
299 fn apply_extra_optimization(&self, expr: &TLExpr) -> Result<EinsumGraph, JitError> {
300 // Step 1: Aggressive expression-level pipeline optimisation.
301 let aggressive_config = PipelineConfig {
302 enable_negation_opt: true,
303 enable_constant_folding: true,
304 enable_algebraic_simplification: true,
305 enable_strength_reduction: true,
306 enable_distributivity: true,
307 enable_quantifier_opt: true,
308 enable_dead_code_elimination: true,
309 max_iterations: 20,
310 stop_on_fixed_point: true,
311 };
312 let pipeline = OptimizationPipeline::with_config(aggressive_config);
313 let (after_pipeline, _pipeline_stats) = pipeline.optimize(expr);
314
315 // Step 2: Additional dead-code elimination pass to prune branches that
316 // may have become unreachable after constant folding / strength
317 // reduction in the pipeline.
318 let dce_config = DceConfig {
319 eliminate_constant_and: true,
320 eliminate_constant_or: true,
321 eliminate_constant_not: true,
322 eliminate_if_branches: true,
323 eliminate_unused_let: true,
324 max_passes: 20,
325 };
326 let eliminator = DeadCodeEliminator::new(dce_config);
327 let (fully_optimized, _dce_stats) = eliminator.run(after_pipeline);
328
329 // Step 3: Compile the fully-optimised expression to an EinsumGraph.
330 let graph = compile_to_einsum_with_config(&fully_optimized, &self.config)?;
331
332 Ok(graph)
333 }
334
335 /// Return a snapshot of the current JIT statistics.
336 pub fn stats(&self) -> JitStats {
337 let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
338 guard.stats.clone()
339 }
340
341 /// Evict all cached hot-path graphs and reset all counters.
342 ///
343 /// After this call the JIT compiler behaves as if it were freshly
344 /// constructed.
345 pub fn clear_cache(&mut self) {
346 if let Ok(mut guard) = self.cache.lock() {
347 *guard = JitCacheInner::new();
348 }
349 }
350
351 /// Return the number of distinct expressions currently in the hot-path cache.
352 pub fn hot_path_count(&self) -> usize {
353 let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
354 guard.hot_paths.len()
355 }
356
357 /// Return the total number of times `expr` has been compiled via this instance.
358 ///
359 /// Returns `0` if `expr` has never been seen.
360 pub fn call_count(&self, expr: &TLExpr) -> usize {
361 let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
362 guard
363 .call_counts
364 .get(&expr_hash(expr))
365 .map(|r| r.count)
366 .unwrap_or(0)
367 }
368
369 /// Return the hot-path threshold used by this instance.
370 pub fn threshold(&self) -> usize {
371 self.hot_threshold
372 }
373}
374
375// ─────────────────────────────────────────────────────────────────────────────
376// Tests
377// ─────────────────────────────────────────────────────────────────────────────
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use tensorlogic_ir::{TLExpr, Term};
383
384 fn simple_expr() -> TLExpr {
385 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")])
386 }
387
388 fn different_expr() -> TLExpr {
389 TLExpr::pred("likes", vec![Term::var("a")])
390 }
391
392 #[test]
393 fn test_cold_path_returns_graph() {
394 let jit = JitCompiler::new(5);
395 let graph = jit.compile(&simple_expr()).expect("cold compile");
396 // Graph must be valid (may be empty for trivial predicates — just must not panic).
397 let _ = graph;
398 let stats = jit.stats();
399 assert_eq!(stats.cold_compilations, 1);
400 assert_eq!(stats.jit_hits, 0);
401 }
402
403 #[test]
404 fn test_hot_upgrade_at_threshold() {
405 let jit = JitCompiler::new(3);
406 let expr = simple_expr();
407 for _ in 0..3 {
408 jit.compile(&expr).expect("compile");
409 }
410 assert_eq!(jit.hot_path_count(), 1);
411 let stats = jit.stats();
412 assert!(stats.jit_upgrades >= 1);
413 }
414
415 #[test]
416 fn test_jit_hit_after_upgrade() {
417 let jit = JitCompiler::new(2);
418 let expr = simple_expr();
419 // First two calls: cold (second one triggers the upgrade).
420 jit.compile(&expr).expect("call 1");
421 jit.compile(&expr).expect("call 2");
422 // Third call: should be a hit from the hot cache.
423 jit.compile(&expr).expect("call 3");
424 let stats = jit.stats();
425 assert!(
426 stats.jit_hits >= 1,
427 "expected at least 1 jit_hit, got {stats:?}"
428 );
429 }
430
431 #[test]
432 fn test_clear_cache_resets() {
433 let mut jit = JitCompiler::new(1);
434 let expr = simple_expr();
435 jit.compile(&expr).expect("compile once");
436 assert_eq!(jit.hot_path_count(), 1);
437 jit.clear_cache();
438 assert_eq!(jit.hot_path_count(), 0);
439 assert_eq!(jit.call_count(&expr), 0);
440 }
441
442 #[test]
443 fn test_different_exprs_tracked_separately() {
444 let jit = JitCompiler::new(10);
445 let e1 = simple_expr();
446 let e2 = different_expr();
447 for _ in 0..3 {
448 jit.compile(&e1).expect("e1");
449 }
450 jit.compile(&e2).expect("e2");
451 assert_eq!(jit.call_count(&e1), 3);
452 assert_eq!(jit.call_count(&e2), 1);
453 }
454
455 #[test]
456 fn test_threshold_one_upgrades_immediately() {
457 let jit = JitCompiler::new(1);
458 let expr = simple_expr();
459 jit.compile(&expr).expect("first call");
460 assert_eq!(jit.hot_path_count(), 1);
461 }
462
463 #[test]
464 fn test_stats_consistent() {
465 let jit = JitCompiler::new(3);
466 let expr = simple_expr();
467 let total = 5usize;
468 for _ in 0..total {
469 jit.compile(&expr).expect("compile");
470 }
471 let stats = jit.stats();
472 assert_eq!(
473 stats.cold_compilations + stats.jit_hits,
474 total,
475 "cold + hits must equal total calls; got {stats:?}"
476 );
477 }
478
479 #[test]
480 fn test_hot_graph_not_empty() {
481 let jit = JitCompiler::new(2);
482 let expr = simple_expr();
483 jit.compile(&expr).expect("call 1");
484 jit.compile(&expr).expect("call 2");
485 // Third call hits the hot cache — should not panic.
486 let graph = jit.compile(&expr).expect("call 3 (hot)");
487 let _ = graph;
488 }
489
490 #[test]
491 fn test_threshold_accessor() {
492 let jit = JitCompiler::new(7);
493 assert_eq!(jit.threshold(), 7);
494 }
495}