Skip to main content

spynso3/
lib.rs

1use std::{collections::HashMap, ops::Deref};
2
3use anyhow::anyhow;
4
5use library::SpensorLibrary;
6use library_tensor::AtomsOrFloats;
7use network::SpensoNet;
8
9use pyo3::{
10    PyClass,
11    exceptions::{self, PyIndexError, PyOverflowError, PyRuntimeError, PyTypeError},
12    prelude::*,
13    types::{PyComplex, PyFloat, PySlice, PyType},
14};
15
16#[cfg(feature = "python_stubgen")]
17use pyo3_stub_gen::{
18    generate::MethodType,
19    inventory::submit,
20    type_info::{MethodInfo, ParameterDefault, ParameterInfo, ParameterKind, PyMethodsInfo},
21};
22
23use spenso::{
24    algebra::complex::{Complex, RealOrComplex, symbolica_traits::CompiledComplexEvaluatorSpenso},
25    tensors::{
26        data::{DenseTensor, GetTensorData, SetTensorData, SparseOrDense, SparseTensor},
27        parametric::{
28            ConcreteOrParam, EvalTensor, ParamOrConcrete, ParamTensor, atomcore::TensorAtomOps,
29        },
30    },
31};
32
33use spenso::{
34    network::parsing::ShadowedStructure,
35    structure::{
36        HasStructure, PermutedStructure, ScalarTensor, TensorStructure,
37        abstract_index::AbstractIndex, permuted::Perm,
38    },
39    tensors::{
40        complex::RealOrComplexTensor,
41        data::{DataTensor, StorageTensor},
42        parametric::{LinearizedEvalTensor, MixedTensor},
43    },
44};
45use structure::{ConvertibleToStructure, SpensoIndices};
46use symbolica::{
47    api::python::SymbolicaCommunityModule,
48    atom::Atom,
49    domains::{float::Complex as SymComplex, rational::Rational},
50    evaluate::{CompileOptions, ExportSettings, FunctionMap, InlineASM, OptimizationSettings},
51    poly::PolyVariable,
52};
53
54use symbolica::api::python::PythonExpression;
55
56#[cfg(feature = "python_stubgen")]
57use pyo3_stub_gen::{PyStubType, TypeInfo, define_stub_info_gatherer, derive::*};
58
59pub mod library;
60pub mod library_tensor;
61pub mod network;
62pub mod structure;
63
64trait ModuleInit: PyClass {
65    fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
66        m.add_class::<Self>()
67    }
68}
69
70pub struct SpensoModule;
71
72impl SymbolicaCommunityModule for SpensoModule {
73    fn get_name() -> String {
74        "spenso".to_string()
75    }
76
77    fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
78        initialize_spenso(m)
79    }
80
81    fn initialize(_py: Python) -> PyResult<()> {
82        idenso::representations::initialize();
83        Ok(())
84    }
85}
86
87pub(crate) fn initialize_spenso(m: &Bound<'_, PyModule>) -> PyResult<()> {
88    use library_tensor::LibrarySpensor;
89    use network::ExecutionMode;
90
91    // m.add_function(?)?;
92    SpensoNet::init(m)?;
93    ExecutionMode::init(m)?;
94    Spensor::init(m)?;
95    LibrarySpensor::init(m)?;
96    SpensoIndices::init(m)?;
97    SpensorLibrary::init(m)?;
98    Ok(())
99}
100
101/// A tensor class that can be either dense or sparse with flexible data types.
102///
103/// The tensor can store data as floats, complex numbers, or symbolic expressions.
104/// Tensors have an associated structure that defines their shape and index properties.
105///
106/// Examples
107/// --------
108/// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation
109/// >>> structure = TensorIndices(Representation.euc(4)(1))
110/// >>> data = [1.0, 2.0, 3.0, 4.0]
111/// >>> tensor = Tensor.dense(structure, data)
112/// >>> sparse_tensor = Tensor.sparse(structure, float)
113#[cfg_attr(feature = "python_stubgen", gen_stub_pyclass)]
114#[pyclass(name = "Tensor", module = "symbolica.community.spenso")]
115#[derive(Clone)]
116pub struct Spensor {
117    tensor: PermutedStructure<MixedTensor<f64, ShadowedStructure<AbstractIndex>>>,
118}
119
120impl Deref for Spensor {
121    type Target = MixedTensor<f64, ShadowedStructure<AbstractIndex>>;
122
123    fn deref(&self) -> &Self::Target {
124        &self.tensor.structure
125    }
126}
127
128impl ModuleInit for Spensor {}
129
130// #[gen_stub_pyclass_enum]
131
132#[derive(FromPyObject)]
133pub enum SliceOrIntOrExpanded<'a> {
134    Slice(Bound<'a, PySlice>),
135    Int(usize),
136    Expanded(Vec<usize>),
137}
138
139#[cfg(feature = "python_stubgen")]
140impl PyStubType for SliceOrIntOrExpanded<'_> {
141    fn type_input() -> pyo3_stub_gen::TypeInfo {
142        TypeInfo::builtin("slice") | usize::type_input() | TypeInfo::list_of::<usize>()
143    }
144
145    fn type_output() -> pyo3_stub_gen::TypeInfo {
146        TypeInfo::builtin("slice") | usize::type_input() | TypeInfo::list_of::<usize>()
147    }
148}
149
150#[derive(IntoPyObject)]
151pub enum TensorElements {
152    Real(Py<PyFloat>),
153    Complex(Py<PyComplex>),
154    Symbolica(PythonExpression),
155}
156
157#[cfg(feature = "python_stubgen")]
158impl PyStubType for TensorElements {
159    fn type_input() -> pyo3_stub_gen::TypeInfo {
160        PythonExpression::type_input() | Complex::type_input() | PyFloat::type_input()
161    }
162
163    fn type_output() -> TypeInfo {
164        PythonExpression::type_output() | Complex::type_output() | PyFloat::type_output()
165    }
166}
167
168impl From<ConcreteOrParam<RealOrComplex<f64>>> for TensorElements {
169    fn from(value: ConcreteOrParam<RealOrComplex<f64>>) -> Self {
170        match value {
171            ConcreteOrParam::Concrete(RealOrComplex::Real(f)) => {
172                TensorElements::Real(Python::attach(|py| {
173                    PyFloat::new(py, f).as_unbound().to_owned()
174                }))
175            }
176            ConcreteOrParam::Concrete(RealOrComplex::Complex(c)) => {
177                TensorElements::Complex(Python::attach(|py| {
178                    PyComplex::from_doubles(py, c.re, c.im)
179                        .as_unbound()
180                        .to_owned()
181                }))
182            }
183            ConcreteOrParam::Param(p) => TensorElements::Symbolica(PythonExpression::from(p)),
184        }
185    }
186}
187
188#[cfg_attr(feature = "python_stubgen", gen_stub_pymethods)]
189#[pymethods]
190impl Spensor {
191    pub fn structure(&self) -> SpensoIndices {
192        SpensoIndices {
193            structure: PermutedStructure {
194                structure: self.tensor.structure.structure().clone(),
195                rep_permutation: self.tensor.rep_permutation.clone(),
196                index_permutation: self.tensor.index_permutation.clone(),
197            },
198        }
199    }
200
201    #[staticmethod]
202    /// Create a new sparse empty tensor with the given structure and data type.
203    ///
204    /// Parameters
205    /// ----------
206    /// structure : TensorIndices or list of Slots
207    ///     The tensor structure defining shape and index properties
208    /// type_info : type
209    ///     The data type - either `float` or `Expression` class
210    ///
211    /// Returns
212    /// -------
213    /// Tensor
214    ///     A new sparse tensor with all elements initially zero
215    ///
216    /// Examples
217    /// --------
218    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
219    /// >>> structure = TensorIndices(R.euc(3)(1), R.euc(3)(2))
220    /// >>> sparse_float = Tensor.sparse(structure, float)
221    /// >>> sparse_sym = Tensor.sparse(structure, symbolica.Expression)
222    pub fn sparse(
223        structure: ConvertibleToStructure,
224        type_info: Bound<'_, PyType>,
225    ) -> PyResult<Spensor> {
226        if type_info.is_subclass_of::<PyFloat>()? {
227            Ok(Spensor {
228                tensor: structure
229                    .0
230                    .structure
231                    .map_structure(|s| SparseTensor::<f64, _>::empty(s, 0.0).into()),
232            })
233        } else if type_info.is_subclass_of::<PythonExpression>()? {
234            Ok(Spensor {
235                tensor: structure.0.structure.map_structure(|s| {
236                    ParamOrConcrete::Param(ParamTensor::from(SparseTensor::<Atom, _>::empty(
237                        s,
238                        Atom::Zero,
239                    )))
240                }),
241            })
242        } else {
243            Err(PyTypeError::new_err("Only float type supported"))
244        }
245    }
246
247    #[staticmethod]
248    /// Create a new dense tensor with the given structure and data.
249    ///
250    /// Parameters
251    /// ----------
252    /// structure : TensorIndices or list of Slots
253    ///     The tensor structure defining shape and index properties
254    /// data : list of float, complex, or Expression
255    ///     The tensor data in row-major order
256    ///
257    /// Returns
258    /// -------
259    /// Tensor
260    ///     A new dense tensor with the specified data
261    ///
262    /// Examples
263    /// --------
264    /// >>> from symbolica import S
265    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
266    /// >>> structure = TensorIndices(R.euc(2)(1), R.euc(2)(2))
267    /// >>> data = [1.0, 2.0, 3.0, 4.0]
268    /// >>> tensor = Tensor.dense(structure, data)
269    /// >>> x, y = S("x", "y")
270    /// >>> sym_data = [x, y, x * y, x + y]
271    /// >>> sym_tensor = Tensor.dense(structure, sym_data)
272    pub fn dense(structure: ConvertibleToStructure, data: AtomsOrFloats) -> PyResult<Spensor> {
273        let dense = match data {
274            AtomsOrFloats::Floats(f) => {
275                DenseTensor::<f64, _>::from_data(f, structure.0.structure.structure)
276                    .map_err(|e| PyOverflowError::new_err(e.to_string()))?
277                    .into()
278            }
279            AtomsOrFloats::Atoms(a) => ParamOrConcrete::Param(ParamTensor::from(
280                DenseTensor::<Atom, _>::from_data(a, structure.0.structure.structure)
281                    .map_err(|e| PyOverflowError::new_err(e.to_string()))?,
282            )),
283            AtomsOrFloats::Complex(c) => {
284                MixedTensor::Concrete(RealOrComplexTensor::Complex(DataTensor::Dense(
285                    DenseTensor::<Complex<f64>, _>::from_data(c, structure.0.structure.structure)
286                        .map_err(|e| PyOverflowError::new_err(e.to_string()))?,
287                )))
288            }
289        };
290
291        let dense = PermutedStructure {
292            structure: dense,
293            rep_permutation: structure.0.structure.rep_permutation,
294            index_permutation: structure.0.structure.index_permutation,
295        };
296
297        Ok(Spensor {
298            tensor: dense.permute_inds_wrapped(),
299        })
300    }
301    #[staticmethod]
302    /// Create a scalar tensor with value 1.0.
303    ///
304    /// Returns
305    /// -------
306    /// Tensor
307    ///     A scalar tensor containing the value 1.0
308    ///
309    /// Examples
310    /// --------
311    /// >>> from symbolica.community.spenso import Tensor
312    /// >>> one = Tensor.one()
313    pub fn one() -> Spensor {
314        Spensor {
315            tensor: PermutedStructure::identity(ParamOrConcrete::new_scalar(
316                ConcreteOrParam::Concrete(RealOrComplex::Real(1.)),
317            )),
318        }
319    }
320
321    #[staticmethod]
322    /// Create a scalar tensor with value 0.0.
323    ///
324    /// Returns
325    /// -------
326    /// Tensor
327    ///     A scalar tensor containing the value 0.0
328    ///
329    /// Examples
330    /// --------
331    /// >>> from symbolica.community.spenso import Tensor
332    /// >>> zero = Tensor.zero()
333    pub fn zero() -> Spensor {
334        Spensor {
335            tensor: PermutedStructure::identity(ParamOrConcrete::new_scalar(
336                ConcreteOrParam::Concrete(RealOrComplex::Real(2.)),
337            )),
338        }
339    }
340
341    #[allow(clippy::wrong_self_convention)]
342    /// Convert this tensor to dense storage format.
343    ///
344    /// Convert this tensor to dense storage format.
345    ///
346    /// Converts sparse tensors to dense format in-place. Dense tensors are unchanged.
347    /// This allocates memory for all tensor elements.
348    ///
349    /// Examples
350    /// --------
351    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
352    /// >>> structure = TensorIndices(R.euc(4)(2))
353    /// >>> tensor = Tensor.sparse(structure, float)
354    /// >>> tensor[0] = 1.0
355    /// >>> tensor.to_dense()
356    fn to_dense(&mut self) {
357        self.tensor.structure = self.tensor.structure.clone().to_dense();
358    }
359
360    #[allow(clippy::wrong_self_convention)]
361    /// Convert this tensor to sparse storage format.
362    ///
363    /// Convert this tensor to sparse storage format.
364    ///
365    /// Converts dense tensors to sparse format in-place, only storing non-zero elements.
366    /// This can save memory for tensors with many zero elements.
367    ///
368    /// Examples
369    /// --------
370    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
371    /// >>> structure = TensorIndices(R.euc(2)(2), R.euc(2)(1))
372    /// >>> data = [1.0, 0.0, 0.0, 2.0]
373    /// >>> tensor = Tensor.dense(structure, data)
374    /// >>> tensor.to_sparse()
375    fn to_sparse(&mut self) {
376        self.tensor.structure = self.tensor.structure.clone().to_sparse();
377    }
378
379    fn __repr__(&self) -> String {
380        format!("Spensor(\n{})", self.tensor)
381    }
382
383    fn __str__(&self) -> String {
384        format!("{}", self.tensor.structure)
385    }
386
387    fn __len__(&self) -> usize {
388        self.size().unwrap()
389    }
390
391    fn __getitem__(&self, item: SliceOrIntOrExpanded) -> PyResult<Py<PyAny>> {
392        let out = match item {
393            SliceOrIntOrExpanded::Int(i) => self
394                .get_owned_linear(i.into())
395                .ok_or(PyIndexError::new_err("flat index out of bounds"))?,
396            SliceOrIntOrExpanded::Expanded(idxs) => self
397                .get_owned(&idxs)
398                .map_err(|s| PyIndexError::new_err(s.to_string()))?,
399            SliceOrIntOrExpanded::Slice(s) => {
400                let r = s.indices(self.size().unwrap() as isize)?;
401
402                let start = if r.start < 0 {
403                    (r.slicelength as isize + r.start) as usize
404                } else {
405                    r.start as usize
406                };
407
408                let end = if r.stop < 0 {
409                    (r.slicelength as isize + r.stop) as usize
410                } else {
411                    r.stop as usize
412                };
413
414                let (range, step) = if r.step < 0 {
415                    (end..start, -r.step as usize)
416                } else {
417                    (start..end, r.step as usize)
418                };
419
420                let slice: Option<Vec<TensorElements>> = range
421                    .step_by(step)
422                    .map(|i| self.get_owned_linear(i.into()).map(TensorElements::from))
423                    .collect();
424
425                if let Some(slice) = slice {
426                    return Ok(
427                        Python::attach(|py| slice.into_pyobject(py).map(|a| a.unbind()))?
428                            .into_any(),
429                    );
430                } else {
431                    return Err(PyIndexError::new_err("slice out of bounds"));
432                }
433            }
434        };
435
436        Python::attach(|py| {
437            TensorElements::from(out)
438                .into_pyobject(py)
439                .map(|a| a.unbind())
440        })
441    }
442
443    /// Set tensor element(s) at the specified index or indices.
444    ///
445    /// Parameters
446    /// ----------
447    /// item : int or list of int
448    ///     Index specification (int for flat index, list of int for coordinates)
449    /// value : float, complex, or Expression
450    ///     The value to set
451    ///
452    /// Examples
453    /// --------
454    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
455    /// >>> structure = TensorIndices(R.euc(2)(2), R.euc(2)(1))
456    /// >>> tensor = Tensor.sparse(structure, float)
457    /// >>> tensor[0] = 4.0
458    /// >>> tensor[1, 1] = 1.0
459    fn __setitem__<'py>(
460        &mut self,
461        item: Bound<'py, PyAny>,
462        value: Bound<'py, PyAny>,
463    ) -> anyhow::Result<()> {
464        let value = if let Ok(v) = value.extract::<PythonExpression>() {
465            ConcreteOrParam::Param(v.expr)
466        } else if let Ok(v) = value.extract::<f64>() {
467            ConcreteOrParam::Concrete(RealOrComplex::Real(v))
468        } else {
469            return Err(anyhow!("Value must be a PythonExpression or a float"));
470        };
471
472        if let Ok(flat_index) = item.extract::<usize>() {
473            self.tensor.structure.set_flat(flat_index.into(), value)
474        } else if let Ok(expanded_idxs) = item.extract::<Vec<usize>>() {
475            self.tensor.structure.set(&expanded_idxs, value)
476        } else {
477            Err(anyhow!("Index must be an integer"))
478        }
479    }
480
481    #[pyo3(signature =
482           (constants,
483           funs,
484           params,
485           iterations = 100,
486           n_cores = 4,
487           verbose = false),
488           )]
489    /// Create an optimized evaluator for symbolic tensor expressions.
490    ///
491    /// Create an optimized evaluator for symbolic tensor expressions.
492    ///
493    /// Compiles the symbolic expressions in this tensor into an optimized evaluation tree
494    /// that can efficiently compute numerical values for different parameter inputs.
495    ///
496    /// Parameters
497    /// ----------
498    /// constants : dict
499    ///     Dict mapping symbolic expressions to their constant numerical values
500    /// funs : dict
501    ///     Dict mapping function signatures to their symbolic definitions
502    /// params : list of Expression
503    ///     List of symbolic parameters that will be varied during evaluation
504    /// iterations : int, optional
505    ///     Number of optimization iterations for Horner scheme (default: 100)
506    /// n_cores : int, optional
507    ///     Number of CPU cores to use for optimization (default: 4)
508    /// verbose : bool, optional
509    ///     Whether to print optimization progress (default: False)
510    ///
511    /// Returns
512    /// -------
513    /// TensorEvaluator
514    ///     An optimized evaluator for efficient numerical evaluation
515    ///
516    /// Examples
517    /// --------
518    /// >>> from symbolica import S
519    /// >>> from symbolica.community.spenso import Tensor, TensorIndices, Representation as R
520    /// >>> x, y = S("x", "y")
521    /// >>> structure = TensorIndices(R.euc(2)(1))
522    /// >>> tensor = Tensor.dense(structure, [x * y, x + y])
523    /// >>> evaluator = tensor.evaluator(constants={}, funs={}, params=[x, y], iterations=50)
524    /// >>> results = evaluator.evaluate_complex([[1.0, 2.0], [3.0, 4.0]])
525    pub fn evaluator(
526        &self,
527        constants: HashMap<PythonExpression, PythonExpression>,
528        funs: HashMap<(PolyVariable, String, Vec<PolyVariable>), PythonExpression>,
529        params: Vec<PythonExpression>,
530        iterations: usize,
531        n_cores: usize,
532        verbose: bool,
533    ) -> PyResult<SpensoExpressionEvaluator> {
534        let mut fn_map = FunctionMap::new();
535
536        for (k, v) in &constants {
537            if let Ok(r) = v.expr.clone().try_into() {
538                fn_map.add_constant(k.expr.clone(), r);
539            } else {
540                Err(exceptions::PyValueError::new_err(
541                    "Constants must be rationals. If this is not possible, pass the value as a parameter",
542                ))?
543            }
544        }
545
546        for ((symbol, rename, args), body) in &funs {
547            let symbol = symbol
548                .get_id()
549                .ok_or(exceptions::PyValueError::new_err(format!(
550                    "Bad function name {}",
551                    symbol
552                )))?;
553            let args: Vec<_> = args
554                .iter()
555                .map(|x| {
556                    x.get_id().ok_or(exceptions::PyValueError::new_err(format!(
557                        "Bad function name {}",
558                        symbol
559                    )))
560                })
561                .collect::<Result<_, _>>()?;
562
563            fn_map
564                .add_function(symbol, rename.clone(), args, body.expr.clone())
565                .map_err(|e| {
566                    exceptions::PyValueError::new_err(format!("Could not add function: {}", e))
567                })?;
568        }
569
570        let settings = OptimizationSettings {
571            horner_iterations: iterations,
572            n_cores,
573            verbose,
574            ..OptimizationSettings::default()
575        };
576
577        let params: Vec<_> = params.iter().map(|x| x.expr.clone()).collect();
578
579        let mut evaltensor = match &self.tensor.structure {
580            ParamOrConcrete::Param(s) => s.to_evaluation_tree(&fn_map, &params).map_err(|e| {
581                exceptions::PyValueError::new_err(format!("Could not create evaluator: {}", e))
582            })?,
583            ParamOrConcrete::Concrete(_) => return Err(PyRuntimeError::new_err("not atom")),
584        };
585
586        evaltensor.optimize_horner_scheme(&settings);
587
588        evaltensor.common_subexpression_elimination();
589        let linear = evaltensor.linearize(None, false);
590        Ok(SpensoExpressionEvaluator {
591            eval: None,
592            eval_complex: linear
593                .clone()
594                .map_coeff(&|x| Complex::new(x.re.to_f64(), x.im.to_f64())),
595            eval_rat: linear,
596        })
597    }
598
599    /// Extract the scalar value from a rank-0 (scalar) tensor.
600    ///
601    /// Returns
602    /// -------
603    /// Expression
604    ///     The scalar expression contained in this tensor
605    ///
606    /// Raises
607    /// ------
608    /// RuntimeError
609    ///     If the tensor is not a scalar
610    ///
611    /// Examples
612    /// --------
613    /// >>> from symbolica.community.spenso import Tensor
614    /// >>> scalar_tensor = Tensor.one()
615    /// >>> value = scalar_tensor.scalar()
616    fn scalar(&self) -> PyResult<PythonExpression> {
617        self.tensor
618            .structure
619            .clone()
620            .scalar()
621            .map(|r| PythonExpression { expr: r.into() })
622            .ok_or_else(|| PyRuntimeError::new_err("No scalar found"))
623    }
624}
625
626impl From<DataTensor<f64, ShadowedStructure<AbstractIndex>>> for Spensor {
627    fn from(value: DataTensor<f64, ShadowedStructure<AbstractIndex>>) -> Self {
628        Spensor {
629            tensor: PermutedStructure::identity(MixedTensor::Concrete(RealOrComplexTensor::Real(
630                value,
631            ))),
632        }
633    }
634}
635
636impl From<DataTensor<Complex<f64>, ShadowedStructure<AbstractIndex>>> for Spensor {
637    fn from(value: DataTensor<Complex<f64>, ShadowedStructure<AbstractIndex>>) -> Self {
638        Spensor {
639            tensor: PermutedStructure::identity(MixedTensor::Concrete(
640                RealOrComplexTensor::Complex(value.map_data(|c| c)),
641            )),
642        }
643    }
644}
645impl From<MixedTensor<f64, ShadowedStructure<AbstractIndex>>> for Spensor {
646    fn from(value: MixedTensor<f64, ShadowedStructure<AbstractIndex>>) -> Self {
647        Spensor {
648            tensor: PermutedStructure::identity(value),
649        }
650    }
651}
652
653/// An optimized evaluator for symbolic tensor expressions.
654///
655/// An optimized evaluator for symbolic tensor expressions.
656///
657/// This class provides efficient numerical evaluation of symbolic tensor expressions
658/// after optimization. It supports both real and complex-valued evaluations.
659///
660/// Create instances using the `Tensor.evaluator()` method rather than directly.
661///
662/// Examples
663/// --------
664/// >>> evaluator = my_tensor.evaluator(constants={}, funs={}, params=[x, y])
665/// >>> results = evaluator.evaluate([[1.0, 2.0], [3.0, 4.0]])
666#[cfg_attr(feature = "python_stubgen", gen_stub_pyclass)]
667#[pyclass(name = "TensorEvaluator", module = "symbolica.community.spenso")]
668#[derive(Clone)]
669pub struct SpensoExpressionEvaluator {
670    pub eval_rat: LinearizedEvalTensor<SymComplex<Rational>, ShadowedStructure<AbstractIndex>>,
671    pub eval: Option<LinearizedEvalTensor<f64, ShadowedStructure<AbstractIndex>>>,
672    pub eval_complex: LinearizedEvalTensor<Complex<f64>, ShadowedStructure<AbstractIndex>>,
673}
674
675#[cfg_attr(feature = "python_stubgen", gen_stub_pymethods)]
676#[pymethods]
677impl SpensoExpressionEvaluator {
678    /// Evaluate the tensor expression for multiple real-valued parameter inputs.
679    ///
680    /// Parameters
681    /// ----------
682    /// inputs : list of list of float
683    ///     List of parameter value lists, where each inner list contains
684    ///     numerical values for all parameters in the same order as specified
685    ///     when creating the evaluator
686    ///
687    /// Returns
688    /// -------
689    /// list of Tensor
690    ///     List of evaluated tensors, one for each input parameter set
691    ///
692    /// Raises
693    /// ------
694    /// ValueError
695    ///     If the evaluator contains complex coefficients
696    ///
697    /// Examples
698    /// --------
699    /// >>> results = evaluator.evaluate([[1.0, 2.0], [3.0, 4.0]])
700    fn evaluate(&mut self, inputs: Vec<Vec<f64>>) -> PyResult<Vec<Spensor>> {
701        let eval = self.eval.as_mut().ok_or(exceptions::PyValueError::new_err(
702            "Evaluator contains complex coefficients. Use evaluate_complex instead.",
703        ))?;
704
705        Ok(inputs.iter().map(|s| eval.evaluate(s).into()).collect())
706    }
707
708    /// Evaluate the expression for multiple inputs and return the results.
709    fn evaluate_complex(&mut self, inputs: Vec<Vec<Complex<f64>>>) -> Vec<Spensor> {
710        let eval = &mut self.eval_complex;
711
712        inputs.iter().map(|s| eval.evaluate(s).into()).collect()
713    }
714
715    /// Compile the evaluator to a shared library using C++ for maximum performance.
716    ///
717    /// Compile the evaluator to a shared library using C++ for maximum performance.
718    ///
719    /// Generates optimized C++ code with optional inline assembly and compiles it
720    /// into a shared library that can be loaded for extremely fast evaluation.
721    ///
722    /// Parameters
723    /// ----------
724    /// function_name : str
725    ///     Name for the generated C++ function
726    /// filename : str
727    ///     Path for the generated C++ source file
728    /// library_name : str
729    ///     Name for the compiled shared library
730    /// inline_asm : str, optional
731    ///     Type of inline assembly optimization ("default", "x64", "aarch64", "none")
732    /// optimization_level : int, optional
733    ///     Compiler optimization level 0-3 (default: 3)
734    /// compiler_path : str, optional
735    ///     Path to specific C++ compiler (default: system default)
736    /// custom_header : str, optional
737    ///     Additional C++ header code to include
738    ///
739    /// Returns
740    /// -------
741    /// CompiledTensorEvaluator
742    ///     A compiled evaluator for maximum performance evaluation
743    ///
744    /// Examples
745    /// --------
746    /// >>> compiled = evaluator.compile(
747    /// ...     function_name="fast_eval",
748    /// ...     filename="tensor_eval.cpp",
749    /// ...     library_name="tensor_lib",
750    /// ...     optimization_level=3,
751    /// ... )
752    /// >>> results = compiled.evaluate_complex([[1.0, 2.0], [3.0, 4.0]])
753    #[pyo3(signature =
754        (function_name,
755        filename,
756        library_name,
757        // number_type,
758        inline_asm = "default",
759        optimization_level = 3,
760        compiler_path = None,
761        // compiler_flags = None,
762        custom_header = None,
763        // cuda_number_of_evaluations = 1,
764        // cuda_block_size = 512
765    ))]
766    #[allow(clippy::too_many_arguments)]
767    fn compile(
768        &self,
769        function_name: &str,
770        filename: &str,
771        library_name: &str,
772        // number_type: &str,
773        inline_asm: &str,
774        optimization_level: u8,
775        compiler_path: Option<&str>,
776        // compiler_flags: Option<Vec<String>>,
777        custom_header: Option<String>,
778        // cuda_number_of_evaluations: usize,
779        // cuda_block_size: usize,
780        // py: Python<'_>,
781    ) -> PyResult<SpensoCompiledExpressionEvaluator> {
782        let mut options = CompileOptions {
783            optimization_level: optimization_level as usize,
784            ..Default::default()
785        };
786
787        if let Some(compiler_path) = compiler_path {
788            options.compiler = compiler_path.to_string();
789        }
790        let inline_asm = match inline_asm.to_lowercase().as_str() {
791            "default" => InlineASM::default(),
792            "x64" => InlineASM::X64,
793            "aarch64" => InlineASM::AArch64,
794            "none" => InlineASM::None,
795            _ => {
796                return Err(exceptions::PyValueError::new_err(
797                    "Invalid inline assembly type specified.",
798                ));
799            }
800        };
801
802        Ok(SpensoCompiledExpressionEvaluator {
803            eval: self
804                .eval_complex
805                .export_cpp::<Complex<f64>>(
806                    filename,
807                    function_name,
808                    ExportSettings {
809                        include_header: true,
810                        inline_asm,
811                        custom_header,
812                        // ..Default::default()
813                    },
814                )
815                .map_err(|e| exceptions::PyValueError::new_err(format!("Export error: {}", e)))?
816                .compile(library_name, options)
817                .map_err(|e| {
818                    exceptions::PyValueError::new_err(format!("Compilation error: {}", e))
819                })?
820                .load()
821                .map_err(|e| {
822                    exceptions::PyValueError::new_err(format!("Library loading error: {}", e))
823                })?,
824        })
825    }
826}
827
828/// A compiled and optimized evaluator for maximum performance tensor evaluation.
829///
830/// This class wraps a compiled C++ shared library for extremely fast numerical
831/// evaluation of tensor expressions. It only supports complex-valued evaluation
832/// as this is the most general case.
833///
834/// A compiled and optimized evaluator for maximum performance tensor evaluation.
835///
836/// This class wraps a compiled C++ shared library for extremely fast numerical
837/// evaluation of tensor expressions. It only supports complex-valued evaluation
838/// as this is the most general case.
839///
840/// Create instances using the `TensorEvaluator.compile()` method.
841///
842/// Examples
843/// --------
844/// >>> compiled = evaluator.compile("eval_func", "code.cpp", "lib")
845/// >>> results = compiled.evaluate_complex(large_input_batch)
846#[cfg_attr(feature = "python_stubgen", gen_stub_pyclass)]
847#[pyclass(
848    name = "CompiledTensorEvaluator",
849    module = "symbolica.community.spenso"
850)]
851#[derive(Clone)]
852pub struct SpensoCompiledExpressionEvaluator {
853    pub eval: EvalTensor<CompiledComplexEvaluatorSpenso, ShadowedStructure<AbstractIndex>>,
854}
855
856#[cfg_attr(feature = "python_stubgen", gen_stub_pymethods)]
857#[pymethods]
858impl SpensoCompiledExpressionEvaluator {
859    /// Evaluate the tensor expression for multiple complex-valued parameter inputs.
860    ///
861    /// Evaluate the tensor expression for multiple complex-valued parameter inputs.
862    ///
863    /// Uses the compiled C++ code for maximum performance evaluation with complex numbers.
864    ///
865    /// Parameters
866    /// ----------
867    /// inputs : list of list of complex
868    ///     List of parameter value lists, where each inner list contains
869    ///     complex values for all parameters in the same order as specified
870    ///     when creating the original evaluator
871    ///
872    /// Returns
873    /// -------
874    /// list of Tensor
875    ///     List of evaluated tensors, one for each input parameter set
876    ///
877    /// Examples
878    /// --------
879    /// >>> complex_inputs = [
880    /// ...     [1.0+2.0j, 3.0+0.0j],
881    /// ...     [0.0+1.0j, 2.0+1.0j]
882    /// ... ]
883    /// >>> results = compiled_evaluator.evaluate_complex(complex_inputs)
884    fn evaluate_complex(&mut self, inputs: Vec<Vec<Complex<f64>>>) -> Vec<Spensor> {
885        inputs
886            .iter()
887            .map(|s| self.eval.evaluate(s).into())
888            .collect()
889    }
890}
891
892#[cfg(feature = "python_stubgen")]
893submit! {
894    PyMethodsInfo {
895        struct_id: std::any::TypeId::of::<crate::structure::SpensoRepresentation>,
896        attrs: &[],
897        getters: &[],
898        setters: &[],
899        file: file!(),
900        line: line!(),
901        column: column!(),
902        methods: &[
903            MethodInfo {
904                name: "__call__",
905                parameters: &[
906                    ParameterInfo {
907                        name: "aind",
908                        kind: ParameterKind::PositionalOrKeyword,
909                        default: ParameterDefault::None,
910                        type_info: structure::ConvertibleToAbstractIndex::type_input,
911                    },
912                ],
913                r#type: MethodType::Instance,
914                r#return: structure::SpensoSlot::type_output,
915                doc:r##"Create a slot from this representation, by specifying an index.
916
917Parameters
918----------
919aind : int, str, or Symbol
920    The index specification
921
922Returns
923-------
924Slot
925    A new Slot object with the specified index
926
927Examples
928--------
929>>> from symbolica.community.spenso import Representation
930>>> import symbolica as sp
931>>> rep = Representation.euc(3)
932>>> slot1 = rep('mu')
933>>> slot2 = rep(1)
934>>> slot3 = rep(sp.S('nu'))
935"##,
936                is_async: false,
937                deprecated: None,
938                type_ignored: None,
939                is_overload: true,
940            },
941            MethodInfo {
942                name: "__call__",
943                parameters: &[
944                    ParameterInfo {
945                        name: "aind",
946                        kind: ParameterKind::PositionalOrKeyword,
947                        default: ParameterDefault::None,
948                        type_info: PythonExpression::type_input
949                    },
950                ],
951                r#type: MethodType::Instance,
952                r#return: || PythonExpression::type_output()| structure::SpensoSlot::type_output(),
953                doc:r##"Create a slot or symbolic expression from this representation.
954
955Parameters
956----------
957aind : Expression
958    The index specification (Expression creates symbolic representation)
959
960Returns
961-------
962Expression
963    A symbolic expression representing this representation
964
965Examples
966--------
967>>> from symbolica.community.spenso import Representation
968>>> import symbolica as sp
969>>> rep = Representation.euc(3)
970>>> expr = rep(sp.E("cos(x)"))
971"##,
972                is_async: false,
973                deprecated: None,
974                type_ignored: None,
975                is_overload: true,
976            }
977        ]
978    }
979}
980
981// static NONE: LazyLock<String> = LazyLock::new(|| "None".to_string());
982
983#[cfg(feature = "python_stubgen")]
984submit! {
985    PyMethodsInfo {
986        struct_id: std::any::TypeId::of::<Spensor>,
987        attrs: &[],
988        getters: &[],
989        setters: &[],
990        file: file!(),
991        line: line!(),
992        column: column!(),
993        methods: &[
994            MethodInfo {
995                name: "__iter__",
996                parameters: &[],
997                r#type: MethodType::Instance,
998                r#return:||
999                TypeInfo {
1000                    name: "typing.Iterator[typing.Any]".into(),
1001                    import: std::collections::HashSet::new(),
1002                },
1003                doc:r##"Iterator"##,
1004                is_async: false,
1005                deprecated: None,
1006                type_ignored: None,
1007                is_overload: true,
1008            },
1009            MethodInfo {
1010                name: "__getitem__",
1011                parameters: &[
1012                    ParameterInfo {
1013                        name: "item",
1014                        kind: ParameterKind::PositionalOrKeyword,
1015                        default: ParameterDefault::None,
1016                        type_info: || TypeInfo::builtin("slice"),
1017                    },
1018                ],
1019                r#type: MethodType::Instance,
1020                r#return: Vec::<TensorElements>::type_output,
1021                doc:r##"Get tensor elements at the specified range of indices.
1022
1023Parameters
1024----------
1025item : slice
1026    Slice object defining the range of indices
1027
1028Returns
1029-------
1030list of float, complex, or Expression
1031    The tensor elements at the specified range
1032"##,
1033                is_async: false,
1034                deprecated: None,
1035                type_ignored: None,
1036                is_overload: true,
1037            },
1038            MethodInfo {
1039                name: "__getitem__",
1040                parameters: &[
1041                    ParameterInfo {
1042                        name: "item",
1043                        kind: ParameterKind::PositionalOrKeyword,
1044                        default: ParameterDefault::None,
1045                        type_info: || Vec::<usize>::type_input()|usize::type_input()
1046                    },
1047                ],
1048                r#type: MethodType::Instance,
1049                r#return: TensorElements::type_output,
1050                doc:r##"Get tensor element at the specified index or indices.
1051
1052Parameters
1053----------
1054item : int or list of int
1055    Index specification (int for flat index, list of int for coordinates)
1056
1057Returns
1058-------
1059float, complex, or Expression
1060    The tensor element at the specified index
1061"##,
1062                is_async: false,
1063                deprecated: None,
1064                type_ignored: None,
1065                is_overload: true,
1066            },
1067            MethodInfo {
1068                name: "__setitem__",
1069                parameters: &[
1070                    ParameterInfo {
1071                        name: "item",
1072                        kind: ParameterKind::PositionalOrKeyword,
1073                        default: ParameterDefault::None,
1074                        type_info: || usize::type_input()|Vec::<usize>::type_input()
1075
1076                    },
1077                    ParameterInfo {
1078                        name: "value",
1079                        kind: ParameterKind::PositionalOrKeyword,
1080                        default: ParameterDefault::None,
1081                        type_info: ||TensorElements::type_input()
1082                    },
1083                ],
1084                r#type: MethodType::Instance,
1085                r#return: TypeInfo::none,
1086                doc:r##"Set tensor element at the specified index.
1087
1088Parameters
1089----------
1090item : int or list of int
1091    Index specification (int for flat index, list of int for coordinates)
1092value : float, complex, or Expression
1093    The value to set
1094
1095Examples
1096--------
1097>>> from symbolica.community.spenso import Tensor, TensorIndices, Representation
1098>>> rep = Representation.euc(2)
1099>>> structure = TensorIndices(rep(1), rep(2))
1100>>> tensor = Tensor.sparse(structure, float)
1101>>> tensor[0] = 1.0
1102>>> tensor[1, 1] = 2.0
1103"##,
1104                is_async: false,
1105                deprecated: None,
1106                type_ignored: None,
1107                is_overload: true,
1108            },
1109        ]
1110    }
1111}
1112
1113#[cfg(feature = "python_stubgen")]
1114define_stub_info_gatherer!(stub_info);