tenrso_exec/executor/
cpuexecutor_traits.rs

1//! # CpuExecutor - Trait Implementations
2//!
3//! This module contains trait implementations for `CpuExecutor`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `TenrsoExecutor`
9//!
10//! ## File Size Note
11//!
12//! This file is 2,152 lines (7.6% over the 2000-line policy limit).
13//! It contains a single, coherent `impl TenrsoExecutor<T> for CpuExecutor` block
14//! with 42 trait methods. Splitting this impl block across multiple files would
15//! require significant architectural changes (delegation pattern, helper traits, etc.)
16//! and would compromise code cohesion. The methods are already organized by category
17//! and delegate to specialized modules where appropriate (simd_ops, tiled_reductions,
18//! advanced_indexing, etc.).
19//!
20//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
21
22use super::functions::*;
23use super::types::*;
24use crate::hints::ExecHints;
25use anyhow::{anyhow, Result};
26use scirs2_core::numeric::{Float, FromPrimitive, Num};
27use tenrso_core::{Axis, DenseND, TensorHandle};
28use tenrso_planner::EinsumSpec;
29
30impl Default for CpuExecutor {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<T> TenrsoExecutor<T> for CpuExecutor
37where
38    T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
39{
40    fn einsum(
41        &mut self,
42        spec: &str,
43        inputs: &[TensorHandle<T>],
44        hints: &ExecHints,
45    ) -> Result<TensorHandle<T>> {
46        let parsed_spec = EinsumSpec::parse(spec)?;
47        if parsed_spec.num_inputs() != inputs.len() {
48            return Err(anyhow!(
49                "Spec expects {} inputs, got {}",
50                parsed_spec.num_inputs(),
51                inputs.len()
52            ));
53        }
54        let dense_inputs: Vec<&DenseND<T>> = inputs
55            .iter()
56            .map(|h| {
57                h.as_dense()
58                    .ok_or_else(|| anyhow!("Only dense tensors supported for now"))
59            })
60            .collect::<Result<Vec<_>>>()?;
61        let dense_inputs_owned: Vec<DenseND<T>> = dense_inputs.iter().map(|&t| t.clone()).collect();
62        let result = self.execute_einsum_with_planner(&parsed_spec, &dense_inputs_owned, hints)?;
63        Ok(TensorHandle::from_dense_auto(result))
64    }
65    fn elem_op(&mut self, op: ElemOp, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
66        let dense = x
67            .as_dense()
68            .ok_or_else(|| anyhow!("Only dense tensors supported for elem_op"))?;
69        let result_data = match op {
70            ElemOp::Neg => dense.view().mapv(|v| -v),
71            ElemOp::Abs => dense.view().mapv(|v| v.abs()),
72            ElemOp::Exp => dense.view().mapv(|v| v.exp()),
73            ElemOp::Log => dense.view().mapv(|v| v.ln()),
74            ElemOp::Sin => dense.view().mapv(|v| v.sin()),
75            ElemOp::Cos => dense.view().mapv(|v| v.cos()),
76            ElemOp::Sqrt => dense.view().mapv(|v| v.sqrt()),
77            ElemOp::Sqr => dense.view().mapv(|v| v * v),
78            ElemOp::Recip => dense.view().mapv(|v| v.recip()),
79            ElemOp::Tanh => dense.view().mapv(|v| v.tanh()),
80            ElemOp::Sigmoid => dense.view().mapv(|v| {
81                let one = T::one();
82                one / (one + (-v).exp())
83            }),
84            ElemOp::ReLU => dense.view().mapv(|v| {
85                let zero = T::zero();
86                if v > zero {
87                    v
88                } else {
89                    zero
90                }
91            }),
92            ElemOp::Gelu => dense.view().mapv(|v| {
93                let half = T::from_f64(0.5).unwrap_or_else(T::one);
94                let one = T::one();
95                let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
96                let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
97                let x_cubed = v * v * v;
98                let inner = coeff * (v + cubic_coeff * x_cubed);
99                half * v * (one + inner.tanh())
100            }),
101            ElemOp::Elu => dense.view().mapv(|v| {
102                let zero = T::zero();
103                let one = T::one();
104                if v > zero {
105                    v
106                } else {
107                    v.exp() - one
108                }
109            }),
110            ElemOp::Selu => dense.view().mapv(|v| {
111                let zero = T::zero();
112                let one = T::one();
113                let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
114                let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
115                if v > zero {
116                    scale * v
117                } else {
118                    scale * alpha * (v.exp() - one)
119                }
120            }),
121            ElemOp::Softplus => dense.view().mapv(|v| {
122                let zero = T::zero();
123                let one = T::one();
124                let abs_v = v.abs();
125                let max_part = if v > zero { v } else { zero };
126                max_part + (one + (-abs_v).exp()).ln()
127            }),
128            ElemOp::Sign => dense.view().mapv(|v| {
129                let zero = T::zero();
130                let one = T::one();
131                let neg_one = -one;
132                if v > zero {
133                    one
134                } else if v < zero {
135                    neg_one
136                } else {
137                    zero
138                }
139            }),
140        };
141        let result = DenseND::from_array(result_data);
142        Ok(TensorHandle::from_dense_auto(result))
143    }
144    fn binary_op(
145        &mut self,
146        op: BinaryOp,
147        x: &TensorHandle<T>,
148        y: &TensorHandle<T>,
149    ) -> Result<TensorHandle<T>> {
150        let dense_x = x
151            .as_dense()
152            .ok_or_else(|| anyhow!("Only dense tensors supported for binary_op"))?;
153        let dense_y = y
154            .as_dense()
155            .ok_or_else(|| anyhow!("Only dense tensors supported for binary_op"))?;
156        if dense_x.shape() != dense_y.shape() {
157            return self.binary_op_with_broadcast(op, dense_x, dense_y);
158        }
159        use scirs2_core::ndarray_ext::Zip;
160        let result_data = match op {
161            BinaryOp::Add => &dense_x.view() + &dense_y.view(),
162            BinaryOp::Sub => &dense_x.view() - &dense_y.view(),
163            BinaryOp::Mul => &dense_x.view() * &dense_y.view(),
164            BinaryOp::Div => &dense_x.view() / &dense_y.view(),
165            BinaryOp::Pow => {
166                let mut result = dense_x.view().to_owned();
167                Zip::from(&mut result)
168                    .and(&dense_x.view())
169                    .and(&dense_y.view())
170                    .for_each(|r, &x_val, &y_val| {
171                        *r = x_val.powf(y_val);
172                    });
173                result
174            }
175            BinaryOp::Maximum => {
176                let mut result = dense_x.view().to_owned();
177                Zip::from(&mut result)
178                    .and(&dense_x.view())
179                    .and(&dense_y.view())
180                    .for_each(|r, &x_val, &y_val| {
181                        *r = if x_val > y_val { x_val } else { y_val };
182                    });
183                result
184            }
185            BinaryOp::Minimum => {
186                let mut result = dense_x.view().to_owned();
187                Zip::from(&mut result)
188                    .and(&dense_x.view())
189                    .and(&dense_y.view())
190                    .for_each(|r, &x_val, &y_val| {
191                        *r = if x_val < y_val { x_val } else { y_val };
192                    });
193                result
194            }
195        };
196        let result = DenseND::from_array(result_data);
197        Ok(TensorHandle::from_dense_auto(result))
198    }
199    fn reduce(
200        &mut self,
201        op: ReduceOp,
202        x: &TensorHandle<T>,
203        axes: &[Axis],
204    ) -> Result<TensorHandle<T>> {
205        let dense = x
206            .as_dense()
207            .ok_or_else(|| anyhow!("Only dense tensors supported for reduce"))?;
208        if axes.is_empty() {
209            return Err(anyhow!("No axes specified for reduction"));
210        }
211        let axis_indices: Vec<usize> = axes.to_vec();
212        let ndim = dense.shape().len();
213        for &axis_idx in &axis_indices {
214            if axis_idx >= ndim {
215                return Err(anyhow!(
216                    "Axis index {} out of range for tensor with {} dimensions",
217                    axis_idx,
218                    ndim
219                ));
220            }
221        }
222        let mut result = dense.view().to_owned();
223        let mut sorted_axes = axis_indices.clone();
224        sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
225        for &axis_idx in &sorted_axes {
226            let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
227            result = match op {
228                ReduceOp::Sum => result.sum_axis(axis),
229                ReduceOp::Max => {
230                    result
231                        .map_axis(
232                            axis,
233                            |view| {
234                                view.iter()
235                                    .cloned()
236                                    .max_by(|a, b| {
237                                        a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
238                                    })
239                                    .unwrap_or_else(T::default)
240                            },
241                        )
242                }
243                ReduceOp::Min => {
244                    result
245                        .map_axis(
246                            axis,
247                            |view| {
248                                view.iter()
249                                    .cloned()
250                                    .min_by(|a, b| {
251                                        a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
252                                    })
253                                    .unwrap_or_else(T::default)
254                            },
255                        )
256                }
257                ReduceOp::Mean => {
258                    result
259                        .mean_axis(axis)
260                        .ok_or_else(|| {
261                            anyhow!(
262                                "Mean reduction failed - axis might be empty or type doesn't support division"
263                            )
264                        })?
265                }
266                ReduceOp::Prod => {
267                    result
268                        .map_axis(
269                            axis,
270                            |view| {
271                                view.iter()
272                                    .cloned()
273                                    .fold(T::one(), |acc, x| acc * x)
274                            },
275                        )
276                }
277                ReduceOp::All => {
278                    result
279                        .map_axis(
280                            axis,
281                            |view| {
282                                let all_nonzero = view.iter().all(|&x| x != T::zero());
283                                if all_nonzero { T::one() } else { T::zero() }
284                            },
285                        )
286                }
287                ReduceOp::Any => {
288                    result
289                        .map_axis(
290                            axis,
291                            |view| {
292                                let any_nonzero = view.iter().any(|&x| x != T::zero());
293                                if any_nonzero { T::one() } else { T::zero() }
294                            },
295                        )
296                }
297                ReduceOp::ArgMax | ReduceOp::ArgMin => {
298                    return Err(anyhow!(
299                        "ArgMax and ArgMin should use dedicated argmax/argmin methods, not reduce"
300                    ));
301                }
302            };
303        }
304        let result_tensor = DenseND::from_array(result);
305        Ok(TensorHandle::from_dense_auto(result_tensor))
306    }
307    fn clip(&mut self, x: &TensorHandle<T>, min_val: T, max_val: T) -> Result<TensorHandle<T>> {
308        let dense = x
309            .as_dense()
310            .ok_or_else(|| anyhow!("Only dense tensors supported for clip"))?;
311        if min_val > max_val {
312            return Err(anyhow!("Invalid clip bounds: min_val > max_val"));
313        }
314        let result_data = dense.view().mapv(|v| {
315            if v < min_val {
316                min_val
317            } else if v > max_val {
318                max_val
319            } else {
320                v
321            }
322        });
323        let result = DenseND::from_array(result_data);
324        Ok(TensorHandle::from_dense_auto(result))
325    }
326    fn softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
327        let dense = x
328            .as_dense()
329            .ok_or_else(|| anyhow!("Only dense tensors supported for softmax"))?;
330        let ndim = dense.shape().len();
331        if axis >= ndim {
332            return Err(anyhow!(
333                "Axis {} out of range for tensor with {} dimensions",
334                axis,
335                ndim
336            ));
337        }
338        use scirs2_core::ndarray_ext::Zip;
339        let axis_obj = scirs2_core::ndarray_ext::Axis(axis);
340        let max_vals = dense.view().map_axis(axis_obj, |view| {
341            view.iter()
342                .cloned()
343                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
344                .unwrap_or_else(T::zero)
345        });
346        let mut exp_vals = dense.view().to_owned();
347        Zip::from(exp_vals.lanes_mut(axis_obj))
348            .and(max_vals.view())
349            .for_each(|mut lane, &max_val| {
350                lane.mapv_inplace(|v| (v - max_val).exp());
351            });
352        let sum_exp = exp_vals.sum_axis(axis_obj);
353        let mut result = exp_vals;
354        Zip::from(result.lanes_mut(axis_obj))
355            .and(sum_exp.view())
356            .for_each(|mut lane, &sum_val| {
357                lane.mapv_inplace(|v| v / sum_val);
358            });
359        Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
360    }
361    fn log_softmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
362        let dense = x
363            .as_dense()
364            .ok_or_else(|| anyhow!("Only dense tensors supported for log_softmax"))?;
365        let ndim = dense.shape().len();
366        if axis >= ndim {
367            return Err(anyhow!(
368                "Axis {} out of range for tensor with {} dimensions",
369                axis,
370                ndim
371            ));
372        }
373        use scirs2_core::ndarray_ext::Zip;
374        let axis_obj = scirs2_core::ndarray_ext::Axis(axis);
375        let max_vals = dense.view().map_axis(axis_obj, |view| {
376            view.iter()
377                .cloned()
378                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
379                .unwrap_or_else(T::zero)
380        });
381        let mut exp_vals = dense.view().to_owned();
382        Zip::from(exp_vals.lanes_mut(axis_obj))
383            .and(max_vals.view())
384            .for_each(|mut lane, &max_val| {
385                lane.mapv_inplace(|v| (v - max_val).exp());
386            });
387        let sum_exp = exp_vals.sum_axis(axis_obj);
388        let log_sum_exp = sum_exp.mapv(|v| v.ln());
389        let mut result = dense.view().to_owned();
390        Zip::from(result.lanes_mut(axis_obj))
391            .and(max_vals.view())
392            .and(log_sum_exp.view())
393            .for_each(|mut lane, &max_val, &lse| {
394                lane.mapv_inplace(|v| v - max_val - lse);
395            });
396        Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
397    }
398    fn transpose(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>> {
399        let dense = x
400            .as_dense()
401            .ok_or_else(|| anyhow!("Only dense tensors supported for transpose"))?;
402        let ndim = dense.shape().len();
403        if axes.len() != ndim {
404            return Err(anyhow!(
405                "Axes length ({}) must match tensor dimensionality ({})",
406                axes.len(),
407                ndim
408            ));
409        }
410        let mut seen = vec![false; ndim];
411        for &axis in axes {
412            if axis >= ndim {
413                return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
414            }
415            if seen[axis] {
416                return Err(anyhow!("Duplicate axis {} in permutation", axis));
417            }
418            seen[axis] = true;
419        }
420        let permuted = dense.view().permuted_axes(axes);
421        let result = DenseND::from_array(permuted.to_owned());
422        Ok(TensorHandle::from_dense_auto(result))
423    }
424    fn reshape(&mut self, x: &TensorHandle<T>, new_shape: &[usize]) -> Result<TensorHandle<T>> {
425        let dense = x
426            .as_dense()
427            .ok_or_else(|| anyhow!("Only dense tensors supported for reshape"))?;
428        let old_size: usize = dense.shape().iter().product();
429        let new_size: usize = new_shape.iter().product();
430        if old_size != new_size {
431            return Err(anyhow!(
432                "Cannot reshape tensor of size {} to size {}",
433                old_size,
434                new_size
435            ));
436        }
437        use scirs2_core::ndarray_ext::{Array, IxDyn};
438        let data: Vec<T> = dense.view().iter().cloned().collect();
439        let reshaped = Array::from_shape_vec(IxDyn(new_shape), data)
440            .map_err(|e| anyhow!("Reshape failed: {}", e))?;
441        let result = DenseND::from_array(reshaped);
442        Ok(TensorHandle::from_dense_auto(result))
443    }
444    fn concatenate(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>> {
445        if tensors.is_empty() {
446            return Err(anyhow!("Cannot concatenate empty tensor list"));
447        }
448        let dense_tensors: Vec<&DenseND<T>> = tensors
449            .iter()
450            .map(|t| {
451                t.as_dense()
452                    .ok_or_else(|| anyhow!("Only dense tensors supported for concatenate"))
453            })
454            .collect::<Result<Vec<_>>>()?;
455        let ndim = dense_tensors[0].shape().len();
456        for t in dense_tensors.iter().skip(1) {
457            if t.shape().len() != ndim {
458                return Err(anyhow!(
459                    "All tensors must have same number of dimensions, got {} and {}",
460                    ndim,
461                    t.shape().len()
462                ));
463            }
464        }
465        if axis >= ndim {
466            return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
467        }
468        for dim in 0..ndim {
469            if dim != axis {
470                let expected_size = dense_tensors[0].shape()[dim];
471                for (i, t) in dense_tensors.iter().enumerate().skip(1) {
472                    if t.shape()[dim] != expected_size {
473                        return Err(anyhow!(
474                            "Dimension {} mismatch: tensor 0 has size {}, tensor {} has size {}",
475                            dim,
476                            expected_size,
477                            i,
478                            t.shape()[dim]
479                        ));
480                    }
481                }
482            }
483        }
484        use scirs2_core::ndarray_ext::{Array, IxDyn};
485        let mut output_shape = dense_tensors[0].shape().to_vec();
486        for t in dense_tensors.iter().skip(1) {
487            output_shape[axis] += t.shape()[axis];
488        }
489        let output_size: usize = output_shape.iter().product();
490
491        // Use pooled buffer for concatenate output (Phase 5: Automatic Pooling)
492        let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
493        output_data.clear(); // Ensure buffer starts empty
494        output_data.reserve(output_size);
495
496        for flat_idx in 0..output_size {
497            let out_idx = self.flat_to_multidim(flat_idx, &output_shape);
498            let mut cumulative_axis_size = 0;
499            let mut tensor_idx = 0;
500            let mut local_axis_pos = out_idx[axis];
501            for (i, t) in dense_tensors.iter().enumerate() {
502                let t_axis_size = t.shape()[axis];
503                if local_axis_pos < cumulative_axis_size + t_axis_size {
504                    tensor_idx = i;
505                    local_axis_pos -= cumulative_axis_size;
506                    break;
507                }
508                cumulative_axis_size += t_axis_size;
509            }
510            let mut src_idx = out_idx.clone();
511            src_idx[axis] = local_axis_pos;
512            let val = dense_tensors[tensor_idx].view()[src_idx.as_slice()];
513            output_data.push(val);
514        }
515
516        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
517            .map_err(|e| anyhow!("Concatenation failed: {}", e))?;
518        self.release_pooled_generic::<T>(&output_shape, output_data);
519        let result = DenseND::from_array(result_array);
520        Ok(TensorHandle::from_dense_auto(result))
521    }
522    fn split(
523        &mut self,
524        x: &TensorHandle<T>,
525        num_splits: usize,
526        axis: Axis,
527    ) -> Result<Vec<TensorHandle<T>>> {
528        let dense = x
529            .as_dense()
530            .ok_or_else(|| anyhow!("Only dense tensors supported for split"))?;
531        let ndim = dense.shape().len();
532        if axis >= ndim {
533            return Err(anyhow!("Axis {} out of range for {}D tensor", axis, ndim));
534        }
535        let axis_size = dense.shape()[axis];
536        if axis_size % num_splits != 0 {
537            return Err(anyhow!(
538                "Cannot split axis of size {} into {} equal parts",
539                axis_size,
540                num_splits
541            ));
542        }
543        let split_size = axis_size / num_splits;
544        let mut results = Vec::with_capacity(num_splits);
545        use scirs2_core::ndarray_ext::Axis as NdAxis;
546        for i in 0..num_splits {
547            let start = i * split_size;
548            let end = start + split_size;
549            let sliced = dense
550                .view()
551                .slice_axis(NdAxis(axis), (start..end).into())
552                .to_owned();
553            results.push(TensorHandle::from_dense_auto(DenseND::from_array(sliced)));
554        }
555        Ok(results)
556    }
557    fn layer_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>> {
558        let dense = x
559            .as_dense()
560            .ok_or_else(|| anyhow!("Only dense tensors supported for layer_norm"))?;
561        let ndim = dense.shape().len();
562        if ndim == 0 {
563            return Err(anyhow!("Cannot normalize scalar tensor"));
564        }
565        let last_axis = ndim - 1;
566        use scirs2_core::ndarray_ext::{Axis as NdAxis, Zip};
567        let mean = dense.view().mean_axis(NdAxis(last_axis)).ok_or_else(|| {
568            anyhow!(
569                "Mean computation failed - axis might be empty or type doesn't support division"
570            )
571        })?;
572        let mut variance = dense.view().to_owned();
573        Zip::from(variance.lanes_mut(NdAxis(last_axis)))
574            .and(mean.view())
575            .for_each(|mut lane, &m| {
576                lane.mapv_inplace(|v| {
577                    let diff = v - m;
578                    diff * diff
579                });
580            });
581        let variance = variance
582            .mean_axis(NdAxis(last_axis))
583            .ok_or_else(|| anyhow!("Variance computation failed"))?;
584        let mut result = dense.view().to_owned();
585        Zip::from(result.lanes_mut(NdAxis(last_axis)))
586            .and(mean.view())
587            .and(variance.view())
588            .for_each(|mut lane, &m, &v| {
589                let std = (v + eps).sqrt();
590                lane.mapv_inplace(|x_val| (x_val - m) / std);
591            });
592        Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
593    }
594    fn batch_norm(&mut self, x: &TensorHandle<T>, eps: T) -> Result<TensorHandle<T>> {
595        let dense = x
596            .as_dense()
597            .ok_or_else(|| anyhow!("Only dense tensors supported for batch_norm"))?;
598        let ndim = dense.shape().len();
599        if ndim == 0 {
600            return Err(anyhow!("Cannot normalize scalar tensor"));
601        }
602        let batch_axis = 0;
603        use scirs2_core::ndarray_ext::{Axis as NdAxis, Zip};
604        let mean = dense.view().mean_axis(NdAxis(batch_axis)).ok_or_else(|| {
605            anyhow!(
606                "Mean computation failed - axis might be empty or type doesn't support division"
607            )
608        })?;
609        let mut variance = dense.view().to_owned();
610        Zip::from(variance.lanes_mut(NdAxis(batch_axis)))
611            .and(mean.view())
612            .for_each(|mut lane, &m| {
613                lane.mapv_inplace(|v| {
614                    let diff = v - m;
615                    diff * diff
616                });
617            });
618        let variance = variance
619            .mean_axis(NdAxis(batch_axis))
620            .ok_or_else(|| anyhow!("Variance computation failed"))?;
621        let mut result = dense.view().to_owned();
622        Zip::from(result.lanes_mut(NdAxis(batch_axis)))
623            .and(mean.view())
624            .and(variance.view())
625            .for_each(|mut lane, &m, &v| {
626                let std = (v + eps).sqrt();
627                lane.mapv_inplace(|x_val| (x_val - m) / std);
628            });
629        Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
630    }
631    fn where_op(
632        &mut self,
633        condition: &TensorHandle<T>,
634        x: &TensorHandle<T>,
635        y: &TensorHandle<T>,
636    ) -> Result<TensorHandle<T>> {
637        let cond_dense = condition
638            .as_dense()
639            .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
640        let x_dense = x
641            .as_dense()
642            .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
643        let y_dense = y
644            .as_dense()
645            .ok_or_else(|| anyhow!("Only dense tensors supported for where_op"))?;
646        if cond_dense.shape() != x_dense.shape() || x_dense.shape() != y_dense.shape() {
647            return Err(anyhow!(
648                "Shape mismatch: condition={:?}, x={:?}, y={:?}",
649                cond_dense.shape(),
650                x_dense.shape(),
651                y_dense.shape()
652            ));
653        }
654        use scirs2_core::ndarray_ext::Zip;
655        let mut result = x_dense.view().to_owned();
656        Zip::from(&mut result)
657            .and(&cond_dense.view())
658            .and(&x_dense.view())
659            .and(&y_dense.view())
660            .for_each(|r, &c, &x_val, &y_val| {
661                *r = if c > T::zero() { x_val } else { y_val };
662            });
663        Ok(TensorHandle::from_dense_auto(DenseND::from_array(result)))
664    }
665    fn masked_select(
666        &mut self,
667        x: &TensorHandle<T>,
668        mask: &TensorHandle<T>,
669    ) -> Result<TensorHandle<T>> {
670        let x_dense = x
671            .as_dense()
672            .ok_or_else(|| anyhow!("Only dense tensors supported for masked_select"))?;
673        let mask_dense = mask
674            .as_dense()
675            .ok_or_else(|| anyhow!("Only dense tensors supported for masked_select"))?;
676        if x_dense.shape() != mask_dense.shape() {
677            return Err(anyhow!(
678                "Shape mismatch: x={:?}, mask={:?}",
679                x_dense.shape(),
680                mask_dense.shape()
681            ));
682        }
683        let mut selected = Vec::new();
684        for (x_val, mask_val) in x_dense.view().iter().zip(mask_dense.view().iter()) {
685            if *mask_val > T::zero() {
686                selected.push(*x_val);
687            }
688        }
689        use scirs2_core::ndarray_ext::{Array, IxDyn};
690        let result_array = Array::from_shape_vec(IxDyn(&[selected.len()]), selected)
691            .map_err(|e| anyhow!("Failed to create result array: {}", e))?;
692        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
693            result_array,
694        )))
695    }
696    fn modulo(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>> {
697        let dense = x
698            .as_dense()
699            .ok_or_else(|| anyhow!("Only dense tensors supported for modulo"))?;
700        if divisor == T::zero() {
701            return Err(anyhow!("Division by zero in modulo operation"));
702        }
703        let result_data = dense.view().mapv(|v| {
704            let quot = (v / divisor).floor();
705            v - quot * divisor
706        });
707        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
708            result_data,
709        )))
710    }
711    fn remainder(&mut self, x: &TensorHandle<T>, divisor: T) -> Result<TensorHandle<T>> {
712        self.modulo(x, divisor)
713    }
714    fn max_pool_1d(
715        &mut self,
716        x: &TensorHandle<T>,
717        kernel_size: usize,
718        stride: usize,
719    ) -> Result<TensorHandle<T>> {
720        let dense = x
721            .as_dense()
722            .ok_or_else(|| anyhow!("Only dense tensors supported for max_pool_1d"))?;
723        if kernel_size == 0 || stride == 0 {
724            return Err(anyhow!("Kernel size and stride must be positive"));
725        }
726        let shape = dense.shape();
727        if shape.len() != 1 {
728            return Err(anyhow!(
729                "Expected 1D tensor for max_pool_1d, got {:?}D",
730                shape.len()
731            ));
732        }
733        let input_len = shape[0];
734        if kernel_size > input_len {
735            return Err(anyhow!(
736                "Kernel size {} larger than input length {}",
737                kernel_size,
738                input_len
739            ));
740        }
741        let output_len = (input_len - kernel_size) / stride + 1;
742        let mut output = Vec::with_capacity(output_len);
743        let view = dense.view();
744        for i in 0..output_len {
745            let start = i * stride;
746            let end = start + kernel_size;
747            let max_val = (start..end)
748                .map(|j| view[[j]])
749                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
750                .unwrap_or_else(T::default);
751            output.push(max_val);
752        }
753        use scirs2_core::ndarray_ext::{Array, IxDyn};
754        let result_array = Array::from_shape_vec(IxDyn(&[output_len]), output)
755            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
756        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
757            result_array,
758        )))
759    }
760    fn avg_pool_1d(
761        &mut self,
762        x: &TensorHandle<T>,
763        kernel_size: usize,
764        stride: usize,
765    ) -> Result<TensorHandle<T>> {
766        let dense = x
767            .as_dense()
768            .ok_or_else(|| anyhow!("Only dense tensors supported for avg_pool_1d"))?;
769        if kernel_size == 0 || stride == 0 {
770            return Err(anyhow!("Kernel size and stride must be positive"));
771        }
772        let shape = dense.shape();
773        if shape.len() != 1 {
774            return Err(anyhow!(
775                "Expected 1D tensor for avg_pool_1d, got {:?}D",
776                shape.len()
777            ));
778        }
779        let input_len = shape[0];
780        if kernel_size > input_len {
781            return Err(anyhow!(
782                "Kernel size {} larger than input length {}",
783                kernel_size,
784                input_len
785            ));
786        }
787        let output_len = (input_len - kernel_size) / stride + 1;
788        let mut output = Vec::with_capacity(output_len);
789        let view = dense.view();
790        let kernel_size_t = T::from_usize(kernel_size).unwrap_or_else(T::one);
791        for i in 0..output_len {
792            let start = i * stride;
793            let end = start + kernel_size;
794            let mut sum = T::zero();
795            for j in start..end {
796                sum += view[[j]];
797            }
798            let avg = sum / kernel_size_t;
799            output.push(avg);
800        }
801        use scirs2_core::ndarray_ext::{Array, IxDyn};
802        let result_array = Array::from_shape_vec(IxDyn(&[output_len]), output)
803            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
804        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
805            result_array,
806        )))
807    }
808    fn max_pool_2d(
809        &mut self,
810        x: &TensorHandle<T>,
811        kernel_size: (usize, usize),
812        stride: (usize, usize),
813    ) -> Result<TensorHandle<T>> {
814        let dense = x
815            .as_dense()
816            .ok_or_else(|| anyhow!("Only dense tensors supported for max_pool_2d"))?;
817        let (kh, kw) = kernel_size;
818        let (sh, sw) = stride;
819        if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
820            return Err(anyhow!("Kernel size and stride must be positive"));
821        }
822        let shape = dense.shape();
823        if shape.len() != 2 {
824            return Err(anyhow!(
825                "Expected 2D tensor for max_pool_2d, got {:?}D",
826                shape.len()
827            ));
828        }
829        let (h, w) = (shape[0], shape[1]);
830        if kh > h || kw > w {
831            return Err(anyhow!(
832                "Kernel size ({}, {}) larger than input ({}, {})",
833                kh,
834                kw,
835                h,
836                w
837            ));
838        }
839        let out_h = (h - kh) / sh + 1;
840        let out_w = (w - kw) / sw + 1;
841        let output_shape = [out_h, out_w];
842
843        // Use pooled buffer for max_pool_2d output (Phase 5: Automatic Pooling)
844        let mut output = self.acquire_pooled_generic::<T>(&output_shape);
845        output.clear(); // Ensure buffer starts empty
846        output.reserve(out_h * out_w);
847
848        let view = dense.view();
849        for i in 0..out_h {
850            for j in 0..out_w {
851                let start_h = i * sh;
852                let start_w = j * sw;
853                let mut max_val = T::default();
854                let mut first = true;
855                for di in 0..kh {
856                    for dj in 0..kw {
857                        let val = view[[start_h + di, start_w + dj]];
858                        if first || val > max_val {
859                            max_val = val;
860                            first = false;
861                        }
862                    }
863                }
864                output.push(max_val);
865            }
866        }
867
868        use scirs2_core::ndarray_ext::{Array, IxDyn};
869        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
870            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
871        self.release_pooled_generic::<T>(&output_shape, output);
872        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
873            result_array,
874        )))
875    }
876    fn avg_pool_2d(
877        &mut self,
878        x: &TensorHandle<T>,
879        kernel_size: (usize, usize),
880        stride: (usize, usize),
881    ) -> Result<TensorHandle<T>> {
882        let dense = x
883            .as_dense()
884            .ok_or_else(|| anyhow!("Only dense tensors supported for avg_pool_2d"))?;
885        let (kh, kw) = kernel_size;
886        let (sh, sw) = stride;
887        if kh == 0 || kw == 0 || sh == 0 || sw == 0 {
888            return Err(anyhow!("Kernel size and stride must be positive"));
889        }
890        let shape = dense.shape();
891        if shape.len() != 2 {
892            return Err(anyhow!(
893                "Expected 2D tensor for avg_pool_2d, got {:?}D",
894                shape.len()
895            ));
896        }
897        let (h, w) = (shape[0], shape[1]);
898        if kh > h || kw > w {
899            return Err(anyhow!(
900                "Kernel size ({}, {}) larger than input ({}, {})",
901                kh,
902                kw,
903                h,
904                w
905            ));
906        }
907        let out_h = (h - kh) / sh + 1;
908        let out_w = (w - kw) / sw + 1;
909        let output_shape = [out_h, out_w];
910
911        // Use pooled buffer for avg_pool_2d output (Phase 5: Automatic Pooling)
912        let mut output = self.acquire_pooled_generic::<T>(&output_shape);
913        output.clear(); // Ensure buffer starts empty
914        output.reserve(out_h * out_w);
915
916        let view = dense.view();
917        let kernel_count = T::from_usize(kh * kw).unwrap_or_else(T::one);
918        for i in 0..out_h {
919            for j in 0..out_w {
920                let start_h = i * sh;
921                let start_w = j * sw;
922                let mut sum = T::zero();
923                for di in 0..kh {
924                    for dj in 0..kw {
925                        sum += view[[start_h + di, start_w + dj]];
926                    }
927                }
928                let avg = sum / kernel_count;
929                output.push(avg);
930            }
931        }
932
933        use scirs2_core::ndarray_ext::{Array, IxDyn};
934        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
935            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
936        self.release_pooled_generic::<T>(&output_shape, output);
937        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
938            result_array,
939        )))
940    }
941    fn conv1d(
942        &mut self,
943        x: &TensorHandle<T>,
944        kernel: &TensorHandle<T>,
945        bias: Option<&TensorHandle<T>>,
946        stride: usize,
947        padding: (usize, usize),
948    ) -> Result<TensorHandle<T>> {
949        let dense_x = x
950            .as_dense()
951            .ok_or_else(|| anyhow!("Only dense tensors supported for conv1d"))?;
952        let dense_kernel = kernel
953            .as_dense()
954            .ok_or_else(|| anyhow!("Only dense tensors supported for conv1d kernel"))?;
955        if stride == 0 {
956            return Err(anyhow!("Stride must be positive"));
957        }
958        let x_shape = dense_x.shape();
959        let k_shape = dense_kernel.shape();
960        if x_shape.len() != 3 {
961            return Err(anyhow!(
962                "Expected 3D input tensor [batch, in_channels, length], got {:?}D",
963                x_shape.len()
964            ));
965        }
966        if k_shape.len() != 3 {
967            return Err(anyhow!(
968                "Expected 3D kernel tensor [out_channels, in_channels, kernel_size], got {:?}D",
969                k_shape.len()
970            ));
971        }
972        let (batch, in_channels, in_length) = (x_shape[0], x_shape[1], x_shape[2]);
973        let (out_channels, k_in_channels, kernel_size) = (k_shape[0], k_shape[1], k_shape[2]);
974        if in_channels != k_in_channels {
975            return Err(anyhow!(
976                "Input channels mismatch: input has {}, kernel expects {}",
977                in_channels,
978                k_in_channels
979            ));
980        }
981        if let Some(bias_tensor) = bias {
982            let bias_dense = bias_tensor
983                .as_dense()
984                .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
985            let bias_shape = bias_dense.shape();
986            if bias_shape.len() != 1 || bias_shape[0] != out_channels {
987                return Err(anyhow!(
988                    "Expected bias shape [{}], got {:?}",
989                    out_channels,
990                    bias_shape
991                ));
992            }
993        }
994        let (pad_left, pad_right) = padding;
995        let padded_length = in_length + pad_left + pad_right;
996        if kernel_size > padded_length {
997            return Err(anyhow!(
998                "Kernel size {} larger than padded input length {}",
999                kernel_size,
1000                padded_length
1001            ));
1002        }
1003        let out_length = (padded_length - kernel_size) / stride + 1;
1004        let output_shape = [batch, out_channels, out_length];
1005
1006        // Use pooled buffer for output allocation (Phase 5: Automatic Pooling)
1007        let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1008        output.clear(); // Ensure buffer starts empty
1009        output.resize(batch * out_channels * out_length, T::zero());
1010
1011        let x_view = dense_x.view();
1012        let k_view = dense_kernel.view();
1013        for b in 0..batch {
1014            for oc in 0..out_channels {
1015                for o in 0..out_length {
1016                    let mut sum = T::zero();
1017                    let in_start = (o * stride) as isize - pad_left as isize;
1018                    for ic in 0..in_channels {
1019                        for k in 0..kernel_size {
1020                            let in_pos = in_start + k as isize;
1021                            if in_pos >= 0 && (in_pos as usize) < in_length {
1022                                let x_val = x_view[[b, ic, in_pos as usize]];
1023                                let k_val = k_view[[oc, ic, k]];
1024                                sum += x_val * k_val;
1025                            }
1026                        }
1027                    }
1028                    if let Some(bias_tensor) = bias {
1029                        let bias_dense = bias_tensor.as_dense().unwrap();
1030                        let bias_view = bias_dense.view();
1031                        sum += bias_view[[oc]];
1032                    }
1033                    output[b * out_channels * out_length + oc * out_length + o] = sum;
1034                }
1035            }
1036        }
1037
1038        use scirs2_core::ndarray_ext::{Array, IxDyn};
1039        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1040            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1041        self.release_pooled_generic::<T>(&output_shape, output);
1042        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1043            result_array,
1044        )))
1045    }
1046    fn conv2d(
1047        &mut self,
1048        x: &TensorHandle<T>,
1049        kernel: &TensorHandle<T>,
1050        bias: Option<&TensorHandle<T>>,
1051        stride: (usize, usize),
1052        padding: (usize, usize, usize, usize),
1053    ) -> Result<TensorHandle<T>> {
1054        let dense_x = x
1055            .as_dense()
1056            .ok_or_else(|| anyhow!("Only dense tensors supported for conv2d"))?;
1057        let dense_kernel = kernel
1058            .as_dense()
1059            .ok_or_else(|| anyhow!("Only dense tensors supported for conv2d kernel"))?;
1060        let (stride_h, stride_w) = stride;
1061        if stride_h == 0 || stride_w == 0 {
1062            return Err(anyhow!("Stride must be positive"));
1063        }
1064        let x_shape = dense_x.shape();
1065        let k_shape = dense_kernel.shape();
1066        if x_shape.len() != 4 {
1067            return Err(anyhow!(
1068                "Expected 4D input tensor [batch, in_channels, height, width], got {:?}D",
1069                x_shape.len()
1070            ));
1071        }
1072        if k_shape.len() != 4 {
1073            return Err(
1074                anyhow!(
1075                    "Expected 4D kernel tensor [out_channels, in_channels, kernel_h, kernel_w], got {:?}D",
1076                    k_shape.len()
1077                ),
1078            );
1079        }
1080        let (batch, in_channels, in_h, in_w) = (x_shape[0], x_shape[1], x_shape[2], x_shape[3]);
1081        let (out_channels, k_in_channels, kernel_h, kernel_w) =
1082            (k_shape[0], k_shape[1], k_shape[2], k_shape[3]);
1083        if in_channels != k_in_channels {
1084            return Err(anyhow!(
1085                "Input channels mismatch: input has {}, kernel expects {}",
1086                in_channels,
1087                k_in_channels
1088            ));
1089        }
1090        if let Some(bias_tensor) = bias {
1091            let bias_dense = bias_tensor
1092                .as_dense()
1093                .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
1094            let bias_shape = bias_dense.shape();
1095            if bias_shape.len() != 1 || bias_shape[0] != out_channels {
1096                return Err(anyhow!(
1097                    "Expected bias shape [{}], got {:?}",
1098                    out_channels,
1099                    bias_shape
1100                ));
1101            }
1102        }
1103        let (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right) = padding;
1104        let padded_h = in_h + pad_h_top + pad_h_bottom;
1105        let padded_w = in_w + pad_w_left + pad_w_right;
1106        if kernel_h > padded_h || kernel_w > padded_w {
1107            return Err(anyhow!(
1108                "Kernel size ({}, {}) larger than padded input ({}, {})",
1109                kernel_h,
1110                kernel_w,
1111                padded_h,
1112                padded_w
1113            ));
1114        }
1115        let out_h = (padded_h - kernel_h) / stride_h + 1;
1116        let out_w = (padded_w - kernel_w) / stride_w + 1;
1117        let output_shape = [batch, out_channels, out_h, out_w];
1118
1119        // Use pooled buffer for output allocation (Phase 5: Automatic Pooling)
1120        let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1121        output.clear(); // Ensure buffer starts empty
1122        output.resize(batch * out_channels * out_h * out_w, T::zero());
1123
1124        let x_view = dense_x.view();
1125        let k_view = dense_kernel.view();
1126        for b in 0..batch {
1127            for oc in 0..out_channels {
1128                for oh in 0..out_h {
1129                    for ow in 0..out_w {
1130                        let mut sum = T::zero();
1131                        let in_start_h = (oh * stride_h) as isize - pad_h_top as isize;
1132                        let in_start_w = (ow * stride_w) as isize - pad_w_left as isize;
1133                        for ic in 0..in_channels {
1134                            for kh in 0..kernel_h {
1135                                for kw in 0..kernel_w {
1136                                    let in_h_pos = in_start_h + kh as isize;
1137                                    let in_w_pos = in_start_w + kw as isize;
1138                                    if in_h_pos >= 0
1139                                        && (in_h_pos as usize) < in_h
1140                                        && in_w_pos >= 0
1141                                        && (in_w_pos as usize) < in_w
1142                                    {
1143                                        let x_val =
1144                                            x_view[[b, ic, in_h_pos as usize, in_w_pos as usize]];
1145                                        let k_val = k_view[[oc, ic, kh, kw]];
1146                                        sum += x_val * k_val;
1147                                    }
1148                                }
1149                            }
1150                        }
1151                        if let Some(bias_tensor) = bias {
1152                            let bias_dense = bias_tensor.as_dense().unwrap();
1153                            let bias_view = bias_dense.view();
1154                            sum += bias_view[[oc]];
1155                        }
1156                        let out_idx = ((b * out_channels + oc) * out_h + oh) * out_w + ow;
1157                        output[out_idx] = sum;
1158                    }
1159                }
1160            }
1161        }
1162
1163        use scirs2_core::ndarray_ext::{Array, IxDyn};
1164        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1165            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1166        self.release_pooled_generic::<T>(&output_shape, output);
1167        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1168            result_array,
1169        )))
1170    }
1171    fn gather(
1172        &mut self,
1173        x: &TensorHandle<T>,
1174        axis: Axis,
1175        indices: &TensorHandle<T>,
1176    ) -> Result<TensorHandle<T>> {
1177        let dense_x = x
1178            .as_dense()
1179            .ok_or_else(|| anyhow!("Only dense tensors supported for gather"))?;
1180        let dense_indices = indices
1181            .as_dense()
1182            .ok_or_else(|| anyhow!("Only dense tensors supported for indices"))?;
1183        let x_shape = dense_x.shape();
1184        let axis_idx = axis;
1185        if axis_idx >= x_shape.len() {
1186            return Err(anyhow!(
1187                "Axis {} out of bounds for tensor with {} dimensions",
1188                axis_idx,
1189                x_shape.len()
1190            ));
1191        }
1192        let indices_shape = dense_indices.shape();
1193        let num_indices: usize = indices_shape.iter().product();
1194        let mut output_shape = Vec::new();
1195        for (i, &dim) in x_shape.iter().enumerate() {
1196            if i == axis_idx {
1197                output_shape.extend_from_slice(indices_shape);
1198            } else if i != axis_idx {
1199                output_shape.push(dim);
1200            }
1201        }
1202        let output_size: usize = output_shape.iter().product();
1203        let mut output = Vec::with_capacity(output_size);
1204        let x_view = dense_x.view();
1205        let indices_view = dense_indices.view();
1206        if axis_idx == 0 && indices_shape.len() == 1 {
1207            let axis_size = x_shape[0];
1208            let elements_per_item: usize = x_shape[1..].iter().product();
1209            for idx_flat in 0..num_indices {
1210                let idx_multi = self.flat_to_multidim(idx_flat, indices_shape);
1211                let idx_val = indices_view[idx_multi.as_slice()];
1212                let idx = idx_val
1213                    .to_usize()
1214                    .ok_or_else(|| anyhow!("Invalid index value: cannot convert to usize"))?;
1215                if idx >= axis_size {
1216                    return Err(anyhow!(
1217                        "Index {} out of bounds for axis {} with size {}",
1218                        idx,
1219                        axis_idx,
1220                        axis_size
1221                    ));
1222                }
1223                for elem_idx in 0..elements_per_item {
1224                    let mut x_index = vec![idx];
1225                    let elem_multi = self.flat_to_multidim(elem_idx, &x_shape[1..]);
1226                    x_index.extend(elem_multi);
1227                    output.push(x_view[x_index.as_slice()]);
1228                }
1229            }
1230        } else {
1231            return Err(anyhow!(
1232                "Gather only supports axis=0 with 1D indices in this implementation"
1233            ));
1234        }
1235        use scirs2_core::ndarray_ext::{Array, IxDyn};
1236        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output)
1237            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1238        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1239            result_array,
1240        )))
1241    }
1242    fn scatter(
1243        &mut self,
1244        shape: &[usize],
1245        axis: Axis,
1246        indices: &TensorHandle<T>,
1247        values: &TensorHandle<T>,
1248    ) -> Result<TensorHandle<T>> {
1249        let dense_indices = indices
1250            .as_dense()
1251            .ok_or_else(|| anyhow!("Only dense tensors supported for indices"))?;
1252        let dense_values = values
1253            .as_dense()
1254            .ok_or_else(|| anyhow!("Only dense tensors supported for values"))?;
1255        let axis_idx = axis;
1256        if axis_idx >= shape.len() {
1257            return Err(anyhow!(
1258                "Axis {} out of bounds for output shape with {} dimensions",
1259                axis_idx,
1260                shape.len()
1261            ));
1262        }
1263        let indices_shape = dense_indices.shape();
1264        let values_shape = dense_values.shape();
1265        let num_indices: usize = indices_shape.iter().product();
1266        let mut expected_values_shape = Vec::new();
1267        expected_values_shape.extend_from_slice(&shape[..axis_idx]);
1268        expected_values_shape.extend_from_slice(indices_shape);
1269        expected_values_shape.extend_from_slice(&shape[axis_idx + 1..]);
1270        if values_shape != expected_values_shape.as_slice() {
1271            return Err(anyhow!(
1272                "Values shape {:?} doesn't match expected shape {:?}",
1273                values_shape,
1274                expected_values_shape
1275            ));
1276        }
1277        let output_size: usize = shape.iter().product();
1278        let mut output = vec![T::zero(); output_size];
1279        let indices_view = dense_indices.view();
1280        let values_view = dense_values.view();
1281        if axis_idx == 0 && indices_shape.len() == 1 {
1282            let axis_size = shape[0];
1283            let elements_per_item: usize = shape[1..].iter().product();
1284            for idx_flat in 0..num_indices {
1285                let idx_multi = self.flat_to_multidim(idx_flat, indices_shape);
1286                let idx_val = indices_view[idx_multi.as_slice()];
1287                let idx = idx_val
1288                    .to_usize()
1289                    .ok_or_else(|| anyhow!("Invalid index value: cannot convert to usize"))?;
1290                if idx >= axis_size {
1291                    return Err(anyhow!(
1292                        "Index {} out of bounds for axis {} with size {}",
1293                        idx,
1294                        axis_idx,
1295                        axis_size
1296                    ));
1297                }
1298                for elem_idx in 0..elements_per_item {
1299                    let mut values_index = vec![idx_flat];
1300                    let elem_multi = self.flat_to_multidim(elem_idx, &shape[1..]);
1301                    values_index.extend(elem_multi.clone());
1302                    let mut out_index = vec![idx];
1303                    out_index.extend(elem_multi);
1304                    let out_flat = self.multidim_to_flat(&out_index, shape);
1305                    output[out_flat] = values_view[values_index.as_slice()];
1306                }
1307            }
1308        } else {
1309            return Err(anyhow!(
1310                "Scatter only supports axis=0 with 1D indices in this implementation"
1311            ));
1312        }
1313        use scirs2_core::ndarray_ext::{Array, IxDyn};
1314        let result_array = Array::from_shape_vec(IxDyn(shape), output)
1315            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1316        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1317            result_array,
1318        )))
1319    }
1320    fn conv3d(
1321        &mut self,
1322        x: &TensorHandle<T>,
1323        kernel: &TensorHandle<T>,
1324        bias: Option<&TensorHandle<T>>,
1325        stride: (usize, usize, usize),
1326        padding: (usize, usize, usize, usize, usize, usize),
1327    ) -> Result<TensorHandle<T>> {
1328        let dense_x = x
1329            .as_dense()
1330            .ok_or_else(|| anyhow!("Only dense tensors supported for conv3d"))?;
1331        let dense_kernel = kernel
1332            .as_dense()
1333            .ok_or_else(|| anyhow!("Only dense tensors supported for conv3d kernel"))?;
1334        let (stride_d, stride_h, stride_w) = stride;
1335        if stride_d == 0 || stride_h == 0 || stride_w == 0 {
1336            return Err(anyhow!("Stride must be positive"));
1337        }
1338        let x_shape = dense_x.shape();
1339        let k_shape = dense_kernel.shape();
1340        if x_shape.len() != 5 {
1341            return Err(anyhow!(
1342                "Expected 5D input tensor [batch, in_channels, depth, height, width], got {:?}D",
1343                x_shape.len()
1344            ));
1345        }
1346        if k_shape.len() != 5 {
1347            return Err(
1348                anyhow!(
1349                    "Expected 5D kernel tensor [out_channels, in_channels, kernel_d, kernel_h, kernel_w], got {:?}D",
1350                    k_shape.len()
1351                ),
1352            );
1353        }
1354        let (batch, in_channels, in_d, in_h, in_w) =
1355            (x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]);
1356        let (out_channels, k_in_channels, kernel_d, kernel_h, kernel_w) =
1357            (k_shape[0], k_shape[1], k_shape[2], k_shape[3], k_shape[4]);
1358        if in_channels != k_in_channels {
1359            return Err(anyhow!(
1360                "Input channels mismatch: input has {}, kernel expects {}",
1361                in_channels,
1362                k_in_channels
1363            ));
1364        }
1365        if let Some(bias_tensor) = bias {
1366            let bias_dense = bias_tensor
1367                .as_dense()
1368                .ok_or_else(|| anyhow!("Only dense tensors supported for bias"))?;
1369            let bias_shape = bias_dense.shape();
1370            if bias_shape.len() != 1 || bias_shape[0] != out_channels {
1371                return Err(anyhow!(
1372                    "Expected bias shape [{}], got {:?}",
1373                    out_channels,
1374                    bias_shape
1375                ));
1376            }
1377        }
1378        let (pad_d_front, pad_d_back, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right) = padding;
1379        let padded_d = in_d + pad_d_front + pad_d_back;
1380        let padded_h = in_h + pad_h_top + pad_h_bottom;
1381        let padded_w = in_w + pad_w_left + pad_w_right;
1382        if kernel_d > padded_d || kernel_h > padded_h || kernel_w > padded_w {
1383            return Err(anyhow!(
1384                "Kernel size ({}, {}, {}) larger than padded input ({}, {}, {})",
1385                kernel_d,
1386                kernel_h,
1387                kernel_w,
1388                padded_d,
1389                padded_h,
1390                padded_w
1391            ));
1392        }
1393        let out_d = (padded_d - kernel_d) / stride_d + 1;
1394        let out_h = (padded_h - kernel_h) / stride_h + 1;
1395        let out_w = (padded_w - kernel_w) / stride_w + 1;
1396        let output_shape = [batch, out_channels, out_d, out_h, out_w];
1397
1398        // Use pooled buffer for output allocation (Phase 5: Automatic Pooling)
1399        let mut output = self.acquire_pooled_generic::<T>(&output_shape);
1400        output.clear(); // Ensure buffer starts empty
1401        output.resize(batch * out_channels * out_d * out_h * out_w, T::zero());
1402
1403        let x_view = dense_x.view();
1404        let k_view = dense_kernel.view();
1405        for b in 0..batch {
1406            for oc in 0..out_channels {
1407                for od in 0..out_d {
1408                    for oh in 0..out_h {
1409                        for ow in 0..out_w {
1410                            let mut sum = T::zero();
1411                            let in_start_d = (od * stride_d) as isize - pad_d_front as isize;
1412                            let in_start_h = (oh * stride_h) as isize - pad_h_top as isize;
1413                            let in_start_w = (ow * stride_w) as isize - pad_w_left as isize;
1414                            for ic in 0..in_channels {
1415                                for kd in 0..kernel_d {
1416                                    for kh in 0..kernel_h {
1417                                        for kw in 0..kernel_w {
1418                                            let in_d_pos = in_start_d + kd as isize;
1419                                            let in_h_pos = in_start_h + kh as isize;
1420                                            let in_w_pos = in_start_w + kw as isize;
1421                                            if in_d_pos >= 0
1422                                                && (in_d_pos as usize) < in_d
1423                                                && in_h_pos >= 0
1424                                                && (in_h_pos as usize) < in_h
1425                                                && in_w_pos >= 0
1426                                                && (in_w_pos as usize) < in_w
1427                                            {
1428                                                let x_val = x_view[[
1429                                                    b,
1430                                                    ic,
1431                                                    in_d_pos as usize,
1432                                                    in_h_pos as usize,
1433                                                    in_w_pos as usize,
1434                                                ]];
1435                                                let k_val = k_view[[oc, ic, kd, kh, kw]];
1436                                                sum += x_val * k_val;
1437                                            }
1438                                        }
1439                                    }
1440                                }
1441                            }
1442                            if let Some(bias_tensor) = bias {
1443                                let bias_dense = bias_tensor.as_dense().unwrap();
1444                                let bias_view = bias_dense.view();
1445                                sum += bias_view[[oc]];
1446                            }
1447                            let out_idx = ((((b * out_channels + oc) * out_d + od) * out_h + oh)
1448                                * out_w)
1449                                + ow;
1450                            output[out_idx] = sum;
1451                        }
1452                    }
1453                }
1454            }
1455        }
1456
1457        use scirs2_core::ndarray_ext::{Array, IxDyn};
1458        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output.clone())
1459            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1460        self.release_pooled_generic::<T>(&output_shape, output);
1461        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1462            result_array,
1463        )))
1464    }
1465    fn determinant(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1466        let dense = x
1467            .as_dense()
1468            .ok_or_else(|| anyhow!("Only dense tensors supported for determinant"))?;
1469        let shape = dense.shape();
1470        if shape.len() < 2 {
1471            return Err(anyhow!("Input must be at least 2D for determinant"));
1472        }
1473        let n = shape[shape.len() - 1];
1474        let m = shape[shape.len() - 2];
1475        if n != m {
1476            return Err(anyhow!(
1477                "Last two dimensions must be square for determinant, got {}x{}",
1478                m,
1479                n
1480            ));
1481        }
1482        if shape.len() == 2 {
1483            use scirs2_core::ndarray_ext::Array2;
1484            let view = dense.view();
1485            let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| view[[i, j]]);
1486            let det = self.compute_determinant_2d(&matrix)?;
1487            return Ok(TensorHandle::from_dense_auto(DenseND::from_vec(
1488                vec![det],
1489                &[],
1490            )?));
1491        }
1492        let batch_size: usize = shape[..shape.len() - 2].iter().product();
1493        let mut determinants = Vec::with_capacity(batch_size);
1494        let view = dense.view();
1495        for batch_idx in 0..batch_size {
1496            let batch_multi = self.flat_to_multidim(batch_idx, &shape[..shape.len() - 2]);
1497            use scirs2_core::ndarray_ext::Array2;
1498            let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| {
1499                let mut idx = batch_multi.clone();
1500                idx.push(i);
1501                idx.push(j);
1502                view[idx.as_slice()]
1503            });
1504            let det = self.compute_determinant_2d(&matrix)?;
1505            determinants.push(det);
1506        }
1507        let output_shape = &shape[..shape.len() - 2];
1508        use scirs2_core::ndarray_ext::{Array, IxDyn};
1509        let result_array = Array::from_shape_vec(IxDyn(output_shape), determinants)
1510            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1511        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1512            result_array,
1513        )))
1514    }
1515    fn matrix_inverse(&mut self, x: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1516        let dense = x
1517            .as_dense()
1518            .ok_or_else(|| anyhow!("Only dense tensors supported for matrix_inverse"))?;
1519        let shape = dense.shape();
1520        if shape.len() < 2 {
1521            return Err(anyhow!("Input must be at least 2D for matrix inverse"));
1522        }
1523        let n = shape[shape.len() - 1];
1524        let m = shape[shape.len() - 2];
1525        if n != m {
1526            return Err(anyhow!(
1527                "Last two dimensions must be square for matrix inverse, got {}x{}",
1528                m,
1529                n
1530            ));
1531        }
1532        if shape.len() == 2 {
1533            use scirs2_core::ndarray_ext::Array2;
1534            let view = dense.view();
1535            let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| view[[i, j]]);
1536            let inv = self.compute_inverse_2d(&matrix)?;
1537            let inv_dyn = inv.into_dyn();
1538            return Ok(TensorHandle::from_dense_auto(DenseND::from_array(inv_dyn)));
1539        }
1540        let batch_size: usize = shape[..shape.len() - 2].iter().product();
1541        let output_size = batch_size * n * n;
1542        let mut output = Vec::with_capacity(output_size);
1543        let view = dense.view();
1544        for batch_idx in 0..batch_size {
1545            let batch_multi = self.flat_to_multidim(batch_idx, &shape[..shape.len() - 2]);
1546            use scirs2_core::ndarray_ext::Array2;
1547            let matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| {
1548                let mut idx = batch_multi.clone();
1549                idx.push(i);
1550                idx.push(j);
1551                view[idx.as_slice()]
1552            });
1553            let inv = self.compute_inverse_2d(&matrix)?;
1554            output.extend(inv.iter().copied());
1555        }
1556        use scirs2_core::ndarray_ext::{Array, IxDyn};
1557        let result_array = Array::from_shape_vec(IxDyn(shape), output)
1558            .map_err(|e| anyhow!("Failed to create output array: {}", e))?;
1559        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1560            result_array,
1561        )))
1562    }
1563    fn solve(&mut self, a: &TensorHandle<T>, b: &TensorHandle<T>) -> Result<TensorHandle<T>> {
1564        let dense_a = a
1565            .as_dense()
1566            .ok_or_else(|| anyhow!("Only dense tensors supported for solve (A)"))?;
1567        let dense_b = b
1568            .as_dense()
1569            .ok_or_else(|| anyhow!("Only dense tensors supported for solve (b)"))?;
1570        let a_shape = dense_a.shape();
1571        let b_shape = dense_b.shape();
1572        if a_shape.len() < 2 {
1573            return Err(anyhow!("Matrix A must be at least 2D"));
1574        }
1575        if b_shape.is_empty() {
1576            return Err(anyhow!("Vector/matrix b must be at least 1D"));
1577        }
1578        let n = a_shape[a_shape.len() - 1];
1579        let m = a_shape[a_shape.len() - 2];
1580        if n != m {
1581            return Err(anyhow!("Matrix A must be square, got {}x{}", m, n));
1582        }
1583        let b_rows = b_shape[b_shape.len()
1584            - (if b_shape.len() == a_shape.len() - 1 {
1585                1
1586            } else {
1587                2
1588            })];
1589        if b_rows != n {
1590            return Err(anyhow!(
1591                "Dimension mismatch: A is {}x{}, b has {} rows",
1592                m,
1593                n,
1594                b_rows
1595            ));
1596        }
1597        if a_shape.len() == 2 && b_shape.len() == 1 {
1598            use scirs2_core::ndarray_ext::{Array1, Array2};
1599            let a_view = dense_a.view();
1600            let b_view = dense_b.view();
1601            let a_matrix: Array2<T> = Array2::from_shape_fn((n, n), |(i, j)| a_view[[i, j]]);
1602            let b_vector: Array1<T> = Array1::from_shape_fn(n, |i| b_view[[i]]);
1603            let x = self.solve_2d_1d(&a_matrix, &b_vector)?;
1604            let x_dyn = x.into_dyn();
1605            return Ok(TensorHandle::from_dense_auto(DenseND::from_array(x_dyn)));
1606        }
1607        Err(anyhow!(
1608            "Solve only supports 2D matrix A with 1D vector b in this implementation"
1609        ))
1610    }
1611
1612    fn advanced_gather(
1613        &mut self,
1614        x: &TensorHandle<T>,
1615        axis: Axis,
1616        indices: &TensorHandle<T>,
1617        allow_negative: bool,
1618    ) -> Result<TensorHandle<T>> {
1619        let dense = x
1620            .as_dense()
1621            .ok_or_else(|| anyhow!("Only dense tensors supported for advanced_gather"))?;
1622        let indices_dense = indices
1623            .as_dense()
1624            .ok_or_else(|| anyhow!("Indices must be dense tensor"))?;
1625
1626        let result =
1627            super::advanced_indexing::advanced_gather(dense, axis, indices_dense, allow_negative)?;
1628        Ok(TensorHandle::from_dense_auto(result))
1629    }
1630
1631    fn advanced_scatter(
1632        &mut self,
1633        shape: &[usize],
1634        axis: Axis,
1635        indices: &TensorHandle<T>,
1636        values: &TensorHandle<T>,
1637        mode: ScatterMode,
1638    ) -> Result<TensorHandle<T>> {
1639        let indices_dense = indices
1640            .as_dense()
1641            .ok_or_else(|| anyhow!("Indices must be dense tensor"))?;
1642        let values_dense = values
1643            .as_dense()
1644            .ok_or_else(|| anyhow!("Values must be dense tensor"))?;
1645
1646        let result = super::advanced_indexing::advanced_scatter(
1647            shape,
1648            axis,
1649            indices_dense,
1650            values_dense,
1651            mode,
1652        )?;
1653        Ok(TensorHandle::from_dense_auto(result))
1654    }
1655
1656    fn fancy_index_mask(
1657        &mut self,
1658        x: &TensorHandle<T>,
1659        mask: &TensorHandle<T>,
1660    ) -> Result<TensorHandle<T>> {
1661        let dense = x
1662            .as_dense()
1663            .ok_or_else(|| anyhow!("Only dense tensors supported for fancy_index_mask"))?;
1664        let mask_dense = mask
1665            .as_dense()
1666            .ok_or_else(|| anyhow!("Mask must be dense tensor"))?;
1667
1668        let result = super::advanced_indexing::fancy_index_mask(dense, mask_dense)?;
1669        Ok(TensorHandle::from_dense_auto(result))
1670    }
1671
1672    fn tile(&mut self, x: &TensorHandle<T>, reps: &[usize]) -> Result<TensorHandle<T>> {
1673        let dense = x
1674            .as_dense()
1675            .ok_or_else(|| anyhow!("Only dense tensors supported for tile"))?;
1676
1677        let input_shape = dense.shape();
1678
1679        // Ensure reps has same length as input shape
1680        if reps.len() != input_shape.len() {
1681            return Err(anyhow!(
1682                "Reps length {} must match input dimensions {}",
1683                reps.len(),
1684                input_shape.len()
1685            ));
1686        }
1687
1688        // Calculate output shape
1689        let output_shape: Vec<usize> = input_shape
1690            .iter()
1691            .zip(reps.iter())
1692            .map(|(&dim, &rep)| dim * rep)
1693            .collect();
1694
1695        let input_view = dense.view();
1696        let output_size: usize = output_shape.iter().product();
1697
1698        // Use pooled buffer for tile output (Phase 5: Automatic Pooling)
1699        let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
1700        output_data.clear(); // Ensure buffer starts empty
1701        output_data.reserve(output_size);
1702
1703        // Generate all output indices and map to input indices
1704        for i in 0..output_size {
1705            let out_idx = self.flat_to_multidim(i, &output_shape);
1706            // Map output index to input index by taking modulo
1707            let in_idx: Vec<usize> = out_idx
1708                .iter()
1709                .zip(input_shape.iter())
1710                .map(|(&o, &s)| o % s)
1711                .collect();
1712            output_data.push(input_view[in_idx.as_slice()]);
1713        }
1714
1715        use scirs2_core::ndarray_ext::{Array, IxDyn};
1716        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
1717            .map_err(|e| anyhow!("Failed to create tiled array: {}", e))?;
1718        self.release_pooled_generic::<T>(&output_shape, output_data);
1719
1720        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1721            result_array,
1722        )))
1723    }
1724
1725    fn pad(
1726        &mut self,
1727        x: &TensorHandle<T>,
1728        pad_width: &[(usize, usize)],
1729        constant_value: T,
1730    ) -> Result<TensorHandle<T>> {
1731        let dense = x
1732            .as_dense()
1733            .ok_or_else(|| anyhow!("Only dense tensors supported for pad"))?;
1734
1735        let input_shape = dense.shape();
1736
1737        if pad_width.len() != input_shape.len() {
1738            return Err(anyhow!(
1739                "Pad width length {} must match input dimensions {}",
1740                pad_width.len(),
1741                input_shape.len()
1742            ));
1743        }
1744
1745        // Calculate output shape
1746        let output_shape: Vec<usize> = input_shape
1747            .iter()
1748            .zip(pad_width.iter())
1749            .map(|(&dim, &(before, after))| dim + before + after)
1750            .collect();
1751
1752        let input_view = dense.view();
1753        let output_size: usize = output_shape.iter().product();
1754
1755        // Use pooled buffer for pad output (Phase 5: Automatic Pooling)
1756        let mut output_data = self.acquire_pooled_generic::<T>(&output_shape);
1757        output_data.clear(); // Ensure buffer starts empty
1758        output_data.resize(output_size, constant_value); // Initialize with constant_value
1759
1760        // Copy input data to the appropriate region in output
1761        let input_size: usize = input_shape.iter().product();
1762        for i in 0..input_size {
1763            let in_idx = self.flat_to_multidim(i, input_shape);
1764            // Calculate output index by adding padding offsets
1765            let out_idx: Vec<usize> = in_idx
1766                .iter()
1767                .zip(pad_width.iter())
1768                .map(|(&idx, &(before, _))| idx + before)
1769                .collect();
1770            let out_flat = self.multidim_to_flat(&out_idx, &output_shape);
1771            output_data[out_flat] = input_view[in_idx.as_slice()];
1772        }
1773
1774        use scirs2_core::ndarray_ext::{Array, IxDyn};
1775        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data.clone())
1776            .map_err(|e| anyhow!("Failed to create padded array: {}", e))?;
1777        self.release_pooled_generic::<T>(&output_shape, output_data);
1778
1779        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1780            result_array,
1781        )))
1782    }
1783
1784    fn flip(&mut self, x: &TensorHandle<T>, axes: &[Axis]) -> Result<TensorHandle<T>> {
1785        let dense = x
1786            .as_dense()
1787            .ok_or_else(|| anyhow!("Only dense tensors supported for flip"))?;
1788
1789        let shape = dense.shape();
1790
1791        // Validate axes
1792        for &axis in axes {
1793            if axis >= shape.len() {
1794                return Err(anyhow!(
1795                    "Axis {} out of bounds for tensor with {} dimensions",
1796                    axis,
1797                    shape.len()
1798                ));
1799            }
1800        }
1801
1802        let input_view = dense.view();
1803        let total_elements: usize = shape.iter().product();
1804
1805        // Use pooled buffer for flip output (Phase 5: Automatic Pooling)
1806        let mut output_data = self.acquire_pooled_generic::<T>(shape);
1807        output_data.clear(); // Ensure buffer starts empty
1808        output_data.reserve(total_elements);
1809
1810        // For each output position, compute the corresponding flipped input position
1811        for i in 0..total_elements {
1812            let out_idx = self.flat_to_multidim(i, shape);
1813            // Flip specified axes
1814            let in_idx: Vec<usize> = out_idx
1815                .iter()
1816                .enumerate()
1817                .map(|(axis, &idx)| {
1818                    if axes.contains(&axis) {
1819                        // Flip this axis
1820                        shape[axis] - 1 - idx
1821                    } else {
1822                        idx
1823                    }
1824                })
1825                .collect();
1826            output_data.push(input_view[in_idx.as_slice()]);
1827        }
1828
1829        use scirs2_core::ndarray_ext::{Array, IxDyn};
1830        let result_array = Array::from_shape_vec(IxDyn(shape), output_data.clone())
1831            .map_err(|e| anyhow!("Failed to create flipped array: {}", e))?;
1832        self.release_pooled_generic::<T>(shape, output_data);
1833
1834        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
1835            result_array,
1836        )))
1837    }
1838
1839    fn squeeze(&mut self, x: &TensorHandle<T>, axes: Option<&[Axis]>) -> Result<TensorHandle<T>> {
1840        let dense = x
1841            .as_dense()
1842            .ok_or_else(|| anyhow!("Only dense tensors supported for squeeze"))?;
1843
1844        let shape = dense.shape();
1845
1846        // Determine which axes to squeeze
1847        let axes_to_squeeze: Vec<usize> = if let Some(ax) = axes {
1848            // Validate and collect specified axes
1849            for &axis in ax {
1850                if axis >= shape.len() {
1851                    return Err(anyhow!(
1852                        "Axis {} out of bounds for tensor with {} dimensions",
1853                        axis,
1854                        shape.len()
1855                    ));
1856                }
1857                if shape[axis] != 1 {
1858                    return Err(anyhow!(
1859                        "Cannot squeeze axis {} with size {}",
1860                        axis,
1861                        shape[axis]
1862                    ));
1863                }
1864            }
1865            ax.to_vec()
1866        } else {
1867            // Find all axes with size 1
1868            shape
1869                .iter()
1870                .enumerate()
1871                .filter_map(|(i, &s)| if s == 1 { Some(i) } else { None })
1872                .collect()
1873        };
1874
1875        // Build new shape by removing squeezed axes
1876        let new_shape: Vec<usize> = shape
1877            .iter()
1878            .enumerate()
1879            .filter_map(|(i, &s)| {
1880                if axes_to_squeeze.contains(&i) {
1881                    None
1882                } else {
1883                    Some(s)
1884                }
1885            })
1886            .collect();
1887
1888        // If no dimensions to squeeze, return original
1889        if new_shape.len() == shape.len() {
1890            return Ok(x.clone());
1891        }
1892
1893        // Reshape - data order is preserved
1894        self.reshape(x, &new_shape)
1895    }
1896
1897    fn unsqueeze(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
1898        let dense = x
1899            .as_dense()
1900            .ok_or_else(|| anyhow!("Only dense tensors supported for unsqueeze"))?;
1901
1902        let shape = dense.shape();
1903
1904        // axis can be at most shape.len() (to append at the end)
1905        if axis > shape.len() {
1906            return Err(anyhow!(
1907                "Axis {} out of bounds for unsqueeze (max {})",
1908                axis,
1909                shape.len()
1910            ));
1911        }
1912
1913        // Build new shape by inserting 1 at the specified axis
1914        let mut new_shape = shape.to_vec();
1915        new_shape.insert(axis, 1);
1916
1917        self.reshape(x, &new_shape)
1918    }
1919
1920    fn stack(&mut self, tensors: &[TensorHandle<T>], axis: Axis) -> Result<TensorHandle<T>> {
1921        if tensors.is_empty() {
1922            return Err(anyhow!("Cannot stack empty sequence of tensors"));
1923        }
1924
1925        // Get shape of first tensor
1926        let first_shape = tensors[0]
1927            .as_dense()
1928            .ok_or_else(|| anyhow!("Only dense tensors supported for stack"))?
1929            .shape();
1930
1931        // Validate all tensors have same shape
1932        for (i, tensor) in tensors.iter().enumerate().skip(1) {
1933            let shape = tensor
1934                .as_dense()
1935                .ok_or_else(|| anyhow!("Only dense tensors supported for stack"))?
1936                .shape();
1937            if shape != first_shape {
1938                return Err(anyhow!(
1939                    "All tensors must have the same shape for stacking. Tensor 0: {:?}, Tensor {}: {:?}",
1940                    first_shape,
1941                    i,
1942                    shape
1943                ));
1944            }
1945        }
1946
1947        if axis > first_shape.len() {
1948            return Err(anyhow!(
1949                "Axis {} out of bounds for stack (max {})",
1950                axis,
1951                first_shape.len()
1952            ));
1953        }
1954
1955        // Unsqueeze all tensors at the specified axis
1956        let mut unsqueezed = Vec::new();
1957        for tensor in tensors {
1958            unsqueezed.push(self.unsqueeze(tensor, axis)?);
1959        }
1960
1961        // Concatenate along the new axis
1962        self.concatenate(&unsqueezed, axis)
1963    }
1964
1965    fn repeat(
1966        &mut self,
1967        x: &TensorHandle<T>,
1968        repeats: usize,
1969        axis: Axis,
1970    ) -> Result<TensorHandle<T>> {
1971        let dense = x
1972            .as_dense()
1973            .ok_or_else(|| anyhow!("Only dense tensors supported for repeat"))?;
1974
1975        let shape = dense.shape();
1976
1977        if axis >= shape.len() {
1978            return Err(anyhow!(
1979                "Axis {} out of bounds for tensor with {} dimensions",
1980                axis,
1981                shape.len()
1982            ));
1983        }
1984
1985        if repeats == 0 {
1986            return Err(anyhow!("Repeat count must be greater than 0"));
1987        }
1988
1989        if repeats == 1 {
1990            return Ok(x.clone());
1991        }
1992
1993        // Calculate output shape
1994        let mut output_shape = shape.to_vec();
1995        output_shape[axis] *= repeats;
1996
1997        let input_view = dense.view();
1998        let total_elements: usize = output_shape.iter().product();
1999        let mut output_data = Vec::with_capacity(total_elements);
2000
2001        // Generate output by repeating each element along the specified axis
2002        for i in 0..total_elements {
2003            let out_idx = self.flat_to_multidim(i, &output_shape);
2004            // Map output index to input index by dividing by repeats
2005            let mut in_idx = out_idx.clone();
2006            in_idx[axis] /= repeats;
2007            output_data.push(input_view[in_idx.as_slice()]);
2008        }
2009
2010        use scirs2_core::ndarray_ext::{Array, IxDyn};
2011        let result_array = Array::from_shape_vec(IxDyn(&output_shape), output_data)
2012            .map_err(|e| anyhow!("Failed to create repeated array: {}", e))?;
2013
2014        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2015            result_array,
2016        )))
2017    }
2018
2019    fn roll(&mut self, x: &TensorHandle<T>, shift: isize, axis: Axis) -> Result<TensorHandle<T>> {
2020        let dense = x
2021            .as_dense()
2022            .ok_or_else(|| anyhow!("Only dense tensors supported for roll"))?;
2023
2024        let shape = dense.shape();
2025
2026        if axis >= shape.len() {
2027            return Err(anyhow!(
2028                "Axis {} out of bounds for tensor with {} dimensions",
2029                axis,
2030                shape.len()
2031            ));
2032        }
2033
2034        if shift == 0 {
2035            return Ok(x.clone());
2036        }
2037
2038        let axis_size = shape[axis] as isize;
2039        // Normalize shift to [0, axis_size)
2040        let normalized_shift = ((shift % axis_size) + axis_size) % axis_size;
2041
2042        if normalized_shift == 0 {
2043            return Ok(x.clone());
2044        }
2045
2046        let input_view = dense.view();
2047        let total_elements: usize = shape.iter().product();
2048        let mut output_data = Vec::with_capacity(total_elements);
2049
2050        // Generate output by rolling indices along the specified axis
2051        for i in 0..total_elements {
2052            let out_idx = self.flat_to_multidim(i, shape);
2053            // Calculate rolled index for the axis
2054            let mut in_idx = out_idx.clone();
2055            let old_idx = out_idx[axis] as isize;
2056            let new_idx = ((old_idx - normalized_shift + axis_size) % axis_size) as usize;
2057            in_idx[axis] = new_idx;
2058            output_data.push(input_view[in_idx.as_slice()]);
2059        }
2060
2061        use scirs2_core::ndarray_ext::{Array, IxDyn};
2062        let result_array = Array::from_shape_vec(IxDyn(shape), output_data)
2063            .map_err(|e| anyhow!("Failed to create rolled array: {}", e))?;
2064
2065        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2066            result_array,
2067        )))
2068    }
2069
2070    fn argmax(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
2071        let dense = x
2072            .as_dense()
2073            .ok_or_else(|| anyhow!("Only dense tensors supported for argmax"))?;
2074
2075        let shape = dense.shape();
2076
2077        if axis >= shape.len() {
2078            return Err(anyhow!(
2079                "Axis {} out of bounds for tensor with {} dimensions",
2080                axis,
2081                shape.len()
2082            ));
2083        }
2084
2085        // Calculate output shape (remove the reduction axis)
2086        let output_shape: Vec<usize> = shape
2087            .iter()
2088            .enumerate()
2089            .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
2090            .collect();
2091
2092        let input_view = dense.view();
2093        let output_size: usize = if output_shape.is_empty() {
2094            1
2095        } else {
2096            output_shape.iter().product()
2097        };
2098        let mut output_data = Vec::with_capacity(output_size);
2099
2100        // For each output position, find the argmax along the reduction axis
2101        for i in 0..output_size {
2102            let base_idx = if output_shape.is_empty() {
2103                vec![]
2104            } else {
2105                self.flat_to_multidim(i, &output_shape)
2106            };
2107
2108            let mut max_val = T::from_f64(f64::NEG_INFINITY).unwrap();
2109            let mut max_idx = 0usize;
2110
2111            for j in 0..shape[axis] {
2112                let mut idx = Vec::new();
2113                let mut out_pos = 0;
2114                for (dim, &_size) in shape.iter().enumerate() {
2115                    if dim == axis {
2116                        idx.push(j);
2117                    } else {
2118                        idx.push(base_idx[out_pos]);
2119                        out_pos += 1;
2120                    }
2121                }
2122                let val = input_view[idx.as_slice()];
2123                if val > max_val {
2124                    max_val = val;
2125                    max_idx = j;
2126                }
2127            }
2128
2129            output_data.push(T::from_usize(max_idx).unwrap());
2130        }
2131
2132        use scirs2_core::ndarray_ext::{Array, IxDyn};
2133        let result_shape = if output_shape.is_empty() {
2134            vec![]
2135        } else {
2136            output_shape
2137        };
2138        let result_array = Array::from_shape_vec(IxDyn(&result_shape), output_data)
2139            .map_err(|e| anyhow!("Failed to create argmax array: {}", e))?;
2140
2141        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2142            result_array,
2143        )))
2144    }
2145
2146    fn argmin(&mut self, x: &TensorHandle<T>, axis: Axis) -> Result<TensorHandle<T>> {
2147        let dense = x
2148            .as_dense()
2149            .ok_or_else(|| anyhow!("Only dense tensors supported for argmin"))?;
2150
2151        let shape = dense.shape();
2152
2153        if axis >= shape.len() {
2154            return Err(anyhow!(
2155                "Axis {} out of bounds for tensor with {} dimensions",
2156                axis,
2157                shape.len()
2158            ));
2159        }
2160
2161        // Calculate output shape (remove the reduction axis)
2162        let output_shape: Vec<usize> = shape
2163            .iter()
2164            .enumerate()
2165            .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
2166            .collect();
2167
2168        let input_view = dense.view();
2169        let output_size: usize = if output_shape.is_empty() {
2170            1
2171        } else {
2172            output_shape.iter().product()
2173        };
2174        let mut output_data = Vec::with_capacity(output_size);
2175
2176        // For each output position, find the argmin along the reduction axis
2177        for i in 0..output_size {
2178            let base_idx = if output_shape.is_empty() {
2179                vec![]
2180            } else {
2181                self.flat_to_multidim(i, &output_shape)
2182            };
2183
2184            let mut min_val = T::from_f64(f64::INFINITY).unwrap();
2185            let mut min_idx = 0usize;
2186
2187            for j in 0..shape[axis] {
2188                let mut idx = Vec::new();
2189                let mut out_pos = 0;
2190                for (dim, &_size) in shape.iter().enumerate() {
2191                    if dim == axis {
2192                        idx.push(j);
2193                    } else {
2194                        idx.push(base_idx[out_pos]);
2195                        out_pos += 1;
2196                    }
2197                }
2198                let val = input_view[idx.as_slice()];
2199                if val < min_val {
2200                    min_val = val;
2201                    min_idx = j;
2202                }
2203            }
2204
2205            output_data.push(T::from_usize(min_idx).unwrap());
2206        }
2207
2208        use scirs2_core::ndarray_ext::{Array, IxDyn};
2209        let result_shape = if output_shape.is_empty() {
2210            vec![]
2211        } else {
2212            output_shape
2213        };
2214        let result_array = Array::from_shape_vec(IxDyn(&result_shape), output_data)
2215            .map_err(|e| anyhow!("Failed to create argmin array: {}", e))?;
2216
2217        Ok(TensorHandle::from_dense_auto(DenseND::from_array(
2218            result_array,
2219        )))
2220    }
2221}