Skip to main content

tensorlogic_compiler/
jit.rs

1//! JIT compilation for hot expression paths.
2//!
3//! [`JitCompiler`] wraps a standard compilation pipeline and tracks expression
4//! usage frequency. When an expression exceeds [`JitCompiler::hot_threshold`]
5//! compilations it is promoted to the "hot path": the expression is re-optimised
6//! more aggressively with [`OptimizationPipeline`] (aggressive preset) before
7//! compilation, and the result is stored as a pre-computed [`Arc<EinsumGraph>`].
8//! All subsequent compilations of the same hot expression return the cached
9//! graph in O(1) without re-running the optimizer or compiler.
10//!
11//! # Design notes
12//!
13//! Expression identity is determined via the `Debug` representation of the
14//! `TLExpr` — a deterministic structural fingerprint. This avoids requiring
15//! `Hash` or `PartialEq` on `TLExpr` while still being correct for the
16//! intended use case (repeated compilation of the same logical rule).
17//!
18//! The call-count map stores a clone of the originating `TLExpr` alongside
19//! its hit count so that, when the threshold is crossed, the original
20//! expression is available for the extra optimization pass.
21//!
22//! # Thread safety
23//!
24//! Both the hot-path cache and the call-count map are guarded by a single
25//! `Mutex`. The cold path (compilation itself) is performed *outside* the
26//! lock so that concurrent cold compilations of different expressions do not
27//! serialise on I/O-heavy optimizer work.
28
29use std::collections::hash_map::DefaultHasher;
30use std::collections::HashMap;
31use std::hash::{Hash, Hasher};
32use std::sync::{Arc, Mutex};
33
34use anyhow::Result;
35use tensorlogic_ir::{EinsumGraph, TLExpr};
36
37use crate::{
38    compile_to_einsum_with_config,
39    config::CompilationConfig,
40    dead_code::{DceConfig, DeadCodeEliminator},
41    optimize::pipeline::{OptimizationPipeline, PipelineConfig},
42};
43
44// ─────────────────────────────────────────────────────────────────────────────
45// Public error type
46// ─────────────────────────────────────────────────────────────────────────────
47
48/// Errors emitted by [`JitCompiler`].
49#[derive(Debug)]
50pub enum JitError {
51    /// The underlying compilation step failed.
52    CompilationFailed(anyhow::Error),
53}
54
55impl std::fmt::Display for JitError {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            JitError::CompilationFailed(e) => write!(f, "JIT compilation failed: {}", e),
59        }
60    }
61}
62
63impl std::error::Error for JitError {
64    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
65        match self {
66            JitError::CompilationFailed(e) => e.source(),
67        }
68    }
69}
70
71impl From<anyhow::Error> for JitError {
72    fn from(e: anyhow::Error) -> Self {
73        JitError::CompilationFailed(e)
74    }
75}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// Statistics
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Statistics snapshot from a [`JitCompiler`].
82#[derive(Debug, Clone, Default)]
83pub struct JitStats {
84    /// Number of distinct expressions currently promoted to the hot-path cache.
85    pub hot_paths: usize,
86    /// Total number of compile calls that went through the cold path
87    /// (including the final cold call that triggers an upgrade).
88    pub cold_compilations: usize,
89    /// Number of compile calls that returned a pre-compiled hot-path graph.
90    pub jit_hits: usize,
91    /// Number of expressions that were upgraded from cold to hot (promoted).
92    pub jit_upgrades: usize,
93}
94
95// ─────────────────────────────────────────────────────────────────────────────
96// Internal types
97// ─────────────────────────────────────────────────────────────────────────────
98
99/// A compiled hot-path entry.
100#[derive(Clone)]
101struct JitEntry {
102    /// Pre-optimised, pre-compiled graph.
103    graph: Arc<EinsumGraph>,
104    /// Number of cache hits since promotion.
105    hit_count: usize,
106}
107
108/// Per-expression tracking record kept in the call-count map.
109struct CallRecord {
110    /// Running invocation count (incremented on every `compile` call).
111    count: usize,
112    /// Clone of the originating expression, needed for extra-optimization
113    /// when the threshold is crossed.
114    expr: TLExpr,
115}
116
117struct JitCacheInner {
118    /// Expressions that have been promoted to the hot path.
119    hot_paths: HashMap<u64, JitEntry>,
120    /// Call counts plus originating expression for every seen expression.
121    call_counts: HashMap<u64, CallRecord>,
122    /// Running statistics.
123    stats: JitStats,
124}
125
126impl JitCacheInner {
127    fn new() -> Self {
128        Self {
129            hot_paths: HashMap::new(),
130            call_counts: HashMap::new(),
131            stats: JitStats::default(),
132        }
133    }
134}
135
136// ─────────────────────────────────────────────────────────────────────────────
137// JitCompiler
138// ─────────────────────────────────────────────────────────────────────────────
139
140/// JIT compiler with hot-path detection and pre-optimized graph caching.
141///
142/// # Example
143///
144/// ```rust
145/// use tensorlogic_compiler::JitCompiler;
146/// use tensorlogic_ir::{TLExpr, Term};
147///
148/// let jit = JitCompiler::new(3);
149/// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
150///
151/// for _ in 0..5 {
152///     let graph = jit.compile(&expr).expect("compile");
153///     let _ = graph;
154/// }
155///
156/// let stats = jit.stats();
157/// assert_eq!(jit.hot_path_count(), 1);
158/// assert!(stats.jit_hits > 0);
159/// ```
160pub struct JitCompiler {
161    /// Compilation configuration forwarded to the cold path.
162    config: CompilationConfig,
163    /// Number of compilations required before an expression is promoted.
164    pub hot_threshold: usize,
165    /// Shared cache guarded by a mutex.
166    cache: Arc<Mutex<JitCacheInner>>,
167}
168
169// ─────────────────────────────────────────────────────────────────────────────
170// Expression hashing helper
171// ─────────────────────────────────────────────────────────────────────────────
172
173/// Compute a structural fingerprint for a `TLExpr` via its `Debug` output.
174///
175/// Two structurally identical expressions produce the same fingerprint.
176/// Collisions are possible but astronomically unlikely for the intended use
177/// case of tracking repeated rule compilations.
178fn expr_hash(expr: &TLExpr) -> u64 {
179    let repr = format!("{expr:?}");
180    let mut hasher = DefaultHasher::new();
181    repr.hash(&mut hasher);
182    hasher.finish()
183}
184
185// ─────────────────────────────────────────────────────────────────────────────
186// JitCompiler implementation
187// ─────────────────────────────────────────────────────────────────────────────
188
189impl JitCompiler {
190    /// Create a new JIT compiler with default [`CompilationConfig`].
191    ///
192    /// `hot_threshold` is the number of compilations an expression must
193    /// accumulate before it is promoted to the hot-path cache.
194    pub fn new(hot_threshold: usize) -> Self {
195        Self::with_config(CompilationConfig::default(), hot_threshold)
196    }
197
198    /// Create a new JIT compiler with a custom [`CompilationConfig`].
199    pub fn with_config(config: CompilationConfig, hot_threshold: usize) -> Self {
200        Self {
201            config,
202            hot_threshold,
203            cache: Arc::new(Mutex::new(JitCacheInner::new())),
204        }
205    }
206
207    /// Compile `expr`, returning a shared `Arc<EinsumGraph>`.
208    ///
209    /// - On the first `hot_threshold` calls the expression is compiled via the
210    ///   normal cold path.
211    /// - When the call count reaches `hot_threshold` the expression is
212    ///   optimised with an aggressive expression-level pass and recompiled;
213    ///   the result is inserted into the hot-path cache.
214    /// - All subsequent calls for the same expression return the cached graph
215    ///   directly without invoking the compiler.
216    pub fn compile(&self, expr: &TLExpr) -> Result<Arc<EinsumGraph>, JitError> {
217        let key = expr_hash(expr);
218
219        // ── Fast path: check hot cache before doing any compilation work ──────
220        {
221            let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
222
223            // Increment call count; insert a new record if first time seen.
224            let record = guard.call_counts.entry(key).or_insert_with(|| CallRecord {
225                count: 0,
226                expr: expr.clone(),
227            });
228            record.count += 1;
229
230            // Hot-path hit: return cached graph immediately.
231            //
232            // We clone the Arc while holding the mutable borrow on the entry,
233            // then drop the mutable borrow before updating the sibling stats
234            // field — satisfying the single-&mut rule.
235            if let Some(arc) = guard.hot_paths.get_mut(&key).map(|entry| {
236                entry.hit_count += 1;
237                Arc::clone(&entry.graph)
238            }) {
239                guard.stats.jit_hits += 1;
240                return Ok(arc);
241            }
242        }
243
244        // ── Cold path: compile the expression normally ─────────────────────────
245        let cold_graph = compile_to_einsum_with_config(expr, &self.config)?;
246
247        // ── Check current call count to decide on promotion ───────────────────
248        let current_count = {
249            let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
250            guard.call_counts.get(&key).map(|r| r.count).unwrap_or(0)
251        };
252
253        if current_count >= self.hot_threshold {
254            // Retrieve the stored expression for the extra optimisation pass.
255            let stored_expr = {
256                let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
257                guard.call_counts.get(&key).map(|r| r.expr.clone())
258            };
259
260            if let Some(original_expr) = stored_expr {
261                let optimized_graph = self.apply_extra_optimization(&original_expr)?;
262                let arc = Arc::new(optimized_graph);
263
264                let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
265                // Guard against a concurrent thread that already promoted this key.
266                if let std::collections::hash_map::Entry::Vacant(slot) = guard.hot_paths.entry(key)
267                {
268                    slot.insert(JitEntry {
269                        graph: Arc::clone(&arc),
270                        hit_count: 0,
271                    });
272                    guard.stats.jit_upgrades += 1;
273                    guard.stats.hot_paths += 1;
274                }
275                guard.stats.cold_compilations += 1;
276                return Ok(arc);
277            }
278        }
279
280        // Below threshold: return cold-compiled graph without promotion.
281        let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
282        guard.stats.cold_compilations += 1;
283        Ok(Arc::new(cold_graph))
284    }
285
286    /// Apply the extra expression-level optimization pass used when promoting
287    /// an expression to the hot path.
288    ///
289    /// Strategy (in order of decreasing preference):
290    ///
291    /// 1. Run the [`OptimizationPipeline`] with an **aggressive** configuration
292    ///    (max 20 iterations, all passes enabled including distributivity and
293    ///    quantifier hoisting) on `expr`.
294    /// 2. Follow with a full [`DeadCodeEliminator`] fixed-point pass.
295    /// 3. Recompile the doubly-optimised expression with [`compile_to_einsum_with_config`].
296    ///
297    /// This produces a graph whose underlying expression has had significantly
298    /// more algebraic simplification applied compared to the cold path.
299    fn apply_extra_optimization(&self, expr: &TLExpr) -> Result<EinsumGraph, JitError> {
300        // Step 1: Aggressive expression-level pipeline optimisation.
301        let aggressive_config = PipelineConfig {
302            enable_negation_opt: true,
303            enable_constant_folding: true,
304            enable_algebraic_simplification: true,
305            enable_strength_reduction: true,
306            enable_distributivity: true,
307            enable_quantifier_opt: true,
308            enable_dead_code_elimination: true,
309            max_iterations: 20,
310            stop_on_fixed_point: true,
311        };
312        let pipeline = OptimizationPipeline::with_config(aggressive_config);
313        let (after_pipeline, _pipeline_stats) = pipeline.optimize(expr);
314
315        // Step 2: Additional dead-code elimination pass to prune branches that
316        //         may have become unreachable after constant folding / strength
317        //         reduction in the pipeline.
318        let dce_config = DceConfig {
319            eliminate_constant_and: true,
320            eliminate_constant_or: true,
321            eliminate_constant_not: true,
322            eliminate_if_branches: true,
323            eliminate_unused_let: true,
324            max_passes: 20,
325        };
326        let eliminator = DeadCodeEliminator::new(dce_config);
327        let (fully_optimized, _dce_stats) = eliminator.run(after_pipeline);
328
329        // Step 3: Compile the fully-optimised expression to an EinsumGraph.
330        let graph = compile_to_einsum_with_config(&fully_optimized, &self.config)?;
331
332        Ok(graph)
333    }
334
335    /// Return a snapshot of the current JIT statistics.
336    pub fn stats(&self) -> JitStats {
337        let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
338        guard.stats.clone()
339    }
340
341    /// Evict all cached hot-path graphs and reset all counters.
342    ///
343    /// After this call the JIT compiler behaves as if it were freshly
344    /// constructed.
345    pub fn clear_cache(&mut self) {
346        if let Ok(mut guard) = self.cache.lock() {
347            *guard = JitCacheInner::new();
348        }
349    }
350
351    /// Return the number of distinct expressions currently in the hot-path cache.
352    pub fn hot_path_count(&self) -> usize {
353        let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
354        guard.hot_paths.len()
355    }
356
357    /// Return the total number of times `expr` has been compiled via this instance.
358    ///
359    /// Returns `0` if `expr` has never been seen.
360    pub fn call_count(&self, expr: &TLExpr) -> usize {
361        let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
362        guard
363            .call_counts
364            .get(&expr_hash(expr))
365            .map(|r| r.count)
366            .unwrap_or(0)
367    }
368
369    /// Return the hot-path threshold used by this instance.
370    pub fn threshold(&self) -> usize {
371        self.hot_threshold
372    }
373}
374
375// ─────────────────────────────────────────────────────────────────────────────
376// Tests
377// ─────────────────────────────────────────────────────────────────────────────
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use tensorlogic_ir::{TLExpr, Term};
383
384    fn simple_expr() -> TLExpr {
385        TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")])
386    }
387
388    fn different_expr() -> TLExpr {
389        TLExpr::pred("likes", vec![Term::var("a")])
390    }
391
392    #[test]
393    fn test_cold_path_returns_graph() {
394        let jit = JitCompiler::new(5);
395        let graph = jit.compile(&simple_expr()).expect("cold compile");
396        // Graph must be valid (may be empty for trivial predicates — just must not panic).
397        let _ = graph;
398        let stats = jit.stats();
399        assert_eq!(stats.cold_compilations, 1);
400        assert_eq!(stats.jit_hits, 0);
401    }
402
403    #[test]
404    fn test_hot_upgrade_at_threshold() {
405        let jit = JitCompiler::new(3);
406        let expr = simple_expr();
407        for _ in 0..3 {
408            jit.compile(&expr).expect("compile");
409        }
410        assert_eq!(jit.hot_path_count(), 1);
411        let stats = jit.stats();
412        assert!(stats.jit_upgrades >= 1);
413    }
414
415    #[test]
416    fn test_jit_hit_after_upgrade() {
417        let jit = JitCompiler::new(2);
418        let expr = simple_expr();
419        // First two calls: cold (second one triggers the upgrade).
420        jit.compile(&expr).expect("call 1");
421        jit.compile(&expr).expect("call 2");
422        // Third call: should be a hit from the hot cache.
423        jit.compile(&expr).expect("call 3");
424        let stats = jit.stats();
425        assert!(
426            stats.jit_hits >= 1,
427            "expected at least 1 jit_hit, got {stats:?}"
428        );
429    }
430
431    #[test]
432    fn test_clear_cache_resets() {
433        let mut jit = JitCompiler::new(1);
434        let expr = simple_expr();
435        jit.compile(&expr).expect("compile once");
436        assert_eq!(jit.hot_path_count(), 1);
437        jit.clear_cache();
438        assert_eq!(jit.hot_path_count(), 0);
439        assert_eq!(jit.call_count(&expr), 0);
440    }
441
442    #[test]
443    fn test_different_exprs_tracked_separately() {
444        let jit = JitCompiler::new(10);
445        let e1 = simple_expr();
446        let e2 = different_expr();
447        for _ in 0..3 {
448            jit.compile(&e1).expect("e1");
449        }
450        jit.compile(&e2).expect("e2");
451        assert_eq!(jit.call_count(&e1), 3);
452        assert_eq!(jit.call_count(&e2), 1);
453    }
454
455    #[test]
456    fn test_threshold_one_upgrades_immediately() {
457        let jit = JitCompiler::new(1);
458        let expr = simple_expr();
459        jit.compile(&expr).expect("first call");
460        assert_eq!(jit.hot_path_count(), 1);
461    }
462
463    #[test]
464    fn test_stats_consistent() {
465        let jit = JitCompiler::new(3);
466        let expr = simple_expr();
467        let total = 5usize;
468        for _ in 0..total {
469            jit.compile(&expr).expect("compile");
470        }
471        let stats = jit.stats();
472        assert_eq!(
473            stats.cold_compilations + stats.jit_hits,
474            total,
475            "cold + hits must equal total calls; got {stats:?}"
476        );
477    }
478
479    #[test]
480    fn test_hot_graph_not_empty() {
481        let jit = JitCompiler::new(2);
482        let expr = simple_expr();
483        jit.compile(&expr).expect("call 1");
484        jit.compile(&expr).expect("call 2");
485        // Third call hits the hot cache — should not panic.
486        let graph = jit.compile(&expr).expect("call 3 (hot)");
487        let _ = graph;
488    }
489
490    #[test]
491    fn test_threshold_accessor() {
492        let jit = JitCompiler::new(7);
493        assert_eq!(jit.threshold(), 7);
494    }
495}