Skip to main content

yscv_tensor/
ops.rs

1use super::aligned::AlignedVec;
2use super::error::TensorError;
3use super::shape::{
4    broadcast_offset, broadcast_shape, compute_strides, increment_coords, shape_element_count,
5};
6use super::simd;
7use super::tensor::Tensor;
8
9impl Tensor {
10    // ── Binary elementwise with broadcasting ────────────────────────────
11
12    /// Element-wise addition with NumPy-style broadcasting.
13    pub fn add(&self, rhs: &Self) -> Result<Self, TensorError> {
14        if self.shape() == rhs.shape() {
15            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Add);
16        }
17        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Add) {
18            return result;
19        }
20        self.binary_broadcast_op(rhs, |l, r| l + r)
21    }
22
23    /// Element-wise subtraction with NumPy-style broadcasting.
24    pub fn sub(&self, rhs: &Self) -> Result<Self, TensorError> {
25        if self.shape() == rhs.shape() {
26            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Sub);
27        }
28        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Sub) {
29            return result;
30        }
31        self.binary_broadcast_op(rhs, |l, r| l - r)
32    }
33
34    /// Element-wise multiplication with NumPy-style broadcasting.
35    pub fn mul(&self, rhs: &Self) -> Result<Self, TensorError> {
36        if self.shape() == rhs.shape() {
37            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Mul);
38        }
39        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Mul) {
40            return result;
41        }
42        self.binary_broadcast_op(rhs, |l, r| l * r)
43    }
44
45    /// Element-wise division with NumPy-style broadcasting.
46    pub fn div(&self, rhs: &Self) -> Result<Self, TensorError> {
47        if self.shape() == rhs.shape() {
48            return self.binary_same_shape_simd(rhs, simd::BinaryKind::Div);
49        }
50        if let Some(result) = self.binary_broadcast_lastdim_simd(rhs, simd::BinaryKind::Div) {
51            return result;
52        }
53        self.binary_broadcast_op(rhs, |l, r| l / r)
54    }
55
56    /// Element-wise power with NumPy-style broadcasting.
57    ///
58    /// Fast paths:
59    /// - Constant exponent 2.0 → `mul(x, x)` (SIMD-accelerated).
60    /// - Constant exponent 0.5 → `sqrt(x)` (SIMD-accelerated).
61    /// - Same-shape general case → SIMD `exp(exp * ln(base))`.
62    #[allow(unsafe_code)]
63    pub fn pow(&self, rhs: &Self) -> Result<Self, TensorError> {
64        // Fast path: constant exponent (O(1) check instead of O(N) scan).
65        // A single-element tensor or a tensor whose total size is 1 (broadcast scalar)
66        // is always constant.  For same-shape tensors we skip the all-equal scan
67        // entirely — the per-element SIMD path handles them efficiently.
68        let rhs_total: usize = rhs.shape().iter().product();
69        if rhs_total == 1 {
70            let exp_val = rhs.data()[0];
71            if exp_val == 2.0 {
72                return self.mul(self);
73            }
74            if exp_val == 0.5 {
75                return Ok(self.sqrt());
76            }
77            if exp_val == 1.0 {
78                return Ok(self.clone());
79            }
80            if exp_val == 0.0 {
81                return Tensor::ones(self.shape().to_vec());
82            }
83            if exp_val == -1.0 {
84                return Ok(self.reciprocal());
85            }
86        }
87        // Same-shape SIMD path: pow(base, exp) = exp(exp * ln(base))
88        if self.shape() == rhs.shape() {
89            return self.pow_simd(rhs);
90        }
91        self.binary_broadcast_op(rhs, |l, r| l.powf(r))
92    }
93
94    /// SIMD-accelerated pow for same-shape tensors via exp(exp * ln(base)).
95    #[allow(unsafe_code)]
96    fn pow_simd(&self, rhs: &Self) -> Result<Self, TensorError> {
97        let len = self.len();
98        // Compute ln(base)
99        let mut ln_buf = AlignedVec::<f32>::uninitialized(len);
100        simd::ln_dispatch(self.data(), &mut ln_buf);
101        // Multiply by exponent: exp_val * ln(base)
102        let mut prod_buf = AlignedVec::<f32>::uninitialized(len);
103        simd::binary_dispatch(&ln_buf, rhs.data(), &mut prod_buf, simd::BinaryKind::Mul);
104        // Compute exp(exp_val * ln(base))
105        let mut out = AlignedVec::<f32>::uninitialized(len);
106        simd::exp_dispatch(&prod_buf, &mut out);
107        Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
108    }
109
110    /// Element-wise atan2(self, other), with broadcasting.
111    ///
112    /// Same-shape case uses a SIMD-friendly polynomial approximation of atan
113    /// with quadrant correction, avoiding scalar `f32::atan2`.
114    #[allow(unsafe_code)]
115    pub fn atan2(&self, rhs: &Self) -> Result<Self, TensorError> {
116        if self.shape() == rhs.shape() {
117            return self.atan2_fast(rhs);
118        }
119        self.binary_broadcast_op(rhs, f32::atan2)
120    }
121
122    /// Vectorized atan2 for same-shape tensors. Uses Cephes-style range
123    /// reduction to [0, tan(pi/12)] then a polynomial approximation,
124    /// giving < 1e-6 max error across all quadrants.
125    #[allow(unsafe_code)]
126    fn atan2_fast(&self, rhs: &Self) -> Result<Self, TensorError> {
127        let y_data = self.data();
128        let x_data = rhs.data();
129        let len = self.len();
130        let mut out = AlignedVec::<f32>::uninitialized(len);
131
132        // Process in chunks of 4 for ILP (instruction-level parallelism).
133        // Each call to the scalar fast path avoids the overhead of the
134        // generic broadcasting machinery.
135        let mut i = 0;
136        while i + 4 <= len {
137            out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
138            out[i + 1] = fast_atan2_scalar(y_data[i + 1], x_data[i + 1]);
139            out[i + 2] = fast_atan2_scalar(y_data[i + 2], x_data[i + 2]);
140            out[i + 3] = fast_atan2_scalar(y_data[i + 3], x_data[i + 3]);
141            i += 4;
142        }
143        while i < len {
144            out[i] = fast_atan2_scalar(y_data[i], x_data[i]);
145            i += 1;
146        }
147
148        Ok(Tensor::from_raw_parts(self.shape(), self.strides(), out))
149    }
150
151    /// Element-wise minimum with NumPy-style broadcasting.
152    pub fn minimum(&self, rhs: &Self) -> Result<Self, TensorError> {
153        if self.shape() == rhs.shape() {
154            return self.binary_same_shape(rhs, f32::min);
155        }
156        self.binary_broadcast_op(rhs, f32::min)
157    }
158
159    /// Element-wise maximum with NumPy-style broadcasting.
160    pub fn maximum(&self, rhs: &Self) -> Result<Self, TensorError> {
161        if self.shape() == rhs.shape() {
162            return self.binary_same_shape(rhs, f32::max);
163        }
164        self.binary_broadcast_op(rhs, f32::max)
165    }
166
167    // ── Unary elementwise ───────────────────────────────────────────────
168
169    /// Element-wise negation.
170    pub fn neg(&self) -> Self {
171        self.unary_simd_op(simd::UnaryKind::Neg)
172    }
173
174    /// Element-wise absolute value.
175    pub fn abs(&self) -> Self {
176        self.unary_simd_op(simd::UnaryKind::Abs)
177    }
178
179    /// Element-wise natural exponential.
180    #[allow(unsafe_code)]
181    pub fn exp(&self) -> Self {
182        let len = self.len();
183        // SAFETY: `uninitialized` allocates without zeroing.  `exp_dispatch`
184        // writes every element before anything reads from `out`.
185        let mut out = AlignedVec::<f32>::uninitialized(len);
186        simd::exp_dispatch(self.data(), &mut out);
187        Tensor::from_raw_parts(self.shape(), self.strides(), out)
188    }
189
190    /// Element-wise natural logarithm.
191    #[allow(unsafe_code)]
192    pub fn ln(&self) -> Self {
193        let len = self.len();
194        // SAFETY: `uninitialized` allocates without zeroing.  `ln_dispatch`
195        // writes every element before we ever read from `out`.
196        let mut out = AlignedVec::<f32>::uninitialized(len);
197        simd::ln_dispatch(self.data(), &mut out);
198        Tensor::from_raw_parts(self.shape(), self.strides(), out)
199    }
200
201    /// Element-wise square root.
202    pub fn sqrt(&self) -> Self {
203        self.unary_simd_op(simd::UnaryKind::Sqrt)
204    }
205
206    /// Element-wise reciprocal (`1 / x`).
207    pub fn reciprocal(&self) -> Self {
208        self.unary_simd_op(simd::UnaryKind::Recip)
209    }
210
211    /// Element-wise sign (`-1`, `0`, or `1`).
212    pub fn sign(&self) -> Self {
213        self.unary_simd_op(simd::UnaryKind::Sign)
214    }
215
216    /// Element-wise floor.
217    pub fn floor(&self) -> Self {
218        self.unary_simd_op(simd::UnaryKind::Floor)
219    }
220
221    /// Element-wise ceil.
222    pub fn ceil(&self) -> Self {
223        self.unary_simd_op(simd::UnaryKind::Ceil)
224    }
225
226    /// Element-wise round.
227    pub fn round(&self) -> Self {
228        self.unary_simd_op(simd::UnaryKind::Round)
229    }
230
231    /// Element-wise sine (SIMD-accelerated polynomial approximation).
232    #[allow(unsafe_code)]
233    pub fn sin(&self) -> Self {
234        let len = self.len();
235        let mut out = AlignedVec::<f32>::uninitialized(len);
236        simd::sin_dispatch(self.data(), &mut out);
237        Tensor::from_raw_parts(self.shape(), self.strides(), out)
238    }
239
240    /// Element-wise cosine (SIMD-accelerated polynomial approximation).
241    #[allow(unsafe_code)]
242    pub fn cos(&self) -> Self {
243        let len = self.len();
244        let mut out = AlignedVec::<f32>::uninitialized(len);
245        simd::cos_dispatch(self.data(), &mut out);
246        Tensor::from_raw_parts(self.shape(), self.strides(), out)
247    }
248
249    /// Element-wise tangent (SIMD-accelerated via sin/cos).
250    #[allow(unsafe_code)]
251    pub fn tan(&self) -> Self {
252        let len = self.len();
253        let mut out = AlignedVec::<f32>::uninitialized(len);
254        simd::tan_dispatch(self.data(), &mut out);
255        Tensor::from_raw_parts(self.shape(), self.strides(), out)
256    }
257
258    /// Element-wise arcsine.
259    pub fn asin(&self) -> Self {
260        self.unary_op(f32::asin)
261    }
262
263    /// Element-wise arccosine.
264    pub fn acos(&self) -> Self {
265        self.unary_op(f32::acos)
266    }
267
268    /// Element-wise arctangent.
269    pub fn atan(&self) -> Self {
270        self.unary_op(f32::atan)
271    }
272
273    /// Element-wise hyperbolic sine.
274    pub fn sinh(&self) -> Self {
275        self.unary_op(f32::sinh)
276    }
277
278    /// Element-wise hyperbolic cosine.
279    pub fn cosh(&self) -> Self {
280        self.unary_op(f32::cosh)
281    }
282
283    /// Element-wise base-2 logarithm.
284    pub fn log2(&self) -> Self {
285        self.unary_op(f32::log2)
286    }
287
288    /// Element-wise base-10 logarithm.
289    pub fn log10(&self) -> Self {
290        self.unary_op(f32::log10)
291    }
292
293    /// Convert radians to degrees.
294    pub fn degrees(&self) -> Self {
295        self.unary_op(|v| v.to_degrees())
296    }
297
298    /// Convert degrees to radians.
299    pub fn radians(&self) -> Self {
300        self.unary_op(|v| v.to_radians())
301    }
302
303    /// Clamp all elements to `[min, max]`.
304    #[allow(unsafe_code)]
305    pub fn clamp(&self, min: f32, max: f32) -> Self {
306        let len = self.len();
307        let mut out = AlignedVec::<f32>::uninitialized(len);
308        simd::clamp_dispatch(self.data(), &mut out, min, max);
309        Tensor::from_raw_parts(self.shape(), self.strides(), out)
310    }
311
312    /// Scalar multiplication (broadcast multiply by a constant).
313    pub fn scale(&self, factor: f32) -> Self {
314        self.unary_op(|v| v * factor)
315    }
316
317    /// Add a scalar to all elements.
318    pub fn add_scalar(&self, value: f32) -> Self {
319        self.unary_op(|v| v + value)
320    }
321
322    // ── Reductions ──────────────────────────────────────────────────────
323
324    /// Sum reduction over all elements.
325    pub fn sum(&self) -> f32 {
326        simd::sum_dispatch(self.data())
327    }
328
329    /// Mean reduction over all elements.
330    pub fn mean(&self) -> f32 {
331        if self.is_empty() {
332            return f32::NAN;
333        }
334        self.sum() / self.len() as f32
335    }
336
337    /// Global max reduction. Returns `f32::NEG_INFINITY` for empty tensors.
338    pub fn max_value(&self) -> f32 {
339        simd::max_dispatch(self.data())
340    }
341
342    /// Global min reduction. Returns `f32::INFINITY` for empty tensors.
343    pub fn min_value(&self) -> f32 {
344        simd::min_dispatch(self.data())
345    }
346
347    /// Global argmax (flat index of maximum value).
348    pub fn argmax(&self) -> Option<usize> {
349        if self.is_empty() {
350            return None;
351        }
352        let mut best = 0;
353        let mut best_val = self.data()[0];
354        for (i, &v) in self.data().iter().enumerate().skip(1) {
355            if v > best_val {
356                best_val = v;
357                best = i;
358            }
359        }
360        Some(best)
361    }
362
363    /// Global argmin (flat index of minimum value).
364    pub fn argmin(&self) -> Option<usize> {
365        if self.is_empty() {
366            return None;
367        }
368        let mut best = 0;
369        let mut best_val = self.data()[0];
370        for (i, &v) in self.data().iter().enumerate().skip(1) {
371            if v < best_val {
372                best_val = v;
373                best = i;
374            }
375        }
376        Some(best)
377    }
378
379    /// Variance over all elements (population variance).
380    pub fn var(&self) -> f32 {
381        if self.is_empty() {
382            return f32::NAN;
383        }
384        let m = self.mean();
385        self.data().iter().map(|&v| (v - m) * (v - m)).sum::<f32>() / self.len() as f32
386    }
387
388    /// Standard deviation over all elements (population).
389    pub fn std_dev(&self) -> f32 {
390        self.var().sqrt()
391    }
392
393    /// Sum reduction over one axis. Reduced axis is removed from output shape.
394    pub fn sum_axis(&self, axis: usize) -> Result<Self, TensorError> {
395        let shape = self.shape();
396        let rank = shape.len();
397        if axis >= rank {
398            return Err(TensorError::InvalidAxis { axis, rank });
399        }
400
401        // Fast path: 2D contiguous tensor, axis 0 → accumulate rows with SIMD
402        if rank == 2 && axis == 0 {
403            let (rows, cols) = (shape[0], shape[1]);
404            let data = self.data();
405            let mut out = vec![0.0f32; cols];
406            for row in 0..rows {
407                let row_start = row * cols;
408                simd::add_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
409            }
410            return Self::from_vec(vec![cols], out);
411        }
412
413        // Fast path: 2D contiguous tensor, axis 1 → sum each row with SIMD
414        if rank == 2 && axis == 1 {
415            let (rows, cols) = (shape[0], shape[1]);
416            let data = self.data();
417            let mut out = Vec::with_capacity(rows);
418            for row in 0..rows {
419                out.push(simd::sum_dispatch(&data[row * cols..(row + 1) * cols]));
420            }
421            return Self::from_vec(vec![rows], out);
422        }
423
424        self.reduce_axis(axis, 0.0, |acc, v| acc + v)
425    }
426
427    /// Mean reduction over one axis. Reduced axis is removed from output shape.
428    pub fn mean_axis(&self, axis: usize) -> Result<Self, TensorError> {
429        if axis >= self.rank() {
430            return Err(TensorError::InvalidAxis {
431                axis,
432                rank: self.rank(),
433            });
434        }
435        let axis_len = self.shape()[axis] as f32;
436        let sum = self.sum_axis(axis)?;
437        Ok(sum.scale(1.0 / axis_len))
438    }
439
440    /// Max reduction over one axis. Reduced axis is removed from output shape.
441    pub fn max_axis(&self, axis: usize) -> Result<Self, TensorError> {
442        let shape = self.shape();
443        let rank = shape.len();
444        if axis >= rank {
445            return Err(TensorError::InvalidAxis { axis, rank });
446        }
447
448        // Fast path: 2D contiguous tensor, axis 0 → accumulate rows with SIMD max
449        if rank == 2 && axis == 0 {
450            let (rows, cols) = (shape[0], shape[1]);
451            let data = self.data();
452            let mut out = data[..cols].to_vec();
453            for row in 1..rows {
454                let row_start = row * cols;
455                simd::max_inplace_dispatch(&mut out, &data[row_start..row_start + cols]);
456            }
457            return Self::from_vec(vec![cols], out);
458        }
459
460        // Fast path: 2D contiguous tensor, axis 1 → max each row with SIMD
461        if rank == 2 && axis == 1 {
462            let (rows, cols) = (shape[0], shape[1]);
463            let data = self.data();
464            let mut out = Vec::with_capacity(rows);
465            for row in 0..rows {
466                out.push(simd::max_dispatch(&data[row * cols..(row + 1) * cols]));
467            }
468            return Self::from_vec(vec![rows], out);
469        }
470
471        self.reduce_axis(axis, f32::NEG_INFINITY, f32::max)
472    }
473
474    /// Min reduction over one axis. Reduced axis is removed from output shape.
475    pub fn min_axis(&self, axis: usize) -> Result<Self, TensorError> {
476        self.reduce_axis(axis, f32::INFINITY, f32::min)
477    }
478
479    /// Variance reduction over one axis (population variance).
480    pub fn var_axis(&self, axis: usize) -> Result<Self, TensorError> {
481        let m = self.mean_axis(axis)?;
482        let diff = self.sub(&m.unsqueeze(axis)?)?;
483        let sq = diff.mul(&diff)?;
484        sq.mean_axis(axis)
485    }
486
487    /// Global median of all elements.
488    pub fn median(&self) -> f32 {
489        let mut sorted = self.data().to_vec();
490        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
491        let n = sorted.len();
492        if n == 0 {
493            return 0.0;
494        }
495        if n % 2 == 1 {
496            sorted[n / 2]
497        } else {
498            (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
499        }
500    }
501
502    /// Median along a given axis.
503    pub fn median_axis(&self, axis: usize) -> Result<Self, TensorError> {
504        let shape = self.shape();
505        let rank = shape.len();
506        if axis >= rank {
507            return Err(TensorError::InvalidAxis { axis, rank });
508        }
509        let axis_len = shape[axis];
510        let outer: usize = shape[..axis].iter().product();
511        let inner: usize = shape[axis + 1..].iter().product();
512        let mut new_shape = shape.to_vec();
513        new_shape.remove(axis);
514        if new_shape.is_empty() {
515            new_shape.push(1);
516        }
517        let data = self.data();
518        let mut result = Vec::with_capacity(outer * inner);
519        for o in 0..outer {
520            for i in 0..inner {
521                let mut vals: Vec<f32> = (0..axis_len)
522                    .map(|a| data[o * axis_len * inner + a * inner + i])
523                    .collect();
524                vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525                let n = vals.len();
526                let med = if n % 2 == 1 {
527                    vals[n / 2]
528                } else {
529                    (vals[n / 2 - 1] + vals[n / 2]) / 2.0
530                };
531                result.push(med);
532            }
533        }
534        Self::from_vec(new_shape, result)
535    }
536
537    /// Returns true if any element is non-zero.
538    pub fn any(&self) -> bool {
539        self.data().iter().any(|&v| v != 0.0)
540    }
541
542    /// Returns true if all elements are non-zero.
543    pub fn all(&self) -> bool {
544        self.data().iter().all(|&v| v != 0.0)
545    }
546
547    /// Quantile of all elements. q must be in [0, 1].
548    pub fn quantile(&self, q: f32) -> f32 {
549        let mut sorted = self.data().to_vec();
550        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
551        let n = sorted.len();
552        if n == 0 {
553            return 0.0;
554        }
555        let idx = q * (n - 1) as f32;
556        let lo = idx.floor() as usize;
557        let hi = idx.ceil() as usize;
558        if lo == hi || hi >= n {
559            sorted[lo.min(n - 1)]
560        } else {
561            let frac = idx - lo as f32;
562            sorted[lo] * (1.0 - frac) + sorted[hi] * frac
563        }
564    }
565
566    // ── Shape manipulation ──────────────────────────────────────────────
567
568    /// 2D matrix transpose. Requires rank-2 input.
569    ///
570    /// # Safety
571    /// `AlignedVec::uninitialized` allocates without zeroing. The tiled loop
572    /// writes every element before anything reads from the buffer.
573    #[allow(unsafe_code)]
574    pub fn transpose_2d(&self) -> Result<Self, TensorError> {
575        if self.rank() != 2 {
576            return Err(TensorError::InvalidAxis {
577                axis: 1,
578                rank: self.rank(),
579            });
580        }
581        let rows = self.shape()[0];
582        let cols = self.shape()[1];
583        // SAFETY: every element is written by the tiled loop below before we read.
584        let mut out_data = AlignedVec::<f32>::uninitialized(rows * cols);
585        let src = self.data();
586
587        // Tiled transpose with 8x8 blocks for cache efficiency.
588        const TILE: usize = 8;
589        let rr = rows / TILE * TILE;
590        let cc = cols / TILE * TILE;
591
592        for ii in (0..rr).step_by(TILE) {
593            for jj in (0..cc).step_by(TILE) {
594                for r in ii..ii + TILE {
595                    for c in jj..jj + TILE {
596                        out_data[c * rows + r] = src[r * cols + c];
597                    }
598                }
599            }
600            // Right edge (columns beyond cc)
601            for r in ii..ii + TILE {
602                for c in cc..cols {
603                    out_data[c * rows + r] = src[r * cols + c];
604                }
605            }
606        }
607        // Bottom edge (rows beyond rr)
608        for r in rr..rows {
609            for c in 0..cols {
610                out_data[c * rows + r] = src[r * cols + c];
611            }
612        }
613
614        Tensor::from_aligned(vec![cols, rows], out_data)
615    }
616
617    /// General axis permutation (like NumPy `transpose`/`permute`).
618    pub fn permute(&self, axes: &[usize]) -> Result<Self, TensorError> {
619        if axes.len() != self.rank() {
620            return Err(TensorError::InvalidIndexRank {
621                expected: self.rank(),
622                got: axes.len(),
623            });
624        }
625        let rank = self.rank();
626        let mut seen = vec![false; rank];
627        for &a in axes {
628            if a >= rank {
629                return Err(TensorError::InvalidAxis { axis: a, rank });
630            }
631            seen[a] = true;
632        }
633        if seen.iter().any(|&s| !s) {
634            return Err(TensorError::InvalidAxis { axis: 0, rank });
635        }
636
637        let src_shape = self.shape();
638        let mut out_shape = vec![0usize; rank];
639        for (dst, &src_axis) in axes.iter().enumerate() {
640            out_shape[dst] = src_shape[src_axis];
641        }
642        let out_count =
643            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
644                shape: out_shape.clone(),
645            })?;
646        let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
647            shape: out_shape.clone(),
648        })?;
649        let mut out_data = vec![0.0f32; out_count];
650
651        let mut in_coords = vec![0usize; rank];
652        for &val in self.data().iter() {
653            let mut out_offset = 0usize;
654            for (dst_axis, &src_axis) in axes.iter().enumerate() {
655                out_offset += in_coords[src_axis] * out_strides[dst_axis];
656            }
657            out_data[out_offset] = val;
658            increment_coords(&mut in_coords, src_shape);
659        }
660
661        Tensor::from_vec(out_shape, out_data)
662    }
663
664    /// Insert a length-1 dimension at the given axis.
665    pub fn unsqueeze(&self, axis: usize) -> Result<Self, TensorError> {
666        if axis > self.rank() {
667            return Err(TensorError::InvalidAxis {
668                axis,
669                rank: self.rank() + 1,
670            });
671        }
672        let mut new_shape = self.shape().to_vec();
673        new_shape.insert(axis, 1);
674        self.reshape(new_shape)
675    }
676
677    /// Remove a length-1 dimension at the given axis.
678    pub fn squeeze(&self, axis: usize) -> Result<Self, TensorError> {
679        if axis >= self.rank() {
680            return Err(TensorError::InvalidAxis {
681                axis,
682                rank: self.rank(),
683            });
684        }
685        if self.shape()[axis] != 1 {
686            return Err(TensorError::InvalidAxis {
687                axis,
688                rank: self.rank(),
689            });
690        }
691        let mut new_shape = self.shape().to_vec();
692        new_shape.remove(axis);
693        self.reshape(new_shape)
694    }
695
696    /// Concatenate tensors along an axis. All tensors must have the same
697    /// shape except along the concatenation axis.
698    pub fn cat(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
699        if tensors.is_empty() {
700            return Err(TensorError::SizeMismatch {
701                shape: vec![],
702                data_len: 0,
703            });
704        }
705        let rank = tensors[0].rank();
706        if axis >= rank {
707            return Err(TensorError::InvalidAxis { axis, rank });
708        }
709        for t in &tensors[1..] {
710            if t.rank() != rank {
711                return Err(TensorError::ShapeMismatch {
712                    left: tensors[0].shape().to_vec(),
713                    right: t.shape().to_vec(),
714                });
715            }
716            for (a, (&d0, &di)) in tensors[0].shape().iter().zip(t.shape().iter()).enumerate() {
717                if a != axis && d0 != di {
718                    return Err(TensorError::ShapeMismatch {
719                        left: tensors[0].shape().to_vec(),
720                        right: t.shape().to_vec(),
721                    });
722                }
723            }
724        }
725
726        let mut out_shape = tensors[0].shape().to_vec();
727        out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
728        let out_count =
729            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
730                shape: out_shape.clone(),
731            })?;
732
733        let outer: usize = out_shape[..axis].iter().product();
734        let inner: usize = out_shape[axis + 1..].iter().product();
735        let mut out_data = Vec::with_capacity(out_count);
736
737        for o in 0..outer {
738            for t in tensors {
739                let t_axis_len = t.shape()[axis];
740                let chunk_start = o * t_axis_len * inner;
741                let chunk_end = chunk_start + t_axis_len * inner;
742                out_data.extend_from_slice(&t.data()[chunk_start..chunk_end]);
743            }
744        }
745
746        Tensor::from_vec(out_shape, out_data)
747    }
748
749    /// Stack tensors along a new axis. All tensors must have identical shapes.
750    pub fn stack(tensors: &[&Self], axis: usize) -> Result<Self, TensorError> {
751        if tensors.is_empty() {
752            return Err(TensorError::SizeMismatch {
753                shape: vec![],
754                data_len: 0,
755            });
756        }
757        if axis > tensors[0].rank() {
758            return Err(TensorError::InvalidAxis {
759                axis,
760                rank: tensors[0].rank() + 1,
761            });
762        }
763        let expanded: Vec<Self> = tensors
764            .iter()
765            .map(|t| t.unsqueeze(axis))
766            .collect::<Result<_, _>>()?;
767        let refs: Vec<&Self> = expanded.iter().collect();
768        Self::cat(&refs, axis)
769    }
770
771    /// Select a single slice along an axis, removing that axis from the output.
772    pub fn select(&self, axis: usize, index: usize) -> Result<Self, TensorError> {
773        if axis >= self.rank() {
774            return Err(TensorError::InvalidAxis {
775                axis,
776                rank: self.rank(),
777            });
778        }
779        if index >= self.shape()[axis] {
780            return Err(TensorError::IndexOutOfBounds {
781                axis,
782                index,
783                dim: self.shape()[axis],
784            });
785        }
786        let outer: usize = self.shape()[..axis].iter().product();
787        let axis_len = self.shape()[axis];
788        let inner: usize = self.shape()[axis + 1..].iter().product();
789
790        let mut out_data = Vec::with_capacity(outer * inner);
791        for o in 0..outer {
792            let base = o * axis_len * inner + index * inner;
793            out_data.extend_from_slice(&self.data()[base..base + inner]);
794        }
795
796        let mut out_shape = self.shape().to_vec();
797        out_shape.remove(axis);
798        Tensor::from_vec(out_shape, out_data)
799    }
800
801    /// Narrow (slice) along an axis: extract elements `start..start+length`.
802    pub fn narrow(&self, axis: usize, start: usize, length: usize) -> Result<Self, TensorError> {
803        if axis >= self.rank() {
804            return Err(TensorError::InvalidAxis {
805                axis,
806                rank: self.rank(),
807            });
808        }
809        if start + length > self.shape()[axis] {
810            return Err(TensorError::IndexOutOfBounds {
811                axis,
812                index: start + length,
813                dim: self.shape()[axis],
814            });
815        }
816        let outer: usize = self.shape()[..axis].iter().product();
817        let axis_len = self.shape()[axis];
818        let inner: usize = self.shape()[axis + 1..].iter().product();
819
820        let mut out_data = Vec::with_capacity(outer * length * inner);
821        for o in 0..outer {
822            let base = o * axis_len * inner + start * inner;
823            out_data.extend_from_slice(&self.data()[base..base + length * inner]);
824        }
825
826        let mut out_shape = self.shape().to_vec();
827        out_shape[axis] = length;
828        Tensor::from_vec(out_shape, out_data)
829    }
830
831    // ── Comparison ──────────────────────────────────────────────────────
832
833    /// Element-wise equality check: 1.0 where equal, 0.0 otherwise.
834    pub fn eq_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
835        if self.shape() == rhs.shape() {
836            return self.binary_same_shape(rhs, |l, r| {
837                if (l - r).abs() < f32::EPSILON {
838                    1.0
839                } else {
840                    0.0
841                }
842            });
843        }
844        self.binary_broadcast_op(rhs, |l, r| {
845            if (l - r).abs() < f32::EPSILON {
846                1.0
847            } else {
848                0.0
849            }
850        })
851    }
852
853    /// Element-wise greater-than: 1.0 where `self > rhs`, 0.0 otherwise.
854    pub fn gt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
855        if self.shape() == rhs.shape() {
856            return self.binary_same_shape(rhs, |l, r| if l > r { 1.0 } else { 0.0 });
857        }
858        self.binary_broadcast_op(rhs, |l, r| if l > r { 1.0 } else { 0.0 })
859    }
860
861    /// Element-wise less-than: 1.0 where `self < rhs`, 0.0 otherwise.
862    pub fn lt_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
863        if self.shape() == rhs.shape() {
864            return self.binary_same_shape(rhs, |l, r| if l < r { 1.0 } else { 0.0 });
865        }
866        self.binary_broadcast_op(rhs, |l, r| if l < r { 1.0 } else { 0.0 })
867    }
868
869    /// Element-wise greater-than writing into a pre-allocated output tensor.
870    /// `self`, `rhs`, and `output` must all have the same shape.
871    pub fn gt_tensor_into(&self, rhs: &Self, output: &mut Self) {
872        debug_assert_eq!(self.shape(), rhs.shape());
873        debug_assert_eq!(self.shape(), output.shape());
874        simd::cmp_dispatch(
875            self.data(),
876            rhs.data(),
877            output.data_mut(),
878            simd::CmpKind::Gt,
879        );
880    }
881
882    /// Element-wise equality check writing into a pre-allocated output tensor.
883    /// `self`, `rhs`, and `output` must all have the same shape.
884    pub fn eq_tensor_into(&self, rhs: &Self, output: &mut Self) {
885        debug_assert_eq!(self.shape(), rhs.shape());
886        debug_assert_eq!(self.shape(), output.shape());
887        simd::cmp_dispatch(
888            self.data(),
889            rhs.data(),
890            output.data_mut(),
891            simd::CmpKind::Eq,
892        );
893    }
894
895    /// Element-wise less-than writing into a pre-allocated output tensor.
896    /// `self`, `rhs`, and `output` must all have the same shape.
897    pub fn lt_tensor_into(&self, rhs: &Self, output: &mut Self) {
898        debug_assert_eq!(self.shape(), rhs.shape());
899        debug_assert_eq!(self.shape(), output.shape());
900        simd::cmp_dispatch(
901            self.data(),
902            rhs.data(),
903            output.data_mut(),
904            simd::CmpKind::Lt,
905        );
906    }
907
908    /// Element-wise not-equal: 1.0 where not equal (diff.abs() >= 1e-7), 0.0 otherwise.
909    pub fn ne_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
910        if self.shape() == rhs.shape() {
911            return self.binary_same_shape(
912                rhs,
913                |l, r| {
914                    if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 }
915                },
916            );
917        }
918        self.binary_broadcast_op(rhs, |l, r| if (l - r).abs() >= 1e-7 { 1.0 } else { 0.0 })
919    }
920
921    /// Element-wise less-than-or-equal: 1.0 where `self <= rhs`, 0.0 otherwise.
922    pub fn le_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
923        if self.shape() == rhs.shape() {
924            return self.binary_same_shape(rhs, |l, r| if l <= r { 1.0 } else { 0.0 });
925        }
926        self.binary_broadcast_op(rhs, |l, r| if l <= r { 1.0 } else { 0.0 })
927    }
928
929    /// Element-wise greater-than-or-equal: 1.0 where `self >= rhs`, 0.0 otherwise.
930    pub fn ge_tensor(&self, rhs: &Self) -> Result<Self, TensorError> {
931        if self.shape() == rhs.shape() {
932            return self.binary_same_shape(rhs, |l, r| if l >= r { 1.0 } else { 0.0 });
933        }
934        self.binary_broadcast_op(rhs, |l, r| if l >= r { 1.0 } else { 0.0 })
935    }
936
937    /// Returns true if all elements are finite (no NaN or Inf).
938    pub fn all_finite(&self) -> bool {
939        self.data().iter().all(|v| v.is_finite())
940    }
941
942    // ── Advanced indexing/selection ─────────────────────────────────────
943
944    /// Element-wise where: `condition ? self : other`.
945    /// `condition` has 1.0 for true, 0.0 for false.
946    pub fn where_cond(&self, condition: &Self, other: &Self) -> Result<Self, TensorError> {
947        if self.shape() != condition.shape() || self.shape() != other.shape() {
948            return Err(TensorError::ShapeMismatch {
949                left: self.shape().to_vec(),
950                right: condition.shape().to_vec(),
951            });
952        }
953        let data: Vec<f32> = condition
954            .data()
955            .iter()
956            .zip(self.data().iter())
957            .zip(other.data().iter())
958            .map(|((&c, &t), &f)| if c != 0.0 { t } else { f })
959            .collect();
960        Tensor::from_vec(self.shape().to_vec(), data)
961    }
962
963    /// Replace elements where `mask != 0` with `value`.
964    pub fn masked_fill(&self, mask: &Self, value: f32) -> Result<Self, TensorError> {
965        if self.shape() != mask.shape() {
966            return Err(TensorError::ShapeMismatch {
967                left: self.shape().to_vec(),
968                right: mask.shape().to_vec(),
969            });
970        }
971        let data: Vec<f32> = self
972            .data()
973            .iter()
974            .zip(mask.data().iter())
975            .map(|(&v, &m)| if m != 0.0 { value } else { v })
976            .collect();
977        Tensor::from_vec(self.shape().to_vec(), data)
978    }
979
980    /// Scatter values into self along `axis` at positions given by `index`.
981    /// `src` provides the values. `index` has same shape as `src`.
982    pub fn scatter(&self, axis: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
983        if index.shape() != src.shape() {
984            return Err(TensorError::ShapeMismatch {
985                left: index.shape().to_vec(),
986                right: src.shape().to_vec(),
987            });
988        }
989        if axis >= self.rank() {
990            return Err(TensorError::InvalidAxis {
991                axis,
992                rank: self.rank(),
993            });
994        }
995        let mut out = self.data().to_vec();
996        let shape = index.shape();
997        let outer: usize = shape[..axis].iter().product();
998        let dim = shape[axis];
999        let inner: usize = shape[axis + 1..].iter().product();
1000        let self_dim = self.shape()[axis];
1001        let self_inner: usize = self.shape()[axis + 1..].iter().product();
1002
1003        for o in 0..outer {
1004            for d in 0..dim {
1005                for i in 0..inner {
1006                    let src_idx = (o * dim + d) * inner + i;
1007                    let target_d = index.data()[src_idx] as usize;
1008                    if target_d < self_dim {
1009                        let out_idx = (o * self_dim + target_d) * self_inner + i;
1010                        if out_idx < out.len() {
1011                            out[out_idx] = src.data()[src_idx];
1012                        }
1013                    }
1014                }
1015            }
1016        }
1017        Tensor::from_vec(self.shape().to_vec(), out)
1018    }
1019
1020    /// Gather elements along `axis` at positions given by `index`.
1021    pub fn gather(&self, axis: usize, index: &Self) -> Result<Self, TensorError> {
1022        if axis >= self.rank() {
1023            return Err(TensorError::InvalidAxis {
1024                axis,
1025                rank: self.rank(),
1026            });
1027        }
1028        let shape = index.shape();
1029        let outer: usize = shape[..axis].iter().product();
1030        let dim = shape[axis];
1031        let inner: usize = shape[axis + 1..].iter().product();
1032        let self_dim = self.shape()[axis];
1033        let self_inner: usize = self.shape()[axis + 1..].iter().product();
1034
1035        let mut out = vec![0.0f32; index.len()];
1036        for o in 0..outer {
1037            for d in 0..dim {
1038                for i in 0..inner {
1039                    let idx_pos = (o * dim + d) * inner + i;
1040                    let src_d = index.data()[idx_pos] as usize;
1041                    if src_d < self_dim {
1042                        let src_pos = (o * self_dim + src_d) * self_inner + i;
1043                        if src_pos < self.len() {
1044                            out[idx_pos] = self.data()[src_pos];
1045                        }
1046                    }
1047                }
1048            }
1049        }
1050        Tensor::from_vec(shape.to_vec(), out)
1051    }
1052
1053    /// Returns the top-k values and their indices along the last axis.
1054    pub fn topk(&self, k: usize) -> Result<(Self, Self), TensorError> {
1055        if self.rank() == 0 {
1056            return Err(TensorError::InvalidAxis { axis: 0, rank: 0 });
1057        }
1058        let last_dim = *self.shape().last().expect("non-empty shape");
1059        let k = k.min(last_dim);
1060        let outer: usize = self.len() / last_dim;
1061
1062        let mut values = Vec::with_capacity(outer * k);
1063        let mut indices = Vec::with_capacity(outer * k);
1064
1065        for o in 0..outer {
1066            let start = o * last_dim;
1067            let slice = &self.data()[start..start + last_dim];
1068            let mut idx_vec: Vec<usize> = (0..last_dim).collect();
1069            idx_vec.sort_unstable_by(|&a, &b| {
1070                slice[b]
1071                    .partial_cmp(&slice[a])
1072                    .unwrap_or(std::cmp::Ordering::Equal)
1073            });
1074            for &i in &idx_vec[..k] {
1075                values.push(slice[i]);
1076                indices.push(i as f32);
1077            }
1078        }
1079
1080        let mut out_shape = self.shape().to_vec();
1081        *out_shape.last_mut().expect("non-empty shape") = k;
1082        let val_t = Tensor::from_vec(out_shape.clone(), values)?;
1083        let idx_t = Tensor::from_vec(out_shape, indices)?;
1084        Ok((val_t, idx_t))
1085    }
1086
1087    /// Upper triangular mask: zero below diagonal, keep above.
1088    /// `diagonal` shifts: 0 = main, positive = above, negative = below.
1089    pub fn triu(&self, diagonal: i64) -> Result<Self, TensorError> {
1090        if self.rank() < 2 {
1091            return Err(TensorError::InvalidAxis {
1092                axis: 0,
1093                rank: self.rank(),
1094            });
1095        }
1096        let shape = self.shape();
1097        let rows = shape[shape.len() - 2];
1098        let cols = shape[shape.len() - 1];
1099        let batch: usize = shape[..shape.len() - 2].iter().product();
1100        let mut out = self.data().to_vec();
1101        for b in 0..batch {
1102            for r in 0..rows {
1103                for c in 0..cols {
1104                    if (c as i64) < (r as i64) + diagonal {
1105                        out[b * rows * cols + r * cols + c] = 0.0;
1106                    }
1107                }
1108            }
1109        }
1110        Tensor::from_vec(shape.to_vec(), out)
1111    }
1112
1113    /// Lower triangular mask: zero above diagonal, keep below.
1114    pub fn tril(&self, diagonal: i64) -> Result<Self, TensorError> {
1115        if self.rank() < 2 {
1116            return Err(TensorError::InvalidAxis {
1117                axis: 0,
1118                rank: self.rank(),
1119            });
1120        }
1121        let shape = self.shape();
1122        let rows = shape[shape.len() - 2];
1123        let cols = shape[shape.len() - 1];
1124        let batch: usize = shape[..shape.len() - 2].iter().product();
1125        let mut out = self.data().to_vec();
1126        for b in 0..batch {
1127            for r in 0..rows {
1128                for c in 0..cols {
1129                    if (c as i64) > (r as i64) + diagonal {
1130                        out[b * rows * cols + r * cols + c] = 0.0;
1131                    }
1132                }
1133            }
1134        }
1135        Tensor::from_vec(shape.to_vec(), out)
1136    }
1137
1138    /// Identity matrix `[n, n]`.
1139    pub fn eye(n: usize) -> Result<Self, TensorError> {
1140        let mut data = vec![0.0f32; n * n];
1141        for i in 0..n {
1142            data[i * n + i] = 1.0;
1143        }
1144        Tensor::from_vec(vec![n, n], data)
1145    }
1146
1147    /// Create a diagonal matrix from a 1D vector.
1148    pub fn diag(vector: &Tensor) -> Result<Self, TensorError> {
1149        let shape = vector.shape();
1150        if shape.len() != 1 {
1151            return Err(TensorError::UnsupportedOperation {
1152                msg: format!("diag requires a 1D tensor, got shape {:?}", shape),
1153            });
1154        }
1155        let n = shape[0];
1156        let mut data = vec![0.0f32; n * n];
1157        for i in 0..n {
1158            data[i * n + i] = vector.data()[i];
1159        }
1160        Self::from_vec(vec![n, n], data)
1161    }
1162
1163    /// Extract the diagonal of a 2D matrix as a 1D vector.
1164    pub fn diag_extract(&self) -> Result<Self, TensorError> {
1165        let shape = self.shape();
1166        if shape.len() != 2 {
1167            return Err(TensorError::UnsupportedOperation {
1168                msg: format!("diag_extract requires a 2D tensor, got shape {:?}", shape),
1169            });
1170        }
1171        let n = shape[0].min(shape[1]);
1172        let cols = shape[1];
1173        let data: Vec<f32> = (0..n).map(|i| self.data()[i * cols + i]).collect();
1174        Self::from_vec(vec![n], data)
1175    }
1176
1177    /// Pad the tensor with a constant value. `padding` is a slice of (before, after) per dimension.
1178    pub fn pad(&self, padding: &[(usize, usize)], value: f32) -> Result<Self, TensorError> {
1179        let shape = self.shape();
1180        if padding.len() != shape.len() {
1181            return Err(TensorError::InvalidIndexRank {
1182                expected: shape.len(),
1183                got: padding.len(),
1184            });
1185        }
1186        let new_shape: Vec<usize> = shape
1187            .iter()
1188            .zip(padding)
1189            .map(|(&s, &(b, a))| s + b + a)
1190            .collect();
1191        let new_size: usize = new_shape.iter().product();
1192        let mut result = vec![value; new_size];
1193        let ndim = shape.len();
1194
1195        // Compute strides for both old and new shapes
1196        let mut old_strides = vec![1usize; ndim];
1197        for i in (0..ndim.saturating_sub(1)).rev() {
1198            old_strides[i] = old_strides[i + 1] * shape[i + 1];
1199        }
1200        let mut new_strides = vec![1usize; ndim];
1201        for i in (0..ndim.saturating_sub(1)).rev() {
1202            new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1203        }
1204
1205        let old_size: usize = shape.iter().product();
1206        let data = self.data();
1207        for flat_idx in 0..old_size {
1208            let mut remaining = flat_idx;
1209            let mut new_flat = 0;
1210            for d in 0..ndim {
1211                let coord = remaining / old_strides[d];
1212                remaining %= old_strides[d];
1213                new_flat += (coord + padding[d].0) * new_strides[d];
1214            }
1215            result[new_flat] = data[flat_idx];
1216        }
1217
1218        Self::from_vec(new_shape, result)
1219    }
1220
1221    /// Repeat tensor along each axis by the given counts.
1222    pub fn repeat(&self, counts: &[usize]) -> Result<Self, TensorError> {
1223        if counts.len() != self.rank() {
1224            return Err(TensorError::InvalidIndexRank {
1225                expected: self.rank(),
1226                got: counts.len(),
1227            });
1228        }
1229        let mut out = self.clone();
1230        for (axis, &count) in counts.iter().enumerate() {
1231            if count > 1 {
1232                let refs: Vec<&Tensor> = std::iter::repeat_n(&out, count).collect();
1233                out = Tensor::cat(&refs, axis)?;
1234            }
1235        }
1236        Ok(out)
1237    }
1238
1239    // ── Cumulative operations ──────────────────────────────────────────
1240
1241    /// Cumulative sum along an axis.
1242    pub fn cumsum(&self, axis: usize) -> Result<Self, TensorError> {
1243        if axis >= self.rank() {
1244            return Err(TensorError::InvalidAxis {
1245                axis,
1246                rank: self.rank(),
1247            });
1248        }
1249        let shape = self.shape();
1250        let outer: usize = shape[..axis].iter().product();
1251        let axis_len = shape[axis];
1252        let inner: usize = shape[axis + 1..].iter().product();
1253        let mut out = self.data().to_vec();
1254
1255        for o in 0..outer {
1256            for i in 0..inner {
1257                let mut acc = 0.0f32;
1258                for d in 0..axis_len {
1259                    let idx = (o * axis_len + d) * inner + i;
1260                    acc += out[idx];
1261                    out[idx] = acc;
1262                }
1263            }
1264        }
1265        Tensor::from_vec(shape.to_vec(), out)
1266    }
1267
1268    /// Cumulative product along an axis.
1269    pub fn cumprod(&self, axis: usize) -> Result<Self, TensorError> {
1270        if axis >= self.rank() {
1271            return Err(TensorError::InvalidAxis {
1272                axis,
1273                rank: self.rank(),
1274            });
1275        }
1276        let shape = self.shape();
1277        let outer: usize = shape[..axis].iter().product();
1278        let axis_len = shape[axis];
1279        let inner: usize = shape[axis + 1..].iter().product();
1280        let mut out = self.data().to_vec();
1281
1282        for o in 0..outer {
1283            for i in 0..inner {
1284                let mut acc = 1.0f32;
1285                for d in 0..axis_len {
1286                    let idx = (o * axis_len + d) * inner + i;
1287                    acc *= out[idx];
1288                    out[idx] = acc;
1289                }
1290            }
1291        }
1292        Tensor::from_vec(shape.to_vec(), out)
1293    }
1294
1295    // ── FP16 conversion ────────────────────────────────────────────────
1296
1297    /// Convert all elements to IEEE 754 half-precision (FP16) bytes.
1298    /// Returns `Vec<u16>` where each u16 is an FP16 bit pattern.
1299    pub fn to_fp16(&self) -> Vec<u16> {
1300        self.data().iter().map(|&v| f32_to_fp16(v)).collect()
1301    }
1302
1303    /// Create a tensor from FP16 bit patterns.
1304    pub fn from_fp16(shape: Vec<usize>, fp16_data: &[u16]) -> Result<Self, TensorError> {
1305        let data: Vec<f32> = fp16_data.iter().map(|&v| fp16_to_f32(v)).collect();
1306        Tensor::from_vec(shape, data)
1307    }
1308
1309    // ── Internal helpers ────────────────────────────────────────────────
1310
1311    #[allow(unsafe_code)]
1312    fn unary_op<F>(&self, op: F) -> Self
1313    where
1314        F: Fn(f32) -> f32,
1315    {
1316        let src = self.data();
1317        let len = src.len();
1318        // SAFETY: `uninitialized` allocates without zeroing.  The loop below
1319        // writes every element before we ever read from `out_data`.
1320        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1321        let inp = src.as_ptr();
1322        let outp = out_data.as_mut_ptr();
1323        unsafe {
1324            for i in 0..len {
1325                *outp.add(i) = op(*inp.add(i));
1326            }
1327        }
1328        Tensor::from_raw_parts(self.shape(), self.strides(), out_data)
1329    }
1330
1331    #[allow(unsafe_code)]
1332    fn unary_simd_op(&self, kind: simd::UnaryKind) -> Self {
1333        let len = self.len();
1334        // SAFETY: `uninitialized` allocates without zeroing.  `unary_dispatch`
1335        // writes every element before we ever read from `out`.
1336        let mut out = AlignedVec::<f32>::uninitialized(len);
1337        simd::unary_dispatch(self.data(), &mut out, kind);
1338        Tensor::from_raw_parts(self.shape(), self.strides(), out)
1339    }
1340
1341    fn reduce_axis<F>(&self, axis: usize, init: f32, combine: F) -> Result<Self, TensorError>
1342    where
1343        F: Fn(f32, f32) -> f32,
1344    {
1345        if axis >= self.rank() {
1346            return Err(TensorError::InvalidAxis {
1347                axis,
1348                rank: self.rank(),
1349            });
1350        }
1351
1352        let mut out_shape = self.shape().to_vec();
1353        out_shape.remove(axis);
1354        let out_count =
1355            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1356                shape: out_shape.clone(),
1357            })?;
1358        let out_strides = compute_strides(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1359            shape: out_shape.clone(),
1360        })?;
1361        let mut out_data = vec![init; out_count];
1362
1363        let mut in_coords = vec![0usize; self.rank()];
1364        for input in self.data().iter().copied() {
1365            let mut out_offset = 0usize;
1366            for (src_axis, coord) in in_coords.iter().copied().enumerate() {
1367                if src_axis == axis {
1368                    continue;
1369                }
1370                let dst_axis = if src_axis < axis {
1371                    src_axis
1372                } else {
1373                    src_axis - 1
1374                };
1375                if !out_shape.is_empty() {
1376                    out_offset += coord * out_strides[dst_axis];
1377                }
1378            }
1379            out_data[out_offset] = combine(out_data[out_offset], input);
1380            increment_coords(&mut in_coords, self.shape());
1381        }
1382
1383        Tensor::from_vec(out_shape, out_data)
1384    }
1385
1386    /// SIMD fast path for broadcasting a last-dim vector across all rows.
1387    ///
1388    /// Matches patterns like `[N, C] + [C]`, `[N, H, W, C] + [1, C]`, etc.
1389    /// Returns `None` if the pattern doesn't match, so the caller falls through
1390    /// to the generic broadcast loop.
1391    #[allow(unsafe_code)]
1392    fn binary_broadcast_lastdim_simd(
1393        &self,
1394        rhs: &Self,
1395        kind: simd::BinaryKind,
1396    ) -> Option<Result<Self, TensorError>> {
1397        let lhs_shape = self.shape();
1398        let rhs_shape = rhs.shape();
1399
1400        // Detect: rhs is a 1-D vector whose length equals lhs's last dim,
1401        // or rhs shape is all-1 except the last dim which matches lhs's last dim.
1402        let lhs_last = *lhs_shape.last()?;
1403        if lhs_last == 0 {
1404            return None;
1405        }
1406
1407        let rhs_last = *rhs_shape.last()?;
1408
1409        // Check rhs is effectively a 1-D vector of size lhs_last:
1410        // either shape [C] or shape [1, 1, ..., C]
1411        let rhs_is_lastdim_vec =
1412            rhs_last == lhs_last && rhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1413        // Also handle the symmetric case: lhs is the vector, rhs is the big tensor
1414        let lhs_is_lastdim_vec =
1415            lhs_last == rhs_last && lhs_shape.iter().rev().skip(1).all(|&d| d == 1);
1416
1417        if rhs_is_lastdim_vec && !lhs_is_lastdim_vec {
1418            // Broadcast rhs across all rows of lhs
1419            let lhs_data = self.data();
1420            let rhs_data = rhs.data();
1421            let row_len = lhs_last;
1422            let num_rows = lhs_data.len() / row_len;
1423            let mut out_data = AlignedVec::<f32>::uninitialized(lhs_data.len());
1424
1425            for i in 0..num_rows {
1426                let start = i * row_len;
1427                let end = start + row_len;
1428                simd::binary_dispatch(
1429                    &lhs_data[start..end],
1430                    &rhs_data[..row_len],
1431                    &mut out_data[start..end],
1432                    kind,
1433                );
1434            }
1435
1436            let out_strides = compute_strides(lhs_shape).expect("valid shape for strides");
1437            Some(Ok(Tensor::from_raw_parts(
1438                lhs_shape,
1439                &out_strides,
1440                out_data,
1441            )))
1442        } else if lhs_is_lastdim_vec && !rhs_is_lastdim_vec {
1443            // Broadcast lhs across all rows of rhs
1444            let lhs_data = self.data();
1445            let rhs_data = rhs.data();
1446            let row_len = rhs_last;
1447            let num_rows = rhs_data.len() / row_len;
1448            let mut out_data = AlignedVec::<f32>::uninitialized(rhs_data.len());
1449
1450            for i in 0..num_rows {
1451                let start = i * row_len;
1452                let end = start + row_len;
1453                simd::binary_dispatch(
1454                    &lhs_data[..row_len],
1455                    &rhs_data[start..end],
1456                    &mut out_data[start..end],
1457                    kind,
1458                );
1459            }
1460
1461            let out_strides = compute_strides(rhs_shape).expect("valid shape for strides");
1462            Some(Ok(Tensor::from_raw_parts(
1463                rhs_shape,
1464                &out_strides,
1465                out_data,
1466            )))
1467        } else {
1468            None
1469        }
1470    }
1471
1472    fn binary_broadcast_op<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1473    where
1474        F: Fn(f32, f32) -> f32,
1475    {
1476        let out_shape = broadcast_shape(self.shape(), rhs.shape()).ok_or_else(|| {
1477            TensorError::BroadcastIncompatible {
1478                left: self.shape().to_vec(),
1479                right: rhs.shape().to_vec(),
1480            }
1481        })?;
1482
1483        let out_count =
1484            shape_element_count(&out_shape).ok_or_else(|| TensorError::SizeOverflow {
1485                shape: out_shape.clone(),
1486            })?;
1487        let mut out_data = vec![0.0; out_count];
1488        let mut coords = vec![0usize; out_shape.len()];
1489
1490        for value in &mut out_data {
1491            let left_offset = broadcast_offset(self.shape(), self.strides(), &coords);
1492            let right_offset = broadcast_offset(rhs.shape(), rhs.strides(), &coords);
1493            *value = op(self.data()[left_offset], rhs.data()[right_offset]);
1494            increment_coords(&mut coords, &out_shape);
1495        }
1496
1497        Tensor::from_vec(out_shape, out_data)
1498    }
1499
1500    /// SIMD-accelerated binary op for same-shape tensors (add/sub/mul/div).
1501    #[allow(unsafe_code)]
1502    #[allow(unsafe_code)]
1503    fn binary_same_shape_simd(
1504        &self,
1505        rhs: &Self,
1506        kind: simd::BinaryKind,
1507    ) -> Result<Self, TensorError> {
1508        let len = self.len();
1509        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1510
1511        // Multi-threaded for large tensors. Cross-platform:
1512        // macOS → GCD dispatch_apply, others → rayon thread pool.
1513        if len >= 100_000 {
1514            let n = std::thread::available_parallelism()
1515                .map(|p| p.get())
1516                .unwrap_or(4);
1517            let chunk = len.div_ceil(n);
1518            let lp = self.data().as_ptr() as usize;
1519            let rp = rhs.data().as_ptr() as usize;
1520            let op = out_data.as_mut_ptr() as usize;
1521
1522            #[cfg(target_os = "macos")]
1523            {
1524                use std::ffi::c_void;
1525                #[allow(unsafe_code)]
1526                unsafe extern "C" {
1527                    fn dispatch_get_global_queue(id: isize, flags: usize) -> *const c_void;
1528                    fn dispatch_apply_f(
1529                        n: usize,
1530                        q: *const c_void,
1531                        ctx: *mut c_void,
1532                        work: unsafe extern "C" fn(*mut c_void, usize),
1533                    );
1534                }
1535                struct Ctx {
1536                    lp: usize,
1537                    rp: usize,
1538                    op: usize,
1539                    chunk: usize,
1540                    len: usize,
1541                    kind: simd::BinaryKind,
1542                }
1543                let ctx = Ctx {
1544                    lp,
1545                    rp,
1546                    op,
1547                    chunk,
1548                    len,
1549                    kind,
1550                };
1551                unsafe extern "C" fn work(raw: *mut c_void, t: usize) {
1552                    let c = unsafe { &*(raw as *const Ctx) };
1553                    let start = t * c.chunk;
1554                    let end = (start + c.chunk).min(c.len);
1555                    if start >= end {
1556                        return;
1557                    }
1558                    let l = unsafe {
1559                        std::slice::from_raw_parts((c.lp as *const f32).add(start), end - start)
1560                    };
1561                    let r = unsafe {
1562                        std::slice::from_raw_parts((c.rp as *const f32).add(start), end - start)
1563                    };
1564                    let o = unsafe {
1565                        std::slice::from_raw_parts_mut((c.op as *mut f32).add(start), end - start)
1566                    };
1567                    simd::binary_dispatch(l, r, o, c.kind);
1568                }
1569                let q = unsafe { dispatch_get_global_queue(0, 0) };
1570                unsafe {
1571                    dispatch_apply_f(n, q, &ctx as *const Ctx as *mut c_void, work);
1572                }
1573            }
1574
1575            #[cfg(not(target_os = "macos"))]
1576            {
1577                // Rayon global thread pool — threads pre-spawned, ~0.5µs dispatch.
1578                use rayon::prelude::*;
1579                (0..n).into_par_iter().for_each(|t| {
1580                    let start = t * chunk;
1581                    let end = (start + chunk).min(len);
1582                    if start >= end {
1583                        return;
1584                    }
1585                    let l = unsafe {
1586                        std::slice::from_raw_parts((lp as *const f32).add(start), end - start)
1587                    };
1588                    let r = unsafe {
1589                        std::slice::from_raw_parts((rp as *const f32).add(start), end - start)
1590                    };
1591                    let o = unsafe {
1592                        std::slice::from_raw_parts_mut((op as *mut f32).add(start), end - start)
1593                    };
1594                    simd::binary_dispatch(l, r, o, kind);
1595                });
1596            }
1597
1598            return Ok(Tensor::from_raw_parts(
1599                self.shape(),
1600                self.strides(),
1601                out_data,
1602            ));
1603        }
1604
1605        simd::binary_dispatch(self.data(), rhs.data(), &mut out_data, kind);
1606        Ok(Tensor::from_raw_parts(
1607            self.shape(),
1608            self.strides(),
1609            out_data,
1610        ))
1611    }
1612
1613    #[allow(unsafe_code)]
1614    fn binary_same_shape<F>(&self, rhs: &Self, op: F) -> Result<Self, TensorError>
1615    where
1616        F: Fn(f32, f32) -> f32,
1617    {
1618        let len = self.len();
1619        // SAFETY: `uninitialized` allocates without zeroing.  The loop below
1620        // writes every element before we ever read from `out_data`.
1621        let mut out_data = AlignedVec::<f32>::uninitialized(len);
1622
1623        let lhs_ptr = self.data().as_ptr();
1624        let rhs_ptr = rhs.data().as_ptr();
1625        let out_ptr = out_data.as_mut_ptr();
1626
1627        // SAFETY:
1628        // - Pointers originate from valid slices/vectors of length `len`.
1629        // - Loop bounds guarantee all pointer arithmetic remains in-bounds.
1630        // - `out_data` is uniquely mutable and does not alias with input buffers.
1631        unsafe {
1632            let mut index = 0usize;
1633            while index + 4 <= len {
1634                *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1635                *out_ptr.add(index + 1) = op(*lhs_ptr.add(index + 1), *rhs_ptr.add(index + 1));
1636                *out_ptr.add(index + 2) = op(*lhs_ptr.add(index + 2), *rhs_ptr.add(index + 2));
1637                *out_ptr.add(index + 3) = op(*lhs_ptr.add(index + 3), *rhs_ptr.add(index + 3));
1638                index += 4;
1639            }
1640            while index < len {
1641                *out_ptr.add(index) = op(*lhs_ptr.add(index), *rhs_ptr.add(index));
1642                index += 1;
1643            }
1644        }
1645
1646        Ok(Tensor::from_raw_parts(
1647            self.shape(),
1648            self.strides(),
1649            out_data,
1650        ))
1651    }
1652
1653    // ── In-place operations ──────────────────────────────────────────────
1654
1655    /// In-place ReLU: clamp negative values to zero.
1656    pub fn relu_inplace(&mut self) {
1657        simd::relu_inplace_dispatch(self.data_mut());
1658    }
1659
1660    /// In-place element-wise add from another tensor (must have same shape).
1661    pub fn add_inplace(&mut self, rhs: &Self) {
1662        debug_assert_eq!(self.len(), rhs.len());
1663        simd::add_inplace_dispatch(self.data_mut(), rhs.data());
1664    }
1665
1666    /// In-place add scalar to all elements.
1667    pub fn add_scalar_inplace(&mut self, s: f32) {
1668        simd::add_scalar_inplace_dispatch(self.data_mut(), s);
1669    }
1670
1671    /// In-place multiply all elements by scalar.
1672    pub fn mul_scalar_inplace(&mut self, s: f32) {
1673        simd::mul_scalar_inplace_dispatch(self.data_mut(), s);
1674    }
1675
1676    // ── Binary-into operations (write into pre-allocated output) ─────
1677
1678    /// Element-wise addition writing into a pre-allocated output tensor.
1679    /// `output`, `lhs`, and `rhs` must all have the same shape.
1680    pub fn add_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1681        debug_assert_eq!(lhs.shape(), rhs.shape());
1682        debug_assert_eq!(lhs.shape(), output.shape());
1683        simd::binary_dispatch(
1684            lhs.data(),
1685            rhs.data(),
1686            output.data_mut(),
1687            simd::BinaryKind::Add,
1688        );
1689    }
1690
1691    /// Element-wise subtraction writing into a pre-allocated output tensor.
1692    /// `output`, `lhs`, and `rhs` must all have the same shape.
1693    pub fn sub_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1694        debug_assert_eq!(lhs.shape(), rhs.shape());
1695        debug_assert_eq!(lhs.shape(), output.shape());
1696        simd::binary_dispatch(
1697            lhs.data(),
1698            rhs.data(),
1699            output.data_mut(),
1700            simd::BinaryKind::Sub,
1701        );
1702    }
1703
1704    /// Element-wise multiplication writing into a pre-allocated output tensor.
1705    /// `output`, `lhs`, and `rhs` must all have the same shape.
1706    pub fn mul_into(output: &mut Self, lhs: &Self, rhs: &Self) {
1707        debug_assert_eq!(lhs.shape(), rhs.shape());
1708        debug_assert_eq!(lhs.shape(), output.shape());
1709        simd::binary_dispatch(
1710            lhs.data(),
1711            rhs.data(),
1712            output.data_mut(),
1713            simd::BinaryKind::Mul,
1714        );
1715    }
1716}
1717
1718// ── atan2 helper ───────────────────────────────────────────────────
1719
1720/// Scalar atan2(y, x) using Cephes-style range reduction of atan.
1721///
1722/// Range-reduces the argument to [0, tan(pi/12)] ≈ [0, 0.268] using:
1723///   - If z > tan(3*pi/8) ≈ 2.414: atan(z) = pi/2 - atan(1/z)
1724///   - If z > tan(pi/8) ≈ 0.414: atan(z) = pi/4 + atan((z-1)/(z+1))
1725///
1726/// Then uses a degree-7 polynomial on the reduced argument.
1727/// Max error < 4e-7 across all inputs.
1728#[allow(clippy::excessive_precision)]
1729#[inline(always)]
1730fn fast_atan2_scalar(y: f32, x: f32) -> f32 {
1731    const PI: f32 = std::f32::consts::PI;
1732    const FRAC_PI_2: f32 = std::f32::consts::FRAC_PI_2;
1733    const FRAC_PI_4: f32 = std::f32::consts::FRAC_PI_4;
1734    const TAN_3PI_8: f32 = 2.414_213_6; // tan(3*pi/8) = 1 + sqrt(2)
1735    const TAN_PI_8: f32 = 0.414_213_57; // tan(pi/8) = sqrt(2) - 1
1736
1737    let ax = x.abs();
1738    let ay = y.abs();
1739
1740    // Compute atan(|y|/|x|) with range reduction
1741    let (num, den, swap) = if ax >= ay {
1742        (ay, ax, false)
1743    } else {
1744        (ax, ay, true)
1745    };
1746    let z = if den > 0.0 { num / den } else { 0.0 };
1747
1748    // Range reduction for atan(z), z >= 0
1749    let (z_red, bias) = if z > TAN_3PI_8 {
1750        // Should not happen since z <= 1 from our swap, but just in case
1751        (-1.0 / z, FRAC_PI_2)
1752    } else if z > TAN_PI_8 {
1753        ((z - 1.0) / (z + 1.0), FRAC_PI_4)
1754    } else {
1755        (z, 0.0)
1756    };
1757
1758    // Polynomial: atan(z) ≈ z + z³·P(z²) for small z
1759    // Coefficients from Cephes atanf.c (S. Moshier), reordered for Horner.
1760    let z2 = z_red * z_red;
1761    let p: f32 = 8.054_666e-02;
1762    let p = p.mul_add(z2, -1.384_895_1e-01);
1763    let p = p.mul_add(z2, 1.997_075_8e-01);
1764    let p = p.mul_add(z2, -3.333_129_8e-01);
1765    let atan_z = z_red.mul_add(z2 * p, z_red) + bias;
1766
1767    // If we swapped: atan(|y|/|x|) = pi/2 - atan(|x|/|y|)
1768    let mut result = if swap { FRAC_PI_2 - atan_z } else { atan_z };
1769
1770    // Quadrant correction
1771    if x < 0.0 {
1772        result = PI - result;
1773    }
1774    if y < 0.0 {
1775        result = -result;
1776    }
1777
1778    result
1779}
1780
1781// ── FP16 conversion utilities ──────────────────────────────────────
1782
1783/// Convert f32 to IEEE 754 half-precision (FP16) bit pattern.
1784fn f32_to_fp16(val: f32) -> u16 {
1785    let bits = val.to_bits();
1786    let sign = ((bits >> 16) & 0x8000) as u16;
1787    let exponent = ((bits >> 23) & 0xFF) as i32;
1788    let mantissa = bits & 0x007F_FFFF;
1789
1790    if exponent == 255 {
1791        // Inf or NaN
1792        return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
1793    }
1794
1795    let unbiased = exponent - 127;
1796    if unbiased > 15 {
1797        return sign | 0x7C00; // overflow -> Inf
1798    }
1799    if unbiased < -24 {
1800        return sign; // underflow -> zero
1801    }
1802    if unbiased < -14 {
1803        // subnormal FP16
1804        let shift = -1 - unbiased;
1805        let m = (mantissa | 0x0080_0000) >> (shift + 13);
1806        return sign | m as u16;
1807    }
1808
1809    let fp16_exp = ((unbiased + 15) as u16) << 10;
1810    let fp16_man = (mantissa >> 13) as u16;
1811    sign | fp16_exp | fp16_man
1812}
1813
1814impl Tensor {
1815    // ── Sort / argsort / unique / nonzero ──────────────────────────────
1816
1817    /// Sort along `dim`. Returns `(sorted_values, sorted_indices)`.
1818    ///
1819    /// If `descending` is true, values are sorted largest-first.
1820    pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self), TensorError> {
1821        if dim >= self.rank() {
1822            return Err(TensorError::InvalidAxis {
1823                axis: dim,
1824                rank: self.rank(),
1825            });
1826        }
1827        let shape = self.shape();
1828        let outer: usize = shape[..dim].iter().product();
1829        let dim_len = shape[dim];
1830        let inner: usize = shape[dim + 1..].iter().product();
1831        let data = self.data();
1832
1833        let mut out_vals = vec![0.0f32; data.len()];
1834        let mut out_idxs = vec![0.0f32; data.len()];
1835
1836        for o in 0..outer {
1837            for i in 0..inner {
1838                let mut idx_vec: Vec<usize> = (0..dim_len).collect();
1839                idx_vec.sort_unstable_by(|&a, &b| {
1840                    let va = data[(o * dim_len + a) * inner + i];
1841                    let vb = data[(o * dim_len + b) * inner + i];
1842                    if descending {
1843                        vb.partial_cmp(&va).unwrap_or(std::cmp::Ordering::Equal)
1844                    } else {
1845                        va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
1846                    }
1847                });
1848                for (rank, &src) in idx_vec.iter().enumerate() {
1849                    let dst = (o * dim_len + rank) * inner + i;
1850                    let src_pos = (o * dim_len + src) * inner + i;
1851                    out_vals[dst] = data[src_pos];
1852                    out_idxs[dst] = src as f32;
1853                }
1854            }
1855        }
1856
1857        let v = Tensor::from_vec(shape.to_vec(), out_vals)?;
1858        let idx = Tensor::from_vec(shape.to_vec(), out_idxs)?;
1859        Ok((v, idx))
1860    }
1861
1862    /// Return indices that would sort along `dim`.
1863    pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self, TensorError> {
1864        let (_, indices) = self.sort(dim, descending)?;
1865        Ok(indices)
1866    }
1867
1868    /// Return unique elements (sorted), inverse indices, and counts.
1869    pub fn unique(&self) -> (Self, Self, Self) {
1870        let data = self.data();
1871        let mut sorted: Vec<f32> = data.to_vec();
1872        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1873        sorted.dedup();
1874
1875        let mut inverse = vec![0.0f32; data.len()];
1876        let mut counts = vec![0.0f32; sorted.len()];
1877        for (i, &v) in data.iter().enumerate() {
1878            let pos = sorted
1879                .binary_search_by(|probe| {
1880                    probe.partial_cmp(&v).unwrap_or(std::cmp::Ordering::Equal)
1881                })
1882                .expect("value exists in sorted list");
1883            inverse[i] = pos as f32;
1884            counts[pos] += 1.0;
1885        }
1886
1887        let vals = Tensor::from_vec(vec![sorted.len()], sorted).expect("unique vals");
1888        let inv = Tensor::from_vec(self.shape().to_vec(), inverse).expect("unique inv");
1889        let cnt = Tensor::from_vec(vec![counts.len()], counts).expect("unique counts");
1890        (vals, inv, cnt)
1891    }
1892
1893    /// Return coordinates of nonzero elements as a 2D tensor `[N, rank]`.
1894    pub fn nonzero(&self) -> Self {
1895        let shape = self.shape();
1896        let rank = shape.len().max(1);
1897        let data = self.data();
1898        let mut coords: Vec<Vec<usize>> = Vec::new();
1899
1900        if shape.is_empty() {
1901            // scalar
1902            if data[0] != 0.0 {
1903                coords.push(vec![0]);
1904            }
1905        } else {
1906            let mut idx = vec![0usize; shape.len()];
1907            for pos in 0..data.len() {
1908                if data[pos] != 0.0 {
1909                    coords.push(idx.clone());
1910                }
1911                // increment multi-dim index
1912                for d in (0..shape.len()).rev() {
1913                    idx[d] += 1;
1914                    if idx[d] < shape[d] {
1915                        break;
1916                    }
1917                    idx[d] = 0;
1918                }
1919            }
1920        }
1921
1922        let n = coords.len();
1923        let mut flat = Vec::with_capacity(n * rank);
1924        for c in &coords {
1925            for &v in c {
1926                flat.push(v as f32);
1927            }
1928        }
1929        if n == 0 {
1930            Tensor::from_vec(vec![0, rank], flat).expect("nonzero empty")
1931        } else {
1932            Tensor::from_vec(vec![n, rank], flat).expect("nonzero")
1933        }
1934    }
1935
1936    // ── Flip / roll ────────────────────────────────────────────────────
1937
1938    /// Reverse elements along the given dimensions.
1939    pub fn flip(&self, dims: &[usize]) -> Result<Self, TensorError> {
1940        for &d in dims {
1941            if d >= self.rank() {
1942                return Err(TensorError::InvalidAxis {
1943                    axis: d,
1944                    rank: self.rank(),
1945                });
1946            }
1947        }
1948        let shape = self.shape();
1949        let data = self.data();
1950        let total = data.len();
1951        let mut out = vec![0.0f32; total];
1952        let rank = shape.len();
1953
1954        let mut src_idx = vec![0usize; rank];
1955        for pos in 0..total {
1956            // compute destination index by flipping specified dims
1957            let mut dst_idx = src_idx.clone();
1958            for &d in dims {
1959                dst_idx[d] = shape[d] - 1 - src_idx[d];
1960            }
1961            // linear offset
1962            let mut dst_pos = 0;
1963            let mut stride = 1;
1964            for d in (0..rank).rev() {
1965                dst_pos += dst_idx[d] * stride;
1966                stride *= shape[d];
1967            }
1968            out[dst_pos] = data[pos];
1969
1970            // increment src_idx
1971            for d in (0..rank).rev() {
1972                src_idx[d] += 1;
1973                if src_idx[d] < shape[d] {
1974                    break;
1975                }
1976                src_idx[d] = 0;
1977            }
1978        }
1979        Tensor::from_vec(shape.to_vec(), out)
1980    }
1981
1982    /// Circular shift elements along `dim` by `shift` positions.
1983    pub fn roll(&self, shift: i64, dim: usize) -> Result<Self, TensorError> {
1984        if dim >= self.rank() {
1985            return Err(TensorError::InvalidAxis {
1986                axis: dim,
1987                rank: self.rank(),
1988            });
1989        }
1990        let shape = self.shape();
1991        let outer: usize = shape[..dim].iter().product();
1992        let dim_len = shape[dim];
1993        let inner: usize = shape[dim + 1..].iter().product();
1994        let data = self.data();
1995
1996        let mut out = vec![0.0f32; data.len()];
1997        for o in 0..outer {
1998            for d in 0..dim_len {
1999                let dst_d = ((d as i64 + shift).rem_euclid(dim_len as i64)) as usize;
2000                for i in 0..inner {
2001                    out[(o * dim_len + dst_d) * inner + i] = data[(o * dim_len + d) * inner + i];
2002                }
2003            }
2004        }
2005        Tensor::from_vec(shape.to_vec(), out)
2006    }
2007
2008    // ── Factory: linspace / arange / meshgrid ──────────────────────────
2009
2010    /// Create a 1-D tensor of `steps` evenly spaced values from `start` to `end` (inclusive).
2011    pub fn linspace(start: f32, end: f32, steps: usize) -> Result<Self, TensorError> {
2012        if steps == 0 {
2013            return Tensor::from_vec(vec![0], vec![]);
2014        }
2015        if steps == 1 {
2016            return Tensor::from_vec(vec![1], vec![start]);
2017        }
2018        let step = (end - start) / (steps - 1) as f32;
2019        let data: Vec<f32> = (0..steps).map(|i| start + step * i as f32).collect();
2020        Tensor::from_vec(vec![steps], data)
2021    }
2022
2023    /// Create a 1-D tensor with values in `[start, end)` with given `step`.
2024    pub fn arange(start: f32, end: f32, step: f32) -> Result<Self, TensorError> {
2025        if step == 0.0 {
2026            return Err(TensorError::ShapeMismatch {
2027                left: vec![],
2028                right: vec![],
2029            });
2030        }
2031        let mut data = Vec::new();
2032        let mut v = start;
2033        if step > 0.0 {
2034            while v < end {
2035                data.push(v);
2036                v += step;
2037            }
2038        } else {
2039            while v > end {
2040                data.push(v);
2041                v += step;
2042            }
2043        }
2044        let n = data.len();
2045        Tensor::from_vec(vec![n], data)
2046    }
2047
2048    /// Create coordinate grids from 1-D tensors (numpy-style `meshgrid` with `indexing='ij'`).
2049    pub fn meshgrid(tensors: &[Self]) -> Result<Vec<Self>, TensorError> {
2050        let shape: Vec<usize> = tensors.iter().map(|t| t.len()).collect();
2051        let total: usize = shape.iter().product();
2052        let n = tensors.len();
2053        let mut result = Vec::with_capacity(n);
2054
2055        for (idx, t) in tensors.iter().enumerate() {
2056            let t_data = t.data();
2057            let mut out = vec![0.0f32; total];
2058            // stride pattern: product of dims after idx
2059            let inner: usize = shape[idx + 1..].iter().product();
2060            let outer: usize = shape[..idx].iter().product();
2061            let dim_len = shape[idx];
2062            for o in 0..outer {
2063                for d in 0..dim_len {
2064                    for i in 0..inner {
2065                        out[(o * dim_len + d) * inner + i] = t_data[d];
2066                    }
2067                }
2068            }
2069            result.push(Tensor::from_vec(shape.clone(), out)?);
2070        }
2071        Ok(result)
2072    }
2073
2074    // ── Advanced indexing extras ────────────────────────────────────────
2075
2076    /// Select elements where `mask` (f32, nonzero = true) is true, returned as 1-D.
2077    pub fn boolean_mask(&self, mask: &Self) -> Result<Self, TensorError> {
2078        if self.shape() != mask.shape() {
2079            return Err(TensorError::ShapeMismatch {
2080                left: self.shape().to_vec(),
2081                right: mask.shape().to_vec(),
2082            });
2083        }
2084        let data = self.data();
2085        let m = mask.data();
2086        let out: Vec<f32> = data
2087            .iter()
2088            .zip(m.iter())
2089            .filter(|(_, mv)| **mv != 0.0)
2090            .map(|(v, _)| *v)
2091            .collect();
2092        let n = out.len();
2093        Tensor::from_vec(vec![n], out)
2094    }
2095
2096    /// Select slices along `dim` using integer `indices` tensor (1-D).
2097    pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self, TensorError> {
2098        if dim >= self.rank() {
2099            return Err(TensorError::InvalidAxis {
2100                axis: dim,
2101                rank: self.rank(),
2102            });
2103        }
2104        let shape = self.shape();
2105        let idx_data = indices.data();
2106        let n_idx = idx_data.len();
2107        let outer: usize = shape[..dim].iter().product();
2108        let dim_len = shape[dim];
2109        let inner: usize = shape[dim + 1..].iter().product();
2110        let data = self.data();
2111
2112        let mut out = Vec::with_capacity(outer * n_idx * inner);
2113        for o in 0..outer {
2114            for &idx_f in idx_data {
2115                let idx = idx_f as usize;
2116                if idx >= dim_len {
2117                    return Err(TensorError::IndexOutOfBounds {
2118                        axis: dim,
2119                        index: idx,
2120                        dim: dim_len,
2121                    });
2122                }
2123                let src_start = (o * dim_len + idx) * inner;
2124                out.extend_from_slice(&data[src_start..src_start + inner]);
2125            }
2126        }
2127
2128        let mut out_shape = shape.to_vec();
2129        out_shape[dim] = n_idx;
2130        Tensor::from_vec(out_shape, out)
2131    }
2132
2133    // ── Random tensor creation ──────────────────────────────────────────
2134
2135    /// Create a tensor filled with uniform random values in [0, 1).
2136    pub fn rand(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2137        let n: usize = shape.iter().product();
2138        let mut rng = seed;
2139        let data: Vec<f32> = (0..n)
2140            .map(|_| {
2141                rng ^= rng << 13;
2142                rng ^= rng >> 7;
2143                rng ^= rng << 17;
2144                (rng as f32) / (u64::MAX as f32)
2145            })
2146            .collect();
2147        Self::from_vec(shape, data)
2148    }
2149
2150    /// Create a tensor filled with normally distributed random values (mean=0, std=1).
2151    /// Uses Box-Muller transform.
2152    pub fn randn(shape: Vec<usize>, seed: u64) -> Result<Self, TensorError> {
2153        let n: usize = shape.iter().product();
2154        let mut rng = seed;
2155        let mut next_rng = || -> f32 {
2156            rng ^= rng << 13;
2157            rng ^= rng >> 7;
2158            rng ^= rng << 17;
2159            // Map to (0, 1) exclusive to avoid log(0)
2160            ((rng as f64) / (u64::MAX as f64)).clamp(1e-15, 1.0 - 1e-15) as f32
2161        };
2162        let mut data = Vec::with_capacity(n);
2163        let mut i = 0;
2164        while i < n {
2165            let u1 = next_rng();
2166            let u2 = next_rng();
2167            let r = (-2.0 * (u1 as f64).ln()).sqrt();
2168            let theta = 2.0 * std::f64::consts::PI * u2 as f64;
2169            data.push((r * theta.cos()) as f32);
2170            i += 1;
2171            if i < n {
2172                data.push((r * theta.sin()) as f32);
2173                i += 1;
2174            }
2175        }
2176        Self::from_vec(shape, data)
2177    }
2178
2179    /// Create a tensor filled with random integers in [low, high).
2180    pub fn randint(shape: Vec<usize>, low: i64, high: i64, seed: u64) -> Result<Self, TensorError> {
2181        if high <= low {
2182            return Err(TensorError::UnsupportedOperation {
2183                msg: format!("randint requires high > low, got low={low}, high={high}"),
2184            });
2185        }
2186        let range = (high - low) as u64;
2187        let n: usize = shape.iter().product();
2188        let mut rng = seed;
2189        let data: Vec<f32> = (0..n)
2190            .map(|_| {
2191                rng ^= rng << 13;
2192                rng ^= rng >> 7;
2193                rng ^= rng << 17;
2194                (low + (rng % range) as i64) as f32
2195            })
2196            .collect();
2197        Self::from_vec(shape, data)
2198    }
2199
2200    /// Create a random permutation of integers [0, n).
2201    pub fn randperm(n: usize, seed: u64) -> Result<Self, TensorError> {
2202        let mut perm: Vec<f32> = (0..n).map(|i| i as f32).collect();
2203        let mut rng = seed;
2204        for i in (1..n).rev() {
2205            rng ^= rng << 13;
2206            rng ^= rng >> 7;
2207            rng ^= rng << 17;
2208            let j = (rng as usize) % (i + 1);
2209            perm.swap(i, j);
2210        }
2211        Self::from_vec(vec![n], perm)
2212    }
2213}
2214
2215// ── Advanced tensor operations ──────────────────────────────────────────
2216
2217impl Tensor {
2218    /// Slice with step: extract every `step`-th element along `dim` from `start` to `end`.
2219    pub fn step_slice(
2220        &self,
2221        dim: usize,
2222        start: usize,
2223        end: usize,
2224        step: usize,
2225    ) -> Result<Self, TensorError> {
2226        let rank = self.rank();
2227        if dim >= rank {
2228            return Err(TensorError::InvalidAxis { axis: dim, rank });
2229        }
2230        if step == 0 {
2231            return Err(TensorError::UnsupportedOperation {
2232                msg: "step must be > 0".to_string(),
2233            });
2234        }
2235        let shape = self.shape();
2236        let dim_len = shape[dim];
2237        let end = end.min(dim_len);
2238        if start >= end {
2239            // empty along this dim
2240            let mut out_shape = shape.to_vec();
2241            out_shape[dim] = 0;
2242            return Tensor::from_vec(out_shape, vec![]);
2243        }
2244
2245        let selected_indices: Vec<usize> = (start..end).step_by(step).collect();
2246        let new_dim = selected_indices.len();
2247
2248        let outer: usize = shape[..dim].iter().product();
2249        let inner: usize = shape[dim + 1..].iter().product();
2250        let data = self.data();
2251
2252        let mut out = Vec::with_capacity(outer * new_dim * inner);
2253        for o in 0..outer {
2254            for &idx in &selected_indices {
2255                let src_start = (o * dim_len + idx) * inner;
2256                out.extend_from_slice(&data[src_start..src_start + inner]);
2257            }
2258        }
2259
2260        let mut out_shape = shape.to_vec();
2261        out_shape[dim] = new_dim;
2262        Tensor::from_vec(out_shape, out)
2263    }
2264
2265    /// Einstein summation for common patterns.
2266    ///
2267    /// Supported equations:
2268    /// - `"ij,jk->ik"` — matrix multiply
2269    /// - `"ij->ji"` — transpose
2270    /// - `"ii->i"` — diagonal
2271    /// - `"ij->i"` — row sum
2272    /// - `"ij->j"` — column sum
2273    /// - `"ij->"` — total sum
2274    /// - `"i,i->"` — dot product
2275    /// - `"ij,ij->"` — Frobenius inner product
2276    pub fn einsum(equation: &str, tensors: &[&Tensor]) -> Result<Tensor, TensorError> {
2277        let equation = equation.replace(' ', "");
2278        let parts: Vec<&str> = equation.split("->").collect();
2279        if parts.len() != 2 {
2280            return Err(TensorError::UnsupportedOperation {
2281                msg: format!("invalid einsum equation: {equation}"),
2282            });
2283        }
2284        let inputs_str = parts[0];
2285        let output_str = parts[1];
2286        let input_parts: Vec<&str> = inputs_str.split(',').collect();
2287
2288        if input_parts.len() != tensors.len() {
2289            return Err(TensorError::UnsupportedOperation {
2290                msg: format!(
2291                    "einsum equation has {} inputs but {} tensors provided",
2292                    input_parts.len(),
2293                    tensors.len()
2294                ),
2295            });
2296        }
2297
2298        // Match known patterns
2299        let pattern = format!(
2300            "{}->{}",
2301            input_parts
2302                .iter()
2303                .map(|s| s.to_string())
2304                .collect::<Vec<_>>()
2305                .join(","),
2306            output_str
2307        );
2308
2309        match pattern.as_str() {
2310            // matrix multiply: ij,jk->ik
2311            "ij,jk->ik" => {
2312                let a = tensors[0];
2313                let b = tensors[1];
2314                if a.rank() != 2 || b.rank() != 2 {
2315                    return Err(TensorError::UnsupportedOperation {
2316                        msg: "ij,jk->ik requires 2D tensors".to_string(),
2317                    });
2318                }
2319                let (m, k1) = (a.shape()[0], a.shape()[1]);
2320                let (k2, n) = (b.shape()[0], b.shape()[1]);
2321                if k1 != k2 {
2322                    return Err(TensorError::ShapeMismatch {
2323                        left: a.shape().to_vec(),
2324                        right: b.shape().to_vec(),
2325                    });
2326                }
2327                let ad = a.data();
2328                let bd = b.data();
2329                let mut out = vec![0.0f32; m * n];
2330                for i in 0..m {
2331                    for j in 0..n {
2332                        let mut sum = 0.0f32;
2333                        for p in 0..k1 {
2334                            sum += ad[i * k1 + p] * bd[p * n + j];
2335                        }
2336                        out[i * n + j] = sum;
2337                    }
2338                }
2339                Tensor::from_vec(vec![m, n], out)
2340            }
2341            // transpose: ij->ji
2342            "ij->ji" => {
2343                let a = tensors[0];
2344                if a.rank() != 2 {
2345                    return Err(TensorError::UnsupportedOperation {
2346                        msg: "ij->ji requires a 2D tensor".to_string(),
2347                    });
2348                }
2349                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2350                let ad = a.data();
2351                let mut out = vec![0.0f32; rows * cols];
2352                for i in 0..rows {
2353                    for j in 0..cols {
2354                        out[j * rows + i] = ad[i * cols + j];
2355                    }
2356                }
2357                Tensor::from_vec(vec![cols, rows], out)
2358            }
2359            // diagonal: ii->i
2360            "ii->i" => {
2361                let a = tensors[0];
2362                if a.rank() != 2 || a.shape()[0] != a.shape()[1] {
2363                    return Err(TensorError::UnsupportedOperation {
2364                        msg: "ii->i requires a square 2D tensor".to_string(),
2365                    });
2366                }
2367                let n = a.shape()[0];
2368                let ad = a.data();
2369                let out: Vec<f32> = (0..n).map(|i| ad[i * n + i]).collect();
2370                Tensor::from_vec(vec![n], out)
2371            }
2372            // row sum: ij->i
2373            "ij->i" => {
2374                let a = tensors[0];
2375                if a.rank() != 2 {
2376                    return Err(TensorError::UnsupportedOperation {
2377                        msg: "ij->i requires a 2D tensor".to_string(),
2378                    });
2379                }
2380                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2381                let ad = a.data();
2382                let out: Vec<f32> = (0..rows)
2383                    .map(|i| ad[i * cols..(i + 1) * cols].iter().sum())
2384                    .collect();
2385                Tensor::from_vec(vec![rows], out)
2386            }
2387            // column sum: ij->j
2388            "ij->j" => {
2389                let a = tensors[0];
2390                if a.rank() != 2 {
2391                    return Err(TensorError::UnsupportedOperation {
2392                        msg: "ij->j requires a 2D tensor".to_string(),
2393                    });
2394                }
2395                let (rows, cols) = (a.shape()[0], a.shape()[1]);
2396                let ad = a.data();
2397                let mut out = vec![0.0f32; cols];
2398                for i in 0..rows {
2399                    for j in 0..cols {
2400                        out[j] += ad[i * cols + j];
2401                    }
2402                }
2403                Tensor::from_vec(vec![cols], out)
2404            }
2405            // total sum: ij->
2406            "ij->" => {
2407                let a = tensors[0];
2408                if a.rank() != 2 {
2409                    return Err(TensorError::UnsupportedOperation {
2410                        msg: "ij-> requires a 2D tensor".to_string(),
2411                    });
2412                }
2413                let sum: f32 = a.data().iter().sum();
2414                Ok(Tensor::scalar(sum))
2415            }
2416            // dot product: i,i->
2417            "i,i->" => {
2418                let a = tensors[0];
2419                let b = tensors[1];
2420                if a.rank() != 1 || b.rank() != 1 {
2421                    return Err(TensorError::UnsupportedOperation {
2422                        msg: "i,i-> requires 1D tensors".to_string(),
2423                    });
2424                }
2425                if a.shape()[0] != b.shape()[0] {
2426                    return Err(TensorError::ShapeMismatch {
2427                        left: a.shape().to_vec(),
2428                        right: b.shape().to_vec(),
2429                    });
2430                }
2431                let sum: f32 = a
2432                    .data()
2433                    .iter()
2434                    .zip(b.data().iter())
2435                    .map(|(x, y)| x * y)
2436                    .sum();
2437                Ok(Tensor::scalar(sum))
2438            }
2439            // Frobenius inner product: ij,ij->
2440            "ij,ij->" => {
2441                let a = tensors[0];
2442                let b = tensors[1];
2443                if a.rank() != 2 || b.rank() != 2 {
2444                    return Err(TensorError::UnsupportedOperation {
2445                        msg: "ij,ij-> requires 2D tensors".to_string(),
2446                    });
2447                }
2448                if a.shape() != b.shape() {
2449                    return Err(TensorError::ShapeMismatch {
2450                        left: a.shape().to_vec(),
2451                        right: b.shape().to_vec(),
2452                    });
2453                }
2454                let sum: f32 = a
2455                    .data()
2456                    .iter()
2457                    .zip(b.data().iter())
2458                    .map(|(x, y)| x * y)
2459                    .sum();
2460                Ok(Tensor::scalar(sum))
2461            }
2462            _ => Err(TensorError::UnsupportedOperation {
2463                msg: format!("unsupported einsum pattern: {pattern}"),
2464            }),
2465        }
2466    }
2467
2468    // ── Chunk ───────────────────────────────────────────────────────────
2469
2470    /// Split tensor into `n_chunks` pieces along `axis`. Last chunk may be smaller.
2471    pub fn chunk(&self, n_chunks: usize, axis: usize) -> Result<Vec<Self>, TensorError> {
2472        if axis >= self.rank() {
2473            return Err(TensorError::InvalidAxis {
2474                axis,
2475                rank: self.rank(),
2476            });
2477        }
2478        if n_chunks == 0 {
2479            return Err(TensorError::UnsupportedOperation {
2480                msg: "n_chunks must be > 0".to_string(),
2481            });
2482        }
2483        let dim = self.shape()[axis];
2484        let chunk_size = dim.div_ceil(n_chunks); // ceil division
2485        let mut chunks = Vec::new();
2486        let mut start = 0;
2487        while start < dim {
2488            let length = chunk_size.min(dim - start);
2489            chunks.push(self.narrow(axis, start, length)?);
2490            start += length;
2491        }
2492        Ok(chunks)
2493    }
2494
2495    // ── Histogram ───────────────────────────────────────────────────────
2496
2497    /// Counts elements in each bin, returns 1D tensor of shape `[bins]`.
2498    /// Bins are evenly spaced between `min` and `max`.
2499    pub fn histogram(&self, bins: usize, min: f32, max: f32) -> Result<Self, TensorError> {
2500        let mut counts = vec![0.0f32; bins];
2501        let range = max - min;
2502        for &v in self.data() {
2503            if v >= min && v <= max {
2504                let idx = if v == max {
2505                    bins - 1
2506                } else {
2507                    ((v - min) / range * bins as f32) as usize
2508                };
2509                counts[idx] += 1.0;
2510            }
2511        }
2512        Tensor::from_vec(vec![bins], counts)
2513    }
2514
2515    // ── Bincount ────────────────────────────────────────────────────────
2516
2517    /// Treats values as integer indices, counts occurrences.
2518    /// Returns 1D tensor of shape `[num_bins]`.
2519    pub fn bincount(&self, num_bins: usize) -> Result<Self, TensorError> {
2520        let mut counts = vec![0.0f32; num_bins];
2521        for &v in self.data() {
2522            let idx = v as usize;
2523            if idx < num_bins {
2524                counts[idx] += 1.0;
2525            }
2526        }
2527        Tensor::from_vec(vec![num_bins], counts)
2528    }
2529
2530    // ── Scalar convenience ──────────────────────────────────────────────
2531
2532    /// Returns the single scalar value if tensor has exactly one element.
2533    /// Errors if tensor has more than one element.
2534    pub fn item(&self) -> Result<f32, TensorError> {
2535        if self.len() != 1 {
2536            return Err(TensorError::ShapeMismatch {
2537                left: vec![1],
2538                right: self.shape().to_vec(),
2539            });
2540        }
2541        Ok(self.data()[0])
2542    }
2543
2544    /// Returns true if tensor has exactly one element.
2545    pub fn is_scalar(&self) -> bool {
2546        self.len() == 1
2547    }
2548
2549    // ── Scatter Add ──────────────────────────────────────────────────
2550
2551    /// Like `scatter` but adds instead of replacing values.
2552    ///
2553    /// For `dim=1`: `self[i][index[i][j][k]][k] += src[i][j][k]`
2554    pub fn scatter_add(&self, dim: usize, index: &Self, src: &Self) -> Result<Self, TensorError> {
2555        if dim >= self.rank() {
2556            return Err(TensorError::InvalidAxis {
2557                axis: dim,
2558                rank: self.rank(),
2559            });
2560        }
2561        if index.rank() != self.rank() {
2562            return Err(TensorError::InvalidIndexRank {
2563                expected: self.rank(),
2564                got: index.rank(),
2565            });
2566        }
2567        if src.shape() != index.shape() {
2568            return Err(TensorError::ShapeMismatch {
2569                left: src.shape().to_vec(),
2570                right: index.shape().to_vec(),
2571            });
2572        }
2573
2574        let self_shape = self.shape();
2575        let idx_shape = index.shape();
2576        let idx_data = index.data();
2577        let src_data = src.data();
2578        let ndim = self.rank();
2579
2580        let mut out = self.data().to_vec();
2581        let mut coords = vec![0usize; ndim];
2582
2583        for pos in 0..index.len() {
2584            let idx_val = idx_data[pos] as usize;
2585            if idx_val >= self_shape[dim] {
2586                return Err(TensorError::IndexOutOfBounds {
2587                    axis: dim,
2588                    index: idx_val,
2589                    dim: self_shape[dim],
2590                });
2591            }
2592
2593            let mut dst_offset = 0;
2594            for d in 0..ndim {
2595                let c = if d == dim { idx_val } else { coords[d] };
2596                dst_offset += c * self.strides()[d];
2597            }
2598            out[dst_offset] += src_data[pos];
2599
2600            increment_coords(&mut coords, idx_shape);
2601        }
2602
2603        Tensor::from_vec(self_shape.to_vec(), out)
2604    }
2605}
2606
2607/// Convert IEEE 754 half-precision (FP16) bit pattern to f32.
2608fn fp16_to_f32(half: u16) -> f32 {
2609    let sign = ((half & 0x8000) as u32) << 16;
2610    let exponent = (half >> 10) & 0x1F;
2611    let mantissa = (half & 0x03FF) as u32;
2612
2613    if exponent == 0 {
2614        if mantissa == 0 {
2615            return f32::from_bits(sign); // zero
2616        }
2617        // subnormal
2618        let mut m = mantissa;
2619        let mut e = 0i32;
2620        while m & 0x0400 == 0 {
2621            m <<= 1;
2622            e += 1;
2623        }
2624        m &= 0x03FF;
2625        let f32_exp = ((127 - 15 - e) as u32) << 23;
2626        let f32_man = m << 13;
2627        return f32::from_bits(sign | f32_exp | f32_man);
2628    }
2629    if exponent == 31 {
2630        let f32_exp = 0xFF << 23;
2631        let f32_man = mantissa << 13;
2632        return f32::from_bits(sign | f32_exp | f32_man);
2633    }
2634
2635    let f32_exp = ((exponent as i32 - 15 + 127) as u32 & 0xFF) << 23;
2636    let f32_man = mantissa << 13;
2637    f32::from_bits(sign | f32_exp | f32_man)
2638}