1use std::collections::HashMap;
4
5use tensorlogic_ir::{EinsumGraph, OpType};
6
7use crate::batch::{BatchResult, TlBatchExecutor};
8use crate::capabilities::{BackendCapabilities, DType, DeviceType, Feature, TlCapabilities};
9use crate::dummy_tensor::DummyTensor;
10use crate::error::ExecutorError;
11use crate::ops::{ElemOp, ReduceOp};
12use crate::profiling::{Profiler, TlProfiledExecutor};
13use crate::traits::{TlAutodiff, TlExecutor};
14
15pub struct DummyExecutor {
20 pub tensors: HashMap<String, DummyTensor>,
21 capabilities: BackendCapabilities,
22 profiler: Option<Profiler>,
23}
24
25impl DummyExecutor {
26 pub fn new() -> Self {
27 let capabilities = BackendCapabilities::new("DummyExecutor", "0.1.0")
28 .with_device(DeviceType::CPU)
29 .with_dtype(DType::F64)
30 .with_feature(Feature::Autodiff)
31 .with_feature(Feature::BatchExecution)
32 .with_max_dims(16);
33
34 DummyExecutor {
35 tensors: HashMap::new(),
36 capabilities,
37 profiler: None,
38 }
39 }
40
41 pub fn add_tensor(&mut self, name: impl Into<String>, tensor: DummyTensor) {
42 self.tensors.insert(name.into(), tensor);
43 }
44
45 pub fn get_tensor(&self, name: &str) -> Option<&DummyTensor> {
46 self.tensors.get(name)
47 }
48}
49
50impl Default for DummyExecutor {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl TlExecutor for DummyExecutor {
57 type Tensor = DummyTensor;
58 type Error = ExecutorError;
59
60 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
61 if inputs.is_empty() {
62 return Err(ExecutorError::InvalidEinsumSpec(
63 "No input tensors".to_string(),
64 ));
65 }
66
67 let output_shape = inputs[0].shape.clone();
69 let output_size: usize = output_shape.iter().product();
70
71 let result_data = vec![1.0; output_size];
72
73 Ok(DummyTensor {
74 name: format!("einsum({})", spec),
75 shape: output_shape,
76 data: result_data,
77 })
78 }
79
80 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
81 match op {
83 ElemOp::Relu | ElemOp::Sigmoid | ElemOp::OneMinus => {}
84 _ => {
85 return Err(ExecutorError::UnsupportedOperation(format!(
86 "Operation {:?} is not a unary operation",
87 op
88 )))
89 }
90 }
91
92 let result_data: Vec<f64> = x
93 .data
94 .iter()
95 .map(|&val| match op {
96 ElemOp::Relu => val.max(0.0),
97 ElemOp::Sigmoid => 1.0 / (1.0 + (-val).exp()),
98 ElemOp::OneMinus => 1.0 - val,
99 _ => unreachable!(),
100 })
101 .collect();
102
103 Ok(DummyTensor {
104 name: format!("{:?}({})", op, x.name),
105 shape: x.shape.clone(),
106 data: result_data,
107 })
108 }
109
110 fn elem_op_binary(
111 &mut self,
112 op: ElemOp,
113 x: &Self::Tensor,
114 y: &Self::Tensor,
115 ) -> Result<Self::Tensor, Self::Error> {
116 if x.shape != y.shape {
117 return Err(ExecutorError::ShapeMismatch(format!(
118 "{:?} vs {:?}",
119 x.shape, y.shape
120 )));
121 }
122
123 let result_data: Vec<f64> = x
124 .data
125 .iter()
126 .zip(y.data.iter())
127 .map(|(&a, &b)| match op {
128 ElemOp::Add => a + b,
130 ElemOp::Subtract => a - b,
131 ElemOp::Multiply => a * b,
132 ElemOp::Divide => {
133 if b.abs() < 1e-10 {
134 0.0 } else {
136 a / b
137 }
138 }
139 ElemOp::Min => a.min(b),
140 ElemOp::Max => a.max(b),
141
142 ElemOp::Eq => {
144 if (a - b).abs() < 1e-10 {
145 1.0
146 } else {
147 0.0
148 }
149 }
150 ElemOp::Lt => {
151 if a < b {
152 1.0
153 } else {
154 0.0
155 }
156 }
157 ElemOp::Gt => {
158 if a > b {
159 1.0
160 } else {
161 0.0
162 }
163 }
164 ElemOp::Lte => {
165 if a <= b {
166 1.0
167 } else {
168 0.0
169 }
170 }
171 ElemOp::Gte => {
172 if a >= b {
173 1.0
174 } else {
175 0.0
176 }
177 }
178
179 ElemOp::OrMax => a.max(b),
181 ElemOp::OrProbSum => a + b - a * b, ElemOp::Nand => 1.0 - (a * b),
183 ElemOp::Nor => 1.0 - a.max(b),
184 ElemOp::Xor => (a - b).abs(), ElemOp::Relu | ElemOp::Sigmoid | ElemOp::OneMinus => {
188 unreachable!("Unary operation {:?} called on binary", op)
189 }
190 })
191 .collect();
192
193 Ok(DummyTensor {
194 name: format!("{:?}({},{})", op, x.name, y.name),
195 shape: x.shape.clone(),
196 data: result_data,
197 })
198 }
199
200 fn reduce(
201 &mut self,
202 op: ReduceOp,
203 x: &Self::Tensor,
204 axes: &[usize],
205 ) -> Result<Self::Tensor, Self::Error> {
206 if axes.is_empty() {
207 return Ok(x.clone());
208 }
209
210 let rank = x.shape.len();
211 let mut output_shape = x.shape.clone();
212 for &axis in axes.iter().rev() {
213 if axis >= rank {
214 return Err(ExecutorError::InvalidAxis { axis, rank });
215 }
216 output_shape.remove(axis);
217 }
218
219 let output_size: usize = if output_shape.is_empty() {
220 1
221 } else {
222 output_shape.iter().product()
223 };
224
225 let result_data = match op {
226 ReduceOp::Sum => vec![x.data.iter().sum::<f64>(); output_size],
227 ReduceOp::Max => {
228 vec![x.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); output_size]
229 }
230 ReduceOp::Min => vec![x.data.iter().fold(f64::INFINITY, |a, &b| a.min(b)); output_size],
231 ReduceOp::Mean => vec![x.data.iter().sum::<f64>() / x.size() as f64; output_size],
232 ReduceOp::Product => vec![x.data.iter().product::<f64>(); output_size],
233 };
234
235 Ok(DummyTensor {
236 name: format!("{:?}({},axes={:?})", op, x.name, axes),
237 shape: if output_shape.is_empty() {
238 vec![1]
239 } else {
240 output_shape
241 },
242 data: result_data,
243 })
244 }
245}
246
247impl TlCapabilities for DummyExecutor {
249 fn capabilities(&self) -> &BackendCapabilities {
250 &self.capabilities
251 }
252
253 fn supports_elem_op(&self, _op: ElemOp) -> bool {
254 true }
256
257 fn supports_reduce_op(&self, _op: ReduceOp) -> bool {
258 true }
260
261 fn supports_einsum(&self, _spec: &str) -> bool {
262 true }
264}
265
266impl TlProfiledExecutor for DummyExecutor {
268 fn profiler(&self) -> Option<&Profiler> {
269 self.profiler.as_ref()
270 }
271
272 fn profiler_mut(&mut self) -> Option<&mut Profiler> {
273 self.profiler.as_mut()
274 }
275
276 fn enable_profiling(&mut self) {
277 let mut profiler = Profiler::new();
278 profiler.start();
279 self.profiler = Some(profiler);
280 }
281
282 fn disable_profiling(&mut self) {
283 if let Some(mut profiler) = self.profiler.take() {
284 profiler.stop();
285 }
286 }
287}
288
289impl TlBatchExecutor for DummyExecutor {
291 type Tensor = DummyTensor;
292 type Error = ExecutorError;
293
294 fn execute_batch(
295 &mut self,
296 graph: &EinsumGraph,
297 batch_inputs: Vec<Vec<Self::Tensor>>,
298 ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
299 if batch_inputs.is_empty() {
300 return Err(ExecutorError::EmptyInput(
301 "Batch inputs cannot be empty".to_string(),
302 ));
303 }
304
305 let mut outputs = Vec::with_capacity(batch_inputs.len());
306 for inputs in batch_inputs {
307 let output = self.execute_graph_internal(graph, &inputs)?;
308 outputs.push(output);
309 }
310
311 Ok(BatchResult::new(outputs))
312 }
313
314 fn execute_batch_parallel(
315 &mut self,
316 graph: &EinsumGraph,
317 batch_inputs: Vec<Vec<Self::Tensor>>,
318 _num_threads: Option<usize>,
319 ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
320 self.execute_batch(graph, batch_inputs)
323 }
324
325 fn optimal_batch_size(&self) -> usize {
326 16 }
328}
329
330impl TlAutodiff for DummyExecutor {
332 type Tape = HashMap<usize, DummyTensor>;
333
334 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
335 if graph.nodes.is_empty() {
336 return Err(ExecutorError::EmptyInput(
337 "Graph has no nodes to execute".to_string(),
338 ));
339 }
340
341 let mut tensors: HashMap<usize, DummyTensor> = HashMap::new();
343
344 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
347 tensors.insert(idx, DummyTensor::ones(tensor_name.clone(), vec![10]));
349 }
350
351 for (node_idx, node) in graph.nodes.iter().enumerate() {
353 let output_idx = graph.tensors.len() + node_idx;
354 let output = self.execute_node_internal(node, &tensors)?;
355 tensors.insert(output_idx, output);
356 }
357
358 let output_idx = if graph.outputs.is_empty() {
360 graph.tensors.len() + graph.nodes.len() - 1
361 } else {
362 graph.outputs[0]
363 };
364
365 tensors
366 .remove(&output_idx)
367 .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor".to_string()))
368 }
369
370 fn backward(
371 &mut self,
372 graph: &EinsumGraph,
373 _loss: &Self::Tensor,
374 ) -> Result<Self::Tape, Self::Error> {
375 let mut gradients = HashMap::new();
377
378 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
379 gradients.insert(
380 idx,
381 DummyTensor::ones(format!("grad_{}", tensor_name), vec![10]),
382 );
383 }
384
385 Ok(gradients)
386 }
387}
388
389impl DummyExecutor {
391 fn execute_graph_internal(
392 &mut self,
393 graph: &EinsumGraph,
394 _inputs: &[DummyTensor],
395 ) -> Result<DummyTensor, ExecutorError> {
396 self.forward(graph)
398 }
399
400 fn execute_node_internal(
401 &mut self,
402 node: &tensorlogic_ir::EinsumNode,
403 tensors: &HashMap<usize, DummyTensor>,
404 ) -> Result<DummyTensor, ExecutorError> {
405 match &node.op {
406 OpType::Einsum { spec } => {
407 let inputs: Vec<DummyTensor> =
408 node.inputs
409 .iter()
410 .map(|&idx| {
411 tensors.get(&idx).cloned().ok_or_else(|| {
412 ExecutorError::TensorNotFound(format!("Tensor {}", idx))
413 })
414 })
415 .collect::<Result<Vec<_>, _>>()?;
416
417 self.einsum(spec, &inputs)
418 }
419 OpType::ElemUnary { op } => {
420 if node.inputs.is_empty() {
421 return Err(ExecutorError::EmptyInput(
422 "ElemUnary requires an input".to_string(),
423 ));
424 }
425 let input = tensors.get(&node.inputs[0]).ok_or_else(|| {
426 ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
427 })?;
428 let elem_op = Self::parse_elem_op(op)?;
429 self.elem_op(elem_op, input)
430 }
431 OpType::ElemBinary { op } => {
432 if node.inputs.len() < 2 {
433 return Err(ExecutorError::EmptyInput(
434 "ElemBinary requires two inputs".to_string(),
435 ));
436 }
437 let input1 = tensors.get(&node.inputs[0]).ok_or_else(|| {
438 ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
439 })?;
440 let input2 = tensors.get(&node.inputs[1]).ok_or_else(|| {
441 ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[1]))
442 })?;
443 let elem_op = Self::parse_elem_op(op)?;
444 self.elem_op_binary(elem_op, input1, input2)
445 }
446 OpType::Reduce { op, axes } => {
447 if node.inputs.is_empty() {
448 return Err(ExecutorError::EmptyInput(
449 "Reduce requires an input".to_string(),
450 ));
451 }
452 let input = tensors.get(&node.inputs[0]).ok_or_else(|| {
453 ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
454 })?;
455 let reduce_op = Self::parse_reduce_op(op)?;
456 self.reduce(reduce_op, input, axes)
457 }
458 }
459 }
460
461 fn parse_elem_op(op_str: &str) -> Result<ElemOp, ExecutorError> {
462 match op_str.to_lowercase().as_str() {
463 "relu" => Ok(ElemOp::Relu),
464 "sigmoid" => Ok(ElemOp::Sigmoid),
465 "oneminus" | "one_minus" => Ok(ElemOp::OneMinus),
466 "add" => Ok(ElemOp::Add),
467 "subtract" | "sub" => Ok(ElemOp::Subtract),
468 "multiply" | "mul" => Ok(ElemOp::Multiply),
469 "divide" | "div" => Ok(ElemOp::Divide),
470 "eq" | "equal" => Ok(ElemOp::Eq),
471 "lt" | "less" => Ok(ElemOp::Lt),
472 "gt" | "greater" => Ok(ElemOp::Gt),
473 "lte" | "le" => Ok(ElemOp::Lte),
474 "gte" | "ge" => Ok(ElemOp::Gte),
475 "ormax" | "or_max" => Ok(ElemOp::OrMax),
476 "orprobsum" | "or_prob_sum" => Ok(ElemOp::OrProbSum),
477 "nand" => Ok(ElemOp::Nand),
478 "nor" => Ok(ElemOp::Nor),
479 "xor" => Ok(ElemOp::Xor),
480 _ => Err(ExecutorError::UnsupportedOperation(format!(
481 "Unknown element operation: {}",
482 op_str
483 ))),
484 }
485 }
486
487 fn parse_reduce_op(op_str: &str) -> Result<ReduceOp, ExecutorError> {
488 match op_str.to_lowercase().as_str() {
489 "sum" => Ok(ReduceOp::Sum),
490 "max" => Ok(ReduceOp::Max),
491 "min" => Ok(ReduceOp::Min),
492 "mean" => Ok(ReduceOp::Mean),
493 "product" | "prod" => Ok(ReduceOp::Product),
494 _ => Err(ExecutorError::UnsupportedOperation(format!(
495 "Unknown reduce operation: {}",
496 op_str
497 ))),
498 }
499 }
500}