Skip to main content

tensorlogic_scirs_backend/
executor.rs

1//! SciRS2 executor implementation.
2
3use scirs2_core::ndarray::Axis;
4use std::collections::HashMap;
5use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
6
7use crate::autodiff::ForwardTape;
8use crate::memory_pool::TensorPool;
9use crate::Scirs2Tensor;
10
11pub struct Scirs2Exec {
12    pub tensors: HashMap<String, Scirs2Tensor>,
13    pub(crate) tape: Option<ForwardTape>,
14    /// Optional memory pool for tensor reuse
15    pub(crate) pool: Option<TensorPool>,
16}
17
18impl Default for Scirs2Exec {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl Scirs2Exec {
25    pub fn new() -> Self {
26        Scirs2Exec {
27            tensors: HashMap::new(),
28            tape: None,
29            pool: None,
30        }
31    }
32
33    /// Create executor with memory pooling enabled
34    pub fn with_memory_pool() -> Self {
35        Scirs2Exec {
36            tensors: HashMap::new(),
37            tape: None,
38            pool: Some(TensorPool::new()),
39        }
40    }
41
42    /// Enable memory pooling
43    pub fn enable_pooling(&mut self) {
44        if self.pool.is_none() {
45            self.pool = Some(TensorPool::new());
46        }
47    }
48
49    /// Disable memory pooling
50    pub fn disable_pooling(&mut self) {
51        self.pool = None;
52    }
53
54    /// Get pool statistics if pooling is enabled
55    pub fn pool_stats(&self) -> Option<crate::memory_pool::PoolStats> {
56        self.pool.as_ref().map(|p| p.stats())
57    }
58
59    pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
60        self.tensors.insert(name.into(), tensor);
61    }
62
63    pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor> {
64        self.tensors.get(name)
65    }
66}
67
68impl TlExecutor for Scirs2Exec {
69    type Tensor = Scirs2Tensor;
70    type Error = ExecutorError;
71
72    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
73        if inputs.is_empty() {
74            return Err(ExecutorError::InvalidEinsumSpec(
75                "No input tensors provided".to_string(),
76            ));
77        }
78
79        let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
80        let view_refs: Vec<_> = views.iter().collect();
81
82        scirs2_linalg::einsum(spec, &view_refs)
83            .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
84    }
85
86    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
87        let result = match op {
88            ElemOp::Relu => x.mapv(|v| v.max(0.0)),
89            ElemOp::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
90            ElemOp::OneMinus => x.mapv(|v| 1.0 - v),
91            _ => {
92                return Err(ExecutorError::UnsupportedOperation(format!(
93                    "Unary operation {:?} not supported",
94                    op
95                )))
96            }
97        };
98
99        Ok(result)
100    }
101
102    fn elem_op_binary(
103        &mut self,
104        op: ElemOp,
105        x: &Self::Tensor,
106        y: &Self::Tensor,
107    ) -> Result<Self::Tensor, Self::Error> {
108        // Handle scalar broadcasting: if one tensor is scalar (shape []) and the other isn't,
109        // broadcast the scalar to match the shape of the other tensor
110        let x_is_scalar = x.ndim() == 0;
111        let y_is_scalar = y.ndim() == 0;
112
113        let (x_broadcast, y_broadcast);
114        let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
115            // x is scalar, broadcast to y's shape
116            let scalar_value = x.iter().next().unwrap();
117            x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
118            (&x_broadcast.view(), &y.view())
119        } else if y_is_scalar && !x_is_scalar {
120            // y is scalar, broadcast to x's shape
121            let scalar_value = y.iter().next().unwrap();
122            y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
123            (&x.view(), &y_broadcast.view())
124        } else if x.shape() != y.shape() {
125            // Shapes don't match and neither is a scalar
126            return Err(ExecutorError::ShapeMismatch(format!(
127                "Shape mismatch: {:?} vs {:?}",
128                x.shape(),
129                y.shape()
130            )));
131        } else {
132            // Shapes match exactly (including both being scalars)
133            (&x.view(), &y.view())
134        };
135
136        let result = match op {
137            // Arithmetic operations
138            ElemOp::Add => x_ref + y_ref,
139            ElemOp::Subtract => x_ref - y_ref,
140            ElemOp::Multiply => x_ref * y_ref,
141            ElemOp::Divide => x_ref / y_ref,
142            ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
143                .and(y_ref)
144                .map_collect(|&a, &b| a.min(b)),
145            ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
146                .and(y_ref)
147                .map_collect(|&a, &b| a.max(b)),
148
149            // Comparison operations (return 0.0 or 1.0)
150            ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
151                .and(y_ref)
152                .map_collect(|&a, &b| if (a - b).abs() < 1e-10 { 1.0 } else { 0.0 }),
153            ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
154                .and(y_ref)
155                .map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
156            ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
157                .and(y_ref)
158                .map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
159            ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
160                .and(y_ref)
161                .map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
162            ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
163                .and(y_ref)
164                .map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
165
166            // Extended logical operations
167            ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
168                .and(y_ref)
169                .map_collect(|&a, &b| a.max(b)),
170            ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
171                .and(y_ref)
172                .map_collect(|&a, &b| a + b - a * b), // 1 - (1-a)(1-b) = a + b - ab
173            ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
174                .and(y_ref)
175                .map_collect(|&a, &b| 1.0 - a * b),
176            ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
177                .and(y_ref)
178                .map_collect(|&a, &b| 1.0 - a.max(b)),
179            ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
180                .and(y_ref)
181                .map_collect(|&a, &b| a + b - 2.0 * a * b), // Soft XOR: (a XOR b) = a + b - 2ab
182
183            _ => {
184                return Err(ExecutorError::UnsupportedOperation(format!(
185                    "Binary operation {:?} not supported",
186                    op
187                )))
188            }
189        };
190
191        Ok(result)
192    }
193
194    fn reduce(
195        &mut self,
196        op: ReduceOp,
197        x: &Self::Tensor,
198        axes: &[usize],
199    ) -> Result<Self::Tensor, Self::Error> {
200        if axes.is_empty() {
201            return Ok(x.clone());
202        }
203
204        for &axis in axes {
205            if axis >= x.ndim() {
206                return Err(ExecutorError::ShapeMismatch(format!(
207                    "Axis {} out of bounds for tensor with {} dimensions",
208                    axis,
209                    x.ndim()
210                )));
211            }
212        }
213
214        let mut result = x.clone();
215        for &axis in axes.iter().rev() {
216            result = match op {
217                ReduceOp::Sum => result.sum_axis(Axis(axis)),
218                ReduceOp::Max => result.fold_axis(Axis(axis), f64::NEG_INFINITY, |&a, &b| a.max(b)),
219                ReduceOp::Min => result.fold_axis(Axis(axis), f64::INFINITY, |&a, &b| a.min(b)),
220                ReduceOp::Mean => result.mean_axis(Axis(axis)).unwrap(),
221                ReduceOp::Product => result.fold_axis(Axis(axis), 1.0, |&a, &b| a * b),
222            };
223        }
224
225        Ok(result)
226    }
227}