1use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
4use numpy::{
7 IxDyn as NpIxDyn, PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods,
8};
9use pyo3::prelude::*;
10use pyo3::types::{PyAny, PyList};
11use torsh_core::device::DeviceType;
12use torsh_tensor::Tensor;
13
14#[pyclass(name = "Tensor")]
16#[derive(Clone)]
17pub struct PyTensor {
18 pub(crate) tensor: Tensor<f32>, }
20
21#[pymethods]
22impl PyTensor {
23 #[new]
24 pub fn new(
25 data: &Bound<'_, PyAny>,
26 _dtype: Option<PyDType>,
27 device: Option<PyDevice>,
28 requires_grad: Option<bool>,
29 ) -> PyResult<Self> {
30 let device = device.map(|d| d.device).unwrap_or(DeviceType::Cpu);
31 let requires_grad = requires_grad.unwrap_or(false);
32
33 let tensor = if let Ok(arr) = data.clone().cast_into::<PyArray1<f32>>() {
35 let data = arr.to_vec()?;
37 let shape = vec![data.len()];
38 py_result!(Tensor::from_data(data, shape, device))?
39 } else if let Ok(arr) = data.clone().cast_into::<PyArray2<f32>>() {
40 let data = arr.to_vec()?;
42 let shape = arr.shape().to_vec();
43 py_result!(Tensor::from_data(data, shape, device))?
44 } else if let Ok(arr) = data.clone().cast_into::<PyArrayDyn<f32>>() {
45 let data = arr.to_vec()?;
47 let shape = arr.shape().to_vec();
48 py_result!(Tensor::from_data(data, shape, device))?
49 } else if let Ok(list) = data.clone().cast_into::<PyList>() {
50 Self::from_py_list(&list, device)?
52 } else if let Ok(scalar) = data.extract::<f32>() {
53 py_result!(Tensor::from_data(vec![scalar], vec![], device))?
55 } else {
56 return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
57 "Unsupported data type for tensor creation",
58 ));
59 };
60
61 let tensor = tensor.requires_grad_(requires_grad);
62
63 Ok(Self { tensor })
64 }
65
66 #[getter]
68 fn shape(&self) -> Vec<usize> {
69 self.tensor.shape().dims().to_vec()
70 }
71
72 #[getter]
73 fn ndim(&self) -> usize {
74 self.tensor.ndim()
75 }
76
77 #[getter]
78 pub fn numel(&self) -> usize {
79 self.tensor.numel()
80 }
81
82 #[getter]
83 fn dtype(&self) -> PyDType {
84 PyDType::from(self.tensor.dtype())
85 }
86
87 #[getter]
88 fn device(&self) -> PyDevice {
89 PyDevice::from(self.tensor.device())
90 }
91
92 #[getter]
93 fn requires_grad(&self) -> bool {
94 self.tensor.requires_grad()
95 }
96
97 fn __repr__(&self) -> String {
99 let binding = self.tensor.shape();
100 let shape = binding.dims();
101 let device_str = match self.tensor.device() {
102 DeviceType::Cpu => String::new(),
103 dev => format!(", device='{}'", PyDevice::from(dev)),
104 };
105 let grad_str = if self.tensor.requires_grad() {
106 ", requires_grad=True"
107 } else {
108 ""
109 };
110
111 format!(
112 "tensor({:?}, shape={}{}{}, dtype={})",
113 shape,
115 shape
116 .iter()
117 .map(|d| d.to_string())
118 .collect::<Vec<_>>()
119 .join(", "),
120 device_str,
121 grad_str,
122 PyDType::from(self.tensor.dtype())
123 )
124 }
125
126 fn __str__(&self) -> String {
127 self.__repr__()
128 }
129
130 fn numpy(&self) -> PyResult<Py<PyAny>> {
132 Python::attach(|py| {
133 let data = py_result!(self.tensor.to_vec())?;
134 let binding = self.tensor.shape();
135 let shape = binding.dims();
136
137 let shape_vec: Vec<usize> = shape.iter().copied().collect();
139
140 match shape.len() {
141 0 => {
142 Ok(data[0].into_pyobject(py)?.into_any().unbind())
144 }
145 1 => {
146 let array = PyArray1::from_vec(py, data);
148 Ok(array.into_pyobject(py)?.into_any().unbind())
149 }
150 2 => {
151 let array_1d = PyArray1::from_vec(py, data);
153 let array_2d = array_1d
154 .reshape([shape_vec[0], shape_vec[1]])
155 .map_err(|e| {
156 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
157 "Reshape error: {}",
158 e
159 ))
160 })?;
161 Ok(array_2d.into_pyobject(py)?.into_any().unbind())
162 }
163 _ => {
164 let array_1d = PyArray1::from_vec(py, data);
166 let array_nd = array_1d.reshape(NpIxDyn(&shape_vec)).map_err(|e| {
167 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
168 "Reshape error: {}",
169 e
170 ))
171 })?;
172 Ok(array_nd.into_pyobject(py)?.into_any().unbind())
173 }
174 }
175 })
176 }
177
178 fn item(&self) -> PyResult<f32> {
180 if self.tensor.numel() != 1 {
181 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
182 "Only one element tensors can be converted to Python scalars",
183 ));
184 }
185 let data = py_result!(self.tensor.to_vec())?;
186 Ok(data[0])
187 }
188
189 fn tolist(&self) -> PyResult<Py<PyAny>> {
191 Python::attach(|py| {
192 let data = py_result!(self.tensor.to_vec())?;
193 let binding = self.tensor.shape();
194 let shape = binding.dims();
195
196 if shape.is_empty() {
197 Ok(data[0].into_pyobject(py)?.into_any().unbind())
199 } else {
200 let nested_list = self.create_nested_list(py, &data, shape, 0, &mut 0)?;
202 Ok(nested_list)
203 }
204 })
205 }
206
207 fn copy(&self) -> PyResult<PyTensor> {
209 Ok(PyTensor {
210 tensor: self.tensor.clone(),
211 })
212 }
213
214 fn stride(&self) -> Vec<usize> {
216 let binding = self.tensor.shape();
218 let shape = binding.dims();
219 let mut strides = vec![1; shape.len()];
220 for i in (0..shape.len().saturating_sub(1)).rev() {
221 strides[i] = strides[i + 1] * shape[i + 1];
222 }
223 strides
224 }
225
226 fn itemsize(&self) -> usize {
228 std::mem::size_of::<f32>()
229 }
230
231 fn nbytes(&self) -> usize {
233 self.tensor.numel() * self.itemsize()
234 }
235
236 fn is_c_contiguous(&self) -> bool {
238 self.tensor.is_contiguous()
239 }
240
241 fn add(&self, other: &PyTensor) -> PyResult<PyTensor> {
247 let result = py_result!(self.tensor.add(&other.tensor))?;
248 Ok(PyTensor { tensor: result })
249 }
250
251 fn sub(&self, other: &PyTensor) -> PyResult<PyTensor> {
253 let result = py_result!(self.tensor.sub(&other.tensor))?;
254 Ok(PyTensor { tensor: result })
255 }
256
257 fn mul(&self, other: &PyTensor) -> PyResult<PyTensor> {
259 let result = py_result!(self.tensor.mul(&other.tensor))?;
260 Ok(PyTensor { tensor: result })
261 }
262
263 fn div(&self, other: &PyTensor) -> PyResult<PyTensor> {
265 let result = py_result!(self.tensor.div(&other.tensor))?;
266 Ok(PyTensor { tensor: result })
267 }
268
269 fn pow(&self, exponent: f32) -> PyResult<PyTensor> {
271 let result = py_result!(self.tensor.pow(exponent))?;
272 Ok(PyTensor { tensor: result })
273 }
274
275 fn add_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
277 let result = py_result!(self.tensor.add_scalar(scalar))?;
278 Ok(PyTensor { tensor: result })
279 }
280
281 fn mul_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
283 let result = py_result!(self.tensor.mul_scalar(scalar))?;
284 Ok(PyTensor { tensor: result })
285 }
286
287 fn reshape(&self, shape: Vec<i64>) -> PyResult<PyTensor> {
293 let i32_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
294 let result = py_result!(self.tensor.reshape(&i32_shape))?;
295 Ok(PyTensor { tensor: result })
296 }
297
298 fn transpose(&self, dim0: i64, dim1: i64) -> PyResult<PyTensor> {
300 let result = py_result!(self.tensor.transpose(dim0 as i32, dim1 as i32))?;
301 Ok(PyTensor { tensor: result })
302 }
303
304 fn t(&self) -> PyResult<PyTensor> {
306 if self.tensor.ndim() != 2 {
307 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
308 "t() can only be called on 2D tensors",
309 ));
310 }
311 self.transpose(0, 1)
312 }
313
314 fn squeeze(&self, dim: Option<i64>) -> PyResult<PyTensor> {
316 let dim_to_squeeze = dim.map(|d| d as i32).unwrap_or(0i32);
317 let result = py_result!(self.tensor.squeeze(dim_to_squeeze))?;
318 Ok(PyTensor { tensor: result })
319 }
320
321 fn unsqueeze(&self, dim: i64) -> PyResult<PyTensor> {
323 let result = py_result!(self.tensor.unsqueeze(dim as i32))?;
324 Ok(PyTensor { tensor: result })
325 }
326
327 fn flatten(&self, _start_dim: Option<i64>, _end_dim: Option<i64>) -> PyResult<PyTensor> {
329 let result = py_result!(self.tensor.flatten())?;
331 Ok(PyTensor { tensor: result })
332 }
333
334 fn sum(&self, _dim: Option<Vec<i64>>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
340 let result = py_result!(self.tensor.sum())?;
342 Ok(PyTensor { tensor: result })
343 }
344
345 fn mean(&self, dim: Option<Vec<i64>>, keepdim: Option<bool>) -> PyResult<PyTensor> {
347 let keepdim = keepdim.unwrap_or(false);
348 let result = if let Some(dims) = dim {
349 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
350 py_result!(self.tensor.mean(Some(&usize_dims), keepdim))?
351 } else {
352 py_result!(self.tensor.mean(None, keepdim))?
353 };
354 Ok(PyTensor { tensor: result })
355 }
356
357 fn max(&self, dim: Option<i64>, keepdim: Option<bool>) -> PyResult<PyTensor> {
359 let dim_opt = dim.map(|d| d as usize);
360 let keepdim = keepdim.unwrap_or(false);
361 let result = py_result!(self.tensor.max(dim_opt, keepdim))?;
362 Ok(PyTensor { tensor: result })
363 }
364
365 fn min(&self, _dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
367 let result = py_result!(self.tensor.min())?;
369 Ok(PyTensor { tensor: result })
370 }
371
372 fn matmul(&self, other: &PyTensor) -> PyResult<PyTensor> {
378 let result = py_result!(self.tensor.matmul(&other.tensor))?;
379 Ok(PyTensor { tensor: result })
380 }
381
382 fn dot(&self, other: &PyTensor) -> PyResult<PyTensor> {
384 let result = py_result!(self.tensor.dot(&other.tensor))?;
385 Ok(PyTensor { tensor: result })
386 }
387
388 fn relu(&self) -> PyResult<PyTensor> {
394 let result = py_result!(self.tensor.relu())?;
395 Ok(PyTensor { tensor: result })
396 }
397
398 fn sigmoid(&self) -> PyResult<PyTensor> {
400 let result = py_result!(self.tensor.sigmoid())?;
401 Ok(PyTensor { tensor: result })
402 }
403
404 fn tanh(&self) -> PyResult<PyTensor> {
406 let result = py_result!(self.tensor.tanh())?;
407 Ok(PyTensor { tensor: result })
408 }
409
410 fn softmax(&self, dim: i64) -> PyResult<PyTensor> {
412 let result = py_result!(self.tensor.softmax(dim as i32))?;
413 Ok(PyTensor { tensor: result })
414 }
415
416 fn sin(&self) -> PyResult<PyTensor> {
422 let result = py_result!(self.tensor.sin())?;
423 Ok(PyTensor { tensor: result })
424 }
425
426 fn cos(&self) -> PyResult<PyTensor> {
428 let result = py_result!(self.tensor.cos())?;
429 Ok(PyTensor { tensor: result })
430 }
431
432 fn exp(&self) -> PyResult<PyTensor> {
434 let result = py_result!(self.tensor.exp())?;
435 Ok(PyTensor { tensor: result })
436 }
437
438 fn log(&self) -> PyResult<PyTensor> {
440 let result = py_result!(self.tensor.log())?;
441 Ok(PyTensor { tensor: result })
442 }
443
444 fn sqrt(&self) -> PyResult<PyTensor> {
446 let result = py_result!(self.tensor.sqrt())?;
447 Ok(PyTensor { tensor: result })
448 }
449
450 fn abs(&self) -> PyResult<PyTensor> {
452 let result = py_result!(self.tensor.abs())?;
453 Ok(PyTensor { tensor: result })
454 }
455
456 fn eq(&self, other: &PyTensor) -> PyResult<PyTensor> {
462 let bool_result = py_result!(self.tensor.eq(&other.tensor))?;
463 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
464 Ok(PyTensor {
465 tensor: float_tensor,
466 })
467 }
468
469 fn ne(&self, other: &PyTensor) -> PyResult<PyTensor> {
471 let bool_result = py_result!(self.tensor.ne(&other.tensor))?;
472 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
473 Ok(PyTensor {
474 tensor: float_tensor,
475 })
476 }
477
478 fn lt(&self, other: &PyTensor) -> PyResult<PyTensor> {
480 let bool_result = py_result!(self.tensor.lt(&other.tensor))?;
481 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
482 Ok(PyTensor {
483 tensor: float_tensor,
484 })
485 }
486
487 fn gt(&self, other: &PyTensor) -> PyResult<PyTensor> {
489 let bool_result = py_result!(self.tensor.gt(&other.tensor))?;
490 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
491 Ok(PyTensor {
492 tensor: float_tensor,
493 })
494 }
495
496 fn le(&self, other: &PyTensor) -> PyResult<PyTensor> {
498 let bool_result = py_result!(self.tensor.le(&other.tensor))?;
499 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
500 Ok(PyTensor {
501 tensor: float_tensor,
502 })
503 }
504
505 fn ge(&self, other: &PyTensor) -> PyResult<PyTensor> {
507 let bool_result = py_result!(self.tensor.ge(&other.tensor))?;
508 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
509 Ok(PyTensor {
510 tensor: float_tensor,
511 })
512 }
513
514 fn maximum(&self, other: &PyTensor) -> PyResult<PyTensor> {
516 let result = py_result!(self.tensor.maximum(&other.tensor))?;
517 Ok(PyTensor { tensor: result })
518 }
519
520 fn minimum(&self, other: &PyTensor) -> PyResult<PyTensor> {
522 let result = py_result!(self.tensor.minimum(&other.tensor))?;
523 Ok(PyTensor { tensor: result })
524 }
525
526 fn eq_scalar(&self, value: f32) -> PyResult<PyTensor> {
530 let bool_result = py_result!(self.tensor.eq_scalar(value))?;
531 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
532 Ok(PyTensor {
533 tensor: float_tensor,
534 })
535 }
536
537 fn ne_scalar(&self, value: f32) -> PyResult<PyTensor> {
539 let bool_result = py_result!(self.tensor.ne_scalar(value))?;
540 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
541 Ok(PyTensor {
542 tensor: float_tensor,
543 })
544 }
545
546 fn lt_scalar(&self, value: f32) -> PyResult<PyTensor> {
548 let bool_result = py_result!(self.tensor.lt_scalar(value))?;
549 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
550 Ok(PyTensor {
551 tensor: float_tensor,
552 })
553 }
554
555 fn gt_scalar(&self, value: f32) -> PyResult<PyTensor> {
557 let bool_result = py_result!(self.tensor.gt_scalar(value))?;
558 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
559 Ok(PyTensor {
560 tensor: float_tensor,
561 })
562 }
563
564 fn le_scalar(&self, value: f32) -> PyResult<PyTensor> {
566 let bool_result = py_result!(self.tensor.le_scalar(value))?;
567 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
568 Ok(PyTensor {
569 tensor: float_tensor,
570 })
571 }
572
573 fn ge_scalar(&self, value: f32) -> PyResult<PyTensor> {
575 let bool_result = py_result!(self.tensor.ge_scalar(value))?;
576 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
577 Ok(PyTensor {
578 tensor: float_tensor,
579 })
580 }
581
582 fn clone_tensor(&self) -> PyResult<PyTensor> {
588 Ok(PyTensor {
589 tensor: self.tensor.clone(),
590 })
591 }
592
593 fn detach(&self) -> PyResult<PyTensor> {
595 Ok(PyTensor {
596 tensor: self.tensor.detach(),
597 })
598 }
599
600 fn to_device(&self, device: PyDevice) -> PyResult<PyTensor> {
602 let result = py_result!(self.tensor.clone().to(device.device))?;
603 Ok(PyTensor { tensor: result })
604 }
605
606 fn is_contiguous(&self) -> bool {
608 self.tensor.is_contiguous()
609 }
610
611 fn contiguous(&self) -> PyResult<PyTensor> {
613 let result = py_result!(self.tensor.contiguous())?;
614 Ok(PyTensor { tensor: result })
615 }
616
617 fn clamp(&self, min: Option<f32>, max: Option<f32>) -> PyResult<PyTensor> {
623 let result = if let (Some(min_val), Some(max_val)) = (min, max) {
624 py_result!(self.tensor.clamp(min_val, max_val))?
625 } else if let Some(min_val) = min {
626 py_result!(self.tensor.clamp(min_val, f32::MAX))?
628 } else if let Some(max_val) = max {
629 py_result!(self.tensor.clamp(f32::MIN, max_val))?
631 } else {
632 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
633 "At least one of min or max must be specified",
634 ));
635 };
636 Ok(PyTensor { tensor: result })
637 }
638
639 fn fill_(&mut self, value: f32) -> PyResult<PyTensor> {
641 py_result!(self.tensor.fill_(value))?;
642 Ok(PyTensor {
643 tensor: self.tensor.clone(),
644 })
645 }
646
647 fn zero_(&mut self) -> PyResult<PyTensor> {
649 py_result!(self.tensor.zero_())?;
650 Ok(PyTensor {
651 tensor: self.tensor.clone(),
652 })
653 }
654
655 fn uniform_(&mut self, from: Option<f32>, to: Option<f32>) -> PyResult<PyTensor> {
657 use scirs2_core::random::{thread_rng, Distribution, Uniform};
659
660 let from = from.unwrap_or(0.0);
661 let to = to.unwrap_or(1.0);
662
663 let mut rng = thread_rng();
664 let dist = Uniform::new(from, to).map_err(|e| {
665 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
666 "Invalid uniform distribution parameters: {}",
667 e
668 ))
669 })?;
670
671 let mut data = py_result!(self.tensor.data())?;
672 for val in data.iter_mut() {
673 *val = dist.sample(&mut rng);
674 }
675
676 let shape = self.tensor.shape().dims().to_vec();
677 let device = self.tensor.device();
678 self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
679 .requires_grad_(self.tensor.requires_grad());
680
681 Ok(PyTensor {
682 tensor: self.tensor.clone(),
683 })
684 }
685
686 fn normal_(&mut self, mean: Option<f32>, std: Option<f32>) -> PyResult<PyTensor> {
688 use scirs2_core::random::{thread_rng, Distribution, Normal};
690
691 let mean = mean.unwrap_or(0.0);
692 let std = std.unwrap_or(1.0);
693
694 let normal = Normal::new(mean as f64, std as f64).map_err(|e| {
695 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
696 "Invalid normal distribution parameters: {}",
697 e
698 ))
699 })?;
700
701 let mut rng = thread_rng();
702 let mut data = py_result!(self.tensor.data())?;
703 for val in data.iter_mut() {
704 *val = normal.sample(&mut rng) as f32;
705 }
706
707 let shape = self.tensor.shape().dims().to_vec();
708 let device = self.tensor.device();
709 self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
710 .requires_grad_(self.tensor.requires_grad());
711
712 Ok(PyTensor {
713 tensor: self.tensor.clone(),
714 })
715 }
716
717 fn repeat(&self, repeats: Vec<usize>) -> PyResult<PyTensor> {
719 let result = py_result!(self.tensor.repeat(&repeats))?;
720 Ok(PyTensor { tensor: result })
721 }
722
723 fn expand(&self, size: Vec<i64>) -> PyResult<PyTensor> {
725 let size: Vec<usize> = size.iter().map(|&x| x as usize).collect();
726 let result = py_result!(self.tensor.expand(&size))?;
727 Ok(PyTensor { tensor: result })
728 }
729
730 fn expand_as(&self, other: &PyTensor) -> PyResult<PyTensor> {
732 let other_shape: Vec<usize> = other.tensor.shape().dims().iter().map(|&x| x).collect();
733 let result = py_result!(self.tensor.expand(&other_shape))?;
734 Ok(PyTensor { tensor: result })
735 }
736
737 fn index_select(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
739 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
740 let index_i64 = py_result!(index_i32.to_i64_simd())?;
741 let result = py_result!(self.tensor.index_select(dim as i32, &index_i64))?;
742 Ok(PyTensor { tensor: result })
743 }
744
745 fn gather(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
747 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
748 let index_i64 = py_result!(index_i32.to_i64_simd())?;
749 let result = py_result!(self.tensor.gather(dim as usize, &index_i64))?;
750 Ok(PyTensor { tensor: result })
751 }
752
753 fn scatter(&self, dim: i64, index: &PyTensor, src: &PyTensor) -> PyResult<PyTensor> {
755 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
756 let index_i64 = py_result!(index_i32.to_i64_simd())?;
757 let result = py_result!(self.tensor.scatter(dim as usize, &index_i64, &src.tensor))?;
758 Ok(PyTensor { tensor: result })
759 }
760
761 fn masked_fill(&self, mask: &PyTensor, value: f32) -> PyResult<PyTensor> {
763 let mut data = py_result!(self.tensor.data())?;
765 let mask_data = py_result!(mask.tensor.data())?;
766
767 if data.len() != mask_data.len() {
768 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
769 "Mask must have the same number of elements as tensor",
770 ));
771 }
772
773 for (val, &mask_val) in data.iter_mut().zip(mask_data.iter()) {
774 if mask_val != 0.0 {
775 *val = value;
776 }
777 }
778
779 let shape = self.tensor.shape().dims().to_vec();
780 let device = self.tensor.device();
781 let result = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
782 .requires_grad_(self.tensor.requires_grad());
783
784 Ok(PyTensor { tensor: result })
785 }
786
787 fn masked_select(&self, mask: &PyTensor) -> PyResult<PyTensor> {
789 let mask_data = py_result!(mask.tensor.data())?;
791 let bool_data: Vec<bool> = mask_data.iter().map(|&x| x != 0.0).collect();
792 let mask_bool = py_result!(torsh_tensor::Tensor::from_data(
793 bool_data,
794 mask.tensor.shape().dims().to_vec(),
795 mask.tensor.device()
796 ))?;
797 let result = py_result!(self.tensor.masked_select(&mask_bool))?;
798 Ok(PyTensor { tensor: result })
799 }
800
801 fn cat(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
803 let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
804 tensors.iter().map(|t| &t.tensor).collect();
805 let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
806 Ok(PyTensor { tensor: result })
807 }
808
809 fn stack(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
811 let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
812 tensors.iter().map(|t| &t.tensor).collect();
813 let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
815 Ok(PyTensor { tensor: result })
816 }
817
818 fn chunk(&self, chunks: usize, dim: i64) -> PyResult<Vec<PyTensor>> {
820 let shape = self.tensor.shape().dims().to_vec();
822 let dim_usize = dim as usize;
823
824 if dim_usize >= shape.len() {
825 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
826 "Dimension {} out of range for tensor with {} dimensions",
827 dim,
828 shape.len()
829 )));
830 }
831
832 let dim_size = shape[dim_usize];
833 let chunk_size = (dim_size + chunks - 1) / chunks; let mut result = Vec::new();
836 for i in 0..chunks {
837 let start = i * chunk_size;
838 if start >= dim_size {
839 break;
840 }
841 let end = std::cmp::min(start + chunk_size, dim_size);
842
843 let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, end - start))?;
845 result.push(PyTensor { tensor: chunk });
846 }
847
848 Ok(result)
849 }
850
851 fn split(&self, split_sizes: Vec<usize>, dim: i64) -> PyResult<Vec<PyTensor>> {
853 let shape = self.tensor.shape().dims().to_vec();
855 let dim_usize = dim as usize;
856
857 if dim_usize >= shape.len() {
858 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
859 "Dimension {} out of range for tensor with {} dimensions",
860 dim,
861 shape.len()
862 )));
863 }
864
865 let total_size: usize = split_sizes.iter().sum();
866 if total_size != shape[dim_usize] {
867 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
868 "Split sizes sum to {} but dimension {} has size {}",
869 total_size, dim, shape[dim_usize]
870 )));
871 }
872
873 let mut result = Vec::new();
874 let mut start = 0;
875
876 for &size in &split_sizes {
877 let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, size))?;
878 result.push(PyTensor { tensor: chunk });
879 start += size;
880 }
881
882 Ok(result)
883 }
884
885 fn permute(&self, dims: Vec<i64>) -> PyResult<PyTensor> {
887 let dims: Vec<i32> = dims.iter().map(|&x| x as i32).collect();
888 let result = py_result!(self.tensor.permute(&dims))?;
889 Ok(PyTensor { tensor: result })
890 }
891
892 fn diag(&self, diagonal: Option<i64>) -> PyResult<PyTensor> {
894 let offset = diagonal.unwrap_or(0);
897 let shape = self.tensor.shape().dims().to_vec();
898
899 if shape.len() == 1 {
900 let n = shape[0];
902 let size = n + offset.abs() as usize;
903 let mut data = vec![0.0; size * size];
904 let input_data = py_result!(self.tensor.data())?;
905
906 for (i, &val) in input_data.iter().enumerate() {
907 if offset >= 0 {
908 let row = i;
909 let col = i + offset as usize;
910 data[row * size + col] = val;
911 } else {
912 let row = i + (-offset) as usize;
913 let col = i;
914 data[row * size + col] = val;
915 }
916 }
917
918 let result = py_result!(torsh_tensor::Tensor::from_data(
919 data,
920 vec![size, size],
921 self.tensor.device()
922 ))?;
923 Ok(PyTensor { tensor: result })
924 } else if shape.len() == 2 {
925 let rows = shape[0];
927 let cols = shape[1];
928 let data = py_result!(self.tensor.data())?;
929
930 let mut diag_data = Vec::new();
931
932 if offset >= 0 {
933 let offset_u = offset as usize;
934 for i in 0..std::cmp::min(rows, cols.saturating_sub(offset_u)) {
935 diag_data.push(data[i * cols + i + offset_u]);
936 }
937 } else {
938 let offset_u = (-offset) as usize;
939 for i in 0..std::cmp::min(rows.saturating_sub(offset_u), cols) {
940 diag_data.push(data[(i + offset_u) * cols + i]);
941 }
942 }
943
944 let diag_len = diag_data.len();
945 let result = py_result!(torsh_tensor::Tensor::from_data(
946 diag_data,
947 vec![diag_len],
948 self.tensor.device()
949 ))?;
950 Ok(PyTensor { tensor: result })
951 } else {
952 Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
953 "diag() only supports 1D or 2D tensors",
954 ))
955 }
956 }
957
958 fn trace(&self) -> PyResult<PyTensor> {
960 let shape = self.tensor.shape().dims().to_vec();
962
963 if shape.len() != 2 {
964 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
965 "trace() requires a 2D tensor",
966 ));
967 }
968
969 let rows = shape[0];
970 let cols = shape[1];
971 let min_dim = std::cmp::min(rows, cols);
972
973 let data = py_result!(self.tensor.data())?;
974 let mut trace_sum = 0.0;
975
976 for i in 0..min_dim {
977 trace_sum += data[i * cols + i];
978 }
979
980 let result = py_result!(torsh_tensor::Tensor::from_data(
981 vec![trace_sum],
982 vec![],
983 self.tensor.device()
984 ))?;
985
986 Ok(PyTensor { tensor: result })
987 }
988
989 fn norm(
991 &self,
992 p: Option<f32>,
993 _dim: Option<Vec<i64>>,
994 _keepdim: Option<bool>,
995 ) -> PyResult<PyTensor> {
996 let _p = p.unwrap_or(2.0);
997 let result = py_result!(self.tensor.norm())?;
1000 Ok(PyTensor { tensor: result })
1001 }
1002
1003 fn std(
1005 &self,
1006 dim: Option<Vec<i64>>,
1007 keepdim: Option<bool>,
1008 unbiased: Option<bool>,
1009 ) -> PyResult<PyTensor> {
1010 let keepdim = keepdim.unwrap_or(false);
1011 let unbiased = unbiased.unwrap_or(true);
1012 let stat_mode = if unbiased {
1013 torsh_tensor::stats::StatMode::Sample
1014 } else {
1015 torsh_tensor::stats::StatMode::Population
1016 };
1017 let result = if let Some(dims) = dim {
1018 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1019 py_result!(self.tensor.std(Some(&usize_dims), keepdim, stat_mode))?
1020 } else {
1021 py_result!(self.tensor.std(None, keepdim, stat_mode))?
1022 };
1023 Ok(PyTensor { tensor: result })
1024 }
1025
1026 fn var(
1028 &self,
1029 dim: Option<Vec<i64>>,
1030 keepdim: Option<bool>,
1031 unbiased: Option<bool>,
1032 ) -> PyResult<PyTensor> {
1033 let keepdim = keepdim.unwrap_or(false);
1034 let unbiased = unbiased.unwrap_or(true);
1035 let stat_mode = if unbiased {
1036 torsh_tensor::stats::StatMode::Sample
1037 } else {
1038 torsh_tensor::stats::StatMode::Population
1039 };
1040 let result = if let Some(dims) = dim {
1041 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1042 py_result!(self.tensor.var(Some(&usize_dims), keepdim, stat_mode))?
1043 } else {
1044 py_result!(self.tensor.var(None, keepdim, stat_mode))?
1045 };
1046 Ok(PyTensor { tensor: result })
1047 }
1048
1049 fn argmax(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1051 let result = if let Some(d) = dim {
1052 py_result!(self.tensor.argmax(Some(d as i32)))?
1053 } else {
1054 py_result!(self.tensor.argmax(None))?
1055 };
1056 let result_f32 = py_result!(result.to_f32_simd())?;
1057 Ok(PyTensor { tensor: result_f32 })
1058 }
1059
1060 fn argmin(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1062 let result = if let Some(d) = dim {
1063 py_result!(self.tensor.argmin(Some(d as i32)))?
1064 } else {
1065 py_result!(self.tensor.argmin(None))?
1066 };
1067 let result_f32 = py_result!(result.to_f32_simd())?;
1068 Ok(PyTensor { tensor: result_f32 })
1069 }
1070}
1071
1072impl PyTensor {
1073 fn from_py_list(list: &Bound<'_, PyList>, device: DeviceType) -> PyResult<Tensor<f32>> {
1075 let mut data = Vec::new();
1076 let len = list.len();
1077
1078 for i in 0..len {
1079 let item = list.get_item(i)?;
1080 if let Ok(val) = item.extract::<f32>() {
1081 data.push(val);
1082 } else {
1083 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
1084 "Cannot convert item at index {} to f32",
1085 i
1086 )));
1087 }
1088 }
1089
1090 py_result!(Tensor::from_data(data, vec![len], device))
1091 }
1092
1093 fn create_nested_list(
1095 &self,
1096 py: Python<'_>,
1097 data: &[f32],
1098 shape: &[usize],
1099 dim: usize,
1100 index: &mut usize,
1101 ) -> PyResult<Py<PyAny>> {
1102 if dim == shape.len() - 1 {
1103 let mut items = Vec::new();
1105 for _ in 0..shape[dim] {
1106 items.push(data[*index].into_pyobject(py)?.into_any().unbind());
1107 *index += 1;
1108 }
1109 Ok(items.into_pyobject(py)?.into_any().unbind())
1110 } else {
1111 let mut items = Vec::new();
1113 for _ in 0..shape[dim] {
1114 let nested = self.create_nested_list(py, data, shape, dim + 1, index)?;
1115 items.push(nested);
1116 }
1117 Ok(items.into_pyobject(py)?.into_any().unbind())
1118 }
1119 }
1120
1121 fn bool_to_float_tensor(
1123 bool_tensor: torsh_tensor::Tensor<bool>,
1124 ) -> torsh_core::error::Result<torsh_tensor::Tensor<f32>> {
1125 let bool_data = bool_tensor.data()?;
1126 let float_data: Vec<f32> = bool_data
1127 .iter()
1128 .map(|&b| if b { 1.0 } else { 0.0 })
1129 .collect();
1130 let shape = bool_tensor.shape().dims().to_vec();
1131 let device = bool_tensor.device();
1132
1133 torsh_tensor::Tensor::from_data(float_data, shape, device)
1134 }
1135}