1use super::core::{Tensor, TensorStorage};
8#[cfg(feature = "gpu")]
9use crate::Device;
10use crate::{Result, TensorError};
11use scirs2_core::numeric::Zero;
12
13impl<T: Clone> Tensor<T> {
15 pub fn backward(&self) -> Result<()>
17 where
18 T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
19 {
20 if !self.requires_grad() {
21 return Err(TensorError::GradientNotEnabled {
22 operation: "backward".to_string(),
23 suggestion: "Call tensor.requires_grad_(true) before computation".to_string(),
24 context: None,
25 });
26 }
27
28 if self.shape().dims().iter().product::<usize>() != 1 {
30 return Err(TensorError::invalid_shape_simple(
31 "backward() can only be called on scalar tensors".to_string(),
32 ));
33 }
34
35 self.init_gradient()?;
38
39 Ok(())
53 }
54
55 pub fn backward_with_options(&self, retain_graph: bool, create_graph: bool) -> Result<()>
57 where
58 T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
59 {
60 if !self.requires_grad() {
61 return Err(TensorError::GradientNotEnabled {
62 operation: "backward".to_string(),
63 suggestion: "Call tensor.requires_grad_(true) before computation".to_string(),
64 context: None,
65 });
66 }
67
68 if self.shape().dims().iter().product::<usize>() != 1 {
70 return Err(TensorError::invalid_shape_simple(
71 "backward() can only be called on scalar tensors".to_string(),
72 ));
73 }
74
75 self.init_gradient()?;
77
78 if retain_graph {
83 }
87
88 if create_graph {
89 }
93
94 Ok(())
98 }
99
100 fn init_gradient(&self) -> Result<()>
102 where
103 T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
104 {
105 if self.grad().is_some() {
107 return Ok(());
108 }
109
110 Ok(())
122 }
123}
124
125impl<T> Tensor<T>
126where
127 T: Clone
128 + Default
129 + scirs2_core::num_traits::Zero
130 + scirs2_core::num_traits::One
131 + Send
132 + Sync
133 + 'static
134 + bytemuck::Pod
135 + bytemuck::Zeroable,
136{
137 pub fn add(&self, other: &Self) -> Result<Self>
139 where
140 T: std::ops::Add<Output = T>,
141 {
142 crate::ops::add(self, other)
143 }
144
145 pub fn sub(&self, other: &Self) -> Result<Self>
147 where
148 T: std::ops::Sub<Output = T>,
149 {
150 crate::ops::sub(self, other)
151 }
152
153 pub fn mul(&self, other: &Self) -> Result<Self>
155 where
156 T: std::ops::Mul<Output = T>,
157 {
158 crate::ops::mul(self, other)
159 }
160
161 pub fn div(&self, other: &Self) -> Result<Self>
163 where
164 T: std::ops::Div<Output = T>,
165 {
166 crate::ops::div(self, other)
167 }
168
169 pub fn pow(&self, other: &Self) -> Result<Self>
171 where
172 T: scirs2_core::num_traits::Float,
173 {
174 crate::ops::pow(self, other)
175 }
176
177 pub fn log(&self) -> Result<Self>
179 where
180 T: scirs2_core::num_traits::Float,
181 {
182 match &self.storage {
183 TensorStorage::Cpu(arr) => {
184 let result = arr.mapv(|x| x.ln());
185 Ok(Self::from_array(result))
186 }
187 #[cfg(feature = "gpu")]
188 TensorStorage::Gpu(buffer) => self.log_gpu_impl(buffer),
189 }
190 }
191
192 #[cfg(feature = "gpu")]
193 fn log_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
194 where
195 T: scirs2_core::num_traits::Float
196 + bytemuck::Pod
197 + bytemuck::Zeroable
198 + Clone
199 + Send
200 + Sync
201 + 'static,
202 {
203 use crate::gpu::ops::{execute_unary_op, UnaryOp};
204 let result_buffer = execute_unary_op(buffer, UnaryOp::Log)?;
205 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
206 }
207
208 pub fn neg(&self) -> Result<Self>
210 where
211 T: std::ops::Neg<Output = T>,
212 {
213 match &self.storage {
214 TensorStorage::Cpu(arr) => {
215 let result = arr.mapv(|x| -x);
216 Ok(Self::from_array(result))
217 }
218 #[cfg(feature = "gpu")]
219 TensorStorage::Gpu(buffer) => self.neg_gpu_impl(buffer),
220 }
221 }
222
223 #[cfg(feature = "gpu")]
224 fn neg_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
225 where
226 T: std::ops::Neg<Output = T>
227 + bytemuck::Pod
228 + bytemuck::Zeroable
229 + Clone
230 + Send
231 + Sync
232 + 'static,
233 {
234 use crate::gpu::ops::{execute_unary_op, UnaryOp};
235 let result_buffer = execute_unary_op(buffer, UnaryOp::Neg)?;
236 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
237 }
238
239 pub fn matmul(&self, other: &Self) -> Result<Self> {
241 crate::ops::matmul(self, other)
242 }
243
244 pub fn relu(&self) -> Result<Self>
247 where
248 T: PartialOrd + scirs2_core::num_traits::Zero + bytemuck::Pod + bytemuck::Zeroable,
249 {
250 crate::ops::activation::relu(self)
251 }
252
253 pub fn sigmoid(&self) -> Result<Self>
255 where
256 T: scirs2_core::num_traits::Float + bytemuck::Pod + bytemuck::Zeroable,
257 {
258 crate::ops::activation::sigmoid(self)
259 }
260
261 pub fn tanh(&self) -> Result<Self>
263 where
264 T: scirs2_core::num_traits::Float + bytemuck::Pod + bytemuck::Zeroable,
265 {
266 crate::ops::activation::tanh(self)
267 }
268
269 pub fn gelu(&self) -> Result<Self>
271 where
272 T: scirs2_core::num_traits::Float + bytemuck::Pod,
273 {
274 crate::ops::activation::gelu(self)
275 }
276
277 pub fn swish(&self) -> Result<Self>
279 where
280 T: scirs2_core::num_traits::Float + bytemuck::Pod,
281 {
282 crate::ops::activation::swish(self)
283 }
284
285 pub fn mish(&self) -> Result<Self>
287 where
288 T: scirs2_core::num_traits::Float
289 + Send
290 + Sync
291 + 'static
292 + bytemuck::Pod
293 + bytemuck::Zeroable,
294 {
295 crate::ops::activation::mish(self)
296 }
297
298 pub fn softmax(&self, axis: Option<i32>) -> Result<Self>
300 where
301 T: scirs2_core::num_traits::Float
302 + std::ops::Sub<Output = T>
303 + std::ops::Add<Output = T>
304 + std::ops::Div<Output = T>
305 + std::iter::Sum
306 + Send
307 + Sync
308 + bytemuck::Pod,
309 {
310 crate::ops::activation::softmax(self, axis)
311 }
312
313 pub fn elu(&self, alpha: T) -> Result<Self>
315 where
316 T: scirs2_core::num_traits::Float + PartialOrd + bytemuck::Pod,
317 {
318 crate::ops::activation::elu(self, alpha)
319 }
320
321 pub fn leaky_relu(&self, alpha: T) -> Result<Self>
323 where
324 T: scirs2_core::num_traits::Float + PartialOrd + bytemuck::Pod,
325 {
326 crate::ops::activation::leaky_relu(self, alpha)
327 }
328
329 pub fn hard_swish(&self) -> Result<Self>
331 where
332 T: scirs2_core::num_traits::Float + PartialOrd,
333 {
334 crate::ops::activation::hard_swish(self)
335 }
336
337 pub fn prelu(&self, alpha: &Self) -> Result<Self>
339 where
340 T: scirs2_core::num_traits::Float + PartialOrd,
341 {
342 crate::ops::activation::prelu(self, alpha)
343 }
344
345 pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
347 crate::ops::reshape(self, shape)
348 }
349
350 pub fn transpose(&self) -> Result<Self> {
352 crate::ops::transpose(self)
353 }
354
355 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Result<Self> {
357 crate::ops::slice(self, ranges)
358 }
359
360 pub fn slice_with_stride(&self, slice_params: &[crate::SliceParams]) -> Result<Self> {
362 crate::ops::slice_with_stride(self, slice_params)
363 }
364
365 pub fn sum(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
367 where
368 T: Zero,
369 {
370 crate::ops::sum(self, axes, keepdims)
371 }
372
373 pub fn mean(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
375 where
376 T: scirs2_core::num_traits::Float + scirs2_core::num_traits::FromPrimitive,
377 {
378 crate::ops::mean(self, axes, keepdims)
379 }
380
381 pub fn max(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
383 where
384 T: PartialOrd,
385 {
386 crate::ops::max(self, axes, keepdims)
387 }
388
389 pub fn min(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
391 where
392 T: PartialOrd,
393 {
394 crate::ops::min(self, axes, keepdims)
395 }
396
397 pub fn sqrt(&self) -> Result<Self>
399 where
400 T: scirs2_core::num_traits::Float,
401 {
402 match &self.storage {
403 TensorStorage::Cpu(arr) => {
404 let result = arr.mapv(|x| x.sqrt());
405 Ok(Self::from_array(result))
406 }
407 #[cfg(feature = "gpu")]
408 TensorStorage::Gpu(buffer) => self.sqrt_gpu_impl(buffer),
409 }
410 }
411
412 #[cfg(feature = "gpu")]
413 fn sqrt_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
414 where
415 T: scirs2_core::num_traits::Float
416 + bytemuck::Pod
417 + bytemuck::Zeroable
418 + Clone
419 + Send
420 + Sync
421 + 'static,
422 {
423 use crate::gpu::ops::{execute_unary_op, UnaryOp};
424 let result_buffer = execute_unary_op(buffer, UnaryOp::Sqrt)?;
425 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
426 }
427
428 pub fn abs(&self) -> Result<Self>
430 where
431 T: scirs2_core::num_traits::Signed,
432 {
433 match &self.storage {
434 TensorStorage::Cpu(arr) => {
435 let result = arr.mapv(|x| x.abs());
436 Ok(Self::from_array(result))
437 }
438 #[cfg(feature = "gpu")]
439 TensorStorage::Gpu(buffer) => {
440 use crate::gpu::ops::{execute_unary_op, UnaryOp};
441 let result_buffer = execute_unary_op(buffer, UnaryOp::Abs)?;
442 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
443 }
444 }
445 }
446
447 pub fn exp(&self) -> Result<Self>
449 where
450 T: scirs2_core::num_traits::Float,
451 {
452 match &self.storage {
453 TensorStorage::Cpu(arr) => {
454 let result = arr.mapv(|x| x.exp());
455 Ok(Self::from_array(result))
456 }
457 #[cfg(feature = "gpu")]
458 TensorStorage::Gpu(buffer) => {
459 use crate::gpu::ops::{execute_unary_op, UnaryOp};
460 let result_buffer = execute_unary_op(buffer, UnaryOp::Exp)?;
461 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
462 }
463 }
464 }
465
466 pub fn sin(&self) -> Result<Self>
468 where
469 T: scirs2_core::num_traits::Float,
470 {
471 match &self.storage {
472 TensorStorage::Cpu(arr) => {
473 let result = arr.mapv(|x| x.sin());
474 Ok(Self::from_array(result))
475 }
476 #[cfg(feature = "gpu")]
477 TensorStorage::Gpu(buffer) => {
478 use crate::gpu::ops::{execute_unary_op, UnaryOp};
479 let result_buffer = execute_unary_op(buffer, UnaryOp::Sin)?;
480 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
481 }
482 }
483 }
484
485 pub fn cos(&self) -> Result<Self>
487 where
488 T: scirs2_core::num_traits::Float,
489 {
490 match &self.storage {
491 TensorStorage::Cpu(arr) => {
492 let result = arr.mapv(|x| x.cos());
493 Ok(Self::from_array(result))
494 }
495 #[cfg(feature = "gpu")]
496 TensorStorage::Gpu(buffer) => {
497 use crate::gpu::ops::{execute_unary_op, UnaryOp};
498 let result_buffer = execute_unary_op(buffer, UnaryOp::Cos)?;
499 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
500 }
501 }
502 }
503
504 pub fn tan(&self) -> Result<Self>
506 where
507 T: scirs2_core::num_traits::Float,
508 {
509 match &self.storage {
510 TensorStorage::Cpu(arr) => {
511 let result = arr.mapv(|x| x.tan());
512 Ok(Self::from_array(result))
513 }
514 #[cfg(feature = "gpu")]
515 TensorStorage::Gpu(buffer) => {
516 use crate::gpu::ops::{execute_unary_op, UnaryOp};
517 let result_buffer = execute_unary_op(buffer, UnaryOp::Tan)?;
518 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
519 }
520 }
521 }
522
523 pub fn recip(&self) -> Result<Self>
525 where
526 T: scirs2_core::num_traits::Float,
527 {
528 match &self.storage {
529 TensorStorage::Cpu(arr) => {
530 let result = arr.mapv(|x| x.recip());
531 Ok(Self::from_array(result))
532 }
533 #[cfg(feature = "gpu")]
534 TensorStorage::Gpu(buffer) => {
535 use crate::gpu::ops::{execute_unary_op, UnaryOp};
536 let result_buffer = execute_unary_op(buffer, UnaryOp::Recip)?;
537 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
538 }
539 }
540 }
541
542 pub fn squeeze(&self, axes: Option<&[usize]>) -> Result<Self>
544 where
545 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
546 {
547 crate::ops::squeeze(self, axes)
548 }
549
550 pub fn unsqueeze(&self, axes: &[usize]) -> Result<Self>
552 where
553 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
554 {
555 crate::ops::unsqueeze(self, axes)
556 }
557
558 pub fn scalar_mul(&self, scalar: T) -> Result<Self>
560 where
561 T: Clone + Default + std::ops::Mul<Output = T> + Send + Sync + 'static,
562 {
563 match &self.storage {
564 TensorStorage::Cpu(arr) => {
565 let result = arr.mapv(|x| x * scalar);
566 Ok(Self::from_array(result))
567 }
568 #[cfg(feature = "gpu")]
569 TensorStorage::Gpu(buffer) => {
570 use crate::gpu::ops::{execute_binary_scalar_op, BinaryScalarOp};
571 let result_buffer = execute_binary_scalar_op(buffer, scalar, BinaryScalarOp::Mul)?;
572 Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
573 }
574 }
575 }
576
577 pub fn to_vec(&self) -> Result<Vec<T>>
579 where
580 T: Clone
581 + Default
582 + Send
583 + Sync
584 + 'static
585 + scirs2_core::num_traits::Zero
586 + scirs2_core::num_traits::One,
587 {
588 match &self.storage {
589 TensorStorage::Cpu(arr) => {
590 if let Some(slice) = arr.as_slice() {
591 Ok(slice.to_vec())
592 } else {
593 Ok(arr.iter().cloned().collect())
595 }
596 }
597 #[cfg(feature = "gpu")]
598 TensorStorage::Gpu(buffer) => {
599 let cpu_array = buffer.to_cpu_array()?;
600 if let Some(slice) = cpu_array.as_slice() {
601 Ok(slice.to_vec())
602 } else {
603 Ok(cpu_array.iter().cloned().collect())
605 }
606 }
607 }
608 }
609
610 pub fn max_axis(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
612 where
613 T: Clone + Default + PartialOrd + Send + Sync + 'static,
614 {
615 crate::ops::reduction::max(self, axes, keepdims)
616 }
617
618 pub fn sum_axis(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
620 where
621 T: Clone + Default + Zero + std::ops::Add<Output = T> + Send + Sync + 'static,
622 {
623 crate::ops::reduction::sum(self, axes, keepdims)
624 }
625
626 pub fn clamp(&self, min: T, max: T) -> Result<Self>
628 where
629 T: PartialOrd + Clone,
630 {
631 match &self.storage {
632 TensorStorage::Cpu(arr) => {
633 let result = arr.mapv(|x| {
634 if x < min {
635 min
636 } else if x > max {
637 max
638 } else {
639 x
640 }
641 });
642 Ok(Self::from_array(result))
643 }
644 #[cfg(feature = "gpu")]
645 TensorStorage::Gpu(_) => {
646 let cpu_tensor = self.to_cpu()?;
648 let clamped_cpu = cpu_tensor.clamp(min, max)?;
649 if let Device::Gpu(gpu_id) = self.device {
650 clamped_cpu.to_gpu(gpu_id)
651 } else {
652 Ok(clamped_cpu)
653 }
654 }
655 }
656 }
657
658 pub fn allclose(&self, other: &Self, rtol: T, atol: T) -> Result<bool>
660 where
661 T: scirs2_core::num_traits::Float + Clone,
662 {
663 if self.shape() != other.shape() {
664 return Ok(false);
665 }
666
667 match (&self.storage, &other.storage) {
668 (TensorStorage::Cpu(a), TensorStorage::Cpu(b)) => {
669 use scirs2_core::ndarray::Zip;
670 let mut all_close = true;
671 Zip::from(a).and(b).for_each(|&a_val, &b_val| {
672 let diff = (a_val - b_val).abs();
673 let tolerance = atol + rtol * b_val.abs().max(a_val.abs());
674 if diff > tolerance {
675 all_close = false;
676 }
677 });
678 Ok(all_close)
679 }
680 #[cfg(feature = "gpu")]
681 _ => {
682 let self_cpu = self.to_cpu()?;
684 let other_cpu = other.to_cpu()?;
685 self_cpu.allclose(&other_cpu, rtol, atol)
686 }
687 }
688 }
689
690 pub fn fill_(&mut self, value: T) -> Result<()>
692 where
693 T: Clone,
694 {
695 match &mut self.storage {
696 TensorStorage::Cpu(arr) => {
697 arr.fill(value);
698 Ok(())
699 }
700 #[cfg(feature = "gpu")]
701 TensorStorage::Gpu(_) => {
702 let filled_cpu = Tensor::full(self.shape().dims(), value);
704 let transferred = filled_cpu.to_device(self.device)?;
705 self.storage = transferred.storage;
706 Ok(())
707 }
708 }
709 }
710
711 pub fn to_scalar(&self) -> Result<T>
713 where
714 T: Clone,
715 {
716 if !self.is_scalar() {
717 return Err(crate::TensorError::invalid_operation_simple(format!(
718 "Cannot extract scalar from tensor with shape {:?}",
719 self.shape().dims()
720 )));
721 }
722
723 match &self.storage {
724 TensorStorage::Cpu(arr) => {
725 if let Some(scalar) = arr.as_slice().and_then(|s| s.first()) {
727 Ok(*scalar)
728 } else {
729 Err(crate::TensorError::invalid_operation_simple(
730 "Failed to extract scalar value".to_string(),
731 ))
732 }
733 }
734 #[cfg(feature = "gpu")]
735 TensorStorage::Gpu(_) => {
736 let cpu_tensor = self.to_cpu()?;
738 cpu_tensor.to_scalar()
739 }
740 }
741 }
742
743 pub fn argmax(&self, axis: i32) -> Result<Tensor<usize>>
745 where
746 T: PartialOrd + Clone,
747 {
748 crate::ops::argmax(self, Some(axis), false)
749 }
750
751 pub fn flatten(&self) -> Result<Self>
768 where
769 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
770 {
771 crate::ops::flatten(self)
772 }
773
774 pub fn cumsum(&self, axis: Option<i32>) -> Result<Self>
790 where
791 T: Clone
792 + Default
793 + std::ops::Add<Output = T>
794 + scirs2_core::num_traits::Zero
795 + Send
796 + Sync
797 + 'static,
798 {
799 crate::ops::cumsum(self, axis)
800 }
801
802 pub fn cumprod(&self, axis: Option<i32>) -> Result<Self>
818 where
819 T: Clone
820 + Default
821 + std::ops::Mul<Output = T>
822 + scirs2_core::num_traits::One
823 + Send
824 + Sync
825 + 'static,
826 {
827 crate::ops::cumprod(self, axis)
828 }
829
830 pub fn tile(&self, multiples: &[usize]) -> Result<Self>
847 where
848 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
849 {
850 crate::ops::tile(self, multiples)
851 }
852
853 pub fn repeat(&self, repeats: usize, axis: Option<usize>) -> Result<Self>
871 where
872 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
873 {
874 crate::ops::repeat(self, repeats, axis)
875 }
876
877 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Self>
894 where
895 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
896 {
897 crate::ops::broadcast_to(self, target_shape)
898 }
899
900 pub fn expand_as(&self, target: &Self) -> Result<Self>
918 where
919 T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
920 {
921 crate::ops::expand_as(self, target)
922 }
923
924 pub fn multiply_scalar(&self, scalar: T) -> Result<Self>
926 where
927 T: Clone + std::ops::Mul<Output = T>,
928 {
929 match &self.storage {
930 TensorStorage::Cpu(arr) => {
931 let result = arr.mapv(|x| x * scalar);
932 Ok(Self {
933 storage: TensorStorage::Cpu(result),
934 shape: self.shape.clone(),
935 device: self.device,
936 requires_grad: self.requires_grad,
937 grad: None,
938 })
939 }
940 #[cfg(feature = "gpu")]
941 TensorStorage::Gpu(_) => Err(TensorError::unsupported_operation_simple(
942 "GPU scalar multiply not yet implemented".to_string(),
943 )),
944 }
945 }
946
947 pub fn dot(&self, other: &Self) -> Result<Self>
949 where
950 T: Clone
951 + Default
952 + scirs2_core::num_traits::Zero
953 + scirs2_core::num_traits::One
954 + std::ops::Add<Output = T>
955 + std::ops::Mul<Output = T>,
956 {
957 crate::ops::dot(self, other)
958 }
959
960 pub fn outer(&self, other: &Self) -> Result<Self>
962 where
963 T: Clone
964 + Default
965 + scirs2_core::num_traits::Zero
966 + scirs2_core::num_traits::One
967 + std::ops::Add<Output = T>
968 + std::ops::Mul<Output = T>,
969 {
970 crate::ops::outer(self, other)
971 }
972}