tensorlogic_scirs_backend/
executor.rs1use scirs2_core::ndarray::Axis;
4use std::collections::HashMap;
5use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
6
7use crate::autodiff::ForwardTape;
8use crate::memory_pool::TensorPool;
9use crate::Scirs2Tensor;
10
11pub struct Scirs2Exec {
12 pub tensors: HashMap<String, Scirs2Tensor>,
13 pub(crate) tape: Option<ForwardTape>,
14 pub(crate) pool: Option<TensorPool>,
16}
17
18impl Default for Scirs2Exec {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Scirs2Exec {
25 pub fn new() -> Self {
26 Scirs2Exec {
27 tensors: HashMap::new(),
28 tape: None,
29 pool: None,
30 }
31 }
32
33 pub fn with_memory_pool() -> Self {
35 Scirs2Exec {
36 tensors: HashMap::new(),
37 tape: None,
38 pool: Some(TensorPool::new()),
39 }
40 }
41
42 pub fn enable_pooling(&mut self) {
44 if self.pool.is_none() {
45 self.pool = Some(TensorPool::new());
46 }
47 }
48
49 pub fn disable_pooling(&mut self) {
51 self.pool = None;
52 }
53
54 pub fn pool_stats(&self) -> Option<crate::memory_pool::PoolStats> {
56 self.pool.as_ref().map(|p| p.stats())
57 }
58
59 pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
60 self.tensors.insert(name.into(), tensor);
61 }
62
63 pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor> {
64 self.tensors.get(name)
65 }
66}
67
68impl TlExecutor for Scirs2Exec {
69 type Tensor = Scirs2Tensor;
70 type Error = ExecutorError;
71
72 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
73 if inputs.is_empty() {
74 return Err(ExecutorError::InvalidEinsumSpec(
75 "No input tensors provided".to_string(),
76 ));
77 }
78
79 let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
80 let view_refs: Vec<_> = views.iter().collect();
81
82 scirs2_linalg::einsum(spec, &view_refs)
83 .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
84 }
85
86 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
87 let result = match op {
88 ElemOp::Relu => x.mapv(|v| v.max(0.0)),
89 ElemOp::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
90 ElemOp::OneMinus => x.mapv(|v| 1.0 - v),
91 _ => {
92 return Err(ExecutorError::UnsupportedOperation(format!(
93 "Unary operation {:?} not supported",
94 op
95 )))
96 }
97 };
98
99 Ok(result)
100 }
101
102 fn elem_op_binary(
103 &mut self,
104 op: ElemOp,
105 x: &Self::Tensor,
106 y: &Self::Tensor,
107 ) -> Result<Self::Tensor, Self::Error> {
108 let x_is_scalar = x.ndim() == 0;
111 let y_is_scalar = y.ndim() == 0;
112
113 let (x_broadcast, y_broadcast);
114 let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
115 let scalar_value = x.iter().next().unwrap();
117 x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
118 (&x_broadcast.view(), &y.view())
119 } else if y_is_scalar && !x_is_scalar {
120 let scalar_value = y.iter().next().unwrap();
122 y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
123 (&x.view(), &y_broadcast.view())
124 } else if x.shape() != y.shape() {
125 return Err(ExecutorError::ShapeMismatch(format!(
127 "Shape mismatch: {:?} vs {:?}",
128 x.shape(),
129 y.shape()
130 )));
131 } else {
132 (&x.view(), &y.view())
134 };
135
136 let result = match op {
137 ElemOp::Add => x_ref + y_ref,
139 ElemOp::Subtract => x_ref - y_ref,
140 ElemOp::Multiply => x_ref * y_ref,
141 ElemOp::Divide => x_ref / y_ref,
142 ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
143 .and(y_ref)
144 .map_collect(|&a, &b| a.min(b)),
145 ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
146 .and(y_ref)
147 .map_collect(|&a, &b| a.max(b)),
148
149 ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
151 .and(y_ref)
152 .map_collect(|&a, &b| if (a - b).abs() < 1e-10 { 1.0 } else { 0.0 }),
153 ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
154 .and(y_ref)
155 .map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
156 ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
157 .and(y_ref)
158 .map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
159 ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
160 .and(y_ref)
161 .map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
162 ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
163 .and(y_ref)
164 .map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
165
166 ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
168 .and(y_ref)
169 .map_collect(|&a, &b| a.max(b)),
170 ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
171 .and(y_ref)
172 .map_collect(|&a, &b| a + b - a * b), ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
174 .and(y_ref)
175 .map_collect(|&a, &b| 1.0 - a * b),
176 ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
177 .and(y_ref)
178 .map_collect(|&a, &b| 1.0 - a.max(b)),
179 ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
180 .and(y_ref)
181 .map_collect(|&a, &b| a + b - 2.0 * a * b), _ => {
184 return Err(ExecutorError::UnsupportedOperation(format!(
185 "Binary operation {:?} not supported",
186 op
187 )))
188 }
189 };
190
191 Ok(result)
192 }
193
194 fn reduce(
195 &mut self,
196 op: ReduceOp,
197 x: &Self::Tensor,
198 axes: &[usize],
199 ) -> Result<Self::Tensor, Self::Error> {
200 if axes.is_empty() {
201 return Ok(x.clone());
202 }
203
204 for &axis in axes {
205 if axis >= x.ndim() {
206 return Err(ExecutorError::ShapeMismatch(format!(
207 "Axis {} out of bounds for tensor with {} dimensions",
208 axis,
209 x.ndim()
210 )));
211 }
212 }
213
214 let mut result = x.clone();
215 for &axis in axes.iter().rev() {
216 result = match op {
217 ReduceOp::Sum => result.sum_axis(Axis(axis)),
218 ReduceOp::Max => result.fold_axis(Axis(axis), f64::NEG_INFINITY, |&a, &b| a.max(b)),
219 ReduceOp::Min => result.fold_axis(Axis(axis), f64::INFINITY, |&a, &b| a.min(b)),
220 ReduceOp::Mean => result.mean_axis(Axis(axis)).unwrap(),
221 ReduceOp::Product => result.fold_axis(Axis(axis), 1.0, |&a, &b| a * b),
222 };
223 }
224
225 Ok(result)
226 }
227}