1#![deny(warnings)]
16
17pub mod executor;
18pub mod hints;
19pub mod ops;
20
21pub use executor::*;
23pub use hints::*;
24
25use scirs2_core::numeric::{Float, FromPrimitive, Num};
26
27pub fn einsum_ex<T>(spec: &str) -> EinsumBuilder<'_, T>
39where
40 T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
41{
42 EinsumBuilder::new(spec)
43}
44
45pub struct EinsumBuilder<'a, T>
47where
48 T: Clone + Num + Float + FromPrimitive + 'static,
49{
50 spec: String,
51 inputs: Option<&'a [tenrso_core::TensorHandle<T>]>,
52 hints: ExecHints,
53}
54
55impl<'a, T> EinsumBuilder<'a, T>
56where
57 T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
58{
59 pub fn new(spec: impl Into<String>) -> Self {
61 Self {
62 spec: spec.into(),
63 inputs: None,
64 hints: ExecHints::default(),
65 }
66 }
67
68 pub fn inputs(mut self, inputs: &'a [tenrso_core::TensorHandle<T>]) -> Self {
82 self.inputs = Some(inputs);
83 self
84 }
85
86 pub fn hints(mut self, hints: &ExecHints) -> Self {
105 self.hints = hints.clone();
106 self
107 }
108
109 pub fn run(self) -> anyhow::Result<tenrso_core::TensorHandle<T>> {
122 let inputs = self
123 .inputs
124 .ok_or_else(|| anyhow::anyhow!("No inputs provided to einsum_ex"))?;
125
126 let mut executor = CpuExecutor::new();
128 executor.einsum(&self.spec, inputs, &self.hints)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use tenrso_core::{DenseND, TensorHandle};
136
137 #[test]
138 fn test_einsum_ex_builder_matmul() {
139 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
140 let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
141
142 let handle_a = TensorHandle::from_dense_auto(a);
143 let handle_b = TensorHandle::from_dense_auto(b);
144
145 let result = einsum_ex::<f64>("ij,jk->ik")
146 .inputs(&[handle_a, handle_b])
147 .run()
148 .unwrap();
149
150 let result_dense = result.as_dense().unwrap();
151 assert_eq!(result_dense.shape(), &[2, 2]);
152
153 let result_view = result_dense.view();
156 let diff1: f64 = result_view[[0, 0]] - 19.0;
157 let diff2: f64 = result_view[[0, 1]] - 22.0;
158 assert!(diff1.abs() < 1e-10);
159 assert!(diff2.abs() < 1e-10);
160 }
161
162 #[test]
163 fn test_einsum_ex_builder_with_hints() {
164 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
165 let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
166
167 let handle_a = TensorHandle::from_dense_auto(a);
168 let handle_b = TensorHandle::from_dense_auto(b);
169
170 let hints = ExecHints::default();
171
172 let result = einsum_ex::<f64>("ij,jk->ik")
173 .inputs(&[handle_a, handle_b])
174 .hints(&hints)
175 .run()
176 .unwrap();
177
178 let result_dense = result.as_dense().unwrap();
179 assert_eq!(result_dense.shape(), &[2, 2]);
180 }
181
182 #[test]
183 fn test_einsum_ex_builder_three_tensors() {
184 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
185 let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
186 let c = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
187
188 let handle_a = TensorHandle::from_dense_auto(a);
189 let handle_b = TensorHandle::from_dense_auto(b);
190 let handle_c = TensorHandle::from_dense_auto(c);
191
192 let result = einsum_ex::<f64>("ij,jk,kl->il")
193 .inputs(&[handle_a, handle_b, handle_c])
194 .run()
195 .unwrap();
196
197 let result_dense = result.as_dense().unwrap();
198 assert_eq!(result_dense.shape(), &[2, 2]);
199
200 let result_view = result_dense.view();
202 let val: f64 = result_view[[0, 0]];
203 assert!(val.abs() > 0.0);
204 }
205
206 #[test]
207 fn test_einsum_ex_builder_no_inputs() {
208 let result = einsum_ex::<f64>("ij,jk->ik").run();
209
210 assert!(result.is_err());
211 assert!(result
212 .unwrap_err()
213 .to_string()
214 .contains("No inputs provided"));
215 }
216}