1use crate::hints::ExecHints;
6use crate::ops::execute_dense_contraction;
7use anyhow::{anyhow, Result};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use std::collections::HashMap;
10use tenrso_core::{DenseND, TensorHandle};
11use tenrso_planner::{greedy_planner, EinsumSpec, Plan, PlanHints};
12
13pub use super::advanced_indexing::ScatterMode;
15
16#[derive(Clone, Debug)]
18pub enum ReduceOp {
19 Sum,
20 Max,
21 Min,
22 Mean,
23 Prod,
24 All,
25 Any,
26 ArgMax,
27 ArgMin,
28}
29#[derive(Clone, Debug)]
31pub enum BinaryOp {
32 Add,
34 Sub,
36 Mul,
38 Div,
40 Pow,
42 Maximum,
44 Minimum,
46}
47pub(crate) struct MemoryPool<T>
66where
67 T: bytemuck::Pod + bytemuck::Zeroable,
68{
69 pools: HashMap<String, Vec<Vec<T>>>,
72 hits: usize,
74 misses: usize,
75 total_allocations: usize,
76 total_releases: usize,
77 enabled: bool,
78 _phantom: std::marker::PhantomData<T>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub struct PoolStats {
85 pub hits: usize,
87 pub misses: usize,
89 pub total_allocations: usize,
91 pub total_releases: usize,
93 pub hit_rate: f64,
95 pub unique_shapes: usize,
97 pub total_bytes_pooled: usize,
99 pub total_buffers_pooled: usize,
101}
102
103impl<T> MemoryPool<T>
104where
105 T: bytemuck::Pod + bytemuck::Zeroable,
106{
107 pub(crate) fn new() -> Self {
109 Self {
110 pools: HashMap::new(),
111 hits: 0,
112 misses: 0,
113 total_allocations: 0,
114 total_releases: 0,
115 enabled: true,
116 _phantom: std::marker::PhantomData,
117 }
118 }
119
120 pub(crate) fn disabled() -> Self {
122 Self {
123 pools: HashMap::new(),
124 hits: 0,
125 misses: 0,
126 total_allocations: 0,
127 total_releases: 0,
128 enabled: false,
129 _phantom: std::marker::PhantomData,
130 }
131 }
132
133 pub(crate) fn set_enabled(&mut self, enabled: bool) {
135 self.enabled = enabled;
136 if !enabled {
137 self.pools.clear();
139 }
140 }
141
142 pub(crate) fn is_enabled(&self) -> bool {
144 self.enabled
145 }
146
147 #[allow(dead_code)]
155 pub(crate) fn acquire(&mut self, shape: &[usize]) -> Vec<T> {
156 self.total_allocations += 1;
157
158 let total_size: usize = shape.iter().product();
159
160 if !self.enabled {
161 self.misses += 1;
162 return vec![T::zeroed(); total_size];
163 }
164
165 let signature = Self::shape_signature(shape);
166
167 if let Some(pool) = self.pools.get_mut(&signature) {
168 if let Some(mut buffer) = pool.pop() {
169 self.hits += 1;
170 buffer.resize(total_size, T::zeroed());
172 return buffer;
173 }
174 }
175
176 self.misses += 1;
177 vec![T::zeroed(); total_size]
178 }
179
180 #[allow(dead_code)]
187 pub(crate) fn release(&mut self, shape: &[usize], buffer: Vec<T>) {
188 self.total_releases += 1;
189
190 if !self.enabled {
191 return;
193 }
194
195 let signature = Self::shape_signature(shape);
196 let pool = self.pools.entry(signature).or_default();
197
198 const MAX_POOL_SIZE: usize = 16;
199 if pool.len() < MAX_POOL_SIZE {
200 pool.push(buffer);
201 }
202 }
204
205 #[allow(dead_code)]
210 pub(crate) fn shape_signature(shape: &[usize]) -> String {
211 shape
212 .iter()
213 .map(|s| s.to_string())
214 .collect::<Vec<_>>()
215 .join("x")
216 }
217
218 pub(crate) fn stats(&self) -> (usize, usize, f64) {
220 let total = self.hits + self.misses;
221 let hit_rate = if total > 0 {
222 self.hits as f64 / total as f64
223 } else {
224 0.0
225 };
226 (self.hits, self.misses, hit_rate)
227 }
228
229 pub(crate) fn detailed_stats(&self) -> PoolStats {
231 let total = self.hits + self.misses;
232 let hit_rate = if total > 0 {
233 self.hits as f64 / total as f64
234 } else {
235 0.0
236 };
237
238 let unique_shapes = self.pools.len();
239 let mut total_bytes_pooled = 0;
240 let mut total_buffers_pooled = 0;
241
242 let elem_size = std::mem::size_of::<T>();
243
244 for pool in self.pools.values() {
245 total_buffers_pooled += pool.len();
246 for buffer in pool {
247 total_bytes_pooled += buffer.len() * elem_size;
248 }
249 }
250
251 PoolStats {
252 hits: self.hits,
253 misses: self.misses,
254 total_allocations: self.total_allocations,
255 total_releases: self.total_releases,
256 hit_rate,
257 unique_shapes,
258 total_bytes_pooled,
259 total_buffers_pooled,
260 }
261 }
262
263 pub(crate) fn clear(&mut self) {
265 self.pools.clear();
266 self.hits = 0;
267 self.misses = 0;
268 self.total_allocations = 0;
269 self.total_releases = 0;
270 }
271
272 pub(crate) fn num_shapes(&self) -> usize {
274 self.pools.len()
275 }
276
277 pub(crate) fn num_buffers(&self) -> usize {
279 self.pools.values().map(|v| v.len()).sum()
280 }
281}
282pub struct CpuExecutor {
284 memory_pool_f32: MemoryPool<f32>,
288 memory_pool_f64: MemoryPool<f64>,
292 pub num_threads: usize,
294 pub enable_parallel: bool,
296 pub enable_simd: bool,
298 pub enable_tiled_reductions: bool,
300 pub enable_vectorized_broadcast: bool,
302 pub enable_memory_pool: bool,
304}
305impl CpuExecutor {
306 pub fn new() -> Self {
309 Self {
310 memory_pool_f32: MemoryPool::new(),
311 memory_pool_f64: MemoryPool::new(),
312 num_threads: 0,
313 enable_parallel: true,
314 enable_simd: true,
315 enable_tiled_reductions: true,
316 enable_vectorized_broadcast: true,
317 enable_memory_pool: true,
318 }
319 }
320 pub fn with_threads(num_threads: usize) -> Self {
322 Self {
323 memory_pool_f32: MemoryPool::new(),
324 memory_pool_f64: MemoryPool::new(),
325 num_threads,
326 enable_parallel: true,
327 enable_simd: true,
328 enable_tiled_reductions: true,
329 enable_vectorized_broadcast: true,
330 enable_memory_pool: true,
331 }
332 }
333 pub fn serial() -> Self {
335 Self {
336 memory_pool_f32: MemoryPool::new(),
337 memory_pool_f64: MemoryPool::new(),
338 num_threads: 1,
339 enable_parallel: false,
340 enable_simd: false,
341 enable_tiled_reductions: false,
342 enable_vectorized_broadcast: false,
343 enable_memory_pool: false,
344 }
345 }
346
347 pub fn unoptimized() -> Self {
349 Self {
350 memory_pool_f32: MemoryPool::disabled(),
351 memory_pool_f64: MemoryPool::disabled(),
352 num_threads: 1,
353 enable_parallel: false,
354 enable_simd: false,
355 enable_tiled_reductions: false,
356 enable_vectorized_broadcast: false,
357 enable_memory_pool: false,
358 }
359 }
360
361 pub fn with_simd(mut self, enabled: bool) -> Self {
363 self.enable_simd = enabled;
364 self
365 }
366
367 pub fn with_tiled_reductions(mut self, enabled: bool) -> Self {
369 self.enable_tiled_reductions = enabled;
370 self
371 }
372
373 pub fn with_vectorized_broadcast(mut self, enabled: bool) -> Self {
375 self.enable_vectorized_broadcast = enabled;
376 self
377 }
378
379 pub fn with_memory_pool(mut self, enabled: bool) -> Self {
381 self.enable_memory_pool = enabled;
382 self.memory_pool_f32.set_enabled(enabled);
383 self.memory_pool_f64.set_enabled(enabled);
384 self
385 }
386
387 pub fn pool_stats(&self) -> (usize, usize, f64) {
391 self.memory_pool_f32.stats()
392 }
393
394 pub fn get_pool_stats(&self) -> PoolStats {
402 self.memory_pool_f32.detailed_stats()
403 }
404
405 pub fn get_pool_stats_f32(&self) -> PoolStats {
407 self.memory_pool_f32.detailed_stats()
408 }
409
410 pub fn get_pool_stats_f64(&self) -> PoolStats {
412 self.memory_pool_f64.detailed_stats()
413 }
414
415 pub fn clear_pool(&mut self) {
419 self.memory_pool_f32.clear();
420 self.memory_pool_f64.clear();
421 }
422
423 pub fn is_pool_enabled(&self) -> bool {
425 self.enable_memory_pool
426 && self.memory_pool_f32.is_enabled()
427 && self.memory_pool_f64.is_enabled()
428 }
429
430 pub fn set_pool_enabled(&mut self, enabled: bool) {
432 self.enable_memory_pool = enabled;
433 self.memory_pool_f32.set_enabled(enabled);
434 self.memory_pool_f64.set_enabled(enabled);
435 }
436
437 pub fn pool_num_shapes(&self) -> usize {
439 self.memory_pool_f32.num_shapes()
440 }
441
442 pub fn pool_num_shapes_f32(&self) -> usize {
444 self.memory_pool_f32.num_shapes()
445 }
446
447 pub fn pool_num_shapes_f64(&self) -> usize {
449 self.memory_pool_f64.num_shapes()
450 }
451
452 pub fn pool_num_buffers(&self) -> usize {
454 self.memory_pool_f32.num_buffers()
455 }
456
457 pub fn pool_num_buffers_f32(&self) -> usize {
459 self.memory_pool_f32.num_buffers()
460 }
461
462 pub fn pool_num_buffers_f64(&self) -> usize {
464 self.memory_pool_f64.num_buffers()
465 }
466
467 pub fn acquire_f32(&mut self, shape: &[usize]) -> Vec<f32> {
483 if self.enable_memory_pool {
484 self.memory_pool_f32.acquire(shape)
485 } else {
486 vec![0.0; shape.iter().product()]
487 }
488 }
489
490 pub fn release_f32(&mut self, shape: &[usize], buffer: Vec<f32>) {
497 if self.enable_memory_pool {
498 self.memory_pool_f32.release(shape, buffer);
499 }
500 }
501
502 pub fn acquire_f64(&mut self, shape: &[usize]) -> Vec<f64> {
509 if self.enable_memory_pool {
510 self.memory_pool_f64.acquire(shape)
511 } else {
512 vec![0.0; shape.iter().product()]
513 }
514 }
515
516 pub fn release_f64(&mut self, shape: &[usize], buffer: Vec<f64>) {
523 if self.enable_memory_pool {
524 self.memory_pool_f64.release(shape, buffer);
525 }
526 }
527
528 #[inline]
540 #[allow(dead_code)]
541 pub(crate) fn acquire_pooled_generic<T>(&mut self, shape: &[usize]) -> Vec<T>
542 where
543 T: Clone + std::default::Default + 'static,
544 {
545 if !self.enable_memory_pool {
546 return vec![T::default(); shape.iter().product()];
547 }
548
549 use std::any::TypeId;
551
552 if TypeId::of::<T>() == TypeId::of::<f32>() {
553 let buffer_f32 = self.memory_pool_f32.acquire(shape);
555 unsafe { std::mem::transmute::<Vec<f32>, Vec<T>>(buffer_f32) }
557 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
558 let buffer_f64 = self.memory_pool_f64.acquire(shape);
560 unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(buffer_f64) }
562 } else {
563 vec![T::default(); shape.iter().product()]
565 }
566 }
567
568 #[inline]
575 #[allow(dead_code)]
576 pub(crate) fn release_pooled_generic<T>(&mut self, shape: &[usize], buffer: Vec<T>)
577 where
578 T: Clone + std::default::Default + 'static,
579 {
580 if !self.enable_memory_pool {
581 return;
582 }
583
584 use std::any::TypeId;
585
586 if TypeId::of::<T>() == TypeId::of::<f32>() {
587 let buffer_f32: Vec<f32> = unsafe { std::mem::transmute::<Vec<T>, Vec<f32>>(buffer) };
589 self.memory_pool_f32.release(shape, buffer_f32);
590 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
591 let buffer_f64: Vec<f64> = unsafe { std::mem::transmute::<Vec<T>, Vec<f64>>(buffer) };
593 self.memory_pool_f64.release(shape, buffer_f64);
594 }
595 }
597
598 #[inline]
605 #[allow(dead_code)]
606 pub(crate) fn with_pooled_buffer<T, F, R>(&mut self, shape: &[usize], f: F) -> Result<R>
607 where
608 T: Clone + std::default::Default + 'static,
609 F: FnOnce(Vec<T>) -> Result<R>,
610 {
611 let buffer = self.acquire_pooled_generic::<T>(shape);
612 let result = f(buffer.clone());
613 self.release_pooled_generic::<T>(shape, buffer);
614 result
615 }
616
617 pub(crate) fn execute_einsum_with_planner<T>(
623 &mut self,
624 spec: &EinsumSpec,
625 inputs: &[DenseND<T>],
626 _hints: &ExecHints,
627 ) -> Result<DenseND<T>>
628 where
629 T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
630 {
631 let shapes: Vec<Vec<usize>> = inputs.iter().map(|t| t.shape().to_vec()).collect();
632 let plan_hints = PlanHints::default();
633 let plan = greedy_planner(spec, &shapes, &plan_hints)?;
634 if inputs.len() == 2 {
635 return execute_dense_contraction(spec, &inputs[0], &inputs[1]);
636 }
637 self.execute_plan(&plan, inputs)
638 }
639 pub(crate) fn binary_op_with_broadcast<T>(
641 &mut self,
642 op: BinaryOp,
643 x: &DenseND<T>,
644 y: &DenseND<T>,
645 ) -> Result<TensorHandle<T>>
646 where
647 T: Clone
648 + Num
649 + std::ops::AddAssign
650 + std::default::Default
651 + Float
652 + FromPrimitive
653 + 'static,
654 {
655 let x_shape = x.shape();
656 let y_shape = y.shape();
657 let output_shape = self.broadcast_shapes(x_shape, y_shape)?;
658 let x_is_scalar = x_shape.is_empty() || (x_shape.len() == 1 && x_shape[0] == 1);
659 let y_is_scalar = y_shape.is_empty() || (y_shape.len() == 1 && y_shape[0] == 1);
660 if x_is_scalar {
661 let x_val = if x_shape.is_empty() {
662 x.view()[[]]
663 } else {
664 x.view()[[0]]
665 };
666 let result_data = match op {
667 BinaryOp::Add => y.view().mapv(|y_val| x_val + y_val),
668 BinaryOp::Sub => y.view().mapv(|y_val| x_val - y_val),
669 BinaryOp::Mul => y.view().mapv(|y_val| x_val * y_val),
670 BinaryOp::Div => y.view().mapv(|y_val| x_val / y_val),
671 BinaryOp::Pow => y.view().mapv(|y_val| x_val.powf(y_val)),
672 BinaryOp::Maximum => y
673 .view()
674 .mapv(|y_val| if x_val > y_val { x_val } else { y_val }),
675 BinaryOp::Minimum => y
676 .view()
677 .mapv(|y_val| if x_val < y_val { x_val } else { y_val }),
678 };
679 return Ok(TensorHandle::from_dense_auto(DenseND::from_array(
680 result_data,
681 )));
682 }
683 if y_is_scalar {
684 let y_val = if y_shape.is_empty() {
685 y.view()[[]]
686 } else {
687 y.view()[[0]]
688 };
689 let result_data = match op {
690 BinaryOp::Add => x.view().mapv(|x_val| x_val + y_val),
691 BinaryOp::Sub => x.view().mapv(|x_val| x_val - y_val),
692 BinaryOp::Mul => x.view().mapv(|x_val| x_val * y_val),
693 BinaryOp::Div => x.view().mapv(|x_val| x_val / y_val),
694 BinaryOp::Pow => x.view().mapv(|x_val| x_val.powf(y_val)),
695 BinaryOp::Maximum => x
696 .view()
697 .mapv(|x_val| if x_val > y_val { x_val } else { y_val }),
698 BinaryOp::Minimum => x
699 .view()
700 .mapv(|x_val| if x_val < y_val { x_val } else { y_val }),
701 };
702 return Ok(TensorHandle::from_dense_auto(DenseND::from_array(
703 result_data,
704 )));
705 }
706 use scirs2_core::ndarray_ext::{Array, IxDyn};
707 let output_size: usize = output_shape.iter().product();
708
709 let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
711 output_data.clear(); for flat_idx in 0..output_size {
714 let out_idx = self.flat_to_multidim(flat_idx, &output_shape);
715 let x_idx = self.broadcast_index(&out_idx, x_shape, &output_shape);
716 let y_idx = self.broadcast_index(&out_idx, y_shape, &output_shape);
717 let x_val = x.view()[x_idx.as_slice()];
718 let y_val = y.view()[y_idx.as_slice()];
719 let result_val = match op {
720 BinaryOp::Add => x_val + y_val,
721 BinaryOp::Sub => x_val - y_val,
722 BinaryOp::Mul => x_val * y_val,
723 BinaryOp::Div => x_val / y_val,
724 BinaryOp::Pow => x_val.powf(y_val),
725 BinaryOp::Maximum => {
726 if x_val > y_val {
727 x_val
728 } else {
729 y_val
730 }
731 }
732 BinaryOp::Minimum => {
733 if x_val < y_val {
734 x_val
735 } else {
736 y_val
737 }
738 }
739 };
740 output_data.push(result_val);
741 }
742
743 let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
745 .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
746 self.release_pooled_generic::<T>(&output_shape, output_data);
747 Ok(TensorHandle::from_dense_auto(DenseND::from_array(
748 result_array,
749 )))
750 }
751 pub(crate) fn flat_to_multidim(&self, flat_idx: usize, shape: &[usize]) -> Vec<usize> {
753 let mut idx = Vec::with_capacity(shape.len());
754 let mut remaining = flat_idx;
755 for &dim_size in shape.iter().rev() {
756 idx.push(remaining % dim_size);
757 remaining /= dim_size;
758 }
759 idx.reverse();
760 idx
761 }
762 pub(crate) fn multidim_to_flat(&self, idx: &[usize], shape: &[usize]) -> usize {
764 let mut flat_idx = 0;
765 let mut multiplier = 1;
766 for i in (0..shape.len()).rev() {
767 flat_idx += idx[i] * multiplier;
768 multiplier *= shape[i];
769 }
770 flat_idx
771 }
772 fn broadcast_index(
774 &self,
775 out_idx: &[usize],
776 in_shape: &[usize],
777 out_shape: &[usize],
778 ) -> Vec<usize> {
779 let mut in_idx = Vec::with_capacity(in_shape.len());
780 let ndim_diff = out_shape.len() - in_shape.len();
781 for (i, &in_dim) in in_shape.iter().enumerate() {
782 let out_i = i + ndim_diff;
783 if in_dim == 1 {
784 in_idx.push(0);
785 } else {
786 in_idx.push(out_idx[out_i]);
787 }
788 }
789 in_idx
790 }
791 fn broadcast_shapes(&self, x_shape: &[usize], y_shape: &[usize]) -> Result<Vec<usize>> {
793 let max_ndim = x_shape.len().max(y_shape.len());
794 let mut result_shape = Vec::with_capacity(max_ndim);
795 for i in 0..max_ndim {
796 let x_dim = if i < x_shape.len() {
797 x_shape[x_shape.len() - 1 - i]
798 } else {
799 1
800 };
801 let y_dim = if i < y_shape.len() {
802 y_shape[y_shape.len() - 1 - i]
803 } else {
804 1
805 };
806 if x_dim == y_dim || x_dim == 1 || y_dim == 1 {
807 result_shape.push(x_dim.max(y_dim));
808 } else {
809 return Err(anyhow!(
810 "Shapes {:?} and {:?} are not broadcast-compatible at dimension {}",
811 x_shape,
812 y_shape,
813 i
814 ));
815 }
816 }
817 result_shape.reverse();
818 Ok(result_shape)
819 }
820 fn execute_plan<T>(&mut self, plan: &Plan, inputs: &[DenseND<T>]) -> Result<DenseND<T>>
822 where
823 T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
824 {
825 let mut intermediates: Vec<DenseND<T>> = inputs.to_vec();
826 for (step_idx, &(i, j)) in plan.order.iter().enumerate() {
827 if i >= intermediates.len() || j >= intermediates.len() {
828 return Err(anyhow!(
829 "Step {}: Invalid indices ({}, {}) for {} intermediates",
830 step_idx,
831 i,
832 j,
833 intermediates.len()
834 ));
835 }
836 let node = &plan.nodes[step_idx];
837 let (tensor_a, tensor_b) = if i < j {
838 let b = intermediates.remove(j);
839 let a = intermediates.remove(i);
840 (a, b)
841 } else {
842 let a = intermediates.remove(i);
843 let b = intermediates.remove(j);
844 (a, b)
845 };
846 let spec_str = format!(
847 "{},{}->{}",
848 node.output_spec.input_specs[0],
849 node.output_spec.input_specs[1],
850 node.output_spec.output_spec
851 );
852 let step_spec = EinsumSpec::parse(&spec_str)?;
853 let result = execute_dense_contraction(&step_spec, &tensor_a, &tensor_b)?;
854 intermediates.push(result);
855 }
856 if intermediates.len() != 1 {
857 return Err(anyhow!(
858 "Expected 1 final tensor, got {}",
859 intermediates.len()
860 ));
861 }
862 Ok(intermediates.into_iter().next().unwrap())
863 }
864 pub(crate) fn compute_determinant_2d<T2>(
866 &self,
867 matrix: &scirs2_core::ndarray_ext::Array2<T2>,
868 ) -> Result<T2>
869 where
870 T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
871 {
872 let n = matrix.nrows();
873 if n == 0 {
874 return Ok(T2::one());
875 }
876 if n == 1 {
877 return Ok(matrix[[0, 0]]);
878 }
879 if n == 2 {
880 let a = matrix[[0, 0]];
881 let b = matrix[[0, 1]];
882 let c = matrix[[1, 0]];
883 let d = matrix[[1, 1]];
884 return Ok(a * d - b * c);
885 }
886 let mut a = matrix.clone();
887 let mut det = T2::one();
888 let mut sign = T2::one();
889 for i in 0..n {
890 let mut pivot = i;
891 let mut max_val = a[[i, i]].abs();
892 for k in (i + 1)..n {
893 let val = a[[k, i]].abs();
894 if val > max_val {
895 max_val = val;
896 pivot = k;
897 }
898 }
899 if max_val < T2::from_f64(1e-10).unwrap() {
900 return Ok(T2::zero());
901 }
902 if pivot != i {
903 for j in 0..n {
904 let temp = a[[i, j]];
905 a[[i, j]] = a[[pivot, j]];
906 a[[pivot, j]] = temp;
907 }
908 sign = -sign;
909 }
910 det = det * a[[i, i]];
911 for k in (i + 1)..n {
912 let factor = a[[k, i]] / a[[i, i]];
913 for j in i..n {
914 a[[k, j]] = a[[k, j]] - factor * a[[i, j]];
915 }
916 }
917 }
918 Ok(sign * det)
919 }
920 pub(crate) fn compute_inverse_2d<T2>(
922 &self,
923 matrix: &scirs2_core::ndarray_ext::Array2<T2>,
924 ) -> Result<scirs2_core::ndarray_ext::Array2<T2>>
925 where
926 T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
927 {
928 use scirs2_core::ndarray_ext::Array2;
929 let n = matrix.nrows();
930 if n == 0 {
931 return Err(anyhow!("Cannot invert empty matrix"));
932 }
933 let mut aug = Array2::zeros((n, 2 * n));
934 for i in 0..n {
935 for j in 0..n {
936 aug[[i, j]] = matrix[[i, j]];
937 }
938 aug[[i, n + i]] = T2::one();
939 }
940 for i in 0..n {
941 let mut pivot = i;
942 let mut max_val = aug[[i, i]].abs();
943 for k in (i + 1)..n {
944 let val = aug[[k, i]].abs();
945 if val > max_val {
946 max_val = val;
947 pivot = k;
948 }
949 }
950 if max_val < T2::from_f64(1e-10).unwrap() {
951 return Err(anyhow!("Matrix is singular and cannot be inverted"));
952 }
953 if pivot != i {
954 for j in 0..(2 * n) {
955 let temp = aug[[i, j]];
956 aug[[i, j]] = aug[[pivot, j]];
957 aug[[pivot, j]] = temp;
958 }
959 }
960 let pivot_val = aug[[i, i]];
961 for j in 0..(2 * n) {
962 aug[[i, j]] = aug[[i, j]] / pivot_val;
963 }
964 for k in 0..n {
965 if k != i {
966 let factor = aug[[k, i]];
967 for j in 0..(2 * n) {
968 aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
969 }
970 }
971 }
972 }
973 let mut inv = Array2::zeros((n, n));
974 for i in 0..n {
975 for j in 0..n {
976 inv[[i, j]] = aug[[i, n + j]];
977 }
978 }
979 Ok(inv)
980 }
981 pub(crate) fn solve_2d_1d<T2>(
983 &self,
984 a: &scirs2_core::ndarray_ext::Array2<T2>,
985 b: &scirs2_core::ndarray_ext::Array1<T2>,
986 ) -> Result<scirs2_core::ndarray_ext::Array1<T2>>
987 where
988 T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
989 {
990 use scirs2_core::ndarray_ext::Array1;
991 let n = a.nrows();
992 if n != b.len() {
993 return Err(anyhow!("Dimension mismatch in solve"));
994 }
995 let mut a_work = a.clone();
996 let mut b_work = b.clone();
997 for i in 0..n {
998 let mut pivot = i;
999 let mut max_val = a_work[[i, i]].abs();
1000 for k in (i + 1)..n {
1001 let val = a_work[[k, i]].abs();
1002 if val > max_val {
1003 max_val = val;
1004 pivot = k;
1005 }
1006 }
1007 if max_val < T2::from_f64(1e-10).unwrap() {
1008 return Err(anyhow!("Matrix is singular, cannot solve"));
1009 }
1010 if pivot != i {
1011 for j in 0..n {
1012 let temp = a_work[[i, j]];
1013 a_work[[i, j]] = a_work[[pivot, j]];
1014 a_work[[pivot, j]] = temp;
1015 }
1016 let temp = b_work[i];
1017 b_work[i] = b_work[pivot];
1018 b_work[pivot] = temp;
1019 }
1020 for k in (i + 1)..n {
1021 let factor = a_work[[k, i]] / a_work[[i, i]];
1022 for j in i..n {
1023 a_work[[k, j]] = a_work[[k, j]] - factor * a_work[[i, j]];
1024 }
1025 b_work[k] = b_work[k] - factor * b_work[i];
1026 }
1027 }
1028 let mut x = Array1::zeros(n);
1029 for i in (0..n).rev() {
1030 let mut sum = b_work[i];
1031 for j in (i + 1)..n {
1032 sum = sum - a_work[[i, j]] * x[j];
1033 }
1034 x[i] = sum / a_work[[i, i]];
1035 }
1036 Ok(x)
1037 }
1038}
1039#[derive(Clone, Debug)]
1041pub enum ElemOp {
1042 Neg,
1044 Abs,
1046 Exp,
1048 Log,
1050 Sin,
1052 Cos,
1054 Sqrt,
1056 Sqr,
1058 Recip,
1060 Tanh,
1062 Sigmoid,
1064 ReLU,
1066 Gelu,
1069 Elu,
1071 Selu,
1074 Softplus,
1076 Sign,
1078}