Skip to main content

tensorlogic_infer/
dummy_executor.rs

1//! Dummy executor implementation for testing.
2
3use std::collections::HashMap;
4
5use tensorlogic_ir::{EinsumGraph, OpType};
6
7use crate::batch::{BatchResult, TlBatchExecutor};
8use crate::capabilities::{BackendCapabilities, DType, DeviceType, Feature, TlCapabilities};
9use crate::dummy_tensor::DummyTensor;
10use crate::error::ExecutorError;
11use crate::ops::{ElemOp, ReduceOp};
12use crate::profiling::{Profiler, TlProfiledExecutor};
13use crate::traits::{TlAutodiff, TlExecutor};
14
15/// Minimal executor implementation for testing and prototyping.
16///
17/// This provides a simple, reference implementation that verifies
18/// the execution logic without requiring heavy dependencies.
19pub struct DummyExecutor {
20    pub tensors: HashMap<String, DummyTensor>,
21    capabilities: BackendCapabilities,
22    profiler: Option<Profiler>,
23}
24
25impl DummyExecutor {
26    pub fn new() -> Self {
27        let capabilities = BackendCapabilities::new("DummyExecutor", "0.1.0")
28            .with_device(DeviceType::CPU)
29            .with_dtype(DType::F64)
30            .with_feature(Feature::Autodiff)
31            .with_feature(Feature::BatchExecution)
32            .with_max_dims(16);
33
34        DummyExecutor {
35            tensors: HashMap::new(),
36            capabilities,
37            profiler: None,
38        }
39    }
40
41    pub fn add_tensor(&mut self, name: impl Into<String>, tensor: DummyTensor) {
42        self.tensors.insert(name.into(), tensor);
43    }
44
45    pub fn get_tensor(&self, name: &str) -> Option<&DummyTensor> {
46        self.tensors.get(name)
47    }
48}
49
50impl Default for DummyExecutor {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl TlExecutor for DummyExecutor {
57    type Tensor = DummyTensor;
58    type Error = ExecutorError;
59
60    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
61        if inputs.is_empty() {
62            return Err(ExecutorError::InvalidEinsumSpec(
63                "No input tensors".to_string(),
64            ));
65        }
66
67        // Simple stub: just return a tensor with the same shape as the first input
68        let output_shape = inputs[0].shape.clone();
69        let output_size: usize = output_shape.iter().product();
70
71        let result_data = vec![1.0; output_size];
72
73        Ok(DummyTensor {
74            name: format!("einsum({})", spec),
75            shape: output_shape,
76            data: result_data,
77        })
78    }
79
80    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
81        // Check if this is actually a unary operation
82        match op {
83            ElemOp::Relu | ElemOp::Sigmoid | ElemOp::OneMinus => {}
84            _ => {
85                return Err(ExecutorError::UnsupportedOperation(format!(
86                    "Operation {:?} is not a unary operation",
87                    op
88                )))
89            }
90        }
91
92        let result_data: Vec<f64> = x
93            .data
94            .iter()
95            .map(|&val| match op {
96                ElemOp::Relu => val.max(0.0),
97                ElemOp::Sigmoid => 1.0 / (1.0 + (-val).exp()),
98                ElemOp::OneMinus => 1.0 - val,
99                _ => unreachable!(),
100            })
101            .collect();
102
103        Ok(DummyTensor {
104            name: format!("{:?}({})", op, x.name),
105            shape: x.shape.clone(),
106            data: result_data,
107        })
108    }
109
110    fn elem_op_binary(
111        &mut self,
112        op: ElemOp,
113        x: &Self::Tensor,
114        y: &Self::Tensor,
115    ) -> Result<Self::Tensor, Self::Error> {
116        if x.shape != y.shape {
117            return Err(ExecutorError::ShapeMismatch(format!(
118                "{:?} vs {:?}",
119                x.shape, y.shape
120            )));
121        }
122
123        let result_data: Vec<f64> = x
124            .data
125            .iter()
126            .zip(y.data.iter())
127            .map(|(&a, &b)| match op {
128                // Arithmetic operations
129                ElemOp::Add => a + b,
130                ElemOp::Subtract => a - b,
131                ElemOp::Multiply => a * b,
132                ElemOp::Divide => {
133                    if b.abs() < 1e-10 {
134                        0.0 // Avoid division by zero
135                    } else {
136                        a / b
137                    }
138                }
139                ElemOp::Min => a.min(b),
140                ElemOp::Max => a.max(b),
141
142                // Comparison operations (return 0.0 or 1.0)
143                ElemOp::Eq => {
144                    if (a - b).abs() < 1e-10 {
145                        1.0
146                    } else {
147                        0.0
148                    }
149                }
150                ElemOp::Lt => {
151                    if a < b {
152                        1.0
153                    } else {
154                        0.0
155                    }
156                }
157                ElemOp::Gt => {
158                    if a > b {
159                        1.0
160                    } else {
161                        0.0
162                    }
163                }
164                ElemOp::Lte => {
165                    if a <= b {
166                        1.0
167                    } else {
168                        0.0
169                    }
170                }
171                ElemOp::Gte => {
172                    if a >= b {
173                        1.0
174                    } else {
175                        0.0
176                    }
177                }
178
179                // Extended logical operations
180                ElemOp::OrMax => a.max(b),
181                ElemOp::OrProbSum => a + b - a * b, // Probabilistic sum: 1 - (1-a)(1-b)
182                ElemOp::Nand => 1.0 - (a * b),
183                ElemOp::Nor => 1.0 - a.max(b),
184                ElemOp::Xor => (a - b).abs(), // Soft XOR: |a - b|
185
186                // Unary operations shouldn't be called on binary
187                ElemOp::Relu | ElemOp::Sigmoid | ElemOp::OneMinus => {
188                    unreachable!("Unary operation {:?} called on binary", op)
189                }
190            })
191            .collect();
192
193        Ok(DummyTensor {
194            name: format!("{:?}({},{})", op, x.name, y.name),
195            shape: x.shape.clone(),
196            data: result_data,
197        })
198    }
199
200    fn reduce(
201        &mut self,
202        op: ReduceOp,
203        x: &Self::Tensor,
204        axes: &[usize],
205    ) -> Result<Self::Tensor, Self::Error> {
206        if axes.is_empty() {
207            return Ok(x.clone());
208        }
209
210        let rank = x.shape.len();
211        let mut output_shape = x.shape.clone();
212        for &axis in axes.iter().rev() {
213            if axis >= rank {
214                return Err(ExecutorError::InvalidAxis { axis, rank });
215            }
216            output_shape.remove(axis);
217        }
218
219        let output_size: usize = if output_shape.is_empty() {
220            1
221        } else {
222            output_shape.iter().product()
223        };
224
225        let result_data = match op {
226            ReduceOp::Sum => vec![x.data.iter().sum::<f64>(); output_size],
227            ReduceOp::Max => {
228                vec![x.data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); output_size]
229            }
230            ReduceOp::Min => vec![x.data.iter().fold(f64::INFINITY, |a, &b| a.min(b)); output_size],
231            ReduceOp::Mean => vec![x.data.iter().sum::<f64>() / x.size() as f64; output_size],
232            ReduceOp::Product => vec![x.data.iter().product::<f64>(); output_size],
233        };
234
235        Ok(DummyTensor {
236            name: format!("{:?}({},axes={:?})", op, x.name, axes),
237            shape: if output_shape.is_empty() {
238                vec![1]
239            } else {
240                output_shape
241            },
242            data: result_data,
243        })
244    }
245}
246
247// TlCapabilities implementation
248impl TlCapabilities for DummyExecutor {
249    fn capabilities(&self) -> &BackendCapabilities {
250        &self.capabilities
251    }
252
253    fn supports_elem_op(&self, _op: ElemOp) -> bool {
254        true // DummyExecutor supports all element ops
255    }
256
257    fn supports_reduce_op(&self, _op: ReduceOp) -> bool {
258        true // DummyExecutor supports all reduce ops
259    }
260
261    fn supports_einsum(&self, _spec: &str) -> bool {
262        true // DummyExecutor has basic einsum support
263    }
264}
265
266// TlProfiledExecutor implementation
267impl TlProfiledExecutor for DummyExecutor {
268    fn profiler(&self) -> Option<&Profiler> {
269        self.profiler.as_ref()
270    }
271
272    fn profiler_mut(&mut self) -> Option<&mut Profiler> {
273        self.profiler.as_mut()
274    }
275
276    fn enable_profiling(&mut self) {
277        let mut profiler = Profiler::new();
278        profiler.start();
279        self.profiler = Some(profiler);
280    }
281
282    fn disable_profiling(&mut self) {
283        if let Some(mut profiler) = self.profiler.take() {
284            profiler.stop();
285        }
286    }
287}
288
289// TlBatchExecutor implementation
290impl TlBatchExecutor for DummyExecutor {
291    type Tensor = DummyTensor;
292    type Error = ExecutorError;
293
294    fn execute_batch(
295        &mut self,
296        graph: &EinsumGraph,
297        batch_inputs: Vec<Vec<Self::Tensor>>,
298    ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
299        if batch_inputs.is_empty() {
300            return Err(ExecutorError::EmptyInput(
301                "Batch inputs cannot be empty".to_string(),
302            ));
303        }
304
305        let mut outputs = Vec::with_capacity(batch_inputs.len());
306        for inputs in batch_inputs {
307            let output = self.execute_graph_internal(graph, &inputs)?;
308            outputs.push(output);
309        }
310
311        Ok(BatchResult::new(outputs))
312    }
313
314    fn execute_batch_parallel(
315        &mut self,
316        graph: &EinsumGraph,
317        batch_inputs: Vec<Vec<Self::Tensor>>,
318        _num_threads: Option<usize>,
319    ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
320        // DummyExecutor doesn't support true parallel execution
321        // Fall back to sequential execution
322        self.execute_batch(graph, batch_inputs)
323    }
324
325    fn optimal_batch_size(&self) -> usize {
326        16 // Conservative batch size for dummy executor
327    }
328}
329
330// TlAutodiff implementation
331impl TlAutodiff for DummyExecutor {
332    type Tape = HashMap<usize, DummyTensor>;
333
334    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
335        if graph.nodes.is_empty() {
336            return Err(ExecutorError::EmptyInput(
337                "Graph has no nodes to execute".to_string(),
338            ));
339        }
340
341        // Execute the graph and return the last tensor
342        let mut tensors: HashMap<usize, DummyTensor> = HashMap::new();
343
344        // Initialize input tensors (first N tensors in the graph)
345        // Note: In a real implementation, these would be provided as inputs
346        for (idx, tensor_name) in graph.tensors.iter().enumerate() {
347            // Create dummy tensors with default shape [10]
348            tensors.insert(idx, DummyTensor::ones(tensor_name.clone(), vec![10]));
349        }
350
351        // Execute each node
352        for (node_idx, node) in graph.nodes.iter().enumerate() {
353            let output_idx = graph.tensors.len() + node_idx;
354            let output = self.execute_node_internal(node, &tensors)?;
355            tensors.insert(output_idx, output);
356        }
357
358        // Return the last computed tensor (or from outputs if specified)
359        let output_idx = if graph.outputs.is_empty() {
360            graph.tensors.len() + graph.nodes.len() - 1
361        } else {
362            graph.outputs[0]
363        };
364
365        tensors
366            .remove(&output_idx)
367            .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor".to_string()))
368    }
369
370    fn backward(
371        &mut self,
372        graph: &EinsumGraph,
373        _loss: &Self::Tensor,
374    ) -> Result<Self::Tape, Self::Error> {
375        // Simplified backward pass: just return unit gradients for all tensors
376        let mut gradients = HashMap::new();
377
378        for (idx, tensor_name) in graph.tensors.iter().enumerate() {
379            gradients.insert(
380                idx,
381                DummyTensor::ones(format!("grad_{}", tensor_name), vec![10]),
382            );
383        }
384
385        Ok(gradients)
386    }
387}
388
389// Helper methods for DummyExecutor
390impl DummyExecutor {
391    fn execute_graph_internal(
392        &mut self,
393        graph: &EinsumGraph,
394        _inputs: &[DummyTensor],
395    ) -> Result<DummyTensor, ExecutorError> {
396        // Simplified: just execute forward pass
397        self.forward(graph)
398    }
399
400    fn execute_node_internal(
401        &mut self,
402        node: &tensorlogic_ir::EinsumNode,
403        tensors: &HashMap<usize, DummyTensor>,
404    ) -> Result<DummyTensor, ExecutorError> {
405        match &node.op {
406            OpType::Einsum { spec } => {
407                let inputs: Vec<DummyTensor> =
408                    node.inputs
409                        .iter()
410                        .map(|&idx| {
411                            tensors.get(&idx).cloned().ok_or_else(|| {
412                                ExecutorError::TensorNotFound(format!("Tensor {}", idx))
413                            })
414                        })
415                        .collect::<Result<Vec<_>, _>>()?;
416
417                self.einsum(spec, &inputs)
418            }
419            OpType::ElemUnary { op } => {
420                if node.inputs.is_empty() {
421                    return Err(ExecutorError::EmptyInput(
422                        "ElemUnary requires an input".to_string(),
423                    ));
424                }
425                let input = tensors.get(&node.inputs[0]).ok_or_else(|| {
426                    ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
427                })?;
428                let elem_op = Self::parse_elem_op(op)?;
429                self.elem_op(elem_op, input)
430            }
431            OpType::ElemBinary { op } => {
432                if node.inputs.len() < 2 {
433                    return Err(ExecutorError::EmptyInput(
434                        "ElemBinary requires two inputs".to_string(),
435                    ));
436                }
437                let input1 = tensors.get(&node.inputs[0]).ok_or_else(|| {
438                    ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
439                })?;
440                let input2 = tensors.get(&node.inputs[1]).ok_or_else(|| {
441                    ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[1]))
442                })?;
443                let elem_op = Self::parse_elem_op(op)?;
444                self.elem_op_binary(elem_op, input1, input2)
445            }
446            OpType::Reduce { op, axes } => {
447                if node.inputs.is_empty() {
448                    return Err(ExecutorError::EmptyInput(
449                        "Reduce requires an input".to_string(),
450                    ));
451                }
452                let input = tensors.get(&node.inputs[0]).ok_or_else(|| {
453                    ExecutorError::TensorNotFound(format!("Tensor {}", node.inputs[0]))
454                })?;
455                let reduce_op = Self::parse_reduce_op(op)?;
456                self.reduce(reduce_op, input, axes)
457            }
458        }
459    }
460
461    fn parse_elem_op(op_str: &str) -> Result<ElemOp, ExecutorError> {
462        match op_str.to_lowercase().as_str() {
463            "relu" => Ok(ElemOp::Relu),
464            "sigmoid" => Ok(ElemOp::Sigmoid),
465            "oneminus" | "one_minus" => Ok(ElemOp::OneMinus),
466            "add" => Ok(ElemOp::Add),
467            "subtract" | "sub" => Ok(ElemOp::Subtract),
468            "multiply" | "mul" => Ok(ElemOp::Multiply),
469            "divide" | "div" => Ok(ElemOp::Divide),
470            "eq" | "equal" => Ok(ElemOp::Eq),
471            "lt" | "less" => Ok(ElemOp::Lt),
472            "gt" | "greater" => Ok(ElemOp::Gt),
473            "lte" | "le" => Ok(ElemOp::Lte),
474            "gte" | "ge" => Ok(ElemOp::Gte),
475            "ormax" | "or_max" => Ok(ElemOp::OrMax),
476            "orprobsum" | "or_prob_sum" => Ok(ElemOp::OrProbSum),
477            "nand" => Ok(ElemOp::Nand),
478            "nor" => Ok(ElemOp::Nor),
479            "xor" => Ok(ElemOp::Xor),
480            _ => Err(ExecutorError::UnsupportedOperation(format!(
481                "Unknown element operation: {}",
482                op_str
483            ))),
484        }
485    }
486
487    fn parse_reduce_op(op_str: &str) -> Result<ReduceOp, ExecutorError> {
488        match op_str.to_lowercase().as_str() {
489            "sum" => Ok(ReduceOp::Sum),
490            "max" => Ok(ReduceOp::Max),
491            "min" => Ok(ReduceOp::Min),
492            "mean" => Ok(ReduceOp::Mean),
493            "product" | "prod" => Ok(ReduceOp::Product),
494            _ => Err(ExecutorError::UnsupportedOperation(format!(
495                "Unknown reduce operation: {}",
496                op_str
497            ))),
498        }
499    }
500}