1use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
4use numpy::{PyArray1, PyArray2, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods, ToPyArray};
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyList};
8use scirs2_core::ndarray::{Array, IxDyn};
9use torsh_core::device::DeviceType;
10use torsh_tensor::Tensor;
11
12#[pyclass(name = "Tensor")]
14#[derive(Clone)]
15pub struct PyTensor {
16 pub(crate) tensor: Tensor<f32>, }
18
19#[pymethods]
20impl PyTensor {
21 #[new]
22 pub fn new(
23 data: &Bound<'_, PyAny>,
24 _dtype: Option<PyDType>,
25 device: Option<PyDevice>,
26 requires_grad: Option<bool>,
27 ) -> PyResult<Self> {
28 let device = device.map(|d| d.device).unwrap_or(DeviceType::Cpu);
29 let requires_grad = requires_grad.unwrap_or(false);
30
31 let tensor = if let Ok(arr) = data.clone().cast_into::<PyArray1<f32>>() {
33 let data = arr.to_vec()?;
35 let shape = vec![data.len()];
36 py_result!(Tensor::from_data(data, shape, device))?
37 } else if let Ok(arr) = data.clone().cast_into::<PyArray2<f32>>() {
38 let data = arr.to_vec()?;
40 let shape = arr.shape().to_vec();
41 py_result!(Tensor::from_data(data, shape, device))?
42 } else if let Ok(arr) = data.clone().cast_into::<PyArrayDyn<f32>>() {
43 let data = arr.to_vec()?;
45 let shape = arr.shape().to_vec();
46 py_result!(Tensor::from_data(data, shape, device))?
47 } else if let Ok(list) = data.clone().cast_into::<PyList>() {
48 Self::from_py_list(&list, device)?
50 } else if let Ok(scalar) = data.extract::<f32>() {
51 py_result!(Tensor::from_data(vec![scalar], vec![], device))?
53 } else {
54 return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
55 "Unsupported data type for tensor creation",
56 ));
57 };
58
59 let tensor = tensor.requires_grad_(requires_grad);
60
61 Ok(Self { tensor })
62 }
63
64 #[getter]
66 fn shape(&self) -> Vec<usize> {
67 self.tensor.shape().dims().to_vec()
68 }
69
70 #[getter]
71 fn ndim(&self) -> usize {
72 self.tensor.ndim()
73 }
74
75 #[getter]
76 pub fn numel(&self) -> usize {
77 self.tensor.numel()
78 }
79
80 #[getter]
81 fn dtype(&self) -> PyDType {
82 PyDType::from(self.tensor.dtype())
83 }
84
85 #[getter]
86 fn device(&self) -> PyDevice {
87 PyDevice::from(self.tensor.device())
88 }
89
90 #[getter]
91 fn requires_grad(&self) -> bool {
92 self.tensor.requires_grad()
93 }
94
95 fn __repr__(&self) -> String {
97 let binding = self.tensor.shape();
98 let shape = binding.dims();
99 let device_str = match self.tensor.device() {
100 DeviceType::Cpu => String::new(),
101 dev => format!(", device='{}'", PyDevice::from(dev)),
102 };
103 let grad_str = if self.tensor.requires_grad() {
104 ", requires_grad=True"
105 } else {
106 ""
107 };
108
109 format!(
110 "tensor({:?}, shape={}{}{}, dtype={})",
111 shape,
113 shape
114 .iter()
115 .map(|d| d.to_string())
116 .collect::<Vec<_>>()
117 .join(", "),
118 device_str,
119 grad_str,
120 PyDType::from(self.tensor.dtype())
121 )
122 }
123
124 fn __str__(&self) -> String {
125 self.__repr__()
126 }
127
128 fn numpy(&self) -> PyResult<Py<PyAny>> {
130 Python::attach(|py| {
131 let data = py_result!(self.tensor.to_vec())?;
132 let binding = self.tensor.shape();
133 let shape = binding.dims();
134
135 let shape_vec: Vec<usize> = shape.iter().copied().collect();
137
138 match shape.len() {
139 0 => {
140 Ok(data[0].into_pyobject(py)?.into_any().unbind())
142 }
143 1 => {
144 let array = data.to_pyarray(py);
146 Ok(array.into_pyobject(py)?.into_any().unbind())
147 }
148 2 => {
149 let ndarray = Array::from_vec(data)
152 .into_dyn()
153 .to_shape(IxDyn(&shape_vec))
154 .map_err(|e| {
155 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
156 "Shape error: {}",
157 e
158 ))
159 })?
160 .into_owned();
161 let array = ndarray.to_pyarray(py);
162 Ok(array.into_pyobject(py)?.into_any().unbind())
163 }
164 _ => {
165 let ndarray = Array::from_vec(data)
167 .into_dyn()
168 .to_shape(IxDyn(&shape_vec))
169 .map_err(|e| {
170 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
171 "Shape error: {}",
172 e
173 ))
174 })?
175 .into_owned();
176 let array = ndarray.to_pyarray(py);
177 Ok(array.into_pyobject(py)?.into_any().unbind())
178 }
179 }
180 })
181 }
182
183 fn item(&self) -> PyResult<f32> {
185 if self.tensor.numel() != 1 {
186 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
187 "Only one element tensors can be converted to Python scalars",
188 ));
189 }
190 let data = py_result!(self.tensor.to_vec())?;
191 Ok(data[0])
192 }
193
194 fn tolist(&self) -> PyResult<Py<PyAny>> {
196 Python::attach(|py| {
197 let data = py_result!(self.tensor.to_vec())?;
198 let binding = self.tensor.shape();
199 let shape = binding.dims();
200
201 if shape.is_empty() {
202 Ok(data[0].into_pyobject(py)?.into_any().unbind())
204 } else {
205 let nested_list = self.create_nested_list(py, &data, shape, 0, &mut 0)?;
207 Ok(nested_list)
208 }
209 })
210 }
211
212 fn copy(&self) -> PyResult<PyTensor> {
214 Ok(PyTensor {
215 tensor: self.tensor.clone(),
216 })
217 }
218
219 fn stride(&self) -> Vec<usize> {
221 let binding = self.tensor.shape();
223 let shape = binding.dims();
224 let mut strides = vec![1; shape.len()];
225 for i in (0..shape.len().saturating_sub(1)).rev() {
226 strides[i] = strides[i + 1] * shape[i + 1];
227 }
228 strides
229 }
230
231 fn itemsize(&self) -> usize {
233 std::mem::size_of::<f32>()
234 }
235
236 fn nbytes(&self) -> usize {
238 self.tensor.numel() * self.itemsize()
239 }
240
241 fn is_c_contiguous(&self) -> bool {
243 self.tensor.is_contiguous()
244 }
245
246 fn add(&self, other: &PyTensor) -> PyResult<PyTensor> {
252 let result = py_result!(self.tensor.add(&other.tensor))?;
253 Ok(PyTensor { tensor: result })
254 }
255
256 fn sub(&self, other: &PyTensor) -> PyResult<PyTensor> {
258 let result = py_result!(self.tensor.sub(&other.tensor))?;
259 Ok(PyTensor { tensor: result })
260 }
261
262 fn mul(&self, other: &PyTensor) -> PyResult<PyTensor> {
264 let result = py_result!(self.tensor.mul(&other.tensor))?;
265 Ok(PyTensor { tensor: result })
266 }
267
268 fn div(&self, other: &PyTensor) -> PyResult<PyTensor> {
270 let result = py_result!(self.tensor.div(&other.tensor))?;
271 Ok(PyTensor { tensor: result })
272 }
273
274 fn pow(&self, exponent: f32) -> PyResult<PyTensor> {
276 let result = py_result!(self.tensor.pow(exponent))?;
277 Ok(PyTensor { tensor: result })
278 }
279
280 fn add_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
282 let result = py_result!(self.tensor.add_scalar(scalar))?;
283 Ok(PyTensor { tensor: result })
284 }
285
286 fn mul_scalar(&self, scalar: f32) -> PyResult<PyTensor> {
288 let result = py_result!(self.tensor.mul_scalar(scalar))?;
289 Ok(PyTensor { tensor: result })
290 }
291
292 fn reshape(&self, shape: Vec<i64>) -> PyResult<PyTensor> {
298 let i32_shape: Vec<i32> = shape.iter().map(|&x| x as i32).collect();
299 let result = py_result!(self.tensor.reshape(&i32_shape))?;
300 Ok(PyTensor { tensor: result })
301 }
302
303 fn transpose(&self, dim0: i64, dim1: i64) -> PyResult<PyTensor> {
305 let result = py_result!(self.tensor.transpose(dim0 as i32, dim1 as i32))?;
306 Ok(PyTensor { tensor: result })
307 }
308
309 fn t(&self) -> PyResult<PyTensor> {
311 if self.tensor.ndim() != 2 {
312 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
313 "t() can only be called on 2D tensors",
314 ));
315 }
316 self.transpose(0, 1)
317 }
318
319 fn squeeze(&self, dim: Option<i64>) -> PyResult<PyTensor> {
321 let dim_to_squeeze = dim.map(|d| d as i32).unwrap_or(0i32);
322 let result = py_result!(self.tensor.squeeze(dim_to_squeeze))?;
323 Ok(PyTensor { tensor: result })
324 }
325
326 fn unsqueeze(&self, dim: i64) -> PyResult<PyTensor> {
328 let result = py_result!(self.tensor.unsqueeze(dim as i32))?;
329 Ok(PyTensor { tensor: result })
330 }
331
332 fn flatten(&self, _start_dim: Option<i64>, _end_dim: Option<i64>) -> PyResult<PyTensor> {
334 let result = py_result!(self.tensor.flatten())?;
336 Ok(PyTensor { tensor: result })
337 }
338
339 fn sum(&self, _dim: Option<Vec<i64>>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
345 let result = py_result!(self.tensor.sum())?;
347 Ok(PyTensor { tensor: result })
348 }
349
350 fn mean(&self, dim: Option<Vec<i64>>, keepdim: Option<bool>) -> PyResult<PyTensor> {
352 let keepdim = keepdim.unwrap_or(false);
353 let result = if let Some(dims) = dim {
354 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
355 py_result!(self.tensor.mean(Some(&usize_dims), keepdim))?
356 } else {
357 py_result!(self.tensor.mean(None, keepdim))?
358 };
359 Ok(PyTensor { tensor: result })
360 }
361
362 fn max(&self, dim: Option<i64>, keepdim: Option<bool>) -> PyResult<PyTensor> {
364 let dim_opt = dim.map(|d| d as usize);
365 let keepdim = keepdim.unwrap_or(false);
366 let result = py_result!(self.tensor.max(dim_opt, keepdim))?;
367 Ok(PyTensor { tensor: result })
368 }
369
370 fn min(&self, _dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
372 let result = py_result!(self.tensor.min())?;
374 Ok(PyTensor { tensor: result })
375 }
376
377 fn matmul(&self, other: &PyTensor) -> PyResult<PyTensor> {
383 let result = py_result!(self.tensor.matmul(&other.tensor))?;
384 Ok(PyTensor { tensor: result })
385 }
386
387 fn dot(&self, other: &PyTensor) -> PyResult<PyTensor> {
389 let result = py_result!(self.tensor.dot(&other.tensor))?;
390 Ok(PyTensor { tensor: result })
391 }
392
393 fn relu(&self) -> PyResult<PyTensor> {
399 let result = py_result!(self.tensor.relu())?;
400 Ok(PyTensor { tensor: result })
401 }
402
403 fn sigmoid(&self) -> PyResult<PyTensor> {
405 let result = py_result!(self.tensor.sigmoid())?;
406 Ok(PyTensor { tensor: result })
407 }
408
409 fn tanh(&self) -> PyResult<PyTensor> {
411 let result = py_result!(self.tensor.tanh())?;
412 Ok(PyTensor { tensor: result })
413 }
414
415 fn softmax(&self, dim: i64) -> PyResult<PyTensor> {
417 let result = py_result!(self.tensor.softmax(dim as i32))?;
418 Ok(PyTensor { tensor: result })
419 }
420
421 fn sin(&self) -> PyResult<PyTensor> {
427 let result = py_result!(self.tensor.sin())?;
428 Ok(PyTensor { tensor: result })
429 }
430
431 fn cos(&self) -> PyResult<PyTensor> {
433 let result = py_result!(self.tensor.cos())?;
434 Ok(PyTensor { tensor: result })
435 }
436
437 fn exp(&self) -> PyResult<PyTensor> {
439 let result = py_result!(self.tensor.exp())?;
440 Ok(PyTensor { tensor: result })
441 }
442
443 fn log(&self) -> PyResult<PyTensor> {
445 let result = py_result!(self.tensor.log())?;
446 Ok(PyTensor { tensor: result })
447 }
448
449 fn sqrt(&self) -> PyResult<PyTensor> {
451 let result = py_result!(self.tensor.sqrt())?;
452 Ok(PyTensor { tensor: result })
453 }
454
455 fn abs(&self) -> PyResult<PyTensor> {
457 let result = py_result!(self.tensor.abs())?;
458 Ok(PyTensor { tensor: result })
459 }
460
461 fn eq(&self, other: &PyTensor) -> PyResult<PyTensor> {
467 let bool_result = py_result!(self.tensor.eq(&other.tensor))?;
468 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
469 Ok(PyTensor {
470 tensor: float_tensor,
471 })
472 }
473
474 fn ne(&self, other: &PyTensor) -> PyResult<PyTensor> {
476 let bool_result = py_result!(self.tensor.ne(&other.tensor))?;
477 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
478 Ok(PyTensor {
479 tensor: float_tensor,
480 })
481 }
482
483 fn lt(&self, other: &PyTensor) -> PyResult<PyTensor> {
485 let bool_result = py_result!(self.tensor.lt(&other.tensor))?;
486 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
487 Ok(PyTensor {
488 tensor: float_tensor,
489 })
490 }
491
492 fn gt(&self, other: &PyTensor) -> PyResult<PyTensor> {
494 let bool_result = py_result!(self.tensor.gt(&other.tensor))?;
495 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
496 Ok(PyTensor {
497 tensor: float_tensor,
498 })
499 }
500
501 fn le(&self, other: &PyTensor) -> PyResult<PyTensor> {
503 let bool_result = py_result!(self.tensor.le(&other.tensor))?;
504 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
505 Ok(PyTensor {
506 tensor: float_tensor,
507 })
508 }
509
510 fn ge(&self, other: &PyTensor) -> PyResult<PyTensor> {
512 let bool_result = py_result!(self.tensor.ge(&other.tensor))?;
513 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
514 Ok(PyTensor {
515 tensor: float_tensor,
516 })
517 }
518
519 fn maximum(&self, other: &PyTensor) -> PyResult<PyTensor> {
521 let result = py_result!(self.tensor.maximum(&other.tensor))?;
522 Ok(PyTensor { tensor: result })
523 }
524
525 fn minimum(&self, other: &PyTensor) -> PyResult<PyTensor> {
527 let result = py_result!(self.tensor.minimum(&other.tensor))?;
528 Ok(PyTensor { tensor: result })
529 }
530
531 fn eq_scalar(&self, value: f32) -> PyResult<PyTensor> {
535 let bool_result = py_result!(self.tensor.eq_scalar(value))?;
536 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
537 Ok(PyTensor {
538 tensor: float_tensor,
539 })
540 }
541
542 fn ne_scalar(&self, value: f32) -> PyResult<PyTensor> {
544 let bool_result = py_result!(self.tensor.ne_scalar(value))?;
545 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
546 Ok(PyTensor {
547 tensor: float_tensor,
548 })
549 }
550
551 fn lt_scalar(&self, value: f32) -> PyResult<PyTensor> {
553 let bool_result = py_result!(self.tensor.lt_scalar(value))?;
554 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
555 Ok(PyTensor {
556 tensor: float_tensor,
557 })
558 }
559
560 fn gt_scalar(&self, value: f32) -> PyResult<PyTensor> {
562 let bool_result = py_result!(self.tensor.gt_scalar(value))?;
563 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
564 Ok(PyTensor {
565 tensor: float_tensor,
566 })
567 }
568
569 fn le_scalar(&self, value: f32) -> PyResult<PyTensor> {
571 let bool_result = py_result!(self.tensor.le_scalar(value))?;
572 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
573 Ok(PyTensor {
574 tensor: float_tensor,
575 })
576 }
577
578 fn ge_scalar(&self, value: f32) -> PyResult<PyTensor> {
580 let bool_result = py_result!(self.tensor.ge_scalar(value))?;
581 let float_tensor = py_result!(Self::bool_to_float_tensor(bool_result))?;
582 Ok(PyTensor {
583 tensor: float_tensor,
584 })
585 }
586
587 fn clone_tensor(&self) -> PyResult<PyTensor> {
593 Ok(PyTensor {
594 tensor: self.tensor.clone(),
595 })
596 }
597
598 fn detach(&self) -> PyResult<PyTensor> {
600 Ok(PyTensor {
601 tensor: self.tensor.detach(),
602 })
603 }
604
605 fn to_device(&self, device: PyDevice) -> PyResult<PyTensor> {
607 let result = py_result!(self.tensor.clone().to(device.device))?;
608 Ok(PyTensor { tensor: result })
609 }
610
611 fn is_contiguous(&self) -> bool {
613 self.tensor.is_contiguous()
614 }
615
616 fn contiguous(&self) -> PyResult<PyTensor> {
618 let result = py_result!(self.tensor.contiguous())?;
619 Ok(PyTensor { tensor: result })
620 }
621
622 fn clamp(&self, min: Option<f32>, max: Option<f32>) -> PyResult<PyTensor> {
628 let result = if let (Some(min_val), Some(max_val)) = (min, max) {
629 py_result!(self.tensor.clamp(min_val, max_val))?
630 } else if let Some(min_val) = min {
631 py_result!(self.tensor.clamp(min_val, f32::MAX))?
633 } else if let Some(max_val) = max {
634 py_result!(self.tensor.clamp(f32::MIN, max_val))?
636 } else {
637 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
638 "At least one of min or max must be specified",
639 ));
640 };
641 Ok(PyTensor { tensor: result })
642 }
643
644 fn fill_(&mut self, value: f32) -> PyResult<PyTensor> {
646 py_result!(self.tensor.fill_(value))?;
647 Ok(PyTensor {
648 tensor: self.tensor.clone(),
649 })
650 }
651
652 fn zero_(&mut self) -> PyResult<PyTensor> {
654 py_result!(self.tensor.zero_())?;
655 Ok(PyTensor {
656 tensor: self.tensor.clone(),
657 })
658 }
659
660 fn uniform_(&mut self, from: Option<f32>, to: Option<f32>) -> PyResult<PyTensor> {
662 use scirs2_core::random::{thread_rng, Distribution, Uniform};
664
665 let from = from.unwrap_or(0.0);
666 let to = to.unwrap_or(1.0);
667
668 let mut rng = thread_rng();
669 let dist = Uniform::new(from, to).map_err(|e| {
670 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
671 "Invalid uniform distribution parameters: {}",
672 e
673 ))
674 })?;
675
676 let mut data = py_result!(self.tensor.data())?;
677 for val in data.iter_mut() {
678 *val = dist.sample(&mut rng);
679 }
680
681 let shape = self.tensor.shape().dims().to_vec();
682 let device = self.tensor.device();
683 self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
684 .requires_grad_(self.tensor.requires_grad());
685
686 Ok(PyTensor {
687 tensor: self.tensor.clone(),
688 })
689 }
690
691 fn normal_(&mut self, mean: Option<f32>, std: Option<f32>) -> PyResult<PyTensor> {
693 use scirs2_core::random::{thread_rng, Distribution, Normal};
695
696 let mean = mean.unwrap_or(0.0);
697 let std = std.unwrap_or(1.0);
698
699 let normal = Normal::new(mean as f64, std as f64).map_err(|e| {
700 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
701 "Invalid normal distribution parameters: {}",
702 e
703 ))
704 })?;
705
706 let mut rng = thread_rng();
707 let mut data = py_result!(self.tensor.data())?;
708 for val in data.iter_mut() {
709 *val = normal.sample(&mut rng) as f32;
710 }
711
712 let shape = self.tensor.shape().dims().to_vec();
713 let device = self.tensor.device();
714 self.tensor = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
715 .requires_grad_(self.tensor.requires_grad());
716
717 Ok(PyTensor {
718 tensor: self.tensor.clone(),
719 })
720 }
721
722 fn repeat(&self, repeats: Vec<usize>) -> PyResult<PyTensor> {
724 let result = py_result!(self.tensor.repeat(&repeats))?;
725 Ok(PyTensor { tensor: result })
726 }
727
728 fn expand(&self, size: Vec<i64>) -> PyResult<PyTensor> {
730 let size: Vec<usize> = size.iter().map(|&x| x as usize).collect();
731 let result = py_result!(self.tensor.expand(&size))?;
732 Ok(PyTensor { tensor: result })
733 }
734
735 fn expand_as(&self, other: &PyTensor) -> PyResult<PyTensor> {
737 let other_shape: Vec<usize> = other.tensor.shape().dims().iter().map(|&x| x).collect();
738 let result = py_result!(self.tensor.expand(&other_shape))?;
739 Ok(PyTensor { tensor: result })
740 }
741
742 fn index_select(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
744 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
745 let index_i64 = py_result!(index_i32.to_i64_simd())?;
746 let result = py_result!(self.tensor.index_select(dim as i32, &index_i64))?;
747 Ok(PyTensor { tensor: result })
748 }
749
750 fn gather(&self, dim: i64, index: &PyTensor) -> PyResult<PyTensor> {
752 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
753 let index_i64 = py_result!(index_i32.to_i64_simd())?;
754 let result = py_result!(self.tensor.gather(dim as usize, &index_i64))?;
755 Ok(PyTensor { tensor: result })
756 }
757
758 fn scatter(&self, dim: i64, index: &PyTensor, src: &PyTensor) -> PyResult<PyTensor> {
760 let index_i32 = py_result!(index.tensor.to_i32_simd())?;
761 let index_i64 = py_result!(index_i32.to_i64_simd())?;
762 let result = py_result!(self.tensor.scatter(dim as usize, &index_i64, &src.tensor))?;
763 Ok(PyTensor { tensor: result })
764 }
765
766 fn masked_fill(&self, mask: &PyTensor, value: f32) -> PyResult<PyTensor> {
768 let mut data = py_result!(self.tensor.data())?;
770 let mask_data = py_result!(mask.tensor.data())?;
771
772 if data.len() != mask_data.len() {
773 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
774 "Mask must have the same number of elements as tensor",
775 ));
776 }
777
778 for (val, &mask_val) in data.iter_mut().zip(mask_data.iter()) {
779 if mask_val != 0.0 {
780 *val = value;
781 }
782 }
783
784 let shape = self.tensor.shape().dims().to_vec();
785 let device = self.tensor.device();
786 let result = py_result!(torsh_tensor::Tensor::from_data(data, shape, device))?
787 .requires_grad_(self.tensor.requires_grad());
788
789 Ok(PyTensor { tensor: result })
790 }
791
792 fn masked_select(&self, mask: &PyTensor) -> PyResult<PyTensor> {
794 let mask_data = py_result!(mask.tensor.data())?;
796 let bool_data: Vec<bool> = mask_data.iter().map(|&x| x != 0.0).collect();
797 let mask_bool = py_result!(torsh_tensor::Tensor::from_data(
798 bool_data,
799 mask.tensor.shape().dims().to_vec(),
800 mask.tensor.device()
801 ))?;
802 let result = py_result!(self.tensor.masked_select(&mask_bool))?;
803 Ok(PyTensor { tensor: result })
804 }
805
806 fn cat(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
808 let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
809 tensors.iter().map(|t| &t.tensor).collect();
810 let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
811 Ok(PyTensor { tensor: result })
812 }
813
814 fn stack(&self, tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
816 let tensor_refs: Vec<&torsh_tensor::Tensor<f32>> =
817 tensors.iter().map(|t| &t.tensor).collect();
818 let result = py_result!(torsh_tensor::Tensor::cat(&tensor_refs, dim as i32))?;
820 Ok(PyTensor { tensor: result })
821 }
822
823 fn chunk(&self, chunks: usize, dim: i64) -> PyResult<Vec<PyTensor>> {
825 let shape = self.tensor.shape().dims().to_vec();
827 let dim_usize = dim as usize;
828
829 if dim_usize >= shape.len() {
830 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
831 "Dimension {} out of range for tensor with {} dimensions",
832 dim,
833 shape.len()
834 )));
835 }
836
837 let dim_size = shape[dim_usize];
838 let chunk_size = (dim_size + chunks - 1) / chunks; let mut result = Vec::new();
841 for i in 0..chunks {
842 let start = i * chunk_size;
843 if start >= dim_size {
844 break;
845 }
846 let end = std::cmp::min(start + chunk_size, dim_size);
847
848 let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, end - start))?;
850 result.push(PyTensor { tensor: chunk });
851 }
852
853 Ok(result)
854 }
855
856 fn split(&self, split_sizes: Vec<usize>, dim: i64) -> PyResult<Vec<PyTensor>> {
858 let shape = self.tensor.shape().dims().to_vec();
860 let dim_usize = dim as usize;
861
862 if dim_usize >= shape.len() {
863 return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(format!(
864 "Dimension {} out of range for tensor with {} dimensions",
865 dim,
866 shape.len()
867 )));
868 }
869
870 let total_size: usize = split_sizes.iter().sum();
871 if total_size != shape[dim_usize] {
872 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
873 "Split sizes sum to {} but dimension {} has size {}",
874 total_size, dim, shape[dim_usize]
875 )));
876 }
877
878 let mut result = Vec::new();
879 let mut start = 0;
880
881 for &size in &split_sizes {
882 let chunk = py_result!(self.tensor.narrow(dim as i32, start as i64, size))?;
883 result.push(PyTensor { tensor: chunk });
884 start += size;
885 }
886
887 Ok(result)
888 }
889
890 fn permute(&self, dims: Vec<i64>) -> PyResult<PyTensor> {
892 let dims: Vec<i32> = dims.iter().map(|&x| x as i32).collect();
893 let result = py_result!(self.tensor.permute(&dims))?;
894 Ok(PyTensor { tensor: result })
895 }
896
897 fn diag(&self, diagonal: Option<i64>) -> PyResult<PyTensor> {
899 let offset = diagonal.unwrap_or(0);
902 let shape = self.tensor.shape().dims().to_vec();
903
904 if shape.len() == 1 {
905 let n = shape[0];
907 let size = n + offset.abs() as usize;
908 let mut data = vec![0.0; size * size];
909 let input_data = py_result!(self.tensor.data())?;
910
911 for (i, &val) in input_data.iter().enumerate() {
912 if offset >= 0 {
913 let row = i;
914 let col = i + offset as usize;
915 data[row * size + col] = val;
916 } else {
917 let row = i + (-offset) as usize;
918 let col = i;
919 data[row * size + col] = val;
920 }
921 }
922
923 let result = py_result!(torsh_tensor::Tensor::from_data(
924 data,
925 vec![size, size],
926 self.tensor.device()
927 ))?;
928 Ok(PyTensor { tensor: result })
929 } else if shape.len() == 2 {
930 let rows = shape[0];
932 let cols = shape[1];
933 let data = py_result!(self.tensor.data())?;
934
935 let mut diag_data = Vec::new();
936
937 if offset >= 0 {
938 let offset_u = offset as usize;
939 for i in 0..std::cmp::min(rows, cols.saturating_sub(offset_u)) {
940 diag_data.push(data[i * cols + i + offset_u]);
941 }
942 } else {
943 let offset_u = (-offset) as usize;
944 for i in 0..std::cmp::min(rows.saturating_sub(offset_u), cols) {
945 diag_data.push(data[(i + offset_u) * cols + i]);
946 }
947 }
948
949 let diag_len = diag_data.len();
950 let result = py_result!(torsh_tensor::Tensor::from_data(
951 diag_data,
952 vec![diag_len],
953 self.tensor.device()
954 ))?;
955 Ok(PyTensor { tensor: result })
956 } else {
957 Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
958 "diag() only supports 1D or 2D tensors",
959 ))
960 }
961 }
962
963 fn trace(&self) -> PyResult<PyTensor> {
965 let shape = self.tensor.shape().dims().to_vec();
967
968 if shape.len() != 2 {
969 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
970 "trace() requires a 2D tensor",
971 ));
972 }
973
974 let rows = shape[0];
975 let cols = shape[1];
976 let min_dim = std::cmp::min(rows, cols);
977
978 let data = py_result!(self.tensor.data())?;
979 let mut trace_sum = 0.0;
980
981 for i in 0..min_dim {
982 trace_sum += data[i * cols + i];
983 }
984
985 let result = py_result!(torsh_tensor::Tensor::from_data(
986 vec![trace_sum],
987 vec![],
988 self.tensor.device()
989 ))?;
990
991 Ok(PyTensor { tensor: result })
992 }
993
994 fn norm(
996 &self,
997 p: Option<f32>,
998 _dim: Option<Vec<i64>>,
999 _keepdim: Option<bool>,
1000 ) -> PyResult<PyTensor> {
1001 let _p = p.unwrap_or(2.0);
1002 let result = py_result!(self.tensor.norm())?;
1005 Ok(PyTensor { tensor: result })
1006 }
1007
1008 fn std(
1010 &self,
1011 dim: Option<Vec<i64>>,
1012 keepdim: Option<bool>,
1013 unbiased: Option<bool>,
1014 ) -> PyResult<PyTensor> {
1015 let keepdim = keepdim.unwrap_or(false);
1016 let unbiased = unbiased.unwrap_or(true);
1017 let stat_mode = if unbiased {
1018 torsh_tensor::stats::StatMode::Sample
1019 } else {
1020 torsh_tensor::stats::StatMode::Population
1021 };
1022 let result = if let Some(dims) = dim {
1023 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1024 py_result!(self.tensor.std(Some(&usize_dims), keepdim, stat_mode))?
1025 } else {
1026 py_result!(self.tensor.std(None, keepdim, stat_mode))?
1027 };
1028 Ok(PyTensor { tensor: result })
1029 }
1030
1031 fn var(
1033 &self,
1034 dim: Option<Vec<i64>>,
1035 keepdim: Option<bool>,
1036 unbiased: Option<bool>,
1037 ) -> PyResult<PyTensor> {
1038 let keepdim = keepdim.unwrap_or(false);
1039 let unbiased = unbiased.unwrap_or(true);
1040 let stat_mode = if unbiased {
1041 torsh_tensor::stats::StatMode::Sample
1042 } else {
1043 torsh_tensor::stats::StatMode::Population
1044 };
1045 let result = if let Some(dims) = dim {
1046 let usize_dims: Vec<usize> = dims.iter().map(|&x| x as usize).collect();
1047 py_result!(self.tensor.var(Some(&usize_dims), keepdim, stat_mode))?
1048 } else {
1049 py_result!(self.tensor.var(None, keepdim, stat_mode))?
1050 };
1051 Ok(PyTensor { tensor: result })
1052 }
1053
1054 fn argmax(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1056 let result = if let Some(d) = dim {
1057 py_result!(self.tensor.argmax(Some(d as i32)))?
1058 } else {
1059 py_result!(self.tensor.argmax(None))?
1060 };
1061 let result_f32 = py_result!(result.to_f32_simd())?;
1062 Ok(PyTensor { tensor: result_f32 })
1063 }
1064
1065 fn argmin(&self, dim: Option<i64>, _keepdim: Option<bool>) -> PyResult<PyTensor> {
1067 let result = if let Some(d) = dim {
1068 py_result!(self.tensor.argmin(Some(d as i32)))?
1069 } else {
1070 py_result!(self.tensor.argmin(None))?
1071 };
1072 let result_f32 = py_result!(result.to_f32_simd())?;
1073 Ok(PyTensor { tensor: result_f32 })
1074 }
1075}
1076
1077impl PyTensor {
1078 fn from_py_list(list: &Bound<'_, PyList>, device: DeviceType) -> PyResult<Tensor<f32>> {
1080 let mut data = Vec::new();
1081 let len = list.len();
1082
1083 for i in 0..len {
1084 let item = list.get_item(i)?;
1085 if let Ok(val) = item.extract::<f32>() {
1086 data.push(val);
1087 } else {
1088 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
1089 "Cannot convert item at index {} to f32",
1090 i
1091 )));
1092 }
1093 }
1094
1095 py_result!(Tensor::from_data(data, vec![len], device))
1096 }
1097
1098 fn create_nested_list(
1100 &self,
1101 py: Python<'_>,
1102 data: &[f32],
1103 shape: &[usize],
1104 dim: usize,
1105 index: &mut usize,
1106 ) -> PyResult<Py<PyAny>> {
1107 if dim == shape.len() - 1 {
1108 let mut items = Vec::new();
1110 for _ in 0..shape[dim] {
1111 items.push(data[*index].into_pyobject(py)?.into_any().unbind());
1112 *index += 1;
1113 }
1114 Ok(items.into_pyobject(py)?.into_any().unbind())
1115 } else {
1116 let mut items = Vec::new();
1118 for _ in 0..shape[dim] {
1119 let nested = self.create_nested_list(py, data, shape, dim + 1, index)?;
1120 items.push(nested);
1121 }
1122 Ok(items.into_pyobject(py)?.into_any().unbind())
1123 }
1124 }
1125
1126 fn bool_to_float_tensor(
1128 bool_tensor: torsh_tensor::Tensor<bool>,
1129 ) -> torsh_core::error::Result<torsh_tensor::Tensor<f32>> {
1130 let bool_data = bool_tensor.data()?;
1131 let float_data: Vec<f32> = bool_data
1132 .iter()
1133 .map(|&b| if b { 1.0 } else { 0.0 })
1134 .collect();
1135 let shape = bool_tensor.shape().dims().to_vec();
1136 let device = bool_tensor.device();
1137
1138 torsh_tensor::Tensor::from_data(float_data, shape, device)
1139 }
1140}