Skip to main content

tensorlogic_infer/
traits.rs

1//! Core execution traits for TensorLogic engines.
2
3use tensorlogic_ir::EinsumGraph;
4
5use crate::ops::{ElemOp, ReduceOp};
6
7/// Core tensor execution interface.
8///
9/// Implementations provide the fundamental tensor operations required
10/// for executing compiled TensorLogic programs.
11pub trait TlExecutor {
12    type Tensor;
13    type Error;
14
15    /// Execute an einsum operation on input tensors.
16    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error>;
17
18    /// Apply an element-wise unary operation.
19    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error>;
20
21    /// Apply an element-wise binary operation.
22    fn elem_op_binary(
23        &mut self,
24        op: ElemOp,
25        x: &Self::Tensor,
26        y: &Self::Tensor,
27    ) -> Result<Self::Tensor, Self::Error>;
28
29    /// Reduce a tensor along specified axes.
30    fn reduce(
31        &mut self,
32        op: ReduceOp,
33        x: &Self::Tensor,
34        axes: &[usize],
35    ) -> Result<Self::Tensor, Self::Error>;
36}
37
38/// Automatic differentiation interface.
39///
40/// Extends `TlExecutor` with forward/backward pass capabilities for training.
41pub trait TlAutodiff: TlExecutor {
42    type Tape;
43
44    /// Execute forward pass on an EinsumGraph.
45    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error>;
46
47    /// Execute backward pass to compute gradients.
48    fn backward(
49        &mut self,
50        graph: &EinsumGraph,
51        loss: &Self::Tensor,
52    ) -> Result<Self::Tape, Self::Error>;
53}