tensorlogic_scirs_backend/
profiled_executor.rs1use crate::{Scirs2Exec, Scirs2Tensor};
4use tensorlogic_infer::{ExecutorError, Profiler, TlAutodiff, TlExecutor, TlProfiledExecutor};
5use tensorlogic_ir::EinsumGraph;
6
7pub struct ProfiledScirs2Exec {
9 executor: Scirs2Exec,
11 profiler: Option<Profiler>,
13}
14
15impl ProfiledScirs2Exec {
16 pub fn new() -> Self {
18 ProfiledScirs2Exec {
19 executor: Scirs2Exec::new(),
20 profiler: Some(Profiler::new()),
21 }
22 }
23
24 pub fn with_memory_pool() -> Self {
26 ProfiledScirs2Exec {
27 executor: Scirs2Exec::with_memory_pool(),
28 profiler: Some(Profiler::new()),
29 }
30 }
31
32 pub fn executor(&self) -> &Scirs2Exec {
34 &self.executor
35 }
36
37 pub fn executor_mut(&mut self) -> &mut Scirs2Exec {
39 &mut self.executor
40 }
41}
42
43impl Default for ProfiledScirs2Exec {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl TlExecutor for ProfiledScirs2Exec {
50 type Tensor = Scirs2Tensor;
51 type Error = ExecutorError;
52
53 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
54 if let Some(profiler) = &mut self.profiler {
55 profiler.time_op(format!("einsum({})", spec), || {
56 self.executor.einsum(spec, inputs)
57 })
58 } else {
59 self.executor.einsum(spec, inputs)
60 }
61 }
62
63 fn elem_op(
64 &mut self,
65 op: tensorlogic_infer::ElemOp,
66 x: &Self::Tensor,
67 ) -> Result<Self::Tensor, Self::Error> {
68 if let Some(profiler) = &mut self.profiler {
69 profiler.time_op(format!("elem_op({:?})", op), || {
70 self.executor.elem_op(op, x)
71 })
72 } else {
73 self.executor.elem_op(op, x)
74 }
75 }
76
77 fn elem_op_binary(
78 &mut self,
79 op: tensorlogic_infer::ElemOp,
80 x: &Self::Tensor,
81 y: &Self::Tensor,
82 ) -> Result<Self::Tensor, Self::Error> {
83 if let Some(profiler) = &mut self.profiler {
84 profiler.time_op(format!("elem_op_binary({:?})", op), || {
85 self.executor.elem_op_binary(op, x, y)
86 })
87 } else {
88 self.executor.elem_op_binary(op, x, y)
89 }
90 }
91
92 fn reduce(
93 &mut self,
94 op: tensorlogic_infer::ReduceOp,
95 x: &Self::Tensor,
96 axes: &[usize],
97 ) -> Result<Self::Tensor, Self::Error> {
98 if let Some(profiler) = &mut self.profiler {
99 profiler.time_op(format!("reduce({:?})", op), || {
100 self.executor.reduce(op, x, axes)
101 })
102 } else {
103 self.executor.reduce(op, x, axes)
104 }
105 }
106}
107
108impl TlAutodiff for ProfiledScirs2Exec {
109 type Tape = <Scirs2Exec as TlAutodiff>::Tape;
110
111 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
112 if let Some(profiler) = &mut self.profiler {
113 profiler.time_op("forward_pass", || self.executor.forward(graph))
114 } else {
115 self.executor.forward(graph)
116 }
117 }
118
119 fn backward(
120 &mut self,
121 graph: &EinsumGraph,
122 loss_grad: &Self::Tensor,
123 ) -> Result<Self::Tape, Self::Error> {
124 if let Some(profiler) = &mut self.profiler {
125 profiler.time_op("backward_pass", || self.executor.backward(graph, loss_grad))
126 } else {
127 self.executor.backward(graph, loss_grad)
128 }
129 }
130}
131
132impl TlProfiledExecutor for ProfiledScirs2Exec {
133 fn profiler(&self) -> Option<&Profiler> {
134 self.profiler.as_ref()
135 }
136
137 fn profiler_mut(&mut self) -> Option<&mut Profiler> {
138 self.profiler.as_mut()
139 }
140
141 fn enable_profiling(&mut self) {
142 if self.profiler.is_none() {
143 self.profiler = Some(Profiler::new());
144 }
145 }
146
147 fn disable_profiling(&mut self) {
148 self.profiler = None;
149 }
150}
151
152#[cfg(all(test, feature = "integration-tests"))]
153mod tests {
154 use super::*;
155 use scirs2_core::ndarray::ArrayD;
156 use tensorlogic_compiler::compile_to_einsum;
157 use tensorlogic_infer::ElemOp;
158 use tensorlogic_ir::{TLExpr, Term};
159
160 fn create_test_tensor(shape: &[usize], value: f64) -> ArrayD<f64> {
161 ArrayD::from_elem(shape.to_vec(), value)
162 }
163
164 #[test]
165 fn test_profiled_executor_basic() {
166 let mut executor = ProfiledScirs2Exec::new();
167
168 let a = create_test_tensor(&[3, 3], 1.0);
169 let b = create_test_tensor(&[3, 3], 2.0);
170
171 let _result = executor
173 .einsum("ij,jk->ik", &[a.clone(), b.clone()])
174 .unwrap();
175
176 assert!(executor.profiler().is_some());
178 }
179
180 #[test]
181 fn test_profiled_forward_pass() {
182 let x = TLExpr::pred("x", vec![Term::var("i")]);
183 let y = TLExpr::pred("y", vec![Term::var("i")]);
184 let expr = TLExpr::add(x, y);
185 let graph = compile_to_einsum(&expr).unwrap();
186
187 let mut executor = ProfiledScirs2Exec::new();
188 executor
189 .executor_mut()
190 .add_tensor(graph.tensors[0].clone(), create_test_tensor(&[5], 1.0));
191 executor
192 .executor_mut()
193 .add_tensor(graph.tensors[1].clone(), create_test_tensor(&[5], 2.0));
194
195 let _result = executor.forward(&graph).unwrap();
196
197 assert!(executor.profiler().is_some());
199 }
200
201 #[test]
202 fn test_enable_disable_profiling() {
203 let mut executor = ProfiledScirs2Exec::new();
204
205 let a = create_test_tensor(&[2, 2], 1.0);
206
207 let _result = executor.elem_op(ElemOp::Relu, &a).unwrap();
209 assert!(executor.profiler().is_some());
210
211 executor.disable_profiling();
213 assert!(executor.profiler().is_none());
214
215 executor.enable_profiling();
217 assert!(executor.profiler().is_some());
218 }
219}