tenrso_exec/executor/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::hints::ExecHints;
6use crate::ops::execute_dense_contraction;
7use anyhow::{anyhow, Result};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use std::collections::HashMap;
10use tenrso_core::{DenseND, TensorHandle};
11use tenrso_planner::{greedy_planner, EinsumSpec, Plan, PlanHints};
12
13// Re-export ScatterMode from advanced_indexing
14pub use super::advanced_indexing::ScatterMode;
15
16/// Reduction operation types
17#[derive(Clone, Debug)]
18pub enum ReduceOp {
19    Sum,
20    Max,
21    Min,
22    Mean,
23    Prod,
24    All,
25    Any,
26    ArgMax,
27    ArgMin,
28}
29/// Binary element-wise operation types (operations on two tensors)
30#[derive(Clone, Debug)]
31pub enum BinaryOp {
32    /// Element-wise addition: x + y
33    Add,
34    /// Element-wise subtraction: x - y
35    Sub,
36    /// Element-wise multiplication: x * y
37    Mul,
38    /// Element-wise division: x / y
39    Div,
40    /// Element-wise power: x^y
41    Pow,
42    /// Element-wise maximum: max(x, y)
43    Maximum,
44    /// Element-wise minimum: min(x, y)
45    Minimum,
46}
47/// Memory pool for tensor allocation reuse
48///
49/// Tracks allocated tensors by their shape signature for reuse.
50/// This reduces allocation overhead for repeated operations.
51///
52/// # Type Safety
53///
54/// The pool is generic over type T, ensuring type-safe buffer reuse.
55/// T must implement `bytemuck::Pod` (Plain Old Data) and `bytemuck::Zeroable`
56/// for safe memory operations.
57///
58/// # Statistics
59///
60/// The pool tracks:
61/// - **hits**: Number of successful buffer reuses
62/// - **misses**: Number of new allocations
63/// - **total_bytes_pooled**: Total bytes currently in pools
64/// - **unique_shapes**: Number of distinct shape signatures
65pub(crate) struct MemoryPool<T>
66where
67    T: bytemuck::Pod + bytemuck::Zeroable,
68{
69    /// Map from shape signature to available buffer pool
70    /// Shape signature is a string like "2x3x4" for a [2, 3, 4] tensor
71    pools: HashMap<String, Vec<Vec<T>>>,
72    /// Statistics for monitoring
73    hits: usize,
74    misses: usize,
75    total_allocations: usize,
76    total_releases: usize,
77    enabled: bool,
78    /// Phantom data for type parameter
79    _phantom: std::marker::PhantomData<T>,
80}
81
82/// Memory pool statistics
83#[derive(Debug, Clone, PartialEq)]
84pub struct PoolStats {
85    /// Number of buffer reuses (cache hits)
86    pub hits: usize,
87    /// Number of new allocations (cache misses)
88    pub misses: usize,
89    /// Total number of allocation requests
90    pub total_allocations: usize,
91    /// Total number of buffer releases
92    pub total_releases: usize,
93    /// Cache hit rate (hits / total)
94    pub hit_rate: f64,
95    /// Number of unique shape signatures in pool
96    pub unique_shapes: usize,
97    /// Total bytes currently pooled
98    pub total_bytes_pooled: usize,
99    /// Total number of buffers currently pooled
100    pub total_buffers_pooled: usize,
101}
102
103impl<T> MemoryPool<T>
104where
105    T: bytemuck::Pod + bytemuck::Zeroable,
106{
107    /// Create a new memory pool
108    pub(crate) fn new() -> Self {
109        Self {
110            pools: HashMap::new(),
111            hits: 0,
112            misses: 0,
113            total_allocations: 0,
114            total_releases: 0,
115            enabled: true,
116            _phantom: std::marker::PhantomData,
117        }
118    }
119
120    /// Create a disabled memory pool (no actual pooling)
121    pub(crate) fn disabled() -> Self {
122        Self {
123            pools: HashMap::new(),
124            hits: 0,
125            misses: 0,
126            total_allocations: 0,
127            total_releases: 0,
128            enabled: false,
129            _phantom: std::marker::PhantomData,
130        }
131    }
132
133    /// Enable or disable the memory pool
134    pub(crate) fn set_enabled(&mut self, enabled: bool) {
135        self.enabled = enabled;
136        if !enabled {
137            // Clear pools when disabling
138            self.pools.clear();
139        }
140    }
141
142    /// Check if pooling is enabled
143    pub(crate) fn is_enabled(&self) -> bool {
144        self.enabled
145    }
146
147    /// Get a buffer for the given shape, reusing if available
148    ///
149    /// **Phase 2 Status**: Now using type-safe generic buffer pooling.
150    ///
151    /// Returns a Vec<T> with the specified total size (product of shape dimensions).
152    /// If a matching buffer is available in the pool, it will be reused (cache hit).
153    /// Otherwise, a new buffer is allocated (cache miss).
154    #[allow(dead_code)]
155    pub(crate) fn acquire(&mut self, shape: &[usize]) -> Vec<T> {
156        self.total_allocations += 1;
157
158        let total_size: usize = shape.iter().product();
159
160        if !self.enabled {
161            self.misses += 1;
162            return vec![T::zeroed(); total_size];
163        }
164
165        let signature = Self::shape_signature(shape);
166
167        if let Some(pool) = self.pools.get_mut(&signature) {
168            if let Some(mut buffer) = pool.pop() {
169                self.hits += 1;
170                // Resize buffer to match requested size
171                buffer.resize(total_size, T::zeroed());
172                return buffer;
173            }
174        }
175
176        self.misses += 1;
177        vec![T::zeroed(); total_size]
178    }
179
180    /// Return a buffer to the pool for reuse
181    ///
182    /// **Phase 2 Status**: Now using type-safe generic buffer pooling.
183    ///
184    /// Adds the buffer back to the pool for the given shape signature.
185    /// Pools are limited to MAX_POOL_SIZE buffers per shape to prevent unbounded growth.
186    #[allow(dead_code)]
187    pub(crate) fn release(&mut self, shape: &[usize], buffer: Vec<T>) {
188        self.total_releases += 1;
189
190        if !self.enabled {
191            // Don't pool if disabled - buffer will be dropped
192            return;
193        }
194
195        let signature = Self::shape_signature(shape);
196        let pool = self.pools.entry(signature).or_default();
197
198        const MAX_POOL_SIZE: usize = 16;
199        if pool.len() < MAX_POOL_SIZE {
200            pool.push(buffer);
201        }
202        // If pool is full, buffer is dropped
203    }
204
205    /// Create a shape signature for hashing
206    ///
207    /// Converts a shape like `[2, 3, 4]` to a string like `"2x3x4"`.
208    /// Used as a key for the buffer pool HashMap.
209    #[allow(dead_code)]
210    pub(crate) fn shape_signature(shape: &[usize]) -> String {
211        shape
212            .iter()
213            .map(|s| s.to_string())
214            .collect::<Vec<_>>()
215            .join("x")
216    }
217
218    /// Get pool statistics (deprecated - use detailed_stats)
219    pub(crate) fn stats(&self) -> (usize, usize, f64) {
220        let total = self.hits + self.misses;
221        let hit_rate = if total > 0 {
222            self.hits as f64 / total as f64
223        } else {
224            0.0
225        };
226        (self.hits, self.misses, hit_rate)
227    }
228
229    /// Get detailed pool statistics
230    pub(crate) fn detailed_stats(&self) -> PoolStats {
231        let total = self.hits + self.misses;
232        let hit_rate = if total > 0 {
233            self.hits as f64 / total as f64
234        } else {
235            0.0
236        };
237
238        let unique_shapes = self.pools.len();
239        let mut total_bytes_pooled = 0;
240        let mut total_buffers_pooled = 0;
241
242        let elem_size = std::mem::size_of::<T>();
243
244        for pool in self.pools.values() {
245            total_buffers_pooled += pool.len();
246            for buffer in pool {
247                total_bytes_pooled += buffer.len() * elem_size;
248            }
249        }
250
251        PoolStats {
252            hits: self.hits,
253            misses: self.misses,
254            total_allocations: self.total_allocations,
255            total_releases: self.total_releases,
256            hit_rate,
257            unique_shapes,
258            total_bytes_pooled,
259            total_buffers_pooled,
260        }
261    }
262
263    /// Clear all pooled buffers
264    pub(crate) fn clear(&mut self) {
265        self.pools.clear();
266        self.hits = 0;
267        self.misses = 0;
268        self.total_allocations = 0;
269        self.total_releases = 0;
270    }
271
272    /// Get the number of unique shapes in the pool
273    pub(crate) fn num_shapes(&self) -> usize {
274        self.pools.len()
275    }
276
277    /// Get the total number of buffers in the pool
278    pub(crate) fn num_buffers(&self) -> usize {
279        self.pools.values().map(|v| v.len()).sum()
280    }
281}
282/// CPU executor implementation with memory pooling and parallel execution
283pub struct CpuExecutor {
284    /// Memory pool for f32 tensors
285    ///
286    /// **Phase 2 Status**: Type-safe buffer pooling now operational.
287    memory_pool_f32: MemoryPool<f32>,
288    /// Memory pool for f64 tensors
289    ///
290    /// **Phase 2 Status**: Type-safe buffer pooling now operational.
291    memory_pool_f64: MemoryPool<f64>,
292    /// Number of threads to use (0 = auto-detect)
293    pub num_threads: usize,
294    /// Enable parallel execution for large tensors
295    pub enable_parallel: bool,
296    /// Enable SIMD-optimized element-wise operations
297    pub enable_simd: bool,
298    /// Enable tiled/blocked reductions for large tensors
299    pub enable_tiled_reductions: bool,
300    /// Enable vectorized broadcasting optimizations
301    pub enable_vectorized_broadcast: bool,
302    /// Enable memory pooling
303    pub enable_memory_pool: bool,
304}
305impl CpuExecutor {
306    /// Create a new CPU executor with default settings
307    /// All optimizations are enabled by default
308    pub fn new() -> Self {
309        Self {
310            memory_pool_f32: MemoryPool::new(),
311            memory_pool_f64: MemoryPool::new(),
312            num_threads: 0,
313            enable_parallel: true,
314            enable_simd: true,
315            enable_tiled_reductions: true,
316            enable_vectorized_broadcast: true,
317            enable_memory_pool: true,
318        }
319    }
320    /// Create a CPU executor with custom thread count
321    pub fn with_threads(num_threads: usize) -> Self {
322        Self {
323            memory_pool_f32: MemoryPool::new(),
324            memory_pool_f64: MemoryPool::new(),
325            num_threads,
326            enable_parallel: true,
327            enable_simd: true,
328            enable_tiled_reductions: true,
329            enable_vectorized_broadcast: true,
330            enable_memory_pool: true,
331        }
332    }
333    /// Create a CPU executor with parallel execution disabled
334    pub fn serial() -> Self {
335        Self {
336            memory_pool_f32: MemoryPool::new(),
337            memory_pool_f64: MemoryPool::new(),
338            num_threads: 1,
339            enable_parallel: false,
340            enable_simd: false,
341            enable_tiled_reductions: false,
342            enable_vectorized_broadcast: false,
343            enable_memory_pool: false,
344        }
345    }
346
347    /// Create a CPU executor with all optimizations disabled (for debugging/testing)
348    pub fn unoptimized() -> Self {
349        Self {
350            memory_pool_f32: MemoryPool::disabled(),
351            memory_pool_f64: MemoryPool::disabled(),
352            num_threads: 1,
353            enable_parallel: false,
354            enable_simd: false,
355            enable_tiled_reductions: false,
356            enable_vectorized_broadcast: false,
357            enable_memory_pool: false,
358        }
359    }
360
361    /// Configure SIMD optimization
362    pub fn with_simd(mut self, enabled: bool) -> Self {
363        self.enable_simd = enabled;
364        self
365    }
366
367    /// Configure tiled reductions
368    pub fn with_tiled_reductions(mut self, enabled: bool) -> Self {
369        self.enable_tiled_reductions = enabled;
370        self
371    }
372
373    /// Configure vectorized broadcasting
374    pub fn with_vectorized_broadcast(mut self, enabled: bool) -> Self {
375        self.enable_vectorized_broadcast = enabled;
376        self
377    }
378
379    /// Configure memory pooling
380    pub fn with_memory_pool(mut self, enabled: bool) -> Self {
381        self.enable_memory_pool = enabled;
382        self.memory_pool_f32.set_enabled(enabled);
383        self.memory_pool_f64.set_enabled(enabled);
384        self
385    }
386
387    /// Get memory pool statistics for f32 tensors (hits, misses, hit_rate)
388    ///
389    /// **Deprecated**: Use `get_pool_stats_f32()` for detailed statistics.
390    pub fn pool_stats(&self) -> (usize, usize, f64) {
391        self.memory_pool_f32.stats()
392    }
393
394    /// Get detailed memory pool statistics for f32 tensors
395    ///
396    /// Returns comprehensive statistics about f32 memory pool usage including:
397    /// - Hit/miss counts and rate
398    /// - Total allocations and releases
399    /// - Number of unique shapes and buffers
400    /// - Total bytes currently pooled
401    pub fn get_pool_stats(&self) -> PoolStats {
402        self.memory_pool_f32.detailed_stats()
403    }
404
405    /// Get detailed memory pool statistics for f32 tensors
406    pub fn get_pool_stats_f32(&self) -> PoolStats {
407        self.memory_pool_f32.detailed_stats()
408    }
409
410    /// Get detailed memory pool statistics for f64 tensors
411    pub fn get_pool_stats_f64(&self) -> PoolStats {
412        self.memory_pool_f64.detailed_stats()
413    }
414
415    /// Clear all memory pools
416    ///
417    /// Releases all pooled buffers and resets statistics for both f32 and f64 pools.
418    pub fn clear_pool(&mut self) {
419        self.memory_pool_f32.clear();
420        self.memory_pool_f64.clear();
421    }
422
423    /// Check if memory pooling is enabled
424    pub fn is_pool_enabled(&self) -> bool {
425        self.enable_memory_pool
426            && self.memory_pool_f32.is_enabled()
427            && self.memory_pool_f64.is_enabled()
428    }
429
430    /// Enable or disable memory pooling at runtime
431    pub fn set_pool_enabled(&mut self, enabled: bool) {
432        self.enable_memory_pool = enabled;
433        self.memory_pool_f32.set_enabled(enabled);
434        self.memory_pool_f64.set_enabled(enabled);
435    }
436
437    /// Get the number of unique shapes in f32 pool
438    pub fn pool_num_shapes(&self) -> usize {
439        self.memory_pool_f32.num_shapes()
440    }
441
442    /// Get the number of unique shapes in f32 pool
443    pub fn pool_num_shapes_f32(&self) -> usize {
444        self.memory_pool_f32.num_shapes()
445    }
446
447    /// Get the number of unique shapes in f64 pool
448    pub fn pool_num_shapes_f64(&self) -> usize {
449        self.memory_pool_f64.num_shapes()
450    }
451
452    /// Get the total number of buffers in f32 pool
453    pub fn pool_num_buffers(&self) -> usize {
454        self.memory_pool_f32.num_buffers()
455    }
456
457    /// Get the total number of buffers in f32 pool
458    pub fn pool_num_buffers_f32(&self) -> usize {
459        self.memory_pool_f32.num_buffers()
460    }
461
462    /// Get the total number of buffers in f64 pool
463    pub fn pool_num_buffers_f64(&self) -> usize {
464        self.memory_pool_f64.num_buffers()
465    }
466
467    /// Acquire a buffer from the f32 pool
468    ///
469    /// # Phase 2 Memory Pool API
470    ///
471    /// This is the public API for manually acquiring buffers from the f32 pool.
472    /// Useful for custom tensor operations and benchmarking.
473    ///
474    /// # Example
475    ///
476    /// ```ignore
477    /// let mut executor = CpuExecutor::new();
478    /// let buffer = executor.acquire_f32(&[64, 64]);
479    /// // Use buffer...
480    /// executor.release_f32(&[64, 64], buffer);
481    /// ```
482    pub fn acquire_f32(&mut self, shape: &[usize]) -> Vec<f32> {
483        if self.enable_memory_pool {
484            self.memory_pool_f32.acquire(shape)
485        } else {
486            vec![0.0; shape.iter().product()]
487        }
488    }
489
490    /// Release a buffer to the f32 pool
491    ///
492    /// # Phase 2 Memory Pool API
493    ///
494    /// Returns a buffer to the pool for reuse. The buffer will be available
495    /// for future `acquire_f32` calls with the same shape.
496    pub fn release_f32(&mut self, shape: &[usize], buffer: Vec<f32>) {
497        if self.enable_memory_pool {
498            self.memory_pool_f32.release(shape, buffer);
499        }
500    }
501
502    /// Acquire a buffer from the f64 pool
503    ///
504    /// # Phase 2 Memory Pool API
505    ///
506    /// This is the public API for manually acquiring buffers from the f64 pool.
507    /// Useful for custom tensor operations and benchmarking.
508    pub fn acquire_f64(&mut self, shape: &[usize]) -> Vec<f64> {
509        if self.enable_memory_pool {
510            self.memory_pool_f64.acquire(shape)
511        } else {
512            vec![0.0; shape.iter().product()]
513        }
514    }
515
516    /// Release a buffer to the f64 pool
517    ///
518    /// # Phase 2 Memory Pool API
519    ///
520    /// Returns a buffer to the pool for reuse. The buffer will be available
521    /// for future `acquire_f64` calls with the same shape.
522    pub fn release_f64(&mut self, shape: &[usize], buffer: Vec<f64>) {
523        if self.enable_memory_pool {
524            self.memory_pool_f64.release(shape, buffer);
525        }
526    }
527
528    // ========================================================================
529    // Generic Pooling Helpers (Phase 5: Automatic Pooling Integration)
530    // ========================================================================
531
532    /// Acquire a pooled buffer with automatic type dispatch
533    ///
534    /// This is an internal helper that automatically selects the appropriate
535    /// pool based on the type T. Only f32 and f64 are pooled; other types
536    /// allocate directly.
537    ///
538    /// **Phase 5 Status**: Automatic pooling for all operations.
539    #[inline]
540    #[allow(dead_code)]
541    pub(crate) fn acquire_pooled_generic<T>(&mut self, shape: &[usize]) -> Vec<T>
542    where
543        T: Clone + std::default::Default + 'static,
544    {
545        if !self.enable_memory_pool {
546            return vec![T::default(); shape.iter().product()];
547        }
548
549        // Use type introspection to dispatch to the correct pool
550        use std::any::TypeId;
551
552        if TypeId::of::<T>() == TypeId::of::<f32>() {
553            // Safe: We've verified T is f32
554            let buffer_f32 = self.memory_pool_f32.acquire(shape);
555            // SAFETY: We checked that T == f32, so transmuting Vec<f32> to Vec<T> is safe
556            unsafe { std::mem::transmute::<Vec<f32>, Vec<T>>(buffer_f32) }
557        } else if TypeId::of::<T>() == TypeId::of::<f64>() {
558            // Safe: We've verified T is f64
559            let buffer_f64 = self.memory_pool_f64.acquire(shape);
560            // SAFETY: We checked that T == f64, so transmuting Vec<f64> to Vec<T> is safe
561            unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(buffer_f64) }
562        } else {
563            // For other types, allocate directly (no pooling)
564            vec![T::default(); shape.iter().product()]
565        }
566    }
567
568    /// Release a pooled buffer with automatic type dispatch
569    ///
570    /// This is an internal helper that automatically returns buffers to the
571    /// appropriate pool based on the type T.
572    ///
573    /// **Phase 5 Status**: Automatic pooling for all operations.
574    #[inline]
575    #[allow(dead_code)]
576    pub(crate) fn release_pooled_generic<T>(&mut self, shape: &[usize], buffer: Vec<T>)
577    where
578        T: Clone + std::default::Default + 'static,
579    {
580        if !self.enable_memory_pool {
581            return;
582        }
583
584        use std::any::TypeId;
585
586        if TypeId::of::<T>() == TypeId::of::<f32>() {
587            // SAFETY: We checked that T == f32, so transmuting Vec<T> to Vec<f32> is safe
588            let buffer_f32: Vec<f32> = unsafe { std::mem::transmute::<Vec<T>, Vec<f32>>(buffer) };
589            self.memory_pool_f32.release(shape, buffer_f32);
590        } else if TypeId::of::<T>() == TypeId::of::<f64>() {
591            // SAFETY: We checked that T == f64, so transmuting Vec<T> to Vec<f64> is safe
592            let buffer_f64: Vec<f64> = unsafe { std::mem::transmute::<Vec<T>, Vec<f64>>(buffer) };
593            self.memory_pool_f64.release(shape, buffer_f64);
594        }
595        // For other types, buffer is dropped (no pooling)
596    }
597
598    /// Execute a computation with a pooled buffer (RAII pattern)
599    ///
600    /// This helper automatically acquires and releases a pooled buffer,
601    /// ensuring the buffer is returned to the pool even if an error occurs.
602    ///
603    /// **Phase 5 Status**: Automatic pooling for all operations.
604    #[inline]
605    #[allow(dead_code)]
606    pub(crate) fn with_pooled_buffer<T, F, R>(&mut self, shape: &[usize], f: F) -> Result<R>
607    where
608        T: Clone + std::default::Default + 'static,
609        F: FnOnce(Vec<T>) -> Result<R>,
610    {
611        let buffer = self.acquire_pooled_generic::<T>(shape);
612        let result = f(buffer.clone());
613        self.release_pooled_generic::<T>(shape, buffer);
614        result
615    }
616
617    // ========================================================================
618    // End Generic Pooling Helpers
619    // ========================================================================
620
621    /// Execute einsum with planner integration
622    pub(crate) fn execute_einsum_with_planner<T>(
623        &mut self,
624        spec: &EinsumSpec,
625        inputs: &[DenseND<T>],
626        _hints: &ExecHints,
627    ) -> Result<DenseND<T>>
628    where
629        T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
630    {
631        let shapes: Vec<Vec<usize>> = inputs.iter().map(|t| t.shape().to_vec()).collect();
632        let plan_hints = PlanHints::default();
633        let plan = greedy_planner(spec, &shapes, &plan_hints)?;
634        if inputs.len() == 2 {
635            return execute_dense_contraction(spec, &inputs[0], &inputs[1]);
636        }
637        self.execute_plan(&plan, inputs)
638    }
639    /// Execute binary operation with full NumPy-style broadcasting support
640    pub(crate) fn binary_op_with_broadcast<T>(
641        &mut self,
642        op: BinaryOp,
643        x: &DenseND<T>,
644        y: &DenseND<T>,
645    ) -> Result<TensorHandle<T>>
646    where
647        T: Clone
648            + Num
649            + std::ops::AddAssign
650            + std::default::Default
651            + Float
652            + FromPrimitive
653            + 'static,
654    {
655        let x_shape = x.shape();
656        let y_shape = y.shape();
657        let output_shape = self.broadcast_shapes(x_shape, y_shape)?;
658        let x_is_scalar = x_shape.is_empty() || (x_shape.len() == 1 && x_shape[0] == 1);
659        let y_is_scalar = y_shape.is_empty() || (y_shape.len() == 1 && y_shape[0] == 1);
660        if x_is_scalar {
661            let x_val = if x_shape.is_empty() {
662                x.view()[[]]
663            } else {
664                x.view()[[0]]
665            };
666            let result_data = match op {
667                BinaryOp::Add => y.view().mapv(|y_val| x_val + y_val),
668                BinaryOp::Sub => y.view().mapv(|y_val| x_val - y_val),
669                BinaryOp::Mul => y.view().mapv(|y_val| x_val * y_val),
670                BinaryOp::Div => y.view().mapv(|y_val| x_val / y_val),
671                BinaryOp::Pow => y.view().mapv(|y_val| x_val.powf(y_val)),
672                BinaryOp::Maximum => y
673                    .view()
674                    .mapv(|y_val| if x_val > y_val { x_val } else { y_val }),
675                BinaryOp::Minimum => y
676                    .view()
677                    .mapv(|y_val| if x_val < y_val { x_val } else { y_val }),
678            };
679            return Ok(TensorHandle::from_dense_auto(DenseND::from_array(
680                result_data,
681            )));
682        }
683        if y_is_scalar {
684            let y_val = if y_shape.is_empty() {
685                y.view()[[]]
686            } else {
687                y.view()[[0]]
688            };
689            let result_data = match op {
690                BinaryOp::Add => x.view().mapv(|x_val| x_val + y_val),
691                BinaryOp::Sub => x.view().mapv(|x_val| x_val - y_val),
692                BinaryOp::Mul => x.view().mapv(|x_val| x_val * y_val),
693                BinaryOp::Div => x.view().mapv(|x_val| x_val / y_val),
694                BinaryOp::Pow => x.view().mapv(|x_val| x_val.powf(y_val)),
695                BinaryOp::Maximum => x
696                    .view()
697                    .mapv(|x_val| if x_val > y_val { x_val } else { y_val }),
698                BinaryOp::Minimum => x
699                    .view()
700                    .mapv(|x_val| if x_val < y_val { x_val } else { y_val }),
701            };
702            return Ok(TensorHandle::from_dense_auto(DenseND::from_array(
703                result_data,
704            )));
705        }
706        use scirs2_core::ndarray_ext::{Array, IxDyn};
707        let output_size: usize = output_shape.iter().product();
708
709        // Use pooled buffer for output allocation (Phase 5: Automatic Pooling)
710        let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
711        output_data.clear(); // Ensure buffer starts empty
712
713        for flat_idx in 0..output_size {
714            let out_idx = self.flat_to_multidim(flat_idx, &output_shape);
715            let x_idx = self.broadcast_index(&out_idx, x_shape, &output_shape);
716            let y_idx = self.broadcast_index(&out_idx, y_shape, &output_shape);
717            let x_val = x.view()[x_idx.as_slice()];
718            let y_val = y.view()[y_idx.as_slice()];
719            let result_val = match op {
720                BinaryOp::Add => x_val + y_val,
721                BinaryOp::Sub => x_val - y_val,
722                BinaryOp::Mul => x_val * y_val,
723                BinaryOp::Div => x_val / y_val,
724                BinaryOp::Pow => x_val.powf(y_val),
725                BinaryOp::Maximum => {
726                    if x_val > y_val {
727                        x_val
728                    } else {
729                        y_val
730                    }
731                }
732                BinaryOp::Minimum => {
733                    if x_val < y_val {
734                        x_val
735                    } else {
736                        y_val
737                    }
738                }
739            };
740            output_data.push(result_val);
741        }
742
743        // Create result array and release buffer back to pool
744        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
745            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
746        self.release_pooled_generic::<T>(&output_shape, output_data);
747        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
748            result_array,
749        )))
750    }
751    /// Convert flat index to multi-dimensional index
752    pub(crate) fn flat_to_multidim(&self, flat_idx: usize, shape: &[usize]) -> Vec<usize> {
753        let mut idx = Vec::with_capacity(shape.len());
754        let mut remaining = flat_idx;
755        for &dim_size in shape.iter().rev() {
756            idx.push(remaining % dim_size);
757            remaining /= dim_size;
758        }
759        idx.reverse();
760        idx
761    }
762    /// Convert multi-dimensional index to flat index
763    pub(crate) fn multidim_to_flat(&self, idx: &[usize], shape: &[usize]) -> usize {
764        let mut flat_idx = 0;
765        let mut multiplier = 1;
766        for i in (0..shape.len()).rev() {
767            flat_idx += idx[i] * multiplier;
768            multiplier *= shape[i];
769        }
770        flat_idx
771    }
772    /// Map output index to input index with broadcasting
773    fn broadcast_index(
774        &self,
775        out_idx: &[usize],
776        in_shape: &[usize],
777        out_shape: &[usize],
778    ) -> Vec<usize> {
779        let mut in_idx = Vec::with_capacity(in_shape.len());
780        let ndim_diff = out_shape.len() - in_shape.len();
781        for (i, &in_dim) in in_shape.iter().enumerate() {
782            let out_i = i + ndim_diff;
783            if in_dim == 1 {
784                in_idx.push(0);
785            } else {
786                in_idx.push(out_idx[out_i]);
787            }
788        }
789        in_idx
790    }
791    /// Compute broadcast shape for two shapes
792    fn broadcast_shapes(&self, x_shape: &[usize], y_shape: &[usize]) -> Result<Vec<usize>> {
793        let max_ndim = x_shape.len().max(y_shape.len());
794        let mut result_shape = Vec::with_capacity(max_ndim);
795        for i in 0..max_ndim {
796            let x_dim = if i < x_shape.len() {
797                x_shape[x_shape.len() - 1 - i]
798            } else {
799                1
800            };
801            let y_dim = if i < y_shape.len() {
802                y_shape[y_shape.len() - 1 - i]
803            } else {
804                1
805            };
806            if x_dim == y_dim || x_dim == 1 || y_dim == 1 {
807                result_shape.push(x_dim.max(y_dim));
808            } else {
809                return Err(anyhow!(
810                    "Shapes {:?} and {:?} are not broadcast-compatible at dimension {}",
811                    x_shape,
812                    y_shape,
813                    i
814                ));
815            }
816        }
817        result_shape.reverse();
818        Ok(result_shape)
819    }
820    /// Execute a multi-step contraction plan
821    fn execute_plan<T>(&mut self, plan: &Plan, inputs: &[DenseND<T>]) -> Result<DenseND<T>>
822    where
823        T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
824    {
825        let mut intermediates: Vec<DenseND<T>> = inputs.to_vec();
826        for (step_idx, &(i, j)) in plan.order.iter().enumerate() {
827            if i >= intermediates.len() || j >= intermediates.len() {
828                return Err(anyhow!(
829                    "Step {}: Invalid indices ({}, {}) for {} intermediates",
830                    step_idx,
831                    i,
832                    j,
833                    intermediates.len()
834                ));
835            }
836            let node = &plan.nodes[step_idx];
837            let (tensor_a, tensor_b) = if i < j {
838                let b = intermediates.remove(j);
839                let a = intermediates.remove(i);
840                (a, b)
841            } else {
842                let a = intermediates.remove(i);
843                let b = intermediates.remove(j);
844                (a, b)
845            };
846            let spec_str = format!(
847                "{},{}->{}",
848                node.output_spec.input_specs[0],
849                node.output_spec.input_specs[1],
850                node.output_spec.output_spec
851            );
852            let step_spec = EinsumSpec::parse(&spec_str)?;
853            let result = execute_dense_contraction(&step_spec, &tensor_a, &tensor_b)?;
854            intermediates.push(result);
855        }
856        if intermediates.len() != 1 {
857            return Err(anyhow!(
858                "Expected 1 final tensor, got {}",
859                intermediates.len()
860            ));
861        }
862        Ok(intermediates.into_iter().next().unwrap())
863    }
864    /// Helper: Compute determinant of a 2D matrix using LU decomposition
865    pub(crate) fn compute_determinant_2d<T2>(
866        &self,
867        matrix: &scirs2_core::ndarray_ext::Array2<T2>,
868    ) -> Result<T2>
869    where
870        T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
871    {
872        let n = matrix.nrows();
873        if n == 0 {
874            return Ok(T2::one());
875        }
876        if n == 1 {
877            return Ok(matrix[[0, 0]]);
878        }
879        if n == 2 {
880            let a = matrix[[0, 0]];
881            let b = matrix[[0, 1]];
882            let c = matrix[[1, 0]];
883            let d = matrix[[1, 1]];
884            return Ok(a * d - b * c);
885        }
886        let mut a = matrix.clone();
887        let mut det = T2::one();
888        let mut sign = T2::one();
889        for i in 0..n {
890            let mut pivot = i;
891            let mut max_val = a[[i, i]].abs();
892            for k in (i + 1)..n {
893                let val = a[[k, i]].abs();
894                if val > max_val {
895                    max_val = val;
896                    pivot = k;
897                }
898            }
899            if max_val < T2::from_f64(1e-10).unwrap() {
900                return Ok(T2::zero());
901            }
902            if pivot != i {
903                for j in 0..n {
904                    let temp = a[[i, j]];
905                    a[[i, j]] = a[[pivot, j]];
906                    a[[pivot, j]] = temp;
907                }
908                sign = -sign;
909            }
910            det = det * a[[i, i]];
911            for k in (i + 1)..n {
912                let factor = a[[k, i]] / a[[i, i]];
913                for j in i..n {
914                    a[[k, j]] = a[[k, j]] - factor * a[[i, j]];
915                }
916            }
917        }
918        Ok(sign * det)
919    }
920    /// Helper: Compute inverse of a 2D matrix using Gauss-Jordan elimination
921    pub(crate) fn compute_inverse_2d<T2>(
922        &self,
923        matrix: &scirs2_core::ndarray_ext::Array2<T2>,
924    ) -> Result<scirs2_core::ndarray_ext::Array2<T2>>
925    where
926        T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
927    {
928        use scirs2_core::ndarray_ext::Array2;
929        let n = matrix.nrows();
930        if n == 0 {
931            return Err(anyhow!("Cannot invert empty matrix"));
932        }
933        let mut aug = Array2::zeros((n, 2 * n));
934        for i in 0..n {
935            for j in 0..n {
936                aug[[i, j]] = matrix[[i, j]];
937            }
938            aug[[i, n + i]] = T2::one();
939        }
940        for i in 0..n {
941            let mut pivot = i;
942            let mut max_val = aug[[i, i]].abs();
943            for k in (i + 1)..n {
944                let val = aug[[k, i]].abs();
945                if val > max_val {
946                    max_val = val;
947                    pivot = k;
948                }
949            }
950            if max_val < T2::from_f64(1e-10).unwrap() {
951                return Err(anyhow!("Matrix is singular and cannot be inverted"));
952            }
953            if pivot != i {
954                for j in 0..(2 * n) {
955                    let temp = aug[[i, j]];
956                    aug[[i, j]] = aug[[pivot, j]];
957                    aug[[pivot, j]] = temp;
958                }
959            }
960            let pivot_val = aug[[i, i]];
961            for j in 0..(2 * n) {
962                aug[[i, j]] = aug[[i, j]] / pivot_val;
963            }
964            for k in 0..n {
965                if k != i {
966                    let factor = aug[[k, i]];
967                    for j in 0..(2 * n) {
968                        aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
969                    }
970                }
971            }
972        }
973        let mut inv = Array2::zeros((n, n));
974        for i in 0..n {
975            for j in 0..n {
976                inv[[i, j]] = aug[[i, n + j]];
977            }
978        }
979        Ok(inv)
980    }
981    /// Helper: Solve linear system Ax = b using LU decomposition
982    pub(crate) fn solve_2d_1d<T2>(
983        &self,
984        a: &scirs2_core::ndarray_ext::Array2<T2>,
985        b: &scirs2_core::ndarray_ext::Array1<T2>,
986    ) -> Result<scirs2_core::ndarray_ext::Array1<T2>>
987    where
988        T2: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive,
989    {
990        use scirs2_core::ndarray_ext::Array1;
991        let n = a.nrows();
992        if n != b.len() {
993            return Err(anyhow!("Dimension mismatch in solve"));
994        }
995        let mut a_work = a.clone();
996        let mut b_work = b.clone();
997        for i in 0..n {
998            let mut pivot = i;
999            let mut max_val = a_work[[i, i]].abs();
1000            for k in (i + 1)..n {
1001                let val = a_work[[k, i]].abs();
1002                if val > max_val {
1003                    max_val = val;
1004                    pivot = k;
1005                }
1006            }
1007            if max_val < T2::from_f64(1e-10).unwrap() {
1008                return Err(anyhow!("Matrix is singular, cannot solve"));
1009            }
1010            if pivot != i {
1011                for j in 0..n {
1012                    let temp = a_work[[i, j]];
1013                    a_work[[i, j]] = a_work[[pivot, j]];
1014                    a_work[[pivot, j]] = temp;
1015                }
1016                let temp = b_work[i];
1017                b_work[i] = b_work[pivot];
1018                b_work[pivot] = temp;
1019            }
1020            for k in (i + 1)..n {
1021                let factor = a_work[[k, i]] / a_work[[i, i]];
1022                for j in i..n {
1023                    a_work[[k, j]] = a_work[[k, j]] - factor * a_work[[i, j]];
1024                }
1025                b_work[k] = b_work[k] - factor * b_work[i];
1026            }
1027        }
1028        let mut x = Array1::zeros(n);
1029        for i in (0..n).rev() {
1030            let mut sum = b_work[i];
1031            for j in (i + 1)..n {
1032                sum = sum - a_work[[i, j]] * x[j];
1033            }
1034            x[i] = sum / a_work[[i, i]];
1035        }
1036        Ok(x)
1037    }
1038}
1039/// Element-wise operation types
1040#[derive(Clone, Debug)]
1041pub enum ElemOp {
1042    /// Negation: -x
1043    Neg,
1044    /// Absolute value: |x|
1045    Abs,
1046    /// Exponential: e^x
1047    Exp,
1048    /// Natural logarithm: ln(x)
1049    Log,
1050    /// Sine: sin(x)
1051    Sin,
1052    /// Cosine: cos(x)
1053    Cos,
1054    /// Square root: sqrt(x)
1055    Sqrt,
1056    /// Power of 2: x^2
1057    Sqr,
1058    /// Reciprocal: 1/x
1059    Recip,
1060    /// Hyperbolic tangent: tanh(x)
1061    Tanh,
1062    /// Sigmoid: 1 / (1 + e^(-x))
1063    Sigmoid,
1064    /// Rectified Linear Unit: max(0, x)
1065    ReLU,
1066    /// Gaussian Error Linear Unit: x * Φ(x) where Φ is the CDF of standard normal
1067    /// Approximation: 0.5 * x * (1 + tanh(sqrt(2/Ï€) * (x + 0.044715 * x^3)))
1068    Gelu,
1069    /// Exponential Linear Unit: x if x > 0, else e^x - 1
1070    Elu,
1071    /// Scaled Exponential Linear Unit: scale * (x if x > 0, else alpha * (e^x - 1))
1072    /// where scale ≈ 1.0507, alpha ≈ 1.67326
1073    Selu,
1074    /// Softplus: ln(1 + e^x)
1075    Softplus,
1076    /// Sign function: -1 if x < 0, 0 if x == 0, 1 if x > 0
1077    Sign,
1078}