torsh_python/tensor/
core.rs

1//! Core tensor implementation - PyTensor struct and fundamental operations
2
3use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
4// ✅ SCIRS2 Policy: Use scirs2_core::ndarray instead of direct ndarray
5use numpy::{PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods, ToPyArray};
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyList};
8use scirs2_core::ndarray::{Array, IxDyn};
9use torsh_core::device::DeviceType;
10use torsh_tensor::Tensor;
11
12/// Python wrapper for ToRSh Tensor (simplified version)
13#[pyclass(name = "Tensor")]
14#[derive(Clone)]
15pub struct PyTensor {
16    pub(crate) tensor: Tensor<f32>, // For now, default to f32
17}
18
19#[pymethods]
20impl PyTensor {
21    #[new]
22    pub fn new(
23        data: &Bound<'_, PyAny>,
24        _dtype: Option<PyDType>,
25        device: Option<PyDevice>,
26        requires_grad: Option<bool>,
27    ) -> PyResult<Self> {
28        let device = device.map(|d| d.device).unwrap_or(DeviceType::Cpu);
29        let requires_grad = requires_grad.unwrap_or(false);
30
31        // Convert Python data to Rust tensor
32        let tensor = if let Ok(arr) = data.clone().cast_into::<PyArray1<f32>>() {
33            // 1D NumPy array
34            let data = arr.to_vec()?;
35            let shape = vec![data.len()];
36            py_result!(Tensor::from_data(data, shape, device))?
37        } else if let Ok(arr) = data.clone().cast_into::<PyArray2<f32>>() {
38            // 2D NumPy array
39            let data = arr.to_vec()?;
40            let shape = arr.shape().to_vec();
41            py_result!(Tensor::from_data(data, shape, device))?
42        } else if let Ok(arr) = data.clone().cast_into::<PyArrayDyn<f32>>() {
43            // N-D NumPy array
44            let data = arr.to_vec()?;
45            let shape = arr.shape().to_vec();
46            py_result!(Tensor::from_data(data, shape, device))?
47        } else if let Ok(list) = data.clone().cast_into::<PyList>() {
48            // Python list - simplified version
49            Self::from_py_list(&list, device)?
50        } else if let Ok(scalar) = data.extract::<f32>() {
51            // Scalar value
52            py_result!(Tensor::from_data(vec![scalar], vec![], device))?
53        } else {
54            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
55                "Unsupported data type for tensor creation",
56            ));
57        };
58
59        let tensor = tensor.requires_grad_(requires_grad);
60
61        Ok(Self { tensor })
62    }
63
64    // Basic properties
65    #[getter]
66    fn shape(&self) -> Vec<usize> {
67        self.tensor.shape().dims().to_vec()
68    }
69
70    #[getter]
71    fn ndim(&self) -> usize {
72        self.tensor.ndim()
73    }
74
75    #[getter]
76    pub fn numel(&self) -> usize {
77        self.tensor.numel()
78    }
79
80    #[getter]
81    fn dtype(&self) -> PyDType {
82        PyDType::from(self.tensor.dtype())
83    }
84
85    #[getter]
86    fn device(&self) -> PyDevice {
87        PyDevice::from(self.tensor.device())
88    }
89
90    #[getter]
91    fn requires_grad(&self) -> bool {
92        self.tensor.requires_grad()
93    }
94
95    // String representation
96    fn __repr__(&self) -> String {
97        let binding = self.tensor.shape();
98        let shape = binding.dims();
99        let device_str = match self.tensor.device() {
100            DeviceType::Cpu => String::new(),
101            dev => format!(", device='{}'", PyDevice::from(dev)),
102        };
103        let grad_str = if self.tensor.requires_grad() {
104            ", requires_grad=True"
105        } else {
106            ""
107        };
108
109        format!(
110            "tensor({:?}, shape={}{}{}, dtype={})",
111            // For now, just show shape info instead of actual data
112            shape,
113            shape
114                .iter()
115                .map(|d| d.to_string())
116                .collect::<Vec<_>>()
117                .join(", "),
118            device_str,
119            grad_str,
120            PyDType::from(self.tensor.dtype())
121        )
122    }
123
124    fn __str__(&self) -> String {
125        self.__repr__()
126    }
127
128    /// Convert tensor to NumPy array with proper shape preservation
129    fn numpy(&self) -> PyResult<Py<PyAny>> {
130        Python::attach(|py| {
131            let data = py_result!(self.tensor.to_vec())?;
132            let binding = self.tensor.shape();
133            let shape = binding.dims();
134
135            // Convert shape from usize to PyArrayDyn compatible format
136            let shape_vec: Vec<usize> = shape.iter().copied().collect();
137
138            match shape.len() {
139                0 => {
140                    // Scalar
141                    Ok(data[0].into_pyobject(py)?.into_any().unbind())
142                }
143                1 => {
144                    // 1D array
145                    let array = data.to_pyarray(py);
146                    Ok(array.into_pyobject(py)?.into_any().unbind())
147                }
148                2 => {
149                    // 2D array - properly reshape
150                    // Create ndarray and convert to PyArrayDyn
151                    let ndarray = Array::from_vec(data)
152                        .into_dyn()
153                        .to_shape(IxDyn(&shape_vec))
154                        .map_err(|e| {
155                            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
156                                "Shape error: {}",
157                                e
158                            ))
159                        })?
160                        .into_owned();
161                    let array = ndarray.to_pyarray(py);
162                    Ok(array.into_pyobject(py)?.into_any().unbind())
163                }
164                _ => {
165                    // N-D array - properly reshape for arbitrary dimensions
166                    let ndarray = Array::from_vec(data)
167                        .into_dyn()
168                        .to_shape(IxDyn(&shape_vec))
169                        .map_err(|e| {
170                            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
171                                "Shape error: {}",
172                                e
173                            ))
174                        })?
175                        .into_owned();
176                    let array = ndarray.to_pyarray(py);
177                    Ok(array.into_pyobject(py)?.into_any().unbind())
178                }
179            }
180        })
181    }
182
183    /// Extract single scalar value from tensor
184    fn item(&self) -> PyResult<f32> {
185        if self.tensor.numel() != 1 {
186            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
187                "Only one element tensors can be converted to Python scalars",
188            ));
189        }
190        let data = py_result!(self.tensor.to_vec())?;
191        Ok(data[0])
192    }
193
194    /// Convert tensor to nested Python lists
195    fn tolist(&self) -> PyResult<Py<PyAny>> {
196        Python::attach(|py| {
197            let data = py_result!(self.tensor.to_vec())?;
198            let binding = self.tensor.shape();
199            let shape = binding.dims();
200
201            if shape.is_empty() {
202                // Scalar
203                Ok(data[0].into_pyobject(py)?.into_any().unbind())
204            } else {
205                // Create nested lists based on tensor shape
206                let nested_list = self.create_nested_list(py, &data, shape, 0, &mut 0)?;
207                Ok(nested_list)
208            }
209        })
210    }
211
212    /// Create a copy of tensor on specified device (NumPy-compatible method)
213    fn copy(&self) -> PyResult<PyTensor> {
214        Ok(PyTensor {
215            tensor: self.tensor.clone(),
216        })
217    }
218
219    /// Get tensor stride information (NumPy-compatible)
220    fn stride(&self) -> Vec<usize> {
221        // For now, return a simple stride calculation
222        let binding = self.tensor.shape();
223        let shape = binding.dims();
224        let mut strides = vec![1; shape.len()];
225        for i in (0..shape.len().saturating_sub(1)).rev() {
226            strides[i] = strides[i + 1] * shape[i + 1];
227        }
228        strides
229    }
230
231    /// Get number of bytes per element (NumPy-compatible)
232    fn itemsize(&self) -> usize {
233        std::mem::size_of::<f32>()
234    }
235
236    /// Get total number of bytes (NumPy-compatible)
237    fn nbytes(&self) -> usize {
238        self.tensor.numel() * self.itemsize()
239    }
240
241    /// Check if tensor data is C-contiguous (NumPy-compatible)
242    fn is_c_contiguous(&self) -> bool {
243        self.tensor.is_contiguous()
244    }
245
246    // ===============================
247    // Mathematical Operations
248    // ===============================
249
250    /// Add operation (tensor + other)
251    fn add(&self, other: &PyTensor) -> PyResult<PyTensor> {
252        let result = py_result!(self.tensor.add(&other.tensor))?;
253        Ok(PyTensor { tensor: result })
254    }
255
256    /// Subtract operation (tensor - other)
257    fn sub(&self, other: &PyTensor) -> PyResult<PyTensor> {
258        let result = py_result!(self.tensor.sub(&other.tensor))?;
259        Ok(PyTensor { tensor: result })
260    }
261
262    /// Multiply operation (tensor * other)
263    fn mul(&self, other: &PyTensor) -> PyResult<PyTensor> {
264        let result = py_result!(self.tensor.mul(&other.tensor))?;
265        Ok(PyTensor { tensor: result })
266    }
267
268    /// Divide operation (tensor / other)
269    fn div(&self, other: &PyTensor) -> PyResult<PyTensor> {
270        let result = py_result!(self.tensor.div(&other.tensor))?;
271        Ok(PyTensor { tensor: result })
272    }
273
274    /// Power operation (tensor ** exponent)
275    fn pow(&self, exponent: f32) -> PyResult<PyTensor> {
276        let result = py_result!(self.tensor.pow(exponent))?;
277        Ok(PyTensor { tensor: result })
278    }
279
280    /// Scalar addition
281    fn add_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
282        let result = py_result!(self.tensor.add_scalar(scalar))?;
283        Ok(PyTensor { tensor: result })
284    }
285
286    /// Scalar multiplication
287    fn mul_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
288        let result = py_result!(self.tensor.mul_scalar(scalar))?;
289        Ok(PyTensor { tensor: result })
290    }
291
292    // ===============================
293    // Tensor Manipulation Operations
294    // ===============================
295
296    /// Reshape tensor to new shape
297    fn reshape(&self, shape: Vec<i64>) -> PyResult<PyTensor> {
298        let i32_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
299        let result = py_result!(self.tensor.reshape(&i32_shape))?;
300        Ok(PyTensor { tensor: result })
301    }
302
303    /// Transpose tensor (swap dimensions)
304    fn transpose(&self, dim0: i64, dim1: i64) -> PyResult<PyTensor> {
305        let result = py_result!(self.tensor.transpose(dim0 as i32, dim1 as i32))?;
306        Ok(PyTensor { tensor: result })
307    }
308
309    /// Transpose 2D tensor
310    fn t(&self) -> PyResult<PyTensor> {
311        if self.tensor.ndim() != 2 {
312            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
313                "t() can only be called on 2D tensors",
314            ));
315        }
316        self.transpose(0, 1)
317    }
318
319    /// Squeeze tensor (remove dimensions of size 1)
320    fn squeeze(&self, dim: Option<i64>) -> PyResult<PyTensor> {
321        let dim_to_squeeze = dim.map(|d| d as i32).unwrap_or(0i32);
322        let result = py_result!(self.tensor.squeeze(dim_to_squeeze))?;
323        Ok(PyTensor { tensor: result })
324    }
325
326    /// Unsqueeze tensor (add dimension of size 1)
327    fn unsqueeze(&self, dim: i64) -> PyResult<PyTensor> {
328        let result = py_result!(self.tensor.unsqueeze(dim as i32))?;
329        Ok(PyTensor { tensor: result })
330    }
331
332    /// Flatten tensor
333    fn flatten(&self, _start_dim: Option<i64>, _end_dim: Option<i64>) -> PyResult<PyTensor> {
334        // For now, use basic flatten - may need different implementation
335        let result = py_result!(self.tensor.flatten())?;
336        Ok(PyTensor { tensor: result })
337    }
338
339    // ===============================
340    // Reduction Operations
341    // ===============================
342
343    /// Sum along specified dimensions
344    fn sum(&self, _dim: Option<Vec<i64>>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
345        // For now, use basic sum - may need different implementation
346        let result = py_result!(self.tensor.sum())?;
347        Ok(PyTensor { tensor: result })
348    }
349
350    /// Mean along specified dimensions
351    fn mean(&self, dim: Option<Vec<i64>>, keepdim: Option<bool>) -> PyResult<PyTensor> {
352        let keepdim = keepdim.unwrap_or(false);
353        let result = if let Some(dims) = dim {
354            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
355            py_result!(self.tensor.mean(Some(&usize_dims), keepdim))?
356        } else {
357            py_result!(self.tensor.mean(None, keepdim))?
358        };
359        Ok(PyTensor { tensor: result })
360    }
361
362    /// Maximum along specified dimensions
363    fn max(&self, dim: Option<i64>, keepdim: Option<bool>) -> PyResult<PyTensor> {
364        let dim_opt = dim.map(|d| d as usize);
365        let keepdim = keepdim.unwrap_or(false);
366        let result = py_result!(self.tensor.max(dim_opt, keepdim))?;
367        Ok(PyTensor { tensor: result })
368    }
369
370    /// Minimum along specified dimensions
371    fn min(&self, _dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
372        // For now, use basic min - may need different implementation
373        let result = py_result!(self.tensor.min())?;
374        Ok(PyTensor { tensor: result })
375    }
376
377    // ===============================
378    // Linear Algebra Operations
379    // ===============================
380
381    /// Matrix multiplication
382    fn matmul(&self, other: &PyTensor) -> PyResult<PyTensor> {
383        let result = py_result!(self.tensor.matmul(&other.tensor))?;
384        Ok(PyTensor { tensor: result })
385    }
386
387    /// Dot product (1D tensors)
388    fn dot(&self, other: &PyTensor) -> PyResult<PyTensor> {
389        let result = py_result!(self.tensor.dot(&other.tensor))?;
390        Ok(PyTensor { tensor: result })
391    }
392
393    // ===============================
394    // Activation Functions
395    // ===============================
396
397    /// ReLU activation function
398    fn relu(&self) -> PyResult<PyTensor> {
399        let result = py_result!(self.tensor.relu())?;
400        Ok(PyTensor { tensor: result })
401    }
402
403    /// Sigmoid activation function
404    fn sigmoid(&self) -> PyResult<PyTensor> {
405        let result = py_result!(self.tensor.sigmoid())?;
406        Ok(PyTensor { tensor: result })
407    }
408
409    /// Tanh activation function
410    fn tanh(&self) -> PyResult<PyTensor> {
411        let result = py_result!(self.tensor.tanh())?;
412        Ok(PyTensor { tensor: result })
413    }
414
415    /// Softmax along specified dimension
416    fn softmax(&self, dim: i64) -> PyResult<PyTensor> {
417        let result = py_result!(self.tensor.softmax(dim as i32))?;
418        Ok(PyTensor { tensor: result })
419    }
420
421    // ===============================
422    // Trigonometric Functions
423    // ===============================
424
425    /// Sine function
426    fn sin(&self) -> PyResult<PyTensor> {
427        let result = py_result!(self.tensor.sin())?;
428        Ok(PyTensor { tensor: result })
429    }
430
431    /// Cosine function
432    fn cos(&self) -> PyResult<PyTensor> {
433        let result = py_result!(self.tensor.cos())?;
434        Ok(PyTensor { tensor: result })
435    }
436
437    /// Exponential function
438    fn exp(&self) -> PyResult<PyTensor> {
439        let result = py_result!(self.tensor.exp())?;
440        Ok(PyTensor { tensor: result })
441    }
442
443    /// Natural logarithm
444    fn log(&self) -> PyResult<PyTensor> {
445        let result = py_result!(self.tensor.log())?;
446        Ok(PyTensor { tensor: result })
447    }
448
449    /// Square root
450    fn sqrt(&self) -> PyResult<PyTensor> {
451        let result = py_result!(self.tensor.sqrt())?;
452        Ok(PyTensor { tensor: result })
453    }
454
455    /// Absolute value
456    fn abs(&self) -> PyResult<PyTensor> {
457        let result = py_result!(self.tensor.abs())?;
458        Ok(PyTensor { tensor: result })
459    }
460
461    // ===============================
462    // Comparison Operations
463    // ===============================
464
465    /// Element-wise equality (returns f32 tensor with 0.0/1.0 values)
466    fn eq(&self, other: &PyTensor) -> PyResult<PyTensor> {
467        let bool_result = py_result!(self.tensor.eq(&other.tensor))?;
468        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
469        Ok(PyTensor {
470            tensor: float_tensor,
471        })
472    }
473
474    /// Element-wise inequality (returns f32 tensor with 0.0/1.0 values)
475    fn ne(&self, other: &PyTensor) -> PyResult<PyTensor> {
476        let bool_result = py_result!(self.tensor.ne(&other.tensor))?;
477        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
478        Ok(PyTensor {
479            tensor: float_tensor,
480        })
481    }
482
483    /// Element-wise less than (returns f32 tensor with 0.0/1.0 values)
484    fn lt(&self, other: &PyTensor) -> PyResult<PyTensor> {
485        let bool_result = py_result!(self.tensor.lt(&other.tensor))?;
486        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
487        Ok(PyTensor {
488            tensor: float_tensor,
489        })
490    }
491
492    /// Element-wise greater than (returns f32 tensor with 0.0/1.0 values)
493    fn gt(&self, other: &PyTensor) -> PyResult<PyTensor> {
494        let bool_result = py_result!(self.tensor.gt(&other.tensor))?;
495        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
496        Ok(PyTensor {
497            tensor: float_tensor,
498        })
499    }
500
501    /// Element-wise less than or equal (returns f32 tensor with 0.0/1.0 values)
502    fn le(&self, other: &PyTensor) -> PyResult<PyTensor> {
503        let bool_result = py_result!(self.tensor.le(&other.tensor))?;
504        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
505        Ok(PyTensor {
506            tensor: float_tensor,
507        })
508    }
509
510    /// Element-wise greater than or equal (returns f32 tensor with 0.0/1.0 values)
511    fn ge(&self, other: &PyTensor) -> PyResult<PyTensor> {
512        let bool_result = py_result!(self.tensor.ge(&other.tensor))?;
513        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
514        Ok(PyTensor {
515            tensor: float_tensor,
516        })
517    }
518
519    /// Element-wise maximum
520    fn maximum(&self, other: &PyTensor) -> PyResult<PyTensor> {
521        let result = py_result!(self.tensor.maximum(&other.tensor))?;
522        Ok(PyTensor { tensor: result })
523    }
524
525    /// Element-wise minimum
526    fn minimum(&self, other: &PyTensor) -> PyResult<PyTensor> {
527        let result = py_result!(self.tensor.minimum(&other.tensor))?;
528        Ok(PyTensor { tensor: result })
529    }
530
531    // Scalar comparison operations
532
533    /// Scalar equality (returns f32 tensor with 0.0/1.0 values)
534    fn eq_scalar(&self, value: f32) -> PyResult<PyTensor> {
535        let bool_result = py_result!(self.tensor.eq_scalar(value))?;
536        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
537        Ok(PyTensor {
538            tensor: float_tensor,
539        })
540    }
541
542    /// Scalar inequality (returns f32 tensor with 0.0/1.0 values)
543    fn ne_scalar(&self, value: f32) -> PyResult<PyTensor> {
544        let bool_result = py_result!(self.tensor.ne_scalar(value))?;
545        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
546        Ok(PyTensor {
547            tensor: float_tensor,
548        })
549    }
550
551    /// Scalar less than (returns f32 tensor with 0.0/1.0 values)
552    fn lt_scalar(&self, value: f32) -> PyResult<PyTensor> {
553        let bool_result = py_result!(self.tensor.lt_scalar(value))?;
554        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
555        Ok(PyTensor {
556            tensor: float_tensor,
557        })
558    }
559
560    /// Scalar greater than (returns f32 tensor with 0.0/1.0 values)
561    fn gt_scalar(&self, value: f32) -> PyResult<PyTensor> {
562        let bool_result = py_result!(self.tensor.gt_scalar(value))?;
563        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
564        Ok(PyTensor {
565            tensor: float_tensor,
566        })
567    }
568
569    /// Scalar less than or equal (returns f32 tensor with 0.0/1.0 values)
570    fn le_scalar(&self, value: f32) -> PyResult<PyTensor> {
571        let bool_result = py_result!(self.tensor.le_scalar(value))?;
572        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
573        Ok(PyTensor {
574            tensor: float_tensor,
575        })
576    }
577
578    /// Scalar greater than or equal (returns f32 tensor with 0.0/1.0 values)
579    fn ge_scalar(&self, value: f32) -> PyResult<PyTensor> {
580        let bool_result = py_result!(self.tensor.ge_scalar(value))?;
581        let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
582        Ok(PyTensor {
583            tensor: float_tensor,
584        })
585    }
586
587    // ===============================
588    // Utility Methods
589    // ===============================
590
591    /// Create a copy of the tensor
592    fn clone_tensor(&self) -> PyResult<PyTensor> {
593        Ok(PyTensor {
594            tensor: self.tensor.clone(),
595        })
596    }
597
598    /// Create a detached copy (no gradients)
599    fn detach(&self) -> PyResult<PyTensor> {
600        Ok(PyTensor {
601            tensor: self.tensor.detach(),
602        })
603    }
604
605    /// Move tensor to specified device
606    fn to_device(&self, device: PyDevice) -> PyResult<PyTensor> {
607        let result = py_result!(self.tensor.clone().to(device.device))?;
608        Ok(PyTensor { tensor: result })
609    }
610
611    /// Check if tensor is contiguous in memory
612    fn is_contiguous(&self) -> bool {
613        self.tensor.is_contiguous()
614    }
615
616    /// Make tensor contiguous in memory
617    fn contiguous(&self) -> PyResult<PyTensor> {
618        let result = py_result!(self.tensor.contiguous())?;
619        Ok(PyTensor { tensor: result })
620    }
621
622    // ===============================
623    // Additional PyTorch-Compatible Operations
624    // ===============================
625
626    /// Clamp tensor values to specified range
627    fn clamp(&self, min: Option<f32>, max: Option<f32>) -> PyResult<PyTensor> {
628        let result = if let (Some(min_val), Some(max_val)) = (min, max) {
629            py_result!(self.tensor.clamp(min_val, max_val))?
630        } else if let Some(min_val) = min {
631            // Use clamp with a very large max value for min-only clamping
632            py_result!(self.tensor.clamp(min_val, f32::MAX))?
633        } else if let Some(max_val) = max {
634            // Use clamp with a very small min value for max-only clamping
635            py_result!(self.tensor.clamp(f32::MIN, max_val))?
636        } else {
637            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
638                "At least one of min or max must be specified",
639            ));
640        };
641        Ok(PyTensor { tensor: result })
642    }
643
644    /// Fill tensor with specified value (in-place operation)
645    fn fill_(&mut self, value: f32) -> PyResult<PyTensor> {
646        py_result!(self.tensor.fill_(value))?;
647        Ok(PyTensor {
648            tensor: self.tensor.clone(),
649        })
650    }
651
652    /// Zero out tensor (in-place operation)
653    fn zero_(&mut self) -> PyResult<PyTensor> {
654        py_result!(self.tensor.zero_())?;
655        Ok(PyTensor {
656            tensor: self.tensor.clone(),
657        })
658    }
659
660    /// Apply uniform random initialization
661    fn uniform_(&mut self, from: Option<f32>, to: Option<f32>) -> PyResult<PyTensor> {
662        // ✅ SciRS2 POLICY: Use scirs2_core::random for RNG
663        use scirs2_core::random::{thread_rng, Distribution, Uniform};
664
665        let from = from.unwrap_or(0.0);
666        let to = to.unwrap_or(1.0);
667
668        let mut rng = thread_rng();
669        let dist = Uniform::new(from, to).map_err(|e| {
670            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
671                "Invalid uniform distribution parameters: {}",
672                e
673            ))
674        })?;
675
676        let mut data = py_result!(self.tensor.data())?;
677        for val in data.iter_mut() {
678            *val = dist.sample(&mut rng);
679        }
680
681        let shape = self.tensor.shape().dims().to_vec();
682        let device = self.tensor.device();
683        self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
684            .requires_grad_(self.tensor.requires_grad());
685
686        Ok(PyTensor {
687            tensor: self.tensor.clone(),
688        })
689    }
690
691    /// Apply normal random initialization
692    fn normal_(&mut self, mean: Option<f32>, std: Option<f32>) -> PyResult<PyTensor> {
693        // ✅ SciRS2 POLICY: Use scirs2_core::random for RNG
694        use scirs2_core::random::{thread_rng, Distribution, Normal};
695
696        let mean = mean.unwrap_or(0.0);
697        let std = std.unwrap_or(1.0);
698
699        let normal = Normal::new(mean as f64, std as f64).map_err(|e| {
700            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
701                "Invalid normal distribution parameters: {}",
702                e
703            ))
704        })?;
705
706        let mut rng = thread_rng();
707        let mut data = py_result!(self.tensor.data())?;
708        for val in data.iter_mut() {
709            *val = normal.sample(&mut rng) as f32;
710        }
711
712        let shape = self.tensor.shape().dims().to_vec();
713        let device = self.tensor.device();
714        self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
715            .requires_grad_(self.tensor.requires_grad());
716
717        Ok(PyTensor {
718            tensor: self.tensor.clone(),
719        })
720    }
721
722    /// Repeat tensor along specified dimensions
723    fn repeat(&self, repeats: Vec<usize>) -> PyResult<PyTensor> {
724        let result = py_result!(self.tensor.repeat(&repeats))?;
725        Ok(PyTensor { tensor: result })
726    }
727
728    /// Expand tensor to specified shape (broadcasting)
729    fn expand(&self, size: Vec<i64>) -> PyResult<PyTensor> {
730        let size: Vec<usize> = size.iter().map(|&x| x as usize).collect();
731        let result = py_result!(self.tensor.expand(&size))?;
732        Ok(PyTensor { tensor: result })
733    }
734
735    /// Expand tensor to match another tensor's shape
736    fn expand_as(&self, other: &PyTensor) -> PyResult<PyTensor> {
737        let other_shape: Vec<usize> = other.tensor.shape().dims().iter().map(|&x| x).collect();
738        let result = py_result!(self.tensor.expand(&other_shape))?;
739        Ok(PyTensor { tensor: result })
740    }
741
742    /// Select elements from tensor along specified dimension
743    fn index_select(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
744        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
745        let index_i64 = py_result!(index_i32.to_i64_simd())?;
746        let result = py_result!(self.tensor.index_select(dim as i32, &index_i64))?;
747        Ok(PyTensor { tensor: result })
748    }
749
750    /// Gather elements from tensor
751    fn gather(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
752        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
753        let index_i64 = py_result!(index_i32.to_i64_simd())?;
754        let result = py_result!(self.tensor.gather(dim as usize, &index_i64))?;
755        Ok(PyTensor { tensor: result })
756    }
757
758    /// Scatter elements into tensor
759    fn scatter(&self, dim: i64, index: &PyTensor, src: &PyTensor) -> PyResult<PyTensor> {
760        let index_i32 = py_result!(index.tensor.to_i32_simd())?;
761        let index_i64 = py_result!(index_i32.to_i64_simd())?;
762        let result = py_result!(self.tensor.scatter(dim as usize, &index_i64, &src.tensor))?;
763        Ok(PyTensor { tensor: result })
764    }
765
766    /// Masked fill operation
767    fn masked_fill(&self, mask: &PyTensor, value: f32) -> PyResult<PyTensor> {
768        // ✅ Proper masked fill implementation
769        let mut data = py_result!(self.tensor.data())?;
770        let mask_data = py_result!(mask.tensor.data())?;
771
772        if data.len() != mask_data.len() {
773            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
774                "Mask must have the same number of elements as tensor",
775            ));
776        }
777
778        for (val, &mask_val) in data.iter_mut().zip(mask_data.iter()) {
779            if mask_val != 0.0 {
780                *val = value;
781            }
782        }
783
784        let shape = self.tensor.shape().dims().to_vec();
785        let device = self.tensor.device();
786        let result = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
787            .requires_grad_(self.tensor.requires_grad());
788
789        Ok(PyTensor { tensor: result })
790    }
791
792    /// Masked select operation
793    fn masked_select(&self, mask: &PyTensor) -> PyResult<PyTensor> {
794        // Convert f32 mask to bool by treating non-zero as true
795        let mask_data = py_result!(mask.tensor.data())?;
796        let bool_data: Vec<bool> = mask_data.iter().map(|&x| x != 0.0).collect();
797        let mask_bool = py_result!(torsh_tensor::Tensor::from_data(
798            bool_data,
799            mask.tensor.shape().dims().to_vec(),
800            mask.tensor.device()
801        ))?;
802        let result = py_result!(self.tensor.masked_select(&mask_bool))?;
803        Ok(PyTensor { tensor: result })
804    }
805
806    /// Concatenate tensors along specified dimension
807    fn cat(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
808        let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
809            tensors.iter().map(|t| &t.tensor).collect();
810        let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
811        Ok(PyTensor { tensor: result })
812    }
813
814    /// Stack tensors along new dimension
815    fn stack(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
816        let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
817            tensors.iter().map(|t| &t.tensor).collect();
818        // Use cat as a simplified stacking implementation
819        let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
820        Ok(PyTensor { tensor: result })
821    }
822
823    /// Split tensor into chunks
824    fn chunk(&self, chunks: usize, dim: i64) -> PyResult<Vec<PyTensor>> {
825        // ✅ Proper chunk implementation
826        let shape = self.tensor.shape().dims().to_vec();
827        let dim_usize = dim as usize;
828
829        if dim_usize >= shape.len() {
830            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
831                "Dimension {} out of range for tensor with {} dimensions",
832                dim,
833                shape.len()
834            )));
835        }
836
837        let dim_size = shape[dim_usize];
838        let chunk_size = (dim_size + chunks - 1) / chunks; // Ceiling division
839
840        let mut result = Vec::new();
841        for i in 0..chunks {
842            let start = i * chunk_size;
843            if start >= dim_size {
844                break;
845            }
846            let end = std::cmp::min(start + chunk_size, dim_size);
847
848            // Use narrow to extract chunk
849            let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, end - start))?;
850            result.push(PyTensor { tensor: chunk });
851        }
852
853        Ok(result)
854    }
855
856    /// Split tensor at specified sizes
857    fn split(&self, split_sizes: Vec<usize>, dim: i64) -> PyResult<Vec<PyTensor>> {
858        // ✅ Proper split implementation
859        let shape = self.tensor.shape().dims().to_vec();
860        let dim_usize = dim as usize;
861
862        if dim_usize >= shape.len() {
863            return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
864                "Dimension {} out of range for tensor with {} dimensions",
865                dim,
866                shape.len()
867            )));
868        }
869
870        let total_size: usize = split_sizes.iter().sum();
871        if total_size != shape[dim_usize] {
872            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
873                "Split sizes sum to {} but dimension {} has size {}",
874                total_size, dim, shape[dim_usize]
875            )));
876        }
877
878        let mut result = Vec::new();
879        let mut start = 0;
880
881        for &size in &split_sizes {
882            let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, size))?;
883            result.push(PyTensor { tensor: chunk });
884            start += size;
885        }
886
887        Ok(result)
888    }
889
890    /// Permute tensor dimensions
891    fn permute(&self, dims: Vec<i64>) -> PyResult<PyTensor> {
892        let dims: Vec<i32> = dims.iter().map(|&x| x as i32).collect();
893        let result = py_result!(self.tensor.permute(&dims))?;
894        Ok(PyTensor { tensor: result })
895    }
896
897    /// Get diagonal elements
898    fn diag(&self, diagonal: Option<i64>) -> PyResult<PyTensor> {
899        // ✅ Proper diagonal extraction implementation
900
901        let offset = diagonal.unwrap_or(0);
902        let shape = self.tensor.shape().dims().to_vec();
903
904        if shape.len() == 1 {
905            // 1D tensor -> create diagonal matrix
906            let n = shape[0];
907            let size = n + offset.abs() as usize;
908            let mut data = vec![0.0; size * size];
909            let input_data = py_result!(self.tensor.data())?;
910
911            for (i, &val) in input_data.iter().enumerate() {
912                if offset >= 0 {
913                    let row = i;
914                    let col = i + offset as usize;
915                    data[row * size + col] = val;
916                } else {
917                    let row = i + (-offset) as usize;
918                    let col = i;
919                    data[row * size + col] = val;
920                }
921            }
922
923            let result = py_result!(torsh_tensor::Tensor::from_data(
924                data,
925                vec![size, size],
926                self.tensor.device()
927            ))?;
928            Ok(PyTensor { tensor: result })
929        } else if shape.len() == 2 {
930            // 2D tensor -> extract diagonal
931            let rows = shape[0];
932            let cols = shape[1];
933            let data = py_result!(self.tensor.data())?;
934
935            let mut diag_data = Vec::new();
936
937            if offset >= 0 {
938                let offset_u = offset as usize;
939                for i in 0..std::cmp::min(rows, cols.saturating_sub(offset_u)) {
940                    diag_data.push(data[i * cols + i + offset_u]);
941                }
942            } else {
943                let offset_u = (-offset) as usize;
944                for i in 0..std::cmp::min(rows.saturating_sub(offset_u), cols) {
945                    diag_data.push(data[(i + offset_u) * cols + i]);
946                }
947            }
948
949            let diag_len = diag_data.len();
950            let result = py_result!(torsh_tensor::Tensor::from_data(
951                diag_data,
952                vec![diag_len],
953                self.tensor.device()
954            ))?;
955            Ok(PyTensor { tensor: result })
956        } else {
957            Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
958                "diag() only supports 1D or 2D tensors",
959            ))
960        }
961    }
962
963    /// Trace of matrix
964    fn trace(&self) -> PyResult<PyTensor> {
965        // ✅ Proper trace computation
966        let shape = self.tensor.shape().dims().to_vec();
967
968        if shape.len() != 2 {
969            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
970                "trace() requires a 2D tensor",
971            ));
972        }
973
974        let rows = shape[0];
975        let cols = shape[1];
976        let min_dim = std::cmp::min(rows, cols);
977
978        let data = py_result!(self.tensor.data())?;
979        let mut trace_sum = 0.0;
980
981        for i in 0..min_dim {
982            trace_sum += data[i * cols + i];
983        }
984
985        let result = py_result!(torsh_tensor::Tensor::from_data(
986            vec![trace_sum],
987            vec![],
988            self.tensor.device()
989        ))?;
990
991        Ok(PyTensor { tensor: result })
992    }
993
994    /// Norm calculation
995    fn norm(
996        &self,
997        p: Option<f32>,
998        _dim: Option<Vec<i64>>,
999        _keepdim: Option<bool>,
1000    ) -> PyResult<PyTensor> {
1001        let _p = p.unwrap_or(2.0);
1002        // For now, use simple L2 norm regardless of parameters
1003        // TODO: Implement full norm_lp functionality when ops module is exposed
1004        let result = py_result!(self.tensor.norm())?;
1005        Ok(PyTensor { tensor: result })
1006    }
1007
1008    /// Standard deviation
1009    fn std(
1010        &self,
1011        dim: Option<Vec<i64>>,
1012        keepdim: Option<bool>,
1013        unbiased: Option<bool>,
1014    ) -> PyResult<PyTensor> {
1015        let keepdim = keepdim.unwrap_or(false);
1016        let unbiased = unbiased.unwrap_or(true);
1017        let stat_mode = if unbiased {
1018            torsh_tensor::stats::StatMode::Sample
1019        } else {
1020            torsh_tensor::stats::StatMode::Population
1021        };
1022        let result = if let Some(dims) = dim {
1023            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1024            py_result!(self.tensor.std(Some(&usize_dims), keepdim, stat_mode))?
1025        } else {
1026            py_result!(self.tensor.std(None, keepdim, stat_mode))?
1027        };
1028        Ok(PyTensor { tensor: result })
1029    }
1030
1031    /// Variance calculation
1032    fn var(
1033        &self,
1034        dim: Option<Vec<i64>>,
1035        keepdim: Option<bool>,
1036        unbiased: Option<bool>,
1037    ) -> PyResult<PyTensor> {
1038        let keepdim = keepdim.unwrap_or(false);
1039        let unbiased = unbiased.unwrap_or(true);
1040        let stat_mode = if unbiased {
1041            torsh_tensor::stats::StatMode::Sample
1042        } else {
1043            torsh_tensor::stats::StatMode::Population
1044        };
1045        let result = if let Some(dims) = dim {
1046            let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1047            py_result!(self.tensor.var(Some(&usize_dims), keepdim, stat_mode))?
1048        } else {
1049            py_result!(self.tensor.var(None, keepdim, stat_mode))?
1050        };
1051        Ok(PyTensor { tensor: result })
1052    }
1053
1054    /// Argmax operation
1055    fn argmax(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1056        let result = if let Some(d) = dim {
1057            py_result!(self.tensor.argmax(Some(d as i32)))?
1058        } else {
1059            py_result!(self.tensor.argmax(None))?
1060        };
1061        let result_f32 = py_result!(result.to_f32_simd())?;
1062        Ok(PyTensor { tensor: result_f32 })
1063    }
1064
1065    /// Argmin operation
1066    fn argmin(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1067        let result = if let Some(d) = dim {
1068            py_result!(self.tensor.argmin(Some(d as i32)))?
1069        } else {
1070            py_result!(self.tensor.argmin(None))?
1071        };
1072        let result_f32 = py_result!(result.to_f32_simd())?;
1073        Ok(PyTensor { tensor: result_f32 })
1074    }
1075}
1076
1077impl PyTensor {
1078    /// Convert from Python list to tensor (simplified)
1079    fn from_py_list(list: &Bound<'_, PyList>, device: DeviceType) -> PyResult<Tensor<f32>> {
1080        let mut data = Vec::new();
1081        let len = list.len();
1082
1083        for i in 0..len {
1084            let item = list.get_item(i)?;
1085            if let Ok(val) = item.extract::<f32>() {
1086                data.push(val);
1087            } else {
1088                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
1089                    "Cannot convert item at index {} to f32",
1090                    i
1091                )));
1092            }
1093        }
1094
1095        py_result!(Tensor::from_data(data, vec![len], device))
1096    }
1097
1098    /// Helper method to create nested Python lists from tensor data
1099    fn create_nested_list(
1100        &self,
1101        py: Python<'_>,
1102        data: &[f32],
1103        shape: &[usize],
1104        dim: usize,
1105        index: &mut usize,
1106    ) -> PyResult<Py<PyAny>> {
1107        if dim == shape.len() - 1 {
1108            // Leaf dimension: create list of values
1109            let mut items = Vec::new();
1110            for _ in 0..shape[dim] {
1111                items.push(data[*index].into_pyobject(py)?.into_any().unbind());
1112                *index += 1;
1113            }
1114            Ok(items.into_pyobject(py)?.into_any().unbind())
1115        } else {
1116            // Intermediate dimension: create list of nested lists
1117            let mut items = Vec::new();
1118            for _ in 0..shape[dim] {
1119                let nested = self.create_nested_list(py, data, shape, dim + 1, index)?;
1120                items.push(nested);
1121            }
1122            Ok(items.into_pyobject(py)?.into_any().unbind())
1123        }
1124    }
1125
1126    /// Convert boolean tensor to float tensor (0.0 for false, 1.0 for true)
1127    fn bool_to_float_tensor(
1128        bool_tensor: torsh_tensor::Tensor<bool>,
1129    ) -> torsh_core::error::Result<torsh_tensor::Tensor<f32>> {
1130        let bool_data = bool_tensor.data()?;
1131        let float_data: Vec<f32> = bool_data
1132            .iter()
1133            .map(|&b| if b { 1.0 } else { 0.0 })
1134            .collect();
1135        let shape = bool_tensor.shape().dims().to_vec();
1136        let device = bool_tensor.device();
1137
1138        torsh_tensor::Tensor::from_data(float_data, shape, device)
1139    }
1140}