Skip to main content

torsh_python/tensor/
core.rs

1//! Core tensor implementation - PyTensor struct and fundamental operations
2
3use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
4// Note: For Python bindings, we use numpy's PyArray API directly
5// which handles type conversions internally (numpy uses ndarray 0.15.x)
6use numpy::{
7    IxDyn as NpIxDyn, PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods,
8};
9use pyo3::prelude::*;
10use pyo3::types::{PyAny, PyList};
11use torsh_core::device::DeviceType;
12use torsh_tensor::Tensor;
13
14/// Python wrapper for ToRSh Tensor (simplified version)
15#[pyclass(name = "Tensor")]
16#[derive(Clone)]
17pub struct PyTensor {
18    pub(crate) tensor: Tensor<f32>, // For now, default to f32
19}
20
21#[pymethods]
22impl PyTensor {
23    #[new]
24    pub fn new(
25        data: &Bound<'_, PyAny>,
26        _dtype: Option<PyDType>,
27        device: Option<PyDevice>,
28        requires_grad: Option<bool>,
29    ) -> PyResult<Self> {
30        let device = device.map(|d| d.device).unwrap_or(DeviceType::Cpu);
31        let requires_grad = requires_grad.unwrap_or(false);
32
33        // Convert Python data to Rust tensor
34        let tensor = if let Ok(arr) = data.clone().cast_into::<PyArray1<f32>>() {
35            // 1D NumPy array
36            let data = arr.to_vec()?;
37            let shape = vec![data.len()];
38            py_result!(Tensor::from_data(data, shape, device))?
39        } else if let Ok(arr) = data.clone().cast_into::<PyArray2<f32>>() {
40            // 2D NumPy array
41            let data = arr.to_vec()?;
42            let shape = arr.shape().to_vec();
43            py_result!(Tensor::from_data(data, shape, device))?
44        } else if let Ok(arr) = data.clone().cast_into::<PyArrayDyn<f32>>() {
45            // N-D NumPy array
46            let data = arr.to_vec()?;
47            let shape = arr.shape().to_vec();
48            py_result!(Tensor::from_data(data, shape, device))?
49        } else if let Ok(list) = data.clone().cast_into::<PyList>() {
50            // Python list - simplified version
51            Self::from_py_list(&list, device)?
52        } else if let Ok(scalar) = data.extract::<f32>() {
53            // Scalar value
54            py_result!(Tensor::from_data(vec![scalar], vec![], device))?
55        } else {
56            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
57                "Unsupported data type for tensor creation",
58            ));
59        };
60
61        let tensor = tensor.requires_grad_(requires_grad);
62
63        Ok(Self { tensor })
64    }
65
66    // Basic properties
67    #[getter]
68    fn shape(&self) -> Vec<usize> {
69        self.tensor.shape().dims().to_vec()
70    }
71
72    #[getter]
73    fn ndim(&self) -> usize {
74        self.tensor.ndim()
75    }
76
77    #[getter]
78    pub fn numel(&self) -> usize {
79        self.tensor.numel()
80    }
81
82    #[getter]
83    fn dtype(&self) -> PyDType {
84        PyDType::from(self.tensor.dtype())
85    }
86
87    #[getter]
88    fn device(&self) -> PyDevice {
89        PyDevice::from(self.tensor.device())
90    }
91
92    #[getter]
93    fn requires_grad(&self) -> bool {
94        self.tensor.requires_grad()
95    }
96
97    // String representation
98    fn __repr__(&self) -> String {
99        let binding = self.tensor.shape();
100        let shape = binding.dims();
101        let device_str = match self.tensor.device() {
102            DeviceType::Cpu => String::new(),
103            dev => format!(", device='{}'", PyDevice::from(dev)),
104        };
105        let grad_str = if self.tensor.requires_grad() {
106            ", requires_grad=True"
107        } else {
108            ""
109        };
110
111        format!(
112            "tensor({:?}, shape={}{}{}, dtype={})",
113            // For now, just show shape info instead of actual data
114            shape,
115            shape
116                .iter()
117                .map(|d| d.to_string())
118                .collect::<Vec<_>>()
119                .join(", "),
120            device_str,
121            grad_str,
122            PyDType::from(self.tensor.dtype())
123        )
124    }
125
126    fn __str__(&self) -> String {
127        self.__repr__()
128    }
129
130    /// Convert tensor to NumPy array with proper shape preservation
131    fn numpy(&self) -> PyResult<Py<PyAny>> {
132        Python::attach(|py| {
133            let data = py_result!(self.tensor.to_vec())?;
134            let binding = self.tensor.shape();
135            let shape = binding.dims();
136
137            // Convert shape from usize to Vec for reshape
138            let shape_vec: Vec<usize> = shape.iter().copied().collect();
139
140            match shape.len() {
141                0 => {
142                    // Scalar
143                    Ok(data[0].into_pyobject(py)?.into_any().unbind())
144                }
145                1 => {
146                    // 1D array - direct creation
147                    let array = PyArray1::from_vec(py, data);
148                    Ok(array.into_pyobject(py)?.into_any().unbind())
149                }
150                2 => {
151                    // 2D array - create 1D and reshape
152                    let array_1d = PyArray1::from_vec(py, data);
153                    let array_2d = array_1d
154                        .reshape([shape_vec[0], shape_vec[1]])
155                        .map_err(|e| {
156                            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
157                                "Reshape error: {}",
158                                e
159                            ))
160                        })?;
161                    Ok(array_2d.into_pyobject(py)?.into_any().unbind())
162                }
163                _ => {
164                    // N-D array - create 1D and reshape to dynamic dimension
165                    let array_1d = PyArray1::from_vec(py, data);
166                    let array_nd = array_1d.reshape(NpIxDyn(&shape_vec)).map_err(|e| {
167                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
168                            "Reshape error: {}",
169                            e
170                        ))
171                    })?;
172                    Ok(array_nd.into_pyobject(py)?.into_any().unbind())
173                }
174            }
175        })
176    }
177
178    /// Extract single scalar value from tensor
179    fn item(&self) -> PyResult<f32> {
180        if self.tensor.numel() != 1 {
181            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
182                "Only one element tensors can be converted to Python scalars",
183            ));
184        }
185        let data = py_result!(self.tensor.to_vec())?;
186        Ok(data[0])
187    }
188
189    /// Convert tensor to nested Python lists
190    fn tolist(&self) -> PyResult<Py<PyAny>> {
191        Python::attach(|py| {
192            let data = py_result!(self.tensor.to_vec())?;
193            let binding = self.tensor.shape();
194            let shape = binding.dims();
195
196            if shape.is_empty() {
197                // Scalar
198                Ok(data[0].into_pyobject(py)?.into_any().unbind())
199            } else {
200                // Create nested lists based on tensor shape
201                let nested_list = self.create_nested_list(py, &data, shape, 0, &mut 0)?;
202                Ok(nested_list)
203            }
204        })
205    }
206
207    /// Create a copy of tensor on specified device (NumPy-compatible method)
208    fn copy(&self) -> PyResult<PyTensor> {
209        Ok(PyTensor {
210            tensor: self.tensor.clone(),
211        })
212    }
213
214    /// Get tensor stride information (NumPy-compatible)
215    fn stride(&self) -> Vec<usize> {
216        // For now, return a simple stride calculation
217        let binding = self.tensor.shape();
218        let shape = binding.dims();
219        let mut strides = vec![1; shape.len()];
220        for i in (0..shape.len().saturating_sub(1)).rev() {
221            strides[i] = strides[i + 1] * shape[i + 1];
222        }
223        strides
224    }
225
226    /// Get number of bytes per element (NumPy-compatible)
227    fn itemsize(&self) -> usize {
228        std::mem::size_of::<f32>()
229    }
230
231    /// Get total number of bytes (NumPy-compatible)
232    fn nbytes(&self) -> usize {
233        self.tensor.numel() * self.itemsize()
234    }
235
236    /// Check if tensor data is C-contiguous (NumPy-compatible)
237    fn is_c_contiguous(&self) -> bool {
238        self.tensor.is_contiguous()
239    }
240
241    // ===============================
242    // Mathematical Operations
243    // ===============================
244
245    /// Add operation (tensor + other)
246    fn add(&self, other: &PyTensor) -> PyResult<PyTensor> {
247        let result = py_result!(self.tensor.add(&other.tensor))?;
248        Ok(PyTensor { tensor: result })
249    }
250
251    /// Subtract operation (tensor - other)
252    fn sub(&self, other: &PyTensor) -> PyResult<PyTensor> {
253        let result = py_result!(self.tensor.sub(&other.tensor))?;
254        Ok(PyTensor { tensor: result })
255    }
256
257    /// Multiply operation (tensor * other)
258    fn mul(&self, other: &PyTensor) -> PyResult<PyTensor> {
259        let result = py_result!(self.tensor.mul(&other.tensor))?;
260        Ok(PyTensor { tensor: result })
261    }
262
263    /// Divide operation (tensor / other)
264    fn div(&self, other: &PyTensor) -> PyResult<PyTensor> {
265        let result = py_result!(self.tensor.div(&other.tensor))?;
266        Ok(PyTensor { tensor: result })
267    }
268
269    /// Power operation (tensor ** exponent)
270    fn pow(&self, exponent: f32) -> PyResult<PyTensor> {
271        let result = py_result!(self.tensor.pow(exponent))?;
272        Ok(PyTensor { tensor: result })
273    }
274
275    /// Scalar addition
276    fn add_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
277        let result = py_result!(self.tensor.add_scalar(scalar))?;
278        Ok(PyTensor { tensor: result })
279    }
280
281    /// Scalar multiplication
282    fn mul_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
283        let result = py_result!(self.tensor.mul_scalar(scalar))?;
284        Ok(PyTensor { tensor: result })
285    }
286
287    // ===============================
288    // Tensor Manipulation Operations
289    // ===============================
290
291    /// Reshape tensor to new shape
292    fn reshape(&self, shape: Vec<i64>) -> PyResult<PyTensor> {
293        let i32_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
294        let result = py_result!(self.tensor.reshape(&i32_shape))?;
295        Ok(PyTensor { tensor: result })
296    }
297
298    /// Transpose tensor (swap dimensions)
299    fn transpose(&self, dim0: i64, dim1: i64) -> PyResult<PyTensor> {
300        let result = py_result!(self.tensor.transpose(dim0 as i32, dim1 as i32))?;
301        Ok(PyTensor { tensor: result })
302    }
303
304    /// Transpose 2D tensor
305    fn t(&self) -> PyResult<PyTensor> {
306        if self.tensor.ndim() != 2 {
307            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
308                "t() can only be called on 2D tensors",
309            ));
310        }
311        self.transpose(0, 1)
312    }
313
314    /// Squeeze tensor (remove dimensions of size 1)
315    fn squeeze(&self, dim: Option<i64>) -> PyResult<PyTensor> {
316        let dim_to_squeeze = dim.map(|d| d as i32).unwrap_or(0i32);
317        let result = py_result!(self.tensor.squeeze(dim_to_squeeze))?;
318        Ok(PyTensor { tensor: result })
319    }
320
321    /// Unsqueeze tensor (add dimension of size 1)
322    fn unsqueeze(&self, dim: i64) -> PyResult<PyTensor> {
323        let result = py_result!(self.tensor.unsqueeze(dim as i32))?;
324        Ok(PyTensor { tensor: result })
325    }
326
327    /// Flatten tensor
328    fn flatten(&self, _start_dim: Option<i64>, _end_dim: Option<i64>) -> PyResult<PyTensor> {
329        // For now, use basic flatten - may need different implementation
330        let result = py_result!(self.tensor.flatten())?;
331        Ok(PyTensor { tensor: result })
332    }
333
334    // ===============================
335    // Reduction Operations
336    // ===============================
337
338    /// Sum along specified dimensions
339    fn sum(&self, _dim: Option<Vec<i64>>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
340        // For now, use basic sum - may need different implementation
341        let result = py_result!(self.tensor.sum())?;
342        Ok(PyTensor { tensor: result })
343    }
344
345    /// Mean along specified dimensions
346    fn mean(&self, dim: Option<Vec<i64>>, keepdim: Option<bool>) -> PyResult<PyTensor> {
347        let keepdim = keepdim.unwrap_or(false);
348        let result = if let Some(dims) = dim {
349            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
350            py_result!(self.tensor.mean(Some(&usize_dims), keepdim))?
351        } else {
352            py_result!(self.tensor.mean(None, keepdim))?
353        };
354        Ok(PyTensor { tensor: result })
355    }
356
357    /// Maximum along specified dimensions
358    fn max(&self, dim: Option<i64>, keepdim: Option<bool>) -> PyResult<PyTensor> {
359        let dim_opt = dim.map(|d| d as usize);
360        let keepdim = keepdim.unwrap_or(false);
361        let result = py_result!(self.tensor.max(dim_opt, keepdim))?;
362        Ok(PyTensor { tensor: result })
363    }
364
365    /// Minimum along specified dimensions
366    fn min(&self, _dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
367        // For now, use basic min - may need different implementation
368        let result = py_result!(self.tensor.min())?;
369        Ok(PyTensor { tensor: result })
370    }
371
372    // ===============================
373    // Linear Algebra Operations
374    // ===============================
375
376    /// Matrix multiplication
377    fn matmul(&self, other: &PyTensor) -> PyResult<PyTensor> {
378        let result = py_result!(self.tensor.matmul(&other.tensor))?;
379        Ok(PyTensor { tensor: result })
380    }
381
382    /// Dot product (1D tensors)
383    fn dot(&self, other: &PyTensor) -> PyResult<PyTensor> {
384        let result = py_result!(self.tensor.dot(&other.tensor))?;
385        Ok(PyTensor { tensor: result })
386    }
387
388    // ===============================
389    // Activation Functions
390    // ===============================
391
392    /// ReLU activation function
393    fn relu(&self) -> PyResult<PyTensor> {
394        let result = py_result!(self.tensor.relu())?;
395        Ok(PyTensor { tensor: result })
396    }
397
398    /// Sigmoid activation function
399    fn sigmoid(&self) -> PyResult<PyTensor> {
400        let result = py_result!(self.tensor.sigmoid())?;
401        Ok(PyTensor { tensor: result })
402    }
403
404    /// Tanh activation function
405    fn tanh(&self) -> PyResult<PyTensor> {
406        let result = py_result!(self.tensor.tanh())?;
407        Ok(PyTensor { tensor: result })
408    }
409
410    /// Softmax along specified dimension
411    fn softmax(&self, dim: i64) -> PyResult<PyTensor> {
412        let result = py_result!(self.tensor.softmax(dim as i32))?;
413        Ok(PyTensor { tensor: result })
414    }
415
416    // ===============================
417    // Trigonometric Functions
418    // ===============================
419
420    /// Sine function
421    fn sin(&self) -> PyResult<PyTensor> {
422        let result = py_result!(self.tensor.sin())?;
423        Ok(PyTensor { tensor: result })
424    }
425
426    /// Cosine function
427    fn cos(&self) -> PyResult<PyTensor> {
428        let result = py_result!(self.tensor.cos())?;
429        Ok(PyTensor { tensor: result })
430    }
431
432    /// Exponential function
433    fn exp(&self) -> PyResult<PyTensor> {
434        let result = py_result!(self.tensor.exp())?;
435        Ok(PyTensor { tensor: result })
436    }
437
438    /// Natural logarithm
439    fn log(&self) -> PyResult<PyTensor> {
440        let result = py_result!(self.tensor.log())?;
441        Ok(PyTensor { tensor: result })
442    }
443
444    /// Square root
445    fn sqrt(&self) -> PyResult<PyTensor> {
446        let result = py_result!(self.tensor.sqrt())?;
447        Ok(PyTensor { tensor: result })
448    }
449
450    /// Absolute value
451    fn abs(&self) -> PyResult<PyTensor> {
452        let result = py_result!(self.tensor.abs())?;
453        Ok(PyTensor { tensor: result })
454    }
455
456    // ===============================
457    // Comparison Operations
458    // ===============================
459
460    /// Element-wise equality (returns f32 tensor with 0.0/1.0 values)
461    fn eq(&self, other: &PyTensor) -> PyResult<PyTensor> {
462        let bool_result = py_result!(self.tensor.eq(&other.tensor))?;
463        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
464        Ok(PyTensor {
465            tensor: float_tensor,
466        })
467    }
468
469    /// Element-wise inequality (returns f32 tensor with 0.0/1.0 values)
470    fn ne(&self, other: &PyTensor) -> PyResult<PyTensor> {
471        let bool_result = py_result!(self.tensor.ne(&other.tensor))?;
472        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
473        Ok(PyTensor {
474            tensor: float_tensor,
475        })
476    }
477
478    /// Element-wise less than (returns f32 tensor with 0.0/1.0 values)
479    fn lt(&self, other: &PyTensor) -> PyResult<PyTensor> {
480        let bool_result = py_result!(self.tensor.lt(&other.tensor))?;
481        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
482        Ok(PyTensor {
483            tensor: float_tensor,
484        })
485    }
486
487    /// Element-wise greater than (returns f32 tensor with 0.0/1.0 values)
488    fn gt(&self, other: &PyTensor) -> PyResult<PyTensor> {
489        let bool_result = py_result!(self.tensor.gt(&other.tensor))?;
490        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
491        Ok(PyTensor {
492            tensor: float_tensor,
493        })
494    }
495
496    /// Element-wise less than or equal (returns f32 tensor with 0.0/1.0 values)
497    fn le(&self, other: &PyTensor) -> PyResult<PyTensor> {
498        let bool_result = py_result!(self.tensor.le(&other.tensor))?;
499        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
500        Ok(PyTensor {
501            tensor: float_tensor,
502        })
503    }
504
505    /// Element-wise greater than or equal (returns f32 tensor with 0.0/1.0 values)
506    fn ge(&self, other: &PyTensor) -> PyResult<PyTensor> {
507        let bool_result = py_result!(self.tensor.ge(&other.tensor))?;
508        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
509        Ok(PyTensor {
510            tensor: float_tensor,
511        })
512    }
513
514    /// Element-wise maximum
515    fn maximum(&self, other: &PyTensor) -> PyResult<PyTensor> {
516        let result = py_result!(self.tensor.maximum(&other.tensor))?;
517        Ok(PyTensor { tensor: result })
518    }
519
520    /// Element-wise minimum
521    fn minimum(&self, other: &PyTensor) -> PyResult<PyTensor> {
522        let result = py_result!(self.tensor.minimum(&other.tensor))?;
523        Ok(PyTensor { tensor: result })
524    }
525
526    // Scalar comparison operations
527
528    /// Scalar equality (returns f32 tensor with 0.0/1.0 values)
529    fn eq_scalar(&self, value: f32) -> PyResult<PyTensor> {
530        let bool_result = py_result!(self.tensor.eq_scalar(value))?;
531        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
532        Ok(PyTensor {
533            tensor: float_tensor,
534        })
535    }
536
537    /// Scalar inequality (returns f32 tensor with 0.0/1.0 values)
538    fn ne_scalar(&self, value: f32) -> PyResult<PyTensor> {
539        let bool_result = py_result!(self.tensor.ne_scalar(value))?;
540        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
541        Ok(PyTensor {
542            tensor: float_tensor,
543        })
544    }
545
546    /// Scalar less than (returns f32 tensor with 0.0/1.0 values)
547    fn lt_scalar(&self, value: f32) -> PyResult<PyTensor> {
548        let bool_result = py_result!(self.tensor.lt_scalar(value))?;
549        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
550        Ok(PyTensor {
551            tensor: float_tensor,
552        })
553    }
554
555    /// Scalar greater than (returns f32 tensor with 0.0/1.0 values)
556    fn gt_scalar(&self, value: f32) -> PyResult<PyTensor> {
557        let bool_result = py_result!(self.tensor.gt_scalar(value))?;
558        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
559        Ok(PyTensor {
560            tensor: float_tensor,
561        })
562    }
563
564    /// Scalar less than or equal (returns f32 tensor with 0.0/1.0 values)
565    fn le_scalar(&self, value: f32) -> PyResult<PyTensor> {
566        let bool_result = py_result!(self.tensor.le_scalar(value))?;
567        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
568        Ok(PyTensor {
569            tensor: float_tensor,
570        })
571    }
572
573    /// Scalar greater than or equal (returns f32 tensor with 0.0/1.0 values)
574    fn ge_scalar(&self, value: f32) -> PyResult<PyTensor> {
575        let bool_result = py_result!(self.tensor.ge_scalar(value))?;
576        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
577        Ok(PyTensor {
578            tensor: float_tensor,
579        })
580    }
581
582    // ===============================
583    // Utility Methods
584    // ===============================
585
586    /// Create a copy of the tensor
587    fn clone_tensor(&self) -> PyResult<PyTensor> {
588        Ok(PyTensor {
589            tensor: self.tensor.clone(),
590        })
591    }
592
593    /// Create a detached copy (no gradients)
594    fn detach(&self) -> PyResult<PyTensor> {
595        Ok(PyTensor {
596            tensor: self.tensor.detach(),
597        })
598    }
599
600    /// Move tensor to specified device
601    fn to_device(&self, device: PyDevice) -> PyResult<PyTensor> {
602        let result = py_result!(self.tensor.clone().to(device.device))?;
603        Ok(PyTensor { tensor: result })
604    }
605
606    /// Check if tensor is contiguous in memory
607    fn is_contiguous(&self) -> bool {
608        self.tensor.is_contiguous()
609    }
610
611    /// Make tensor contiguous in memory
612    fn contiguous(&self) -> PyResult<PyTensor> {
613        let result = py_result!(self.tensor.contiguous())?;
614        Ok(PyTensor { tensor: result })
615    }
616
617    // ===============================
618    // Additional PyTorch-Compatible Operations
619    // ===============================
620
621    /// Clamp tensor values to specified range
622    fn clamp(&self, min: Option<f32>, max: Option<f32>) -> PyResult<PyTensor> {
623        let result = if let (Some(min_val), Some(max_val)) = (min, max) {
624            py_result!(self.tensor.clamp(min_val, max_val))?
625        } else if let Some(min_val) = min {
626            // Use clamp with a very large max value for min-only clamping
627            py_result!(self.tensor.clamp(min_val, f32::MAX))?
628        } else if let Some(max_val) = max {
629            // Use clamp with a very small min value for max-only clamping
630            py_result!(self.tensor.clamp(f32::MIN, max_val))?
631        } else {
632            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
633                "At least one of min or max must be specified",
634            ));
635        };
636        Ok(PyTensor { tensor: result })
637    }
638
639    /// Fill tensor with specified value (in-place operation)
640    fn fill_(&mut self, value: f32) -> PyResult<PyTensor> {
641        py_result!(self.tensor.fill_(value))?;
642        Ok(PyTensor {
643            tensor: self.tensor.clone(),
644        })
645    }
646
647    /// Zero out tensor (in-place operation)
648    fn zero_(&mut self) -> PyResult<PyTensor> {
649        py_result!(self.tensor.zero_())?;
650        Ok(PyTensor {
651            tensor: self.tensor.clone(),
652        })
653    }
654
655    /// Apply uniform random initialization
656    fn uniform_(&mut self, from: Option<f32>, to: Option<f32>) -> PyResult<PyTensor> {
657        // ✅ SciRS2 POLICY: Use scirs2_core::random for RNG
658        use scirs2_core::random::{thread_rng, Distribution, Uniform};
659
660        let from = from.unwrap_or(0.0);
661        let to = to.unwrap_or(1.0);
662
663        let mut rng = thread_rng();
664        let dist = Uniform::new(from, to).map_err(|e| {
665            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
666                "Invalid uniform distribution parameters: {}",
667                e
668            ))
669        })?;
670
671        let mut data = py_result!(self.tensor.data())?;
672        for val in data.iter_mut() {
673            *val = dist.sample(&mut rng);
674        }
675
676        let shape = self.tensor.shape().dims().to_vec();
677        let device = self.tensor.device();
678        self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
679            .requires_grad_(self.tensor.requires_grad());
680
681        Ok(PyTensor {
682            tensor: self.tensor.clone(),
683        })
684    }
685
686    /// Apply normal random initialization
687    fn normal_(&mut self, mean: Option<f32>, std: Option<f32>) -> PyResult<PyTensor> {
688        // ✅ SciRS2 POLICY: Use scirs2_core::random for RNG
689        use scirs2_core::random::{thread_rng, Distribution, Normal};
690
691        let mean = mean.unwrap_or(0.0);
692        let std = std.unwrap_or(1.0);
693
694        let normal = Normal::new(mean as f64, std as f64).map_err(|e| {
695            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
696                "Invalid normal distribution parameters: {}",
697                e
698            ))
699        })?;
700
701        let mut rng = thread_rng();
702        let mut data = py_result!(self.tensor.data())?;
703        for val in data.iter_mut() {
704            *val = normal.sample(&mut rng) as f32;
705        }
706
707        let shape = self.tensor.shape().dims().to_vec();
708        let device = self.tensor.device();
709        self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
710            .requires_grad_(self.tensor.requires_grad());
711
712        Ok(PyTensor {
713            tensor: self.tensor.clone(),
714        })
715    }
716
717    /// Repeat tensor along specified dimensions
718    fn repeat(&self, repeats: Vec<usize>) -> PyResult<PyTensor> {
719        let result = py_result!(self.tensor.repeat(&repeats))?;
720        Ok(PyTensor { tensor: result })
721    }
722
723    /// Expand tensor to specified shape (broadcasting)
724    fn expand(&self, size: Vec<i64>) -> PyResult<PyTensor> {
725        let size: Vec<usize> = size.iter().map(|&x| x as usize).collect();
726        let result = py_result!(self.tensor.expand(&size))?;
727        Ok(PyTensor { tensor: result })
728    }
729
730    /// Expand tensor to match another tensor's shape
731    fn expand_as(&self, other: &PyTensor) -> PyResult<PyTensor> {
732        let other_shape: Vec<usize> = other.tensor.shape().dims().iter().map(|&x| x).collect();
733        let result = py_result!(self.tensor.expand(&other_shape))?;
734        Ok(PyTensor { tensor: result })
735    }
736
737    /// Select elements from tensor along specified dimension
738    fn index_select(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
739        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
740        let index_i64 = py_result!(index_i32.to_i64_simd())?;
741        let result = py_result!(self.tensor.index_select(dim as i32, &index_i64))?;
742        Ok(PyTensor { tensor: result })
743    }
744
745    /// Gather elements from tensor
746    fn gather(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
747        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
748        let index_i64 = py_result!(index_i32.to_i64_simd())?;
749        let result = py_result!(self.tensor.gather(dim as usize, &index_i64))?;
750        Ok(PyTensor { tensor: result })
751    }
752
753    /// Scatter elements into tensor
754    fn scatter(&self, dim: i64, index: &PyTensor, src: &PyTensor) -> PyResult<PyTensor> {
755        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
756        let index_i64 = py_result!(index_i32.to_i64_simd())?;
757        let result = py_result!(self.tensor.scatter(dim as usize, &index_i64, &src.tensor))?;
758        Ok(PyTensor { tensor: result })
759    }
760
761    /// Masked fill operation
762    fn masked_fill(&self, mask: &PyTensor, value: f32) -> PyResult<PyTensor> {
763        // ✅ Proper masked fill implementation
764        let mut data = py_result!(self.tensor.data())?;
765        let mask_data = py_result!(mask.tensor.data())?;
766
767        if data.len() != mask_data.len() {
768            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
769                "Mask must have the same number of elements as tensor",
770            ));
771        }
772
773        for (val, &mask_val) in data.iter_mut().zip(mask_data.iter()) {
774            if mask_val != 0.0 {
775                *val = value;
776            }
777        }
778
779        let shape = self.tensor.shape().dims().to_vec();
780        let device = self.tensor.device();
781        let result = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
782            .requires_grad_(self.tensor.requires_grad());
783
784        Ok(PyTensor { tensor: result })
785    }
786
787    /// Masked select operation
788    fn masked_select(&self, mask: &PyTensor) -> PyResult<PyTensor> {
789        // Convert f32 mask to bool by treating non-zero as true
790        let mask_data = py_result!(mask.tensor.data())?;
791        let bool_data: Vec<bool> = mask_data.iter().map(|&x| x != 0.0).collect();
792        let mask_bool = py_result!(torsh_tensor::Tensor::from_data(
793            bool_data,
794            mask.tensor.shape().dims().to_vec(),
795            mask.tensor.device()
796        ))?;
797        let result = py_result!(self.tensor.masked_select(&mask_bool))?;
798        Ok(PyTensor { tensor: result })
799    }
800
801    /// Concatenate tensors along specified dimension
802    fn cat(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
803        let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
804            tensors.iter().map(|t| &t.tensor).collect();
805        let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
806        Ok(PyTensor { tensor: result })
807    }
808
809    /// Stack tensors along new dimension
810    fn stack(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
811        let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
812            tensors.iter().map(|t| &t.tensor).collect();
813        // Use cat as a simplified stacking implementation
814        let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
815        Ok(PyTensor { tensor: result })
816    }
817
818    /// Split tensor into chunks
819    fn chunk(&self, chunks: usize, dim: i64) -> PyResult<Vec<PyTensor>> {
820        // ✅ Proper chunk implementation
821        let shape = self.tensor.shape().dims().to_vec();
822        let dim_usize = dim as usize;
823
824        if dim_usize >= shape.len() {
825            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
826                "Dimension {} out of range for tensor with {} dimensions",
827                dim,
828                shape.len()
829            )));
830        }
831
832        let dim_size = shape[dim_usize];
833        let chunk_size = (dim_size + chunks - 1) / chunks; // Ceiling division
834
835        let mut result = Vec::new();
836        for i in 0..chunks {
837            let start = i * chunk_size;
838            if start >= dim_size {
839                break;
840            }
841            let end = std::cmp::min(start + chunk_size, dim_size);
842
843            // Use narrow to extract chunk
844            let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, end - start))?;
845            result.push(PyTensor { tensor: chunk });
846        }
847
848        Ok(result)
849    }
850
851    /// Split tensor at specified sizes
852    fn split(&self, split_sizes: Vec<usize>, dim: i64) -> PyResult<Vec<PyTensor>> {
853        // ✅ Proper split implementation
854        let shape = self.tensor.shape().dims().to_vec();
855        let dim_usize = dim as usize;
856
857        if dim_usize >= shape.len() {
858            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
859                "Dimension {} out of range for tensor with {} dimensions",
860                dim,
861                shape.len()
862            )));
863        }
864
865        let total_size: usize = split_sizes.iter().sum();
866        if total_size != shape[dim_usize] {
867            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
868                "Split sizes sum to {} but dimension {} has size {}",
869                total_size, dim, shape[dim_usize]
870            )));
871        }
872
873        let mut result = Vec::new();
874        let mut start = 0;
875
876        for &size in &split_sizes {
877            let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, size))?;
878            result.push(PyTensor { tensor: chunk });
879            start += size;
880        }
881
882        Ok(result)
883    }
884
885    /// Permute tensor dimensions
886    fn permute(&self, dims: Vec<i64>) -> PyResult<PyTensor> {
887        let dims: Vec<i32> = dims.iter().map(|&x| x as i32).collect();
888        let result = py_result!(self.tensor.permute(&dims))?;
889        Ok(PyTensor { tensor: result })
890    }
891
892    /// Get diagonal elements
893    fn diag(&self, diagonal: Option<i64>) -> PyResult<PyTensor> {
894        // ✅ Proper diagonal extraction implementation
895
896        let offset = diagonal.unwrap_or(0);
897        let shape = self.tensor.shape().dims().to_vec();
898
899        if shape.len() == 1 {
900            // 1D tensor -> create diagonal matrix
901            let n = shape[0];
902            let size = n + offset.abs() as usize;
903            let mut data = vec![0.0; size * size];
904            let input_data = py_result!(self.tensor.data())?;
905
906            for (i, &val) in input_data.iter().enumerate() {
907                if offset >= 0 {
908                    let row = i;
909                    let col = i + offset as usize;
910                    data[row * size + col] = val;
911                } else {
912                    let row = i + (-offset) as usize;
913                    let col = i;
914                    data[row * size + col] = val;
915                }
916            }
917
918            let result = py_result!(torsh_tensor::Tensor::from_data(
919                data,
920                vec![size, size],
921                self.tensor.device()
922            ))?;
923            Ok(PyTensor { tensor: result })
924        } else if shape.len() == 2 {
925            // 2D tensor -> extract diagonal
926            let rows = shape[0];
927            let cols = shape[1];
928            let data = py_result!(self.tensor.data())?;
929
930            let mut diag_data = Vec::new();
931
932            if offset >= 0 {
933                let offset_u = offset as usize;
934                for i in 0..std::cmp::min(rows, cols.saturating_sub(offset_u)) {
935                    diag_data.push(data[i * cols + i + offset_u]);
936                }
937            } else {
938                let offset_u = (-offset) as usize;
939                for i in 0..std::cmp::min(rows.saturating_sub(offset_u), cols) {
940                    diag_data.push(data[(i + offset_u) * cols + i]);
941                }
942            }
943
944            let diag_len = diag_data.len();
945            let result = py_result!(torsh_tensor::Tensor::from_data(
946                diag_data,
947                vec![diag_len],
948                self.tensor.device()
949            ))?;
950            Ok(PyTensor { tensor: result })
951        } else {
952            Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
953                "diag() only supports 1D or 2D tensors",
954            ))
955        }
956    }
957
958    /// Trace of matrix
959    fn trace(&self) -> PyResult<PyTensor> {
960        // ✅ Proper trace computation
961        let shape = self.tensor.shape().dims().to_vec();
962
963        if shape.len() != 2 {
964            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
965                "trace() requires a 2D tensor",
966            ));
967        }
968
969        let rows = shape[0];
970        let cols = shape[1];
971        let min_dim = std::cmp::min(rows, cols);
972
973        let data = py_result!(self.tensor.data())?;
974        let mut trace_sum = 0.0;
975
976        for i in 0..min_dim {
977            trace_sum += data[i * cols + i];
978        }
979
980        let result = py_result!(torsh_tensor::Tensor::from_data(
981            vec![trace_sum],
982            vec![],
983            self.tensor.device()
984        ))?;
985
986        Ok(PyTensor { tensor: result })
987    }
988
989    /// Norm calculation
990    fn norm(
991        &self,
992        p: Option<f32>,
993        _dim: Option<Vec<i64>>,
994        _keepdim: Option<bool>,
995    ) -> PyResult<PyTensor> {
996        let _p = p.unwrap_or(2.0);
997        // For now, use simple L2 norm regardless of parameters
998        // TODO: Implement full norm_lp functionality when ops module is exposed
999        let result = py_result!(self.tensor.norm())?;
1000        Ok(PyTensor { tensor: result })
1001    }
1002
1003    /// Standard deviation
1004    fn std(
1005        &self,
1006        dim: Option<Vec<i64>>,
1007        keepdim: Option<bool>,
1008        unbiased: Option<bool>,
1009    ) -> PyResult<PyTensor> {
1010        let keepdim = keepdim.unwrap_or(false);
1011        let unbiased = unbiased.unwrap_or(true);
1012        let stat_mode = if unbiased {
1013            torsh_tensor::stats::StatMode::Sample
1014        } else {
1015            torsh_tensor::stats::StatMode::Population
1016        };
1017        let result = if let Some(dims) = dim {
1018            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1019            py_result!(self.tensor.std(Some(&usize_dims), keepdim, stat_mode))?
1020        } else {
1021            py_result!(self.tensor.std(None, keepdim, stat_mode))?
1022        };
1023        Ok(PyTensor { tensor: result })
1024    }
1025
1026    /// Variance calculation
1027    fn var(
1028        &self,
1029        dim: Option<Vec<i64>>,
1030        keepdim: Option<bool>,
1031        unbiased: Option<bool>,
1032    ) -> PyResult<PyTensor> {
1033        let keepdim = keepdim.unwrap_or(false);
1034        let unbiased = unbiased.unwrap_or(true);
1035        let stat_mode = if unbiased {
1036            torsh_tensor::stats::StatMode::Sample
1037        } else {
1038            torsh_tensor::stats::StatMode::Population
1039        };
1040        let result = if let Some(dims) = dim {
1041            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1042            py_result!(self.tensor.var(Some(&usize_dims), keepdim, stat_mode))?
1043        } else {
1044            py_result!(self.tensor.var(None, keepdim, stat_mode))?
1045        };
1046        Ok(PyTensor { tensor: result })
1047    }
1048
1049    /// Argmax operation
1050    fn argmax(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1051        let result = if let Some(d) = dim {
1052            py_result!(self.tensor.argmax(Some(d as i32)))?
1053        } else {
1054            py_result!(self.tensor.argmax(None))?
1055        };
1056        let result_f32 = py_result!(result.to_f32_simd())?;
1057        Ok(PyTensor { tensor: result_f32 })
1058    }
1059
1060    /// Argmin operation
1061    fn argmin(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1062        let result = if let Some(d) = dim {
1063            py_result!(self.tensor.argmin(Some(d as i32)))?
1064        } else {
1065            py_result!(self.tensor.argmin(None))?
1066        };
1067        let result_f32 = py_result!(result.to_f32_simd())?;
1068        Ok(PyTensor { tensor: result_f32 })
1069    }
1070}
1071
1072impl PyTensor {
1073    /// Convert from Python list to tensor (simplified)
1074    fn from_py_list(list: &Bound<'_, PyList>, device: DeviceType) -> PyResult<Tensor<f32>> {
1075        let mut data = Vec::new();
1076        let len = list.len();
1077
1078        for i in 0..len {
1079            let item = list.get_item(i)?;
1080            if let Ok(val) = item.extract::<f32>() {
1081                data.push(val);
1082            } else {
1083                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
1084                    "Cannot convert item at index {} to f32",
1085                    i
1086                )));
1087            }
1088        }
1089
1090        py_result!(Tensor::from_data(data, vec![len], device))
1091    }
1092
1093    /// Helper method to create nested Python lists from tensor data
1094    fn create_nested_list(
1095        &self,
1096        py: Python<'_>,
1097        data: &[f32],
1098        shape: &[usize],
1099        dim: usize,
1100        index: &mut usize,
1101    ) -> PyResult<Py<PyAny>> {
1102        if dim == shape.len() - 1 {
1103            // Leaf dimension: create list of values
1104            let mut items = Vec::new();
1105            for _ in 0..shape[dim] {
1106                items.push(data[*index].into_pyobject(py)?.into_any().unbind());
1107                *index += 1;
1108            }
1109            Ok(items.into_pyobject(py)?.into_any().unbind())
1110        } else {
1111            // Intermediate dimension: create list of nested lists
1112            let mut items = Vec::new();
1113            for _ in 0..shape[dim] {
1114                let nested = self.create_nested_list(py, data, shape, dim + 1, index)?;
1115                items.push(nested);
1116            }
1117            Ok(items.into_pyobject(py)?.into_any().unbind())
1118        }
1119    }
1120
1121    /// Convert boolean tensor to float tensor (0.0 for false, 1.0 for true)
1122    fn bool_to_float_tensor(
1123        bool_tensor: torsh_tensor::Tensor<bool>,
1124    ) -> torsh_core::error::Result<torsh_tensor::Tensor<f32>> {
1125        let bool_data = bool_tensor.data()?;
1126        let float_data: Vec<f32> = bool_data
1127            .iter()
1128            .map(|&b| if b { 1.0 } else { 0.0 })
1129            .collect();
1130        let shape = bool_tensor.shape().dims().to_vec();
1131        let device = bool_tensor.device();
1132
1133        torsh_tensor::Tensor::from_data(float_data, shape, device)
1134    }
1135}