Skip to main content

shrew_core/
tensor.rs

1use std::sync::{Arc, RwLock};
2
3use crate::backend::{Backend, BinaryOp, CmpOp, ReduceOp, UnaryOp};
4use crate::dtype::DType;
5use crate::error::{Error, Result};
6use crate::layout::Layout;
7use crate::op::{Op, TensorId};
8use crate::shape::Shape;
9
10// Tensor — The fundamental data structure
11//
12// A Tensor is an n-dimensional array of numbers, the building block of all
13// neural network computations. Like in PyTorch, our Tensor:
14//
15//   1. Holds data on a specific device (CPU, GPU)
16//   2. Has a shape (e.g., [batch, channels, height, width])
17//   3. Has a dtype (f32, f64, etc.)
18//   4. Optionally tracks the operation that created it (for autograd)
19//
20// ARCHITECTURE:
21//
22//   Tensor<B: Backend> is generic over the backend. This means:
23//     - Tensor<CpuBackend> holds data in CPU memory
24//     - Tensor<CudaBackend> holds data in GPU memory
25//     - Operations are dispatched via the Backend trait
26//
27// MEMORY MODEL:
28//
29//   The inner data is wrapped in Arc (atomic reference counting).
30//   This means cloning a Tensor is cheap (just increments a counter).
31//   Multiple tensors can share the same underlying storage (views).
32//
33//   Storage is behind Arc<RwLock<Storage>> so that:
34//   - Multiple tensors can read concurrently
35//   - In-place ops can write when there's only one reference
36//
37// WHY Arc + inner struct?
38//
39//   We separate Tensor (the handle) from TensorInner (the data) so that:
40//   - Cloning Tensor is O(1) — just copies the Arc pointer
41//   - The autograd graph can hold TensorIds without owning data
42//   - Views (transpose, narrow) share the same storage via Arc<RwLock<>>
43
44/// Inner data of a tensor, shared via Arc.
45struct TensorInner<B: Backend> {
46    /// Unique identifier for this tensor (used in autograd graph).
47    id: TensorId,
48    /// The raw data stored on the backend's device.
49    storage: Arc<RwLock<B::Storage>>,
50    /// Memory layout: shape + strides + offset.
51    layout: Layout,
52    /// Data type of the elements.
53    dtype: DType,
54    /// The device this tensor lives on.
55    device: B::Device,
56    /// The operation that created this tensor (for autograd).
57    /// None for leaf tensors (inputs, parameters).
58    op: Op<B>,
59    /// Whether this tensor is a trainable variable.
60    /// Only variables accumulate gradients during backward().
61    is_variable: bool,
62}
63
64/// An n-dimensional array of numbers on a specific backend.
65///
66/// Tensors are the fundamental data type in Shrew. All neural network
67/// operations accept and return tensors.
68///
69/// # Type Parameter
70/// - `B: Backend` — the compute backend (e.g., `CpuBackend`, `CudaBackend`)
71///
72/// # Example
73/// ```ignore
74/// use shrew_core::Tensor;
75/// use shrew_cpu::CpuBackend;
76///
77/// let a = Tensor::<CpuBackend>::from_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2))?;
78/// let b = Tensor::<CpuBackend>::ones((2, 2), DType::F32, &CpuDevice)?;
79/// let c = a.add(&b)?;
80/// ```
81pub struct Tensor<B: Backend> {
82    inner: Arc<TensorInner<B>>,
83}
84
85// Manual Clone: Arc::clone is cheap (just increment refcount).
86impl<B: Backend> Clone for Tensor<B> {
87    fn clone(&self) -> Self {
88        Tensor {
89            inner: Arc::clone(&self.inner),
90        }
91    }
92}
93
94impl<B: Backend> std::fmt::Debug for Tensor<B> {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(
97            f,
98            "Tensor(id={:?}, shape={}, dtype={}, device={:?})",
99            self.inner.id,
100            self.inner.layout.shape(),
101            self.inner.dtype,
102            self.inner.device,
103        )
104    }
105}
106
107impl<B: Backend> Tensor<B> {
108    // Internal constructors
109
110    /// Create a tensor from existing storage and layout.
111    pub(crate) fn from_storage(
112        storage: B::Storage,
113        layout: Layout,
114        dtype: DType,
115        device: B::Device,
116        op: Op<B>,
117    ) -> Self {
118        Tensor {
119            inner: Arc::new(TensorInner {
120                id: TensorId::new(),
121                storage: Arc::new(RwLock::new(storage)),
122                layout,
123                dtype,
124                device,
125                op,
126                is_variable: false,
127            }),
128        }
129    }
130
131    /// Create a view tensor sharing the same storage but with a different layout.
132    fn view_with_layout(&self, layout: Layout, op: Op<B>) -> Self {
133        Tensor {
134            inner: Arc::new(TensorInner {
135                id: TensorId::new(),
136                storage: Arc::clone(&self.inner.storage),
137                layout,
138                dtype: self.inner.dtype,
139                device: self.inner.device.clone(),
140                op,
141                is_variable: false,
142            }),
143        }
144    }
145
146    // Accessors
147
148    /// Unique tensor ID.
149    pub fn id(&self) -> TensorId {
150        self.inner.id
151    }
152
153    /// The shape of this tensor.
154    pub fn shape(&self) -> &Shape {
155        self.inner.layout.shape()
156    }
157
158    /// The dimensions as a slice (shortcut for shape().dims()).
159    pub fn dims(&self) -> &[usize] {
160        self.inner.layout.dims()
161    }
162
163    /// Number of dimensions (rank).
164    pub fn rank(&self) -> usize {
165        self.inner.layout.rank()
166    }
167
168    /// Total number of elements.
169    pub fn elem_count(&self) -> usize {
170        self.inner.layout.elem_count()
171    }
172
173    /// Data type of the elements.
174    pub fn dtype(&self) -> DType {
175        self.inner.dtype
176    }
177
178    /// The device this tensor is on.
179    pub fn device(&self) -> &B::Device {
180        &self.inner.device
181    }
182
183    /// The memory layout (shape + strides + offset).
184    pub fn layout(&self) -> &Layout {
185        &self.inner.layout
186    }
187
188    /// Whether this tensor is contiguous in memory.
189    pub fn is_contiguous(&self) -> bool {
190        self.inner.layout.is_contiguous()
191    }
192
193    /// Whether this tensor tracks gradients.
194    pub fn is_variable(&self) -> bool {
195        self.inner.is_variable
196    }
197
198    /// Access the underlying storage (read lock).
199    pub fn storage(&self) -> std::sync::RwLockReadGuard<'_, B::Storage> {
200        self.inner.storage.read().expect("storage lock poisoned")
201    }
202
203    /// Try to acquire a read lock on storage, returning an error instead of panicking.
204    fn read_storage(&self) -> Result<std::sync::RwLockReadGuard<'_, B::Storage>> {
205        self.inner
206            .storage
207            .read()
208            .map_err(|_| Error::msg("storage lock poisoned"))
209    }
210
211    /// Try to acquire a write lock on storage, returning an error instead of panicking.
212    fn write_storage(&self) -> Result<std::sync::RwLockWriteGuard<'_, B::Storage>> {
213        self.inner
214            .storage
215            .write()
216            .map_err(|_| Error::msg("storage lock poisoned"))
217    }
218
219    /// The op that created this tensor.
220    pub fn op(&self) -> &Op<B> {
221        &self.inner.op
222    }
223
224    // In-place mutation
225
226    /// Update the underlying storage data in place.
227    ///
228    /// This writes `new_data` directly into the existing `Arc<RwLock<Storage>>`,
229    /// so any other tensor sharing this storage (e.g., a clone held by a Module)
230    /// will also see the updated values.
231    ///
232    /// This is the mechanism that makes optimizer parameter updates visible to
233    /// model layers without needing to re-assign parameters.
234    ///
235    /// # Safety (logical)
236    /// The new data must have the same number of elements and dtype as the
237    /// current storage. The shape is not changed.
238    pub fn update_data_inplace(&self, new_data: &[f64]) -> Result<()> {
239        let expected = self.elem_count();
240        if new_data.len() != expected {
241            return Err(Error::msg(format!(
242                "update_data_inplace: expected {} elements, got {}",
243                expected,
244                new_data.len()
245            )));
246        }
247        let new_storage = B::from_f64_slice(new_data, self.dtype(), self.device())?;
248        let mut guard = self.write_storage()?;
249        *guard = new_storage;
250        Ok(())
251    }
252
253    // Creation methods
254
255    /// Create a tensor filled with zeros.
256    pub fn zeros(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
257        let shape = shape.into();
258        let layout = Layout::contiguous(shape.clone());
259        let storage = B::zeros(&shape, dtype, device)?;
260        Ok(Self::from_storage(
261            storage,
262            layout,
263            dtype,
264            device.clone(),
265            Op::None,
266        ))
267    }
268
269    /// Create a tensor filled with ones.
270    pub fn ones(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
271        let shape = shape.into();
272        let layout = Layout::contiguous(shape.clone());
273        let storage = B::ones(&shape, dtype, device)?;
274        Ok(Self::from_storage(
275            storage,
276            layout,
277            dtype,
278            device.clone(),
279            Op::None,
280        ))
281    }
282
283    /// Create a tensor filled with a constant value.
284    pub fn full(
285        shape: impl Into<Shape>,
286        val: f64,
287        dtype: DType,
288        device: &B::Device,
289    ) -> Result<Self> {
290        let shape = shape.into();
291        let layout = Layout::contiguous(shape.clone());
292        let storage = B::full(&shape, val, dtype, device)?;
293        Ok(Self::from_storage(
294            storage,
295            layout,
296            dtype,
297            device.clone(),
298            Op::None,
299        ))
300    }
301
302    /// Create a tensor from a flat slice of f64 values.
303    /// The data is converted to the specified dtype.
304    pub fn from_f64_slice(
305        data: &[f64],
306        shape: impl Into<Shape>,
307        dtype: DType,
308        device: &B::Device,
309    ) -> Result<Self> {
310        let shape = shape.into();
311        if data.len() != shape.elem_count() {
312            return Err(Error::ElementCountMismatch {
313                shape: shape.clone(),
314                expected: shape.elem_count(),
315                got: data.len(),
316            });
317        }
318        let layout = Layout::contiguous(shape);
319        let storage = B::from_f64_slice(data, dtype, device)?;
320        Ok(Self::from_storage(
321            storage,
322            layout,
323            dtype,
324            device.clone(),
325            Op::None,
326        ))
327    }
328
329    /// Create a tensor with random uniform values in [0, 1).
330    pub fn rand(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
331        let shape = shape.into();
332        let layout = Layout::contiguous(shape.clone());
333        let storage = B::rand_uniform(&shape, dtype, device)?;
334        Ok(Self::from_storage(
335            storage,
336            layout,
337            dtype,
338            device.clone(),
339            Op::None,
340        ))
341    }
342
343    /// Create a tensor with random normal values (mean=0, std=1).
344    pub fn randn(shape: impl Into<Shape>, dtype: DType, device: &B::Device) -> Result<Self> {
345        let shape = shape.into();
346        let layout = Layout::contiguous(shape.clone());
347        let storage = B::rand_normal(&shape, dtype, device)?;
348        Ok(Self::from_storage(
349            storage,
350            layout,
351            dtype,
352            device.clone(),
353            Op::None,
354        ))
355    }
356
357    /// Create a 1-D tensor with `steps` evenly spaced values from `start` to `end` (inclusive).
358    ///
359    /// ```ignore
360    /// let t = Tensor::linspace(0.0, 1.0, 5, DType::F64, &dev)?;
361    /// // => [0.0, 0.25, 0.5, 0.75, 1.0]
362    /// ```
363    pub fn linspace(
364        start: f64,
365        end: f64,
366        steps: usize,
367        dtype: DType,
368        device: &B::Device,
369    ) -> Result<Self> {
370        if steps == 0 {
371            return Err(Error::msg("linspace requires steps >= 1"));
372        }
373        if steps == 1 {
374            return Self::from_f64_slice(&[start], 1, dtype, device);
375        }
376        let step = (end - start) / (steps as f64 - 1.0);
377        let data: Vec<f64> = (0..steps).map(|i| start + step * i as f64).collect();
378        Self::from_f64_slice(&data, steps, dtype, device)
379    }
380
381    /// Create an identity matrix of size `n × n`.
382    ///
383    /// ```ignore
384    /// let I = Tensor::eye(3, DType::F64, &dev)?;
385    /// // [[1, 0, 0],
386    /// //  [0, 1, 0],
387    /// //  [0, 0, 1]]
388    /// ```
389    pub fn eye(n: usize, dtype: DType, device: &B::Device) -> Result<Self> {
390        let mut data = vec![0.0f64; n * n];
391        for i in 0..n {
392            data[i * n + i] = 1.0;
393        }
394        Self::from_f64_slice(&data, (n, n), dtype, device)
395    }
396
397    /// Create a tensor of zeros with the same shape, dtype, and device as `other`.
398    pub fn zeros_like(other: &Self) -> Result<Self> {
399        Self::zeros(other.shape().clone(), other.dtype(), other.device())
400    }
401
402    /// Create a tensor of ones with the same shape, dtype, and device as `other`.
403    pub fn ones_like(other: &Self) -> Result<Self> {
404        Self::ones(other.shape().clone(), other.dtype(), other.device())
405    }
406
407    /// Create a tensor filled with `val`, with the same shape, dtype, and device as `other`.
408    pub fn full_like(other: &Self, val: f64) -> Result<Self> {
409        Self::full(other.shape().clone(), val, other.dtype(), other.device())
410    }
411
412    /// Mark this tensor as a variable (trainable parameter).
413    /// Variables accumulate gradients during backward().
414    pub fn set_variable(self) -> Self {
415        Tensor {
416            inner: Arc::new(TensorInner {
417                id: self.inner.id,
418                storage: Arc::clone(&self.inner.storage),
419                layout: self.inner.layout.clone(),
420                dtype: self.inner.dtype,
421                device: self.inner.device.clone(),
422                op: self.inner.op.clone(),
423                is_variable: true,
424            }),
425        }
426    }
427
428    // Shape manipulation (these create views, no data copy)
429
430    /// Transpose two dimensions (no data copy).
431    pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Self> {
432        let new_layout = self.inner.layout.transpose(dim0, dim1)?;
433        let op = Op::Transpose {
434            input: self.clone(),
435            dim0,
436            dim1,
437        };
438        Ok(self.view_with_layout(new_layout, op))
439    }
440
441    /// Transpose a 2D matrix (shorthand for transpose(0, 1)).
442    pub fn t(&self) -> Result<Self> {
443        if self.rank() != 2 {
444            return Err(Error::RankMismatch {
445                expected: 2,
446                got: self.rank(),
447            });
448        }
449        self.transpose(0, 1)
450    }
451
452    /// Narrow (slice) along a dimension.
453    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
454        let new_layout = self.inner.layout.narrow(dim, start, len)?;
455        let op = Op::Narrow {
456            input: self.clone(),
457            dim,
458            start,
459            len,
460        };
461        Ok(self.view_with_layout(new_layout, op))
462    }
463
464    /// Reshape to a new shape. The new shape must have the same total elements.
465    /// If the tensor is not contiguous, it will be made contiguous first.
466    pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Self> {
467        let new_shape = new_shape.into();
468        let current_count = self.elem_count();
469        let new_count = new_shape.elem_count();
470        if current_count != new_count {
471            return Err(Error::ReshapeElementMismatch {
472                src: current_count,
473                dst: new_count,
474                dst_shape: new_shape,
475            });
476        }
477        // If not contiguous, make a contiguous copy first
478        let tensor = if self.is_contiguous() {
479            self.clone()
480        } else {
481            self.contiguous()?
482        };
483        let src_shape = tensor.shape().clone();
484        let new_layout = Layout::contiguous(new_shape);
485        let op = Op::Reshape {
486            input: tensor.clone(),
487            src_shape,
488        };
489        Ok(tensor.view_with_layout(new_layout, op))
490    }
491
492    /// Ensure the tensor is contiguous in memory.
493    /// If already contiguous, returns a clone (cheap Arc copy).
494    /// Otherwise, copies the data into a new contiguous storage.
495    pub fn contiguous(&self) -> Result<Self> {
496        if self.is_contiguous() {
497            return Ok(self.clone());
498        }
499        let storage = self.read_storage()?;
500        let new_storage = B::to_contiguous(&storage, &self.inner.layout)?;
501        let new_layout = Layout::contiguous(self.shape().clone());
502        Ok(Self::from_storage(
503            new_storage,
504            new_layout,
505            self.inner.dtype,
506            self.inner.device.clone(),
507            Op::Contiguous {
508                input: self.clone(),
509            },
510        ))
511    }
512
513    /// Add a dimension of size 1 at the given position.
514    /// unsqueeze(0) on [3, 4] → [1, 3, 4]
515    /// unsqueeze(2) on [3, 4] → [3, 4, 1]
516    pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
517        let rank = self.rank();
518        if dim > rank {
519            return Err(Error::DimOutOfRange {
520                dim,
521                rank: rank + 1,
522            });
523        }
524        let mut new_dims = self.dims().to_vec();
525        let mut new_strides = self.layout().strides().to_vec();
526        // The stride for a size-1 dim doesn't matter (you never move along it),
527        // but convention is to use the stride of the next dimension (or 1 if last).
528        let stride_val = if dim < rank { new_strides[dim] } else { 1 };
529        new_dims.insert(dim, 1);
530        new_strides.insert(dim, stride_val);
531        let new_layout = Layout::new(Shape::new(new_dims), new_strides, self.layout().offset());
532        let op = Op::Reshape {
533            input: self.clone(),
534            src_shape: self.shape().clone(),
535        };
536        Ok(self.view_with_layout(new_layout, op))
537    }
538
539    /// Remove dimensions of size 1.
540    /// squeeze on [1, 3, 1, 4] → [3, 4]
541    pub fn squeeze_all(&self) -> Self {
542        let new_dims: Vec<usize> = self.dims().iter().copied().filter(|&d| d != 1).collect();
543        let new_strides: Vec<usize> = self
544            .dims()
545            .iter()
546            .zip(self.layout().strides().iter())
547            .filter(|(&d, _)| d != 1)
548            .map(|(_, &s)| s)
549            .collect();
550        let new_layout = Layout::new(
551            Shape::new(if new_dims.is_empty() {
552                vec![]
553            } else {
554                new_dims
555            }),
556            new_strides,
557            self.layout().offset(),
558        );
559        let op = Op::Reshape {
560            input: self.clone(),
561            src_shape: self.shape().clone(),
562        };
563        self.view_with_layout(new_layout, op)
564    }
565
566    /// Remove a specific dimension of size 1.
567    ///
568    /// squeeze(1) on [3, 1, 4] → [3, 4]
569    ///
570    /// Returns an error if the specified dimension is not size 1.
571    pub fn squeeze(&self, dim: usize) -> Result<Self> {
572        let rank = self.rank();
573        if dim >= rank {
574            return Err(Error::DimOutOfRange { dim, rank });
575        }
576        if self.dims()[dim] != 1 {
577            return Err(Error::msg(format!(
578                "squeeze: dimension {} has size {}, expected 1",
579                dim,
580                self.dims()[dim]
581            )));
582        }
583        let mut new_dims = self.dims().to_vec();
584        let mut new_strides = self.layout().strides().to_vec();
585        new_dims.remove(dim);
586        new_strides.remove(dim);
587        let new_layout = Layout::new(
588            Shape::new(if new_dims.is_empty() {
589                vec![]
590            } else {
591                new_dims
592            }),
593            new_strides,
594            self.layout().offset(),
595        );
596        let op = Op::Reshape {
597            input: self.clone(),
598            src_shape: self.shape().clone(),
599        };
600        Ok(self.view_with_layout(new_layout, op))
601    }
602
603    /// Permute the dimensions of this tensor.
604    ///
605    /// permute(&[2, 0, 1]) on [A, B, C] → [C, A, B]
606    ///
607    /// This is a generalization of transpose to arbitrary dimension orderings.
608    /// No data copy — just changes strides.
609    pub fn permute(&self, dims: &[usize]) -> Result<Self> {
610        let rank = self.rank();
611        if dims.len() != rank {
612            return Err(Error::msg(format!(
613                "permute: expected {} dimensions, got {}",
614                rank,
615                dims.len()
616            )));
617        }
618        // Check for duplicates and out-of-range
619        let mut seen = vec![false; rank];
620        for &d in dims {
621            if d >= rank {
622                return Err(Error::DimOutOfRange { dim: d, rank });
623            }
624            if seen[d] {
625                return Err(Error::msg(format!("permute: duplicate dimension {}", d)));
626            }
627            seen[d] = true;
628        }
629
630        let old_dims = self.dims();
631        let old_strides = self.layout().strides();
632        let new_dims: Vec<usize> = dims.iter().map(|&d| old_dims[d]).collect();
633        let new_strides: Vec<usize> = dims.iter().map(|&d| old_strides[d]).collect();
634        let new_layout = Layout::new(Shape::new(new_dims), new_strides, self.layout().offset());
635        // Use a chain of transposes conceptually, but represent as reshape
636        // for backward compatibility. The backward pass handles reshapes correctly.
637        let op = Op::Reshape {
638            input: self.clone(),
639            src_shape: self.shape().clone(),
640        };
641        Ok(self.view_with_layout(new_layout, op))
642    }
643
644    /// Cumulative sum along dimension `dim`.
645    ///
646    /// ```ignore
647    /// // [1, 2, 3] → [1, 3, 6]
648    /// let y = x.cumsum(0)?;
649    /// ```
650    pub fn cumsum(&self, dim: usize) -> Result<Self> {
651        let rank = self.rank();
652        if dim >= rank {
653            return Err(Error::DimOutOfRange { dim, rank });
654        }
655        let t = self.contiguous()?;
656        let data = t.to_f64_vec()?;
657        let shape = t.shape().clone();
658        let dims = shape.dims();
659        let mut out = data.clone();
660
661        // Compute strides for iteration
662        let inner: usize = dims[dim + 1..].iter().product();
663        let outer: usize = dims[..dim].iter().product();
664        let dim_size = dims[dim];
665
666        for o in 0..outer {
667            for i in 0..inner {
668                for d in 1..dim_size {
669                    let idx = (o * dim_size + d) * inner + i;
670                    let prev = (o * dim_size + d - 1) * inner + i;
671                    out[idx] += out[prev];
672                }
673            }
674        }
675
676        Self::from_f64_slice(&out, shape, t.dtype(), t.device())
677    }
678
679    /// Sort along a dimension. Returns `(sorted_values, sorted_indices)`.
680    ///
681    /// ```ignore
682    /// let (vals, idxs) = x.sort(0, false)?; // ascending along dim 0
683    /// ```
684    pub fn sort(&self, dim: usize, descending: bool) -> Result<(Self, Self)> {
685        let rank = self.rank();
686        if dim >= rank {
687            return Err(Error::DimOutOfRange { dim, rank });
688        }
689        let t = self.contiguous()?;
690        let data = t.to_f64_vec()?;
691        let shape = t.shape().clone();
692        let dims = shape.dims();
693        let dim_size = dims[dim];
694        let inner: usize = dims[dim + 1..].iter().product();
695        let outer: usize = dims[..dim].iter().product();
696
697        let mut sorted_data = data.clone();
698        let mut indices = vec![0.0f64; data.len()];
699
700        for o in 0..outer {
701            for i in 0..inner {
702                // Extract the slice along dim
703                let mut slice: Vec<(f64, usize)> = (0..dim_size)
704                    .map(|d| {
705                        let idx = (o * dim_size + d) * inner + i;
706                        (data[idx], d)
707                    })
708                    .collect();
709
710                if descending {
711                    slice
712                        .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
713                } else {
714                    slice
715                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
716                }
717
718                for (d, (val, orig_idx)) in slice.into_iter().enumerate() {
719                    let idx = (o * dim_size + d) * inner + i;
720                    sorted_data[idx] = val;
721                    indices[idx] = orig_idx as f64;
722                }
723            }
724        }
725
726        let vals = Self::from_f64_slice(&sorted_data, shape.clone(), t.dtype(), t.device())?;
727        let idxs = Self::from_f64_slice(&indices, shape, t.dtype(), t.device())?;
728        Ok((vals, idxs))
729    }
730
731    /// Argsort: returns indices that would sort the tensor along `dim`.
732    ///
733    /// ```ignore
734    /// let indices = x.argsort(0, false)?; // ascending
735    /// ```
736    pub fn argsort(&self, dim: usize, descending: bool) -> Result<Self> {
737        let (_, indices) = self.sort(dim, descending)?;
738        Ok(indices)
739    }
740
741    // Arithmetic operations
742
743    /// Element-wise addition: self + rhs.
744    pub fn add(&self, rhs: &Self) -> Result<Self> {
745        self.binary_op(rhs, BinaryOp::Add)
746    }
747
748    /// Element-wise subtraction: self - rhs.
749    pub fn sub(&self, rhs: &Self) -> Result<Self> {
750        self.binary_op(rhs, BinaryOp::Sub)
751    }
752
753    /// Element-wise multiplication: self * rhs.
754    pub fn mul(&self, rhs: &Self) -> Result<Self> {
755        self.binary_op(rhs, BinaryOp::Mul)
756    }
757
758    /// Element-wise division: self / rhs.
759    pub fn div(&self, rhs: &Self) -> Result<Self> {
760        self.binary_op(rhs, BinaryOp::Div)
761    }
762
763    /// Generic binary operation dispatch.
764    fn binary_op(&self, rhs: &Self, op: BinaryOp) -> Result<Self> {
765        if self.dtype() != rhs.dtype() {
766            return Err(Error::DTypeMismatch {
767                expected: self.dtype(),
768                got: rhs.dtype(),
769            });
770        }
771        let storage_lhs = self.read_storage()?;
772        let storage_rhs = rhs.read_storage()?;
773        let result = B::binary_op(
774            op,
775            &storage_lhs,
776            &self.inner.layout,
777            &storage_rhs,
778            &rhs.inner.layout,
779        )?;
780        // Compute broadcast output shape
781        let result_shape = Shape::broadcast_shape(self.shape(), rhs.shape())?;
782        let result_layout = Layout::contiguous(result_shape);
783        let result_op = Op::Binary {
784            lhs: self.clone(),
785            rhs: rhs.clone(),
786            op,
787        };
788        Ok(Self::from_storage(
789            result,
790            result_layout,
791            self.inner.dtype,
792            self.inner.device.clone(),
793            result_op,
794        ))
795    }
796
797    // Comparison operations
798
799    /// Element-wise equal: self == rhs. Returns a U8 tensor (0 or 1).
800    pub fn eq(&self, rhs: &Self) -> Result<Self> {
801        self.cmp_op(rhs, CmpOp::Eq)
802    }
803
804    /// Element-wise not-equal: self != rhs. Returns a U8 tensor (0 or 1).
805    pub fn ne(&self, rhs: &Self) -> Result<Self> {
806        self.cmp_op(rhs, CmpOp::Ne)
807    }
808
809    /// Element-wise greater-than: self > rhs. Returns a U8 tensor (0 or 1).
810    pub fn gt(&self, rhs: &Self) -> Result<Self> {
811        self.cmp_op(rhs, CmpOp::Gt)
812    }
813
814    /// Element-wise greater-or-equal: self >= rhs. Returns a U8 tensor (0 or 1).
815    pub fn ge(&self, rhs: &Self) -> Result<Self> {
816        self.cmp_op(rhs, CmpOp::Ge)
817    }
818
819    /// Element-wise less-than: self < rhs. Returns a U8 tensor (0 or 1).
820    pub fn lt(&self, rhs: &Self) -> Result<Self> {
821        self.cmp_op(rhs, CmpOp::Lt)
822    }
823
824    /// Element-wise less-or-equal: self <= rhs. Returns a U8 tensor (0 or 1).
825    pub fn le(&self, rhs: &Self) -> Result<Self> {
826        self.cmp_op(rhs, CmpOp::Le)
827    }
828
829    /// Generic comparison operation dispatch.
830    /// Produces a U8-dtype tensor (non-differentiable, Op::None).
831    fn cmp_op(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
832        let storage_lhs = self.read_storage()?;
833        let storage_rhs = rhs.read_storage()?;
834        let result = B::cmp_op(
835            op,
836            &storage_lhs,
837            &self.inner.layout,
838            &storage_rhs,
839            &rhs.inner.layout,
840        )?;
841        let result_shape = Shape::broadcast_shape(self.shape(), rhs.shape())?;
842        let result_layout = Layout::contiguous(result_shape);
843        // Comparisons are non-differentiable — no autograd tracking
844        Ok(Self::from_storage(
845            result,
846            result_layout,
847            DType::U8,
848            self.inner.device.clone(),
849            Op::None,
850        ))
851    }
852
853    // Unary operations
854
855    /// Element-wise negation: -self.
856    pub fn neg(&self) -> Result<Self> {
857        self.unary_op(UnaryOp::Neg)
858    }
859
860    /// Element-wise absolute value.
861    pub fn abs(&self) -> Result<Self> {
862        self.unary_op(UnaryOp::Abs)
863    }
864
865    /// Element-wise exponential: e^x.
866    pub fn exp(&self) -> Result<Self> {
867        self.unary_op(UnaryOp::Exp)
868    }
869
870    /// Element-wise natural logarithm.
871    pub fn log(&self) -> Result<Self> {
872        self.unary_op(UnaryOp::Log)
873    }
874
875    /// Element-wise square root.
876    pub fn sqrt(&self) -> Result<Self> {
877        self.unary_op(UnaryOp::Sqrt)
878    }
879
880    /// Element-wise square: x².
881    pub fn square(&self) -> Result<Self> {
882        self.unary_op(UnaryOp::Square)
883    }
884
885    /// ReLU activation: max(0, x).
886    pub fn relu(&self) -> Result<Self> {
887        self.unary_op(UnaryOp::Relu)
888    }
889
890    /// Sigmoid activation: 1 / (1 + e^(-x)).
891    pub fn sigmoid(&self) -> Result<Self> {
892        self.unary_op(UnaryOp::Sigmoid)
893    }
894
895    /// Tanh activation.
896    pub fn tanh(&self) -> Result<Self> {
897        self.unary_op(UnaryOp::Tanh)
898    }
899
900    /// GELU activation (Gaussian Error Linear Unit).
901    pub fn gelu(&self) -> Result<Self> {
902        self.unary_op(UnaryOp::Gelu)
903    }
904
905    /// SiLU / Swish activation: x * sigmoid(x).
906    pub fn silu(&self) -> Result<Self> {
907        self.unary_op(UnaryOp::Silu)
908    }
909
910    /// Element-wise sine.
911    pub fn sin(&self) -> Result<Self> {
912        self.unary_op(UnaryOp::Sin)
913    }
914
915    /// Element-wise cosine.
916    pub fn cos(&self) -> Result<Self> {
917        self.unary_op(UnaryOp::Cos)
918    }
919
920    /// Element-wise floor: largest integer ≤ x.
921    pub fn floor(&self) -> Result<Self> {
922        self.unary_op(UnaryOp::Floor)
923    }
924
925    /// Element-wise ceiling: smallest integer ≥ x.
926    pub fn ceil(&self) -> Result<Self> {
927        self.unary_op(UnaryOp::Ceil)
928    }
929
930    /// Element-wise round to nearest integer.
931    pub fn round(&self) -> Result<Self> {
932        self.unary_op(UnaryOp::Round)
933    }
934
935    /// Element-wise power: self^exponent.
936    pub fn powf(&self, exponent: f64) -> Result<Self> {
937        let storage = self.read_storage()?;
938        let result = B::powf(&storage, &self.inner.layout, exponent)?;
939        let result_layout = Layout::contiguous(self.shape().clone());
940        let result_op = Op::Powf {
941            input: self.clone(),
942            exponent,
943        };
944        Ok(Self::from_storage(
945            result,
946            result_layout,
947            self.inner.dtype,
948            self.inner.device.clone(),
949            result_op,
950        ))
951    }
952
953    /// Element-wise clamp to [min, max].
954    pub fn clamp(&self, min: f64, max: f64) -> Result<Self> {
955        let storage = self.read_storage()?;
956        let result = B::clamp(&storage, &self.inner.layout, min, max)?;
957        let result_layout = Layout::contiguous(self.shape().clone());
958        let result_op = Op::Clamp {
959            input: self.clone(),
960            min,
961            max,
962        };
963        Ok(Self::from_storage(
964            result,
965            result_layout,
966            self.inner.dtype,
967            self.inner.device.clone(),
968            result_op,
969        ))
970    }
971
972    /// Conditional select: result[i] = if mask[i] != 0 { on_true[i] } else { on_false[i] }.
973    ///
974    /// `mask` is typically a U8 tensor from comparison ops.
975    /// `on_true` and `on_false` must have the same shape and dtype.
976    pub fn where_cond(mask: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
977        let mask_s = mask.read_storage()?;
978        let true_s = on_true.read_storage()?;
979        let false_s = on_false.read_storage()?;
980        let result = B::where_cond(
981            &mask_s,
982            &mask.inner.layout,
983            &true_s,
984            &on_true.inner.layout,
985            &false_s,
986            &on_false.inner.layout,
987        )?;
988        let result_layout = Layout::contiguous(on_true.shape().clone());
989        let result_op = Op::WhereCond {
990            mask: mask.clone(),
991            on_true: on_true.clone(),
992            on_false: on_false.clone(),
993        };
994        Ok(Self::from_storage(
995            result,
996            result_layout,
997            on_true.inner.dtype,
998            on_true.inner.device.clone(),
999            result_op,
1000        ))
1001    }
1002
1003    /// Gather elements along `dim` using an index tensor.
1004    ///
1005    /// `output[i][j][k] = input[index[i][j][k]][j][k]`  (when dim=0)
1006    ///
1007    /// The index tensor must have the same number of dimensions as self.
1008    /// The output has the same shape as the index tensor.
1009    pub fn gather(&self, dim: usize, index: &Self) -> Result<Self> {
1010        let input_s = self.read_storage()?;
1011        let index_s = index.read_storage()?;
1012        let result = B::gather(
1013            &input_s,
1014            &self.inner.layout,
1015            &index_s,
1016            &index.inner.layout,
1017            dim,
1018        )?;
1019        let result_layout = Layout::contiguous(index.shape().clone());
1020        let result_op = Op::Gather {
1021            input: self.clone(),
1022            index: index.clone(),
1023            dim,
1024        };
1025        Ok(Self::from_storage(
1026            result,
1027            result_layout,
1028            self.inner.dtype,
1029            self.inner.device.clone(),
1030            result_op,
1031        ))
1032    }
1033
1034    /// Fill elements where `mask != 0` with `value`, keeping other elements.
1035    ///
1036    /// `result[i] = if mask[i] != 0 { value } else { self[i] }`
1037    ///
1038    /// This is implemented via `where_cond` so autograd is automatic.
1039    pub fn masked_fill(&self, mask: &Self, value: f64) -> Result<Self> {
1040        let fill = Self::full(self.shape().clone(), value, self.dtype(), self.device())?;
1041        Self::where_cond(mask, &fill, self)
1042    }
1043
1044    /// Pad the last N dimensions with constant `value`.
1045    ///
1046    /// `padding` is a list of `[before, after]` pairs, one per dimension,
1047    /// applied to the **last** dimensions of the tensor.
1048    ///
1049    /// Example: `pad(&[[1, 1], [2, 2]], 0.0)` pads the last 2 dims.
1050    pub fn pad(&self, padding: &[[usize; 2]], value: f64) -> Result<Self> {
1051        let rank = self.rank();
1052        if padding.len() > rank {
1053            return Err(Error::msg(format!(
1054                "pad: {} padding pairs but tensor rank is {}",
1055                padding.len(),
1056                rank
1057            )));
1058        }
1059
1060        // Build full-rank padding: leading dims get [0,0]
1061        let mut full_pad = vec![[0usize; 2]; rank];
1062        let offset = rank - padding.len();
1063        for (i, p) in padding.iter().enumerate() {
1064            full_pad[offset + i] = *p;
1065        }
1066
1067        // Compute output shape
1068        let in_dims = self.dims();
1069        let out_dims: Vec<usize> = in_dims
1070            .iter()
1071            .zip(full_pad.iter())
1072            .map(|(&d, &[b, a])| d + b + a)
1073            .collect();
1074
1075        // If no padding at all, just return a clone
1076        if full_pad.iter().all(|&[b, a]| b == 0 && a == 0) {
1077            return Ok(self.clone());
1078        }
1079
1080        // Build result by concatenating pad tensors along each dimension
1081        let mut current = self.clone();
1082        for d in (0..rank).rev() {
1083            let [before, after] = full_pad[d];
1084            if before == 0 && after == 0 {
1085                continue;
1086            }
1087            let mut cur_dims = current.dims().to_vec();
1088            let mut parts: Vec<Self> = Vec::new();
1089
1090            if before > 0 {
1091                cur_dims[d] = before;
1092                let pad_before = Self::full(
1093                    Shape::new(cur_dims.clone()),
1094                    value,
1095                    self.dtype(),
1096                    self.device(),
1097                )?;
1098                cur_dims[d] = current.dims()[d]; // restore
1099                parts.push(pad_before);
1100            }
1101            parts.push(current);
1102            if after > 0 {
1103                cur_dims[d] = after;
1104                let pad_after = Self::full(
1105                    Shape::new(cur_dims.clone()),
1106                    value,
1107                    self.dtype(),
1108                    self.device(),
1109                )?;
1110                parts.push(pad_after);
1111            }
1112            current = Self::cat(&parts, d)?;
1113        }
1114
1115        // Wrap in Op::Pad for clean backward
1116        let result_layout = Layout::contiguous(Shape::new(out_dims));
1117        let pad_op = Op::Pad {
1118            input: self.clone(),
1119            padding: full_pad,
1120        };
1121
1122        // Re-wrap with the Pad op so backward can narrow back
1123        let storage = current.read_storage()?;
1124        Ok(Self::from_storage(
1125            storage.clone(),
1126            result_layout,
1127            self.inner.dtype,
1128            self.inner.device.clone(),
1129            pad_op,
1130        ))
1131    }
1132
1133    /// Return the `k` largest elements along `dim`.
1134    ///
1135    /// Returns `(values, indices)` where both have the same shape as self
1136    /// except dimension `dim` has size `k`.
1137    ///
1138    /// Non-differentiable (returns detached values).
1139    #[allow(clippy::needless_range_loop)]
1140    pub fn topk(&self, k: usize, dim: usize) -> Result<(Self, Vec<usize>)> {
1141        if dim >= self.rank() {
1142            return Err(Error::DimOutOfRange {
1143                dim,
1144                rank: self.rank(),
1145            });
1146        }
1147        let dims = self.dims();
1148        let dim_size = dims[dim];
1149        if k > dim_size {
1150            return Err(Error::msg(format!(
1151                "topk: k={} exceeds dim {} size {}",
1152                k, dim, dim_size
1153            )));
1154        }
1155
1156        let data = self.contiguous()?.to_f64_vec()?;
1157
1158        // Output shape: same as input but dim has size k
1159        let mut out_dims = dims.to_vec();
1160        out_dims[dim] = k;
1161        let out_size: usize = out_dims.iter().product();
1162        let mut out_values = vec![0.0f64; out_size];
1163        let mut out_indices = vec![0usize; out_size];
1164
1165        // Number of "slices" along dim
1166        let outer: usize = dims[..dim].iter().product();
1167        let inner: usize = dims[dim + 1..].iter().product();
1168
1169        for o in 0..outer {
1170            for i in 0..inner {
1171                // Collect all elements along this dim-slice
1172                let mut slice: Vec<(f64, usize)> = (0..dim_size)
1173                    .map(|d| {
1174                        let flat = o * (dim_size * inner) + d * inner + i;
1175                        (data[flat], d)
1176                    })
1177                    .collect();
1178                // Sort descending
1179                slice.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1180
1181                // Write top-k
1182                for j in 0..k {
1183                    let out_flat = o * (k * inner) + j * inner + i;
1184                    out_values[out_flat] = slice[j].0;
1185                    out_indices[out_flat] = slice[j].1;
1186                }
1187            }
1188        }
1189
1190        let values = Self::from_f64_slice(
1191            &out_values,
1192            Shape::new(out_dims),
1193            self.dtype(),
1194            self.device(),
1195        )?;
1196        Ok((values, out_indices))
1197    }
1198
1199    /// Generic unary operation dispatch.
1200    fn unary_op(&self, op: UnaryOp) -> Result<Self> {
1201        let storage = self.read_storage()?;
1202        let result = B::unary_op(op, &storage, &self.inner.layout)?;
1203        let result_layout = Layout::contiguous(self.shape().clone());
1204        let result_op = Op::Unary {
1205            input: self.clone(),
1206            op,
1207        };
1208        Ok(Self::from_storage(
1209            result,
1210            result_layout,
1211            self.inner.dtype,
1212            self.inner.device.clone(),
1213            result_op,
1214        ))
1215    }
1216
1217    // Reductions
1218
1219    /// Sum all elements, returning a scalar tensor.
1220    pub fn sum_all(&self) -> Result<Self> {
1221        self.reduce_op(ReduceOp::Sum, &[], false)
1222    }
1223
1224    /// Sum along a specific dimension.
1225    pub fn sum(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1226        self.reduce_op(ReduceOp::Sum, &[dim], keep_dim)
1227    }
1228
1229    /// Mean of all elements, returning a scalar tensor.
1230    pub fn mean_all(&self) -> Result<Self> {
1231        self.reduce_op(ReduceOp::Mean, &[], false)
1232    }
1233
1234    /// Mean along a specific dimension.
1235    pub fn mean(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1236        self.reduce_op(ReduceOp::Mean, &[dim], keep_dim)
1237    }
1238
1239    /// Max along a specific dimension.
1240    pub fn max(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1241        self.reduce_op(ReduceOp::Max, &[dim], keep_dim)
1242    }
1243
1244    /// Min along a specific dimension.
1245    pub fn min(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1246        self.reduce_op(ReduceOp::Min, &[dim], keep_dim)
1247    }
1248
1249    /// ArgMax along a specific dimension (returns i64 indices).
1250    pub fn argmax(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1251        self.reduce_op(ReduceOp::ArgMax, &[dim], keep_dim)
1252    }
1253
1254    /// ArgMin along a specific dimension (returns i64 indices).
1255    pub fn argmin(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1256        self.reduce_op(ReduceOp::ArgMin, &[dim], keep_dim)
1257    }
1258
1259    /// Generic reduction dispatch.
1260    fn reduce_op(&self, op: ReduceOp, dims: &[usize], keep_dim: bool) -> Result<Self> {
1261        // Validate dimensions
1262        for &d in dims {
1263            if d >= self.rank() {
1264                return Err(Error::DimOutOfRange {
1265                    dim: d,
1266                    rank: self.rank(),
1267                });
1268            }
1269        }
1270        let storage = self.read_storage()?;
1271        let result = B::reduce_op(op, &storage, &self.inner.layout, dims, keep_dim)?;
1272
1273        // Compute result shape
1274        let result_shape = if dims.is_empty() {
1275            // Reduce all → scalar
1276            Shape::from(())
1277        } else if keep_dim {
1278            let mut new_dims = self.dims().to_vec();
1279            for &d in dims {
1280                new_dims[d] = 1;
1281            }
1282            Shape::new(new_dims)
1283        } else {
1284            let new_dims: Vec<usize> = self
1285                .dims()
1286                .iter()
1287                .enumerate()
1288                .filter(|(i, _)| !dims.contains(i))
1289                .map(|(_, &d)| d)
1290                .collect();
1291            if new_dims.is_empty() {
1292                Shape::from(())
1293            } else {
1294                Shape::new(new_dims)
1295            }
1296        };
1297
1298        let result_layout = Layout::contiguous(result_shape);
1299        let result_dtype = match op {
1300            ReduceOp::ArgMax | ReduceOp::ArgMin => DType::I64,
1301            _ => self.inner.dtype,
1302        };
1303        let result_op = Op::Reduce {
1304            input: self.clone(),
1305            op,
1306            dims: dims.to_vec(),
1307            keep_dim,
1308        };
1309        Ok(Self::from_storage(
1310            result,
1311            result_layout,
1312            result_dtype,
1313            self.inner.device.clone(),
1314            result_op,
1315        ))
1316    }
1317
1318    // Composite operations (built from primitives)
1319
1320    /// Softmax along a dimension: softmax(x)_i = exp(x_i) / sum(exp(x_j))
1321    ///
1322    /// Uses the numerically stable trick: subtract max before exp.
1323    /// This is built from existing differentiable ops (exp, sum, div, sub)
1324    /// so gradients flow through automatically.
1325    pub fn softmax(&self, dim: usize) -> Result<Self> {
1326        // max(x, dim, keep_dim=true) — used as a constant for stability
1327        let max_val = self.max(dim, true)?;
1328        // We detach max so it's treated as a constant in backward
1329        let max_detached = max_val.detach();
1330        let shifted = self.sub(&max_detached)?; // x - max(x)
1331        let exp_x = shifted.exp()?;
1332        let sum_exp = exp_x.sum(dim, true)?;
1333        exp_x.div(&sum_exp)
1334    }
1335
1336    /// Log-softmax along a dimension: log(softmax(x)) but numerically stable.
1337    ///
1338    /// log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x))))
1339    pub fn log_softmax(&self, dim: usize) -> Result<Self> {
1340        let max_val = self.max(dim, true)?.detach();
1341        let shifted = self.sub(&max_val)?;
1342        let exp_x = shifted.exp()?;
1343        let sum_exp = exp_x.sum(dim, true)?;
1344        let log_sum_exp = sum_exp.log()?;
1345        shifted.sub(&log_sum_exp)
1346    }
1347
1348    /// Variance along a dimension: var(x) = mean((x - mean(x))²)
1349    pub fn var(&self, dim: usize, keep_dim: bool) -> Result<Self> {
1350        let mu = self.mean(dim, true)?;
1351        let centered = self.sub(&mu)?;
1352        let sq = centered.square()?;
1353        sq.mean(dim, keep_dim)
1354    }
1355
1356    /// Concatenate tensors along a dimension.
1357    ///
1358    /// All tensors must have the same shape except in the concatenation dimension.
1359    /// This creates a new tensor by copying data from all inputs.
1360    pub fn cat(tensors: &[Self], dim: usize) -> Result<Self> {
1361        if tensors.is_empty() {
1362            return Err(Error::msg("cat: empty tensor list"));
1363        }
1364        if tensors.len() == 1 {
1365            return Ok(tensors[0].clone());
1366        }
1367
1368        let first = &tensors[0];
1369        let rank = first.rank();
1370        if dim >= rank {
1371            return Err(Error::DimOutOfRange { dim, rank });
1372        }
1373
1374        // Validate shapes: all dims must match except `dim`
1375        for (i, t) in tensors.iter().enumerate().skip(1) {
1376            if t.rank() != rank {
1377                return Err(Error::msg(format!(
1378                    "cat: tensor {} has rank {} but expected {}",
1379                    i,
1380                    t.rank(),
1381                    rank
1382                )));
1383            }
1384            if t.dtype() != first.dtype() {
1385                return Err(Error::DTypeMismatch {
1386                    expected: first.dtype(),
1387                    got: t.dtype(),
1388                });
1389            }
1390            for d in 0..rank {
1391                if d != dim && t.dims()[d] != first.dims()[d] {
1392                    return Err(Error::msg(format!(
1393                        "cat: tensor {} has size {} at dim {} but expected {}",
1394                        i,
1395                        t.dims()[d],
1396                        d,
1397                        first.dims()[d]
1398                    )));
1399                }
1400            }
1401        }
1402
1403        // Compute output shape
1404        let cat_size: usize = tensors.iter().map(|t| t.dims()[dim]).sum();
1405        let mut out_dims = first.dims().to_vec();
1406        out_dims[dim] = cat_size;
1407        let out_shape = Shape::new(out_dims.clone());
1408
1409        // Record per-input sizes along the cat dim for backward
1410        let sizes: Vec<usize> = tensors.iter().map(|t| t.dims()[dim]).collect();
1411
1412        // Collect (storage, layout) pairs for Backend::cat
1413        let inner_guards: Vec<_> = tensors
1414            .iter()
1415            .map(|t| t.inner.storage.read().unwrap())
1416            .collect();
1417        let pairs: Vec<(&B::Storage, &Layout)> = tensors
1418            .iter()
1419            .enumerate()
1420            .map(|(i, t)| (&*inner_guards[i], &t.inner.layout))
1421            .collect();
1422
1423        let storage = B::cat(&pairs, &out_shape, dim)?;
1424        let layout = Layout::contiguous(out_shape);
1425        let op = Op::Cat {
1426            inputs: tensors.to_vec(),
1427            dim,
1428            sizes,
1429        };
1430        Ok(Self::from_storage(
1431            storage,
1432            layout,
1433            first.dtype(),
1434            first.device().clone(),
1435            op,
1436        ))
1437    }
1438
1439    /// Split a tensor into `n` equal chunks along a dimension.
1440    /// If the dimension size is not evenly divisible, the last chunk is smaller.
1441    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Self>> {
1442        if dim >= self.rank() {
1443            return Err(Error::DimOutOfRange {
1444                dim,
1445                rank: self.rank(),
1446            });
1447        }
1448        let dim_size = self.dims()[dim];
1449        let chunk_size = dim_size.div_ceil(n);
1450        let mut chunks = Vec::new();
1451        let mut start = 0;
1452        while start < dim_size {
1453            let len = chunk_size.min(dim_size - start);
1454            chunks.push(self.narrow(dim, start, len)?);
1455            start += len;
1456        }
1457        Ok(chunks)
1458    }
1459
1460    /// Expand a tensor to a larger shape by repeating data along size-1 dims.
1461    /// Only dims that are currently size 1 can be expanded.
1462    /// A size of -1 (usize::MAX) means don't change that dim.
1463    pub fn expand(&self, target_shape: impl Into<Shape>) -> Result<Self> {
1464        let target = target_shape.into();
1465        let self_dims = self.dims();
1466        let target_dims = target.dims();
1467
1468        if self_dims.len() != target_dims.len() {
1469            return Err(Error::msg(format!(
1470                "expand: rank mismatch — self {:?} vs target {:?}",
1471                self_dims, target_dims
1472            )));
1473        }
1474
1475        for (i, (&sd, &td)) in self_dims.iter().zip(target_dims.iter()).enumerate() {
1476            if sd != td && sd != 1 {
1477                return Err(Error::msg(format!(
1478                    "expand: can only expand size-1 dims, but dim {} has size {}",
1479                    i, sd
1480                )));
1481            }
1482        }
1483
1484        // Zero-copy expand via stride tricks:
1485        // For dims where self_dim == 1 and target_dim > 1, set stride to 0
1486        // (the single element is "repeated" without copying data).
1487        let self_strides = self.inner.layout.strides();
1488        let mut new_strides = self_strides.to_vec();
1489        for d in 0..target_dims.len() {
1490            if self_dims[d] == 1 && target_dims[d] > 1 {
1491                new_strides[d] = 0;
1492            }
1493        }
1494
1495        let new_layout = Layout::new(target, new_strides, self.inner.layout.offset());
1496
1497        // Share the same storage — no copy!
1498        Ok(Tensor {
1499            inner: Arc::new(TensorInner {
1500                id: TensorId::new(),
1501                storage: Arc::clone(&self.inner.storage),
1502                layout: new_layout,
1503                dtype: self.inner.dtype,
1504                device: self.inner.device.clone(),
1505                op: Op::None,
1506                is_variable: false,
1507            }),
1508        })
1509    }
1510
1511    // Stack — concatenate with a new dimension
1512
1513    /// Stack tensors along a new dimension.
1514    ///
1515    /// All tensors must have the same shape. Inserts a new dimension at `dim`.
1516    /// `stack([a, b], dim=0)` where a,b are shape [2,3] → [2, 2, 3].
1517    pub fn stack(tensors: &[Self], dim: usize) -> Result<Self> {
1518        if tensors.is_empty() {
1519            return Err(Error::msg("stack: empty tensor list"));
1520        }
1521        let first_shape = tensors[0].shape().clone();
1522        let rank = first_shape.dims().len();
1523        if dim > rank {
1524            return Err(Error::DimOutOfRange {
1525                dim,
1526                rank: rank + 1,
1527            });
1528        }
1529        // Validate all shapes match
1530        for (i, t) in tensors.iter().enumerate().skip(1) {
1531            if t.shape() != &first_shape {
1532                return Err(Error::msg(format!(
1533                    "stack: tensor {} has shape {:?} but expected {:?}",
1534                    i,
1535                    t.dims(),
1536                    first_shape.dims()
1537                )));
1538            }
1539        }
1540        // Unsqueeze each tensor at `dim`, then cat
1541        let unsqueezed: Vec<Self> = tensors
1542            .iter()
1543            .map(|t| t.unsqueeze(dim))
1544            .collect::<Result<Vec<_>>>()?;
1545        Self::cat(&unsqueezed, dim)
1546    }
1547
1548    // Arange — generate a sequence
1549
1550    /// Create a 1-D tensor with values [0, 1, ..., n-1].
1551    pub fn arange(n: usize, dtype: DType, device: &B::Device) -> Result<Self> {
1552        let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
1553        Self::from_f64_slice(&data, n, dtype, device)
1554    }
1555
1556    /// Create a 1-D tensor with values [start, start+step, ..., <end).
1557    pub fn arange_step(
1558        start: f64,
1559        end: f64,
1560        step: f64,
1561        dtype: DType,
1562        device: &B::Device,
1563    ) -> Result<Self> {
1564        if step == 0.0 {
1565            return Err(Error::msg("arange_step: step cannot be zero"));
1566        }
1567        let mut data = Vec::new();
1568        let mut v = start;
1569        if step > 0.0 {
1570            while v < end {
1571                data.push(v);
1572                v += step;
1573            }
1574        } else {
1575            while v > end {
1576                data.push(v);
1577                v += step;
1578            }
1579        }
1580        let len = data.len();
1581        Self::from_f64_slice(&data, len, dtype, device)
1582    }
1583
1584    // Triangular masks — triu / tril
1585
1586    /// Upper triangular mask: returns a 2-D tensor of shape [n, m] where
1587    /// elements on and above the `diagonal`-th diagonal are 1.0, rest 0.0.
1588    ///
1589    /// `diagonal = 0` → main diagonal. `diagonal > 0` → above. `diagonal < 0` → below.
1590    pub fn triu(
1591        n: usize,
1592        m: usize,
1593        diagonal: i64,
1594        dtype: DType,
1595        device: &B::Device,
1596    ) -> Result<Self> {
1597        let mut data = vec![0.0f64; n * m];
1598        for i in 0..n {
1599            for j in 0..m {
1600                if (j as i64) >= (i as i64) + diagonal {
1601                    data[i * m + j] = 1.0;
1602                }
1603            }
1604        }
1605        Self::from_f64_slice(&data, (n, m), dtype, device)
1606    }
1607
1608    /// Lower triangular mask: returns a 2-D tensor of shape [n, m] where
1609    /// elements on and below the `diagonal`-th diagonal are 1.0, rest 0.0.
1610    pub fn tril(
1611        n: usize,
1612        m: usize,
1613        diagonal: i64,
1614        dtype: DType,
1615        device: &B::Device,
1616    ) -> Result<Self> {
1617        let mut data = vec![0.0f64; n * m];
1618        for i in 0..n {
1619            for j in 0..m {
1620                if (j as i64) <= (i as i64) + diagonal {
1621                    data[i * m + j] = 1.0;
1622                }
1623            }
1624        }
1625        Self::from_f64_slice(&data, (n, m), dtype, device)
1626    }
1627
1628    // Matrix multiplication
1629
1630    /// Matrix multiplication: self @ rhs.
1631    ///
1632    /// - [m, k] @ [k, n] → [m, n]
1633    /// - Batched: [b, m, k] @ [b, k, n] → [b, m, n]
1634    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
1635        if self.dtype() != rhs.dtype() {
1636            return Err(Error::DTypeMismatch {
1637                expected: self.dtype(),
1638                got: rhs.dtype(),
1639            });
1640        }
1641        // Validate shapes for matmul
1642        if self.rank() < 2 || rhs.rank() < 2 {
1643            return Err(Error::RankMismatch {
1644                expected: 2,
1645                got: self.rank().min(rhs.rank()),
1646            });
1647        }
1648        let lhs_dims = self.dims();
1649        let rhs_dims = rhs.dims();
1650        let k1 = lhs_dims[lhs_dims.len() - 1];
1651        let k2 = rhs_dims[rhs_dims.len() - 2];
1652        if k1 != k2 {
1653            let m = lhs_dims[lhs_dims.len() - 2];
1654            let n = rhs_dims[rhs_dims.len() - 1];
1655            return Err(Error::MatmulShapeMismatch { m, k1, k2, n });
1656        }
1657
1658        let storage_lhs = self.read_storage()?;
1659        let storage_rhs = rhs.read_storage()?;
1660        let result = B::matmul(
1661            &storage_lhs,
1662            &self.inner.layout,
1663            &storage_rhs,
1664            &rhs.inner.layout,
1665        )?;
1666
1667        // Result shape: [..., m, n]
1668        let m = lhs_dims[lhs_dims.len() - 2];
1669        let n = rhs_dims[rhs_dims.len() - 1];
1670        let mut result_dims: Vec<usize> = lhs_dims[..lhs_dims.len() - 2].to_vec();
1671        result_dims.push(m);
1672        result_dims.push(n);
1673        let result_layout = Layout::contiguous(Shape::new(result_dims));
1674        let result_op = Op::Matmul {
1675            lhs: self.clone(),
1676            rhs: rhs.clone(),
1677        };
1678        Ok(Self::from_storage(
1679            result,
1680            result_layout,
1681            self.inner.dtype,
1682            self.inner.device.clone(),
1683            result_op,
1684        ))
1685    }
1686
1687    // 2D Convolution
1688
1689    /// 2D convolution: applies convolution filters to a 4D input tensor.
1690    ///
1691    /// - `self` (input): `[N, C_in, H, W]`
1692    /// - `weight`:       `[C_out, C_in, kH, kW]`
1693    /// - `bias`:         optional `[C_out]`
1694    /// - `stride`:       `[sH, sW]`
1695    /// - `padding`:      `[pH, pW]`
1696    ///
1697    /// Returns tensor of shape `[N, C_out, H_out, W_out]` where
1698    /// `H_out = (H + 2*pH - kH) / sH + 1`.
1699    #[allow(clippy::needless_range_loop)]
1700    pub fn conv2d(
1701        &self,
1702        weight: &Self,
1703        bias: Option<&Self>,
1704        stride: [usize; 2],
1705        padding: [usize; 2],
1706    ) -> Result<Self> {
1707        // Validate ranks
1708        if self.rank() != 4 {
1709            return Err(Error::msg(format!(
1710                "conv2d input must be 4D [N,C,H,W], got rank {}",
1711                self.rank()
1712            )));
1713        }
1714        if weight.rank() != 4 {
1715            return Err(Error::msg(format!(
1716                "conv2d weight must be 4D [C_out,C_in,kH,kW], got rank {}",
1717                weight.rank()
1718            )));
1719        }
1720
1721        let in_dims = self.dims();
1722        let w_dims = weight.dims();
1723        let (n, c_in, h, w) = (in_dims[0], in_dims[1], in_dims[2], in_dims[3]);
1724        let (c_out, wc_in, kh, kw) = (w_dims[0], w_dims[1], w_dims[2], w_dims[3]);
1725
1726        if c_in != wc_in {
1727            return Err(Error::msg(format!(
1728                "conv2d: input channels {} != weight channels {}",
1729                c_in, wc_in
1730            )));
1731        }
1732
1733        let [sh, sw] = stride;
1734        let [ph, pw] = padding;
1735
1736        if h + 2 * ph < kh || w + 2 * pw < kw {
1737            return Err(Error::msg("conv2d: kernel larger than padded input"));
1738        }
1739
1740        let h_out = (h + 2 * ph - kh) / sh + 1;
1741        let w_out = (w + 2 * pw - kw) / sw + 1;
1742
1743        // Get contiguous data
1744        let input_data = self.contiguous()?.to_f64_vec()?;
1745        let weight_data = weight.contiguous()?.to_f64_vec()?;
1746        let bias_data = match bias {
1747            Some(b) => Some(b.contiguous()?.to_f64_vec()?),
1748            None => None,
1749        };
1750
1751        let out_size = n * c_out * h_out * w_out;
1752        let mut output = vec![0.0f64; out_size];
1753
1754        // im2col + GEMM approach:
1755        // For each batch sample:
1756        //   1. im2col: unroll input patches → columns [c_in*kh*kw, h_out*w_out]
1757        //   2. GEMM: weight [c_out, c_in*kh*kw] × columns → out [c_out, h_out*w_out]
1758        let col_rows = c_in * kh * kw;
1759        let col_cols = h_out * w_out;
1760        let mut columns = vec![0.0f64; col_rows * col_cols];
1761        let sample_size = c_in * h * w;
1762
1763        for ni in 0..n {
1764            // im2col for this sample
1765            let in_offset = ni * sample_size;
1766            im2col(
1767                &input_data[in_offset..in_offset + sample_size],
1768                c_in,
1769                h,
1770                w,
1771                kh,
1772                kw,
1773                sh,
1774                sw,
1775                ph,
1776                pw,
1777                h_out,
1778                w_out,
1779                &mut columns,
1780            );
1781
1782            // GEMM: output[ni] = weight × columns + bias
1783            let out_offset = ni * c_out * h_out * w_out;
1784            gemm(
1785                &weight_data,
1786                &columns,
1787                &mut output[out_offset..out_offset + c_out * col_cols],
1788                c_out,
1789                col_cols,
1790                col_rows,
1791            );
1792
1793            // Add bias
1794            if let Some(ref bd) = bias_data {
1795                for co in 0..c_out {
1796                    let row_start = out_offset + co * col_cols;
1797                    for j in 0..col_cols {
1798                        output[row_start + j] += bd[co];
1799                    }
1800                }
1801            }
1802        }
1803
1804        let result_shape = Shape::new(vec![n, c_out, h_out, w_out]);
1805        let result_op = Op::Conv2d {
1806            input: self.clone(),
1807            weight: weight.clone(),
1808            bias: bias.cloned(),
1809            stride,
1810            padding,
1811        };
1812        Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1813            Self::from_storage(
1814                {
1815                    let s = t.inner.storage.read().expect("storage lock poisoned");
1816                    s.clone()
1817                },
1818                Layout::contiguous(result_shape),
1819                self.inner.dtype,
1820                self.inner.device.clone(),
1821                result_op,
1822            )
1823        })
1824    }
1825
1826    // 2D Max Pooling
1827
1828    /// 2D max pooling on a 4D input tensor `[N, C, H, W]`.
1829    ///
1830    /// Returns `(output, indices)` where `indices` stores argmax positions
1831    /// (flat indices into the input) for backward.
1832    pub fn max_pool2d(
1833        &self,
1834        kernel_size: [usize; 2],
1835        stride: [usize; 2],
1836        padding: [usize; 2],
1837    ) -> Result<Self> {
1838        if self.rank() != 4 {
1839            return Err(Error::msg(format!(
1840                "max_pool2d input must be 4D [N,C,H,W], got rank {}",
1841                self.rank()
1842            )));
1843        }
1844
1845        let dims = self.dims();
1846        let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
1847        let [kh, kw] = kernel_size;
1848        let [sh, sw] = stride;
1849        let [ph, pw] = padding;
1850
1851        if h + 2 * ph < kh || w + 2 * pw < kw {
1852            return Err(Error::msg("max_pool2d: kernel larger than padded input"));
1853        }
1854
1855        let h_out = (h + 2 * ph - kh) / sh + 1;
1856        let w_out = (w + 2 * pw - kw) / sw + 1;
1857
1858        let input_data = self.contiguous()?.to_f64_vec()?;
1859        let out_size = n * c * h_out * w_out;
1860        let mut output = vec![f64::NEG_INFINITY; out_size];
1861        let mut indices = vec![0usize; out_size];
1862
1863        for ni in 0..n {
1864            for ci in 0..c {
1865                for oh in 0..h_out {
1866                    for ow in 0..w_out {
1867                        let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1868                        let mut max_val = f64::NEG_INFINITY;
1869                        let mut max_idx = 0usize;
1870                        for ki in 0..kh {
1871                            for kj in 0..kw {
1872                                let ih = (oh * sh + ki) as isize - ph as isize;
1873                                let iw = (ow * sw + kj) as isize - pw as isize;
1874                                if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1875                                    let ih = ih as usize;
1876                                    let iw = iw as usize;
1877                                    let in_idx = ((ni * c + ci) * h + ih) * w + iw;
1878                                    if input_data[in_idx] > max_val {
1879                                        max_val = input_data[in_idx];
1880                                        max_idx = in_idx;
1881                                    }
1882                                }
1883                            }
1884                        }
1885                        output[out_idx] = max_val;
1886                        indices[out_idx] = max_idx;
1887                    }
1888                }
1889            }
1890        }
1891
1892        let result_shape = Shape::new(vec![n, c, h_out, w_out]);
1893        let result_op = Op::MaxPool2d {
1894            input: self.clone(),
1895            kernel_size,
1896            stride,
1897            padding,
1898            indices: indices.clone(),
1899        };
1900        Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1901            Self::from_storage(
1902                {
1903                    let s = t.inner.storage.read().expect("storage lock poisoned");
1904                    s.clone()
1905                },
1906                Layout::contiguous(result_shape),
1907                self.inner.dtype,
1908                self.inner.device.clone(),
1909                result_op,
1910            )
1911        })
1912    }
1913
1914    // 2D Average Pooling
1915
1916    /// Apply 2D average pooling to a 4D tensor [N, C, H, W].
1917    pub fn avg_pool2d(
1918        &self,
1919        kernel_size: [usize; 2],
1920        stride: [usize; 2],
1921        padding: [usize; 2],
1922    ) -> Result<Self> {
1923        if self.rank() != 4 {
1924            return Err(Error::msg(format!(
1925                "avg_pool2d input must be 4D [N,C,H,W], got rank {}",
1926                self.rank()
1927            )));
1928        }
1929
1930        let dims = self.dims();
1931        let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
1932        let [kh, kw] = kernel_size;
1933        let [sh, sw] = stride;
1934        let [ph, pw] = padding;
1935
1936        if h + 2 * ph < kh || w + 2 * pw < kw {
1937            return Err(Error::msg("avg_pool2d: kernel larger than padded input"));
1938        }
1939
1940        let h_out = (h + 2 * ph - kh) / sh + 1;
1941        let w_out = (w + 2 * pw - kw) / sw + 1;
1942
1943        let input_data = self.contiguous()?.to_f64_vec()?;
1944        let out_size = n * c * h_out * w_out;
1945        let mut output = vec![0.0f64; out_size];
1946
1947        for ni in 0..n {
1948            for ci in 0..c {
1949                for oh in 0..h_out {
1950                    for ow in 0..w_out {
1951                        let out_idx = ((ni * c + ci) * h_out + oh) * w_out + ow;
1952                        let mut sum = 0.0f64;
1953                        let mut count = 0usize;
1954                        for ki in 0..kh {
1955                            for kj in 0..kw {
1956                                let ih = (oh * sh + ki) as isize - ph as isize;
1957                                let iw = (ow * sw + kj) as isize - pw as isize;
1958                                if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1959                                    let in_idx =
1960                                        ((ni * c + ci) * h + ih as usize) * w + iw as usize;
1961                                    sum += input_data[in_idx];
1962                                    count += 1;
1963                                }
1964                            }
1965                        }
1966                        output[out_idx] = if count > 0 { sum / count as f64 } else { 0.0 };
1967                    }
1968                }
1969            }
1970        }
1971
1972        let result_shape = Shape::new(vec![n, c, h_out, w_out]);
1973        let result_op = Op::AvgPool2d {
1974            input: self.clone(),
1975            kernel_size,
1976            stride,
1977            padding,
1978        };
1979        Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
1980            Self::from_storage(
1981                {
1982                    let s = t.inner.storage.read().expect("storage lock poisoned");
1983                    s.clone()
1984                },
1985                Layout::contiguous(result_shape),
1986                self.inner.dtype,
1987                self.inner.device.clone(),
1988                result_op,
1989            )
1990        })
1991    }
1992
1993    // 1D Convolution
1994
1995    /// Apply 1D convolution to a 3D tensor [N, C_in, L].
1996    /// weight: [C_out, C_in, K]
1997    #[allow(clippy::needless_range_loop)]
1998    pub fn conv1d(
1999        &self,
2000        weight: &Self,
2001        bias: Option<&Self>,
2002        stride: usize,
2003        padding: usize,
2004    ) -> Result<Self> {
2005        if self.rank() != 3 {
2006            return Err(Error::msg(format!(
2007                "conv1d input must be 3D [N,C_in,L], got rank {}",
2008                self.rank()
2009            )));
2010        }
2011        if weight.rank() != 3 {
2012            return Err(Error::msg(format!(
2013                "conv1d weight must be 3D [C_out,C_in,K], got rank {}",
2014                weight.rank()
2015            )));
2016        }
2017
2018        let in_dims = self.dims();
2019        let w_dims = weight.dims();
2020        let (n, c_in, l) = (in_dims[0], in_dims[1], in_dims[2]);
2021        let (c_out, wc_in, k) = (w_dims[0], w_dims[1], w_dims[2]);
2022
2023        if c_in != wc_in {
2024            return Err(Error::msg(format!(
2025                "conv1d: input channels {} != weight channels {}",
2026                c_in, wc_in
2027            )));
2028        }
2029        if let Some(b) = bias {
2030            if b.elem_count() != c_out {
2031                return Err(Error::msg(format!(
2032                    "conv1d: bias size {} != output channels {}",
2033                    b.elem_count(),
2034                    c_out
2035                )));
2036            }
2037        }
2038
2039        if l + 2 * padding < k {
2040            return Err(Error::msg("conv1d: kernel larger than padded input"));
2041        }
2042
2043        let l_out = (l + 2 * padding - k) / stride + 1;
2044
2045        let input_data = self.contiguous()?.to_f64_vec()?;
2046        let weight_data = weight.contiguous()?.to_f64_vec()?;
2047        let bias_data: Option<Vec<f64>> = match bias {
2048            Some(b) => Some(b.to_f64_vec()?),
2049            None => None,
2050        };
2051
2052        let out_size = n * c_out * l_out;
2053        let mut output = vec![0.0f64; out_size];
2054
2055        // im2col + GEMM for conv1d (treat as 2D with h=1)
2056        let col_rows = c_in * k;
2057        let col_cols = l_out;
2058        let mut columns = vec![0.0f64; col_rows * col_cols];
2059        let sample_size = c_in * l;
2060
2061        for ni in 0..n {
2062            // im2col for 1D: unroll patches
2063            let in_offset = ni * sample_size;
2064            im2col(
2065                &input_data[in_offset..in_offset + sample_size],
2066                c_in,
2067                1,
2068                l,
2069                1,
2070                k,
2071                1,
2072                stride,
2073                0,
2074                padding,
2075                1,
2076                l_out,
2077                &mut columns,
2078            );
2079
2080            // GEMM: output[ni] = weight × columns
2081            let out_offset = ni * c_out * l_out;
2082            gemm(
2083                &weight_data,
2084                &columns,
2085                &mut output[out_offset..out_offset + c_out * col_cols],
2086                c_out,
2087                col_cols,
2088                col_rows,
2089            );
2090
2091            // Add bias
2092            if let Some(ref bd) = bias_data {
2093                for co in 0..c_out {
2094                    let row_start = out_offset + co * col_cols;
2095                    for j in 0..col_cols {
2096                        output[row_start + j] += bd[co];
2097                    }
2098                }
2099            }
2100        }
2101
2102        let result_shape = Shape::new(vec![n, c_out, l_out]);
2103        let result_op = Op::Conv1d {
2104            input: self.clone(),
2105            weight: weight.clone(),
2106            bias: bias.cloned(),
2107            stride,
2108            padding,
2109        };
2110        Self::from_f64_slice(&output, result_shape.clone(), self.dtype(), self.device()).map(|t| {
2111            Self::from_storage(
2112                {
2113                    let s = t.inner.storage.read().expect("storage lock poisoned");
2114                    s.clone()
2115                },
2116                Layout::contiguous(result_shape),
2117                self.inner.dtype,
2118                self.inner.device.clone(),
2119                result_op,
2120            )
2121        })
2122    }
2123
2124    // Affine transform
2125
2126    /// Affine transform: result[i] = self[i] * mul + add.
2127    /// Useful for normalization and scaling.
2128    pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
2129        let storage = self.read_storage()?;
2130        let result = B::affine(&storage, &self.inner.layout, mul, add)?;
2131        let result_layout = Layout::contiguous(self.shape().clone());
2132        let result_op = Op::Affine {
2133            input: self.clone(),
2134            mul,
2135            add,
2136        };
2137        Ok(Self::from_storage(
2138            result,
2139            result_layout,
2140            self.inner.dtype,
2141            self.inner.device.clone(),
2142            result_op,
2143        ))
2144    }
2145
2146    // Data extraction (for testing and debugging)
2147
2148    /// Extract all elements as a flat Vec<f64>.
2149    pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
2150        let storage = self.read_storage()?;
2151        B::to_f64_vec(&storage, &self.inner.layout)
2152    }
2153
2154    /// Extract a scalar value (tensor must have exactly 1 element).
2155    pub fn to_scalar_f64(&self) -> Result<f64> {
2156        if self.elem_count() != 1 {
2157            return Err(Error::NotAScalar {
2158                shape: self.shape().clone(),
2159            });
2160        }
2161        let vec = self.to_f64_vec()?;
2162        Ok(vec[0])
2163    }
2164
2165    /// Convert this tensor to a different dtype.
2166    ///
2167    /// Returns a new tensor with the same shape but different element type.
2168    /// Uses the backend's on-device cast when available, avoiding host round-trips.
2169    /// Records Op::ToDtype so gradients flow back through dtype conversions.
2170    pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
2171        if self.dtype() == dtype {
2172            return Ok(self.clone());
2173        }
2174        let src_dtype = self.dtype();
2175        let guard = self.inner.storage.read().unwrap();
2176        let storage = B::cast(&*guard, &self.inner.layout, dtype, self.device())?;
2177        let layout = Layout::contiguous(self.shape().clone());
2178        let op = if self.is_variable() {
2179            Op::ToDtype {
2180                input: self.clone(),
2181                src_dtype,
2182            }
2183        } else {
2184            Op::None
2185        };
2186        Ok(Self::from_storage(
2187            storage,
2188            layout,
2189            dtype,
2190            self.device().clone(),
2191            op,
2192        ))
2193    }
2194
2195    /// Display the tensor contents in a human-readable format.
2196    pub fn to_string_with_data(&self) -> Result<String> {
2197        let data = self.to_f64_vec()?;
2198        Ok(format!(
2199            "Tensor(shape={}, dtype={}, data={:?})",
2200            self.shape(),
2201            self.dtype(),
2202            data
2203        ))
2204    }
2205
2206    // Autograd
2207
2208    /// Compute gradients via reverse-mode automatic differentiation.
2209    ///
2210    /// This tensor must be a scalar (single element). Returns a GradStore
2211    /// containing gradients for all tensors in the computation graph.
2212    ///
2213    /// # Example
2214    /// ```ignore
2215    /// let a = Tensor::from_f64_slice(&[2.0], 1, DType::F32, &dev)?.set_variable();
2216    /// let b = Tensor::from_f64_slice(&[3.0], 1, DType::F32, &dev)?.set_variable();
2217    /// let c = a.mul(&b)?;
2218    /// let grads = c.backward()?;
2219    /// // grad_a = b = 3.0, grad_b = a = 2.0
2220    /// ```
2221    pub fn backward(&self) -> Result<crate::backprop::GradStore<B>> {
2222        crate::backprop::backward(self)
2223    }
2224
2225    /// Create a detached copy: same data but no gradient tracking.
2226    /// The new tensor has Op::None and a fresh TensorId.
2227    pub fn detach(&self) -> Self {
2228        self.view_with_layout(self.layout().clone(), Op::None)
2229    }
2230
2231    /// Freeze this tensor: same data and id, but `is_variable = false`.
2232    ///
2233    /// Frozen tensors do NOT accumulate gradients during backward().
2234    /// This is the functional equivalent of PyTorch's `param.requires_grad_(False)`.
2235    pub fn freeze(&self) -> Self {
2236        Tensor {
2237            inner: Arc::new(TensorInner {
2238                id: self.inner.id,
2239                storage: Arc::clone(&self.inner.storage),
2240                layout: self.inner.layout.clone(),
2241                dtype: self.inner.dtype,
2242                device: self.inner.device.clone(),
2243                op: self.inner.op.clone(),
2244                is_variable: false,
2245            }),
2246        }
2247    }
2248
2249    /// Unfreeze this tensor: same data and id, but `is_variable = true`.
2250    ///
2251    /// This is the opposite of `freeze()`.
2252    pub fn unfreeze(&self) -> Self {
2253        self.set_variable_ref()
2254    }
2255
2256    // Additional composite operations
2257
2258    /// Select entries along `dim` using the given 1-D index tensor.
2259    ///
2260    /// The output has the same rank, with `dim` resized to `indices.len()`.
2261    /// Wraps the `Backend::index_select` kernel.
2262    pub fn index_select(&self, dim: usize, indices: &Self) -> Result<Self> {
2263        if dim >= self.rank() {
2264            return Err(Error::DimOutOfRange {
2265                dim,
2266                rank: self.rank(),
2267            });
2268        }
2269        let guard = self.inner.storage.read().unwrap();
2270        let idx_guard = indices.inner.storage.read().unwrap();
2271        let storage = B::index_select(
2272            &*guard,
2273            &self.inner.layout,
2274            &*idx_guard,
2275            &indices.inner.layout,
2276            dim,
2277        )?;
2278        let mut out_dims = self.dims().to_vec();
2279        out_dims[dim] = indices.elem_count();
2280        let layout = Layout::contiguous(Shape::new(out_dims));
2281        // Record op for autograd — gradient flows back to input via scatter-add
2282        let op = Op::IndexSelect {
2283            input: self.clone(),
2284            indices: indices.clone(),
2285            dim,
2286        };
2287        Ok(Self::from_storage(
2288            storage,
2289            layout,
2290            self.dtype(),
2291            self.device().clone(),
2292            op,
2293        ))
2294    }
2295
2296    /// Split a tensor into chunks of `split_size` along `dim`.
2297    ///
2298    /// The last chunk may be smaller if the dimension is not evenly divisible.
2299    pub fn split(&self, split_size: usize, dim: usize) -> Result<Vec<Self>> {
2300        if dim >= self.rank() {
2301            return Err(Error::DimOutOfRange {
2302                dim,
2303                rank: self.rank(),
2304            });
2305        }
2306        if split_size == 0 {
2307            return Err(Error::msg("split: split_size must be > 0"));
2308        }
2309        let dim_size = self.dims()[dim];
2310        let mut parts = Vec::new();
2311        let mut start = 0;
2312        while start < dim_size {
2313            let len = split_size.min(dim_size - start);
2314            parts.push(self.narrow(dim, start, len)?);
2315            start += len;
2316        }
2317        Ok(parts)
2318    }
2319
2320    /// Flatten dimensions `start_dim..=end_dim` into a single dimension.
2321    ///
2322    /// Negative-style indexing is **not** supported; both bounds are inclusive
2323    /// and zero-based.
2324    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Self> {
2325        let rank = self.rank();
2326        if start_dim >= rank || end_dim >= rank || start_dim > end_dim {
2327            return Err(Error::msg(format!(
2328                "flatten: invalid range [{}, {}] for rank {}",
2329                start_dim, end_dim, rank
2330            )));
2331        }
2332        let dims = self.dims();
2333        let mut new_dims: Vec<usize> = Vec::new();
2334        new_dims.extend_from_slice(&dims[..start_dim]);
2335        let flat: usize = dims[start_dim..=end_dim].iter().product();
2336        new_dims.push(flat);
2337        if end_dim + 1 < rank {
2338            new_dims.extend_from_slice(&dims[end_dim + 1..]);
2339        }
2340        self.reshape(new_dims)
2341    }
2342
2343    /// Standard deviation along a dimension.
2344    ///
2345    /// Computed as `sqrt(var(x, dim))`.
2346    pub fn std(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2347        self.var(dim, keep_dim)?.sqrt()
2348    }
2349
2350    /// Element-wise reciprocal: `1 / x`.
2351    pub fn reciprocal(&self) -> Result<Self> {
2352        let one = Self::ones(self.dims(), self.dtype(), self.device())?;
2353        one.div(self)
2354    }
2355
2356    /// Element-wise reciprocal square-root: `1 / sqrt(x)`.
2357    pub fn rsqrt(&self) -> Result<Self> {
2358        self.sqrt()?.reciprocal()
2359    }
2360
2361    /// Element-wise sign: returns -1, 0, or +1.
2362    ///
2363    /// Implemented via `x / (|x| + eps)` clamped to [-1, 1], with exact 0 for
2364    /// inputs that are exactly zero.
2365    pub fn sign(&self) -> Result<Self> {
2366        let eps = 1e-12;
2367        let abs_x = self.abs()?;
2368        let denom = abs_x.affine(1.0, eps)?;
2369        let raw = self.div(&denom)?;
2370        raw.clamp(-1.0, 1.0)
2371    }
2372
2373    /// Log-sum-exp along a dimension (numerically stable).
2374    ///
2375    /// `logsumexp(x, d) = max(x,d) + log(sum(exp(x - max(x,d)), d))`
2376    pub fn logsumexp(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2377        let m = self.max(dim, true)?.detach();
2378        let shifted = self.sub(&m)?;
2379        let sum_exp = shifted.exp()?.sum(dim, true)?.log()?;
2380        let result = m.add(&sum_exp)?;
2381        if keep_dim {
2382            Ok(result)
2383        } else {
2384            result.squeeze(dim)
2385        }
2386    }
2387
2388    /// Product of elements along a dimension.
2389    ///
2390    /// Computed as `exp(sum(log(|x|)))` with sign correction.
2391    /// **Warning**: undefined for inputs containing zero.
2392    pub fn prod(&self, dim: usize, keep_dim: bool) -> Result<Self> {
2393        let log_abs = self.abs()?.log()?;
2394        let sum_log = log_abs.sum(dim, keep_dim)?;
2395        let magnitude = sum_log.exp()?;
2396        // Sign: count negatives via (sign < 0) then parity
2397        // For simplicity, assume positive inputs (like PyTorch's default usage).
2398        // A full sign-tracking impl would need additional ops.
2399        Ok(magnitude)
2400    }
2401
2402    /// Like `set_variable(self)` but takes `&self` instead of `self`.
2403    fn set_variable_ref(&self) -> Self {
2404        Tensor {
2405            inner: Arc::new(TensorInner {
2406                id: self.inner.id,
2407                storage: Arc::clone(&self.inner.storage),
2408                layout: self.inner.layout.clone(),
2409                dtype: self.inner.dtype,
2410                device: self.inner.device.clone(),
2411                op: self.inner.op.clone(),
2412                is_variable: true,
2413            }),
2414        }
2415    }
2416}
2417
2418// im2col / col2im — Efficient convolution via matrix multiplication
2419//
2420// im2col extracts all sliding-window patches from the input and arranges them
2421// as columns of a matrix. This converts convolution into a single large GEMM:
2422//
2423//   columns = im2col(input)          shape: [C_in * kH * kW,  H_out * W_out]
2424//   output  = weight × columns       shape: [C_out, H_out * W_out]
2425//
2426// col2im is the reverse: it scatters columns back into an image-shaped buffer,
2427// accumulating overlapping contributions. Used in the backward pass.
2428
2429/// im2col: Extract sliding-window patches from a single sample.
2430///
2431/// Input: `[C_in, H, W]` (one sample, no batch dim)
2432/// Output: columns `[C_in * kH * kW, H_out * W_out]`
2433#[inline]
2434#[allow(clippy::too_many_arguments)]
2435pub(crate) fn im2col(
2436    input: &[f64],
2437    c_in: usize,
2438    h: usize,
2439    w: usize,
2440    kh: usize,
2441    kw: usize,
2442    sh: usize,
2443    sw: usize,
2444    ph: usize,
2445    pw: usize,
2446    h_out: usize,
2447    w_out: usize,
2448    columns: &mut [f64],
2449) {
2450    let col_cols = h_out * w_out;
2451    // Each row of `columns` corresponds to one element of the kernel
2452    // across all spatial output positions
2453    for ci in 0..c_in {
2454        for ki in 0..kh {
2455            for kj in 0..kw {
2456                let row = (ci * kh + ki) * kw + kj;
2457                let row_offset = row * col_cols;
2458                for oh in 0..h_out {
2459                    for ow in 0..w_out {
2460                        let ih = (oh * sh + ki) as isize - ph as isize;
2461                        let iw = (ow * sw + kj) as isize - pw as isize;
2462                        let val = if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
2463                            input[(ci * h + ih as usize) * w + iw as usize]
2464                        } else {
2465                            0.0
2466                        };
2467                        columns[row_offset + oh * w_out + ow] = val;
2468                    }
2469                }
2470            }
2471        }
2472    }
2473}
2474
2475/// col2im: Scatter columns back into an image buffer (for backward).
2476///
2477/// Accumulates into `output` (which should be zeroed before calling).
2478/// columns: `[C_in * kH * kW, H_out * W_out]`
2479/// output: `[C_in, H, W]`
2480#[inline]
2481#[allow(clippy::too_many_arguments)]
2482pub(crate) fn col2im(
2483    columns: &[f64],
2484    c_in: usize,
2485    h: usize,
2486    w: usize,
2487    kh: usize,
2488    kw: usize,
2489    sh: usize,
2490    sw: usize,
2491    ph: usize,
2492    pw: usize,
2493    h_out: usize,
2494    w_out: usize,
2495    output: &mut [f64],
2496) {
2497    let col_cols = h_out * w_out;
2498    for ci in 0..c_in {
2499        for ki in 0..kh {
2500            for kj in 0..kw {
2501                let row = (ci * kh + ki) * kw + kj;
2502                let row_offset = row * col_cols;
2503                for oh in 0..h_out {
2504                    for ow in 0..w_out {
2505                        let ih = (oh * sh + ki) as isize - ph as isize;
2506                        let iw = (ow * sw + kj) as isize - pw as isize;
2507                        if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
2508                            output[(ci * h + ih as usize) * w + iw as usize] +=
2509                                columns[row_offset + oh * w_out + ow];
2510                        }
2511                    }
2512                }
2513            }
2514        }
2515    }
2516}
2517
2518/// Simple GEMM: C = A × B
2519///
2520/// A: [m, k], B: [k, n], C: [m, n]
2521/// All row-major.
2522#[inline]
2523pub(crate) fn gemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2524    for i in 0..m {
2525        let a_row = i * k;
2526        let c_row = i * n;
2527        for p in 0..k {
2528            let a_val = a[a_row + p];
2529            let b_row = p * n;
2530            for j in 0..n {
2531                c[c_row + j] += a_val * b[b_row + j];
2532            }
2533        }
2534    }
2535}
2536
2537/// GEMM: C = A^T × B
2538///
2539/// A: [k, m] (transposed to [m, k]), B: [k, n], C: [m, n]
2540#[inline]
2541pub(crate) fn gemm_at_b(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2542    for i in 0..m {
2543        let c_row = i * n;
2544        for p in 0..k {
2545            let a_val = a[p * m + i]; // A^T[i,p] = A[p,i]
2546            let b_row = p * n;
2547            for j in 0..n {
2548                c[c_row + j] += a_val * b[b_row + j];
2549            }
2550        }
2551    }
2552}
2553
2554/// GEMM: C = A × B^T
2555///
2556/// A: [m, k], B: [n, k] (transposed to [k, n]), C: [m, n]
2557#[inline]
2558pub(crate) fn gemm_a_bt(a: &[f64], b: &[f64], c: &mut [f64], m: usize, n: usize, k: usize) {
2559    for i in 0..m {
2560        let a_row = i * k;
2561        let c_row = i * n;
2562        for j in 0..n {
2563            let b_row = j * k;
2564            let mut val = 0.0f64;
2565            for p in 0..k {
2566                val += a[a_row + p] * b[b_row + p];
2567            }
2568            c[c_row + j] += val;
2569        }
2570    }
2571}