Skip to main content

tensorlogic_scirs_backend/lazy/
executor.rs

1//! Lazy executor with per-node caching on top of [`Scirs2Exec`].
2//!
3//! `LazyExecutor` wraps a `Scirs2Exec` and maintains a node-level cache keyed
4//! by `usize` (node index).  Before delegating any tensor operation to the
5//! inner executor it checks the cache; hits increment `LazyStats::cache_hits`
6//! while misses increment `LazyStats::cache_misses` and populate the cache for
7//! future calls.
8//!
9//! The executor also implements [`TlAutodiff`] by delegating forward / backward
10//! passes to the inner executor and caching intermediate outputs.
11
12use std::collections::HashMap;
13
14use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlAutodiff, TlExecutor};
15use tensorlogic_ir::EinsumGraph;
16
17use crate::autodiff::ForwardTape;
18use crate::{Scirs2Exec, Scirs2Tensor};
19
20/// Accumulated statistics for a [`LazyExecutor`] session.
21#[derive(Debug, Default, Clone)]
22pub struct LazyStats {
23    /// Number of tensor lookups served directly from the cache.
24    pub cache_hits: usize,
25    /// Number of tensor lookups that required actual computation.
26    pub cache_misses: usize,
27    /// Number of tensors that had to be re-computed after cache invalidation.
28    pub tensors_recomputed: usize,
29    /// High-water-mark estimate of live memory (in bytes).
30    pub peak_memory_estimate_bytes: usize,
31}
32
33/// A lazy executor that caches computed tensors by node index.
34///
35/// All tensor operations are delegated to an inner [`Scirs2Exec`]; the cache
36/// layer sits above it and avoids redundant work when the same node is
37/// requested multiple times (e.g. during iterative training with a static
38/// graph).
39pub struct LazyExecutor {
40    inner: Scirs2Exec,
41    /// Tensor cache: maps node_id → computed tensor.
42    cache: HashMap<usize, Scirs2Tensor>,
43    stats: LazyStats,
44}
45
46impl LazyExecutor {
47    /// Create a new `LazyExecutor` with an empty cache.
48    pub fn new() -> Self {
49        Self {
50            inner: Scirs2Exec::new(),
51            cache: HashMap::new(),
52            stats: LazyStats::default(),
53        }
54    }
55
56    /// Create a `LazyExecutor` with a pre-allocated cache capacity.
57    pub fn with_capacity(capacity: usize) -> Self {
58        Self {
59            inner: Scirs2Exec::new(),
60            cache: HashMap::with_capacity(capacity),
61            stats: LazyStats::default(),
62        }
63    }
64
65    /// Discard all cached tensors.
66    pub fn invalidate_cache(&mut self) {
67        self.cache.clear();
68    }
69
70    /// Remove a single node from the cache.  The next access to that node will
71    /// be a miss and will increment `tensors_recomputed`.
72    pub fn invalidate_node(&mut self, node_id: usize) {
73        if self.cache.remove(&node_id).is_some() {
74            self.stats.tensors_recomputed += 1;
75        }
76    }
77
78    /// Read-only reference to the accumulated statistics.
79    pub fn stats(&self) -> &LazyStats {
80        &self.stats
81    }
82
83    /// Rough total memory estimate for all tensors currently held by `graph`.
84    ///
85    /// Computed as `number_of_nodes * average_cached_tensor_size` — a simple
86    /// heuristic based on what is already in the cache.
87    pub fn memory_estimate_for(&self, graph: &EinsumGraph) -> usize {
88        if graph.nodes.is_empty() {
89            return 0;
90        }
91        if self.cache.is_empty() {
92            return 0;
93        }
94        let total_cached_bytes: usize = self
95            .cache
96            .values()
97            .map(|t| t.len() * std::mem::size_of::<f64>())
98            .sum();
99        let avg_bytes = total_cached_bytes / self.cache.len();
100        avg_bytes * graph.nodes.len()
101    }
102
103    /// Number of tensors currently in the cache.
104    pub fn cached_count(&self) -> usize {
105        self.cache.len()
106    }
107
108    // ------------------------------------------------------------------
109    // Internal helpers
110    // ------------------------------------------------------------------
111
112    /// Look up `node_id` in the cache.  Returns the cached tensor and bumps
113    /// `cache_hits` if found.
114    fn cache_get(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
115        if let Some(t) = self.cache.get(&node_id) {
116            self.stats.cache_hits += 1;
117            Some(t.clone())
118        } else {
119            self.stats.cache_misses += 1;
120            None
121        }
122    }
123
124    /// Insert a computed tensor into the cache and update peak memory stats.
125    fn cache_insert(&mut self, node_id: usize, tensor: Scirs2Tensor) {
126        let size = tensor.len() * std::mem::size_of::<f64>();
127        self.cache.insert(node_id, tensor);
128        let current_bytes: usize = self
129            .cache
130            .values()
131            .map(|t| t.len() * std::mem::size_of::<f64>())
132            .sum();
133        if current_bytes > self.stats.peak_memory_estimate_bytes {
134            self.stats.peak_memory_estimate_bytes = current_bytes;
135        }
136        // suppress unused-variable warning for size in release builds
137        let _ = size;
138    }
139}
140
141impl Default for LazyExecutor {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147// ---------------------------------------------------------------------------
148// TlExecutor implementation
149// ---------------------------------------------------------------------------
150
151impl TlExecutor for LazyExecutor {
152    type Tensor = Scirs2Tensor;
153    type Error = ExecutorError;
154
155    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
156        // The einsum operations themselves are not keyed by node_id here; the
157        // cache is populated at the graph-traversal level (via forward()).
158        // Direct calls to einsum bypass the cache — they are atomic operations.
159        self.inner.einsum(spec, inputs)
160    }
161
162    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
163        self.inner.elem_op(op, x)
164    }
165
166    fn elem_op_binary(
167        &mut self,
168        op: ElemOp,
169        x: &Self::Tensor,
170        y: &Self::Tensor,
171    ) -> Result<Self::Tensor, Self::Error> {
172        self.inner.elem_op_binary(op, x, y)
173    }
174
175    fn reduce(
176        &mut self,
177        op: ReduceOp,
178        x: &Self::Tensor,
179        axes: &[usize],
180    ) -> Result<Self::Tensor, Self::Error> {
181        self.inner.reduce(op, x, axes)
182    }
183}
184
185// ---------------------------------------------------------------------------
186// TlAutodiff implementation
187// ---------------------------------------------------------------------------
188
189impl TlAutodiff for LazyExecutor {
190    type Tape = ForwardTape;
191
192    /// Execute the forward pass, caching every node output.
193    ///
194    /// Node outputs are stored in `self.cache` keyed by their index in
195    /// `graph.nodes` so that subsequent forward calls on the same (or an
196    /// overlapping) graph reuse already-computed tensors.
197    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
198        // Delegate to inner; it returns the final output tensor and stores the
199        // full ForwardTape internally.
200        let result = self.inner.forward(graph)?;
201
202        // Populate cache with the node outputs stored by the inner executor.
203        // The inner `ForwardTape` holds one Option<Scirs2Tensor> per *tensor*
204        // index (not node index). We map node_index → its output tensor index.
205        // Collect tensors first to avoid simultaneous mutable + immutable borrows.
206        let node_tensors: Vec<(usize, Scirs2Tensor)> = if let Some(tape) = &self.inner.tape {
207            graph
208                .nodes
209                .iter()
210                .enumerate()
211                .filter_map(|(node_idx, node)| {
212                    node.outputs.first().and_then(|&tensor_idx| {
213                        tape.tensors
214                            .get(tensor_idx)
215                            .and_then(|opt| opt.as_ref())
216                            .map(|t| (node_idx, t.clone()))
217                    })
218                })
219                .collect()
220        } else {
221            Vec::new()
222        };
223
224        for (node_idx, tensor) in node_tensors {
225            if !self.cache.contains_key(&node_idx) {
226                self.cache_insert(node_idx, tensor);
227            } else {
228                self.stats.cache_hits += 1;
229            }
230        }
231
232        Ok(result)
233    }
234
235    /// Execute the backward pass, delegating to the inner executor.
236    fn backward(
237        &mut self,
238        graph: &EinsumGraph,
239        loss: &Self::Tensor,
240    ) -> Result<Self::Tape, Self::Error> {
241        self.inner.backward(graph, loss)
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Node-level cache lookup (optional convenience for graph executors)
247// ---------------------------------------------------------------------------
248
249impl LazyExecutor {
250    /// Retrieve a cached tensor for the given node index (if available).
251    ///
252    /// This is the primary entry-point for lazy graph traversal: call this
253    /// before scheduling a node for execution.  On a hit the value is returned
254    /// without calling any inner operations.
255    pub fn get_cached(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
256        self.cache_get(node_id)
257    }
258
259    /// Store a tensor result for a node.  Subsequent calls to
260    /// `get_cached(node_id)` will return this value.
261    pub fn put_cached(&mut self, node_id: usize, tensor: Scirs2Tensor) {
262        self.cache_insert(node_id, tensor);
263    }
264
265    /// Access the inner [`Scirs2Exec`] (e.g. to register input tensors).
266    pub fn inner_mut(&mut self) -> &mut Scirs2Exec {
267        &mut self.inner
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use tensorlogic_ir::EinsumGraph;
275
276    #[test]
277    fn test_lazy_executor_default() {
278        let exec = LazyExecutor::default();
279        assert_eq!(exec.cached_count(), 0);
280    }
281
282    #[test]
283    fn test_lazy_executor_cached_count_starts_zero() {
284        let exec = LazyExecutor::new();
285        assert_eq!(exec.cached_count(), 0);
286    }
287
288    #[test]
289    fn test_lazy_executor_invalidate_cache() {
290        let mut exec = LazyExecutor::with_capacity(4);
291        use scirs2_core::ndarray::ArrayD;
292        let t: Scirs2Tensor = ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[2, 2]));
293        exec.put_cached(0, t);
294        assert_eq!(exec.cached_count(), 1);
295        exec.invalidate_cache();
296        assert_eq!(exec.cached_count(), 0);
297    }
298
299    #[test]
300    fn test_lazy_stats_default() {
301        let stats = LazyStats::default();
302        assert_eq!(stats.cache_hits, 0);
303        assert_eq!(stats.cache_misses, 0);
304        assert_eq!(stats.tensors_recomputed, 0);
305        assert_eq!(stats.peak_memory_estimate_bytes, 0);
306    }
307
308    #[test]
309    fn test_lazy_executor_memory_estimate_for_empty_graph() {
310        let exec = LazyExecutor::new();
311        let g = EinsumGraph::new();
312        assert_eq!(exec.memory_estimate_for(&g), 0);
313    }
314}