1use ndarray::ArrayD;
7
8use crate::ops::{ElemOp, ReduceOp};
9use crate::traits::TlExecutor;
10
11#[derive(Debug, Clone)]
13pub enum BreakpointCondition {
14 NodeIndex(usize),
16 OnNaN,
18 OnInf,
20 Always,
22}
23
24#[derive(Debug, Clone)]
26pub struct IntermediateValue {
27 pub step: usize,
29 pub operation: String,
31 pub shape: Vec<usize>,
33 pub min: f64,
35 pub max: f64,
37 pub mean: f64,
39 pub has_nan: bool,
41 pub has_inf: bool,
43 pub element_count: usize,
45}
46
47impl IntermediateValue {
48 pub fn from_tensor(step: usize, op: &str, tensor: &ArrayD<f64>) -> Self {
50 let element_count = tensor.len();
51 let has_nan = tensor.iter().any(|x| x.is_nan());
52 let has_inf = tensor.iter().any(|x| x.is_infinite());
53
54 let (min, max, sum) = tensor.iter().cloned().fold(
55 (f64::INFINITY, f64::NEG_INFINITY, 0.0f64),
56 |(mn, mx, s), v| (mn.min(v), mx.max(v), s + v),
57 );
58
59 let (min, max) = if element_count == 0 {
60 (0.0, 0.0)
61 } else {
62 (min, max)
63 };
64
65 let mean = if element_count == 0 {
66 0.0
67 } else {
68 sum / element_count as f64
69 };
70
71 Self {
72 step,
73 operation: op.to_owned(),
74 shape: tensor.shape().to_vec(),
75 min,
76 max,
77 mean,
78 has_nan,
79 has_inf,
80 element_count,
81 }
82 }
83}
84
85pub struct StepExecutor<E> {
91 pub inner: E,
93 conditions: Vec<BreakpointCondition>,
94 pub log: Vec<IntermediateValue>,
96 step_count: usize,
97}
98
99impl<E> StepExecutor<E> {
100 pub fn new(inner: E) -> Self {
102 Self {
103 inner,
104 conditions: Vec::new(),
105 log: Vec::new(),
106 step_count: 0,
107 }
108 }
109
110 pub fn add_condition(&mut self, cond: BreakpointCondition) {
112 self.conditions.push(cond);
113 }
114
115 pub fn log(&self) -> &[IntermediateValue] {
117 &self.log
118 }
119
120 pub fn step_count(&self) -> usize {
122 self.step_count
123 }
124
125 pub fn clear_log(&mut self) {
127 self.log.clear();
128 }
129
130 pub fn has_nan_in_log(&self) -> bool {
132 self.log.iter().any(|v| v.has_nan)
133 }
134
135 pub fn has_inf_in_log(&self) -> bool {
137 self.log.iter().any(|v| v.has_inf)
138 }
139
140 pub fn summary(&self) -> String {
142 let nan_count = self.log.iter().filter(|v| v.has_nan).count();
143 let inf_count = self.log.iter().filter(|v| v.has_inf).count();
144 format!(
145 "StepExecutor: {} steps executed, {} logged, {} NaN entries, {} Inf entries",
146 self.step_count,
147 self.log.len(),
148 nan_count,
149 inf_count,
150 )
151 }
152
153 fn should_log(&self, step: usize, iv: &IntermediateValue) -> bool {
156 self.conditions.iter().any(|cond| match cond {
157 BreakpointCondition::Always => true,
158 BreakpointCondition::NodeIndex(idx) => *idx == step,
159 BreakpointCondition::OnNaN => iv.has_nan,
160 BreakpointCondition::OnInf => iv.has_inf,
161 })
162 }
163
164 fn record_if_triggered(&mut self, iv: IntermediateValue) {
165 if self.should_log(iv.step, &iv) {
166 self.log.push(iv);
167 }
168 }
169}
170
171impl<E> TlExecutor for StepExecutor<E>
173where
174 E: TlExecutor<Tensor = ArrayD<f64>>,
175{
176 type Tensor = ArrayD<f64>;
177 type Error = E::Error;
178
179 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
180 let step = self.step_count;
181 self.step_count += 1;
182 let result = self.inner.einsum(spec, inputs)?;
183 let iv = IntermediateValue::from_tensor(step, &format!("einsum({})", spec), &result);
184 self.record_if_triggered(iv);
185 Ok(result)
186 }
187
188 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
189 let step = self.step_count;
190 self.step_count += 1;
191 let result = self.inner.elem_op(op, x)?;
192 let iv = IntermediateValue::from_tensor(step, &format!("elem_op({:?})", op), &result);
193 self.record_if_triggered(iv);
194 Ok(result)
195 }
196
197 fn elem_op_binary(
198 &mut self,
199 op: ElemOp,
200 x: &Self::Tensor,
201 y: &Self::Tensor,
202 ) -> Result<Self::Tensor, Self::Error> {
203 let step = self.step_count;
204 self.step_count += 1;
205 let result = self.inner.elem_op_binary(op, x, y)?;
206 let iv =
207 IntermediateValue::from_tensor(step, &format!("elem_op_binary({:?})", op), &result);
208 self.record_if_triggered(iv);
209 Ok(result)
210 }
211
212 fn reduce(
213 &mut self,
214 op: ReduceOp,
215 x: &Self::Tensor,
216 axes: &[usize],
217 ) -> Result<Self::Tensor, Self::Error> {
218 let step = self.step_count;
219 self.step_count += 1;
220 let result = self.inner.reduce(op, x, axes)?;
221 let iv = IntermediateValue::from_tensor(step, &format!("reduce({:?})", op), &result);
222 self.record_if_triggered(iv);
223 Ok(result)
224 }
225}
226
227#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::error::ExecutorError;
233 use ndarray::{Array, IxDyn};
234
235 struct ArrayExecutor;
237
238 impl TlExecutor for ArrayExecutor {
239 type Tensor = ArrayD<f64>;
240 type Error = ExecutorError;
241
242 fn einsum(
243 &mut self,
244 _spec: &str,
245 inputs: &[Self::Tensor],
246 ) -> Result<Self::Tensor, Self::Error> {
247 Ok(inputs[0].clone())
248 }
249
250 fn elem_op(&mut self, _op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
251 Ok(x.clone())
252 }
253
254 fn elem_op_binary(
255 &mut self,
256 _op: ElemOp,
257 x: &Self::Tensor,
258 _y: &Self::Tensor,
259 ) -> Result<Self::Tensor, Self::Error> {
260 Ok(x.clone())
261 }
262
263 fn reduce(
264 &mut self,
265 _op: ReduceOp,
266 x: &Self::Tensor,
267 _axes: &[usize],
268 ) -> Result<Self::Tensor, Self::Error> {
269 Ok(x.clone())
270 }
271 }
272
273 fn make_tensor(data: &[f64]) -> ArrayD<f64> {
274 Array::from_shape_vec(IxDyn(&[data.len()]), data.to_vec()).unwrap()
275 }
276
277 #[test]
278 fn test_step_executor_creates() {
279 let exec = StepExecutor::new(ArrayExecutor);
280 assert_eq!(exec.step_count(), 0);
281 assert!(exec.log().is_empty());
282 }
283
284 #[test]
285 fn test_intermediate_value_from_tensor() {
286 let t = make_tensor(&[1.0, 2.0, 3.0, 4.0]);
287 let iv = IntermediateValue::from_tensor(0, "test_op", &t);
288 assert_eq!(iv.step, 0);
289 assert_eq!(iv.operation, "test_op");
290 assert_eq!(iv.element_count, 4);
291 assert!((iv.min - 1.0).abs() < 1e-10);
292 assert!((iv.max - 4.0).abs() < 1e-10);
293 assert!((iv.mean - 2.5).abs() < 1e-10);
294 assert!(!iv.has_nan);
295 assert!(!iv.has_inf);
296 }
297
298 #[test]
299 fn test_always_condition_logs_all() {
300 let mut exec = StepExecutor::new(ArrayExecutor);
301 exec.add_condition(BreakpointCondition::Always);
302 let t = make_tensor(&[1.0, 2.0]);
303 exec.einsum("ij->ij", std::slice::from_ref(&t)).unwrap();
304 exec.elem_op(ElemOp::Relu, &t).unwrap();
305 exec.elem_op_binary(ElemOp::Add, &t, &t).unwrap();
306 assert_eq!(exec.log().len(), 3, "all 3 ops should be logged");
307 assert_eq!(exec.step_count(), 3);
308 }
309
310 #[test]
311 fn test_nan_detection_in_log() {
312 let mut exec = StepExecutor::new(ArrayExecutor);
313 exec.add_condition(BreakpointCondition::OnNaN);
314 let normal = make_tensor(&[1.0, 2.0]);
316 exec.einsum("i->i", &[normal]).unwrap();
317 assert!(exec.log().is_empty(), "no NaN, should not log");
318
319 let nan_tensor = make_tensor(&[f64::NAN, 1.0]);
321 exec.einsum("i->i", &[nan_tensor]).unwrap();
322 assert_eq!(exec.log().len(), 1, "NaN tensor should be logged");
323 assert!(exec.has_nan_in_log());
324 }
325
326 #[test]
327 fn test_step_count_and_clear() {
328 let mut exec = StepExecutor::new(ArrayExecutor);
329 exec.add_condition(BreakpointCondition::Always);
330 let t = make_tensor(&[1.0]);
331 exec.einsum("i->i", std::slice::from_ref(&t)).unwrap();
332 exec.einsum("i->i", std::slice::from_ref(&t)).unwrap();
333 assert_eq!(exec.step_count(), 2);
334 assert_eq!(exec.log().len(), 2);
335 exec.clear_log();
336 assert_eq!(exec.log().len(), 0);
337 assert_eq!(exec.step_count(), 2, "step_count preserved after clear");
338 let summary = exec.summary();
339 assert!(summary.contains("2 steps"));
340 }
341}