Skip to main content

tensorlogic_scirs_backend/
profiled_executor.rs

1//! Performance profiling support for execution monitoring.
2
3use crate::{Scirs2Exec, Scirs2Tensor};
4use tensorlogic_infer::{ExecutorError, Profiler, TlAutodiff, TlExecutor, TlProfiledExecutor};
5use tensorlogic_ir::EinsumGraph;
6
7/// Profiling-enabled executor wrapper
8pub struct ProfiledScirs2Exec {
9    /// Underlying executor
10    executor: Scirs2Exec,
11    /// Profiler for tracking operations
12    profiler: Option<Profiler>,
13}
14
15impl ProfiledScirs2Exec {
16    /// Create a new profiled executor
17    pub fn new() -> Self {
18        ProfiledScirs2Exec {
19            executor: Scirs2Exec::new(),
20            profiler: Some(Profiler::new()),
21        }
22    }
23
24    /// Create with memory pooling enabled
25    pub fn with_memory_pool() -> Self {
26        ProfiledScirs2Exec {
27            executor: Scirs2Exec::with_memory_pool(),
28            profiler: Some(Profiler::new()),
29        }
30    }
31
32    /// Access the underlying executor
33    pub fn executor(&self) -> &Scirs2Exec {
34        &self.executor
35    }
36
37    /// Access the underlying executor mutably
38    pub fn executor_mut(&mut self) -> &mut Scirs2Exec {
39        &mut self.executor
40    }
41}
42
43impl Default for ProfiledScirs2Exec {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl TlExecutor for ProfiledScirs2Exec {
50    type Tensor = Scirs2Tensor;
51    type Error = ExecutorError;
52
53    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
54        if let Some(profiler) = &mut self.profiler {
55            profiler.time_op(format!("einsum({})", spec), || {
56                self.executor.einsum(spec, inputs)
57            })
58        } else {
59            self.executor.einsum(spec, inputs)
60        }
61    }
62
63    fn elem_op(
64        &mut self,
65        op: tensorlogic_infer::ElemOp,
66        x: &Self::Tensor,
67    ) -> Result<Self::Tensor, Self::Error> {
68        if let Some(profiler) = &mut self.profiler {
69            profiler.time_op(format!("elem_op({:?})", op), || {
70                self.executor.elem_op(op, x)
71            })
72        } else {
73            self.executor.elem_op(op, x)
74        }
75    }
76
77    fn elem_op_binary(
78        &mut self,
79        op: tensorlogic_infer::ElemOp,
80        x: &Self::Tensor,
81        y: &Self::Tensor,
82    ) -> Result<Self::Tensor, Self::Error> {
83        if let Some(profiler) = &mut self.profiler {
84            profiler.time_op(format!("elem_op_binary({:?})", op), || {
85                self.executor.elem_op_binary(op, x, y)
86            })
87        } else {
88            self.executor.elem_op_binary(op, x, y)
89        }
90    }
91
92    fn reduce(
93        &mut self,
94        op: tensorlogic_infer::ReduceOp,
95        x: &Self::Tensor,
96        axes: &[usize],
97    ) -> Result<Self::Tensor, Self::Error> {
98        if let Some(profiler) = &mut self.profiler {
99            profiler.time_op(format!("reduce({:?})", op), || {
100                self.executor.reduce(op, x, axes)
101            })
102        } else {
103            self.executor.reduce(op, x, axes)
104        }
105    }
106}
107
108impl TlAutodiff for ProfiledScirs2Exec {
109    type Tape = <Scirs2Exec as TlAutodiff>::Tape;
110
111    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
112        if let Some(profiler) = &mut self.profiler {
113            profiler.time_op("forward_pass", || self.executor.forward(graph))
114        } else {
115            self.executor.forward(graph)
116        }
117    }
118
119    fn backward(
120        &mut self,
121        graph: &EinsumGraph,
122        loss_grad: &Self::Tensor,
123    ) -> Result<Self::Tape, Self::Error> {
124        if let Some(profiler) = &mut self.profiler {
125            profiler.time_op("backward_pass", || self.executor.backward(graph, loss_grad))
126        } else {
127            self.executor.backward(graph, loss_grad)
128        }
129    }
130}
131
132impl TlProfiledExecutor for ProfiledScirs2Exec {
133    fn profiler(&self) -> Option<&Profiler> {
134        self.profiler.as_ref()
135    }
136
137    fn profiler_mut(&mut self) -> Option<&mut Profiler> {
138        self.profiler.as_mut()
139    }
140
141    fn enable_profiling(&mut self) {
142        if self.profiler.is_none() {
143            self.profiler = Some(Profiler::new());
144        }
145    }
146
147    fn disable_profiling(&mut self) {
148        self.profiler = None;
149    }
150}
151
152#[cfg(all(test, feature = "integration-tests"))]
153mod tests {
154    use super::*;
155    use scirs2_core::ndarray::ArrayD;
156    use tensorlogic_compiler::compile_to_einsum;
157    use tensorlogic_infer::ElemOp;
158    use tensorlogic_ir::{TLExpr, Term};
159
160    fn create_test_tensor(shape: &[usize], value: f64) -> ArrayD<f64> {
161        ArrayD::from_elem(shape.to_vec(), value)
162    }
163
164    #[test]
165    fn test_profiled_executor_basic() {
166        let mut executor = ProfiledScirs2Exec::new();
167
168        let a = create_test_tensor(&[3, 3], 1.0);
169        let b = create_test_tensor(&[3, 3], 2.0);
170
171        // Execute an einsum operation
172        let _result = executor
173            .einsum("ij,jk->ik", &[a.clone(), b.clone()])
174            .unwrap();
175
176        // Check that profiling recorded the operation
177        assert!(executor.profiler().is_some());
178    }
179
180    #[test]
181    fn test_profiled_forward_pass() {
182        let x = TLExpr::pred("x", vec![Term::var("i")]);
183        let y = TLExpr::pred("y", vec![Term::var("i")]);
184        let expr = TLExpr::add(x, y);
185        let graph = compile_to_einsum(&expr).unwrap();
186
187        let mut executor = ProfiledScirs2Exec::new();
188        executor
189            .executor_mut()
190            .add_tensor(graph.tensors[0].clone(), create_test_tensor(&[5], 1.0));
191        executor
192            .executor_mut()
193            .add_tensor(graph.tensors[1].clone(), create_test_tensor(&[5], 2.0));
194
195        let _result = executor.forward(&graph).unwrap();
196
197        // Check profiling is active
198        assert!(executor.profiler().is_some());
199    }
200
201    #[test]
202    fn test_enable_disable_profiling() {
203        let mut executor = ProfiledScirs2Exec::new();
204
205        let a = create_test_tensor(&[2, 2], 1.0);
206
207        // Execute with profiling enabled
208        let _result = executor.elem_op(ElemOp::Relu, &a).unwrap();
209        assert!(executor.profiler().is_some());
210
211        // Disable profiling
212        executor.disable_profiling();
213        assert!(executor.profiler().is_none());
214
215        // Re-enable profiling
216        executor.enable_profiling();
217        assert!(executor.profiler().is_some());
218    }
219}