tensorlogic_scirs_backend/lazy/executor.rs
1//! Lazy executor with per-node caching on top of [`Scirs2Exec`].
2//!
3//! `LazyExecutor` wraps a `Scirs2Exec` and maintains a node-level cache keyed
4//! by `usize` (node index). Before delegating any tensor operation to the
5//! inner executor it checks the cache; hits increment `LazyStats::cache_hits`
6//! while misses increment `LazyStats::cache_misses` and populate the cache for
7//! future calls.
8//!
9//! The executor also implements [`TlAutodiff`] by delegating forward / backward
10//! passes to the inner executor and caching intermediate outputs.
11
12use std::collections::HashMap;
13
14use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlAutodiff, TlExecutor};
15use tensorlogic_ir::EinsumGraph;
16
17use crate::autodiff::ForwardTape;
18use crate::{Scirs2Exec, Scirs2Tensor};
19
20/// Accumulated statistics for a [`LazyExecutor`] session.
21#[derive(Debug, Default, Clone)]
22pub struct LazyStats {
23 /// Number of tensor lookups served directly from the cache.
24 pub cache_hits: usize,
25 /// Number of tensor lookups that required actual computation.
26 pub cache_misses: usize,
27 /// Number of tensors that had to be re-computed after cache invalidation.
28 pub tensors_recomputed: usize,
29 /// High-water-mark estimate of live memory (in bytes).
30 pub peak_memory_estimate_bytes: usize,
31}
32
33/// A lazy executor that caches computed tensors by node index.
34///
35/// All tensor operations are delegated to an inner [`Scirs2Exec`]; the cache
36/// layer sits above it and avoids redundant work when the same node is
37/// requested multiple times (e.g. during iterative training with a static
38/// graph).
39pub struct LazyExecutor {
40 inner: Scirs2Exec,
41 /// Tensor cache: maps node_id → computed tensor.
42 cache: HashMap<usize, Scirs2Tensor>,
43 stats: LazyStats,
44}
45
46impl LazyExecutor {
47 /// Create a new `LazyExecutor` with an empty cache.
48 pub fn new() -> Self {
49 Self {
50 inner: Scirs2Exec::new(),
51 cache: HashMap::new(),
52 stats: LazyStats::default(),
53 }
54 }
55
56 /// Create a `LazyExecutor` with a pre-allocated cache capacity.
57 pub fn with_capacity(capacity: usize) -> Self {
58 Self {
59 inner: Scirs2Exec::new(),
60 cache: HashMap::with_capacity(capacity),
61 stats: LazyStats::default(),
62 }
63 }
64
65 /// Discard all cached tensors.
66 pub fn invalidate_cache(&mut self) {
67 self.cache.clear();
68 }
69
70 /// Remove a single node from the cache. The next access to that node will
71 /// be a miss and will increment `tensors_recomputed`.
72 pub fn invalidate_node(&mut self, node_id: usize) {
73 if self.cache.remove(&node_id).is_some() {
74 self.stats.tensors_recomputed += 1;
75 }
76 }
77
78 /// Read-only reference to the accumulated statistics.
79 pub fn stats(&self) -> &LazyStats {
80 &self.stats
81 }
82
83 /// Rough total memory estimate for all tensors currently held by `graph`.
84 ///
85 /// Computed as `number_of_nodes * average_cached_tensor_size` — a simple
86 /// heuristic based on what is already in the cache.
87 pub fn memory_estimate_for(&self, graph: &EinsumGraph) -> usize {
88 if graph.nodes.is_empty() {
89 return 0;
90 }
91 if self.cache.is_empty() {
92 return 0;
93 }
94 let total_cached_bytes: usize = self
95 .cache
96 .values()
97 .map(|t| t.len() * std::mem::size_of::<f64>())
98 .sum();
99 let avg_bytes = total_cached_bytes / self.cache.len();
100 avg_bytes * graph.nodes.len()
101 }
102
103 /// Number of tensors currently in the cache.
104 pub fn cached_count(&self) -> usize {
105 self.cache.len()
106 }
107
108 // ------------------------------------------------------------------
109 // Internal helpers
110 // ------------------------------------------------------------------
111
112 /// Look up `node_id` in the cache. Returns the cached tensor and bumps
113 /// `cache_hits` if found.
114 fn cache_get(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
115 if let Some(t) = self.cache.get(&node_id) {
116 self.stats.cache_hits += 1;
117 Some(t.clone())
118 } else {
119 self.stats.cache_misses += 1;
120 None
121 }
122 }
123
124 /// Insert a computed tensor into the cache and update peak memory stats.
125 fn cache_insert(&mut self, node_id: usize, tensor: Scirs2Tensor) {
126 let size = tensor.len() * std::mem::size_of::<f64>();
127 self.cache.insert(node_id, tensor);
128 let current_bytes: usize = self
129 .cache
130 .values()
131 .map(|t| t.len() * std::mem::size_of::<f64>())
132 .sum();
133 if current_bytes > self.stats.peak_memory_estimate_bytes {
134 self.stats.peak_memory_estimate_bytes = current_bytes;
135 }
136 // suppress unused-variable warning for size in release builds
137 let _ = size;
138 }
139}
140
141impl Default for LazyExecutor {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147// ---------------------------------------------------------------------------
148// TlExecutor implementation
149// ---------------------------------------------------------------------------
150
151impl TlExecutor for LazyExecutor {
152 type Tensor = Scirs2Tensor;
153 type Error = ExecutorError;
154
155 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
156 // The einsum operations themselves are not keyed by node_id here; the
157 // cache is populated at the graph-traversal level (via forward()).
158 // Direct calls to einsum bypass the cache — they are atomic operations.
159 self.inner.einsum(spec, inputs)
160 }
161
162 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
163 self.inner.elem_op(op, x)
164 }
165
166 fn elem_op_binary(
167 &mut self,
168 op: ElemOp,
169 x: &Self::Tensor,
170 y: &Self::Tensor,
171 ) -> Result<Self::Tensor, Self::Error> {
172 self.inner.elem_op_binary(op, x, y)
173 }
174
175 fn reduce(
176 &mut self,
177 op: ReduceOp,
178 x: &Self::Tensor,
179 axes: &[usize],
180 ) -> Result<Self::Tensor, Self::Error> {
181 self.inner.reduce(op, x, axes)
182 }
183}
184
185// ---------------------------------------------------------------------------
186// TlAutodiff implementation
187// ---------------------------------------------------------------------------
188
189impl TlAutodiff for LazyExecutor {
190 type Tape = ForwardTape;
191
192 /// Execute the forward pass, caching every node output.
193 ///
194 /// Node outputs are stored in `self.cache` keyed by their index in
195 /// `graph.nodes` so that subsequent forward calls on the same (or an
196 /// overlapping) graph reuse already-computed tensors.
197 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
198 // Delegate to inner; it returns the final output tensor and stores the
199 // full ForwardTape internally.
200 let result = self.inner.forward(graph)?;
201
202 // Populate cache with the node outputs stored by the inner executor.
203 // The inner `ForwardTape` holds one Option<Scirs2Tensor> per *tensor*
204 // index (not node index). We map node_index → its output tensor index.
205 // Collect tensors first to avoid simultaneous mutable + immutable borrows.
206 let node_tensors: Vec<(usize, Scirs2Tensor)> = if let Some(tape) = &self.inner.tape {
207 graph
208 .nodes
209 .iter()
210 .enumerate()
211 .filter_map(|(node_idx, node)| {
212 node.outputs.first().and_then(|&tensor_idx| {
213 tape.tensors
214 .get(tensor_idx)
215 .and_then(|opt| opt.as_ref())
216 .map(|t| (node_idx, t.clone()))
217 })
218 })
219 .collect()
220 } else {
221 Vec::new()
222 };
223
224 for (node_idx, tensor) in node_tensors {
225 if !self.cache.contains_key(&node_idx) {
226 self.cache_insert(node_idx, tensor);
227 } else {
228 self.stats.cache_hits += 1;
229 }
230 }
231
232 Ok(result)
233 }
234
235 /// Execute the backward pass, delegating to the inner executor.
236 fn backward(
237 &mut self,
238 graph: &EinsumGraph,
239 loss: &Self::Tensor,
240 ) -> Result<Self::Tape, Self::Error> {
241 self.inner.backward(graph, loss)
242 }
243}
244
245// ---------------------------------------------------------------------------
246// Node-level cache lookup (optional convenience for graph executors)
247// ---------------------------------------------------------------------------
248
249impl LazyExecutor {
250 /// Retrieve a cached tensor for the given node index (if available).
251 ///
252 /// This is the primary entry-point for lazy graph traversal: call this
253 /// before scheduling a node for execution. On a hit the value is returned
254 /// without calling any inner operations.
255 pub fn get_cached(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
256 self.cache_get(node_id)
257 }
258
259 /// Store a tensor result for a node. Subsequent calls to
260 /// `get_cached(node_id)` will return this value.
261 pub fn put_cached(&mut self, node_id: usize, tensor: Scirs2Tensor) {
262 self.cache_insert(node_id, tensor);
263 }
264
265 /// Access the inner [`Scirs2Exec`] (e.g. to register input tensors).
266 pub fn inner_mut(&mut self) -> &mut Scirs2Exec {
267 &mut self.inner
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use tensorlogic_ir::EinsumGraph;
275
276 #[test]
277 fn test_lazy_executor_default() {
278 let exec = LazyExecutor::default();
279 assert_eq!(exec.cached_count(), 0);
280 }
281
282 #[test]
283 fn test_lazy_executor_cached_count_starts_zero() {
284 let exec = LazyExecutor::new();
285 assert_eq!(exec.cached_count(), 0);
286 }
287
288 #[test]
289 fn test_lazy_executor_invalidate_cache() {
290 let mut exec = LazyExecutor::with_capacity(4);
291 use scirs2_core::ndarray::ArrayD;
292 let t: Scirs2Tensor = ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[2, 2]));
293 exec.put_cached(0, t);
294 assert_eq!(exec.cached_count(), 1);
295 exec.invalidate_cache();
296 assert_eq!(exec.cached_count(), 0);
297 }
298
299 #[test]
300 fn test_lazy_stats_default() {
301 let stats = LazyStats::default();
302 assert_eq!(stats.cache_hits, 0);
303 assert_eq!(stats.cache_misses, 0);
304 assert_eq!(stats.tensors_recomputed, 0);
305 assert_eq!(stats.peak_memory_estimate_bytes, 0);
306 }
307
308 #[test]
309 fn test_lazy_executor_memory_estimate_for_empty_graph() {
310 let exec = LazyExecutor::new();
311 let g = EinsumGraph::new();
312 assert_eq!(exec.memory_estimate_for(&g), 0);
313 }
314}