Skip to main content

tensorlogic_infer/
step_executor.rs

1//! Step-through executor wrapper that logs intermediate tensor statistics.
2//!
3//! `StepExecutor<E>` wraps any `TlExecutor` and records `IntermediateValue`
4//! statistics for each operation, optionally guarded by `BreakpointCondition`s.
5
6use ndarray::ArrayD;
7
8use crate::ops::{ElemOp, ReduceOp};
9use crate::traits::TlExecutor;
10
11/// Conditions that decide whether an intermediate value is recorded.
12#[derive(Debug, Clone)]
13pub enum BreakpointCondition {
14    /// Break at the operation with this sequential index.
15    NodeIndex(usize),
16    /// Break whenever the output contains NaN.
17    OnNaN,
18    /// Break whenever the output contains Inf.
19    OnInf,
20    /// Always record — equivalent to a full trace.
21    Always,
22}
23
24/// Statistics snapshot of a tensor value at a specific execution step.
25#[derive(Debug, Clone)]
26pub struct IntermediateValue {
27    /// Sequential operation index (0-based).
28    pub step: usize,
29    /// Human-readable name of the operation.
30    pub operation: String,
31    /// Shape of the tensor.
32    pub shape: Vec<usize>,
33    /// Minimum value (NaN-safe via fold).
34    pub min: f64,
35    /// Maximum value (NaN-safe via fold).
36    pub max: f64,
37    /// Mean value.
38    pub mean: f64,
39    /// Whether any element is NaN.
40    pub has_nan: bool,
41    /// Whether any element is ±Inf.
42    pub has_inf: bool,
43    /// Total number of elements.
44    pub element_count: usize,
45}
46
47impl IntermediateValue {
48    /// Build statistics from a tensor.
49    pub fn from_tensor(step: usize, op: &str, tensor: &ArrayD<f64>) -> Self {
50        let element_count = tensor.len();
51        let has_nan = tensor.iter().any(|x| x.is_nan());
52        let has_inf = tensor.iter().any(|x| x.is_infinite());
53
54        let (min, max, sum) = tensor.iter().cloned().fold(
55            (f64::INFINITY, f64::NEG_INFINITY, 0.0f64),
56            |(mn, mx, s), v| (mn.min(v), mx.max(v), s + v),
57        );
58
59        let (min, max) = if element_count == 0 {
60            (0.0, 0.0)
61        } else {
62            (min, max)
63        };
64
65        let mean = if element_count == 0 {
66            0.0
67        } else {
68            sum / element_count as f64
69        };
70
71        Self {
72            step,
73            operation: op.to_owned(),
74            shape: tensor.shape().to_vec(),
75            min,
76            max,
77            mean,
78            has_nan,
79            has_inf,
80            element_count,
81        }
82    }
83}
84
85/// Wraps any `TlExecutor` and logs `IntermediateValue` snapshots at each operation.
86///
87/// A snapshot is recorded when at least one active `BreakpointCondition` triggers.
88/// If no conditions are added no logging occurs; add `BreakpointCondition::Always`
89/// to capture every step.
90pub struct StepExecutor<E> {
91    /// The inner executor that performs actual computation.
92    pub inner: E,
93    conditions: Vec<BreakpointCondition>,
94    /// Accumulated log of intermediate values.
95    pub log: Vec<IntermediateValue>,
96    step_count: usize,
97}
98
99impl<E> StepExecutor<E> {
100    /// Create a new `StepExecutor` wrapping `inner` with no active conditions.
101    pub fn new(inner: E) -> Self {
102        Self {
103            inner,
104            conditions: Vec::new(),
105            log: Vec::new(),
106            step_count: 0,
107        }
108    }
109
110    /// Add a breakpoint condition.
111    pub fn add_condition(&mut self, cond: BreakpointCondition) {
112        self.conditions.push(cond);
113    }
114
115    /// View the accumulated log.
116    pub fn log(&self) -> &[IntermediateValue] {
117        &self.log
118    }
119
120    /// Total number of operations executed so far.
121    pub fn step_count(&self) -> usize {
122        self.step_count
123    }
124
125    /// Clear the accumulated log (step count is not reset).
126    pub fn clear_log(&mut self) {
127        self.log.clear();
128    }
129
130    /// Returns true if any logged entry contains NaN.
131    pub fn has_nan_in_log(&self) -> bool {
132        self.log.iter().any(|v| v.has_nan)
133    }
134
135    /// Returns true if any logged entry contains Inf.
136    pub fn has_inf_in_log(&self) -> bool {
137        self.log.iter().any(|v| v.has_inf)
138    }
139
140    /// One-line human-readable summary of the execution log.
141    pub fn summary(&self) -> String {
142        let nan_count = self.log.iter().filter(|v| v.has_nan).count();
143        let inf_count = self.log.iter().filter(|v| v.has_inf).count();
144        format!(
145            "StepExecutor: {} steps executed, {} logged, {} NaN entries, {} Inf entries",
146            self.step_count,
147            self.log.len(),
148            nan_count,
149            inf_count,
150        )
151    }
152
153    // ── private helpers ──────────────────────────────────────────────────────
154
155    fn should_log(&self, step: usize, iv: &IntermediateValue) -> bool {
156        self.conditions.iter().any(|cond| match cond {
157            BreakpointCondition::Always => true,
158            BreakpointCondition::NodeIndex(idx) => *idx == step,
159            BreakpointCondition::OnNaN => iv.has_nan,
160            BreakpointCondition::OnInf => iv.has_inf,
161        })
162    }
163
164    fn record_if_triggered(&mut self, iv: IntermediateValue) {
165        if self.should_log(iv.step, &iv) {
166            self.log.push(iv);
167        }
168    }
169}
170
171/// `TlExecutor` implementation for executors whose tensor type is `ArrayD<f64>`.
172impl<E> TlExecutor for StepExecutor<E>
173where
174    E: TlExecutor<Tensor = ArrayD<f64>>,
175{
176    type Tensor = ArrayD<f64>;
177    type Error = E::Error;
178
179    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
180        let step = self.step_count;
181        self.step_count += 1;
182        let result = self.inner.einsum(spec, inputs)?;
183        let iv = IntermediateValue::from_tensor(step, &format!("einsum({})", spec), &result);
184        self.record_if_triggered(iv);
185        Ok(result)
186    }
187
188    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
189        let step = self.step_count;
190        self.step_count += 1;
191        let result = self.inner.elem_op(op, x)?;
192        let iv = IntermediateValue::from_tensor(step, &format!("elem_op({:?})", op), &result);
193        self.record_if_triggered(iv);
194        Ok(result)
195    }
196
197    fn elem_op_binary(
198        &mut self,
199        op: ElemOp,
200        x: &Self::Tensor,
201        y: &Self::Tensor,
202    ) -> Result<Self::Tensor, Self::Error> {
203        let step = self.step_count;
204        self.step_count += 1;
205        let result = self.inner.elem_op_binary(op, x, y)?;
206        let iv =
207            IntermediateValue::from_tensor(step, &format!("elem_op_binary({:?})", op), &result);
208        self.record_if_triggered(iv);
209        Ok(result)
210    }
211
212    fn reduce(
213        &mut self,
214        op: ReduceOp,
215        x: &Self::Tensor,
216        axes: &[usize],
217    ) -> Result<Self::Tensor, Self::Error> {
218        let step = self.step_count;
219        self.step_count += 1;
220        let result = self.inner.reduce(op, x, axes)?;
221        let iv = IntermediateValue::from_tensor(step, &format!("reduce({:?})", op), &result);
222        self.record_if_triggered(iv);
223        Ok(result)
224    }
225}
226
227// ── Tests ────────────────────────────────────────────────────────────────────
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::error::ExecutorError;
233    use ndarray::{Array, IxDyn};
234
235    // Minimal executor whose Tensor = ArrayD<f64> for testing StepExecutor.
236    struct ArrayExecutor;
237
238    impl TlExecutor for ArrayExecutor {
239        type Tensor = ArrayD<f64>;
240        type Error = ExecutorError;
241
242        fn einsum(
243            &mut self,
244            _spec: &str,
245            inputs: &[Self::Tensor],
246        ) -> Result<Self::Tensor, Self::Error> {
247            Ok(inputs[0].clone())
248        }
249
250        fn elem_op(&mut self, _op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
251            Ok(x.clone())
252        }
253
254        fn elem_op_binary(
255            &mut self,
256            _op: ElemOp,
257            x: &Self::Tensor,
258            _y: &Self::Tensor,
259        ) -> Result<Self::Tensor, Self::Error> {
260            Ok(x.clone())
261        }
262
263        fn reduce(
264            &mut self,
265            _op: ReduceOp,
266            x: &Self::Tensor,
267            _axes: &[usize],
268        ) -> Result<Self::Tensor, Self::Error> {
269            Ok(x.clone())
270        }
271    }
272
273    fn make_tensor(data: &[f64]) -> ArrayD<f64> {
274        Array::from_shape_vec(IxDyn(&[data.len()]), data.to_vec()).unwrap()
275    }
276
277    #[test]
278    fn test_step_executor_creates() {
279        let exec = StepExecutor::new(ArrayExecutor);
280        assert_eq!(exec.step_count(), 0);
281        assert!(exec.log().is_empty());
282    }
283
284    #[test]
285    fn test_intermediate_value_from_tensor() {
286        let t = make_tensor(&[1.0, 2.0, 3.0, 4.0]);
287        let iv = IntermediateValue::from_tensor(0, "test_op", &t);
288        assert_eq!(iv.step, 0);
289        assert_eq!(iv.operation, "test_op");
290        assert_eq!(iv.element_count, 4);
291        assert!((iv.min - 1.0).abs() < 1e-10);
292        assert!((iv.max - 4.0).abs() < 1e-10);
293        assert!((iv.mean - 2.5).abs() < 1e-10);
294        assert!(!iv.has_nan);
295        assert!(!iv.has_inf);
296    }
297
298    #[test]
299    fn test_always_condition_logs_all() {
300        let mut exec = StepExecutor::new(ArrayExecutor);
301        exec.add_condition(BreakpointCondition::Always);
302        let t = make_tensor(&[1.0, 2.0]);
303        exec.einsum("ij->ij", std::slice::from_ref(&t)).unwrap();
304        exec.elem_op(ElemOp::Relu, &t).unwrap();
305        exec.elem_op_binary(ElemOp::Add, &t, &t).unwrap();
306        assert_eq!(exec.log().len(), 3, "all 3 ops should be logged");
307        assert_eq!(exec.step_count(), 3);
308    }
309
310    #[test]
311    fn test_nan_detection_in_log() {
312        let mut exec = StepExecutor::new(ArrayExecutor);
313        exec.add_condition(BreakpointCondition::OnNaN);
314        // Normal tensor should not be logged.
315        let normal = make_tensor(&[1.0, 2.0]);
316        exec.einsum("i->i", &[normal]).unwrap();
317        assert!(exec.log().is_empty(), "no NaN, should not log");
318
319        // NaN tensor should be logged.
320        let nan_tensor = make_tensor(&[f64::NAN, 1.0]);
321        exec.einsum("i->i", &[nan_tensor]).unwrap();
322        assert_eq!(exec.log().len(), 1, "NaN tensor should be logged");
323        assert!(exec.has_nan_in_log());
324    }
325
326    #[test]
327    fn test_step_count_and_clear() {
328        let mut exec = StepExecutor::new(ArrayExecutor);
329        exec.add_condition(BreakpointCondition::Always);
330        let t = make_tensor(&[1.0]);
331        exec.einsum("i->i", std::slice::from_ref(&t)).unwrap();
332        exec.einsum("i->i", std::slice::from_ref(&t)).unwrap();
333        assert_eq!(exec.step_count(), 2);
334        assert_eq!(exec.log().len(), 2);
335        exec.clear_log();
336        assert_eq!(exec.log().len(), 0);
337        assert_eq!(exec.step_count(), 2, "step_count preserved after clear");
338        let summary = exec.summary();
339        assert!(summary.contains("2 steps"));
340    }
341}