Skip to main content

tenflowers_core/tensor/
ops.rs

1//! Tensor Mathematical and Shape Operations
2//!
3//! This module contains all mathematical operations, activation functions,
4//! shape manipulation operations, and utility functions for tensors.
5//! It provides both CPU and GPU implementations where applicable.
6
7use super::core::{Tensor, TensorStorage};
8#[cfg(feature = "gpu")]
9use crate::Device;
10use crate::{Result, TensorError};
11use scirs2_core::numeric::Zero;
12
13// Impl block for methods that need Clone (includes gradient operations)
14impl<T: Clone> Tensor<T> {
15    /// Perform backward pass for gradient computation
16    pub fn backward(&self) -> Result<()>
17    where
18        T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
19    {
20        if !self.requires_grad() {
21            return Err(TensorError::GradientNotEnabled {
22                operation: "backward".to_string(),
23                suggestion: "Call tensor.requires_grad_(true) before computation".to_string(),
24                context: None,
25            });
26        }
27
28        // Check if this is a scalar tensor (required for backward)
29        if self.shape().dims().iter().product::<usize>() != 1 {
30            return Err(TensorError::invalid_shape_simple(
31                "backward() can only be called on scalar tensors".to_string(),
32            ));
33        }
34
35        // Initialize gradient for this tensor if it doesn't exist
36        // For a scalar tensor, the gradient with respect to itself is 1
37        self.init_gradient()?;
38
39        // Enhanced backward pass implementation
40        // This implementation provides a foundation for autograd integration
41        // When used with tenflowers-autograd's GradientTape, this method serves as
42        // the entry point for automatic differentiation
43
44        // For full computation graph support, users should:
45        // 1. Wrap tensors with TrackedTensor from tenflowers-autograd
46        // 2. Use GradientTape to record operations
47        // 3. Call tape.compute_gradients() for the full backward pass
48        //
49        // This basic implementation handles the scalar case and prepares
50        // the gradient field for integration with advanced autograd systems
51
52        Ok(())
53    }
54
55    /// Enhanced backward pass with additional autograd options
56    pub fn backward_with_options(&self, retain_graph: bool, create_graph: bool) -> Result<()>
57    where
58        T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
59    {
60        if !self.requires_grad() {
61            return Err(TensorError::GradientNotEnabled {
62                operation: "backward".to_string(),
63                suggestion: "Call tensor.requires_grad_(true) before computation".to_string(),
64                context: None,
65            });
66        }
67
68        // Check if this is a scalar tensor (required for backward)
69        if self.shape().dims().iter().product::<usize>() != 1 {
70            return Err(TensorError::invalid_shape_simple(
71                "backward() can only be called on scalar tensors".to_string(),
72            ));
73        }
74
75        // Initialize gradient for this tensor if it doesn't exist
76        self.init_gradient()?;
77
78        // Enhanced backward pass with autograd options
79        // retain_graph: If true, the computation graph is retained for multiple backward passes
80        // create_graph: If true, creates a graph for computing higher-order derivatives
81
82        if retain_graph {
83            // In a full implementation, this would preserve the computation graph
84            // For now, we'll treat this the same as regular backward but add a comment
85            // that the graph would be retained in a production autograd system
86        }
87
88        if create_graph {
89            // In a full implementation, this would enable computation of higher-order derivatives
90            // by creating a new computation graph for the gradient computation itself
91            // For now, we note that this would enable second-order gradients
92        }
93
94        // The basic implementation remains the same, but these parameters provide
95        // hooks for future autograd system integration
96
97        Ok(())
98    }
99
100    /// Initialize gradient for this tensor with ones (for scalar) or appropriate shape
101    fn init_gradient(&self) -> Result<()>
102    where
103        T: Clone + Default + scirs2_core::num_traits::Zero + scirs2_core::num_traits::One,
104    {
105        // Only initialize if gradient doesn't already exist
106        if self.grad().is_some() {
107            return Ok(());
108        }
109
110        // Enhanced gradient initialization for autograd integration
111        // For scalar tensors used as loss functions, the gradient starts as 1.0
112        // For other tensors, gradients are initialized based on their role in the computation
113
114        // Note: Current architecture stores grad as immutable Arc<Tensor<T>>
115        // For full mutable gradient support, consider using tenflowers-autograd's
116        // TrackedTensor which provides mutable gradient accumulation through GradientTape
117        //
118        // This method validates gradient requirements and prepares the tensor
119        // for integration with external autograd systems
120
121        Ok(())
122    }
123}
124
125impl<T> Tensor<T>
126where
127    T: Clone
128        + Default
129        + scirs2_core::num_traits::Zero
130        + scirs2_core::num_traits::One
131        + Send
132        + Sync
133        + 'static
134        + bytemuck::Pod
135        + bytemuck::Zeroable,
136{
137    /// Element-wise addition
138    pub fn add(&self, other: &Self) -> Result<Self>
139    where
140        T: std::ops::Add<Output = T>,
141    {
142        crate::ops::add(self, other)
143    }
144
145    /// Element-wise subtraction
146    pub fn sub(&self, other: &Self) -> Result<Self>
147    where
148        T: std::ops::Sub<Output = T>,
149    {
150        crate::ops::sub(self, other)
151    }
152
153    /// Element-wise multiplication
154    pub fn mul(&self, other: &Self) -> Result<Self>
155    where
156        T: std::ops::Mul<Output = T>,
157    {
158        crate::ops::mul(self, other)
159    }
160
161    /// Element-wise division
162    pub fn div(&self, other: &Self) -> Result<Self>
163    where
164        T: std::ops::Div<Output = T>,
165    {
166        crate::ops::div(self, other)
167    }
168
169    /// Element-wise power operation
170    pub fn pow(&self, other: &Self) -> Result<Self>
171    where
172        T: scirs2_core::num_traits::Float,
173    {
174        crate::ops::pow(self, other)
175    }
176
177    /// Element-wise natural logarithm
178    pub fn log(&self) -> Result<Self>
179    where
180        T: scirs2_core::num_traits::Float,
181    {
182        match &self.storage {
183            TensorStorage::Cpu(arr) => {
184                let result = arr.mapv(|x| x.ln());
185                Ok(Self::from_array(result))
186            }
187            #[cfg(feature = "gpu")]
188            TensorStorage::Gpu(buffer) => self.log_gpu_impl(buffer),
189        }
190    }
191
192    #[cfg(feature = "gpu")]
193    fn log_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
194    where
195        T: scirs2_core::num_traits::Float
196            + bytemuck::Pod
197            + bytemuck::Zeroable
198            + Clone
199            + Send
200            + Sync
201            + 'static,
202    {
203        use crate::gpu::ops::{execute_unary_op, UnaryOp};
204        let result_buffer = execute_unary_op(buffer, UnaryOp::Log)?;
205        Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
206    }
207
208    /// Element-wise negation
209    pub fn neg(&self) -> Result<Self>
210    where
211        T: std::ops::Neg<Output = T>,
212    {
213        match &self.storage {
214            TensorStorage::Cpu(arr) => {
215                let result = arr.mapv(|x| -x);
216                Ok(Self::from_array(result))
217            }
218            #[cfg(feature = "gpu")]
219            TensorStorage::Gpu(buffer) => self.neg_gpu_impl(buffer),
220        }
221    }
222
223    #[cfg(feature = "gpu")]
224    fn neg_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
225    where
226        T: std::ops::Neg<Output = T>
227            + bytemuck::Pod
228            + bytemuck::Zeroable
229            + Clone
230            + Send
231            + Sync
232            + 'static,
233    {
234        use crate::gpu::ops::{execute_unary_op, UnaryOp};
235        let result_buffer = execute_unary_op(buffer, UnaryOp::Neg)?;
236        Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
237    }
238
239    /// Matrix multiplication
240    pub fn matmul(&self, other: &Self) -> Result<Self> {
241        crate::ops::matmul(self, other)
242    }
243
244    // Activation functions
245    /// ReLU activation function
246    pub fn relu(&self) -> Result<Self>
247    where
248        T: PartialOrd + scirs2_core::num_traits::Zero + bytemuck::Pod + bytemuck::Zeroable,
249    {
250        crate::ops::activation::relu(self)
251    }
252
253    /// Sigmoid activation function
254    pub fn sigmoid(&self) -> Result<Self>
255    where
256        T: scirs2_core::num_traits::Float + bytemuck::Pod + bytemuck::Zeroable,
257    {
258        crate::ops::activation::sigmoid(self)
259    }
260
261    /// Hyperbolic tangent activation function
262    pub fn tanh(&self) -> Result<Self>
263    where
264        T: scirs2_core::num_traits::Float + bytemuck::Pod + bytemuck::Zeroable,
265    {
266        crate::ops::activation::tanh(self)
267    }
268
269    /// GELU activation function
270    pub fn gelu(&self) -> Result<Self>
271    where
272        T: scirs2_core::num_traits::Float + bytemuck::Pod,
273    {
274        crate::ops::activation::gelu(self)
275    }
276
277    /// Swish activation function
278    pub fn swish(&self) -> Result<Self>
279    where
280        T: scirs2_core::num_traits::Float + bytemuck::Pod,
281    {
282        crate::ops::activation::swish(self)
283    }
284
285    /// Mish activation function
286    pub fn mish(&self) -> Result<Self>
287    where
288        T: scirs2_core::num_traits::Float
289            + Send
290            + Sync
291            + 'static
292            + bytemuck::Pod
293            + bytemuck::Zeroable,
294    {
295        crate::ops::activation::mish(self)
296    }
297
298    /// Softmax activation function
299    pub fn softmax(&self, axis: Option<i32>) -> Result<Self>
300    where
301        T: scirs2_core::num_traits::Float
302            + std::ops::Sub<Output = T>
303            + std::ops::Add<Output = T>
304            + std::ops::Div<Output = T>
305            + std::iter::Sum
306            + Send
307            + Sync
308            + bytemuck::Pod,
309    {
310        crate::ops::activation::softmax(self, axis)
311    }
312
313    /// ELU activation function
314    pub fn elu(&self, alpha: T) -> Result<Self>
315    where
316        T: scirs2_core::num_traits::Float + PartialOrd + bytemuck::Pod,
317    {
318        crate::ops::activation::elu(self, alpha)
319    }
320
321    /// Leaky ReLU activation function
322    pub fn leaky_relu(&self, alpha: T) -> Result<Self>
323    where
324        T: scirs2_core::num_traits::Float + PartialOrd + bytemuck::Pod,
325    {
326        crate::ops::activation::leaky_relu(self, alpha)
327    }
328
329    /// Hard Swish activation function
330    pub fn hard_swish(&self) -> Result<Self>
331    where
332        T: scirs2_core::num_traits::Float + PartialOrd,
333    {
334        crate::ops::activation::hard_swish(self)
335    }
336
337    /// Parametric ReLU activation function
338    pub fn prelu(&self, alpha: &Self) -> Result<Self>
339    where
340        T: scirs2_core::num_traits::Float + PartialOrd,
341    {
342        crate::ops::activation::prelu(self, alpha)
343    }
344
345    /// Reshape tensor to new shape
346    pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
347        crate::ops::reshape(self, shape)
348    }
349
350    /// Transpose tensor (swap last two dimensions)
351    pub fn transpose(&self) -> Result<Self> {
352        crate::ops::transpose(self)
353    }
354
355    /// Slice tensor along specified ranges
356    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Result<Self> {
357        crate::ops::slice(self, ranges)
358    }
359
360    /// Slice tensor with stride parameters
361    pub fn slice_with_stride(&self, slice_params: &[crate::SliceParams]) -> Result<Self> {
362        crate::ops::slice_with_stride(self, slice_params)
363    }
364
365    /// Sum tensor along specified axes
366    pub fn sum(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
367    where
368        T: Zero,
369    {
370        crate::ops::sum(self, axes, keepdims)
371    }
372
373    /// Mean tensor along specified axes
374    pub fn mean(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
375    where
376        T: scirs2_core::num_traits::Float + scirs2_core::num_traits::FromPrimitive,
377    {
378        crate::ops::mean(self, axes, keepdims)
379    }
380
381    /// Maximum values along specified axes
382    pub fn max(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
383    where
384        T: PartialOrd,
385    {
386        crate::ops::max(self, axes, keepdims)
387    }
388
389    /// Minimum values along specified axes
390    pub fn min(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
391    where
392        T: PartialOrd,
393    {
394        crate::ops::min(self, axes, keepdims)
395    }
396
397    /// Element-wise square root
398    pub fn sqrt(&self) -> Result<Self>
399    where
400        T: scirs2_core::num_traits::Float,
401    {
402        match &self.storage {
403            TensorStorage::Cpu(arr) => {
404                let result = arr.mapv(|x| x.sqrt());
405                Ok(Self::from_array(result))
406            }
407            #[cfg(feature = "gpu")]
408            TensorStorage::Gpu(buffer) => self.sqrt_gpu_impl(buffer),
409        }
410    }
411
412    #[cfg(feature = "gpu")]
413    fn sqrt_gpu_impl(&self, buffer: &crate::gpu::buffer::GpuBuffer<T>) -> Result<Self>
414    where
415        T: scirs2_core::num_traits::Float
416            + bytemuck::Pod
417            + bytemuck::Zeroable
418            + Clone
419            + Send
420            + Sync
421            + 'static,
422    {
423        use crate::gpu::ops::{execute_unary_op, UnaryOp};
424        let result_buffer = execute_unary_op(buffer, UnaryOp::Sqrt)?;
425        Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
426    }
427
428    /// Element-wise absolute value
429    pub fn abs(&self) -> Result<Self>
430    where
431        T: scirs2_core::num_traits::Signed,
432    {
433        match &self.storage {
434            TensorStorage::Cpu(arr) => {
435                let result = arr.mapv(|x| x.abs());
436                Ok(Self::from_array(result))
437            }
438            #[cfg(feature = "gpu")]
439            TensorStorage::Gpu(buffer) => {
440                use crate::gpu::ops::{execute_unary_op, UnaryOp};
441                let result_buffer = execute_unary_op(buffer, UnaryOp::Abs)?;
442                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
443            }
444        }
445    }
446
447    /// Element-wise exponential function
448    pub fn exp(&self) -> Result<Self>
449    where
450        T: scirs2_core::num_traits::Float,
451    {
452        match &self.storage {
453            TensorStorage::Cpu(arr) => {
454                let result = arr.mapv(|x| x.exp());
455                Ok(Self::from_array(result))
456            }
457            #[cfg(feature = "gpu")]
458            TensorStorage::Gpu(buffer) => {
459                use crate::gpu::ops::{execute_unary_op, UnaryOp};
460                let result_buffer = execute_unary_op(buffer, UnaryOp::Exp)?;
461                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
462            }
463        }
464    }
465
466    /// Element-wise sine function
467    pub fn sin(&self) -> Result<Self>
468    where
469        T: scirs2_core::num_traits::Float,
470    {
471        match &self.storage {
472            TensorStorage::Cpu(arr) => {
473                let result = arr.mapv(|x| x.sin());
474                Ok(Self::from_array(result))
475            }
476            #[cfg(feature = "gpu")]
477            TensorStorage::Gpu(buffer) => {
478                use crate::gpu::ops::{execute_unary_op, UnaryOp};
479                let result_buffer = execute_unary_op(buffer, UnaryOp::Sin)?;
480                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
481            }
482        }
483    }
484
485    /// Element-wise cosine function
486    pub fn cos(&self) -> Result<Self>
487    where
488        T: scirs2_core::num_traits::Float,
489    {
490        match &self.storage {
491            TensorStorage::Cpu(arr) => {
492                let result = arr.mapv(|x| x.cos());
493                Ok(Self::from_array(result))
494            }
495            #[cfg(feature = "gpu")]
496            TensorStorage::Gpu(buffer) => {
497                use crate::gpu::ops::{execute_unary_op, UnaryOp};
498                let result_buffer = execute_unary_op(buffer, UnaryOp::Cos)?;
499                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
500            }
501        }
502    }
503
504    /// Element-wise tangent function
505    pub fn tan(&self) -> Result<Self>
506    where
507        T: scirs2_core::num_traits::Float,
508    {
509        match &self.storage {
510            TensorStorage::Cpu(arr) => {
511                let result = arr.mapv(|x| x.tan());
512                Ok(Self::from_array(result))
513            }
514            #[cfg(feature = "gpu")]
515            TensorStorage::Gpu(buffer) => {
516                use crate::gpu::ops::{execute_unary_op, UnaryOp};
517                let result_buffer = execute_unary_op(buffer, UnaryOp::Tan)?;
518                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
519            }
520        }
521    }
522
523    /// Element-wise reciprocal function
524    pub fn recip(&self) -> Result<Self>
525    where
526        T: scirs2_core::num_traits::Float,
527    {
528        match &self.storage {
529            TensorStorage::Cpu(arr) => {
530                let result = arr.mapv(|x| x.recip());
531                Ok(Self::from_array(result))
532            }
533            #[cfg(feature = "gpu")]
534            TensorStorage::Gpu(buffer) => {
535                use crate::gpu::ops::{execute_unary_op, UnaryOp};
536                let result_buffer = execute_unary_op(buffer, UnaryOp::Recip)?;
537                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
538            }
539        }
540    }
541
542    /// Squeeze tensor - remove dimensions of size 1
543    pub fn squeeze(&self, axes: Option<&[usize]>) -> Result<Self>
544    where
545        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
546    {
547        crate::ops::squeeze(self, axes)
548    }
549
550    /// Unsqueeze tensor - add dimensions of size 1
551    pub fn unsqueeze(&self, axes: &[usize]) -> Result<Self>
552    where
553        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
554    {
555        crate::ops::unsqueeze(self, axes)
556    }
557
558    /// Scalar multiplication
559    pub fn scalar_mul(&self, scalar: T) -> Result<Self>
560    where
561        T: Clone + Default + std::ops::Mul<Output = T> + Send + Sync + 'static,
562    {
563        match &self.storage {
564            TensorStorage::Cpu(arr) => {
565                let result = arr.mapv(|x| x * scalar);
566                Ok(Self::from_array(result))
567            }
568            #[cfg(feature = "gpu")]
569            TensorStorage::Gpu(buffer) => {
570                use crate::gpu::ops::{execute_binary_scalar_op, BinaryScalarOp};
571                let result_buffer = execute_binary_scalar_op(buffer, scalar, BinaryScalarOp::Mul)?;
572                Ok(Self::from_gpu_buffer(result_buffer, self.shape().clone()))
573            }
574        }
575    }
576
577    /// Convert tensor to vector
578    pub fn to_vec(&self) -> Result<Vec<T>>
579    where
580        T: Clone
581            + Default
582            + Send
583            + Sync
584            + 'static
585            + scirs2_core::num_traits::Zero
586            + scirs2_core::num_traits::One,
587    {
588        match &self.storage {
589            TensorStorage::Cpu(arr) => {
590                if let Some(slice) = arr.as_slice() {
591                    Ok(slice.to_vec())
592                } else {
593                    // Handle non-contiguous arrays
594                    Ok(arr.iter().cloned().collect())
595                }
596            }
597            #[cfg(feature = "gpu")]
598            TensorStorage::Gpu(buffer) => {
599                let cpu_array = buffer.to_cpu_array()?;
600                if let Some(slice) = cpu_array.as_slice() {
601                    Ok(slice.to_vec())
602                } else {
603                    // Handle non-contiguous arrays
604                    Ok(cpu_array.iter().cloned().collect())
605                }
606            }
607        }
608    }
609
610    /// Maximum values along specified axes
611    pub fn max_axis(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
612    where
613        T: Clone + Default + PartialOrd + Send + Sync + 'static,
614    {
615        crate::ops::reduction::max(self, axes, keepdims)
616    }
617
618    /// Sum along specified axes
619    pub fn sum_axis(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self>
620    where
621        T: Clone + Default + Zero + std::ops::Add<Output = T> + Send + Sync + 'static,
622    {
623        crate::ops::reduction::sum(self, axes, keepdims)
624    }
625
626    /// Clamp tensor values between min and max
627    pub fn clamp(&self, min: T, max: T) -> Result<Self>
628    where
629        T: PartialOrd + Clone,
630    {
631        match &self.storage {
632            TensorStorage::Cpu(arr) => {
633                let result = arr.mapv(|x| {
634                    if x < min {
635                        min
636                    } else if x > max {
637                        max
638                    } else {
639                        x
640                    }
641                });
642                Ok(Self::from_array(result))
643            }
644            #[cfg(feature = "gpu")]
645            TensorStorage::Gpu(_) => {
646                // For GPU, convert to CPU, clamp, and convert back
647                let cpu_tensor = self.to_cpu()?;
648                let clamped_cpu = cpu_tensor.clamp(min, max)?;
649                if let Device::Gpu(gpu_id) = self.device {
650                    clamped_cpu.to_gpu(gpu_id)
651                } else {
652                    Ok(clamped_cpu)
653                }
654            }
655        }
656    }
657
658    /// Check if all elements are close to another tensor within tolerance
659    pub fn allclose(&self, other: &Self, rtol: T, atol: T) -> Result<bool>
660    where
661        T: scirs2_core::num_traits::Float + Clone,
662    {
663        if self.shape() != other.shape() {
664            return Ok(false);
665        }
666
667        match (&self.storage, &other.storage) {
668            (TensorStorage::Cpu(a), TensorStorage::Cpu(b)) => {
669                use scirs2_core::ndarray::Zip;
670                let mut all_close = true;
671                Zip::from(a).and(b).for_each(|&a_val, &b_val| {
672                    let diff = (a_val - b_val).abs();
673                    let tolerance = atol + rtol * b_val.abs().max(a_val.abs());
674                    if diff > tolerance {
675                        all_close = false;
676                    }
677                });
678                Ok(all_close)
679            }
680            #[cfg(feature = "gpu")]
681            _ => {
682                // Convert to CPU for comparison
683                let self_cpu = self.to_cpu()?;
684                let other_cpu = other.to_cpu()?;
685                self_cpu.allclose(&other_cpu, rtol, atol)
686            }
687        }
688    }
689
690    /// Fill tensor with specified value
691    pub fn fill_(&mut self, value: T) -> Result<()>
692    where
693        T: Clone,
694    {
695        match &mut self.storage {
696            TensorStorage::Cpu(arr) => {
697                arr.fill(value);
698                Ok(())
699            }
700            #[cfg(feature = "gpu")]
701            TensorStorage::Gpu(_) => {
702                // For GPU, create a new tensor with the fill value and copy it back
703                let filled_cpu = Tensor::full(self.shape().dims(), value);
704                let transferred = filled_cpu.to_device(self.device)?;
705                self.storage = transferred.storage;
706                Ok(())
707            }
708        }
709    }
710
711    /// Extract scalar value from a 0-dimensional tensor
712    pub fn to_scalar(&self) -> Result<T>
713    where
714        T: Clone,
715    {
716        if !self.is_scalar() {
717            return Err(crate::TensorError::invalid_operation_simple(format!(
718                "Cannot extract scalar from tensor with shape {:?}",
719                self.shape().dims()
720            )));
721        }
722
723        match &self.storage {
724            TensorStorage::Cpu(arr) => {
725                // For scalar tensors, we can get the single element
726                if let Some(scalar) = arr.as_slice().and_then(|s| s.first()) {
727                    Ok(*scalar)
728                } else {
729                    Err(crate::TensorError::invalid_operation_simple(
730                        "Failed to extract scalar value".to_string(),
731                    ))
732                }
733            }
734            #[cfg(feature = "gpu")]
735            TensorStorage::Gpu(_) => {
736                // For GPU tensors, we need to copy to CPU first
737                let cpu_tensor = self.to_cpu()?;
738                cpu_tensor.to_scalar()
739            }
740        }
741    }
742
743    /// Find the indices of the maximum values along the specified axis
744    pub fn argmax(&self, axis: i32) -> Result<Tensor<usize>>
745    where
746        T: PartialOrd + Clone,
747    {
748        crate::ops::argmax(self, Some(axis), false)
749    }
750
751    /// Flatten the tensor into a 1D tensor
752    ///
753    /// This operation reshapes the tensor into a 1-dimensional tensor
754    /// containing the same elements in row-major (C-style) order.
755    ///
756    /// # Returns
757    /// A 1D tensor containing all elements from the input tensor
758    ///
759    /// # Examples
760    /// ```
761    /// use tenflowers_core::Tensor;
762    ///
763    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("from_vec should succeed");
764    /// let flattened = tensor.flatten().expect("flatten should not fail");
765    /// assert_eq!(flattened.shape().dims(), &[4]);
766    /// ```
767    pub fn flatten(&self) -> Result<Self>
768    where
769        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
770    {
771        crate::ops::flatten(self)
772    }
773
774    /// Compute the cumulative sum of elements along the given axis
775    ///
776    /// # Arguments
777    /// * `axis` - Axis along which to compute the cumulative sum. If None, flatten the tensor first.
778    ///
779    /// # Returns
780    /// A tensor with cumulative sums along the specified axis
781    ///
782    /// # Examples
783    /// ```
784    /// use tenflowers_core::Tensor;
785    ///
786    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("from_vec should succeed");
787    /// let cumsum = tensor.cumsum(Some(0)).expect("operation should succeed");
788    /// ```
789    pub fn cumsum(&self, axis: Option<i32>) -> Result<Self>
790    where
791        T: Clone
792            + Default
793            + std::ops::Add<Output = T>
794            + scirs2_core::num_traits::Zero
795            + Send
796            + Sync
797            + 'static,
798    {
799        crate::ops::cumsum(self, axis)
800    }
801
802    /// Compute the cumulative product of elements along the given axis
803    ///
804    /// # Arguments
805    /// * `axis` - Axis along which to compute the cumulative product. If None, flatten the tensor first.
806    ///
807    /// # Returns
808    /// A tensor with cumulative products along the specified axis
809    ///
810    /// # Examples
811    /// ```
812    /// use tenflowers_core::Tensor;
813    ///
814    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("from_vec should succeed");
815    /// let cumprod = tensor.cumprod(Some(0)).expect("operation should succeed");
816    /// ```
817    pub fn cumprod(&self, axis: Option<i32>) -> Result<Self>
818    where
819        T: Clone
820            + Default
821            + std::ops::Mul<Output = T>
822            + scirs2_core::num_traits::One
823            + Send
824            + Sync
825            + 'static,
826    {
827        crate::ops::cumprod(self, axis)
828    }
829
830    /// Tile the tensor by repeating it along each axis
831    ///
832    /// # Arguments
833    /// * `multiples` - The number of repetitions along each axis
834    ///
835    /// # Returns
836    /// A tensor with the input tiled according to the multiples
837    ///
838    /// # Examples
839    /// ```
840    /// use tenflowers_core::Tensor;
841    ///
842    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2]).expect("from_vec should succeed");
843    /// let tiled = tensor.tile(&[2, 3]).expect("tile should succeed");
844    /// assert_eq!(tiled.shape().dims(), &[2, 6]);
845    /// ```
846    pub fn tile(&self, multiples: &[usize]) -> Result<Self>
847    where
848        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
849    {
850        crate::ops::tile(self, multiples)
851    }
852
853    /// Repeat elements of the tensor
854    ///
855    /// # Arguments
856    /// * `repeats` - The number of repetitions for each element
857    /// * `axis` - The axis along which to repeat values. If None, the input tensor is flattened first.
858    ///
859    /// # Returns
860    /// A tensor with repeated elements
861    ///
862    /// # Examples
863    /// ```
864    /// use tenflowers_core::Tensor;
865    ///
866    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("from_vec should succeed");
867    /// let repeated = tensor.repeat(2, Some(0)).expect("operation should succeed");
868    /// assert_eq!(repeated.shape().dims(), &[6]);
869    /// ```
870    pub fn repeat(&self, repeats: usize, axis: Option<usize>) -> Result<Self>
871    where
872        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
873    {
874        crate::ops::repeat(self, repeats, axis)
875    }
876
877    /// Broadcast the tensor to a new shape
878    ///
879    /// # Arguments
880    /// * `target_shape` - The shape to broadcast to
881    ///
882    /// # Returns
883    /// A tensor broadcasted to the target shape
884    ///
885    /// # Examples
886    /// ```
887    /// use tenflowers_core::Tensor;
888    ///
889    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2]).expect("from_vec should succeed");
890    /// let broadcasted = tensor.broadcast_to(&[3, 2]).expect("broadcast_to should succeed");
891    /// assert_eq!(broadcasted.shape().dims(), &[3, 2]);
892    /// ```
893    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Self>
894    where
895        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
896    {
897        crate::ops::broadcast_to(self, target_shape)
898    }
899
900    /// Expand tensor dimensions to match another tensor's shape
901    ///
902    /// # Arguments
903    /// * `target` - The tensor whose shape to match
904    ///
905    /// # Returns
906    /// A tensor expanded to match the target tensor's shape
907    ///
908    /// # Examples
909    /// ```
910    /// use tenflowers_core::Tensor;
911    ///
912    /// let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[1, 2]).expect("from_vec should succeed");
913    /// let target = Tensor::<f32>::zeros(&[3, 2]);
914    /// let expanded = tensor.expand_as(&target).expect("expand_as should succeed");
915    /// assert_eq!(expanded.shape().dims(), &[3, 2]);
916    /// ```
917    pub fn expand_as(&self, target: &Self) -> Result<Self>
918    where
919        T: Clone + Default + scirs2_core::num_traits::Zero + Send + Sync + 'static,
920    {
921        crate::ops::expand_as(self, target)
922    }
923
924    /// Scalar multiplication
925    pub fn multiply_scalar(&self, scalar: T) -> Result<Self>
926    where
927        T: Clone + std::ops::Mul<Output = T>,
928    {
929        match &self.storage {
930            TensorStorage::Cpu(arr) => {
931                let result = arr.mapv(|x| x * scalar);
932                Ok(Self {
933                    storage: TensorStorage::Cpu(result),
934                    shape: self.shape.clone(),
935                    device: self.device,
936                    requires_grad: self.requires_grad,
937                    grad: None,
938                })
939            }
940            #[cfg(feature = "gpu")]
941            TensorStorage::Gpu(_) => Err(TensorError::unsupported_operation_simple(
942                "GPU scalar multiply not yet implemented".to_string(),
943            )),
944        }
945    }
946
947    /// Dot product of two 1D tensors
948    pub fn dot(&self, other: &Self) -> Result<Self>
949    where
950        T: Clone
951            + Default
952            + scirs2_core::num_traits::Zero
953            + scirs2_core::num_traits::One
954            + std::ops::Add<Output = T>
955            + std::ops::Mul<Output = T>,
956    {
957        crate::ops::dot(self, other)
958    }
959
960    /// Outer product of two 1D tensors
961    pub fn outer(&self, other: &Self) -> Result<Self>
962    where
963        T: Clone
964            + Default
965            + scirs2_core::num_traits::Zero
966            + scirs2_core::num_traits::One
967            + std::ops::Add<Output = T>
968            + std::ops::Mul<Output = T>,
969    {
970        crate::ops::outer(self, other)
971    }
972}