1use scirs2_core::ndarray::{ArrayD, Axis};
7use std::collections::HashMap;
8use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
9
10pub type Scirs2Tensor32 = ArrayD<f32>;
12
13pub struct Scirs2Exec32 {
19 pub tensors: HashMap<String, Scirs2Tensor32>,
20}
21
22impl Default for Scirs2Exec32 {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl Scirs2Exec32 {
29 pub fn new() -> Self {
31 Scirs2Exec32 {
32 tensors: HashMap::new(),
33 }
34 }
35
36 pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor32) {
38 self.tensors.insert(name.into(), tensor);
39 }
40
41 pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor32> {
43 self.tensors.get(name)
44 }
45}
46
47impl TlExecutor for Scirs2Exec32 {
48 type Tensor = Scirs2Tensor32;
49 type Error = ExecutorError;
50
51 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
52 if inputs.is_empty() {
53 return Err(ExecutorError::InvalidEinsumSpec(
54 "No input tensors provided".to_string(),
55 ));
56 }
57
58 let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
59 let view_refs: Vec<_> = views.iter().collect();
60
61 scirs2_linalg::einsum(spec, &view_refs)
62 .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
63 }
64
65 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
66 let result = match op {
67 ElemOp::Relu => x.mapv(|v| v.max(0.0_f32)),
68 ElemOp::Sigmoid => x.mapv(|v| 1.0_f32 / (1.0_f32 + (-v).exp())),
69 ElemOp::OneMinus => x.mapv(|v| 1.0_f32 - v),
70 _ => {
71 return Err(ExecutorError::UnsupportedOperation(format!(
72 "Unary operation {:?} not supported",
73 op
74 )))
75 }
76 };
77
78 Ok(result)
79 }
80
81 fn elem_op_binary(
82 &mut self,
83 op: ElemOp,
84 x: &Self::Tensor,
85 y: &Self::Tensor,
86 ) -> Result<Self::Tensor, Self::Error> {
87 let x_is_scalar = x.ndim() == 0;
90 let y_is_scalar = y.ndim() == 0;
91
92 let (x_broadcast, y_broadcast);
93 let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
94 let scalar_value = x
95 .iter()
96 .next()
97 .expect("scalar tensor has at least one element");
98 x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
99 (&x_broadcast.view(), &y.view())
100 } else if y_is_scalar && !x_is_scalar {
101 let scalar_value = y
102 .iter()
103 .next()
104 .expect("scalar tensor has at least one element");
105 y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
106 (&x.view(), &y_broadcast.view())
107 } else if x.shape() != y.shape() {
108 return Err(ExecutorError::ShapeMismatch(format!(
109 "Shape mismatch: {:?} vs {:?}",
110 x.shape(),
111 y.shape()
112 )));
113 } else {
114 (&x.view(), &y.view())
115 };
116
117 let result = match op {
118 ElemOp::Add => x_ref + y_ref,
119 ElemOp::Subtract => x_ref - y_ref,
120 ElemOp::Multiply => x_ref * y_ref,
121 ElemOp::Divide => x_ref / y_ref,
122 ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
123 .and(y_ref)
124 .map_collect(|&a, &b| a.min(b)),
125 ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
126 .and(y_ref)
127 .map_collect(|&a, &b| a.max(b)),
128
129 ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
130 .and(y_ref)
131 .map_collect(|&a, &b| if (a - b).abs() < 1e-7_f32 { 1.0 } else { 0.0 }),
132 ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
133 .and(y_ref)
134 .map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
135 ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
136 .and(y_ref)
137 .map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
138 ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
139 .and(y_ref)
140 .map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
141 ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
142 .and(y_ref)
143 .map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
144
145 ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
146 .and(y_ref)
147 .map_collect(|&a, &b| a.max(b)),
148 ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
149 .and(y_ref)
150 .map_collect(|&a, &b| a + b - a * b),
151 ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
152 .and(y_ref)
153 .map_collect(|&a, &b| 1.0_f32 - a * b),
154 ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
155 .and(y_ref)
156 .map_collect(|&a, &b| 1.0_f32 - a.max(b)),
157 ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
158 .and(y_ref)
159 .map_collect(|&a, &b| a + b - 2.0_f32 * a * b),
160
161 _ => {
162 return Err(ExecutorError::UnsupportedOperation(format!(
163 "Binary operation {:?} not supported",
164 op
165 )))
166 }
167 };
168
169 Ok(result)
170 }
171
172 fn reduce(
173 &mut self,
174 op: ReduceOp,
175 x: &Self::Tensor,
176 axes: &[usize],
177 ) -> Result<Self::Tensor, Self::Error> {
178 if axes.is_empty() {
179 return Ok(x.clone());
180 }
181
182 for &axis in axes {
183 if axis >= x.ndim() {
184 return Err(ExecutorError::ShapeMismatch(format!(
185 "Axis {} out of bounds for tensor with {} dimensions",
186 axis,
187 x.ndim()
188 )));
189 }
190 }
191
192 let mut result = x.clone();
193 for &axis in axes.iter().rev() {
194 result = match op {
195 ReduceOp::Sum => result.sum_axis(Axis(axis)),
196 ReduceOp::Max => result.fold_axis(Axis(axis), f32::NEG_INFINITY, |&a, &b| a.max(b)),
197 ReduceOp::Min => result.fold_axis(Axis(axis), f32::INFINITY, |&a, &b| a.min(b)),
198 ReduceOp::Mean => result
199 .mean_axis(Axis(axis))
200 .expect("axis is valid as validated earlier"),
201 ReduceOp::Product => result.fold_axis(Axis(axis), 1.0_f32, |&a, &b| a * b),
202 };
203 }
204
205 Ok(result)
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use scirs2_core::ndarray::ArrayD;
213
214 fn make_tensor(shape: &[usize], data: Vec<f32>) -> ArrayD<f32> {
215 ArrayD::from_shape_vec(shape, data).expect("valid shape/data for test tensor")
216 }
217
218 #[test]
219 fn test_exec32_einsum_matmul() {
220 let a = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
222 let b = make_tensor(&[3, 2], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
223 let mut exec = Scirs2Exec32::new();
224 let result = exec.einsum("ij,jk->ik", &[a, b]).expect("einsum matmul");
225 assert_eq!(result.shape(), &[2, 2]);
226 let data: Vec<f32> = result.iter().copied().collect();
229 assert!(
230 (data[0] - 58.0).abs() < 1e-4,
231 "expected 58, got {}",
232 data[0]
233 );
234 assert!(
235 (data[1] - 64.0).abs() < 1e-4,
236 "expected 64, got {}",
237 data[1]
238 );
239 assert!(
240 (data[2] - 139.0).abs() < 1e-4,
241 "expected 139, got {}",
242 data[2]
243 );
244 assert!(
245 (data[3] - 154.0).abs() < 1e-4,
246 "expected 154, got {}",
247 data[3]
248 );
249 }
250
251 #[test]
252 fn test_exec32_relu() {
253 let x = make_tensor(&[4], vec![-1.0, 0.0, 1.0, 2.0]);
254 let mut exec = Scirs2Exec32::new();
255 let result = exec.elem_op(ElemOp::Relu, &x).expect("relu");
256 let data: Vec<f32> = result.iter().copied().collect();
257 assert_eq!(data[0], 0.0);
258 assert_eq!(data[1], 0.0);
259 assert_eq!(data[2], 1.0);
260 assert_eq!(data[3], 2.0);
261 }
262
263 #[test]
264 fn test_exec32_sigmoid() {
265 let x = make_tensor(&[4], vec![-2.0, 0.0, 1.0, 10.0]);
266 let mut exec = Scirs2Exec32::new();
267 let result = exec.elem_op(ElemOp::Sigmoid, &x).expect("sigmoid");
268 for &v in result.iter() {
269 assert!(v > 0.0 && v < 1.0, "sigmoid output {} not in (0,1)", v);
270 }
271 let data: Vec<f32> = result.iter().copied().collect();
273 assert!((data[1] - 0.5).abs() < 1e-5, "sigmoid(0) should be 0.5");
274 }
275
276 #[test]
277 fn test_exec32_one_minus() {
278 let x = make_tensor(&[3], vec![0.0, 0.3, 1.0]);
279 let mut exec = Scirs2Exec32::new();
280 let result = exec.elem_op(ElemOp::OneMinus, &x).expect("one_minus");
281 let data: Vec<f32> = result.iter().copied().collect();
282 assert!((data[0] - 1.0).abs() < 1e-6);
283 assert!((data[1] - 0.7).abs() < 1e-5);
284 assert!((data[2] - 0.0).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_exec32_add() {
289 let x = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
290 let y = make_tensor(&[3], vec![4.0, 5.0, 6.0]);
291 let mut exec = Scirs2Exec32::new();
292 let result = exec.elem_op_binary(ElemOp::Add, &x, &y).expect("add");
293 let data: Vec<f32> = result.iter().copied().collect();
294 assert_eq!(data, vec![5.0, 7.0, 9.0]);
295 }
296
297 #[test]
298 fn test_exec32_sub() {
299 let x = make_tensor(&[3], vec![10.0, 5.0, 3.0]);
300 let y = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
301 let mut exec = Scirs2Exec32::new();
302 let result = exec.elem_op_binary(ElemOp::Subtract, &x, &y).expect("sub");
303 let data: Vec<f32> = result.iter().copied().collect();
304 assert_eq!(data, vec![9.0, 3.0, 0.0]);
305 }
306
307 #[test]
308 fn test_exec32_mul() {
309 let x = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
310 let y = make_tensor(&[3], vec![5.0, 6.0, 7.0]);
311 let mut exec = Scirs2Exec32::new();
312 let result = exec.elem_op_binary(ElemOp::Multiply, &x, &y).expect("mul");
313 let data: Vec<f32> = result.iter().copied().collect();
314 assert_eq!(data, vec![10.0, 18.0, 28.0]);
315 }
316
317 #[test]
318 fn test_exec32_div() {
319 let x = make_tensor(&[3], vec![6.0, 9.0, 12.0]);
320 let y = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
321 let mut exec = Scirs2Exec32::new();
322 let result = exec.elem_op_binary(ElemOp::Divide, &x, &y).expect("div");
323 let data: Vec<f32> = result.iter().copied().collect();
324 assert_eq!(data, vec![3.0, 3.0, 3.0]);
325 }
326
327 #[test]
328 fn test_exec32_reduce_sum() {
329 let x = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
331 let mut exec = Scirs2Exec32::new();
332 let result = exec.reduce(ReduceOp::Sum, &x, &[0]).expect("reduce_sum");
333 assert_eq!(result.shape(), &[3]);
334 let data: Vec<f32> = result.iter().copied().collect();
335 assert_eq!(data, vec![5.0, 7.0, 9.0]);
336 }
337
338 #[test]
339 fn test_exec32_reduce_max() {
340 let x = make_tensor(&[2, 3], vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
341 let mut exec = Scirs2Exec32::new();
342 let result = exec.reduce(ReduceOp::Max, &x, &[0]).expect("reduce_max");
343 assert_eq!(result.shape(), &[3]);
344 let data: Vec<f32> = result.iter().copied().collect();
345 assert_eq!(data, vec![4.0, 5.0, 6.0]);
346 }
347
348 #[test]
349 fn test_exec32_reduce_mean() {
350 let x = make_tensor(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
351 let mut exec = Scirs2Exec32::new();
352 let result = exec.reduce(ReduceOp::Mean, &x, &[0]).expect("reduce_mean");
353 assert_eq!(result.shape(), &[4]);
354 let data: Vec<f32> = result.iter().copied().collect();
355 for (got, expected) in data.iter().zip([3.0_f32, 4.0, 5.0, 6.0].iter()) {
356 assert!(
357 (got - expected).abs() < 1e-5,
358 "mean mismatch: {} vs {}",
359 got,
360 expected
361 );
362 }
363 }
364
365 #[test]
366 fn test_exec32_zeros() {
367 let zeros: ArrayD<f32> = ArrayD::zeros(vec![2, 3]);
368 assert_eq!(zeros.shape(), &[2, 3]);
369 assert!(zeros.iter().all(|&v| v == 0.0_f32));
370 }
371
372 #[test]
373 fn test_exec32_ones() {
374 let ones: ArrayD<f32> = ArrayD::ones(vec![2, 3]);
375 assert_eq!(ones.shape(), &[2, 3]);
376 assert!(ones.iter().all(|&v| v == 1.0_f32));
377 }
378
379 #[test]
380 fn test_exec32_from_data() {
381 let data = vec![1.5_f32, 2.5, 3.5, 4.5];
382 let tensor = ArrayD::from_shape_vec(vec![2, 2], data.clone())
383 .expect("valid shape for from_data test");
384 assert_eq!(tensor.shape(), &[2, 2]);
385 let roundtrip: Vec<f32> = tensor.iter().copied().collect();
386 assert_eq!(roundtrip, data);
387 }
388
389 #[test]
390 fn test_exec32_memory_half_of_f64() {
391 let f32_tensor: ArrayD<f32> = ArrayD::zeros(vec![4, 4]);
393 let f64_tensor: ArrayD<f64> = ArrayD::zeros(vec![4, 4]);
394 let f32_bytes = f32_tensor.len() * std::mem::size_of::<f32>();
395 let f64_bytes = f64_tensor.len() * std::mem::size_of::<f64>();
396 assert_eq!(f32_bytes * 2, f64_bytes);
397 }
398}