Skip to main content

torsh_tensor/
convenience.rs

1//! Convenience methods for tensor manipulation
2//!
3//! This module provides convenient shortcuts and aliases for common tensor operations
4//! to improve ergonomics and match PyTorch/NumPy APIs.
5
6use crate::{Tensor, TensorElement};
7use torsh_core::error::Result;
8
9/// Convenience trait for tensor manipulation shortcuts
10pub trait TensorConvenience<T: TensorElement> {
11    /// Transpose shortcut (equivalent to .transpose())
12    ///
13    /// # Examples
14    /// ```
15    /// # use torsh_tensor::{tensor_2d, convenience::TensorConvenience};
16    /// let tensor = tensor_2d!([&[1.0, 2.0], &[3.0, 4.0]]).expect("tensor creation failed");
17    /// let transposed = tensor.T().expect("transpose failed");
18    /// ```
19    #[allow(non_snake_case)]
20    fn T(&self) -> Result<Tensor<T>>;
21
22    /// Matrix transpose (alias for .T())
23    #[allow(non_snake_case)]
24    fn mT(&self) -> Result<Tensor<T>>;
25
26    /// Hermitian transpose (conjugate transpose for complex numbers)
27    #[allow(non_snake_case)]
28    fn H(&self) -> Result<Tensor<T>>;
29
30    /// Transpose shortcut (snake_case version)
31    fn t(&self) -> Result<Tensor<T>>;
32
33    /// Matrix transpose (snake_case version)
34    fn m_t(&self) -> Result<Tensor<T>>;
35
36    /// Hermitian transpose (snake_case version)
37    fn h(&self) -> Result<Tensor<T>>;
38
39    /// Detach tensor from computational graph (creates a new tensor without gradients)
40    fn detach(&self) -> Tensor<T>;
41
42    /// Clone tensor data (creates a deep copy)
43    fn clone_tensor(&self) -> Result<Tensor<T>>;
44
45    /// Check if tensor is contiguous in memory
46    fn is_contiguous(&self) -> bool;
47
48    /// Make tensor contiguous (reorganize memory layout)
49    fn contiguous(&self) -> Result<Tensor<T>>;
50
51    /// Get number of elements in tensor
52    fn numel(&self) -> usize;
53
54    /// Get tensor size (alias for shape().dims())
55    fn size(&self) -> Vec<usize>;
56
57    /// Check if tensor is empty (has zero elements)
58    fn is_empty(&self) -> bool;
59
60    /// Check if tensor is scalar (zero dimensions)
61    fn is_scalar(&self) -> bool;
62
63    /// Get tensor item as scalar (only works for scalar tensors)
64    fn item(&self) -> T;
65
66    /// Convert tensor to scalar (squeezes all dimensions of size 1 first)
67    fn to_scalar(&self) -> Result<T>;
68}
69
70impl<T: TensorElement + Copy + torsh_core::FloatElement> TensorConvenience<T> for Tensor<T> {
71    #[allow(non_snake_case)]
72    fn T(&self) -> Result<Tensor<T>> {
73        // For 2D tensors, transpose is straightforward
74        if self.shape().dims().len() == 2 {
75            self.transpose(0, 1)
76        } else if self.shape().dims().len() == 1 {
77            // 1D tensor transpose returns the same tensor
78            Ok(self.clone())
79        } else {
80            // For higher dimensional tensors, transpose last two dimensions
81            let ndim = self.shape().dims().len();
82            if ndim >= 2 {
83                self.transpose((ndim - 2) as i32, (ndim - 1) as i32)
84            } else {
85                Ok(self.clone())
86            }
87        }
88    }
89
90    #[allow(non_snake_case)]
91    fn mT(&self) -> Result<Tensor<T>> {
92        self.T()
93    }
94
95    #[allow(non_snake_case)]
96    fn H(&self) -> Result<Tensor<T>> {
97        // For real numbers, Hermitian transpose is just transpose
98        // For complex numbers, we need conjugate transpose
99        let transposed = self.T()?;
100
101        // If T implements conjugate operation, apply it
102        // For now, just return transpose for real numbers
103        Ok(transposed)
104    }
105
106    fn t(&self) -> Result<Tensor<T>> {
107        self.T()
108    }
109
110    fn m_t(&self) -> Result<Tensor<T>> {
111        self.T()
112    }
113
114    fn h(&self) -> Result<Tensor<T>> {
115        self.H()
116    }
117
118    fn detach(&self) -> Tensor<T> {
119        // Create a new tensor without gradient tracking
120        // For now, just return a clone since we don't have gradient tracking implemented
121        self.clone()
122    }
123
124    fn clone_tensor(&self) -> Result<Tensor<T>> {
125        Ok(self.detach())
126    }
127
128    fn is_contiguous(&self) -> bool {
129        // Check if strides follow row-major order
130        let shape_ref = self.shape();
131        let shape = shape_ref.dims();
132        if shape.is_empty() {
133            return true;
134        }
135
136        let mut _expected_stride = 1;
137        for &dim_size in shape.iter().rev() {
138            _expected_stride *= dim_size;
139        }
140
141        // For now, assume tensors are contiguous
142        // TODO: Add actual stride checking when stride information is available
143        true
144    }
145
146    fn contiguous(&self) -> Result<Tensor<T>> {
147        if self.is_contiguous() {
148            Ok(self.clone())
149        } else {
150            // Reorganize memory layout to be contiguous
151            self.clone_tensor()
152        }
153    }
154
155    fn numel(&self) -> usize {
156        self.shape().dims().iter().product()
157    }
158
159    fn size(&self) -> Vec<usize> {
160        self.shape().dims().to_vec()
161    }
162
163    fn is_empty(&self) -> bool {
164        self.numel() == 0
165    }
166
167    fn is_scalar(&self) -> bool {
168        self.shape().dims().is_empty()
169    }
170
171    fn item(&self) -> T {
172        // Get a single item from scalar tensor
173        if self.numel() != 1 {
174            panic!("Can only call item() on tensors with one element");
175        }
176        let data = self
177            .to_vec()
178            .expect("tensor to vec conversion should succeed");
179        data[0]
180    }
181
182    fn to_scalar(&self) -> Result<T> {
183        // First squeeze all dimensions of size 1
184        let squeezed = self.squeeze_all()?;
185        squeezed.item()
186    }
187}
188
189/// Additional convenience methods for specific tensor operations
190pub trait TensorShapeConvenience<T: TensorElement> {
191    /// Add singleton dimension at specified position
192    fn unsqueeze_at(&self, dim: i32) -> Result<Tensor<T>>;
193
194    /// Remove all singleton dimensions
195    fn squeeze_all(&self) -> Result<Tensor<T>>;
196
197    /// Flatten tensor to 1D (preserving total number of elements)
198    fn flatten(&self) -> Result<Tensor<T>>;
199
200    /// Flatten tensor starting from specified dimension
201    fn flatten_from(&self, start_dim: i32) -> Result<Tensor<T>>;
202
203    /// Unflatten tensor back to specified shape
204    fn unflatten(&self, dim: i32, sizes: &[usize]) -> Result<Tensor<T>>;
205}
206
207impl<T: TensorElement + Copy> TensorShapeConvenience<T> for Tensor<T> {
208    fn unsqueeze_at(&self, dim: i32) -> Result<Tensor<T>> {
209        self.unsqueeze(dim)
210    }
211
212    fn squeeze_all(&self) -> Result<Tensor<T>> {
213        let mut result = self.clone();
214        let shape_ref = self.shape();
215        let dims = shape_ref.dims();
216
217        // Remove all dimensions of size 1
218        for (i, &size) in dims.iter().enumerate().rev() {
219            if size == 1 {
220                result = result.squeeze(i as i32)?;
221            }
222        }
223
224        Ok(result)
225    }
226
227    fn flatten(&self) -> Result<Tensor<T>> {
228        let total_elements = self.numel();
229        self.reshape(&[total_elements as i32])
230    }
231
232    fn flatten_from(&self, start_dim: i32) -> Result<Tensor<T>> {
233        let shape_ref = self.shape();
234        let shape = shape_ref.dims();
235        let ndim = shape.len() as i32;
236        let start_dim = if start_dim < 0 {
237            ndim + start_dim
238        } else {
239            start_dim
240        };
241
242        if start_dim < 0 || start_dim >= ndim {
243            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
244                "Invalid start_dim {start_dim} for tensor with {ndim} dimensions"
245            )));
246        }
247
248        let mut new_shape = Vec::new();
249
250        // Keep dimensions before start_dim
251        for &dim in shape.iter().take(start_dim as usize) {
252            new_shape.push(dim);
253        }
254
255        // Flatten dimensions from start_dim onwards
256        let flattened_size: usize = shape[start_dim as usize..].iter().product();
257        new_shape.push(flattened_size);
258
259        let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
260        self.reshape(&new_shape_i32)
261    }
262
263    fn unflatten(&self, dim: i32, sizes: &[usize]) -> Result<Tensor<T>> {
264        let shape_ref = self.shape();
265        let shape = shape_ref.dims();
266        let ndim = shape.len() as i32;
267        let dim = if dim < 0 { ndim + dim } else { dim };
268
269        if dim < 0 || dim >= ndim {
270            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
271                "Invalid dim {dim} for tensor with {ndim} dimensions"
272            )));
273        }
274
275        // Check that sizes product matches the dimension size
276        let expected_size = shape[dim as usize];
277        let actual_size: usize = sizes.iter().product();
278
279        if expected_size != actual_size {
280            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
281                "Sizes {actual_size} don't multiply to dimension size {expected_size}"
282            )));
283        }
284
285        // Build new shape
286        let mut new_shape = Vec::new();
287
288        // Add dimensions before the target dimension
289        for &dim_size in shape.iter().take(dim as usize) {
290            new_shape.push(dim_size);
291        }
292
293        // Add the unflattened dimensions
294        new_shape.extend_from_slice(sizes);
295
296        // Add dimensions after the target dimension
297        for &dim_size in shape.iter().skip(dim as usize + 1) {
298            new_shape.push(dim_size);
299        }
300
301        let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
302        self.reshape(&new_shape_i32)
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_transpose_shortcuts() {
312        let tensor = crate::creation::tensor_2d_arrays(&[[1.0f32, 2.0], [3.0, 4.0]])
313            .expect("tensor creation failed");
314
315        // Test .T() shortcut
316        let transposed = tensor.T().expect("T() failed");
317        assert_eq!(transposed.shape().dims(), &[2, 2]);
318
319        // Test .mT() alias
320        let mt_transposed = tensor.mT().expect("mT() failed");
321        assert_eq!(mt_transposed.shape().dims(), &[2, 2]);
322
323        // Test .H() (should be same as .T() for real numbers)
324        let hermitian = tensor.H().expect("H() failed");
325        assert_eq!(hermitian.shape().dims(), &[2, 2]);
326    }
327
328    #[test]
329    fn test_tensor_properties() {
330        let tensor = crate::creation::tensor_2d_arrays(&[[1.0f32, 2.0], [3.0, 4.0]])
331            .expect("tensor creation failed");
332
333        assert_eq!(tensor.numel(), 4);
334        assert_eq!(tensor.shape().dims(), &[2, 2]);
335        assert!(!tensor.is_empty());
336        assert!(!tensor.is_scalar());
337        assert!(tensor.is_contiguous());
338
339        // Test scalar tensor
340        let scalar = crate::creation::tensor_scalar(42.0f32).expect("scalar creation failed");
341        assert!(scalar.is_scalar());
342        assert_eq!(scalar.item().expect("item retrieval failed"), 42.0);
343    }
344
345    #[test]
346    fn test_shape_convenience() {
347        // Create a 3D tensor with shape [2, 1, 2] using zeros and reshape
348        let tensor = crate::creation::zeros::<f32>(&[4])
349            .expect("zeros creation failed")
350            .reshape(&[2, 1, 2])
351            .expect("reshape failed");
352
353        // Test squeeze_all (should remove dimension of size 1)
354        let squeezed = tensor.squeeze_all().expect("squeeze_all failed");
355        assert_eq!(squeezed.shape().dims(), &[2, 2]);
356
357        // Test flatten
358        let flattened = tensor.flatten().expect("flatten failed");
359        assert_eq!(flattened.shape().dims(), &[4]);
360
361        // Test flatten_from
362        let flat_from_1 = tensor.flatten_from(1).expect("flatten_from failed");
363        assert_eq!(flat_from_1.shape().dims(), &[2, 2]);
364    }
365
366    #[test]
367    fn test_detach() {
368        let tensor =
369            crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor creation failed");
370        let detached = tensor.detach();
371
372        // Should have same data and shape
373        assert_eq!(tensor.shape().dims(), detached.shape().dims());
374        assert_eq!(
375            tensor.data().expect("data retrieval failed"),
376            detached.data().expect("detached data retrieval failed")
377        );
378    }
379
380    #[test]
381    fn test_fluent_api() {
382        use crate::TensorFluentExt;
383        let tensor =
384            crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor creation failed");
385
386        // Test method chaining with fluent API
387        let result = tensor
388            .fluent()
389            .add_scalar(1.0) // [2.0, 3.0, 4.0, 5.0]
390            .mul_scalar(2.0) // [4.0, 6.0, 8.0, 10.0]
391            .sub_scalar(1.0) // [3.0, 5.0, 7.0, 9.0]
392            .unwrap()
393            .unwrap();
394
395        let expected = vec![3.0, 5.0, 7.0, 9.0];
396        let actual = result.to_vec().expect("to_vec failed");
397
398        for (exp, act) in expected.iter().zip(actual.iter()) {
399            assert!((exp - act).abs() < f32::EPSILON);
400        }
401    }
402
403    #[test]
404    fn test_fluent_api_operations() {
405        use crate::TensorFluentExt;
406        let tensor1 =
407            crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor1 creation failed");
408        let tensor2 =
409            crate::creation::tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor2 creation failed");
410
411        // Test tensor operations with fluent API
412        let result = tensor1
413            .fluent()
414            .add(&tensor2) // [3.0, 4.0, 5.0, 6.0]
415            .mul_scalar(0.5) // [1.5, 2.0, 2.5, 3.0]
416            .sum() // 9.0
417            .unwrap()
418            .unwrap();
419
420        let actual = result.to_vec().expect("to_vec failed");
421        assert!((actual[0] - 9.0).abs() < f32::EPSILON);
422    }
423
424    #[test]
425    fn test_fluent_api_mathematical_operations() {
426        use crate::TensorFluentExt;
427        let tensor =
428            crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor creation failed");
429
430        // Test mathematical operations with fluent API
431        let result = tensor
432            .fluent()
433            .relu() // [1.0, 2.0, 3.0, 4.0] (no change since all positive)
434            .pow(2.0) // [1.0, 4.0, 9.0, 16.0]
435            .sigmoid() // sigmoid values
436            .unwrap()
437            .unwrap();
438
439        let actual = result.to_vec().expect("to_vec failed");
440        // Check that all values are between 0 and 1 (sigmoid property)
441        for val in actual.iter() {
442            assert!(*val > 0.0 && *val < 1.0);
443        }
444    }
445}
446
447/// Fluent API trait for method chaining operations
448///
449/// This trait provides a PyTorch-like fluent interface that allows chaining operations
450/// in a readable and natural way. Unlike lazy evaluation, these operations are executed
451/// immediately but return self to enable chaining.
452///
453/// # Examples
454/// ```rust
455/// use torsh_tensor::{Tensor, TensorFluentExt};
456/// use torsh_core::device::DeviceType;
457///
458/// let result = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
459///     .expect("tensor creation failed")
460///     .fluent()
461///     .add_scalar(1.0)
462///     .mul_scalar(2.0)
463///     .relu()
464///     .sum()
465///     .expect("operation should succeed");
466/// ```
467pub trait TensorFluentExt<T: TensorElement> {
468    /// Start fluent chaining
469    fn fluent(self) -> FluentTensor<T>;
470}
471
472/// Wrapper for fluent tensor operations
473pub struct FluentTensor<T: TensorElement> {
474    tensor: Tensor<T>,
475}
476
477impl<T: TensorElement> TensorFluentExt<T> for Tensor<T> {
478    fn fluent(self) -> FluentTensor<T> {
479        FluentTensor { tensor: self }
480    }
481}
482
483impl<
484        T: TensorElement
485            + Copy
486            + std::ops::Add<Output = T>
487            + std::ops::Sub<Output = T>
488            + std::ops::Mul<Output = T>
489            + std::ops::Div<Output = T>
490            + num_traits::Zero,
491    > FluentTensor<T>
492{
493    /// Get the wrapped tensor, consuming the fluent wrapper
494    pub fn tensor(self) -> Tensor<T> {
495        self.tensor
496    }
497
498    /// Unwrap and return as Result
499    pub fn unwrap(self) -> Result<Tensor<T>> {
500        Ok(self.tensor)
501    }
502
503    /// Chain scalar addition
504    pub fn add_scalar(mut self, scalar: T) -> Self {
505        if let Ok(result) = self.tensor.add_scalar(scalar) {
506            self.tensor = result;
507        }
508        self
509    }
510
511    /// Chain scalar multiplication
512    pub fn mul_scalar(mut self, scalar: T) -> Self {
513        if let Ok(result) = self.tensor.mul_scalar(scalar) {
514            self.tensor = result;
515        }
516        self
517    }
518
519    /// Chain scalar subtraction
520    pub fn sub_scalar(mut self, scalar: T) -> Self {
521        if let Ok(result) = self.tensor.sub_scalar(scalar) {
522            self.tensor = result;
523        }
524        self
525    }
526
527    /// Chain scalar division
528    pub fn div_scalar(mut self, scalar: T) -> Self {
529        if let Ok(result) = self.tensor.div_scalar(scalar) {
530            self.tensor = result;
531        }
532        self
533    }
534
535    /// Chain tensor addition
536    pub fn add(mut self, other: &Tensor<T>) -> Self {
537        if let Ok(result) = self.tensor.add_op(other) {
538            self.tensor = result;
539        }
540        self
541    }
542
543    /// Chain tensor multiplication
544    pub fn mul(mut self, other: &Tensor<T>) -> Self {
545        if let Ok(result) = self.tensor.mul_op(other) {
546            self.tensor = result;
547        }
548        self
549    }
550
551    /// Chain tensor subtraction
552    pub fn sub(mut self, other: &Tensor<T>) -> Self {
553        if let Ok(result) = self.tensor.sub(other) {
554            self.tensor = result;
555        }
556        self
557    }
558
559    /// Chain tensor division
560    pub fn div(mut self, other: &Tensor<T>) -> Self {
561        if let Ok(result) = self.tensor.div(other) {
562            self.tensor = result;
563        }
564        self
565    }
566
567    /// Chain reshape operation
568    pub fn reshape(mut self, shape: &[i32]) -> Self {
569        if let Ok(result) = self.tensor.reshape(shape) {
570            self.tensor = result;
571        }
572        self
573    }
574
575    /// Chain transpose operation
576    pub fn transpose(mut self, dim0: i32, dim1: i32) -> Self {
577        if let Ok(result) = self.tensor.transpose(dim0, dim1) {
578            self.tensor = result;
579        }
580        self
581    }
582
583    /// Chain transpose (last two dimensions)
584    pub fn t(mut self) -> Self {
585        if let Ok(result) = self.tensor.t() {
586            self.tensor = result;
587        }
588        self
589    }
590
591    /// Chain sum operation
592    pub fn sum(mut self) -> Self {
593        if let Ok(result) = self.tensor.sum() {
594            self.tensor = result;
595        }
596        self
597    }
598
599    /// Chain sum along dimension
600    pub fn sum_dim(mut self, dims: &[i32], keepdim: bool) -> Self {
601        if let Ok(result) = self.tensor.sum_dim(dims, keepdim) {
602            self.tensor = result;
603        }
604        self
605    }
606
607    /// Chain squeeze operation
608    pub fn squeeze(mut self, dim: i32) -> Self {
609        if let Ok(result) = self.tensor.squeeze(dim) {
610            self.tensor = result;
611        }
612        self
613    }
614
615    /// Chain unsqueeze operation
616    pub fn unsqueeze(mut self, dim: i32) -> Self {
617        if let Ok(result) = self.tensor.unsqueeze(dim) {
618            self.tensor = result;
619        }
620        self
621    }
622}
623
624/// Mathematical operations for fluent chaining
625impl<T: TensorElement + Copy + num_traits::Float> FluentTensor<T> {
626    /// Chain ReLU activation
627    pub fn relu(mut self) -> Self {
628        if let Ok(result) = self.tensor.relu() {
629            self.tensor = result;
630        }
631        self
632    }
633
634    /// Chain sigmoid activation
635    pub fn sigmoid(mut self) -> Self
636    where
637        T: torsh_core::dtype::FloatElement,
638    {
639        if let Ok(result) = self.tensor.sigmoid() {
640            self.tensor = result;
641        }
642        self
643    }
644
645    /// Chain tanh activation
646    pub fn tanh(mut self) -> Self
647    where
648        T: torsh_core::dtype::FloatElement,
649    {
650        if let Ok(result) = self.tensor.tanh() {
651            self.tensor = result;
652        }
653        self
654    }
655
656    /// Chain exponential function
657    pub fn exp(mut self) -> Self
658    where
659        T: torsh_core::dtype::FloatElement,
660    {
661        if let Ok(result) = self.tensor.exp() {
662            self.tensor = result;
663        }
664        self
665    }
666
667    /// Chain logarithm function
668    pub fn log(mut self) -> Self
669    where
670        T: torsh_core::dtype::FloatElement,
671    {
672        if let Ok(result) = self.tensor.log() {
673            self.tensor = result;
674        }
675        self
676    }
677
678    /// Chain power operation
679    pub fn pow(mut self, exponent: T) -> Self
680    where
681        T: torsh_core::dtype::FloatElement + Into<f32>,
682    {
683        if let Ok(result) = self.tensor.pow(exponent) {
684            self.tensor = result;
685        }
686        self
687    }
688
689    // Note: abs() and neg() methods removed due to complex trait requirements
690    // Users can call these methods directly on the tensor when needed
691}
692
693/// Matrix operations for fluent chaining
694impl<T: TensorElement + Copy> FluentTensor<T>
695where
696    T: num_traits::Float + std::iter::Sum,
697{
698    /// Chain matrix multiplication
699    pub fn matmul(mut self, other: &Tensor<T>) -> Self {
700        if let Ok(result) = self.tensor.matmul(other) {
701            self.tensor = result;
702        }
703        self
704    }
705}
706
707/// Mean operations with specific trait bounds
708impl<
709        T: TensorElement
710            + Copy
711            + num_traits::FromPrimitive
712            + std::ops::Div<Output = T>
713            + num_traits::Zero
714            + num_traits::One,
715    > FluentTensor<T>
716{
717    /// Chain mean operation
718    pub fn mean(mut self, dims: Option<&[usize]>, keepdim: bool) -> Self {
719        if let Ok(result) = self.tensor.mean(dims, keepdim) {
720            self.tensor = result;
721        }
722        self
723    }
724}