Skip to main content

yscv_tensor/
ops.rs

1use super::aligned::AlignedVec;
2use super::error::TensorError;
3use super::shape::{
4    broadcast_offset, broadcast_shape, compute_strides, increment_coords, shape_element_count,
5};
6use super::simd;
7use super::tensor::Tensor;
8
9impl Tensor {
10    // ── Binary elementwise with broadcasting ────────────────────────────
11
12    /// Element-wise addition with NumPy-style broadcasting.
13    pub fn add(&self, rhs: &Self) -> Result<Self, TensorError> {
14        if self.shape() == rhs.shape() {
15            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Add);
16        }
17        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Add) {
18            return result;
19        }
20        self.binary_broadcast_op(rhs, |l, r| l + r)
21    }
22
23    /// Element-wise subtraction with NumPy-style broadcasting.
24    pub fn sub(&self, rhs: &Self) -> Result<Self, TensorError> {
25        if self.shape() == rhs.shape() {
26            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Sub);
27        }
28        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Sub) {
29            return result;
30        }
31        self.binary_broadcast_op(rhs, |l, r| l - r)
32    }
33
34    /// Element-wise multiplication with NumPy-style broadcasting.
35    pub fn mul(&self, rhs: &Self) -> Result<Self, TensorError> {
36        if self.shape() == rhs.shape() {
37            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Mul);
38        }
39        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Mul) {
40            return result;
41        }
42        self.binary_broadcast_op(rhs, |l, r| l * r)
43    }
44
45    /// Element-wise division with NumPy-style broadcasting.
46    pub fn div(&self, rhs: &Self) -> Result<Self, TensorError> {
47        if self.shape() == rhs.shape() {
48            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Div);
49        }
50        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Div) {
51            return result;
52        }
53        self.binary_broadcast_op(rhs, |l, r| l / r)
54    }
55
56    /// Element-wise power with NumPy-style broadcasting.
57    ///
58    /// Fast paths:
59    /// - Constant exponent 2.0 → `mul(x, x)` (SIMD-accelerated).
60    /// - Constant exponent 0.5 → `sqrt(x)` (SIMD-accelerated).
61    /// - Same-shape general case → SIMD `exp(exp * ln(base))`.
62    #[allow(unsafe_code)]
63    pub fn pow(&self, rhs: &Self) -> Result<Self, TensorError> {
64        // Fast path: constant exponent (O(1) check instead of O(N) scan).
65        // A single-element tensor or a tensor whose total size is 1 (broadcast scalar)
66        // is always constant.  For same-shape tensors we skip the all-equal scan
67        // entirely — the per-element SIMD path handles them efficiently.
68        let rhs_total: usize = rhs.shape().iter().product();
69        if rhs_total == 1 {
70            let exp_val = rhs.data()[0];
71            if exp_val == 2.0 {
72                return self.mul(self);
73            }
74            if exp_val == 0.5 {
75                return Ok(self.sqrt());
76            }
77            if exp_val == 1.0 {
78                return Ok(self.clone());
79            }
80            if exp_val == 0.0 {
81                return Tensor::ones(self.shape().to_vec());
82            }
83            if exp_val == -1.0 {
84                return Ok(self.reciprocal());
85            }
86        }
87        // Same-shape SIMD path: pow(base, exp) = exp(exp * ln(base))
88        if self.shape() == rhs.shape() {
89            return self.pow_simd(rhs);
90        }
91        self.binary_broadcast_op(rhs, |l, r| l.powf(r))
92    }
93
94    /// SIMD-accelerated pow for same-shape tensors via exp(exp * ln(base)).
95    #[allow(unsafe_code)]
96    fn pow_simd(&self, rhs: &Self) -> Result<Self, TensorError> {
97        let len = self.len();
98        // Compute ln(base)
99        let mut ln_buf = AlignedVec::<f32>::uninitialized(len);
100        simd::ln_dispatch(self.data(), &mut ln_buf);
101        // Multiply by exponent: exp_val * ln(base)
102        let mut prod_buf = AlignedVec::<f32>::uninitialized(len);
103        simd::binary_dispatch(&ln_buf, rhs.data(), &mut prod_buf, simd::BinaryKind::Mul);
104        // Compute exp(exp_val * ln(base))
105        let mut out = AlignedVec::<f32>::uninitialized(len);
106        simd::exp_dispatch(&prod_buf, &mut out);
107        Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
108    }
109
110    /// Element-wise atan2(self, other), with broadcasting.
111    ///
112    /// Same-shape case uses a SIMD-friendly polynomial approximation of atan
113    /// with quadrant correction, avoiding scalar `f32::atan2`.
114    #[allow(unsafe_code)]
115    pub fn atan2(&self, rhs: &Self) -> Result<Self, TensorError> {
116        if self.shape() == rhs.shape() {
117            return self.atan2_fast(rhs);
118        }
119        self.binary_broadcast_op(rhs, f32::atan2)
120    }
121
122    /// Vectorized atan2 for same-shape tensors. Uses Cephes-style range
123    /// reduction to [0, tan(pi/12)] then a polynomial approximation,
124    /// giving < 1e-6 max error across all quadrants.
125    #[allow(unsafe_code)]
126    fn atan2_fast(&self, rhs: &Self) -> Result<Self, TensorError> {
127        let y_data = self.data();
128        let x_data = rhs.data();
129        let len = self.len();
130        let mut out = AlignedVec::<f32>::uninitialized(len);
131
132        // Process in chunks of 4 for ILP (instruction-level parallelism).
133        // Each call to the scalar fast path avoids the overhead of the
134        // generic broadcasting machinery.
135        let mut i = 0;
136        while i + 4 <= len {
137            out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
138            out[i + 1] = fast_atan2_scalar(y_data[i + 1], x_data[i + 1]);
139            out[i + 2] = fast_atan2_scalar(y_data[i + 2], x_data[i + 2]);
140            out[i + 3] = fast_atan2_scalar(y_data[i + 3], x_data[i + 3]);
141            i += 4;
142        }
143        while i < len {
144            out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
145            i += 1;
146        }
147
148        Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
149    }
150
151    /// Element-wise minimum with NumPy-style broadcasting.
152    pub fn minimum(&self, rhs: &Self) -> Result<Self, TensorError> {
153        if self.shape() == rhs.shape() {
154            return self.binary_same_shape(rhs, f32::min);
155        }
156        self.binary_broadcast_op(rhs, f32::min)
157    }
158
159    /// Element-wise maximum with NumPy-style broadcasting.
160    pub fn maximum(&self, rhs: &Self) -> Result<Self, TensorError> {
161        if self.shape() == rhs.shape() {
162            return self.binary_same_shape(rhs, f32::max);
163        }
164        self.binary_broadcast_op(rhs, f32::max)
165    }
166
167    // ── Unary elementwise ───────────────────────────────────────────────
168
169    /// Element-wise negation.
170    pub fn neg(&self) -> Self {
171        self.unary_simd_op(simd::UnaryKind::Neg)
172    }
173
174    /// Element-wise absolute value.
175    pub fn abs(&self) -> Self {
176        self.unary_simd_op(simd::UnaryKind::Abs)
177    }
178
179    /// Element-wise natural exponential.
180    #[allow(unsafe_code)]
181    pub fn exp(&self) -> Self {
182        let len = self.len();
183        // SAFETY: `uninitialized` allocates without zeroing.  `exp_dispatch`
184        // writes every element before anything reads from `out`.
185        let mut out = AlignedVec::<f32>::uninitialized(len);
186        simd::exp_dispatch(self.data(), &mut out);
187        Tensor::from_raw_parts(self.shape(), self.strides(), out)
188    }
189
190    /// Element-wise natural logarithm.
191    #[allow(unsafe_code)]
192    pub fn ln(&self) -> Self {
193        let len = self.len();
194        // SAFETY: `uninitialized` allocates without zeroing.  `ln_dispatch`
195        // writes every element before we ever read from `out`.
196        let mut out = AlignedVec::<f32>::uninitialized(len);
197        simd::ln_dispatch(self.data(), &mut out);
198        Tensor::from_raw_parts(self.shape(), self.strides(), out)
199    }
200
201    /// Element-wise square root.
202    pub fn sqrt(&self) -> Self {
203        self.unary_simd_op(simd::UnaryKind::Sqrt)
204    }
205
206    /// Element-wise reciprocal (`1 / x`).
207    pub fn reciprocal(&self) -> Self {
208        self.unary_simd_op(simd::UnaryKind::Recip)
209    }
210
211    /// Element-wise sign (`-1`, `0`, or `1`).
212    pub fn sign(&self) -> Self {
213        self.unary_simd_op(simd::UnaryKind::Sign)
214    }
215
216    /// Element-wise floor.
217    pub fn floor(&self) -> Self {
218        self.unary_simd_op(simd::UnaryKind::Floor)
219    }
220
221    /// Element-wise ceil.
222    pub fn ceil(&self) -> Self {
223        self.unary_simd_op(simd::UnaryKind::Ceil)
224    }
225
226    /// Element-wise round.
227    pub fn round(&self) -> Self {
228        self.unary_simd_op(simd::UnaryKind::Round)
229    }
230
231    /// Element-wise sine (SIMD-accelerated polynomial approximation).
232    #[allow(unsafe_code)]
233    pub fn sin(&self) -> Self {
234        let len = self.len();
235        let mut out = AlignedVec::<f32>::uninitialized(len);
236        simd::sin_dispatch(self.data(), &mut out);
237        Tensor::from_raw_parts(self.shape(), self.strides(), out)
238    }
239
240    /// Element-wise cosine (SIMD-accelerated polynomial approximation).
241    #[allow(unsafe_code)]
242    pub fn cos(&self) -> Self {
243        let len = self.len();
244        let mut out = AlignedVec::<f32>::uninitialized(len);
245        simd::cos_dispatch(self.data(), &mut out);
246        Tensor::from_raw_parts(self.shape(), self.strides(), out)
247    }
248
249    /// Element-wise tangent (SIMD-accelerated via sin/cos).
250    #[allow(unsafe_code)]
251    pub fn tan(&self) -> Self {
252        let len = self.len();
253        let mut out = AlignedVec::<f32>::uninitialized(len);
254        simd::tan_dispatch(self.data(), &mut out);
255        Tensor::from_raw_parts(self.shape(), self.strides(), out)
256    }
257
258    /// Element-wise arcsine.
259    pub fn asin(&self) -> Self {
260        self.unary_op(f32::asin)
261    }
262
263    /// Element-wise arccosine.
264    pub fn acos(&self) -> Self {
265        self.unary_op(f32::acos)
266    }
267
268    /// Element-wise arctangent.
269    pub fn atan(&self) -> Self {
270        self.unary_op(f32::atan)
271    }
272
273    /// Element-wise hyperbolic sine.
274    pub fn sinh(&self) -> Self {
275        self.unary_op(f32::sinh)
276    }
277
278    /// Element-wise hyperbolic cosine.
279    pub fn cosh(&self) -> Self {
280        self.unary_op(f32::cosh)
281    }
282
283    /// Element-wise base-2 logarithm.
284    pub fn log2(&self) -> Self {
285        self.unary_op(f32::log2)
286    }
287
288    /// Element-wise base-10 logarithm.
289    pub fn log10(&self) -> Self {
290        self.unary_op(f32::log10)
291    }
292
293    /// Convert radians to degrees.
294    pub fn degrees(&self) -> Self {
295        self.unary_op(|v| v.to_degrees())
296    }
297
298    /// Convert degrees to radians.
299    pub fn radians(&self) -> Self {
300        self.unary_op(|v| v.to_radians())
301    }
302
303    /// Clamp all elements to `[min, max]`.
304    #[allow(unsafe_code)]
305    pub fn clamp(&self, min: f32, max: f32) -> Self {
306        let len = self.len();
307        let mut out = AlignedVec::<f32>::uninitialized(len);
308        simd::clamp_dispatch(self.data(), &mut out, min, max);
309        Tensor::from_raw_parts(self.shape(), self.strides(), out)
310    }
311
312    /// Scalar multiplication (broadcast multiply by a constant).
313    pub fn scale(&self, factor: f32) -> Self {
314        self.unary_op(|v| v * factor)
315    }
316
317    /// Add a scalar to all elements.
318    pub fn add_scalar(&self, value: f32) -> Self {
319        self.unary_op(|v| v + value)
320    }
321
322    // ── Reductions ──────────────────────────────────────────────────────
323
324    /// Sum reduction over all elements.
325    pub fn sum(&self) -> f32 {
326        simd::sum_dispatch(self.data())
327    }
328
329    /// Mean reduction over all elements.
330    pub fn mean(&self) -> f32 {
331        if self.is_empty() {
332            return f32::NAN;
333        }
334        self.sum() / self.len() as f32
335    }
336
337    /// Global max reduction. Returns `f32::NEG_INFINITY` for empty tensors.
338    pub fn max_value(&self) -> f32 {
339        simd::max_dispatch(self.data())
340    }
341
342    /// Global min reduction. Returns `f32::INFINITY` for empty tensors.
343    pub fn min_value(&self) -> f32 {
344        simd::min_dispatch(self.data())
345    }
346
347    /// Global argmax (flat index of maximum value).
348    pub fn argmax(&self) -> Option<usize> {
349        if self.is_empty() {
350            return None;
351        }
352        let mut best = 0;
353        let mut best_val = self.data()[0];
354        for (i, &v) in self.data().iter().enumerate().skip(1) {
355            if v > best_val {
356                best_val = v;
357                best = i;
358            }
359        }
360        Some(best)
361    }
362
363    /// Global argmin (flat index of minimum value).
364    pub fn argmin(&self) -> Option<usize> {
365        if self.is_empty() {
366            return None;
367        }
368        let mut best = 0;
369        let mut best_val = self.data()[0];
370        for (i, &v) in self.data().iter().enumerate().skip(1) {
371            if v < best_val {
372                best_val = v;
373                best = i;
374            }
375        }
376        Some(best)
377    }
378
379    /// Variance over all elements (population variance).
380    pub fn var(&self) -> f32 {
381        if self.is_empty() {
382            return f32::NAN;
383        }
384        let m = self.mean();
385        self.data().iter().map(|&v| (v - m) * (v - m)).sum::<f32>() / self.len() as f32
386    }
387
388    /// Standard deviation over all elements (population).
389    pub fn std_dev(&self) -> f32 {
390        self.var().sqrt()
391    }
392
393    /// Sum reduction over one axis. Reduced axis is removed from output shape.
394    pub fn sum_axis(&self, axis: usize) -> Result<Self, TensorError> {
395        let shape = self.shape();
396        let rank = shape.len();
397        if axis >= rank {
398            return Err(TensorError::InvalidAxis { axis, rank });
399        }
400
401        // Fast path: 2D contiguous tensor, axis 0 → accumulate rows with SIMD
402        if rank == 2 && axis == 0 {
403            let (rows, cols) = (shape[0], shape[1]);
404            let data = self.data();
405            let mut out = vec![0.0f32; cols];
406            for row in 0..rows {
407                let row_start = row * cols;
408                simd::add_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
409            }
410            return Self::from_vec(vec![cols], out);
411        }
412
413        // Fast path: 2D contiguous tensor, axis 1 → sum each row with SIMD
414        if rank == 2 && axis == 1 {
415            let (rows, cols) = (shape[0], shape[1]);
416            let data = self.data();
417            let mut out = Vec::with_capacity(rows);
418            for row in 0..rows {
419                out.push(simd::sum_dispatch(&data[row * cols..(row + 1) * cols]));
420            }
421            return Self::from_vec(vec![rows], out);
422        }
423
424        self.reduce_axis(axis, 0.0, |acc, v| acc + v)
425    }
426
427    /// Mean reduction over one axis. Reduced axis is removed from output shape.
428    pub fn mean_axis(&self, axis: usize) -> Result<Self, TensorError> {
429        if axis >= self.rank() {
430            return Err(TensorError::InvalidAxis {
431                axis,
432                rank: self.rank(),
433            });
434        }
435        let axis_len = self.shape()[axis] as f32;
436        let sum = self.sum_axis(axis)?;
437        Ok(sum.scale(1.0 / axis_len))
438    }
439
440    /// Max reduction over one axis. Reduced axis is removed from output shape.
441    pub fn max_axis(&self, axis: usize) -> Result<Self, TensorError> {
442        let shape = self.shape();
443        let rank = shape.len();
444        if axis >= rank {
445            return Err(TensorError::InvalidAxis { axis, rank });
446        }
447
448        // Fast path: 2D contiguous tensor, axis 0 → accumulate rows with SIMD max
449        if rank == 2 && axis == 0 {
450            let (rows, cols) = (shape[0], shape[1]);
451            let data = self.data();
452            let mut out = data[..cols].to_vec();
453            for row in 1..rows {
454                let row_start = row * cols;
455                simd::max_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
456            }
457            return Self::from_vec(vec![cols], out);
458        }
459
460        // Fast path: 2D contiguous tensor, axis 1 → max each row with SIMD
461        if rank == 2 && axis == 1 {
462            let (rows, cols) = (shape[0], shape[1]);
463            let data = self.data();
464            let mut out = Vec::with_capacity(rows);
465            for row in 0..rows {
466                out.push(simd::max_dispatch(&data[row * cols..(row + 1) * cols]));
467            }
468            return Self::from_vec(vec![rows], out);
469        }
470
471        self.reduce_axis(axis, f32::NEG_INFINITY, f32::max)
472    }
473
474    /// Min reduction over one axis. Reduced axis is removed from output shape.
475    pub fn min_axis(&self, axis: usize) -> Result<Self, TensorError> {
476        self.reduce_axis(axis, f32::INFINITY, f32::min)
477    }
478
479    /// Variance reduction over one axis (population variance).
480    pub fn var_axis(&self, axis: usize) -> Result<Self, TensorError> {
481        let m = self.mean_axis(axis)?;
482        let diff = self.sub(&m.unsqueeze(axis)?)?;
483        let sq = diff.mul(&diff)?;
484        sq.mean_axis(axis)
485    }
486
487    /// Global median of all elements.
488    pub fn median(&self) -> f32 {
489        let mut sorted = self.data().to_vec();
490        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491        let n = sorted.len();
492        if n == 0 {
493            return 0.0;
494        }
495        if n % 2 == 1 {
496            sorted[n / 2]
497        } else {
498            (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
499        }
500    }
501
502    /// Median along a given axis.
503    pub fn median_axis(&self, axis: usize) -> Result<Self, TensorError> {
504        let shape = self.shape();
505        let rank = shape.len();
506        if axis >= rank {
507            return Err(TensorError::InvalidAxis { axis, rank });
508        }
509        let axis_len = shape[axis];
510        let outer: usize = shape[..axis].iter().product();
511        let inner: usize = shape[axis + 1..].iter().product();
512        let mut new_shape = shape.to_vec();
513        new_shape.remove(axis);
514        if new_shape.is_empty() {
515            new_shape.push(1);
516        }
517        let data = self.data();
518        let mut result = Vec::with_capacity(outer * inner);
519        for o in 0..outer {
520            for i in 0..inner {
521                let mut vals: Vec<f32> = (0..axis_len)
522                    .map(|a| data[o * axis_len * inner + a * inner + i])
523                    .collect();
524                vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525                let n = vals.len();
526                let med = if n % 2 == 1 {
527                    vals[n / 2]
528                } else {
529                    (vals[n / 2 - 1] + vals[n / 2]) / 2.0
530                };
531                result.push(med);
532            }
533        }
534        Self::from_vec(new_shape, result)
535    }
536
537    /// Returns true if any element is non-zero.
538    pub fn any(&self) -> bool {
539        self.data().iter().any(|&v| v != 0.0)
540    }
541
542    /// Returns true if all elements are non-zero.
543    pub fn all(&self) -> bool {
544        self.data().iter().all(|&v| v != 0.0)
545    }
546
547    /// Quantile of all elements. q must be in [0, 1].
548    pub fn quantile(&self, q: f32) -> f32 {
549        let mut sorted = self.data().to_vec();
550        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
551        let n = sorted.len();
552        if n == 0 {
553            return 0.0;
554        }
555        let idx = q * (n - 1) as f32;
556        let lo = idx.floor() as usize;
557        let hi = idx.ceil() as usize;
558        if lo == hi || hi >= n {
559            sorted[lo.min(n - 1)]
560        } else {
561            let frac = idx - lo as f32;
562            sorted[lo] * (1.0 - frac) + sorted[hi] * frac
563        }
564    }
565
566    // ── Shape manipulation ──────────────────────────────────────────────
567
568    /// 2D matrix transpose. Requires rank-2 input.
569    ///
570    /// # Safety
571    /// `AlignedVec::uninitialized` allocates without zeroing. The tiled loop
572    /// writes every element before anything reads from the buffer.
573    #[allow(unsafe_code)]
574    pub fn transpose_2d(&self) -> Result<Self, TensorError> {
575        if self.rank() != 2 {
576            return Err(TensorError::InvalidAxis {
577                axis: 1,
578                rank: self.rank(),
579            });
580        }
581        let rows = self.shape()[0];
582        let cols = self.shape()[1];
583        // SAFETY: every element is written by the tiled loop below before we read.
584        let mut out_data = AlignedVec::<f32>::uninitialized(rows * cols);
585        let src = self.data();
586
587        // Tiled transpose with 8x8 blocks for cache efficiency.
588        const TILE: usize = 8;
589        let rr = rows / TILE * TILE;
590        let cc = cols / TILE * TILE;
591
592        for ii in (0..rr).step_by(TILE) {
593            for jj in (0..cc).step_by(TILE) {
594                for r in ii..ii + TILE {
595                    for c in jj..jj + TILE {
596                        out_data[c * rows + r] = src[r * cols + c];
597                    }
598                }
599            }
600            // Right edge (columns beyond cc)
601            for r in ii..ii + TILE {
602                for c in cc..cols {
603                    out_data[c * rows + r] = src[r * cols + c];
604                }
605            }
606        }
607        // Bottom edge (rows beyond rr)
608        for r in rr..rows {
609            for c in 0..cols {
610                out_data[c * rows + r] = src[r * cols + c];
611            }
612        }
613
614        Tensor::from_aligned(vec![cols, rows], out_data)
615    }
616
617    /// General axis permutation (like NumPy `transpose`/`permute`).
618    pub fn permute(&self, axes: &[usize]) -> Result<Self, TensorError> {
619        if axes.len() != self.rank() {
620            return Err(TensorError::InvalidIndexRank {
621                expected: self.rank(),
622                got: axes.len(),
623            });
624        }
625        let rank = self.rank();
626        let mut seen = vec![false; rank];
627        for &a in axes {
628            if a >= rank {
629                return Err(TensorError::InvalidAxis { axis: a, rank });
630            }
631            seen[a] = true;
632        }
633        if seen.iter().any(|&s| !s) {
634            return Err(TensorError::InvalidAxis { axis: 0, rank });
635        }
636
637        let src_shape = self.shape();
638        let mut out_shape = vec![0usize; rank];
639        for (dst, &src_axis) in axes.iter().enumerate() {
640            out_shape[dst] = src_shape[src_axis];
641        }
642        let out_count =
643            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
644                shape: out_shape.clone(),
645            })?;
646
647        // ── Fast path: tiled 2D transpose for common 4D permutations ──
648        // Uses unsafe pointer arithmetic to eliminate bounds checks in hot inner loops.
649        // NHWC→NCHW [0,3,1,2]: transpose inner [H*W, C] → [C, H*W]
650        if rank == 4 && axes == [0, 3, 1, 2] {
651            let (n, h, w, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
652            let hw = h * w;
653            let src = self.data();
654            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
655            const TILE: usize = 32;
656            #[allow(unsafe_code)]
657            unsafe {
658                let src_ptr = src.as_ptr();
659                let dst_ptr = dst.as_mut_ptr();
660                for batch in 0..n {
661                    let s_base = src_ptr.add(batch * hw * c);
662                    let d_base = dst_ptr.add(batch * c * hw);
663                    for i0 in (0..hw).step_by(TILE) {
664                        let ie = (i0 + TILE).min(hw);
665                        for j0 in (0..c).step_by(TILE) {
666                            let je = (j0 + TILE).min(c);
667                            for i in i0..ie {
668                                let s_row = s_base.add(i * c);
669                                for j in j0..je {
670                                    *d_base.add(j * hw + i) = *s_row.add(j);
671                                }
672                            }
673                        }
674                    }
675                }
676            }
677            return Tensor::from_aligned(out_shape, dst);
678        }
679        // NCHW→NHWC [0,2,3,1]: transpose inner [C, H*W] → [H*W, C]
680        if rank == 4 && axes == [0, 2, 3, 1] {
681            let (n, c, h, w) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
682            let hw = h * w;
683            let src = self.data();
684            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
685            const TILE: usize = 32;
686            #[allow(unsafe_code)]
687            unsafe {
688                let src_ptr = src.as_ptr();
689                let dst_ptr = dst.as_mut_ptr();
690                for batch in 0..n {
691                    let s_base = src_ptr.add(batch * c * hw);
692                    let d_base = dst_ptr.add(batch * hw * c);
693                    for i0 in (0..c).step_by(TILE) {
694                        let ie = (i0 + TILE).min(c);
695                        for j0 in (0..hw).step_by(TILE) {
696                            let je = (j0 + TILE).min(hw);
697                            for i in i0..ie {
698                                let s_row = s_base.add(i * hw);
699                                for j in j0..je {
700                                    *d_base.add(j * c + i) = *s_row.add(j);
701                                }
702                            }
703                        }
704                    }
705                }
706            }
707            return Tensor::from_aligned(out_shape, dst);
708        }
709        // 3D swap last two dims [0,2,1]: transpose [A, B, C] → [A, C, B]
710        if rank == 3 && axes == [0, 2, 1] {
711            let (a, b, c) = (src_shape[0], src_shape[1], src_shape[2]);
712            let src = self.data();
713            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
714            const TILE: usize = 32;
715            #[allow(unsafe_code)]
716            unsafe {
717                let src_ptr = src.as_ptr();
718                let dst_ptr = dst.as_mut_ptr();
719                for batch in 0..a {
720                    let s_base = src_ptr.add(batch * b * c);
721                    let d_base = dst_ptr.add(batch * c * b);
722                    for i0 in (0..b).step_by(TILE) {
723                        let ie = (i0 + TILE).min(b);
724                        for j0 in (0..c).step_by(TILE) {
725                            let je = (j0 + TILE).min(c);
726                            for i in i0..ie {
727                                let s_row = s_base.add(i * c);
728                                for j in j0..je {
729                                    *d_base.add(j * b + i) = *s_row.add(j);
730                                }
731                            }
732                        }
733                    }
734                }
735            }
736            return Tensor::from_aligned(out_shape, dst);
737        }
738
739        // [0,1,3,2]: swap last two dims in 4D → [N, A, C, B]
740        // For each (n, a), tiled 2D transpose of [B, C] → [C, B].
741        if rank == 4 && axes == [0, 1, 3, 2] {
742            let (nn, a, b, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
743            let src = self.data();
744            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
745            const TILE: usize = 32;
746            #[allow(unsafe_code)]
747            unsafe {
748                let src_ptr = src.as_ptr();
749                let dst_ptr = dst.as_mut_ptr();
750                for n in 0..nn {
751                    for aa in 0..a {
752                        let base = (n * a + aa) * b * c;
753                        let s_base = src_ptr.add(base);
754                        let d_base = dst_ptr.add(base); // same offset, different shape
755                        for i0 in (0..b).step_by(TILE) {
756                            let ie = (i0 + TILE).min(b);
757                            for j0 in (0..c).step_by(TILE) {
758                                let je = (j0 + TILE).min(c);
759                                for i in i0..ie {
760                                    let s_row = s_base.add(i * c);
761                                    for j in j0..je {
762                                        *d_base.add(j * b + i) = *s_row.add(j);
763                                    }
764                                }
765                            }
766                        }
767                    }
768                }
769            }
770            return Tensor::from_aligned(out_shape, dst);
771        }
772        // [0,2,1,3]: swap dims 1↔2 in 4D → [N, B, A, C]
773        // Each element in the swap is a contiguous block of C floats — use memcpy.
774        if rank == 4 && axes == [0, 2, 1, 3] {
775            let (nn, a, b, c) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
776            let src = self.data();
777            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
778            #[allow(unsafe_code)]
779            unsafe {
780                let src_ptr = src.as_ptr();
781                let dst_ptr = dst.as_mut_ptr();
782                for n in 0..nn {
783                    let s_batch = src_ptr.add(n * a * b * c);
784                    let d_batch = dst_ptr.add(n * b * a * c);
785                    for aa in 0..a {
786                        for bb in 0..b {
787                            std::ptr::copy_nonoverlapping(
788                                s_batch.add(aa * b * c + bb * c),
789                                d_batch.add(bb * a * c + aa * c),
790                                c,
791                            );
792                        }
793                    }
794                }
795            }
796            return Tensor::from_aligned(out_shape, dst);
797        }
798        // [0,3,2,1]: swap dims 1↔3 in 4D → [N, D, B, A]
799        // For each (n, b), tiled 2D transpose of [A, D] → [D, A] with strides.
800        if rank == 4 && axes == [0, 3, 2, 1] {
801            let (nn, a, b, d) = (src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
802            let src = self.data();
803            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
804            let src_a_stride = b * d;
805            let dst_d_stride = b * a;
806            const TILE: usize = 32;
807            #[allow(unsafe_code)]
808            unsafe {
809                let src_ptr = src.as_ptr();
810                let dst_ptr = dst.as_mut_ptr();
811                for n in 0..nn {
812                    for bb in 0..b {
813                        let s_base = src_ptr.add(n * a * b * d + bb * d);
814                        let d_base = dst_ptr.add(n * d * b * a + bb * a);
815                        for i0 in (0..a).step_by(TILE) {
816                            let ie = (i0 + TILE).min(a);
817                            for j0 in (0..d).step_by(TILE) {
818                                let je = (j0 + TILE).min(d);
819                                for i in i0..ie {
820                                    for j in j0..je {
821                                        *d_base.add(j * dst_d_stride + i) =
822                                            *s_base.add(i * src_a_stride + j);
823                                    }
824                                }
825                            }
826                        }
827                    }
828                }
829            }
830            return Tensor::from_aligned(out_shape, dst);
831        }
832        // 2D transpose [1,0]: swap rows and cols
833        if rank == 2 && axes == [1, 0] {
834            let (rows, cols) = (src_shape[0], src_shape[1]);
835            let src = self.data();
836            let mut dst = AlignedVec::<f32>::uninitialized(out_count);
837            const TILE: usize = 32;
838            #[allow(unsafe_code)]
839            unsafe {
840                let src_ptr = src.as_ptr();
841                let dst_ptr = dst.as_mut_ptr();
842                for i0 in (0..rows).step_by(TILE) {
843                    let ie = (i0 + TILE).min(rows);
844                    for j0 in (0..cols).step_by(TILE) {
845                        let je = (j0 + TILE).min(cols);
846                        for i in i0..ie {
847                            let s_row = src_ptr.add(i * cols);
848                            for j in j0..je {
849                                *dst_ptr.add(j * rows + i) = *s_row.add(j);
850                            }
851                        }
852                    }
853                }
854            }
855            return Tensor::from_aligned(out_shape, dst);
856        }
857
858        // ── General fallback: coordinate-based scatter ──
859        let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
860            shape: out_shape.clone(),
861        })?;
862        let mut out_data = vec![0.0f32; out_count];
863
864        let mut in_coords = vec![0usize; rank];
865        for &val in self.data().iter() {
866            let mut out_offset = 0usize;
867            for (dst_axis, &src_axis) in axes.iter().enumerate() {
868                out_offset += in_coords[src_axis] * out_strides[dst_axis];
869            }
870            out_data[out_offset] = val;
871            increment_coords(&mut in_coords, src_shape);
872        }
873
874        Tensor::from_vec(out_shape, out_data)
875    }
876
877    /// Insert a length-1 dimension at the given axis.
878    pub fn unsqueeze(&self, axis: usize) -> Result<Self, TensorError> {
879        if axis > self.rank() {
880            return Err(TensorError::InvalidAxis {
881                axis,
882                rank: self.rank() + 1,
883            });
884        }
885        let mut new_shape = self.shape().to_vec();
886        new_shape.insert(axis, 1);
887        self.reshape(new_shape)
888    }
889
890    /// Remove a length-1 dimension at the given axis.
891    pub fn squeeze(&self, axis: usize) -> Result<Self, TensorError> {
892        if axis >= self.rank() {
893            return Err(TensorError::InvalidAxis {
894                axis,
895                rank: self.rank(),
896            });
897        }
898        if self.shape()[axis] != 1 {
899            return Err(TensorError::InvalidAxis {
900                axis,
901                rank: self.rank(),
902            });
903        }
904        let mut new_shape = self.shape().to_vec();
905        new_shape.remove(axis);
906        self.reshape(new_shape)
907    }
908
909    /// Concatenate tensors along an axis. All tensors must have the same
910    /// shape except along the concatenation axis.
911    pub fn cat(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
912        if tensors.is_empty() {
913            return Err(TensorError::SizeMismatch {
914                shape: vec![],
915                data_len: 0,
916            });
917        }
918        let rank = tensors[0].rank();
919        if axis >= rank {
920            return Err(TensorError::InvalidAxis { axis, rank });
921        }
922        for t in &tensors[1..] {
923            if t.rank() != rank {
924                return Err(TensorError::ShapeMismatch {
925                    left: tensors[0].shape().to_vec(),
926                    right: t.shape().to_vec(),
927                });
928            }
929            for (a, (&d0, &di)) in tensors[0].shape().iter().zip(t.shape().iter()).enumerate() {
930                if a != axis && d0 != di {
931                    return Err(TensorError::ShapeMismatch {
932                        left: tensors[0].shape().to_vec(),
933                        right: t.shape().to_vec(),
934                    });
935                }
936            }
937        }
938
939        let mut out_shape = tensors[0].shape().to_vec();
940        out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
941        let out_count =
942            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
943                shape: out_shape.clone(),
944            })?;
945
946        let outer: usize = out_shape[..axis].iter().product();
947        let inner: usize = out_shape[axis + 1..].iter().product();
948
949        // Write directly into AlignedVec to avoid the double-copy through
950        // Vec -> AlignedVec::from_vec.
951        let mut out_data = AlignedVec::<f32>::uninitialized(out_count);
952
953        if inner == 1 && tensors.len() <= 8 {
954            // Last-axis concat: write entire output in outer-major order.
955            let axis_lens: Vec<usize> = tensors.iter().map(|t| t.shape()[axis]).collect();
956            let dst = out_data.as_mut_slice();
957            let mut dst_off = 0;
958            for o in 0..outer {
959                for (ti, t) in tensors.iter().enumerate() {
960                    let al = axis_lens[ti];
961                    let src_off = o * al;
962                    dst[dst_off..dst_off + al].copy_from_slice(&t.data()[src_off..src_off + al]);
963                    dst_off += al;
964                }
965            }
966        } else {
967            let dst = out_data.as_mut_slice();
968            let mut dst_off = 0;
969            for o in 0..outer {
970                for t in tensors {
971                    let t_axis_len = t.shape()[axis];
972                    let chunk_len = t_axis_len * inner;
973                    let chunk_start = o * chunk_len;
974                    dst[dst_off..dst_off + chunk_len]
975                        .copy_from_slice(&t.data()[chunk_start..chunk_start + chunk_len]);
976                    dst_off += chunk_len;
977                }
978            }
979        }
980
981        Tensor::from_aligned(out_shape, out_data)
982    }
983
984    /// Stack tensors along a new axis. All tensors must have identical shapes.
985    pub fn stack(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
986        if tensors.is_empty() {
987            return Err(TensorError::SizeMismatch {
988                shape: vec![],
989                data_len: 0,
990            });
991        }
992        if axis > tensors[0].rank() {
993            return Err(TensorError::InvalidAxis {
994                axis,
995                rank: tensors[0].rank() + 1,
996            });
997        }
998        let expanded: Vec<Self> = tensors
999            .iter()
1000            .map(|t| t.unsqueeze(axis))
1001            .collect::<Result<_, _>>()?;
1002        let refs: Vec<&Self> = expanded.iter().collect();
1003        Self::cat(&refs, axis)
1004    }
1005
1006    /// Select a single slice along an axis, removing that axis from the output.
1007    pub fn select(&self, axis: usize, index: usize) -> Result<Self, TensorError> {
1008        if axis >= self.rank() {
1009            return Err(TensorError::InvalidAxis {
1010                axis,
1011                rank: self.rank(),
1012            });
1013        }
1014        if index >= self.shape()[axis] {
1015            return Err(TensorError::IndexOutOfBounds {
1016                axis,
1017                index,
1018                dim: self.shape()[axis],
1019            });
1020        }
1021        let outer: usize = self.shape()[..axis].iter().product();
1022        let axis_len = self.shape()[axis];
1023        let inner: usize = self.shape()[axis + 1..].iter().product();
1024
1025        let mut out_data = Vec::with_capacity(outer * inner);
1026        for o in 0..outer {
1027            let base = o * axis_len * inner + index * inner;
1028            out_data.extend_from_slice(&self.data()[base..base + inner]);
1029        }
1030
1031        let mut out_shape = self.shape().to_vec();
1032        out_shape.remove(axis);
1033        Tensor::from_vec(out_shape, out_data)
1034    }
1035
1036    /// Narrow (slice) along an axis: extract elements `start..start+length`.
1037    pub fn narrow(&self, axis: usize, start: usize, length: usize) -> Result<Self, TensorError> {
1038        if axis >= self.rank() {
1039            return Err(TensorError::InvalidAxis {
1040                axis,
1041                rank: self.rank(),
1042            });
1043        }
1044        if start + length > self.shape()[axis] {
1045            return Err(TensorError::IndexOutOfBounds {
1046                axis,
1047                index: start + length,
1048                dim: self.shape()[axis],
1049            });
1050        }
1051        let outer: usize = self.shape()[..axis].iter().product();
1052        let axis_len = self.shape()[axis];
1053        let inner: usize = self.shape()[axis + 1..].iter().product();
1054
1055        let mut out_data = Vec::with_capacity(outer * length * inner);
1056        for o in 0..outer {
1057            let base = o * axis_len * inner + start * inner;
1058            out_data.extend_from_slice(&self.data()[base..base + length * inner]);
1059        }
1060
1061        let mut out_shape = self.shape().to_vec();
1062        out_shape[axis] = length;
1063        Tensor::from_vec(out_shape, out_data)
1064    }
1065
1066    // ── Comparison ──────────────────────────────────────────────────────
1067
1068    /// Element-wise equality check: 1.0 where equal, 0.0 otherwise.
1069    pub fn eq_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1070        if self.shape() == rhs.shape() {
1071            return self.binary_same_shape(rhs, |l, r| {
1072                if (l - r).abs() < f32::EPSILON {
1073                    1.0
1074                } else {
1075                    0.0
1076                }
1077            });
1078        }
1079        self.binary_broadcast_op(rhs, |l, r| {
1080            if (l - r).abs() < f32::EPSILON {
1081                1.0
1082            } else {
1083                0.0
1084            }
1085        })
1086    }
1087
1088    /// Element-wise greater-than: 1.0 where `self > rhs`, 0.0 otherwise.
1089    pub fn gt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1090        if self.shape() == rhs.shape() {
1091            return self.binary_same_shape(rhs, |l, r| if l > r { 1.0 } else { 0.0 });
1092        }
1093        self.binary_broadcast_op(rhs, |l, r| if l > r { 1.0 } else { 0.0 })
1094    }
1095
1096    /// Element-wise less-than: 1.0 where `self < rhs`, 0.0 otherwise.
1097    pub fn lt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1098        if self.shape() == rhs.shape() {
1099            return self.binary_same_shape(rhs, |l, r| if l < r { 1.0 } else { 0.0 });
1100        }
1101        self.binary_broadcast_op(rhs, |l, r| if l < r { 1.0 } else { 0.0 })
1102    }
1103
1104    /// Element-wise greater-than writing into a pre-allocated output tensor.
1105    /// `self`, `rhs`, and `output` must all have the same shape.
1106    pub fn gt_tensor_into(&self, rhs: &Self, output: &mut Self) {
1107        debug_assert_eq!(self.shape(), rhs.shape());
1108        debug_assert_eq!(self.shape(), output.shape());
1109        simd::cmp_dispatch(
1110            self.data(),
1111            rhs.data(),
1112            output.data_mut(),
1113            simd::CmpKind::Gt,
1114        );
1115    }
1116
1117    /// Element-wise equality check writing into a pre-allocated output tensor.
1118    /// `self`, `rhs`, and `output` must all have the same shape.
1119    pub fn eq_tensor_into(&self, rhs: &Self, output: &mut Self) {
1120        debug_assert_eq!(self.shape(), rhs.shape());
1121        debug_assert_eq!(self.shape(), output.shape());
1122        simd::cmp_dispatch(
1123            self.data(),
1124            rhs.data(),
1125            output.data_mut(),
1126            simd::CmpKind::Eq,
1127        );
1128    }
1129
1130    /// Element-wise less-than writing into a pre-allocated output tensor.
1131    /// `self`, `rhs`, and `output` must all have the same shape.
1132    pub fn lt_tensor_into(&self, rhs: &Self, output: &mut Self) {
1133        debug_assert_eq!(self.shape(), rhs.shape());
1134        debug_assert_eq!(self.shape(), output.shape());
1135        simd::cmp_dispatch(
1136            self.data(),
1137            rhs.data(),
1138            output.data_mut(),
1139            simd::CmpKind::Lt,
1140        );
1141    }
1142
1143    /// Element-wise not-equal: 1.0 where not equal (diff.abs() >= 1e-7), 0.0 otherwise.
1144    pub fn ne_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1145        if self.shape() == rhs.shape() {
1146            return self.binary_same_shape(
1147                rhs,
1148                |l, r| {
1149                    if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 }
1150                },
1151            );
1152        }
1153        self.binary_broadcast_op(rhs, |l, r| if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 })
1154    }
1155
1156    /// Element-wise less-than-or-equal: 1.0 where `self <= rhs`, 0.0 otherwise.
1157    pub fn le_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1158        if self.shape() == rhs.shape() {
1159            return self.binary_same_shape(rhs, |l, r| if l <= r { 1.0 } else { 0.0 });
1160        }
1161        self.binary_broadcast_op(rhs, |l, r| if l <= r { 1.0 } else { 0.0 })
1162    }
1163
1164    /// Element-wise greater-than-or-equal: 1.0 where `self >= rhs`, 0.0 otherwise.
1165    pub fn ge_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
1166        if self.shape() == rhs.shape() {
1167            return self.binary_same_shape(rhs, |l, r| if l >= r { 1.0 } else { 0.0 });
1168        }
1169        self.binary_broadcast_op(rhs, |l, r| if l >= r { 1.0 } else { 0.0 })
1170    }
1171
1172    /// Returns true if all elements are finite (no NaN or Inf).
1173    pub fn all_finite(&self) -> bool {
1174        self.data().iter().all(|v| v.is_finite())
1175    }
1176
1177    // ── Advanced indexing/selection ─────────────────────────────────────
1178
1179    /// Element-wise where: `condition ? self : other`.
1180    /// `condition` has 1.0 for true, 0.0 for false.
1181    pub fn where_cond(&self, condition: &Self, other: &Self) -> Result<Self, TensorError> {
1182        if self.shape() != condition.shape() || self.shape() != other.shape() {
1183            return Err(TensorError::ShapeMismatch {
1184                left: self.shape().to_vec(),
1185                right: condition.shape().to_vec(),
1186            });
1187        }
1188        let data: Vec<f32> = condition
1189            .data()
1190            .iter()
1191            .zip(self.data().iter())
1192            .zip(other.data().iter())
1193            .map(|((&c, &t), &f)| if c != 0.0 { t } else { f })
1194            .collect();
1195        Tensor::from_vec(self.shape().to_vec(), data)
1196    }
1197
1198    /// Replace elements where `mask != 0` with `value`.
1199    pub fn masked_fill(&self, mask: &Self, value: f32) -> Result<Self, TensorError> {
1200        if self.shape() != mask.shape() {
1201            return Err(TensorError::ShapeMismatch {
1202                left: self.shape().to_vec(),
1203                right: mask.shape().to_vec(),
1204            });
1205        }
1206        let data: Vec<f32> = self
1207            .data()
1208            .iter()
1209            .zip(mask.data().iter())
1210            .map(|(&v, &m)| if m != 0.0 { value } else { v })
1211            .collect();
1212        Tensor::from_vec(self.shape().to_vec(), data)
1213    }
1214
1215    /// Scatter values into self along `axis` at positions given by `index`.
1216    /// `src` provides the values. `index` has same shape as `src`.
1217    pub fn scatter(&self, axis: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
1218        if index.shape() != src.shape() {
1219            return Err(TensorError::ShapeMismatch {
1220                left: index.shape().to_vec(),
1221                right: src.shape().to_vec(),
1222            });
1223        }
1224        if axis >= self.rank() {
1225            return Err(TensorError::InvalidAxis {
1226                axis,
1227                rank: self.rank(),
1228            });
1229        }
1230        let mut out = self.data().to_vec();
1231        let shape = index.shape();
1232        let outer: usize = shape[..axis].iter().product();
1233        let dim = shape[axis];
1234        let inner: usize = shape[axis + 1..].iter().product();
1235        let self_dim = self.shape()[axis];
1236        let self_inner: usize = self.shape()[axis + 1..].iter().product();
1237
1238        for o in 0..outer {
1239            for d in 0..dim {
1240                for i in 0..inner {
1241                    let src_idx = (o * dim + d) * inner + i;
1242                    let target_d = index.data()[src_idx] as usize;
1243                    if target_d < self_dim {
1244                        let out_idx = (o * self_dim + target_d) * self_inner + i;
1245                        if out_idx < out.len() {
1246                            out[out_idx] = src.data()[src_idx];
1247                        }
1248                    }
1249                }
1250            }
1251        }
1252        Tensor::from_vec(self.shape().to_vec(), out)
1253    }
1254
1255    /// Gather elements along `axis` at positions given by `index`.
1256    pub fn gather(&self, axis: usize, index: &Self) -> Result<Self, TensorError> {
1257        if axis >= self.rank() {
1258            return Err(TensorError::InvalidAxis {
1259                axis,
1260                rank: self.rank(),
1261            });
1262        }
1263        let shape = index.shape();
1264        let outer: usize = shape[..axis].iter().product();
1265        let dim = shape[axis];
1266        let inner: usize = shape[axis + 1..].iter().product();
1267        let self_dim = self.shape()[axis];
1268        let self_inner: usize = self.shape()[axis + 1..].iter().product();
1269
1270        let mut out = vec![0.0f32; index.len()];
1271        for o in 0..outer {
1272            for d in 0..dim {
1273                for i in 0..inner {
1274                    let idx_pos = (o * dim + d) * inner + i;
1275                    let src_d = index.data()[idx_pos] as usize;
1276                    if src_d < self_dim {
1277                        let src_pos = (o * self_dim + src_d) * self_inner + i;
1278                        if src_pos < self.len() {
1279                            out[idx_pos] = self.data()[src_pos];
1280                        }
1281                    }
1282                }
1283            }
1284        }
1285        Tensor::from_vec(shape.to_vec(), out)
1286    }
1287
1288    /// Returns the top-k values and their indices along the last axis.
1289    pub fn topk(&self, k: usize) -> Result<(Self, Self), TensorError> {
1290        if self.rank() == 0 {
1291            return Err(TensorError::InvalidAxis { axis: 0, rank: 0 });
1292        }
1293        let last_dim = *self.shape().last().expect("non-empty shape");
1294        let k = k.min(last_dim);
1295        let outer: usize = self.len() / last_dim;
1296
1297        let mut values = Vec::with_capacity(outer * k);
1298        let mut indices = Vec::with_capacity(outer * k);
1299
1300        for o in 0..outer {
1301            let start = o * last_dim;
1302            let slice = &self.data()[start..start + last_dim];
1303            let mut idx_vec: Vec<usize> = (0..last_dim).collect();
1304            idx_vec.sort_unstable_by(|&a, &b| {
1305                slice[b]
1306                    .partial_cmp(&slice[a])
1307                    .unwrap_or(std::cmp::Ordering::Equal)
1308            });
1309            for &i in &idx_vec[..k] {
1310                values.push(slice[i]);
1311                indices.push(i as f32);
1312            }
1313        }
1314
1315        let mut out_shape = self.shape().to_vec();
1316        *out_shape.last_mut().expect("non-empty shape") = k;
1317        let val_t = Tensor::from_vec(out_shape.clone(), values)?;
1318        let idx_t = Tensor::from_vec(out_shape, indices)?;
1319        Ok((val_t, idx_t))
1320    }
1321
1322    /// Upper triangular mask: zero below diagonal, keep above.
1323    /// `diagonal` shifts: 0 = main, positive = above, negative = below.
1324    pub fn triu(&self, diagonal: i64) -> Result<Self, TensorError> {
1325        if self.rank() < 2 {
1326            return Err(TensorError::InvalidAxis {
1327                axis: 0,
1328                rank: self.rank(),
1329            });
1330        }
1331        let shape = self.shape();
1332        let rows = shape[shape.len() - 2];
1333        let cols = shape[shape.len() - 1];
1334        let batch: usize = shape[..shape.len() - 2].iter().product();
1335        let mut out = self.data().to_vec();
1336        for b in 0..batch {
1337            for r in 0..rows {
1338                for c in 0..cols {
1339                    if (c as i64) < (r as i64) + diagonal {
1340                        out[b * rows * cols + r * cols + c] = 0.0;
1341                    }
1342                }
1343            }
1344        }
1345        Tensor::from_vec(shape.to_vec(), out)
1346    }
1347
1348    /// Lower triangular mask: zero above diagonal, keep below.
1349    pub fn tril(&self, diagonal: i64) -> Result<Self, TensorError> {
1350        if self.rank() < 2 {
1351            return Err(TensorError::InvalidAxis {
1352                axis: 0,
1353                rank: self.rank(),
1354            });
1355        }
1356        let shape = self.shape();
1357        let rows = shape[shape.len() - 2];
1358        let cols = shape[shape.len() - 1];
1359        let batch: usize = shape[..shape.len() - 2].iter().product();
1360        let mut out = self.data().to_vec();
1361        for b in 0..batch {
1362            for r in 0..rows {
1363                for c in 0..cols {
1364                    if (c as i64) > (r as i64) + diagonal {
1365                        out[b * rows * cols + r * cols + c] = 0.0;
1366                    }
1367                }
1368            }
1369        }
1370        Tensor::from_vec(shape.to_vec(), out)
1371    }
1372
1373    /// Identity matrix `[n, n]`.
1374    pub fn eye(n: usize) -> Result<Self, TensorError> {
1375        let mut data = vec![0.0f32; n * n];
1376        for i in 0..n {
1377            data[i * n + i] = 1.0;
1378        }
1379        Tensor::from_vec(vec![n, n], data)
1380    }
1381
1382    /// Create a diagonal matrix from a 1D vector.
1383    pub fn diag(vector: &Tensor) -> Result<Self, TensorError> {
1384        let shape = vector.shape();
1385        if shape.len() != 1 {
1386            return Err(TensorError::UnsupportedOperation {
1387                msg: format!("diag requires a 1D tensor, got shape {:?}", shape),
1388            });
1389        }
1390        let n = shape[0];
1391        let mut data = vec![0.0f32; n * n];
1392        for i in 0..n {
1393            data[i * n + i] = vector.data()[i];
1394        }
1395        Self::from_vec(vec![n, n], data)
1396    }
1397
1398    /// Extract the diagonal of a 2D matrix as a 1D vector.
1399    pub fn diag_extract(&self) -> Result<Self, TensorError> {
1400        let shape = self.shape();
1401        if shape.len() != 2 {
1402            return Err(TensorError::UnsupportedOperation {
1403                msg: format!("diag_extract requires a 2D tensor, got shape {:?}", shape),
1404            });
1405        }
1406        let n = shape[0].min(shape[1]);
1407        let cols = shape[1];
1408        let data: Vec<f32> = (0..n).map(|i| self.data()[i * cols + i]).collect();
1409        Self::from_vec(vec![n], data)
1410    }
1411
1412    /// Pad the tensor with a constant value. `padding` is a slice of (before, after) per dimension.
1413    pub fn pad(&self, padding: &[(usize, usize)], value: f32) -> Result<Self, TensorError> {
1414        let shape = self.shape();
1415        if padding.len() != shape.len() {
1416            return Err(TensorError::InvalidIndexRank {
1417                expected: shape.len(),
1418                got: padding.len(),
1419            });
1420        }
1421        let new_shape: Vec<usize> = shape
1422            .iter()
1423            .zip(padding)
1424            .map(|(&s, &(b, a))| s + b + a)
1425            .collect();
1426        let new_size: usize = new_shape.iter().product();
1427        let mut result = vec![value; new_size];
1428        let ndim = shape.len();
1429
1430        // Compute strides for both old and new shapes
1431        let mut old_strides = vec![1usize; ndim];
1432        for i in (0..ndim.saturating_sub(1)).rev() {
1433            old_strides[i] = old_strides[i + 1] * shape[i + 1];
1434        }
1435        let mut new_strides = vec![1usize; ndim];
1436        for i in (0..ndim.saturating_sub(1)).rev() {
1437            new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1438        }
1439
1440        let old_size: usize = shape.iter().product();
1441        let data = self.data();
1442        for flat_idx in 0..old_size {
1443            let mut remaining = flat_idx;
1444            let mut new_flat = 0;
1445            for d in 0..ndim {
1446                let coord = remaining / old_strides[d];
1447                remaining %= old_strides[d];
1448                new_flat += (coord + padding[d].0) * new_strides[d];
1449            }
1450            result[new_flat] = data[flat_idx];
1451        }
1452
1453        Self::from_vec(new_shape, result)
1454    }
1455
1456    /// Repeat tensor along each axis by the given counts.
1457    pub fn repeat(&self, counts: &[usize]) -> Result<Self, TensorError> {
1458        if counts.len() != self.rank() {
1459            return Err(TensorError::InvalidIndexRank {
1460                expected: self.rank(),
1461                got: counts.len(),
1462            });
1463        }
1464        let mut out = self.clone();
1465        for (axis, &count) in counts.iter().enumerate() {
1466            if count > 1 {
1467                let refs: Vec<&Tensor> = std::iter::repeat_n(&out, count).collect();
1468                out = Tensor::cat(&refs, axis)?;
1469            }
1470        }
1471        Ok(out)
1472    }
1473
1474    // ── Cumulative operations ──────────────────────────────────────────
1475
1476    /// Cumulative sum along an axis.
1477    pub fn cumsum(&self, axis: usize) -> Result<Self, TensorError> {
1478        if axis >= self.rank() {
1479            return Err(TensorError::InvalidAxis {
1480                axis,
1481                rank: self.rank(),
1482            });
1483        }
1484        let shape = self.shape();
1485        let outer: usize = shape[..axis].iter().product();
1486        let axis_len = shape[axis];
1487        let inner: usize = shape[axis + 1..].iter().product();
1488        let mut out = self.data().to_vec();
1489
1490        for o in 0..outer {
1491            for i in 0..inner {
1492                let mut acc = 0.0f32;
1493                for d in 0..axis_len {
1494                    let idx = (o * axis_len + d) * inner + i;
1495                    acc += out[idx];
1496                    out[idx] = acc;
1497                }
1498            }
1499        }
1500        Tensor::from_vec(shape.to_vec(), out)
1501    }
1502
1503    /// Cumulative product along an axis.
1504    pub fn cumprod(&self, axis: usize) -> Result<Self, TensorError> {
1505        if axis >= self.rank() {
1506            return Err(TensorError::InvalidAxis {
1507                axis,
1508                rank: self.rank(),
1509            });
1510        }
1511        let shape = self.shape();
1512        let outer: usize = shape[..axis].iter().product();
1513        let axis_len = shape[axis];
1514        let inner: usize = shape[axis + 1..].iter().product();
1515        let mut out = self.data().to_vec();
1516
1517        for o in 0..outer {
1518            for i in 0..inner {
1519                let mut acc = 1.0f32;
1520                for d in 0..axis_len {
1521                    let idx = (o * axis_len + d) * inner + i;
1522                    acc *= out[idx];
1523                    out[idx] = acc;
1524                }
1525            }
1526        }
1527        Tensor::from_vec(shape.to_vec(), out)
1528    }
1529
1530    // ── FP16 conversion ────────────────────────────────────────────────
1531
1532    /// Convert all elements to IEEE 754 half-precision (FP16) bytes.
1533    /// Returns `Vec<u16>` where each u16 is an FP16 bit pattern.
1534    pub fn to_fp16(&self) -> Vec<u16> {
1535        self.data().iter().map(|&v| f32_to_fp16(v)).collect()
1536    }
1537
1538    /// Create a tensor from FP16 bit patterns.
1539    pub fn from_fp16(shape: Vec<usize>, fp16_data: &[u16]) -> Result<Self, TensorError> {
1540        let data: Vec<f32> = fp16_data.iter().map(|&v| fp16_to_f32(v)).collect();
1541        Tensor::from_vec(shape, data)
1542    }
1543
1544    // ── Internal helpers ────────────────────────────────────────────────
1545
1546    #[allow(unsafe_code)]
1547    fn unary_op<F>(&self, op: F) -> Self
1548    where
1549        F: Fn(f32) -> f32,
1550    {
1551        let src = self.data();
1552        let len = src.len();
1553        // SAFETY: `uninitialized` allocates without zeroing.  The loop below
1554        // writes every element before we ever read from `out_data`.
1555        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1556        let inp = src.as_ptr();
1557        let outp = out_data.as_mut_ptr();
1558        unsafe {
1559            for i in 0..len {
1560                *outp.add(i) = op(*inp.add(i));
1561            }
1562        }
1563        Tensor::from_raw_parts(self.shape(), self.strides(), out_data)
1564    }
1565
1566    #[allow(unsafe_code)]
1567    fn unary_simd_op(&self, kind: simd::UnaryKind) -> Self {
1568        let len = self.len();
1569        // SAFETY: `uninitialized` allocates without zeroing.  `unary_dispatch`
1570        // writes every element before we ever read from `out`.
1571        let mut out = AlignedVec::<f32>::uninitialized(len);
1572        simd::unary_dispatch(self.data(), &mut out, kind);
1573        Tensor::from_raw_parts(self.shape(), self.strides(), out)
1574    }
1575
1576    fn reduce_axis<F>(&self, axis: usize, init: f32, combine: F) -> Result<Self, TensorError>
1577    where
1578        F: Fn(f32, f32) -> f32,
1579    {
1580        if axis >= self.rank() {
1581            return Err(TensorError::InvalidAxis {
1582                axis,
1583                rank: self.rank(),
1584            });
1585        }
1586
1587        let mut out_shape = self.shape().to_vec();
1588        out_shape.remove(axis);
1589        let out_count =
1590            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1591                shape: out_shape.clone(),
1592            })?;
1593        let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1594            shape: out_shape.clone(),
1595        })?;
1596        let mut out_data = vec![init; out_count];
1597
1598        let mut in_coords = vec![0usize; self.rank()];
1599        for input in self.data().iter().copied() {
1600            let mut out_offset = 0usize;
1601            for (src_axis, coord) in in_coords.iter().copied().enumerate() {
1602                if src_axis == axis {
1603                    continue;
1604                }
1605                let dst_axis = if src_axis < axis {
1606                    src_axis
1607                } else {
1608                    src_axis - 1
1609                };
1610                if !out_shape.is_empty() {
1611                    out_offset += coord * out_strides[dst_axis];
1612                }
1613            }
1614            out_data[out_offset] = combine(out_data[out_offset], input);
1615            increment_coords(&mut in_coords, self.shape());
1616        }
1617
1618        Tensor::from_vec(out_shape, out_data)
1619    }
1620
1621    /// SIMD fast path for broadcasting a last-dim vector across all rows.
1622    ///
1623    /// Matches patterns like `[N, C] + [C]`, `[N, H, W, C] + [1, C]`, etc.
1624    /// Returns `None` if the pattern doesn't match, so the caller falls through
1625    /// to the generic broadcast loop.
1626    #[allow(unsafe_code)]
1627    fn binary_broadcast_lastdim_simd(
1628        &self,
1629        rhs: &Self,
1630        kind: simd::BinaryKind,
1631    ) -> Option<Result<Self, TensorError>> {
1632        let lhs_shape = self.shape();
1633        let rhs_shape = rhs.shape();
1634
1635        // Detect: rhs is a 1-D vector whose length equals lhs's last dim,
1636        // or rhs shape is all-1 except the last dim which matches lhs's last dim.
1637        let lhs_last = *lhs_shape.last()?;
1638        if lhs_last == 0 {
1639            return None;
1640        }
1641
1642        let rhs_last = *rhs_shape.last()?;
1643
1644        // Check rhs is effectively a 1-D vector of size lhs_last:
1645        // either shape [C] or shape [1, 1, ..., C]
1646        let rhs_is_lastdim_vec =
1647            rhs_last == lhs_last && rhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1648        // Also handle the symmetric case: lhs is the vector, rhs is the big tensor
1649        let lhs_is_lastdim_vec =
1650            lhs_last == rhs_last && lhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1651
1652        if rhs_is_lastdim_vec && !lhs_is_lastdim_vec {
1653            // Broadcast rhs across all rows of lhs
1654            let lhs_data = self.data();
1655            let rhs_data = rhs.data();
1656            let row_len = lhs_last;
1657            let num_rows = lhs_data.len() / row_len;
1658            let mut out_data = AlignedVec::<f32>::uninitialized(lhs_data.len());
1659
1660            for i in 0..num_rows {
1661                let start = i * row_len;
1662                let end = start + row_len;
1663                simd::binary_dispatch(
1664                    &lhs_data[start..end],
1665                    &rhs_data[..row_len],
1666                    &mut out_data[start..end],
1667                    kind,
1668                );
1669            }
1670
1671            let out_strides = compute_strides(lhs_shape).expect("valid shape for strides");
1672            Some(Ok(Tensor::from_raw_parts(
1673                lhs_shape,
1674                &out_strides,
1675                out_data,
1676            )))
1677        } else if lhs_is_lastdim_vec && !rhs_is_lastdim_vec {
1678            // Broadcast lhs across all rows of rhs
1679            let lhs_data = self.data();
1680            let rhs_data = rhs.data();
1681            let row_len = rhs_last;
1682            let num_rows = rhs_data.len() / row_len;
1683            let mut out_data = AlignedVec::<f32>::uninitialized(rhs_data.len());
1684
1685            for i in 0..num_rows {
1686                let start = i * row_len;
1687                let end = start + row_len;
1688                simd::binary_dispatch(
1689                    &lhs_data[..row_len],
1690                    &rhs_data[start..end],
1691                    &mut out_data[start..end],
1692                    kind,
1693                );
1694            }
1695
1696            let out_strides = compute_strides(rhs_shape).expect("valid shape for strides");
1697            Some(Ok(Tensor::from_raw_parts(
1698                rhs_shape,
1699                &out_strides,
1700                out_data,
1701            )))
1702        } else {
1703            None
1704        }
1705    }
1706
1707    fn binary_broadcast_op<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1708    where
1709        F: Fn(f32, f32) -> f32,
1710    {
1711        let out_shape = broadcast_shape(self.shape(), rhs.shape()).ok_or_else(|| {
1712            TensorError::BroadcastIncompatible {
1713                left: self.shape().to_vec(),
1714                right: rhs.shape().to_vec(),
1715            }
1716        })?;
1717
1718        let out_count =
1719            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1720                shape: out_shape.clone(),
1721            })?;
1722        let mut out_data = vec![0.0; out_count];
1723        let mut coords = vec![0usize; out_shape.len()];
1724
1725        for value in &mut out_data {
1726            let left_offset = broadcast_offset(self.shape(), self.strides(), &coords);
1727            let right_offset = broadcast_offset(rhs.shape(), rhs.strides(), &coords);
1728            *value = op(self.data()[left_offset], rhs.data()[right_offset]);
1729            increment_coords(&mut coords, &out_shape);
1730        }
1731
1732        Tensor::from_vec(out_shape, out_data)
1733    }
1734
1735    /// SIMD-accelerated binary op for same-shape tensors (add/sub/mul/div).
1736    #[allow(unsafe_code)]
1737    #[allow(unsafe_code)]
1738    fn binary_same_shape_simd(
1739        &self,
1740        rhs: &Self,
1741        kind: simd::BinaryKind,
1742    ) -> Result<Self, TensorError> {
1743        let len = self.len();
1744        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1745
1746        // Multi-threaded for large tensors. Cross-platform:
1747        // macOS → GCD dispatch_apply, others → rayon thread pool.
1748        if len >= 100_000 {
1749            let n = std::thread::available_parallelism()
1750                .map(|p| p.get())
1751                .unwrap_or(4);
1752            let chunk = len.div_ceil(n);
1753            let lp = self.data().as_ptr() as usize;
1754            let rp = rhs.data().as_ptr() as usize;
1755            let op = out_data.as_mut_ptr() as usize;
1756
1757            #[cfg(target_os = "macos")]
1758            {
1759                use std::ffi::c_void;
1760                #[allow(unsafe_code)]
1761                unsafe extern "C" {
1762                    fn dispatch_get_global_queue(id: isize, flags: usize) -> *const c_void;
1763                    fn dispatch_apply_f(
1764                        n: usize,
1765                        q: *const c_void,
1766                        ctx: *mut c_void,
1767                        work: unsafe extern "C" fn(*mut c_void, usize),
1768                    );
1769                }
1770                struct Ctx {
1771                    lp: usize,
1772                    rp: usize,
1773                    op: usize,
1774                    chunk: usize,
1775                    len: usize,
1776                    kind: simd::BinaryKind,
1777                }
1778                let ctx = Ctx {
1779                    lp,
1780                    rp,
1781                    op,
1782                    chunk,
1783                    len,
1784                    kind,
1785                };
1786                unsafe extern "C" fn work(raw: *mut c_void, t: usize) {
1787                    let c = unsafe { &*(raw as *const Ctx) };
1788                    let start = t * c.chunk;
1789                    let end = (start + c.chunk).min(c.len);
1790                    if start >= end {
1791                        return;
1792                    }
1793                    let l = unsafe {
1794                        std::slice::from_raw_parts((c.lp as *const f32).add(start), end - start)
1795                    };
1796                    let r = unsafe {
1797                        std::slice::from_raw_parts((c.rp as *const f32).add(start), end - start)
1798                    };
1799                    let o = unsafe {
1800                        std::slice::from_raw_parts_mut((c.op as *mut f32).add(start), end - start)
1801                    };
1802                    simd::binary_dispatch(l, r, o, c.kind);
1803                }
1804                let q = unsafe { dispatch_get_global_queue(0, 0) };
1805                unsafe {
1806                    dispatch_apply_f(n, q, &ctx as *const Ctx as *mut c_void, work);
1807                }
1808            }
1809
1810            #[cfg(not(target_os = "macos"))]
1811            {
1812                // Rayon global thread pool — threads pre-spawned, ~0.5µs dispatch.
1813                use rayon::prelude::*;
1814                (0..n).into_par_iter().for_each(|t| {
1815                    let start = t * chunk;
1816                    let end = (start + chunk).min(len);
1817                    if start >= end {
1818                        return;
1819                    }
1820                    let l = unsafe {
1821                        std::slice::from_raw_parts((lp as *const f32).add(start), end - start)
1822                    };
1823                    let r = unsafe {
1824                        std::slice::from_raw_parts((rp as *const f32).add(start), end - start)
1825                    };
1826                    let o = unsafe {
1827                        std::slice::from_raw_parts_mut((op as *mut f32).add(start), end - start)
1828                    };
1829                    simd::binary_dispatch(l, r, o, kind);
1830                });
1831            }
1832
1833            return Ok(Tensor::from_raw_parts(
1834                self.shape(),
1835                self.strides(),
1836                out_data,
1837            ));
1838        }
1839
1840        simd::binary_dispatch(self.data(), rhs.data(), &mut out_data, kind);
1841        Ok(Tensor::from_raw_parts(
1842            self.shape(),
1843            self.strides(),
1844            out_data,
1845        ))
1846    }
1847
1848    #[allow(unsafe_code)]
1849    fn binary_same_shape<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1850    where
1851        F: Fn(f32, f32) -> f32,
1852    {
1853        let len = self.len();
1854        // SAFETY: `uninitialized` allocates without zeroing.  The loop below
1855        // writes every element before we ever read from `out_data`.
1856        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1857
1858        let lhs_ptr = self.data().as_ptr();
1859        let rhs_ptr = rhs.data().as_ptr();
1860        let out_ptr = out_data.as_mut_ptr();
1861
1862        // SAFETY:
1863        // - Pointers originate from valid slices/vectors of length `len`.
1864        // - Loop bounds guarantee all pointer arithmetic remains in-bounds.
1865        // - `out_data` is uniquely mutable and does not alias with input buffers.
1866        unsafe {
1867            let mut index = 0usize;
1868            while index + 4 <= len {
1869                *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1870                *out_ptr.add(index + 1) = op(*lhs_ptr.add(index + 1), *rhs_ptr.add(index + 1));
1871                *out_ptr.add(index + 2) = op(*lhs_ptr.add(index + 2), *rhs_ptr.add(index + 2));
1872                *out_ptr.add(index + 3) = op(*lhs_ptr.add(index + 3), *rhs_ptr.add(index + 3));
1873                index += 4;
1874            }
1875            while index < len {
1876                *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1877                index += 1;
1878            }
1879        }
1880
1881        Ok(Tensor::from_raw_parts(
1882            self.shape(),
1883            self.strides(),
1884            out_data,
1885        ))
1886    }
1887
1888    // ── In-place operations ──────────────────────────────────────────────
1889
1890    /// In-place ReLU: clamp negative values to zero.
1891    pub fn relu_inplace(&mut self) {
1892        simd::relu_inplace_dispatch(self.data_mut());
1893    }
1894
1895    /// In-place element-wise add from another tensor (must have same shape).
1896    pub fn add_inplace(&mut self, rhs: &Self) {
1897        debug_assert_eq!(self.len(), rhs.len());
1898        simd::add_inplace_dispatch(self.data_mut(), rhs.data());
1899    }
1900
1901    /// In-place add scalar to all elements.
1902    pub fn add_scalar_inplace(&mut self, s: f32) {
1903        simd::add_scalar_inplace_dispatch(self.data_mut(), s);
1904    }
1905
1906    /// In-place multiply all elements by scalar.
1907    pub fn mul_scalar_inplace(&mut self, s: f32) {
1908        simd::mul_scalar_inplace_dispatch(self.data_mut(), s);
1909    }
1910
1911    // ── Binary-into operations (write into pre-allocated output) ─────
1912
1913    /// Element-wise addition writing into a pre-allocated output tensor.
1914    /// `output`, `lhs`, and `rhs` must all have the same shape.
1915    pub fn add_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1916        debug_assert_eq!(lhs.shape(), rhs.shape());
1917        debug_assert_eq!(lhs.shape(), output.shape());
1918        simd::binary_dispatch(
1919            lhs.data(),
1920            rhs.data(),
1921            output.data_mut(),
1922            simd::BinaryKind::Add,
1923        );
1924    }
1925
1926    /// Element-wise subtraction writing into a pre-allocated output tensor.
1927    /// `output`, `lhs`, and `rhs` must all have the same shape.
1928    pub fn sub_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1929        debug_assert_eq!(lhs.shape(), rhs.shape());
1930        debug_assert_eq!(lhs.shape(), output.shape());
1931        simd::binary_dispatch(
1932            lhs.data(),
1933            rhs.data(),
1934            output.data_mut(),
1935            simd::BinaryKind::Sub,
1936        );
1937    }
1938
1939    /// Element-wise multiplication writing into a pre-allocated output tensor.
1940    /// `output`, `lhs`, and `rhs` must all have the same shape.
1941    pub fn mul_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1942        debug_assert_eq!(lhs.shape(), rhs.shape());
1943        debug_assert_eq!(lhs.shape(), output.shape());
1944        simd::binary_dispatch(
1945            lhs.data(),
1946            rhs.data(),
1947            output.data_mut(),
1948            simd::BinaryKind::Mul,
1949        );
1950    }
1951}
1952
1953// ── atan2 helper ───────────────────────────────────────────────────
1954
1955/// Scalar atan2(y, x) using Cephes-style range reduction of atan.
1956///
1957/// Range-reduces the argument to [0, tan(pi/12)] ≈ [0, 0.268] using:
1958///   - If z > tan(3*pi/8) ≈ 2.414: atan(z) = pi/2 - atan(1/z)
1959///   - If z > tan(pi/8) ≈ 0.414: atan(z) = pi/4 + atan((z-1)/(z+1))
1960///
1961/// Then uses a degree-7 polynomial on the reduced argument.
1962/// Max error < 4e-7 across all inputs.
1963#[allow(clippy::excessive_precision)]
1964#[inline(always)]
1965fn fast_atan2_scalar(y: f32, x: f32) -> f32 {
1966    const PI: f32 = std::f32::consts::PI;
1967    const FRAC_PI_2: f32 = std::f32::consts::FRAC_PI_2;
1968    const FRAC_PI_4: f32 = std::f32::consts::FRAC_PI_4;
1969    const TAN_3PI_8: f32 = 2.414_213_6; // tan(3*pi/8) = 1 + sqrt(2)
1970    const TAN_PI_8: f32 = 0.414_213_57; // tan(pi/8) = sqrt(2) - 1
1971
1972    let ax = x.abs();
1973    let ay = y.abs();
1974
1975    // Compute atan(|y|/|x|) with range reduction
1976    let (num, den, swap) = if ax >= ay {
1977        (ay, ax, false)
1978    } else {
1979        (ax, ay, true)
1980    };
1981    let z = if den > 0.0 { num / den } else { 0.0 };
1982
1983    // Range reduction for atan(z), z >= 0
1984    let (z_red, bias) = if z > TAN_3PI_8 {
1985        // Should not happen since z <= 1 from our swap, but just in case
1986        (-1.0 / z, FRAC_PI_2)
1987    } else if z > TAN_PI_8 {
1988        ((z - 1.0) / (z + 1.0), FRAC_PI_4)
1989    } else {
1990        (z, 0.0)
1991    };
1992
1993    // Polynomial: atan(z) ≈ z + z³·P(z²) for small z
1994    // Coefficients from Cephes atanf.c (S. Moshier), reordered for Horner.
1995    let z2 = z_red * z_red;
1996    let p: f32 = 8.054_666e-02;
1997    let p = p.mul_add(z2, -1.384_895_1e-01);
1998    let p = p.mul_add(z2, 1.997_075_8e-01);
1999    let p = p.mul_add(z2, -3.333_129_8e-01);
2000    let atan_z = z_red.mul_add(z2 * p, z_red) + bias;
2001
2002    // If we swapped: atan(|y|/|x|) = pi/2 - atan(|x|/|y|)
2003    let mut result = if swap { FRAC_PI_2 - atan_z } else { atan_z };
2004
2005    // Quadrant correction
2006    if x < 0.0 {
2007        result = PI - result;
2008    }
2009    if y < 0.0 {
2010        result = -result;
2011    }
2012
2013    result
2014}
2015
2016// ── FP16 conversion utilities ──────────────────────────────────────
2017
2018/// Convert f32 to IEEE 754 half-precision (FP16) bit pattern.
2019fn f32_to_fp16(val: f32) -> u16 {
2020    let bits = val.to_bits();
2021    let sign = ((bits >> 16) & 0x8000) as u16;
2022    let exponent = ((bits >> 23) & 0xFF) as i32;
2023    let mantissa = bits & 0x007F_FFFF;
2024
2025    if exponent == 255 {
2026        // Inf or NaN
2027        return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
2028    }
2029
2030    let unbiased = exponent - 127;
2031    if unbiased > 15 {
2032        return sign | 0x7C00; // overflow -> Inf
2033    }
2034    if unbiased < -24 {
2035        return sign; // underflow -> zero
2036    }
2037    if unbiased < -14 {
2038        // subnormal FP16
2039        let shift = -1 - unbiased;
2040        let m = (mantissa | 0x0080_0000) >> (shift + 13);
2041        return sign | m as u16;
2042    }
2043
2044    let fp16_exp = ((unbiased + 15) as u16) << 10;
2045    let fp16_man = (mantissa >> 13) as u16;
2046    sign | fp16_exp | fp16_man
2047}
2048
2049impl Tensor {
2050    // ── Sort / argsort / unique / nonzero ──────────────────────────────
2051
2052    /// Sort along `dim`. Returns `(sorted_values, sorted_indices)`.
2053    ///
2054    /// If `descending` is true, values are sorted largest-first.
2055    pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self), TensorError> {
2056        if dim >= self.rank() {
2057            return Err(TensorError::InvalidAxis {
2058                axis: dim,
2059                rank: self.rank(),
2060            });
2061        }
2062        let shape = self.shape();
2063        let outer: usize = shape[..dim].iter().product();
2064        let dim_len = shape[dim];
2065        let inner: usize = shape[dim + 1..].iter().product();
2066        let data = self.data();
2067
2068        let mut out_vals = vec![0.0f32; data.len()];
2069        let mut out_idxs = vec![0.0f32; data.len()];
2070
2071        for o in 0..outer {
2072            for i in 0..inner {
2073                let mut idx_vec: Vec<usize> = (0..dim_len).collect();
2074                idx_vec.sort_unstable_by(|&a, &b| {
2075                    let va = data[(o * dim_len + a) * inner + i];
2076                    let vb = data[(o * dim_len + b) * inner + i];
2077                    if descending {
2078                        vb.partial_cmp(&va).unwrap_or(std::cmp::Ordering::Equal)
2079                    } else {
2080                        va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
2081                    }
2082                });
2083                for (rank, &src) in idx_vec.iter().enumerate() {
2084                    let dst = (o * dim_len + rank) * inner + i;
2085                    let src_pos = (o * dim_len + src) * inner + i;
2086                    out_vals[dst] = data[src_pos];
2087                    out_idxs[dst] = src as f32;
2088                }
2089            }
2090        }
2091
2092        let v = Tensor::from_vec(shape.to_vec(), out_vals)?;
2093        let idx = Tensor::from_vec(shape.to_vec(), out_idxs)?;
2094        Ok((v, idx))
2095    }
2096
2097    /// Return indices that would sort along `dim`.
2098    pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self, TensorError> {
2099        let (_, indices) = self.sort(dim, descending)?;
2100        Ok(indices)
2101    }
2102
2103    /// Return unique elements (sorted), inverse indices, and counts.
2104    pub fn unique(&self) -> (Self, Self, Self) {
2105        let data = self.data();
2106        let mut sorted: Vec<f32> = data.to_vec();
2107        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
2108        sorted.dedup();
2109
2110        let mut inverse = vec![0.0f32; data.len()];
2111        let mut counts = vec![0.0f32; sorted.len()];
2112        for (i, &v) in data.iter().enumerate() {
2113            let pos = sorted
2114                .binary_search_by(|probe| {
2115                    probe.partial_cmp(&v).unwrap_or(std::cmp::Ordering::Equal)
2116                })
2117                .expect("value exists in sorted list");
2118            inverse[i] = pos as f32;
2119            counts[pos] += 1.0;
2120        }
2121
2122        let vals = Tensor::from_vec(vec![sorted.len()], sorted).expect("unique vals");
2123        let inv = Tensor::from_vec(self.shape().to_vec(), inverse).expect("unique inv");
2124        let cnt = Tensor::from_vec(vec![counts.len()], counts).expect("unique counts");
2125        (vals, inv, cnt)
2126    }
2127
2128    /// Return coordinates of nonzero elements as a 2D tensor `[N, rank]`.
2129    pub fn nonzero(&self) -> Self {
2130        let shape = self.shape();
2131        let rank = shape.len().max(1);
2132        let data = self.data();
2133        let mut coords: Vec<Vec<usize>> = Vec::new();
2134
2135        if shape.is_empty() {
2136            // scalar
2137            if data[0] != 0.0 {
2138                coords.push(vec![0]);
2139            }
2140        } else {
2141            let mut idx = vec![0usize; shape.len()];
2142            for pos in 0..data.len() {
2143                if data[pos] != 0.0 {
2144                    coords.push(idx.clone());
2145                }
2146                // increment multi-dim index
2147                for d in (0..shape.len()).rev() {
2148                    idx[d] += 1;
2149                    if idx[d] < shape[d] {
2150                        break;
2151                    }
2152                    idx[d] = 0;
2153                }
2154            }
2155        }
2156
2157        let n = coords.len();
2158        let mut flat = Vec::with_capacity(n * rank);
2159        for c in &coords {
2160            for &v in c {
2161                flat.push(v as f32);
2162            }
2163        }
2164        if n == 0 {
2165            Tensor::from_vec(vec![0, rank], flat).expect("nonzero empty")
2166        } else {
2167            Tensor::from_vec(vec![n, rank], flat).expect("nonzero")
2168        }
2169    }
2170
2171    // ── Flip / roll ────────────────────────────────────────────────────
2172
2173    /// Reverse elements along the given dimensions.
2174    pub fn flip(&self, dims: &[usize]) -> Result<Self, TensorError> {
2175        for &d in dims {
2176            if d >= self.rank() {
2177                return Err(TensorError::InvalidAxis {
2178                    axis: d,
2179                    rank: self.rank(),
2180                });
2181            }
2182        }
2183        let shape = self.shape();
2184        let data = self.data();
2185        let total = data.len();
2186        let mut out = vec![0.0f32; total];
2187        let rank = shape.len();
2188
2189        let mut src_idx = vec![0usize; rank];
2190        for pos in 0..total {
2191            // compute destination index by flipping specified dims
2192            let mut dst_idx = src_idx.clone();
2193            for &d in dims {
2194                dst_idx[d] = shape[d] - 1 - src_idx[d];
2195            }
2196            // linear offset
2197            let mut dst_pos = 0;
2198            let mut stride = 1;
2199            for d in (0..rank).rev() {
2200                dst_pos += dst_idx[d] * stride;
2201                stride *= shape[d];
2202            }
2203            out[dst_pos] = data[pos];
2204
2205            // increment src_idx
2206            for d in (0..rank).rev() {
2207                src_idx[d] += 1;
2208                if src_idx[d] < shape[d] {
2209                    break;
2210                }
2211                src_idx[d] = 0;
2212            }
2213        }
2214        Tensor::from_vec(shape.to_vec(), out)
2215    }
2216
2217    /// Circular shift elements along `dim` by `shift` positions.
2218    pub fn roll(&self, shift: i64, dim: usize) -> Result<Self, TensorError> {
2219        if dim >= self.rank() {
2220            return Err(TensorError::InvalidAxis {
2221                axis: dim,
2222                rank: self.rank(),
2223            });
2224        }
2225        let shape = self.shape();
2226        let outer: usize = shape[..dim].iter().product();
2227        let dim_len = shape[dim];
2228        let inner: usize = shape[dim + 1..].iter().product();
2229        let data = self.data();
2230
2231        let mut out = vec![0.0f32; data.len()];
2232        for o in 0..outer {
2233            for d in 0..dim_len {
2234                let dst_d = ((d as i64 + shift).rem_euclid(dim_len as i64)) as usize;
2235                for i in 0..inner {
2236                    out[(o * dim_len + dst_d) * inner + i] = data[(o * dim_len + d) * inner + i];
2237                }
2238            }
2239        }
2240        Tensor::from_vec(shape.to_vec(), out)
2241    }
2242
2243    // ── Factory: linspace / arange / meshgrid ──────────────────────────
2244
2245    /// Create a 1-D tensor of `steps` evenly spaced values from `start` to `end` (inclusive).
2246    pub fn linspace(start: f32, end: f32, steps: usize) -> Result<Self, TensorError> {
2247        if steps == 0 {
2248            return Tensor::from_vec(vec![0], vec![]);
2249        }
2250        if steps == 1 {
2251            return Tensor::from_vec(vec![1], vec![start]);
2252        }
2253        let step = (end - start) / (steps - 1) as f32;
2254        let data: Vec<f32> = (0..steps).map(|i| start + step * i as f32).collect();
2255        Tensor::from_vec(vec![steps], data)
2256    }
2257
2258    /// Create a 1-D tensor with values in `[start, end)` with given `step`.
2259    pub fn arange(start: f32, end: f32, step: f32) -> Result<Self, TensorError> {
2260        if step == 0.0 {
2261            return Err(TensorError::ShapeMismatch {
2262                left: vec![],
2263                right: vec![],
2264            });
2265        }
2266        let mut data = Vec::new();
2267        let mut v = start;
2268        if step > 0.0 {
2269            while v < end {
2270                data.push(v);
2271                v += step;
2272            }
2273        } else {
2274            while v > end {
2275                data.push(v);
2276                v += step;
2277            }
2278        }
2279        let n = data.len();
2280        Tensor::from_vec(vec![n], data)
2281    }
2282
2283    /// Create coordinate grids from 1-D tensors (numpy-style `meshgrid` with `indexing='ij'`).
2284    pub fn meshgrid(tensors: &[Self]) -> Result<Vec<Self>, TensorError> {
2285        let shape: Vec<usize> = tensors.iter().map(|t| t.len()).collect();
2286        let total: usize = shape.iter().product();
2287        let n = tensors.len();
2288        let mut result = Vec::with_capacity(n);
2289
2290        for (idx, t) in tensors.iter().enumerate() {
2291            let t_data = t.data();
2292            let mut out = vec![0.0f32; total];
2293            // stride pattern: product of dims after idx
2294            let inner: usize = shape[idx + 1..].iter().product();
2295            let outer: usize = shape[..idx].iter().product();
2296            let dim_len = shape[idx];
2297            for o in 0..outer {
2298                for d in 0..dim_len {
2299                    for i in 0..inner {
2300                        out[(o * dim_len + d) * inner + i] = t_data[d];
2301                    }
2302                }
2303            }
2304            result.push(Tensor::from_vec(shape.clone(), out)?);
2305        }
2306        Ok(result)
2307    }
2308
2309    // ── Advanced indexing extras ────────────────────────────────────────
2310
2311    /// Select elements where `mask` (f32, nonzero = true) is true, returned as 1-D.
2312    pub fn boolean_mask(&self, mask: &Self) -> Result<Self, TensorError> {
2313        if self.shape() != mask.shape() {
2314            return Err(TensorError::ShapeMismatch {
2315                left: self.shape().to_vec(),
2316                right: mask.shape().to_vec(),
2317            });
2318        }
2319        let data = self.data();
2320        let m = mask.data();
2321        let out: Vec<f32> = data
2322            .iter()
2323            .zip(m.iter())
2324            .filter(|(_, mv)| **mv != 0.0)
2325            .map(|(v, _)| *v)
2326            .collect();
2327        let n = out.len();
2328        Tensor::from_vec(vec![n], out)
2329    }
2330
2331    /// Select slices along `dim` using integer `indices` tensor (1-D).
2332    pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self, TensorError> {
2333        if dim >= self.rank() {
2334            return Err(TensorError::InvalidAxis {
2335                axis: dim,
2336                rank: self.rank(),
2337            });
2338        }
2339        let shape = self.shape();
2340        let idx_data = indices.data();
2341        let n_idx = idx_data.len();
2342        let outer: usize = shape[..dim].iter().product();
2343        let dim_len = shape[dim];
2344        let inner: usize = shape[dim + 1..].iter().product();
2345        let data = self.data();
2346
2347        let mut out = Vec::with_capacity(outer * n_idx * inner);
2348        for o in 0..outer {
2349            for &idx_f in idx_data {
2350                let idx = idx_f as usize;
2351                if idx >= dim_len {
2352                    return Err(TensorError::IndexOutOfBounds {
2353                        axis: dim,
2354                        index: idx,
2355                        dim: dim_len,
2356                    });
2357                }
2358                let src_start = (o * dim_len + idx) * inner;
2359                out.extend_from_slice(&data[src_start..src_start + inner]);
2360            }
2361        }
2362
2363        let mut out_shape = shape.to_vec();
2364        out_shape[dim] = n_idx;
2365        Tensor::from_vec(out_shape, out)
2366    }
2367
2368    // ── Random tensor creation ──────────────────────────────────────────
2369
2370    /// Create a tensor filled with uniform random values in [0, 1).
2371    pub fn rand(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2372        let n: usize = shape.iter().product();
2373        let mut rng = seed;
2374        let data: Vec<f32> = (0..n)
2375            .map(|_| {
2376                rng ^= rng << 13;
2377                rng ^= rng >> 7;
2378                rng ^= rng << 17;
2379                (rng as f32) / (u64::MAX as f32)
2380            })
2381            .collect();
2382        Self::from_vec(shape, data)
2383    }
2384
2385    /// Create a tensor filled with normally distributed random values (mean=0, std=1).
2386    /// Uses Box-Muller transform.
2387    pub fn randn(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2388        let n: usize = shape.iter().product();
2389        let mut rng = seed;
2390        let mut next_rng = || -> f32 {
2391            rng ^= rng << 13;
2392            rng ^= rng >> 7;
2393            rng ^= rng << 17;
2394            // Map to (0, 1) exclusive to avoid log(0)
2395            ((rng as f64) / (u64::MAX as f64)).clamp(1e-15, 1.0 - 1e-15) as f32
2396        };
2397        let mut data = Vec::with_capacity(n);
2398        let mut i = 0;
2399        while i < n {
2400            let u1 = next_rng();
2401            let u2 = next_rng();
2402            let r = (-2.0 * (u1 as f64).ln()).sqrt();
2403            let theta = 2.0 * std::f64::consts::PI * u2 as f64;
2404            data.push((r * theta.cos()) as f32);
2405            i += 1;
2406            if i < n {
2407                data.push((r * theta.sin()) as f32);
2408                i += 1;
2409            }
2410        }
2411        Self::from_vec(shape, data)
2412    }
2413
2414    /// Create a tensor filled with random integers in [low, high).
2415    pub fn randint(shape: Vec<usize>, low: i64, high: i64, seed: u64) -> Result<Self, TensorError> {
2416        if high <= low {
2417            return Err(TensorError::UnsupportedOperation {
2418                msg: format!("randint requires high > low, got low={low}, high={high}"),
2419            });
2420        }
2421        let range = (high - low) as u64;
2422        let n: usize = shape.iter().product();
2423        let mut rng = seed;
2424        let data: Vec<f32> = (0..n)
2425            .map(|_| {
2426                rng ^= rng << 13;
2427                rng ^= rng >> 7;
2428                rng ^= rng << 17;
2429                (low + (rng % range) as i64) as f32
2430            })
2431            .collect();
2432        Self::from_vec(shape, data)
2433    }
2434
2435    /// Create a random permutation of integers [0, n).
2436    pub fn randperm(n: usize, seed: u64) -> Result<Self, TensorError> {
2437        let mut perm: Vec<f32> = (0..n).map(|i| i as f32).collect();
2438        let mut rng = seed;
2439        for i in (1..n).rev() {
2440            rng ^= rng << 13;
2441            rng ^= rng >> 7;
2442            rng ^= rng << 17;
2443            let j = (rng as usize) % (i + 1);
2444            perm.swap(i, j);
2445        }
2446        Self::from_vec(vec![n], perm)
2447    }
2448}
2449
2450// ── Advanced tensor operations ──────────────────────────────────────────
2451
2452impl Tensor {
2453    /// Slice with step: extract every `step`-th element along `dim` from `start` to `end`.
2454    pub fn step_slice(
2455        &self,
2456        dim: usize,
2457        start: usize,
2458        end: usize,
2459        step: usize,
2460    ) -> Result<Self, TensorError> {
2461        let rank = self.rank();
2462        if dim >= rank {
2463            return Err(TensorError::InvalidAxis { axis: dim, rank });
2464        }
2465        if step == 0 {
2466            return Err(TensorError::UnsupportedOperation {
2467                msg: "step must be > 0".to_string(),
2468            });
2469        }
2470        let shape = self.shape();
2471        let dim_len = shape[dim];
2472        let end = end.min(dim_len);
2473        if start >= end {
2474            // empty along this dim
2475            let mut out_shape = shape.to_vec();
2476            out_shape[dim] = 0;
2477            return Tensor::from_vec(out_shape, vec![]);
2478        }
2479
2480        let selected_indices: Vec<usize> = (start..end).step_by(step).collect();
2481        let new_dim = selected_indices.len();
2482
2483        let outer: usize = shape[..dim].iter().product();
2484        let inner: usize = shape[dim + 1..].iter().product();
2485        let data = self.data();
2486
2487        let mut out = Vec::with_capacity(outer * new_dim * inner);
2488        for o in 0..outer {
2489            for &idx in &selected_indices {
2490                let src_start = (o * dim_len + idx) * inner;
2491                out.extend_from_slice(&data[src_start..src_start + inner]);
2492            }
2493        }
2494
2495        let mut out_shape = shape.to_vec();
2496        out_shape[dim] = new_dim;
2497        Tensor::from_vec(out_shape, out)
2498    }
2499
2500    /// Einstein summation for common patterns.
2501    ///
2502    /// Supported equations:
2503    /// - `"ij,jk->ik"` — matrix multiply
2504    /// - `"ij->ji"` — transpose
2505    /// - `"ii->i"` — diagonal
2506    /// - `"ij->i"` — row sum
2507    /// - `"ij->j"` — column sum
2508    /// - `"ij->"` — total sum
2509    /// - `"i,i->"` — dot product
2510    /// - `"ij,ij->"` — Frobenius inner product
2511    pub fn einsum(equation: &str, tensors: &[&Tensor]) -> Result<Tensor, TensorError> {
2512        let equation = equation.replace(' ', "");
2513        let parts: Vec<&str> = equation.split("->").collect();
2514        if parts.len() != 2 {
2515            return Err(TensorError::UnsupportedOperation {
2516                msg: format!("invalid einsum equation: {equation}"),
2517            });
2518        }
2519        let inputs_str = parts[0];
2520        let output_str = parts[1];
2521        let input_parts: Vec<&str> = inputs_str.split(',').collect();
2522
2523        if input_parts.len() != tensors.len() {
2524            return Err(TensorError::UnsupportedOperation {
2525                msg: format!(
2526                    "einsum equation has {} inputs but {} tensors provided",
2527                    input_parts.len(),
2528                    tensors.len()
2529                ),
2530            });
2531        }
2532
2533        // Match known patterns
2534        let pattern = format!(
2535            "{}->{}",
2536            input_parts
2537                .iter()
2538                .map(|s| s.to_string())
2539                .collect::<Vec<_>>()
2540                .join(","),
2541            output_str
2542        );
2543
2544        match pattern.as_str() {
2545            // matrix multiply: ij,jk->ik
2546            "ij,jk->ik" => {
2547                let a = tensors[0];
2548                let b = tensors[1];
2549                if a.rank() != 2 || b.rank() != 2 {
2550                    return Err(TensorError::UnsupportedOperation {
2551                        msg: "ij,jk->ik requires 2D tensors".to_string(),
2552                    });
2553                }
2554                let (m, k1) = (a.shape()[0], a.shape()[1]);
2555                let (k2, n) = (b.shape()[0], b.shape()[1]);
2556                if k1 != k2 {
2557                    return Err(TensorError::ShapeMismatch {
2558                        left: a.shape().to_vec(),
2559                        right: b.shape().to_vec(),
2560                    });
2561                }
2562                let ad = a.data();
2563                let bd = b.data();
2564                let mut out = vec![0.0f32; m * n];
2565                for i in 0..m {
2566                    for j in 0..n {
2567                        let mut sum = 0.0f32;
2568                        for p in 0..k1 {
2569                            sum += ad[i * k1 + p] * bd[p * n + j];
2570                        }
2571                        out[i * n + j] = sum;
2572                    }
2573                }
2574                Tensor::from_vec(vec![m, n], out)
2575            }
2576            // transpose: ij->ji
2577            "ij->ji" => {
2578                let a = tensors[0];
2579                if a.rank() != 2 {
2580                    return Err(TensorError::UnsupportedOperation {
2581                        msg: "ij->ji requires a 2D tensor".to_string(),
2582                    });
2583                }
2584                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2585                let ad = a.data();
2586                let mut out = vec![0.0f32; rows * cols];
2587                for i in 0..rows {
2588                    for j in 0..cols {
2589                        out[j * rows + i] = ad[i * cols + j];
2590                    }
2591                }
2592                Tensor::from_vec(vec![cols, rows], out)
2593            }
2594            // diagonal: ii->i
2595            "ii->i" => {
2596                let a = tensors[0];
2597                if a.rank() != 2 || a.shape()[0] != a.shape()[1] {
2598                    return Err(TensorError::UnsupportedOperation {
2599                        msg: "ii->i requires a square 2D tensor".to_string(),
2600                    });
2601                }
2602                let n = a.shape()[0];
2603                let ad = a.data();
2604                let out: Vec<f32> = (0..n).map(|i| ad[i * n + i]).collect();
2605                Tensor::from_vec(vec![n], out)
2606            }
2607            // row sum: ij->i
2608            "ij->i" => {
2609                let a = tensors[0];
2610                if a.rank() != 2 {
2611                    return Err(TensorError::UnsupportedOperation {
2612                        msg: "ij->i requires a 2D tensor".to_string(),
2613                    });
2614                }
2615                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2616                let ad = a.data();
2617                let out: Vec<f32> = (0..rows)
2618                    .map(|i| ad[i * cols..(i + 1) * cols].iter().sum())
2619                    .collect();
2620                Tensor::from_vec(vec![rows], out)
2621            }
2622            // column sum: ij->j
2623            "ij->j" => {
2624                let a = tensors[0];
2625                if a.rank() != 2 {
2626                    return Err(TensorError::UnsupportedOperation {
2627                        msg: "ij->j requires a 2D tensor".to_string(),
2628                    });
2629                }
2630                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2631                let ad = a.data();
2632                let mut out = vec![0.0f32; cols];
2633                for i in 0..rows {
2634                    for j in 0..cols {
2635                        out[j] += ad[i * cols + j];
2636                    }
2637                }
2638                Tensor::from_vec(vec![cols], out)
2639            }
2640            // total sum: ij->
2641            "ij->" => {
2642                let a = tensors[0];
2643                if a.rank() != 2 {
2644                    return Err(TensorError::UnsupportedOperation {
2645                        msg: "ij-> requires a 2D tensor".to_string(),
2646                    });
2647                }
2648                let sum: f32 = a.data().iter().sum();
2649                Ok(Tensor::scalar(sum))
2650            }
2651            // dot product: i,i->
2652            "i,i->" => {
2653                let a = tensors[0];
2654                let b = tensors[1];
2655                if a.rank() != 1 || b.rank() != 1 {
2656                    return Err(TensorError::UnsupportedOperation {
2657                        msg: "i,i-> requires 1D tensors".to_string(),
2658                    });
2659                }
2660                if a.shape()[0] != b.shape()[0] {
2661                    return Err(TensorError::ShapeMismatch {
2662                        left: a.shape().to_vec(),
2663                        right: b.shape().to_vec(),
2664                    });
2665                }
2666                let sum: f32 = a
2667                    .data()
2668                    .iter()
2669                    .zip(b.data().iter())
2670                    .map(|(x, y)| x * y)
2671                    .sum();
2672                Ok(Tensor::scalar(sum))
2673            }
2674            // Frobenius inner product: ij,ij->
2675            "ij,ij->" => {
2676                let a = tensors[0];
2677                let b = tensors[1];
2678                if a.rank() != 2 || b.rank() != 2 {
2679                    return Err(TensorError::UnsupportedOperation {
2680                        msg: "ij,ij-> requires 2D tensors".to_string(),
2681                    });
2682                }
2683                if a.shape() != b.shape() {
2684                    return Err(TensorError::ShapeMismatch {
2685                        left: a.shape().to_vec(),
2686                        right: b.shape().to_vec(),
2687                    });
2688                }
2689                let sum: f32 = a
2690                    .data()
2691                    .iter()
2692                    .zip(b.data().iter())
2693                    .map(|(x, y)| x * y)
2694                    .sum();
2695                Ok(Tensor::scalar(sum))
2696            }
2697            _ => Err(TensorError::UnsupportedOperation {
2698                msg: format!("unsupported einsum pattern: {pattern}"),
2699            }),
2700        }
2701    }
2702
2703    // ── Chunk ───────────────────────────────────────────────────────────
2704
2705    /// Split tensor into `n_chunks` pieces along `axis`. Last chunk may be smaller.
2706    pub fn chunk(&self, n_chunks: usize, axis: usize) -> Result<Vec<Self>, TensorError> {
2707        if axis >= self.rank() {
2708            return Err(TensorError::InvalidAxis {
2709                axis,
2710                rank: self.rank(),
2711            });
2712        }
2713        if n_chunks == 0 {
2714            return Err(TensorError::UnsupportedOperation {
2715                msg: "n_chunks must be > 0".to_string(),
2716            });
2717        }
2718        let dim = self.shape()[axis];
2719        let chunk_size = dim.div_ceil(n_chunks); // ceil division
2720        let mut chunks = Vec::new();
2721        let mut start = 0;
2722        while start < dim {
2723            let length = chunk_size.min(dim - start);
2724            chunks.push(self.narrow(axis, start, length)?);
2725            start += length;
2726        }
2727        Ok(chunks)
2728    }
2729
2730    // ── Histogram ───────────────────────────────────────────────────────
2731
2732    /// Counts elements in each bin, returns 1D tensor of shape `[bins]`.
2733    /// Bins are evenly spaced between `min` and `max`.
2734    pub fn histogram(&self, bins: usize, min: f32, max: f32) -> Result<Self, TensorError> {
2735        let mut counts = vec![0.0f32; bins];
2736        let range = max - min;
2737        for &v in self.data() {
2738            if v >= min && v <= max {
2739                let idx = if v == max {
2740                    bins - 1
2741                } else {
2742                    ((v - min) / range * bins as f32) as usize
2743                };
2744                counts[idx] += 1.0;
2745            }
2746        }
2747        Tensor::from_vec(vec![bins], counts)
2748    }
2749
2750    // ── Bincount ────────────────────────────────────────────────────────
2751
2752    /// Treats values as integer indices, counts occurrences.
2753    /// Returns 1D tensor of shape `[num_bins]`.
2754    pub fn bincount(&self, num_bins: usize) -> Result<Self, TensorError> {
2755        let mut counts = vec![0.0f32; num_bins];
2756        for &v in self.data() {
2757            let idx = v as usize;
2758            if idx < num_bins {
2759                counts[idx] += 1.0;
2760            }
2761        }
2762        Tensor::from_vec(vec![num_bins], counts)
2763    }
2764
2765    // ── Scalar convenience ──────────────────────────────────────────────
2766
2767    /// Returns the single scalar value if tensor has exactly one element.
2768    /// Errors if tensor has more than one element.
2769    pub fn item(&self) -> Result<f32, TensorError> {
2770        if self.len() != 1 {
2771            return Err(TensorError::ShapeMismatch {
2772                left: vec![1],
2773                right: self.shape().to_vec(),
2774            });
2775        }
2776        Ok(self.data()[0])
2777    }
2778
2779    /// Returns true if tensor has exactly one element.
2780    pub fn is_scalar(&self) -> bool {
2781        self.len() == 1
2782    }
2783
2784    // ── Scatter Add ──────────────────────────────────────────────────
2785
2786    /// Like `scatter` but adds instead of replacing values.
2787    ///
2788    /// For `dim=1`: `self[i][index[i][j][k]][k] += src[i][j][k]`
2789    pub fn scatter_add(&self, dim: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
2790        if dim >= self.rank() {
2791            return Err(TensorError::InvalidAxis {
2792                axis: dim,
2793                rank: self.rank(),
2794            });
2795        }
2796        if index.rank() != self.rank() {
2797            return Err(TensorError::InvalidIndexRank {
2798                expected: self.rank(),
2799                got: index.rank(),
2800            });
2801        }
2802        if src.shape() != index.shape() {
2803            return Err(TensorError::ShapeMismatch {
2804                left: src.shape().to_vec(),
2805                right: index.shape().to_vec(),
2806            });
2807        }
2808
2809        let self_shape = self.shape();
2810        let idx_shape = index.shape();
2811        let idx_data = index.data();
2812        let src_data = src.data();
2813        let ndim = self.rank();
2814
2815        let mut out = self.data().to_vec();
2816        let mut coords = vec![0usize; ndim];
2817
2818        for pos in 0..index.len() {
2819            let idx_val = idx_data[pos] as usize;
2820            if idx_val >= self_shape[dim] {
2821                return Err(TensorError::IndexOutOfBounds {
2822                    axis: dim,
2823                    index: idx_val,
2824                    dim: self_shape[dim],
2825                });
2826            }
2827
2828            let mut dst_offset = 0;
2829            for d in 0..ndim {
2830                let c = if d == dim { idx_val } else { coords[d] };
2831                dst_offset += c * self.strides()[d];
2832            }
2833            out[dst_offset] += src_data[pos];
2834
2835            increment_coords(&mut coords, idx_shape);
2836        }
2837
2838        Tensor::from_vec(self_shape.to_vec(), out)
2839    }
2840}
2841
2842/// Convert IEEE 754 half-precision (FP16) bit pattern to f32.
2843fn fp16_to_f32(half: u16) -> f32 {
2844    let sign = ((half & 0x8000) as u32) << 16;
2845    let exponent = (half >> 10) & 0x1F;
2846    let mantissa = (half & 0x03FF) as u32;
2847
2848    if exponent == 0 {
2849        if mantissa == 0 {
2850            return f32::from_bits(sign); // zero
2851        }
2852        // subnormal
2853        let mut m = mantissa;
2854        let mut e = 0i32;
2855        while m & 0x0400 == 0 {
2856            m <<= 1;
2857            e += 1;
2858        }
2859        m &= 0x03FF;
2860        let f32_exp = ((127 - 15 - e) as u32) << 23;
2861        let f32_man = m << 13;
2862        return f32::from_bits(sign | f32_exp | f32_man);
2863    }
2864    if exponent == 31 {
2865        let f32_exp = 0xFF << 23;
2866        let f32_man = mantissa << 13;
2867        return f32::from_bits(sign | f32_exp | f32_man);
2868    }
2869
2870    let f32_exp = ((exponent as i32 - 15 + 127) as u32 & 0xFF) << 23;
2871    let f32_man = mantissa << 13;
2872    f32::from_bits(sign | f32_exp | f32_man)
2873}