Skip to main content

tensorlogic_scirs_backend/
executor_f32.rs

1//! SciRS2 f32 executor implementation.
2//!
3//! This module provides `Scirs2Exec32`, an f32-precision executor that mirrors
4//! `Scirs2Exec` but uses `ArrayD<f32>` for all tensor operations.
5
6use scirs2_core::ndarray::{ArrayD, Axis};
7use std::collections::HashMap;
8use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
9
10/// An f32 dynamic-rank tensor backed by ndarray.
11pub type Scirs2Tensor32 = ArrayD<f32>;
12
13/// SciRS2-backed executor operating in f32 precision.
14///
15/// This executor mirrors `Scirs2Exec` but uses `ArrayD<f32>` for all tensor
16/// storage and computation, providing 50% memory savings compared to f64
17/// at the cost of reduced numerical precision.
18pub struct Scirs2Exec32 {
19    pub tensors: HashMap<String, Scirs2Tensor32>,
20}
21
22impl Default for Scirs2Exec32 {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl Scirs2Exec32 {
29    /// Create a new, empty f32 executor.
30    pub fn new() -> Self {
31        Scirs2Exec32 {
32            tensors: HashMap::new(),
33        }
34    }
35
36    /// Insert a named tensor into the executor's store.
37    pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor32) {
38        self.tensors.insert(name.into(), tensor);
39    }
40
41    /// Retrieve a reference to a named tensor.
42    pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor32> {
43        self.tensors.get(name)
44    }
45}
46
47impl TlExecutor for Scirs2Exec32 {
48    type Tensor = Scirs2Tensor32;
49    type Error = ExecutorError;
50
51    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
52        if inputs.is_empty() {
53            return Err(ExecutorError::InvalidEinsumSpec(
54                "No input tensors provided".to_string(),
55            ));
56        }
57
58        let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
59        let view_refs: Vec<_> = views.iter().collect();
60
61        scirs2_linalg::einsum(spec, &view_refs)
62            .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
63    }
64
65    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
66        let result = match op {
67            ElemOp::Relu => x.mapv(|v| v.max(0.0_f32)),
68            ElemOp::Sigmoid => x.mapv(|v| 1.0_f32 / (1.0_f32 + (-v).exp())),
69            ElemOp::OneMinus => x.mapv(|v| 1.0_f32 - v),
70            _ => {
71                return Err(ExecutorError::UnsupportedOperation(format!(
72                    "Unary operation {:?} not supported",
73                    op
74                )))
75            }
76        };
77
78        Ok(result)
79    }
80
81    fn elem_op_binary(
82        &mut self,
83        op: ElemOp,
84        x: &Self::Tensor,
85        y: &Self::Tensor,
86    ) -> Result<Self::Tensor, Self::Error> {
87        // Handle scalar broadcasting: if one tensor is scalar (shape []) and the other isn't,
88        // broadcast the scalar to match the shape of the other tensor.
89        let x_is_scalar = x.ndim() == 0;
90        let y_is_scalar = y.ndim() == 0;
91
92        let (x_broadcast, y_broadcast);
93        let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
94            let scalar_value = x
95                .iter()
96                .next()
97                .expect("scalar tensor has at least one element");
98            x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
99            (&x_broadcast.view(), &y.view())
100        } else if y_is_scalar && !x_is_scalar {
101            let scalar_value = y
102                .iter()
103                .next()
104                .expect("scalar tensor has at least one element");
105            y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
106            (&x.view(), &y_broadcast.view())
107        } else if x.shape() != y.shape() {
108            return Err(ExecutorError::ShapeMismatch(format!(
109                "Shape mismatch: {:?} vs {:?}",
110                x.shape(),
111                y.shape()
112            )));
113        } else {
114            (&x.view(), &y.view())
115        };
116
117        let result = match op {
118            ElemOp::Add => x_ref + y_ref,
119            ElemOp::Subtract => x_ref - y_ref,
120            ElemOp::Multiply => x_ref * y_ref,
121            ElemOp::Divide => x_ref / y_ref,
122            ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
123                .and(y_ref)
124                .map_collect(|&a, &b| a.min(b)),
125            ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
126                .and(y_ref)
127                .map_collect(|&a, &b| a.max(b)),
128
129            ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
130                .and(y_ref)
131                .map_collect(|&a, &b| if (a - b).abs() < 1e-7_f32 { 1.0 } else { 0.0 }),
132            ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
133                .and(y_ref)
134                .map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
135            ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
136                .and(y_ref)
137                .map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
138            ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
139                .and(y_ref)
140                .map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
141            ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
142                .and(y_ref)
143                .map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
144
145            ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
146                .and(y_ref)
147                .map_collect(|&a, &b| a.max(b)),
148            ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
149                .and(y_ref)
150                .map_collect(|&a, &b| a + b - a * b),
151            ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
152                .and(y_ref)
153                .map_collect(|&a, &b| 1.0_f32 - a * b),
154            ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
155                .and(y_ref)
156                .map_collect(|&a, &b| 1.0_f32 - a.max(b)),
157            ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
158                .and(y_ref)
159                .map_collect(|&a, &b| a + b - 2.0_f32 * a * b),
160
161            _ => {
162                return Err(ExecutorError::UnsupportedOperation(format!(
163                    "Binary operation {:?} not supported",
164                    op
165                )))
166            }
167        };
168
169        Ok(result)
170    }
171
172    fn reduce(
173        &mut self,
174        op: ReduceOp,
175        x: &Self::Tensor,
176        axes: &[usize],
177    ) -> Result<Self::Tensor, Self::Error> {
178        if axes.is_empty() {
179            return Ok(x.clone());
180        }
181
182        for &axis in axes {
183            if axis >= x.ndim() {
184                return Err(ExecutorError::ShapeMismatch(format!(
185                    "Axis {} out of bounds for tensor with {} dimensions",
186                    axis,
187                    x.ndim()
188                )));
189            }
190        }
191
192        let mut result = x.clone();
193        for &axis in axes.iter().rev() {
194            result = match op {
195                ReduceOp::Sum => result.sum_axis(Axis(axis)),
196                ReduceOp::Max => result.fold_axis(Axis(axis), f32::NEG_INFINITY, |&a, &b| a.max(b)),
197                ReduceOp::Min => result.fold_axis(Axis(axis), f32::INFINITY, |&a, &b| a.min(b)),
198                ReduceOp::Mean => result
199                    .mean_axis(Axis(axis))
200                    .expect("axis is valid as validated earlier"),
201                ReduceOp::Product => result.fold_axis(Axis(axis), 1.0_f32, |&a, &b| a * b),
202            };
203        }
204
205        Ok(result)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use scirs2_core::ndarray::ArrayD;
213
214    fn make_tensor(shape: &[usize], data: Vec<f32>) -> ArrayD<f32> {
215        ArrayD::from_shape_vec(shape, data).expect("valid shape/data for test tensor")
216    }
217
218    #[test]
219    fn test_exec32_einsum_matmul() {
220        // 2×3 matrix × 3×2 matrix -> 2×2
221        let a = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
222        let b = make_tensor(&[3, 2], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
223        let mut exec = Scirs2Exec32::new();
224        let result = exec.einsum("ij,jk->ik", &[a, b]).expect("einsum matmul");
225        assert_eq!(result.shape(), &[2, 2]);
226        // Row 0: [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
227        // Row 1: [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
228        let data: Vec<f32> = result.iter().copied().collect();
229        assert!(
230            (data[0] - 58.0).abs() < 1e-4,
231            "expected 58, got {}",
232            data[0]
233        );
234        assert!(
235            (data[1] - 64.0).abs() < 1e-4,
236            "expected 64, got {}",
237            data[1]
238        );
239        assert!(
240            (data[2] - 139.0).abs() < 1e-4,
241            "expected 139, got {}",
242            data[2]
243        );
244        assert!(
245            (data[3] - 154.0).abs() < 1e-4,
246            "expected 154, got {}",
247            data[3]
248        );
249    }
250
251    #[test]
252    fn test_exec32_relu() {
253        let x = make_tensor(&[4], vec![-1.0, 0.0, 1.0, 2.0]);
254        let mut exec = Scirs2Exec32::new();
255        let result = exec.elem_op(ElemOp::Relu, &x).expect("relu");
256        let data: Vec<f32> = result.iter().copied().collect();
257        assert_eq!(data[0], 0.0);
258        assert_eq!(data[1], 0.0);
259        assert_eq!(data[2], 1.0);
260        assert_eq!(data[3], 2.0);
261    }
262
263    #[test]
264    fn test_exec32_sigmoid() {
265        let x = make_tensor(&[4], vec![-2.0, 0.0, 1.0, 10.0]);
266        let mut exec = Scirs2Exec32::new();
267        let result = exec.elem_op(ElemOp::Sigmoid, &x).expect("sigmoid");
268        for &v in result.iter() {
269            assert!(v > 0.0 && v < 1.0, "sigmoid output {} not in (0,1)", v);
270        }
271        // sigmoid(0) == 0.5
272        let data: Vec<f32> = result.iter().copied().collect();
273        assert!((data[1] - 0.5).abs() < 1e-5, "sigmoid(0) should be 0.5");
274    }
275
276    #[test]
277    fn test_exec32_one_minus() {
278        let x = make_tensor(&[3], vec![0.0, 0.3, 1.0]);
279        let mut exec = Scirs2Exec32::new();
280        let result = exec.elem_op(ElemOp::OneMinus, &x).expect("one_minus");
281        let data: Vec<f32> = result.iter().copied().collect();
282        assert!((data[0] - 1.0).abs() < 1e-6);
283        assert!((data[1] - 0.7).abs() < 1e-5);
284        assert!((data[2] - 0.0).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_exec32_add() {
289        let x = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
290        let y = make_tensor(&[3], vec![4.0, 5.0, 6.0]);
291        let mut exec = Scirs2Exec32::new();
292        let result = exec.elem_op_binary(ElemOp::Add, &x, &y).expect("add");
293        let data: Vec<f32> = result.iter().copied().collect();
294        assert_eq!(data, vec![5.0, 7.0, 9.0]);
295    }
296
297    #[test]
298    fn test_exec32_sub() {
299        let x = make_tensor(&[3], vec![10.0, 5.0, 3.0]);
300        let y = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
301        let mut exec = Scirs2Exec32::new();
302        let result = exec.elem_op_binary(ElemOp::Subtract, &x, &y).expect("sub");
303        let data: Vec<f32> = result.iter().copied().collect();
304        assert_eq!(data, vec![9.0, 3.0, 0.0]);
305    }
306
307    #[test]
308    fn test_exec32_mul() {
309        let x = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
310        let y = make_tensor(&[3], vec![5.0, 6.0, 7.0]);
311        let mut exec = Scirs2Exec32::new();
312        let result = exec.elem_op_binary(ElemOp::Multiply, &x, &y).expect("mul");
313        let data: Vec<f32> = result.iter().copied().collect();
314        assert_eq!(data, vec![10.0, 18.0, 28.0]);
315    }
316
317    #[test]
318    fn test_exec32_div() {
319        let x = make_tensor(&[3], vec![6.0, 9.0, 12.0]);
320        let y = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
321        let mut exec = Scirs2Exec32::new();
322        let result = exec.elem_op_binary(ElemOp::Divide, &x, &y).expect("div");
323        let data: Vec<f32> = result.iter().copied().collect();
324        assert_eq!(data, vec![3.0, 3.0, 3.0]);
325    }
326
327    #[test]
328    fn test_exec32_reduce_sum() {
329        // 2×3 matrix, sum along axis 0 -> [1,3] vector
330        let x = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
331        let mut exec = Scirs2Exec32::new();
332        let result = exec.reduce(ReduceOp::Sum, &x, &[0]).expect("reduce_sum");
333        assert_eq!(result.shape(), &[3]);
334        let data: Vec<f32> = result.iter().copied().collect();
335        assert_eq!(data, vec![5.0, 7.0, 9.0]);
336    }
337
338    #[test]
339    fn test_exec32_reduce_max() {
340        let x = make_tensor(&[2, 3], vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
341        let mut exec = Scirs2Exec32::new();
342        let result = exec.reduce(ReduceOp::Max, &x, &[0]).expect("reduce_max");
343        assert_eq!(result.shape(), &[3]);
344        let data: Vec<f32> = result.iter().copied().collect();
345        assert_eq!(data, vec![4.0, 5.0, 6.0]);
346    }
347
348    #[test]
349    fn test_exec32_reduce_mean() {
350        let x = make_tensor(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
351        let mut exec = Scirs2Exec32::new();
352        let result = exec.reduce(ReduceOp::Mean, &x, &[0]).expect("reduce_mean");
353        assert_eq!(result.shape(), &[4]);
354        let data: Vec<f32> = result.iter().copied().collect();
355        for (got, expected) in data.iter().zip([3.0_f32, 4.0, 5.0, 6.0].iter()) {
356            assert!(
357                (got - expected).abs() < 1e-5,
358                "mean mismatch: {} vs {}",
359                got,
360                expected
361            );
362        }
363    }
364
365    #[test]
366    fn test_exec32_zeros() {
367        let zeros: ArrayD<f32> = ArrayD::zeros(vec![2, 3]);
368        assert_eq!(zeros.shape(), &[2, 3]);
369        assert!(zeros.iter().all(|&v| v == 0.0_f32));
370    }
371
372    #[test]
373    fn test_exec32_ones() {
374        let ones: ArrayD<f32> = ArrayD::ones(vec![2, 3]);
375        assert_eq!(ones.shape(), &[2, 3]);
376        assert!(ones.iter().all(|&v| v == 1.0_f32));
377    }
378
379    #[test]
380    fn test_exec32_from_data() {
381        let data = vec![1.5_f32, 2.5, 3.5, 4.5];
382        let tensor = ArrayD::from_shape_vec(vec![2, 2], data.clone())
383            .expect("valid shape for from_data test");
384        assert_eq!(tensor.shape(), &[2, 2]);
385        let roundtrip: Vec<f32> = tensor.iter().copied().collect();
386        assert_eq!(roundtrip, data);
387    }
388
389    #[test]
390    fn test_exec32_memory_half_of_f64() {
391        // f32 is 4 bytes, f64 is 8 bytes; same element count means half total bytes.
392        let f32_tensor: ArrayD<f32> = ArrayD::zeros(vec![4, 4]);
393        let f64_tensor: ArrayD<f64> = ArrayD::zeros(vec![4, 4]);
394        let f32_bytes = f32_tensor.len() * std::mem::size_of::<f32>();
395        let f64_bytes = f64_tensor.len() * std::mem::size_of::<f64>();
396        assert_eq!(f32_bytes * 2, f64_bytes);
397    }
398}